core_crypto/mls/
proposal.rs1use openmls::prelude::{hash_ref::ProposalRef, KeyPackage};
18
19use mls_crypto_provider::MlsCryptoProvider;
20
21use crate::{
22 mls::{ClientId, ConversationId, MlsConversation},
23 prelude::{Client, CryptoError, CryptoResult, MlsProposalBundle},
24};
25
26use crate::context::CentralContext;
27
28#[derive(Debug, Clone, Eq, PartialEq, derive_more::From, derive_more::Deref, derive_more::Display)]
30pub struct MlsProposalRef(ProposalRef);
31
32impl From<Vec<u8>> for MlsProposalRef {
33 fn from(value: Vec<u8>) -> Self {
34 Self(ProposalRef::from_slice(value.as_slice()))
35 }
36}
37
38impl MlsProposalRef {
39 pub fn into_inner(self) -> ProposalRef {
41 self.0
42 }
43
44 pub(crate) fn to_bytes(&self) -> Vec<u8> {
45 self.0.as_slice().to_vec()
46 }
47}
48
49#[cfg(test)]
50impl From<MlsProposalRef> for Vec<u8> {
51 fn from(prop_ref: MlsProposalRef) -> Self {
52 prop_ref.0.as_slice().to_vec()
53 }
54}
55
56#[allow(clippy::large_enum_variant)]
61pub enum MlsProposal {
62 Add(KeyPackage),
64 Update,
67 Remove(ClientId),
69}
70
71impl MlsProposal {
72 async fn create(
74 self,
75 client: &Client,
76 backend: &MlsCryptoProvider,
77 mut conversation: impl std::ops::DerefMut<Target = MlsConversation>,
78 ) -> CryptoResult<MlsProposalBundle> {
79 let proposal = match self {
80 MlsProposal::Add(key_package) => {
81 (*conversation)
82 .propose_add_member(client, backend, key_package.into())
83 .await
84 }
85 MlsProposal::Update => (*conversation).propose_self_update(client, backend).await,
86 MlsProposal::Remove(client_id) => {
87 let index = conversation
88 .group
89 .members()
90 .find(|kp| kp.credential.identity() == client_id.as_slice())
91 .ok_or(CryptoError::ClientNotFound(client_id))
92 .map(|kp| kp.index)?;
93 (*conversation).propose_remove_member(client, backend, index).await
94 }
95 }?;
96 Ok(proposal)
97 }
98}
99
100impl CentralContext {
101 #[cfg_attr(test, crate::idempotent)]
103 pub async fn new_add_proposal(
104 &self,
105 id: &ConversationId,
106 key_package: KeyPackage,
107 ) -> CryptoResult<MlsProposalBundle> {
108 self.new_proposal(id, MlsProposal::Add(key_package)).await
109 }
110
111 #[cfg_attr(test, crate::idempotent)]
113 pub async fn new_remove_proposal(
114 &self,
115 id: &ConversationId,
116 client_id: ClientId,
117 ) -> CryptoResult<MlsProposalBundle> {
118 self.new_proposal(id, MlsProposal::Remove(client_id)).await
119 }
120
121 #[cfg_attr(test, crate::dispotent)]
123 pub async fn new_update_proposal(&self, id: &ConversationId) -> CryptoResult<MlsProposalBundle> {
124 self.new_proposal(id, MlsProposal::Update).await
125 }
126
127 async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> CryptoResult<MlsProposalBundle> {
140 let conversation = self.get_conversation(id).await?;
141 let client = &self.mls_client().await?;
142 proposal
143 .create(client, &self.mls_provider().await?, conversation.write().await)
144 .await
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use wasm_bindgen_test::*;
151
152 use crate::{prelude::MlsCommitBundle, prelude::*, test_utils::*};
153
154 wasm_bindgen_test_configure!(run_in_browser);
155
156 mod add {
157 use super::*;
158
159 #[apply(all_cred_cipher)]
160 #[wasm_bindgen_test]
161 pub async fn should_add_member(case: TestCase) {
162 run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice_central, bob_central]| {
163 Box::pin(async move {
164 let id = conversation_id();
165 alice_central
166 .context
167 .new_conversation(&id, case.credential_type, case.cfg.clone())
168 .await
169 .unwrap();
170 let bob_kp = bob_central.get_one_key_package(&case).await;
171 alice_central.context.new_add_proposal(&id, bob_kp).await.unwrap();
172 let MlsCommitBundle { welcome, .. } = alice_central
173 .context
174 .commit_pending_proposals(&id)
175 .await
176 .unwrap()
177 .unwrap();
178 alice_central.context.commit_accepted(&id).await.unwrap();
179 assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
180 let new_id = bob_central
181 .context
182 .process_welcome_message(welcome.unwrap().into(), case.custom_cfg())
183 .await
184 .unwrap()
185 .id;
186 assert_eq!(id, new_id);
187 assert!(bob_central.try_talk_to(&id, &alice_central).await.is_ok());
188 })
189 })
190 .await
191 }
192 }
193
194 mod update {
195 use super::*;
196 use itertools::Itertools;
197
198 #[apply(all_cred_cipher)]
199 #[wasm_bindgen_test]
200 pub async fn should_update_hpke_key(case: TestCase) {
201 run_test_with_central(case.clone(), move |[central]| {
202 Box::pin(async move {
203 let id = conversation_id();
204 central
205 .context
206 .new_conversation(&id, case.credential_type, case.cfg.clone())
207 .await
208 .unwrap();
209 let before = central
210 .get_conversation_unchecked(&id)
211 .await
212 .encryption_keys()
213 .find_or_first(|_| true)
214 .unwrap();
215 central.context.new_update_proposal(&id).await.unwrap();
216 central.context.commit_pending_proposals(&id).await.unwrap();
217 central.context.commit_accepted(&id).await.unwrap();
218 let after = central
219 .get_conversation_unchecked(&id)
220 .await
221 .encryption_keys()
222 .find_or_first(|_| true)
223 .unwrap();
224 assert_ne!(before, after)
225 })
226 })
227 .await
228 }
229 }
230
231 mod remove {
232 use super::*;
233
234 #[apply(all_cred_cipher)]
235 #[wasm_bindgen_test]
236 pub async fn should_remove_member(case: TestCase) {
237 run_test_with_client_ids(case.clone(), ["alice", "bob"], |[alice_central, bob_central]| {
238 Box::pin(async move {
239 let id = conversation_id();
240 alice_central
241 .context
242 .new_conversation(&id, case.credential_type, case.cfg.clone())
243 .await
244 .unwrap();
245 alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
246 assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
247 assert_eq!(bob_central.get_conversation_unchecked(&id).await.members().len(), 2);
248
249 let remove_proposal = alice_central
250 .context
251 .new_remove_proposal(&id, bob_central.get_client_id().await)
252 .await
253 .unwrap();
254 bob_central
255 .context
256 .decrypt_message(&id, remove_proposal.proposal.to_bytes().unwrap())
257 .await
258 .unwrap();
259 let MlsCommitBundle { commit, .. } = alice_central
260 .context
261 .commit_pending_proposals(&id)
262 .await
263 .unwrap()
264 .unwrap();
265 alice_central.context.commit_accepted(&id).await.unwrap();
266 assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 1);
267
268 bob_central
269 .context
270 .decrypt_message(&id, commit.to_bytes().unwrap())
271 .await
272 .unwrap();
273 assert!(matches!(
274 bob_central.context.get_conversation(&id).await.unwrap_err(),
275 CryptoError::ConversationNotFound(conv_id) if conv_id == id
276 ));
277 })
278 })
279 .await
280 }
281
282 #[apply(all_cred_cipher)]
283 #[wasm_bindgen_test]
284 pub async fn should_fail_when_unknown_client(case: TestCase) {
285 run_test_with_client_ids(case.clone(), ["alice"], move |[alice_central]| {
286 Box::pin(async move {
287 let id = conversation_id();
288 alice_central
289 .context
290 .new_conversation(&id, case.credential_type, case.cfg.clone())
291 .await
292 .unwrap();
293
294 let remove_proposal = alice_central
295 .context
296 .new_remove_proposal(&id, b"unknown"[..].into())
297 .await;
298 assert!(matches!(
299 remove_proposal.unwrap_err(),
300 CryptoError::ClientNotFound(client_id) if client_id == b"unknown"[..].into()
301 ));
302 })
303 })
304 .await
305 }
306 }
307}