core_crypto/transaction_context/conversation/
proposal.rs

1use openmls::prelude::KeyPackage;
2
3use super::{Error, Result};
4use crate::{
5    RecursiveError,
6    prelude::{ClientId, ConversationId, MlsProposal, MlsProposalBundle},
7    transaction_context::TransactionContext,
8};
9
10impl TransactionContext {
11    /// Creates a new Add proposal
12    #[cfg_attr(test, crate::idempotent)]
13    pub async fn new_add_proposal(&self, id: &ConversationId, key_package: KeyPackage) -> Result<MlsProposalBundle> {
14        self.new_proposal(id, MlsProposal::Add(key_package)).await
15    }
16
17    /// Creates a new Add proposal
18    #[cfg_attr(test, crate::idempotent)]
19    pub async fn new_remove_proposal(&self, id: &ConversationId, client_id: ClientId) -> Result<MlsProposalBundle> {
20        self.new_proposal(id, MlsProposal::Remove(client_id)).await
21    }
22
23    /// Creates a new Add proposal
24    #[cfg_attr(test, crate::dispotent)]
25    pub async fn new_update_proposal(&self, id: &ConversationId) -> Result<MlsProposalBundle> {
26        self.new_proposal(id, MlsProposal::Update).await
27    }
28
29    /// Creates a new proposal within a group
30    ///
31    /// # Arguments
32    /// * `conversation` - the group/conversation id
33    /// * `proposal` - the proposal do be added in the group
34    ///
35    /// # Return type
36    /// A [MlsProposalBundle] with the proposal in a Mls message and a reference to that proposal in order to rollback it if required
37    ///
38    /// # Errors
39    /// If the conversation is not found, an error will be returned. Errors from OpenMls can be
40    /// returned as well, when for example there's a commit pending to be merged
41    async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result<MlsProposalBundle> {
42        let mut conversation = self.conversation(id).await?;
43        let mut conversation = conversation.conversation_mut().await;
44        let client = &self.session().await?;
45        let backend = &self.mls_provider().await?;
46        let proposal = match proposal {
47            MlsProposal::Add(key_package) => conversation
48                .propose_add_member(client, backend, key_package.into())
49                .await
50                .map_err(RecursiveError::mls_conversation("proposing to add member"))?,
51            MlsProposal::Update => conversation
52                .propose_self_update(client, backend)
53                .await
54                .map_err(RecursiveError::mls_conversation("proposing self update"))?,
55            MlsProposal::Remove(client_id) => {
56                let index = conversation
57                    .group
58                    .members()
59                    .find(|kp| kp.credential.identity() == client_id.as_slice())
60                    .ok_or(Error::ClientNotFound(client_id))
61                    .map(|kp| kp.index)?;
62                (*conversation)
63                    .propose_remove_member(client, backend, index)
64                    .await
65                    .map_err(RecursiveError::mls_conversation("proposing to remove member"))?
66            }
67        };
68        Ok(proposal)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use wasm_bindgen_test::*;
75
76    use crate::{prelude::*, test_utils::*};
77
78    wasm_bindgen_test_configure!(run_in_browser);
79    use super::Error;
80
81    mod add {
82        use super::*;
83
84        #[apply(all_cred_cipher)]
85        #[wasm_bindgen_test]
86        pub async fn should_add_member(case: TestContext) {
87            run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice_central, bob_central]| {
88                Box::pin(async move {
89                    let id = conversation_id();
90                    alice_central
91                        .transaction
92                        .new_conversation(&id, case.credential_type, case.cfg.clone())
93                        .await
94                        .unwrap();
95                    let bob_kp = bob_central.get_one_key_package(&case).await;
96                    alice_central.transaction.new_add_proposal(&id, bob_kp).await.unwrap();
97                    alice_central
98                        .transaction
99                        .conversation(&id)
100                        .await
101                        .unwrap()
102                        .commit_pending_proposals()
103                        .await
104                        .unwrap();
105                    let welcome = alice_central.mls_transport.latest_welcome_message().await;
106                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
107                    let new_id = bob_central
108                        .transaction
109                        .process_welcome_message(welcome.into(), case.custom_cfg())
110                        .await
111                        .unwrap()
112                        .id;
113                    assert_eq!(id, new_id);
114                    assert!(bob_central.try_talk_to(&id, &alice_central).await.is_ok());
115                })
116            })
117            .await
118        }
119    }
120
121    mod update {
122        use super::*;
123        use itertools::Itertools;
124
125        #[apply(all_cred_cipher)]
126        #[wasm_bindgen_test]
127        pub async fn should_update_hpke_key(case: TestContext) {
128            let [session] = case.sessions().await;
129            let id = conversation_id();
130            session
131                .transaction
132                .new_conversation(&id, case.credential_type, case.cfg.clone())
133                .await
134                .unwrap();
135            let before = session
136                .get_conversation_unchecked(&id)
137                .await
138                .encryption_keys()
139                .find_or_first(|_| true)
140                .unwrap();
141            session.transaction.new_update_proposal(&id).await.unwrap();
142            session
143                .transaction
144                .conversation(&id)
145                .await
146                .unwrap()
147                .commit_pending_proposals()
148                .await
149                .unwrap();
150            let after = session
151                .get_conversation_unchecked(&id)
152                .await
153                .encryption_keys()
154                .find_or_first(|_| true)
155                .unwrap();
156            assert_ne!(before, after)
157        }
158    }
159
160    mod remove {
161        use super::*;
162
163        #[apply(all_cred_cipher)]
164        #[wasm_bindgen_test]
165        pub async fn should_remove_member(case: TestContext) {
166            run_test_with_client_ids(case.clone(), ["alice", "bob"], |[alice_central, bob_central]| {
167                Box::pin(async move {
168                    let id = conversation_id();
169                    alice_central
170                        .transaction
171                        .new_conversation(&id, case.credential_type, case.cfg.clone())
172                        .await
173                        .unwrap();
174                    alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
175                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
176                    assert_eq!(bob_central.get_conversation_unchecked(&id).await.members().len(), 2);
177
178                    let remove_proposal = alice_central
179                        .transaction
180                        .new_remove_proposal(&id, bob_central.get_client_id().await)
181                        .await
182                        .unwrap();
183                    bob_central
184                        .transaction
185                        .conversation(&id)
186                        .await
187                        .unwrap()
188                        .decrypt_message(remove_proposal.proposal.to_bytes().unwrap())
189                        .await
190                        .unwrap();
191                    alice_central
192                        .transaction
193                        .conversation(&id)
194                        .await
195                        .unwrap()
196                        .commit_pending_proposals()
197                        .await
198                        .unwrap();
199                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 1);
200
201                    let commit = alice_central.mls_transport.latest_commit().await;
202                    bob_central
203                        .transaction
204                        .conversation(&id)
205                        .await
206                        .unwrap()
207                        .decrypt_message(commit.to_bytes().unwrap())
208                        .await
209                        .unwrap();
210                    assert!(matches!(
211                        bob_central.transaction.conversation(&id).await.unwrap_err(),
212                        Error::Leaf(LeafError::ConversationNotFound(conv_id)) if conv_id == id
213                    ));
214                })
215            })
216            .await
217        }
218
219        #[apply(all_cred_cipher)]
220        #[wasm_bindgen_test]
221        pub async fn should_fail_when_unknown_client(case: TestContext) {
222            run_test_with_client_ids(case.clone(), ["alice"], move |[alice_central]| {
223                Box::pin(async move {
224                    let id = conversation_id();
225                    alice_central
226                        .transaction
227                        .new_conversation(&id, case.credential_type, case.cfg.clone())
228                        .await
229                        .unwrap();
230
231                    let remove_proposal = alice_central
232                        .transaction
233                        .new_remove_proposal(&id, b"unknown"[..].into())
234                        .await;
235                    assert!(matches!(
236                        remove_proposal.unwrap_err(),
237                        Error::ClientNotFound(client_id) if client_id == b"unknown"[..].into()
238                    ));
239                })
240            })
241            .await
242        }
243    }
244}