1use openmls::prelude::{Credential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime};
18use openmls_traits::OpenMlsCryptoProvider;
19use std::collections::HashMap;
20use std::ops::{Deref, DerefMut};
21use tls_codec::{Deserialize, Serialize};
22
23use core_crypto_keystore::{
24 connection::FetchFromDatabase,
25 entities::{EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage},
26};
27use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
28
29use super::{Error, Result};
30use crate::{KeystoreError, MlsError};
31use crate::{
32 RecursiveError,
33 context::CentralContext,
34 mls::{client::ClientInner, credential::CredentialBundle},
35 prelude::{Client, MlsCiphersuite, MlsConversationConfiguration, MlsCredentialType},
36};
37
38#[cfg(not(test))]
40pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 100;
41#[cfg(test)]
43pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
44
45pub(crate) const KEYPACKAGE_DEFAULT_LIFETIME: std::time::Duration =
47 std::time::Duration::from_secs(60 * 60 * 24 * 28 * 3); impl Client {
50 pub async fn generate_one_keypackage_from_credential_bundle(
58 &self,
59 backend: &MlsCryptoProvider,
60 cs: MlsCiphersuite,
61 cb: &CredentialBundle,
62 ) -> Result<KeyPackage> {
63 match self.state.read().await.deref() {
64 None => Err(Error::MlsNotInitialized),
65 Some(ClientInner {
66 keypackage_lifetime, ..
67 }) => {
68 let keypackage = KeyPackage::builder()
69 .leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
70 .key_package_lifetime(Lifetime::new(keypackage_lifetime.as_secs()))
71 .build(
72 CryptoConfig {
73 ciphersuite: cs.into(),
74 version: openmls::versions::ProtocolVersion::default(),
75 },
76 backend,
77 &cb.signature_key,
78 CredentialWithKey {
79 credential: cb.credential.clone(),
80 signature_key: cb.signature_key.public().into(),
81 },
82 )
83 .await
84 .map_err(KeystoreError::wrap("building keypackage"))?;
85
86 Ok(keypackage)
87 }
88 }
89 }
90
91 pub async fn request_key_packages(
102 &self,
103 count: usize,
104 ciphersuite: MlsCiphersuite,
105 credential_type: MlsCredentialType,
106 backend: &MlsCryptoProvider,
107 ) -> Result<Vec<KeyPackage>> {
108 self.prune_keypackages(backend, &[]).await?;
110 use core_crypto_keystore::CryptoKeystoreMls as _;
111
112 let mut existing_kps = backend
113 .key_store()
114 .mls_fetch_keypackages::<KeyPackage>(count as u32)
115 .await.map_err(KeystoreError::wrap("fetching mls keypackages"))?
116 .into_iter()
117 .filter(|kp|
119 kp.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(kp.leaf_node().credential().credential_type()) == credential_type)
120 .collect::<Vec<_>>();
121
122 let kpb_count = existing_kps.len();
123 let mut kps = if count > kpb_count {
124 let to_generate = count - kpb_count;
125 let cb = self
126 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
127 .await?;
128 self.generate_new_keypackages(backend, ciphersuite, &cb, to_generate)
129 .await?
130 } else {
131 vec![]
132 };
133
134 existing_kps.reverse();
135
136 kps.append(&mut existing_kps);
137 Ok(kps)
138 }
139
140 pub(crate) async fn generate_new_keypackages(
141 &self,
142 backend: &MlsCryptoProvider,
143 ciphersuite: MlsCiphersuite,
144 cb: &CredentialBundle,
145 count: usize,
146 ) -> Result<Vec<KeyPackage>> {
147 let mut kps = Vec::with_capacity(count);
148
149 for _ in 0..count {
150 let kp = self
151 .generate_one_keypackage_from_credential_bundle(backend, ciphersuite, cb)
152 .await?;
153 kps.push(kp);
154 }
155
156 Ok(kps)
157 }
158
159 pub async fn valid_keypackages_count(
161 &self,
162 backend: &MlsCryptoProvider,
163 ciphersuite: MlsCiphersuite,
164 credential_type: MlsCredentialType,
165 ) -> Result<usize> {
166 let kps: Vec<MlsKeyPackage> = backend
167 .key_store()
168 .find_all(EntityFindParams::default())
169 .await
170 .map_err(KeystoreError::wrap("finding all key packages"))?;
171
172 let mut valid_count = 0;
173 for kp in kps
174 .into_iter()
175 .map(|kp| core_crypto_keystore::deser::<KeyPackage>(&kp.keypackage))
176 .filter(|kp| {
178 kp.as_ref()
179 .map(|b| b.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(b.leaf_node().credential().credential_type()) == credential_type)
180 .unwrap_or_default()
181 })
182 {
183 let kp = kp.map_err(KeystoreError::wrap("counting valid keypackages"))?;
184 if !Self::is_mls_keypackage_expired(&kp) {
185 valid_count += 1;
186 }
187 }
188
189 Ok(valid_count)
190 }
191
192 fn is_mls_keypackage_expired(kp: &KeyPackage) -> bool {
195 let Some(lifetime) = kp.leaf_node().life_time() else {
196 return false;
197 };
198
199 !(lifetime.has_acceptable_range() && lifetime.is_valid())
200 }
201
202 pub async fn prune_keypackages(&self, backend: &MlsCryptoProvider, refs: &[KeyPackageRef]) -> Result<()> {
208 let keystore = backend.keystore();
209 let kps = self.find_all_keypackages(&keystore).await?;
210 let _ = self._prune_keypackages(&kps, &keystore, refs).await?;
211 Ok(())
212 }
213
214 pub(crate) async fn prune_keypackages_and_credential(
215 &mut self,
216 backend: &MlsCryptoProvider,
217 refs: &[KeyPackageRef],
218 ) -> Result<()> {
219 match self.state.write().await.deref_mut() {
220 None => Err(Error::MlsNotInitialized),
221 Some(ClientInner { identities, .. }) => {
222 let keystore = backend.key_store();
223 let kps = self.find_all_keypackages(keystore).await?;
224 let kp_to_delete = self._prune_keypackages(&kps, keystore, refs).await?;
225
226 let mut grouped_kps = HashMap::<Vec<u8>, Vec<KeyPackageRef>>::new();
228 for (_, kp) in &kps {
229 let cred = kp
230 .leaf_node()
231 .credential()
232 .tls_serialize_detached()
233 .map_err(Error::tls_serialize("keypackage"))?;
234 let kp_ref = kp
235 .hash_ref(backend.crypto())
236 .map_err(MlsError::wrap("computing keypackage hashref"))?;
237 grouped_kps
238 .entry(cred)
239 .and_modify(|kprfs| kprfs.push(kp_ref.clone()))
240 .or_insert(vec![kp_ref]);
241 }
242
243 for (credential, kps) in &grouped_kps {
244 let all_to_delete = kps.iter().all(|kpr| kp_to_delete.contains(&kpr.as_slice()));
246 if all_to_delete {
247 backend
249 .keystore()
250 .cred_delete_by_credential(credential.clone())
251 .await
252 .map_err(KeystoreError::wrap("deleting credential"))?;
253 let credential = Credential::tls_deserialize(&mut credential.as_slice())
254 .map_err(Error::tls_deserialize("credential"))?;
255 identities.remove(&credential).await?;
256 }
257 }
258 Ok(())
259 }
260 }
261 }
262
263 async fn _prune_keypackages<'a>(
268 &self,
269 kps: &'a [(MlsKeyPackage, KeyPackage)],
270 keystore: &CryptoKeystore,
271 refs: &[KeyPackageRef],
272 ) -> Result<Vec<&'a [u8]>, Error> {
273 let kp_to_delete: Vec<_> = kps
274 .iter()
275 .filter_map(|(store_kp, kp)| {
276 let is_expired = Self::is_mls_keypackage_expired(kp);
277 let mut to_delete = is_expired;
278 if !(is_expired || refs.is_empty()) {
279 to_delete = refs.iter().any(|r| r.as_slice() == store_kp.keypackage_ref);
282 }
283
284 to_delete.then_some((kp, &store_kp.keypackage_ref))
285 })
286 .collect();
287
288 for (kp, kp_ref) in &kp_to_delete {
289 keystore
291 .remove::<MlsKeyPackage, &[u8]>(kp_ref.as_slice())
292 .await
293 .map_err(KeystoreError::wrap("removing key package from keystore"))?;
294 keystore
295 .remove::<MlsHpkePrivateKey, &[u8]>(kp.hpke_init_key().as_slice())
296 .await
297 .map_err(KeystoreError::wrap("removing private key from keystore"))?;
298 keystore
299 .remove::<MlsEncryptionKeyPair, &[u8]>(kp.leaf_node().encryption_key().as_slice())
300 .await
301 .map_err(KeystoreError::wrap("removing encryption keypair from keystore"))?;
302 }
303
304 let kp_to_delete = kp_to_delete
305 .into_iter()
306 .map(|(_, kpref)| &kpref[..])
307 .collect::<Vec<_>>();
308
309 Ok(kp_to_delete)
310 }
311
312 async fn find_all_keypackages(&self, keystore: &CryptoKeystore) -> Result<Vec<(MlsKeyPackage, KeyPackage)>> {
313 let kps: Vec<MlsKeyPackage> = keystore
314 .find_all(EntityFindParams::default())
315 .await
316 .map_err(KeystoreError::wrap("finding all keypackages"))?;
317
318 let kps = kps
319 .into_iter()
320 .map(|raw_kp| -> Result<_> {
321 let kp = core_crypto_keystore::deser::<KeyPackage>(&raw_kp.keypackage)
322 .map_err(KeystoreError::wrap("deserializing keypackage"))?;
323 Ok((raw_kp, kp))
324 })
325 .collect::<Result<Vec<_>, _>>()?;
326
327 Ok(kps)
328 }
329
330 #[cfg(test)]
333 pub async fn set_keypackage_lifetime(&self, duration: std::time::Duration) -> Result<()> {
334 use std::ops::DerefMut;
335 match self.state.write().await.deref_mut() {
336 None => Err(Error::MlsNotInitialized),
337 Some(ClientInner {
338 keypackage_lifetime, ..
339 }) => {
340 *keypackage_lifetime = duration;
341 Ok(())
342 }
343 }
344 }
345}
346
347impl CentralContext {
348 pub async fn get_or_create_client_keypackages(
362 &self,
363 ciphersuite: MlsCiphersuite,
364 credential_type: MlsCredentialType,
365 amount_requested: usize,
366 ) -> Result<Vec<KeyPackage>> {
367 let client = self
368 .mls_client()
369 .await
370 .map_err(RecursiveError::root("getting mls client"))?;
371 client
372 .request_key_packages(
373 amount_requested,
374 ciphersuite,
375 credential_type,
376 &self
377 .mls_provider()
378 .await
379 .map_err(RecursiveError::root("getting mls provider"))?,
380 )
381 .await
382 }
383
384 #[cfg_attr(test, crate::idempotent)]
386 pub async fn client_valid_key_packages_count(
387 &self,
388 ciphersuite: MlsCiphersuite,
389 credential_type: MlsCredentialType,
390 ) -> Result<usize> {
391 let client = self
392 .mls_client()
393 .await
394 .map_err(RecursiveError::root("getting mls client"))?;
395 client
396 .valid_keypackages_count(
397 &self
398 .mls_provider()
399 .await
400 .map_err(RecursiveError::root("getting mls provider"))?,
401 ciphersuite,
402 credential_type,
403 )
404 .await
405 }
406
407 #[cfg_attr(test, crate::dispotent)]
410 pub async fn delete_keypackages(&self, refs: &[KeyPackageRef]) -> Result<()> {
411 if refs.is_empty() {
412 return Err(Error::EmptyKeypackageList);
413 }
414 let mut client = self
415 .mls_client()
416 .await
417 .map_err(RecursiveError::root("getting mls client"))?;
418 client
419 .prune_keypackages_and_credential(
420 &self
421 .mls_provider()
422 .await
423 .map_err(RecursiveError::root("getting mls provider"))?,
424 refs,
425 )
426 .await
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use openmls::prelude::{KeyPackage, KeyPackageIn, KeyPackageRef, ProtocolVersion};
433 use openmls_traits::OpenMlsCryptoProvider;
434 use openmls_traits::types::VerifiableCiphersuite;
435 use wasm_bindgen_test::*;
436
437 use mls_crypto_provider::MlsCryptoProvider;
438
439 use crate::e2e_identity::tests::{e2ei_enrollment, init_activation_or_rotation, noop_restore};
440 use crate::prelude::MlsConversationConfiguration;
441 use crate::prelude::key_package::INITIAL_KEYING_MATERIAL_COUNT;
442 use crate::test_utils::*;
443
444 use super::Client;
445
446 wasm_bindgen_test_configure!(run_in_browser);
447
448 #[apply(all_cred_cipher)]
449 #[wasm_bindgen_test]
450 async fn can_assess_keypackage_expiration(case: TestCase) {
451 let (cs, ct) = (case.ciphersuite(), case.credential_type);
452 let backend = MlsCryptoProvider::try_new_in_memory("test").await.unwrap();
453 let x509_test_chain = if case.is_x509() {
454 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
455 x509_test_chain.register_with_provider(&backend).await;
456 Some(x509_test_chain)
457 } else {
458 None
459 };
460
461 backend.new_transaction().await.unwrap();
462 let client = Client::random_generate(
463 &case,
464 &backend,
465 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
466 false,
467 )
468 .await
469 .unwrap();
470
471 let kp_std_exp = client.generate_one_keypackage(&backend, cs, ct).await.unwrap();
473 assert!(!Client::is_mls_keypackage_expired(&kp_std_exp));
474
475 client
477 .set_keypackage_lifetime(std::time::Duration::from_secs(1))
478 .await
479 .unwrap();
480 let kp_1s_exp = client.generate_one_keypackage(&backend, cs, ct).await.unwrap();
481 async_std::task::sleep(std::time::Duration::from_secs(2)).await;
483 assert!(Client::is_mls_keypackage_expired(&kp_1s_exp));
484 }
485
486 #[apply(all_cred_cipher)]
487 #[wasm_bindgen_test]
488 async fn requesting_x509_key_packages_after_basic(case: TestCase) {
489 if !case.is_basic() {
491 return;
492 }
493 run_test_with_client_ids(case.clone(), ["alice"], move |[mut client_context]| {
494 Box::pin(async move {
495 let signature_scheme = case.signature_scheme();
496 let cipher_suite = case.ciphersuite();
497
498 let _basic_key_packages = client_context
500 .context
501 .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::Basic, 5)
502 .await
503 .unwrap();
504
505 let test_chain = x509::X509TestChain::init_for_random_clients(signature_scheme, 1);
507
508 let (mut enrollment, cert_chain) = e2ei_enrollment(
509 &mut client_context,
510 &case,
511 &test_chain,
512 None,
513 false,
514 init_activation_or_rotation,
515 noop_restore,
516 )
517 .await
518 .unwrap();
519
520 let _rotate_bundle = client_context
521 .context
522 .save_x509_credential(&mut enrollment, cert_chain)
523 .await
524 .unwrap();
525
526 assert!(client_context.context.e2ei_is_enabled(signature_scheme).await.unwrap());
528
529 let x509_key_packages = client_context
531 .context
532 .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::X509, 5)
533 .await
534 .unwrap();
535
536 assert!(x509_key_packages.iter().all(|kp| MlsCredentialType::X509
538 == MlsCredentialType::from(kp.leaf_node().credential().credential_type())));
539 })
540 })
541 .await
542 }
543
544 #[apply(all_cred_cipher)]
545 #[wasm_bindgen_test]
546 async fn generates_correct_number_of_kpbs(case: TestCase) {
547 run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
548 Box::pin(async move {
549 const N: usize = 2;
550 const COUNT: usize = 109;
551
552 let init = cc.context.count_entities().await;
553 assert_eq!(init.key_package, INITIAL_KEYING_MATERIAL_COUNT);
554 assert_eq!(init.encryption_keypair, INITIAL_KEYING_MATERIAL_COUNT);
555 assert_eq!(init.hpke_private_key, INITIAL_KEYING_MATERIAL_COUNT);
556 assert_eq!(init.credential, 1);
557 assert_eq!(init.signature_keypair, 1);
558
559 let transactional_provider = cc.context.mls_provider().await.unwrap();
562 let crypto_provider = transactional_provider.crypto();
563 let mut pinned_kp = None;
564
565 let mut prev_kps: Option<Vec<KeyPackage>> = None;
566 for _ in 0..N {
567 let mut kps = cc
568 .context
569 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, COUNT + 1)
570 .await
571 .unwrap();
572
573 pinned_kp = Some(kps.pop().unwrap());
575
576 assert_eq!(kps.len(), COUNT);
577 let after_creation = cc.context.count_entities().await;
578 assert_eq!(after_creation.key_package, COUNT + 1);
579 assert_eq!(after_creation.encryption_keypair, COUNT + 1);
580 assert_eq!(after_creation.hpke_private_key, COUNT + 1);
581 assert_eq!(after_creation.credential, 1);
582
583 let kpbs_refs = kps
584 .iter()
585 .map(|kp| kp.hash_ref(crypto_provider).unwrap())
586 .collect::<Vec<KeyPackageRef>>();
587
588 if let Some(pkpbs) = prev_kps.replace(kps) {
589 let pkpbs_refs = pkpbs
590 .into_iter()
591 .map(|kpb| kpb.hash_ref(crypto_provider).unwrap())
592 .collect::<Vec<KeyPackageRef>>();
593
594 let has_duplicates = kpbs_refs.iter().any(|href| pkpbs_refs.contains(href));
595 assert!(!has_duplicates);
597 }
598 cc.context.delete_keypackages(&kpbs_refs).await.unwrap();
599 }
600
601 let count = cc
602 .context
603 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
604 .await
605 .unwrap();
606 assert_eq!(count, 1);
607
608 let pinned_kpr = pinned_kp.unwrap().hash_ref(crypto_provider).unwrap();
609 cc.context.delete_keypackages(&[pinned_kpr]).await.unwrap();
610 let count = cc
611 .context
612 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
613 .await
614 .unwrap();
615 assert_eq!(count, 0);
616 let after_delete = cc.context.count_entities().await;
617 assert_eq!(after_delete.key_package, 0);
618 assert_eq!(after_delete.encryption_keypair, 0);
619 assert_eq!(after_delete.hpke_private_key, 0);
620 assert_eq!(after_delete.credential, 0);
621 })
622 })
623 .await
624 }
625
626 #[apply(all_cred_cipher)]
627 #[wasm_bindgen_test]
628 async fn automatically_prunes_lifetime_expired_keypackages(case: TestCase) {
629 const UNEXPIRED_COUNT: usize = 125;
630 const EXPIRED_COUNT: usize = 200;
631 let backend = MlsCryptoProvider::try_new_in_memory("test").await.unwrap();
632 let x509_test_chain = if case.is_x509() {
633 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
634 x509_test_chain.register_with_provider(&backend).await;
635 Some(x509_test_chain)
636 } else {
637 None
638 };
639 backend.new_transaction().await.unwrap();
640 let client = Client::random_generate(
641 &case,
642 &backend,
643 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
644 false,
645 )
646 .await
647 .unwrap();
648
649 let unexpired_kpbs = client
651 .request_key_packages(UNEXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
652 .await
653 .unwrap();
654 let len = client
655 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
656 .await
657 .unwrap();
658 assert_eq!(len, unexpired_kpbs.len());
659 assert_eq!(len, UNEXPIRED_COUNT);
660
661 client
663 .set_keypackage_lifetime(std::time::Duration::from_secs(10))
664 .await
665 .unwrap();
666
667 let partially_expired_kpbs = client
669 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
670 .await
671 .unwrap();
672 assert_eq!(partially_expired_kpbs.len(), EXPIRED_COUNT);
673
674 async_std::task::sleep(std::time::Duration::from_secs(10)).await;
676
677 let fresh_kpbs = client
680 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
681 .await
682 .unwrap();
683 let len = client
684 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
685 .await
686 .unwrap();
687 assert_eq!(len, fresh_kpbs.len());
688 assert_eq!(len, EXPIRED_COUNT);
689
690 let (unexpired_match, expired_match) =
692 fresh_kpbs
693 .iter()
694 .fold((0usize, 0usize), |(mut unexpired_match, mut expired_match), fresh| {
695 if unexpired_kpbs.iter().any(|kp| kp == fresh) {
696 unexpired_match += 1;
697 } else if partially_expired_kpbs.iter().any(|kpb| kpb == fresh) {
698 expired_match += 1;
699 }
700
701 (unexpired_match, expired_match)
702 });
703
704 assert_eq!(unexpired_match, UNEXPIRED_COUNT);
706 assert_eq!(expired_match, 0);
707 }
708
709 #[apply(all_cred_cipher)]
710 #[wasm_bindgen_test]
711 async fn new_keypackage_has_correct_extensions(case: TestCase) {
712 run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
713 Box::pin(async move {
714 let kps = cc
715 .context
716 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 1)
717 .await
718 .unwrap();
719 let kp = kps.first().unwrap();
720
721 let _ = KeyPackageIn::from(kp.clone())
723 .standalone_validate(&cc.context.mls_provider().await.unwrap(), ProtocolVersion::Mls10, true)
724 .await
725 .unwrap();
726
727 assert!(kp.extensions().is_empty());
729
730 assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
731 assert_eq!(
732 kp.leaf_node().capabilities().ciphersuites().to_vec(),
733 MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
734 .iter()
735 .map(|c| VerifiableCiphersuite::from(*c))
736 .collect::<Vec<_>>()
737 );
738 assert!(kp.leaf_node().capabilities().proposals().is_empty());
739 assert!(kp.leaf_node().capabilities().extensions().is_empty());
740 assert_eq!(
741 kp.leaf_node().capabilities().credentials(),
742 MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
743 );
744 })
745 })
746 .await
747 }
748}