core_crypto/mls/
proposal.rs1use openmls::prelude::{KeyPackage, hash_ref::ProposalRef};
18
19use mls_crypto_provider::MlsCryptoProvider;
20
21use super::{Error, Result};
22use crate::{
23 RecursiveError,
24 mls::{ClientId, ConversationId, MlsConversation},
25 prelude::{Client, MlsProposalBundle},
26};
27
28use crate::context::CentralContext;
29
30#[derive(Debug, Clone, Eq, PartialEq, derive_more::From, derive_more::Deref, derive_more::Display)]
32pub struct MlsProposalRef(ProposalRef);
33
34impl From<Vec<u8>> for MlsProposalRef {
35 fn from(value: Vec<u8>) -> Self {
36 Self(ProposalRef::from_slice(value.as_slice()))
37 }
38}
39
40impl MlsProposalRef {
41 pub fn into_inner(self) -> ProposalRef {
43 self.0
44 }
45
46 pub(crate) fn to_bytes(&self) -> Vec<u8> {
47 self.0.as_slice().to_vec()
48 }
49}
50
51#[cfg(test)]
52impl From<MlsProposalRef> for Vec<u8> {
53 fn from(prop_ref: MlsProposalRef) -> Self {
54 prop_ref.0.as_slice().to_vec()
55 }
56}
57
58#[allow(clippy::large_enum_variant)]
63pub enum MlsProposal {
64 Add(KeyPackage),
66 Update,
69 Remove(ClientId),
71}
72
73impl MlsProposal {
74 async fn create(
76 self,
77 client: &Client,
78 backend: &MlsCryptoProvider,
79 mut conversation: impl std::ops::DerefMut<Target = MlsConversation>,
80 ) -> Result<MlsProposalBundle> {
81 let proposal = match self {
82 MlsProposal::Add(key_package) => (*conversation)
83 .propose_add_member(client, backend, key_package.into())
84 .await
85 .map_err(RecursiveError::mls_conversation("proposing to add member"))?,
86 MlsProposal::Update => (*conversation)
87 .propose_self_update(client, backend)
88 .await
89 .map_err(RecursiveError::mls_conversation("proposing self update"))?,
90 MlsProposal::Remove(client_id) => {
91 let index = conversation
92 .group
93 .members()
94 .find(|kp| kp.credential.identity() == client_id.as_slice())
95 .ok_or(Error::ClientNotFound(client_id))
96 .map(|kp| kp.index)?;
97 (*conversation)
98 .propose_remove_member(client, backend, index)
99 .await
100 .map_err(RecursiveError::mls_conversation("proposing to remove member"))?
101 }
102 };
103 Ok(proposal)
104 }
105}
106
107impl CentralContext {
108 #[cfg_attr(test, crate::idempotent)]
110 pub async fn new_add_proposal(&self, id: &ConversationId, key_package: KeyPackage) -> Result<MlsProposalBundle> {
111 self.new_proposal(id, MlsProposal::Add(key_package)).await
112 }
113
114 #[cfg_attr(test, crate::idempotent)]
116 pub async fn new_remove_proposal(&self, id: &ConversationId, client_id: ClientId) -> Result<MlsProposalBundle> {
117 self.new_proposal(id, MlsProposal::Remove(client_id)).await
118 }
119
120 #[cfg_attr(test, crate::dispotent)]
122 pub async fn new_update_proposal(&self, id: &ConversationId) -> Result<MlsProposalBundle> {
123 self.new_proposal(id, MlsProposal::Update).await
124 }
125
126 async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result<MlsProposalBundle> {
139 let mut conversation = self
140 .conversation(id)
141 .await
142 .map_err(RecursiveError::mls_conversation("getting conversation by id"))?;
143 let client = &self
144 .mls_client()
145 .await
146 .map_err(RecursiveError::root("getting mls client"))?;
147 proposal
148 .create(
149 client,
150 &self
151 .mls_provider()
152 .await
153 .map_err(RecursiveError::root("getting mls provider"))?,
154 conversation.conversation_mut().await,
155 )
156 .await
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use wasm_bindgen_test::*;
163
164 use crate::{prelude::*, test_utils::*};
165
166 wasm_bindgen_test_configure!(run_in_browser);
167
168 mod add {
169 use super::*;
170
171 #[apply(all_cred_cipher)]
172 #[wasm_bindgen_test]
173 pub async fn should_add_member(case: TestCase) {
174 run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice_central, bob_central]| {
175 Box::pin(async move {
176 let id = conversation_id();
177 alice_central
178 .context
179 .new_conversation(&id, case.credential_type, case.cfg.clone())
180 .await
181 .unwrap();
182 let bob_kp = bob_central.get_one_key_package(&case).await;
183 alice_central.context.new_add_proposal(&id, bob_kp).await.unwrap();
184 alice_central
185 .context
186 .conversation(&id)
187 .await
188 .unwrap()
189 .commit_pending_proposals()
190 .await
191 .unwrap();
192 let welcome = alice_central.mls_transport.latest_welcome_message().await;
193 assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
194 let new_id = bob_central
195 .context
196 .process_welcome_message(welcome.into(), case.custom_cfg())
197 .await
198 .unwrap()
199 .id;
200 assert_eq!(id, new_id);
201 assert!(bob_central.try_talk_to(&id, &alice_central).await.is_ok());
202 })
203 })
204 .await
205 }
206 }
207
208 mod update {
209 use super::*;
210 use itertools::Itertools;
211
212 #[apply(all_cred_cipher)]
213 #[wasm_bindgen_test]
214 pub async fn should_update_hpke_key(case: TestCase) {
215 run_test_with_central(case.clone(), move |[central]| {
216 Box::pin(async move {
217 let id = conversation_id();
218 central
219 .context
220 .new_conversation(&id, case.credential_type, case.cfg.clone())
221 .await
222 .unwrap();
223 let before = central
224 .get_conversation_unchecked(&id)
225 .await
226 .encryption_keys()
227 .find_or_first(|_| true)
228 .unwrap();
229 central.context.new_update_proposal(&id).await.unwrap();
230 central
231 .context
232 .conversation(&id)
233 .await
234 .unwrap()
235 .commit_pending_proposals()
236 .await
237 .unwrap();
238 let after = central
239 .get_conversation_unchecked(&id)
240 .await
241 .encryption_keys()
242 .find_or_first(|_| true)
243 .unwrap();
244 assert_ne!(before, after)
245 })
246 })
247 .await
248 }
249 }
250
251 mod remove {
252 use super::*;
253
254 #[apply(all_cred_cipher)]
255 #[wasm_bindgen_test]
256 pub async fn should_remove_member(case: TestCase) {
257 use crate::mls;
258
259 run_test_with_client_ids(case.clone(), ["alice", "bob"], |[alice_central, bob_central]| {
260 Box::pin(async move {
261 let id = conversation_id();
262 alice_central
263 .context
264 .new_conversation(&id, case.credential_type, case.cfg.clone())
265 .await
266 .unwrap();
267 alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
268 assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
269 assert_eq!(bob_central.get_conversation_unchecked(&id).await.members().len(), 2);
270
271 let remove_proposal = alice_central
272 .context
273 .new_remove_proposal(&id, bob_central.get_client_id().await)
274 .await
275 .unwrap();
276 bob_central
277 .context
278 .conversation(&id)
279 .await
280 .unwrap()
281 .decrypt_message(remove_proposal.proposal.to_bytes().unwrap())
282 .await
283 .unwrap();
284 alice_central
285 .context
286 .conversation(&id)
287 .await
288 .unwrap()
289 .commit_pending_proposals()
290 .await
291 .unwrap();
292 assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 1);
293
294 let commit = alice_central.mls_transport.latest_commit().await;
295 bob_central
296 .context
297 .conversation(&id)
298 .await
299 .unwrap()
300 .decrypt_message(commit.to_bytes().unwrap())
301 .await
302 .unwrap();
303 assert!(matches!(
304 bob_central.context.conversation(&id).await.unwrap_err(),
305 mls::conversation::Error::Leaf(LeafError::ConversationNotFound(conv_id)) if conv_id == id
306 ));
307 })
308 })
309 .await
310 }
311
312 #[apply(all_cred_cipher)]
313 #[wasm_bindgen_test]
314 pub async fn should_fail_when_unknown_client(case: TestCase) {
315 use crate::mls;
316
317 run_test_with_client_ids(case.clone(), ["alice"], move |[alice_central]| {
318 Box::pin(async move {
319 let id = conversation_id();
320 alice_central
321 .context
322 .new_conversation(&id, case.credential_type, case.cfg.clone())
323 .await
324 .unwrap();
325
326 let remove_proposal = alice_central
327 .context
328 .new_remove_proposal(&id, b"unknown"[..].into())
329 .await;
330 assert!(matches!(
331 remove_proposal.unwrap_err(),
332 mls::Error::ClientNotFound(client_id) if client_id == b"unknown"[..].into()
333 ));
334 })
335 })
336 .await
337 }
338 }
339}