1use openmls::prelude::{Credential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime};
18use openmls_traits::OpenMlsCryptoProvider;
19use std::collections::HashMap;
20use std::ops::{Deref, DerefMut};
21use tls_codec::{Deserialize, Serialize};
22
23use core_crypto_keystore::{
24 connection::FetchFromDatabase,
25 entities::{EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage},
26};
27use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
28
29use crate::context::CentralContext;
30use crate::mls::client::ClientInner;
31use crate::{
32 mls::credential::CredentialBundle,
33 prelude::{
34 Client, CryptoError, CryptoResult, MlsCiphersuite, MlsConversationConfiguration, MlsCredentialType, MlsError,
35 },
36};
37
38#[cfg(not(test))]
40pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 100;
41#[cfg(test)]
42pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
43
44pub(crate) const KEYPACKAGE_DEFAULT_LIFETIME: std::time::Duration =
46 std::time::Duration::from_secs(60 * 60 * 24 * 28 * 3); impl Client {
49 pub async fn generate_one_keypackage_from_credential_bundle(
57 &self,
58 backend: &MlsCryptoProvider,
59 cs: MlsCiphersuite,
60 cb: &CredentialBundle,
61 ) -> CryptoResult<KeyPackage> {
62 match self.state.read().await.deref() {
63 None => Err(CryptoError::MlsNotInitialized),
64 Some(ClientInner {
65 keypackage_lifetime, ..
66 }) => {
67 let keypackage = KeyPackage::builder()
68 .leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
69 .key_package_lifetime(Lifetime::new(keypackage_lifetime.as_secs()))
70 .build(
71 CryptoConfig {
72 ciphersuite: cs.into(),
73 version: openmls::versions::ProtocolVersion::default(),
74 },
75 backend,
76 &cb.signature_key,
77 CredentialWithKey {
78 credential: cb.credential.clone(),
79 signature_key: cb.signature_key.public().into(),
80 },
81 )
82 .await
83 .map_err(MlsError::from)?;
84
85 Ok(keypackage)
86 }
87 }
88 }
89
90 pub async fn request_key_packages(
101 &self,
102 count: usize,
103 ciphersuite: MlsCiphersuite,
104 credential_type: MlsCredentialType,
105 backend: &MlsCryptoProvider,
106 ) -> CryptoResult<Vec<KeyPackage>> {
107 self.prune_keypackages(backend, &[]).await?;
109 use core_crypto_keystore::CryptoKeystoreMls as _;
110
111 let mut existing_kps = backend
112 .key_store()
113 .mls_fetch_keypackages::<KeyPackage>(count as u32)
114 .await?
115 .into_iter()
116 .filter(|kp|
118 kp.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(kp.leaf_node().credential().credential_type()) == credential_type)
119 .collect::<Vec<_>>();
120
121 let kpb_count = existing_kps.len();
122 let mut kps = if count > kpb_count {
123 let to_generate = count - kpb_count;
124 let cb = self
125 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
126 .await?;
127 self.generate_new_keypackages(backend, ciphersuite, &cb, to_generate)
128 .await?
129 } else {
130 vec![]
131 };
132
133 existing_kps.reverse();
134
135 kps.append(&mut existing_kps);
136 Ok(kps)
137 }
138
139 pub(crate) async fn generate_new_keypackages(
140 &self,
141 backend: &MlsCryptoProvider,
142 ciphersuite: MlsCiphersuite,
143 cb: &CredentialBundle,
144 count: usize,
145 ) -> CryptoResult<Vec<KeyPackage>> {
146 let mut kps = Vec::with_capacity(count);
147
148 for _ in 0..count {
149 let kp = self
150 .generate_one_keypackage_from_credential_bundle(backend, ciphersuite, cb)
151 .await?;
152 kps.push(kp);
153 }
154
155 Ok(kps)
156 }
157
158 pub async fn valid_keypackages_count(
160 &self,
161 backend: &MlsCryptoProvider,
162 ciphersuite: MlsCiphersuite,
163 credential_type: MlsCredentialType,
164 ) -> CryptoResult<usize> {
165 let kps: Vec<MlsKeyPackage> = backend.key_store().find_all(EntityFindParams::default()).await?;
166
167 let valid_count = kps
168 .into_iter()
169 .map(|kp| core_crypto_keystore::deser::<KeyPackage>(&kp.keypackage))
170 .filter(|kp| {
172 kp.as_ref()
173 .map(|b| b.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(b.leaf_node().credential().credential_type()) == credential_type)
174 .unwrap_or_default()
175 })
176 .try_fold(0usize, |mut valid_count, kp| {
177 if !Self::is_mls_keypackage_expired(&kp?) {
178 valid_count += 1;
179 }
180 CryptoResult::Ok(valid_count)
181 })?;
182
183 Ok(valid_count)
184 }
185
186 fn is_mls_keypackage_expired(kp: &KeyPackage) -> bool {
189 let Some(lifetime) = kp.leaf_node().life_time() else {
190 return false;
191 };
192
193 !(lifetime.has_acceptable_range() && lifetime.is_valid())
194 }
195
196 pub async fn prune_keypackages(&self, backend: &MlsCryptoProvider, refs: &[KeyPackageRef]) -> CryptoResult<()> {
202 let keystore = backend.keystore();
203 let kps = self.find_all_keypackages(&keystore).await?;
204 let _ = self._prune_keypackages(&kps, &keystore, refs).await?;
205 Ok(())
206 }
207
208 pub(crate) async fn prune_keypackages_and_credential(
209 &mut self,
210 backend: &MlsCryptoProvider,
211 refs: &[KeyPackageRef],
212 ) -> CryptoResult<()> {
213 match self.state.write().await.deref_mut() {
214 None => Err(CryptoError::MlsNotInitialized),
215 Some(ClientInner { identities, .. }) => {
216 let keystore = backend.key_store();
217 let kps = self.find_all_keypackages(keystore).await?;
218 let kp_to_delete = self._prune_keypackages(&kps, keystore, refs).await?;
219
220 let mut grouped_kps = HashMap::<Vec<u8>, Vec<KeyPackageRef>>::new();
222 for (_, kp) in &kps {
223 let cred = kp
224 .leaf_node()
225 .credential()
226 .tls_serialize_detached()
227 .map_err(MlsError::from)?;
228 let kp_ref = kp.hash_ref(backend.crypto()).map_err(MlsError::from)?;
229 grouped_kps
230 .entry(cred)
231 .and_modify(|kprfs| kprfs.push(kp_ref.clone()))
232 .or_insert(vec![kp_ref]);
233 }
234
235 for (credential, kps) in &grouped_kps {
236 let all_to_delete = kps.iter().all(|kpr| kp_to_delete.contains(&kpr.as_slice()));
238 if all_to_delete {
239 backend.keystore().cred_delete_by_credential(credential.clone()).await?;
241 let credential =
242 Credential::tls_deserialize(&mut credential.as_slice()).map_err(MlsError::from)?;
243 identities.remove(&credential).await?;
244 }
245 }
246 Ok(())
247 }
248 }
249 }
250
251 async fn _prune_keypackages<'a>(
256 &self,
257 kps: &'a [(MlsKeyPackage, KeyPackage)],
258 keystore: &CryptoKeystore,
259 refs: &[KeyPackageRef],
260 ) -> Result<Vec<&'a [u8]>, CryptoError> {
261 let kp_to_delete: Vec<_> = kps
262 .iter()
263 .filter_map(|(store_kp, kp)| {
264 let is_expired = Self::is_mls_keypackage_expired(kp);
265 let mut to_delete = is_expired;
266 if !(is_expired || refs.is_empty()) {
267 to_delete = refs.iter().any(|r| r.as_slice() == store_kp.keypackage_ref);
270 }
271
272 to_delete.then_some((kp, &store_kp.keypackage_ref))
273 })
274 .collect();
275
276 for (kp, kp_ref) in &kp_to_delete {
277 keystore.remove::<MlsKeyPackage, &[u8]>(kp_ref.as_slice()).await?;
279 keystore
280 .remove::<MlsHpkePrivateKey, &[u8]>(kp.hpke_init_key().as_slice())
281 .await?;
282 keystore
283 .remove::<MlsEncryptionKeyPair, &[u8]>(kp.leaf_node().encryption_key().as_slice())
284 .await?;
285 }
286
287 let kp_to_delete = kp_to_delete
288 .into_iter()
289 .map(|(_, kpref)| &kpref[..])
290 .collect::<Vec<_>>();
291
292 Ok(kp_to_delete)
293 }
294
295 async fn find_all_keypackages(&self, keystore: &CryptoKeystore) -> CryptoResult<Vec<(MlsKeyPackage, KeyPackage)>> {
296 let kps: Vec<MlsKeyPackage> = keystore.find_all(EntityFindParams::default()).await?;
297
298 let kps = kps.into_iter().try_fold(vec![], |mut acc, raw_kp| {
299 let kp = core_crypto_keystore::deser::<KeyPackage>(&raw_kp.keypackage)?;
300 acc.push((raw_kp, kp));
301 CryptoResult::Ok(acc)
302 })?;
303
304 Ok(kps)
305 }
306
307 #[cfg(test)]
310 pub async fn set_keypackage_lifetime(&self, duration: std::time::Duration) -> CryptoResult<()> {
311 use std::ops::DerefMut;
312 match self.state.write().await.deref_mut() {
313 None => Err(CryptoError::MlsNotInitialized),
314 Some(ClientInner {
315 ref mut keypackage_lifetime,
316 ..
317 }) => {
318 *keypackage_lifetime = duration;
319 Ok(())
320 }
321 }
322 }
323}
324
325impl CentralContext {
326 pub async fn get_or_create_client_keypackages(
340 &self,
341 ciphersuite: MlsCiphersuite,
342 credential_type: MlsCredentialType,
343 amount_requested: usize,
344 ) -> CryptoResult<Vec<KeyPackage>> {
345 let client = self.mls_client().await?;
346 client
347 .request_key_packages(
348 amount_requested,
349 ciphersuite,
350 credential_type,
351 &self.mls_provider().await?,
352 )
353 .await
354 }
355
356 #[cfg_attr(test, crate::idempotent)]
358 pub async fn client_valid_key_packages_count(
359 &self,
360 ciphersuite: MlsCiphersuite,
361 credential_type: MlsCredentialType,
362 ) -> CryptoResult<usize> {
363 let client = self.mls_client().await?;
364 client
365 .valid_keypackages_count(&self.mls_provider().await?, ciphersuite, credential_type)
366 .await
367 }
368
369 #[cfg_attr(test, crate::dispotent)]
372 pub async fn delete_keypackages(&self, refs: &[KeyPackageRef]) -> CryptoResult<()> {
373 if refs.is_empty() {
374 return Err(CryptoError::ConsumerError);
375 }
376 let mut client = self.mls_client().await?;
377 client
378 .prune_keypackages_and_credential(&self.mls_provider().await?, refs)
379 .await
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use openmls::prelude::{KeyPackage, KeyPackageIn, KeyPackageRef, ProtocolVersion};
386 use openmls_traits::types::VerifiableCiphersuite;
387 use openmls_traits::OpenMlsCryptoProvider;
388 use wasm_bindgen_test::*;
389
390 use mls_crypto_provider::MlsCryptoProvider;
391
392 use crate::e2e_identity::tests::{e2ei_enrollment, init_activation_or_rotation, noop_restore};
393 use crate::prelude::key_package::INITIAL_KEYING_MATERIAL_COUNT;
394 use crate::prelude::MlsConversationConfiguration;
395 use crate::test_utils::*;
396
397 use super::Client;
398
399 wasm_bindgen_test_configure!(run_in_browser);
400
401 #[apply(all_cred_cipher)]
402 #[wasm_bindgen_test]
403 async fn can_assess_keypackage_expiration(case: TestCase) {
404 let (cs, ct) = (case.ciphersuite(), case.credential_type);
405 let backend = MlsCryptoProvider::try_new_in_memory("test").await.unwrap();
406 let x509_test_chain = if case.is_x509() {
407 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
408 x509_test_chain.register_with_provider(&backend).await;
409 Some(x509_test_chain)
410 } else {
411 None
412 };
413
414 backend.new_transaction().await.unwrap();
415 let client = Client::random_generate(
416 &case,
417 &backend,
418 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
419 false,
420 )
421 .await
422 .unwrap();
423
424 let kp_std_exp = client.generate_one_keypackage(&backend, cs, ct).await.unwrap();
426 assert!(!Client::is_mls_keypackage_expired(&kp_std_exp));
427
428 client
430 .set_keypackage_lifetime(std::time::Duration::from_secs(1))
431 .await
432 .unwrap();
433 let kp_1s_exp = client.generate_one_keypackage(&backend, cs, ct).await.unwrap();
434 async_std::task::sleep(std::time::Duration::from_secs(2)).await;
436 assert!(Client::is_mls_keypackage_expired(&kp_1s_exp));
437 }
438
439 #[apply(all_cred_cipher)]
440 #[wasm_bindgen_test]
441 async fn requesting_x509_key_packages_after_basic(case: TestCase) {
442 if !case.is_basic() {
444 return;
445 }
446 run_test_with_client_ids(case.clone(), ["alice"], move |[mut client_context]| {
447 Box::pin(async move {
448 let signature_scheme = case.signature_scheme();
449 let cipher_suite = case.ciphersuite();
450
451 let _basic_key_packages = client_context
453 .context
454 .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::Basic, 5)
455 .await
456 .unwrap();
457
458 let test_chain = x509::X509TestChain::init_for_random_clients(signature_scheme, 1);
460
461 let (mut enrollment, cert_chain) = e2ei_enrollment(
462 &mut client_context,
463 &case,
464 &test_chain,
465 None,
466 false,
467 init_activation_or_rotation,
468 noop_restore,
469 )
470 .await
471 .unwrap();
472
473 let _rotate_bundle = client_context
474 .context
475 .e2ei_rotate_all(&mut enrollment, cert_chain, 5)
476 .await
477 .unwrap();
478
479 assert!(client_context.context.e2ei_is_enabled(signature_scheme).await.unwrap());
481
482 let x509_key_packages = client_context
484 .context
485 .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::X509, 5)
486 .await
487 .unwrap();
488
489 assert!(x509_key_packages.iter().all(|kp| MlsCredentialType::X509
491 == MlsCredentialType::from(kp.leaf_node().credential().credential_type())));
492 })
493 })
494 .await
495 }
496
497 #[apply(all_cred_cipher)]
498 #[wasm_bindgen_test]
499 async fn generates_correct_number_of_kpbs(case: TestCase) {
500 run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
501 Box::pin(async move {
502 const N: usize = 2;
503 const COUNT: usize = 109;
504
505 let init = cc.context.count_entities().await;
506 assert_eq!(init.key_package, INITIAL_KEYING_MATERIAL_COUNT);
507 assert_eq!(init.encryption_keypair, INITIAL_KEYING_MATERIAL_COUNT);
508 assert_eq!(init.hpke_private_key, INITIAL_KEYING_MATERIAL_COUNT);
509 assert_eq!(init.credential, 1);
510 assert_eq!(init.signature_keypair, 1);
511
512 let transactional_provider = cc.context.mls_provider().await.unwrap();
515 let crypto_provider = transactional_provider.crypto();
516 let mut pinned_kp = None;
517
518 let mut prev_kps: Option<Vec<KeyPackage>> = None;
519 for _ in 0..N {
520 let mut kps = cc
521 .context
522 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, COUNT + 1)
523 .await
524 .unwrap();
525
526 pinned_kp = Some(kps.pop().unwrap());
528
529 assert_eq!(kps.len(), COUNT);
530 let after_creation = cc.context.count_entities().await;
531 assert_eq!(after_creation.key_package, COUNT + 1);
532 assert_eq!(after_creation.encryption_keypair, COUNT + 1);
533 assert_eq!(after_creation.hpke_private_key, COUNT + 1);
534 assert_eq!(after_creation.credential, 1);
535
536 let kpbs_refs = kps
537 .iter()
538 .map(|kp| kp.hash_ref(crypto_provider).unwrap())
539 .collect::<Vec<KeyPackageRef>>();
540
541 if let Some(pkpbs) = prev_kps.replace(kps) {
542 let pkpbs_refs = pkpbs
543 .into_iter()
544 .map(|kpb| kpb.hash_ref(crypto_provider).unwrap())
545 .collect::<Vec<KeyPackageRef>>();
546
547 let has_duplicates = kpbs_refs.iter().any(|href| pkpbs_refs.contains(href));
548 assert!(!has_duplicates);
550 }
551 cc.context.delete_keypackages(&kpbs_refs).await.unwrap();
552 }
553
554 let count = cc
555 .context
556 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
557 .await
558 .unwrap();
559 assert_eq!(count, 1);
560
561 let pinned_kpr = pinned_kp.unwrap().hash_ref(crypto_provider).unwrap();
562 cc.context.delete_keypackages(&[pinned_kpr]).await.unwrap();
563 let count = cc
564 .context
565 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
566 .await
567 .unwrap();
568 assert_eq!(count, 0);
569 let after_delete = cc.context.count_entities().await;
570 assert_eq!(after_delete.key_package, 0);
571 assert_eq!(after_delete.encryption_keypair, 0);
572 assert_eq!(after_delete.hpke_private_key, 0);
573 assert_eq!(after_delete.credential, 0);
574 })
575 })
576 .await
577 }
578
579 #[apply(all_cred_cipher)]
580 #[wasm_bindgen_test]
581 async fn automatically_prunes_lifetime_expired_keypackages(case: TestCase) {
582 const UNEXPIRED_COUNT: usize = 125;
583 const EXPIRED_COUNT: usize = 200;
584 let backend = MlsCryptoProvider::try_new_in_memory("test").await.unwrap();
585 let x509_test_chain = if case.is_x509() {
586 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
587 x509_test_chain.register_with_provider(&backend).await;
588 Some(x509_test_chain)
589 } else {
590 None
591 };
592 backend.new_transaction().await.unwrap();
593 let client = Client::random_generate(
594 &case,
595 &backend,
596 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
597 false,
598 )
599 .await
600 .unwrap();
601
602 let unexpired_kpbs = client
604 .request_key_packages(UNEXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
605 .await
606 .unwrap();
607 let len = client
608 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
609 .await
610 .unwrap();
611 assert_eq!(len, unexpired_kpbs.len());
612 assert_eq!(len, UNEXPIRED_COUNT);
613
614 client
616 .set_keypackage_lifetime(std::time::Duration::from_secs(10))
617 .await
618 .unwrap();
619
620 let partially_expired_kpbs = client
622 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
623 .await
624 .unwrap();
625 assert_eq!(partially_expired_kpbs.len(), EXPIRED_COUNT);
626
627 async_std::task::sleep(std::time::Duration::from_secs(10)).await;
629
630 let fresh_kpbs = client
633 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
634 .await
635 .unwrap();
636 let len = client
637 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
638 .await
639 .unwrap();
640 assert_eq!(len, fresh_kpbs.len());
641 assert_eq!(len, EXPIRED_COUNT);
642
643 let (unexpired_match, expired_match) =
645 fresh_kpbs
646 .iter()
647 .fold((0usize, 0usize), |(mut unexpired_match, mut expired_match), fresh| {
648 if unexpired_kpbs.iter().any(|kp| kp == fresh) {
649 unexpired_match += 1;
650 } else if partially_expired_kpbs.iter().any(|kpb| kpb == fresh) {
651 expired_match += 1;
652 }
653
654 (unexpired_match, expired_match)
655 });
656
657 assert_eq!(unexpired_match, UNEXPIRED_COUNT);
659 assert_eq!(expired_match, 0);
660 }
661
662 #[apply(all_cred_cipher)]
663 #[wasm_bindgen_test]
664 async fn new_keypackage_has_correct_extensions(case: TestCase) {
665 run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
666 Box::pin(async move {
667 let kps = cc
668 .context
669 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 1)
670 .await
671 .unwrap();
672 let kp = kps.first().unwrap();
673
674 let _ = KeyPackageIn::from(kp.clone())
676 .standalone_validate(&cc.context.mls_provider().await.unwrap(), ProtocolVersion::Mls10, true)
677 .await
678 .unwrap();
679
680 assert!(kp.extensions().is_empty());
682
683 assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
684 assert_eq!(
685 kp.leaf_node().capabilities().ciphersuites().to_vec(),
686 MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
687 .iter()
688 .map(|c| VerifiableCiphersuite::from(*c))
689 .collect::<Vec<_>>()
690 );
691 assert!(kp.leaf_node().capabilities().proposals().is_empty());
692 assert!(kp.leaf_node().capabilities().extensions().is_empty());
693 assert_eq!(
694 kp.leaf_node().capabilities().credentials(),
695 MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
696 );
697 })
698 })
699 .await
700 }
701}