core_crypto/transaction_context/conversation/
proposal.rs

1use openmls::prelude::KeyPackage;
2
3use super::{Error, Result};
4use crate::{
5    RecursiveError,
6    prelude::{ClientId, ConversationId, MlsProposal, MlsProposalBundle},
7    transaction_context::TransactionContext,
8};
9
10impl TransactionContext {
11    /// Creates a new Add proposal
12    #[cfg_attr(test, crate::idempotent)]
13    pub async fn new_add_proposal(&self, id: &ConversationId, key_package: KeyPackage) -> Result<MlsProposalBundle> {
14        self.new_proposal(id, MlsProposal::Add(key_package)).await
15    }
16
17    /// Creates a new Add proposal
18    #[cfg_attr(test, crate::idempotent)]
19    pub async fn new_remove_proposal(&self, id: &ConversationId, client_id: ClientId) -> Result<MlsProposalBundle> {
20        self.new_proposal(id, MlsProposal::Remove(client_id)).await
21    }
22
23    /// Creates a new Add proposal
24    #[cfg_attr(test, crate::dispotent)]
25    pub async fn new_update_proposal(&self, id: &ConversationId) -> Result<MlsProposalBundle> {
26        self.new_proposal(id, MlsProposal::Update).await
27    }
28
29    /// Creates a new proposal within a group
30    ///
31    /// # Arguments
32    /// * `conversation` - the group/conversation id
33    /// * `proposal` - the proposal do be added in the group
34    ///
35    /// # Return type
36    /// A [MlsProposalBundle] with the proposal in a Mls message and a reference to that proposal in order to rollback it if required
37    ///
38    /// # Errors
39    /// If the conversation is not found, an error will be returned. Errors from OpenMls can be
40    /// returned as well, when for example there's a commit pending to be merged
41    async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result<MlsProposalBundle> {
42        let mut conversation = self.conversation(id).await?;
43        let mut conversation = conversation.conversation_mut().await;
44        let client = &self.session().await?;
45        let backend = &self.mls_provider().await?;
46        let proposal = match proposal {
47            MlsProposal::Add(key_package) => conversation
48                .propose_add_member(client, backend, key_package.into())
49                .await
50                .map_err(RecursiveError::mls_conversation("proposing to add member"))?,
51            MlsProposal::Update => conversation
52                .propose_self_update(client, backend)
53                .await
54                .map_err(RecursiveError::mls_conversation("proposing self update"))?,
55            MlsProposal::Remove(client_id) => {
56                let index = conversation
57                    .group
58                    .members()
59                    .find(|kp| kp.credential.identity() == client_id.as_slice())
60                    .ok_or(Error::ClientNotFound(client_id))
61                    .map(|kp| kp.index)?;
62                (*conversation)
63                    .propose_remove_member(client, backend, index)
64                    .await
65                    .map_err(RecursiveError::mls_conversation("proposing to remove member"))?
66            }
67        };
68        Ok(proposal)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use wasm_bindgen_test::*;
75
76    use crate::{prelude::*, test_utils::*};
77
78    wasm_bindgen_test_configure!(run_in_browser);
79    use super::Error;
80
81    mod add {
82        use super::*;
83
84        #[apply(all_cred_cipher)]
85        #[wasm_bindgen_test]
86        pub async fn should_add_member(case: TestCase) {
87            run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice_central, bob_central]| {
88                Box::pin(async move {
89                    let id = conversation_id();
90                    alice_central
91                        .context
92                        .new_conversation(&id, case.credential_type, case.cfg.clone())
93                        .await
94                        .unwrap();
95                    let bob_kp = bob_central.get_one_key_package(&case).await;
96                    alice_central.context.new_add_proposal(&id, bob_kp).await.unwrap();
97                    alice_central
98                        .context
99                        .conversation(&id)
100                        .await
101                        .unwrap()
102                        .commit_pending_proposals()
103                        .await
104                        .unwrap();
105                    let welcome = alice_central.mls_transport.latest_welcome_message().await;
106                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
107                    let new_id = bob_central
108                        .context
109                        .process_welcome_message(welcome.into(), case.custom_cfg())
110                        .await
111                        .unwrap()
112                        .id;
113                    assert_eq!(id, new_id);
114                    assert!(bob_central.try_talk_to(&id, &alice_central).await.is_ok());
115                })
116            })
117            .await
118        }
119    }
120
121    mod update {
122        use super::*;
123        use itertools::Itertools;
124
125        #[apply(all_cred_cipher)]
126        #[wasm_bindgen_test]
127        pub async fn should_update_hpke_key(case: TestCase) {
128            run_test_with_central(case.clone(), move |[central]| {
129                Box::pin(async move {
130                    let id = conversation_id();
131                    central
132                        .context
133                        .new_conversation(&id, case.credential_type, case.cfg.clone())
134                        .await
135                        .unwrap();
136                    let before = central
137                        .get_conversation_unchecked(&id)
138                        .await
139                        .encryption_keys()
140                        .find_or_first(|_| true)
141                        .unwrap();
142                    central.context.new_update_proposal(&id).await.unwrap();
143                    central
144                        .context
145                        .conversation(&id)
146                        .await
147                        .unwrap()
148                        .commit_pending_proposals()
149                        .await
150                        .unwrap();
151                    let after = central
152                        .get_conversation_unchecked(&id)
153                        .await
154                        .encryption_keys()
155                        .find_or_first(|_| true)
156                        .unwrap();
157                    assert_ne!(before, after)
158                })
159            })
160            .await
161        }
162    }
163
164    mod remove {
165        use super::*;
166
167        #[apply(all_cred_cipher)]
168        #[wasm_bindgen_test]
169        pub async fn should_remove_member(case: TestCase) {
170            run_test_with_client_ids(case.clone(), ["alice", "bob"], |[alice_central, bob_central]| {
171                Box::pin(async move {
172                    let id = conversation_id();
173                    alice_central
174                        .context
175                        .new_conversation(&id, case.credential_type, case.cfg.clone())
176                        .await
177                        .unwrap();
178                    alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
179                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
180                    assert_eq!(bob_central.get_conversation_unchecked(&id).await.members().len(), 2);
181
182                    let remove_proposal = alice_central
183                        .context
184                        .new_remove_proposal(&id, bob_central.get_client_id().await)
185                        .await
186                        .unwrap();
187                    bob_central
188                        .context
189                        .conversation(&id)
190                        .await
191                        .unwrap()
192                        .decrypt_message(remove_proposal.proposal.to_bytes().unwrap())
193                        .await
194                        .unwrap();
195                    alice_central
196                        .context
197                        .conversation(&id)
198                        .await
199                        .unwrap()
200                        .commit_pending_proposals()
201                        .await
202                        .unwrap();
203                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 1);
204
205                    let commit = alice_central.mls_transport.latest_commit().await;
206                    bob_central
207                        .context
208                        .conversation(&id)
209                        .await
210                        .unwrap()
211                        .decrypt_message(commit.to_bytes().unwrap())
212                        .await
213                        .unwrap();
214                    assert!(matches!(
215                        bob_central.context.conversation(&id).await.unwrap_err(),
216                        Error::Leaf(LeafError::ConversationNotFound(conv_id)) if conv_id == id
217                    ));
218                })
219            })
220            .await
221        }
222
223        #[apply(all_cred_cipher)]
224        #[wasm_bindgen_test]
225        pub async fn should_fail_when_unknown_client(case: TestCase) {
226            run_test_with_client_ids(case.clone(), ["alice"], move |[alice_central]| {
227                Box::pin(async move {
228                    let id = conversation_id();
229                    alice_central
230                        .context
231                        .new_conversation(&id, case.credential_type, case.cfg.clone())
232                        .await
233                        .unwrap();
234
235                    let remove_proposal = alice_central
236                        .context
237                        .new_remove_proposal(&id, b"unknown"[..].into())
238                        .await;
239                    assert!(matches!(
240                        remove_proposal.unwrap_err(),
241                        Error::ClientNotFound(client_id) if client_id == b"unknown"[..].into()
242                    ));
243                })
244            })
245            .await
246        }
247    }
248}