1use std::collections::HashMap;
2
3use openmls::prelude::{KeyPackage, KeyPackageRef, MlsCredentialType as OpenMlsCredential};
4use openmls_traits::OpenMlsCryptoProvider;
5
6use core_crypto_keystore::connection::FetchFromDatabase;
7use core_crypto_keystore::{entities::MlsKeyPackage, CryptoKeystoreMls};
8use mls_crypto_provider::MlsCryptoProvider;
9
10use crate::context::CentralContext;
11use crate::e2e_identity::init_certificates::NewCrlDistributionPoint;
12#[cfg(not(target_family = "wasm"))]
13use crate::e2e_identity::refresh_token::RefreshToken;
14use crate::{
15 mls::credential::{ext::CredentialExt, x509::CertificatePrivateKey, CredentialBundle},
16 prelude::{
17 CertificateBundle, Client, ConversationId, CryptoError, CryptoResult, E2eIdentityError, E2eiEnrollment,
18 MlsCiphersuite, MlsCommitBundle, MlsConversation, MlsCredentialType,
19 },
20 MlsError,
21};
22
23impl CentralContext {
24 pub async fn e2ei_new_activation_enrollment(
30 &self,
31 display_name: String,
32 handle: String,
33 team: Option<String>,
34 expiry_sec: u32,
35 ciphersuite: MlsCiphersuite,
36 ) -> CryptoResult<E2eiEnrollment> {
37 let mls_provider = self.mls_provider().await?;
38 let cb = self
40 .mls_client()
41 .await?
42 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), MlsCredentialType::Basic)
43 .await
44 .map_err(|_| E2eIdentityError::MissingExistingClient(MlsCredentialType::Basic))?;
45 let client_id = cb.credential().identity().into();
46
47 let sign_keypair = Some((&cb.signature_key).try_into()?);
48
49 E2eiEnrollment::try_new(
50 client_id,
51 display_name,
52 handle,
53 team,
54 expiry_sec,
55 &mls_provider,
56 ciphersuite,
57 sign_keypair,
58 #[cfg(not(target_family = "wasm"))]
59 None, )
61 }
62
63 pub async fn e2ei_new_rotate_enrollment(
69 &self,
70 display_name: Option<String>,
71 handle: Option<String>,
72 team: Option<String>,
73 expiry_sec: u32,
74 ciphersuite: MlsCiphersuite,
75 ) -> CryptoResult<E2eiEnrollment> {
76 let mls_provider = self.mls_provider().await?;
77 let cb = self
79 .mls_client()
80 .await?
81 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), MlsCredentialType::X509)
82 .await
83 .map_err(|_| E2eIdentityError::MissingExistingClient(MlsCredentialType::X509))?;
84 let client_id = cb.credential().identity().into();
85 let sign_keypair = Some((&cb.signature_key).try_into()?);
86 let existing_identity = cb
87 .to_mls_credential_with_key()
88 .extract_identity(ciphersuite, None)?
89 .x509_identity
90 .ok_or(E2eIdentityError::ImplementationError)?;
91
92 let display_name = display_name.unwrap_or(existing_identity.display_name);
93 let handle = handle.unwrap_or(existing_identity.handle);
94
95 E2eiEnrollment::try_new(
96 client_id,
97 display_name,
98 handle,
99 team,
100 expiry_sec,
101 &mls_provider,
102 ciphersuite,
103 sign_keypair,
104 #[cfg(not(target_family = "wasm"))]
105 Some(RefreshToken::find(&mls_provider.keystore()).await?), )
107 }
108
109 pub async fn e2ei_rotate_all(
113 &self,
114 enrollment: &mut E2eiEnrollment,
115 certificate_chain: String,
116 new_key_packages_count: usize,
117 ) -> CryptoResult<MlsRotateBundle> {
118 let sk = enrollment.get_sign_key_for_mls()?;
119 let cs = enrollment.ciphersuite;
120 let certificate_chain = enrollment
121 .certificate_response(
122 certificate_chain,
123 self.mls_provider()
124 .await?
125 .authentication_service()
126 .borrow()
127 .await
128 .as_ref()
129 .ok_or(CryptoError::ConsumerError)?,
130 )
131 .await?;
132
133 let private_key = CertificatePrivateKey {
134 value: sk,
135 signature_scheme: cs.signature_algorithm(),
136 };
137
138 let crl_new_distribution_points = self.extract_dp_on_init(&certificate_chain[..]).await?;
139
140 let cert_bundle = CertificateBundle {
141 certificate_chain,
142 private_key,
143 };
144 let client = &self.mls_client().await?;
145
146 let new_cb = client
147 .save_new_x509_credential_bundle(
148 &self.mls_provider().await?.keystore(),
149 cs.signature_algorithm(),
150 cert_bundle,
151 )
152 .await?;
153
154 let commits = self.e2ei_update_all(client, &new_cb).await?;
155
156 let key_package_refs_to_remove = self.find_key_packages_to_remove(&new_cb).await?;
157
158 let new_key_packages = client
159 .generate_new_keypackages(&self.mls_provider().await?, cs, &new_cb, new_key_packages_count)
160 .await?;
161
162 Ok(MlsRotateBundle {
163 commits,
164 new_key_packages,
165 key_package_refs_to_remove,
166 crl_new_distribution_points,
167 })
168 }
169
170 async fn find_key_packages_to_remove(&self, cb: &CredentialBundle) -> CryptoResult<Vec<KeyPackageRef>> {
171 let transaction = self.keystore().await?;
172 let nb_kp = transaction.count::<MlsKeyPackage>().await?;
173 let kps: Vec<KeyPackage> = transaction.mls_fetch_keypackages(nb_kp as u32).await?;
174
175 let mut kp_refs = vec![];
176
177 let provider = self.mls_provider().await?;
178 for kp in kps {
179 let kp_cred = kp.leaf_node().credential().mls_credential();
180 let local_cred = cb.credential().mls_credential();
181 let mut push_kpr = || {
182 let kpr = kp.hash_ref(provider.crypto()).map_err(MlsError::from)?;
183 kp_refs.push(kpr);
184 CryptoResult::Ok(())
185 };
186
187 match (kp_cred, local_cred) {
188 (_, OpenMlsCredential::Basic(_)) => return Err(CryptoError::ImplementationError),
189 (OpenMlsCredential::X509(kp_cert), OpenMlsCredential::X509(local_cert)) if kp_cert != local_cert => {
190 push_kpr()?
191 }
192 (OpenMlsCredential::Basic(_), _) => push_kpr()?,
193 _ => {}
194 }
195 }
196 Ok(kp_refs)
197 }
198
199 async fn e2ei_update_all(
200 &self,
201 client: &Client,
202 cb: &CredentialBundle,
203 ) -> CryptoResult<HashMap<ConversationId, MlsCommitBundle>> {
204 let all_conversations = self.get_all_conversations().await?;
205
206 let mut commits = HashMap::with_capacity(all_conversations.len());
207 for conv in all_conversations {
208 let mut conv = conv.write().await;
209 let id = conv.id().clone();
210 let commit = conv.e2ei_rotate(&self.mls_provider().await?, client, Some(cb)).await?;
211 let _ = commits.insert(id, commit);
212 }
213 Ok(commits)
214 }
215
216 pub async fn e2ei_rotate(
220 &self,
221 id: &crate::prelude::ConversationId,
222 cb: Option<&CredentialBundle>,
223 ) -> CryptoResult<MlsCommitBundle> {
224 let client = &self.mls_client().await?;
225 self.get_conversation(id)
226 .await?
227 .write()
228 .await
229 .e2ei_rotate(&self.mls_provider().await?, client, cb)
230 .await
231 }
232}
233
234impl MlsConversation {
235 #[cfg_attr(test, crate::durable)]
236 pub(crate) async fn e2ei_rotate(
237 &mut self,
238 backend: &MlsCryptoProvider,
239 client: &Client,
240 cb: Option<&CredentialBundle>,
241 ) -> CryptoResult<MlsCommitBundle> {
242 let cb = match cb {
243 Some(cb) => cb,
244 None => &client
245 .find_most_recent_credential_bundle(self.ciphersuite().signature_algorithm(), MlsCredentialType::X509)
246 .await
247 .map_err(|_| E2eIdentityError::MissingExistingClient(MlsCredentialType::X509))?,
248 };
249 let mut leaf_node = self.group.own_leaf().ok_or(CryptoError::InternalMlsError)?.clone();
250 leaf_node.set_credential_with_key(cb.to_mls_credential_with_key());
251 self.update_keying_material(client, backend, Some(cb), Some(leaf_node))
252 .await
253 }
254}
255
256#[derive(Debug, Clone)]
258pub struct MlsRotateBundle {
259 pub commits: HashMap<ConversationId, MlsCommitBundle>,
261 pub new_key_packages: Vec<KeyPackage>,
263 pub key_package_refs_to_remove: Vec<KeyPackageRef>,
265 pub crl_new_distribution_points: NewCrlDistributionPoint,
267}
268
269impl MlsRotateBundle {
270 #[allow(clippy::type_complexity)]
272 pub fn to_bytes(
273 self,
274 ) -> CryptoResult<(
275 HashMap<String, MlsCommitBundle>,
276 Vec<Vec<u8>>,
277 Vec<Vec<u8>>,
278 NewCrlDistributionPoint,
279 )> {
280 use openmls::prelude::TlsSerializeTrait as _;
281
282 let commits_size = self.commits.len();
283 let commits = self
284 .commits
285 .into_iter()
286 .try_fold(HashMap::with_capacity(commits_size), |mut acc, (id, c)| {
287 let id = hex::encode(id);
289 let _ = acc.insert(id, c);
290 CryptoResult::Ok(acc)
291 })?;
292
293 let kp_size = self.new_key_packages.len();
294 let new_key_packages =
295 self.new_key_packages
296 .into_iter()
297 .try_fold(Vec::with_capacity(kp_size), |mut acc, kp| {
298 acc.push(kp.tls_serialize_detached().map_err(MlsError::from)?);
299 CryptoResult::Ok(acc)
300 })?;
301 let key_package_refs_to_remove = self
302 .key_package_refs_to_remove
303 .into_iter()
304 .map(|r| r.as_slice().to_vec())
306 .collect::<Vec<_>>();
307 Ok((
308 commits,
309 new_key_packages,
310 key_package_refs_to_remove,
311 self.crl_new_distribution_points,
312 ))
313 }
314}
315
316#[cfg(test)]
317pub(crate) mod tests {
319 use std::collections::HashSet;
320
321 use openmls::prelude::SignaturePublicKey;
322 use tls_codec::Deserialize;
323 use wasm_bindgen_test::*;
324
325 use core_crypto_keystore::entities::{EntityFindParams, MlsCredential};
326
327 use crate::{
328 e2e_identity::tests::*,
329 mls::credential::ext::CredentialExt,
330 prelude::key_package::INITIAL_KEYING_MATERIAL_COUNT,
331 test_utils::{x509::X509TestChain, *},
332 };
333
334 use super::*;
335
336 wasm_bindgen_test_configure!(run_in_browser);
337
338 pub(crate) mod all {
339 use openmls_traits::types::SignatureScheme;
340
341 use crate::test_utils::context::TEAM;
342
343 use super::*;
344
345 pub(crate) async fn failsafe_ctx(
346 ctxs: &mut [&mut ClientContext],
347 sc: SignatureScheme,
348 ) -> std::sync::Arc<Option<X509TestChain>> {
349 let mut found_test_chain = None;
350 for ctx in ctxs.iter() {
351 if ctx.x509_test_chain.is_some() {
352 found_test_chain.replace(ctx.x509_test_chain.clone());
353 break;
354 }
355 }
356
357 let found_test_chain = found_test_chain.unwrap_or_else(|| Some(X509TestChain::init_empty(sc)).into());
358
359 for ctx in ctxs.iter_mut() {
361 if ctx.x509_test_chain.is_none() {
362 ctx.replace_x509_chain(found_test_chain.clone());
363 }
364 }
365
366 let x509_test_chain = found_test_chain.as_ref().as_ref().unwrap();
367
368 for ctx in ctxs {
369 let _ = x509_test_chain.register_with_central(&ctx.context).await;
370 }
371
372 found_test_chain
373 }
374
375 #[apply(all_cred_cipher)]
376 #[wasm_bindgen_test]
377 async fn enrollment_should_rotate_all(case: TestCase) {
378 run_test_with_client_ids(
379 case.clone(),
380 ["alice", "bob", "charlie"],
381 move |[mut alice_central, mut bob_central, mut charlie_central]| {
382 Box::pin(async move {
383 const N: usize = 50;
384 const NB_KEY_PACKAGE: usize = 50;
385
386 let mut ids = vec![];
387
388 let x509_test_chain_arc = failsafe_ctx(
389 &mut [&mut alice_central, &mut bob_central, &mut charlie_central],
390 case.signature_scheme(),
391 )
392 .await;
393
394 let x509_test_chain = x509_test_chain_arc.as_ref().as_ref().unwrap();
395
396 for _ in 0..N {
397 let id = conversation_id();
398 alice_central
399 .context
400 .new_conversation(&id, case.credential_type, case.cfg.clone())
401 .await
402 .unwrap();
403 alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
404 ids.push(id)
405 }
406
407 let before_rotate = alice_central.context.count_entities().await;
409 assert_eq!(before_rotate.key_package, INITIAL_KEYING_MATERIAL_COUNT);
410
411 assert_eq!(before_rotate.hpke_private_key, INITIAL_KEYING_MATERIAL_COUNT);
412
413 assert_eq!(before_rotate.encryption_keypair, INITIAL_KEYING_MATERIAL_COUNT);
415
416 assert_eq!(before_rotate.credential, 1);
417 let old_credential = alice_central
418 .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
419 .await
420 .unwrap()
421 .clone();
422
423 let is_renewal = case.credential_type == MlsCredentialType::X509;
424
425 let (mut enrollment, cert) = e2ei_enrollment(
426 &mut alice_central,
427 &case,
428 x509_test_chain,
429 None,
430 is_renewal,
431 init_activation_or_rotation,
432 noop_restore,
433 )
434 .await
435 .unwrap();
436
437 let rotate_bundle = alice_central
438 .context
439 .e2ei_rotate_all(&mut enrollment, cert, NB_KEY_PACKAGE)
440 .await
441 .unwrap();
442
443 let after_rotate = alice_central.context.count_entities().await;
444 assert_eq!(after_rotate.key_package - before_rotate.key_package, NB_KEY_PACKAGE);
446
447 assert_eq!(after_rotate.credential - before_rotate.credential, 1);
449
450 for (id, commit) in rotate_bundle.commits.into_iter() {
451 let decrypted = bob_central
452 .context
453 .decrypt_message(&id, commit.commit.to_bytes().unwrap())
454 .await
455 .unwrap();
456 alice_central.verify_sender_identity(&case, &decrypted).await;
457
458 alice_central.context.commit_accepted(&id).await.unwrap();
459 alice_central
460 .verify_local_credential_rotated(&id, NEW_HANDLE, NEW_DISPLAY_NAME)
461 .await;
462 }
463
464 let new_credentials = rotate_bundle
466 .new_key_packages
467 .iter()
468 .map(|kp| kp.leaf_node().to_credential_with_key());
469 for c in new_credentials {
470 assert_eq!(c.credential.credential_type(), openmls::prelude::CredentialType::X509);
471 let identity = c.extract_identity(case.ciphersuite(), None).unwrap();
472 assert_eq!(identity.x509_identity.as_ref().unwrap().display_name, NEW_DISPLAY_NAME);
473 assert_eq!(
474 identity.x509_identity.as_ref().unwrap().handle,
475 format!("wireapp://%40{NEW_HANDLE}@world.com")
476 );
477 }
478
479 assert!(alice_central
483 .find_credential_bundle(
484 case.signature_scheme(),
485 case.credential_type,
486 &old_credential.signature_key.public().into()
487 )
488 .await
489 .is_some());
490
491 let before_delete = alice_central.context.count_entities().await;
493 assert_eq!(
494 before_delete.hpke_private_key - before_rotate.hpke_private_key,
495 NB_KEY_PACKAGE
496 );
497
498 assert_eq!(before_delete.key_package - before_rotate.key_package, NB_KEY_PACKAGE);
500
501 assert!(alice_central
503 .find_signature_keypair_from_keystore(old_credential.signature_key.public())
504 .await
505 .is_some());
506
507 alice_central
510 .context
511 .delete_keypackages(&rotate_bundle.key_package_refs_to_remove[..])
512 .await
513 .unwrap();
514
515 let nb_x509_kp = alice_central
517 .count_key_package(case.ciphersuite(), Some(MlsCredentialType::X509))
518 .await;
519 assert_eq!(nb_x509_kp, NB_KEY_PACKAGE);
520 let nb_basic_kp = alice_central
522 .count_key_package(case.ciphersuite(), Some(MlsCredentialType::Basic))
523 .await;
524 assert_eq!(nb_basic_kp, 0);
525
526 let after_delete = alice_central.context.count_entities().await;
530 assert_eq!(after_delete.credential, 1);
531 assert!(alice_central
532 .find_credential_from_keystore(&old_credential)
533 .await
534 .is_none());
535
536 assert_eq!(after_delete.hpke_private_key, NB_KEY_PACKAGE);
538
539 assert_eq!(
541 after_rotate.encryption_keypair - after_delete.encryption_keypair,
542 INITIAL_KEYING_MATERIAL_COUNT
543 );
544
545 let id = conversation_id();
547 charlie_central
548 .context
549 .new_conversation(&id, case.credential_type, case.cfg.clone())
550 .await
551 .unwrap();
552 let alice = alice_central
554 .rand_key_package_of_type(&case, MlsCredentialType::X509)
555 .await;
556 charlie_central
557 .invite_all_members(&case, &id, [(&alice_central, alice)])
558 .await
559 .unwrap();
560 })
561 },
562 )
563 .await
564 }
565
566 #[apply(all_cred_cipher)]
567 #[wasm_bindgen_test]
568 async fn should_restore_credentials_in_order(case: TestCase) {
569 run_test_with_client_ids(case.clone(), ["alice"], move |[mut alice_central]| {
570 Box::pin(async move {
571 let x509_test_chain_arc = failsafe_ctx(&mut [&mut alice_central], case.signature_scheme()).await;
572
573 let x509_test_chain = x509_test_chain_arc.as_ref().as_ref().unwrap();
574
575 let id = conversation_id();
576 alice_central
577 .context
578 .new_conversation(&id, case.credential_type, case.cfg.clone())
579 .await
580 .unwrap();
581
582 let old_cb = alice_central
583 .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
584 .await
585 .unwrap()
586 .clone();
587
588 async_std::task::sleep(core::time::Duration::from_secs(1)).await;
591
592 let is_renewal = case.credential_type == MlsCredentialType::X509;
593
594 let (mut enrollment, cert) = e2ei_enrollment(
595 &mut alice_central,
596 &case,
597 x509_test_chain,
598 None,
599 is_renewal,
600 init_activation_or_rotation,
601 noop_restore,
602 )
603 .await
604 .unwrap();
605
606 alice_central
607 .context
608 .e2ei_rotate_all(&mut enrollment, cert, 10)
609 .await
610 .unwrap();
611
612 alice_central.context.commit_accepted(&id).await.unwrap();
613
614 let cb = alice_central
616 .find_most_recent_credential_bundle(case.signature_scheme(), MlsCredentialType::X509)
617 .await
618 .unwrap();
619 let identity = cb
620 .to_mls_credential_with_key()
621 .extract_identity(case.ciphersuite(), None)
622 .unwrap();
623 assert_eq!(identity.x509_identity.as_ref().unwrap().display_name, NEW_DISPLAY_NAME);
624 assert_eq!(
625 identity.x509_identity.as_ref().unwrap().handle,
626 format!("wireapp://%40{NEW_HANDLE}@world.com")
627 );
628
629 let old_spk = SignaturePublicKey::from(old_cb.signature_key.public());
631 let old_cb_found = alice_central
632 .find_credential_bundle(case.signature_scheme(), case.credential_type, &old_spk)
633 .await
634 .unwrap();
635 assert_eq!(old_cb, old_cb_found);
636 let (cid, all_credentials, scs, old_nb_identities) = {
637 let alice_client = alice_central.client().await;
638 let old_nb_identities = alice_client.identities_count().await.unwrap();
639
640 let cid = alice_client.id().await.unwrap();
642 let scs = HashSet::from([case.signature_scheme()]);
643 let all_credentials = alice_central
644 .context
645 .keystore()
646 .await
647 .unwrap()
648 .find_all::<MlsCredential>(EntityFindParams::default())
649 .await
650 .unwrap()
651 .into_iter()
652 .map(|c| {
653 let credential =
654 openmls::prelude::Credential::tls_deserialize(&mut c.credential.as_slice())
655 .unwrap();
656 (credential, c.created_at)
657 })
658 .collect::<Vec<_>>();
659 assert_eq!(all_credentials.len(), 2);
660 (cid, all_credentials, scs, old_nb_identities)
661 };
662 let backend = &alice_central.context.mls_provider().await.unwrap();
663 backend.keystore().commit_transaction().await.unwrap();
664 backend.keystore().new_transaction().await.unwrap();
665
666 let new_client = Client::default();
667
668 new_client.load(backend, &cid, all_credentials, scs).await.unwrap();
669
670 let cb = new_client
672 .find_most_recent_credential_bundle(case.signature_scheme(), MlsCredentialType::X509)
673 .await
674 .unwrap();
675 let identity = cb
676 .to_mls_credential_with_key()
677 .extract_identity(case.ciphersuite(), None)
678 .unwrap();
679
680 assert_eq!(identity.x509_identity.as_ref().unwrap().display_name, NEW_DISPLAY_NAME);
681 assert_eq!(
682 identity.x509_identity.as_ref().unwrap().handle,
683 format!("wireapp://%40{NEW_HANDLE}@world.com")
684 );
685
686 assert_eq!(new_client.identities_count().await.unwrap(), old_nb_identities);
687 })
688 })
689 .await
690 }
691
692 #[apply(all_cred_cipher)]
693 #[wasm_bindgen_test]
694 async fn rotate_should_roundtrip(case: TestCase) {
695 run_test_with_client_ids(
696 case.clone(),
697 ["alice", "bob"],
698 move |[mut alice_central, mut bob_central]| {
699 Box::pin(async move {
700 let x509_test_chain_arc =
701 failsafe_ctx(&mut [&mut alice_central, &mut bob_central], case.signature_scheme()).await;
702
703 let x509_test_chain = x509_test_chain_arc.as_ref().as_ref().unwrap();
704
705 let id = conversation_id();
706 alice_central
707 .context
708 .new_conversation(&id, case.credential_type, case.cfg.clone())
709 .await
710 .unwrap();
711
712 alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
713 const ALICE_NEW_HANDLE: &str = "new_alice_wire";
715 const ALICE_NEW_DISPLAY_NAME: &str = "New Alice Smith";
716
717 fn init_alice(wrapper: E2eiInitWrapper) -> InitFnReturn<'_> {
718 Box::pin(async move {
719 let E2eiInitWrapper { context: cc, case } = wrapper;
720 let cs = case.ciphersuite();
721 match case.credential_type {
722 MlsCredentialType::Basic => {
723 cc.e2ei_new_activation_enrollment(
724 ALICE_NEW_DISPLAY_NAME.to_string(),
725 ALICE_NEW_HANDLE.to_string(),
726 Some(TEAM.to_string()),
727 E2EI_EXPIRY,
728 cs,
729 )
730 .await
731 }
732 MlsCredentialType::X509 => {
733 cc.e2ei_new_rotate_enrollment(
734 Some(ALICE_NEW_DISPLAY_NAME.to_string()),
735 Some(ALICE_NEW_HANDLE.to_string()),
736 Some(TEAM.to_string()),
737 E2EI_EXPIRY,
738 cs,
739 )
740 .await
741 }
742 }
743 })
744 }
745
746 let is_renewal = case.credential_type == MlsCredentialType::X509;
747
748 let (mut enrollment, cert) = e2ei_enrollment(
749 &mut alice_central,
750 &case,
751 x509_test_chain,
752 None,
753 is_renewal,
754 init_alice,
755 noop_restore,
756 )
757 .await
758 .unwrap();
759
760 let rotate_bundle = alice_central
761 .context
762 .e2ei_rotate_all(&mut enrollment, cert, 10)
763 .await
764 .unwrap();
765
766 let commit = &rotate_bundle.commits.get(&id).unwrap().commit;
767
768 let decrypted = bob_central
769 .context
770 .decrypt_message(&id, commit.to_bytes().unwrap())
771 .await
772 .unwrap();
773 alice_central.verify_sender_identity(&case, &decrypted).await;
774
775 alice_central.context.commit_accepted(&id).await.unwrap();
776 alice_central
777 .verify_local_credential_rotated(&id, ALICE_NEW_HANDLE, ALICE_NEW_DISPLAY_NAME)
778 .await;
779
780 const BOB_NEW_HANDLE: &str = "new_bob_wire";
782 const BOB_NEW_DISPLAY_NAME: &str = "New Bob Smith";
783
784 fn init_bob(wrapper: E2eiInitWrapper) -> InitFnReturn<'_> {
785 Box::pin(async move {
786 let E2eiInitWrapper { context: cc, case } = wrapper;
787 let cs = case.ciphersuite();
788 match case.credential_type {
789 MlsCredentialType::Basic => {
790 cc.e2ei_new_activation_enrollment(
791 BOB_NEW_DISPLAY_NAME.to_string(),
792 BOB_NEW_HANDLE.to_string(),
793 Some(TEAM.to_string()),
794 E2EI_EXPIRY,
795 cs,
796 )
797 .await
798 }
799 MlsCredentialType::X509 => {
800 cc.e2ei_new_rotate_enrollment(
801 Some(BOB_NEW_DISPLAY_NAME.to_string()),
802 Some(BOB_NEW_HANDLE.to_string()),
803 Some(TEAM.to_string()),
804 E2EI_EXPIRY,
805 cs,
806 )
807 .await
808 }
809 }
810 })
811 }
812 let is_renewal = case.credential_type == MlsCredentialType::X509;
813
814 let (mut enrollment, cert) = e2ei_enrollment(
815 &mut bob_central,
816 &case,
817 x509_test_chain,
818 None,
819 is_renewal,
820 init_bob,
821 noop_restore,
822 )
823 .await
824 .unwrap();
825
826 let rotate_bundle = bob_central
827 .context
828 .e2ei_rotate_all(&mut enrollment, cert, 10)
829 .await
830 .unwrap();
831
832 let commit = &rotate_bundle.commits.get(&id).unwrap().commit;
833
834 let decrypted = alice_central
835 .context
836 .decrypt_message(&id, commit.to_bytes().unwrap())
837 .await
838 .unwrap();
839 bob_central.verify_sender_identity(&case, &decrypted).await;
840
841 bob_central.context.commit_accepted(&id).await.unwrap();
842 bob_central
843 .verify_local_credential_rotated(&id, BOB_NEW_HANDLE, BOB_NEW_DISPLAY_NAME)
844 .await;
845 })
846 },
847 )
848 .await
849 }
850 }
851
852 mod one {
853 use super::*;
854
855 #[apply(all_cred_cipher)]
856 #[wasm_bindgen_test]
857 pub async fn should_rotate_one_conversations_credential(case: TestCase) {
858 if case.is_x509() {
859 run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice_central, bob_central]| {
860 Box::pin(async move {
861 let id = conversation_id();
862 alice_central
863 .context
864 .new_conversation(&id, case.credential_type, case.cfg.clone())
865 .await
866 .unwrap();
867
868 alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
869
870 let init_count = alice_central.context.count_entities().await;
871 let x509_test_chain = alice_central.x509_test_chain.as_ref().as_ref().unwrap();
872
873 let intermediate_ca = x509_test_chain.find_local_intermediate_ca();
874 let alice_og_cert = &x509_test_chain
875 .actors
876 .iter()
877 .find(|actor| actor.name == "alice")
878 .unwrap()
879 .certificate;
880
881 let alice_cid = alice_central.get_client_id().await;
883 let (new_handle, new_display_name) = ("new_alice_wire", "New Alice Smith");
884 let cb = alice_central
885 .rotate_credential(&case, new_handle, new_display_name, alice_og_cert, intermediate_ca)
886 .await;
887
888 let alice_old_identities = alice_central
890 .context
891 .get_device_identities(&id, &[alice_cid])
892 .await
893 .unwrap();
894 let alice_old_identity = alice_old_identities.first().unwrap();
895 assert_ne!(
896 alice_old_identity.x509_identity.as_ref().unwrap().display_name,
897 new_display_name
898 );
899 assert_ne!(
900 alice_old_identity.x509_identity.as_ref().unwrap().handle,
901 format!("{new_handle}@world.com")
902 );
903
904 let commit = alice_central.context.e2ei_rotate(&id, Some(&cb)).await.unwrap();
906
907 let decrypted = bob_central
909 .context
910 .decrypt_message(&id, commit.commit.to_bytes().unwrap())
911 .await
912 .unwrap();
913 alice_central.verify_sender_identity(&case, &decrypted).await;
915
916 alice_central.context.commit_accepted(&id).await.unwrap();
918 alice_central
919 .verify_local_credential_rotated(&id, new_handle, new_display_name)
920 .await;
921
922 let final_count = alice_central.context.count_entities().await;
923 assert_eq!(init_count.encryption_keypair, final_count.encryption_keypair);
924 assert_eq!(
925 init_count.epoch_encryption_keypair,
926 final_count.epoch_encryption_keypair
927 );
928 assert_eq!(init_count.key_package, final_count.key_package);
929 })
930 })
931 .await
932 }
933 }
934
935 #[apply(all_cred_cipher)]
936 #[wasm_bindgen_test]
937 pub async fn rotate_should_be_renewable_when_commit_denied(case: TestCase) {
938 if case.is_x509() {
939 run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice_central, bob_central]| {
940 Box::pin(async move {
941 let id = conversation_id();
942 alice_central
943 .context
944 .new_conversation(&id, case.credential_type, case.cfg.clone())
945 .await
946 .unwrap();
947
948 alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
949
950 let init_count = alice_central.context.count_entities().await;
951
952 let x509_test_chain = alice_central.x509_test_chain.as_ref().as_ref().unwrap();
953
954 let intermediate_ca = x509_test_chain.find_local_intermediate_ca();
955
956 let (new_handle, new_display_name) = ("new_alice_wire", "New Alice Smith");
961 let cb = alice_central
962 .rotate_credential(
963 &case,
964 new_handle,
965 new_display_name,
966 x509_test_chain.find_certificate_for_actor("alice").unwrap(),
967 intermediate_ca,
968 )
969 .await;
970
971 let _rotate_commit = alice_central.context.e2ei_rotate(&id, Some(&cb)).await.unwrap();
973
974 let bob_commit = bob_central.context.update_keying_material(&id).await.unwrap();
976 bob_central.context.commit_accepted(&id).await.unwrap();
978
979 let decrypted = alice_central
981 .context
982 .decrypt_message(&id, bob_commit.commit.to_bytes().unwrap())
983 .await
984 .unwrap();
985
986 assert_eq!(decrypted.proposals.len(), 1);
988 let renewed_proposal = decrypted.proposals.first().unwrap();
989 bob_central
990 .context
991 .decrypt_message(&id, renewed_proposal.proposal.to_bytes().unwrap())
992 .await
993 .unwrap();
994
995 let rotate_commit = alice_central
996 .context
997 .commit_pending_proposals(&id)
998 .await
999 .unwrap()
1000 .unwrap();
1001
1002 alice_central.context.commit_accepted(&id).await.unwrap();
1004 alice_central
1005 .verify_local_credential_rotated(&id, new_handle, new_display_name)
1006 .await;
1007
1008 let decrypted = bob_central
1010 .context
1011 .decrypt_message(&id, rotate_commit.commit.to_bytes().unwrap())
1012 .await
1013 .unwrap();
1014 alice_central.verify_sender_identity(&case, &decrypted).await;
1015
1016 let final_count = alice_central.context.count_entities().await;
1017 assert_eq!(init_count.encryption_keypair, final_count.encryption_keypair);
1018 })
1024 })
1025 .await
1026 }
1027 }
1028
1029 #[apply(all_cred_cipher)]
1030 #[wasm_bindgen_test]
1031 pub async fn rotate_should_replace_existing_basic_credentials(case: TestCase) {
1032 if case.is_x509() {
1033 run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice_central, bob_central]| {
1034 Box::pin(async move {
1035 let id = conversation_id();
1036 alice_central
1037 .context
1038 .new_conversation(&id, MlsCredentialType::Basic, case.cfg.clone())
1039 .await
1040 .unwrap();
1041
1042 alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
1043
1044 let x509_test_chain = alice_central.x509_test_chain.as_ref().as_ref().unwrap();
1045 let intermediate_ca = x509_test_chain.find_local_intermediate_ca();
1046 let alice_og_cert = &x509_test_chain
1047 .actors
1048 .iter()
1049 .find(|actor| actor.name == "alice")
1050 .unwrap()
1051 .certificate;
1052
1053 let alice_cid = alice_central.get_client_id().await;
1055 let (new_handle, new_display_name) = ("new_alice_wire", "New Alice Smith");
1056 alice_central
1057 .rotate_credential(&case, new_handle, new_display_name, alice_og_cert, intermediate_ca)
1058 .await;
1059
1060 let alice_old_identities = alice_central
1062 .context
1063 .get_device_identities(&id, &[alice_cid])
1064 .await
1065 .unwrap();
1066 let alice_old_identity = alice_old_identities.first().unwrap();
1067 assert_eq!(alice_old_identity.credential_type, MlsCredentialType::Basic);
1068 assert_eq!(alice_old_identity.x509_identity, None);
1069
1070 let commit = alice_central.context.e2ei_rotate(&id, None).await.unwrap();
1072
1073 let decrypted = bob_central
1075 .context
1076 .decrypt_message(&id, commit.commit.to_bytes().unwrap())
1077 .await
1078 .unwrap();
1079 alice_central.verify_sender_identity(&case, &decrypted).await;
1081
1082 alice_central.context.commit_accepted(&id).await.unwrap();
1084 alice_central
1085 .verify_local_credential_rotated(&id, new_handle, new_display_name)
1086 .await;
1087 })
1088 })
1089 .await
1090 }
1091 }
1092 }
1093}