core_crypto/transaction_context/conversation/
proposal.rs

1use openmls::prelude::KeyPackage;
2
3use super::{Error, Result};
4use crate::{
5    RecursiveError,
6    prelude::{ClientId, ConversationId, MlsProposal, MlsProposalBundle},
7    transaction_context::TransactionContext,
8};
9
10impl TransactionContext {
11    /// Creates a new Add proposal
12    #[cfg_attr(test, crate::idempotent)]
13    pub async fn new_add_proposal(&self, id: &ConversationId, key_package: KeyPackage) -> Result<MlsProposalBundle> {
14        self.new_proposal(id, MlsProposal::Add(key_package)).await
15    }
16
17    /// Creates a new Add proposal
18    #[cfg_attr(test, crate::idempotent)]
19    pub async fn new_remove_proposal(&self, id: &ConversationId, client_id: ClientId) -> Result<MlsProposalBundle> {
20        self.new_proposal(id, MlsProposal::Remove(client_id)).await
21    }
22
23    /// Creates a new Add proposal
24    #[cfg_attr(test, crate::dispotent)]
25    pub async fn new_update_proposal(&self, id: &ConversationId) -> Result<MlsProposalBundle> {
26        self.new_proposal(id, MlsProposal::Update).await
27    }
28
29    /// Creates a new proposal within a group
30    ///
31    /// # Arguments
32    /// * `conversation` - the group/conversation id
33    /// * `proposal` - the proposal do be added in the group
34    ///
35    /// # Return type
36    /// A [MlsProposalBundle] with the proposal in a Mls message and a reference to that proposal in order to rollback it if required
37    ///
38    /// # Errors
39    /// If the conversation is not found, an error will be returned. Errors from OpenMls can be
40    /// returned as well, when for example there's a commit pending to be merged
41    async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result<MlsProposalBundle> {
42        let mut conversation = self.conversation(id).await?;
43        let mut conversation = conversation.conversation_mut().await;
44        let client = &self.session().await?;
45        let backend = &self.mls_provider().await?;
46        let proposal = match proposal {
47            MlsProposal::Add(key_package) => conversation
48                .propose_add_member(client, backend, key_package.into())
49                .await
50                .map_err(RecursiveError::mls_conversation("proposing to add member"))?,
51            MlsProposal::Update => conversation
52                .propose_self_update(client, backend)
53                .await
54                .map_err(RecursiveError::mls_conversation("proposing self update"))?,
55            MlsProposal::Remove(client_id) => {
56                let index = conversation
57                    .group
58                    .members()
59                    .find(|kp| kp.credential.identity() == client_id.as_slice())
60                    .ok_or(Error::ClientNotFound(client_id))
61                    .map(|kp| kp.index)?;
62                (*conversation)
63                    .propose_remove_member(client, backend, index)
64                    .await
65                    .map_err(RecursiveError::mls_conversation("proposing to remove member"))?
66            }
67        };
68        Ok(proposal)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use wasm_bindgen_test::*;
75
76    use crate::{prelude::*, test_utils::*};
77
78    wasm_bindgen_test_configure!(run_in_browser);
79    use super::Error;
80
81    mod add {
82        use super::*;
83
84        #[apply(all_cred_cipher)]
85        #[wasm_bindgen_test]
86        pub async fn should_add_member(case: TestContext) {
87            let [alice_central, bob_central] = case.sessions().await;
88            Box::pin(async move {
89                let id = conversation_id();
90                alice_central
91                    .transaction
92                    .new_conversation(&id, case.credential_type, case.cfg.clone())
93                    .await
94                    .unwrap();
95                let bob_kp = bob_central.get_one_key_package(&case).await;
96                alice_central.transaction.new_add_proposal(&id, bob_kp).await.unwrap();
97                alice_central
98                    .transaction
99                    .conversation(&id)
100                    .await
101                    .unwrap()
102                    .commit_pending_proposals()
103                    .await
104                    .unwrap();
105                let welcome = alice_central.mls_transport().await.latest_welcome_message().await;
106                assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
107                let new_id = bob_central
108                    .transaction
109                    .process_welcome_message(welcome.into(), case.custom_cfg())
110                    .await
111                    .unwrap()
112                    .id;
113                assert_eq!(id, new_id);
114                assert!(bob_central.try_talk_to(&id, &alice_central).await.is_ok());
115            })
116            .await
117        }
118    }
119
120    mod update {
121        use super::*;
122        use itertools::Itertools;
123
124        #[apply(all_cred_cipher)]
125        #[wasm_bindgen_test]
126        pub async fn should_update_hpke_key(case: TestContext) {
127            let [session] = case.sessions().await;
128            let id = conversation_id();
129            session
130                .transaction
131                .new_conversation(&id, case.credential_type, case.cfg.clone())
132                .await
133                .unwrap();
134            let before = session
135                .get_conversation_unchecked(&id)
136                .await
137                .encryption_keys()
138                .find_or_first(|_| true)
139                .unwrap();
140            session.transaction.new_update_proposal(&id).await.unwrap();
141            session
142                .transaction
143                .conversation(&id)
144                .await
145                .unwrap()
146                .commit_pending_proposals()
147                .await
148                .unwrap();
149            let after = session
150                .get_conversation_unchecked(&id)
151                .await
152                .encryption_keys()
153                .find_or_first(|_| true)
154                .unwrap();
155            assert_ne!(before, after)
156        }
157    }
158
159    mod remove {
160        use super::*;
161
162        #[apply(all_cred_cipher)]
163        #[wasm_bindgen_test]
164        pub async fn should_remove_member(case: TestContext) {
165            let [alice_central, bob_central] = case.sessions().await;
166            Box::pin(async move {
167                let id = conversation_id();
168                alice_central
169                    .transaction
170                    .new_conversation(&id, case.credential_type, case.cfg.clone())
171                    .await
172                    .unwrap();
173                alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
174                assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
175                assert_eq!(bob_central.get_conversation_unchecked(&id).await.members().len(), 2);
176
177                let remove_proposal = alice_central
178                    .transaction
179                    .new_remove_proposal(&id, bob_central.get_client_id().await)
180                    .await
181                    .unwrap();
182                bob_central
183                    .transaction
184                    .conversation(&id)
185                    .await
186                    .unwrap()
187                    .decrypt_message(remove_proposal.proposal.to_bytes().unwrap())
188                    .await
189                    .unwrap();
190                alice_central
191                    .transaction
192                    .conversation(&id)
193                    .await
194                    .unwrap()
195                    .commit_pending_proposals()
196                    .await
197                    .unwrap();
198                assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 1);
199
200                let commit = alice_central.mls_transport().await.latest_commit().await;
201                bob_central
202                    .transaction
203                    .conversation(&id)
204                    .await
205                    .unwrap()
206                    .decrypt_message(commit.to_bytes().unwrap())
207                    .await
208                    .unwrap();
209                assert!(matches!(
210                    bob_central.transaction.conversation(&id).await.unwrap_err(),
211                    Error::Leaf(LeafError::ConversationNotFound(conv_id)) if conv_id == id
212                ));
213            })
214            .await
215        }
216
217        #[apply(all_cred_cipher)]
218        #[wasm_bindgen_test]
219        pub async fn should_fail_when_unknown_client(case: TestContext) {
220            let [alice_central] = case.sessions().await;
221            Box::pin(async move {
222                let id = conversation_id();
223                alice_central
224                    .transaction
225                    .new_conversation(&id, case.credential_type, case.cfg.clone())
226                    .await
227                    .unwrap();
228
229                let remove_proposal = alice_central
230                    .transaction
231                    .new_remove_proposal(&id, b"unknown"[..].into())
232                    .await;
233                assert!(matches!(
234                    remove_proposal.unwrap_err(),
235                    Error::ClientNotFound(client_id) if client_id == b"unknown"[..].into()
236                ));
237            })
238            .await
239        }
240    }
241}