1use openmls::prelude::{Credential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime};
2use openmls_traits::OpenMlsCryptoProvider;
3use std::collections::HashMap;
4use std::ops::{Deref, DerefMut};
5use tls_codec::{Deserialize, Serialize};
6
7use core_crypto_keystore::{
8 connection::FetchFromDatabase,
9 entities::{EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage},
10};
11use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
12
13use super::{Error, Result};
14use crate::{
15 KeystoreError, MlsError, RecursiveError,
16 mls::{credential::CredentialBundle, session::SessionInner},
17 prelude::{MlsCiphersuite, MlsConversationConfiguration, MlsCredentialType, Session},
18 transaction_context::TransactionContext,
19};
20
21#[cfg(not(test))]
23pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 100;
24#[cfg(test)]
26pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
27
28pub(crate) const KEYPACKAGE_DEFAULT_LIFETIME: std::time::Duration =
30 std::time::Duration::from_secs(60 * 60 * 24 * 28 * 3); impl Session {
33 pub async fn generate_one_keypackage_from_credential_bundle(
41 &self,
42 backend: &MlsCryptoProvider,
43 cs: MlsCiphersuite,
44 cb: &CredentialBundle,
45 ) -> Result<KeyPackage> {
46 match self.inner.read().await.deref() {
47 None => Err(Error::MlsNotInitialized),
48 Some(SessionInner {
49 keypackage_lifetime, ..
50 }) => {
51 let keypackage = KeyPackage::builder()
52 .leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
53 .key_package_lifetime(Lifetime::new(keypackage_lifetime.as_secs()))
54 .build(
55 CryptoConfig {
56 ciphersuite: cs.into(),
57 version: openmls::versions::ProtocolVersion::default(),
58 },
59 backend,
60 &cb.signature_key,
61 CredentialWithKey {
62 credential: cb.credential.clone(),
63 signature_key: cb.signature_key.public().into(),
64 },
65 )
66 .await
67 .map_err(KeystoreError::wrap("building keypackage"))?;
68
69 Ok(keypackage)
70 }
71 }
72 }
73
74 pub async fn request_key_packages(
85 &self,
86 count: usize,
87 ciphersuite: MlsCiphersuite,
88 credential_type: MlsCredentialType,
89 backend: &MlsCryptoProvider,
90 ) -> Result<Vec<KeyPackage>> {
91 self.prune_keypackages(backend, &[]).await?;
93 use core_crypto_keystore::CryptoKeystoreMls as _;
94
95 let mut existing_kps = backend
96 .key_store()
97 .mls_fetch_keypackages::<KeyPackage>(count as u32)
98 .await.map_err(KeystoreError::wrap("fetching mls keypackages"))?
99 .into_iter()
100 .filter(|kp|
102 kp.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(kp.leaf_node().credential().credential_type()) == credential_type)
103 .collect::<Vec<_>>();
104
105 let kpb_count = existing_kps.len();
106 let mut kps = if count > kpb_count {
107 let to_generate = count - kpb_count;
108 let cb = self
109 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
110 .await?;
111 self.generate_new_keypackages(backend, ciphersuite, &cb, to_generate)
112 .await?
113 } else {
114 vec![]
115 };
116
117 existing_kps.reverse();
118
119 kps.append(&mut existing_kps);
120 Ok(kps)
121 }
122
123 pub(crate) async fn generate_new_keypackages(
124 &self,
125 backend: &MlsCryptoProvider,
126 ciphersuite: MlsCiphersuite,
127 cb: &CredentialBundle,
128 count: usize,
129 ) -> Result<Vec<KeyPackage>> {
130 let mut kps = Vec::with_capacity(count);
131
132 for _ in 0..count {
133 let kp = self
134 .generate_one_keypackage_from_credential_bundle(backend, ciphersuite, cb)
135 .await?;
136 kps.push(kp);
137 }
138
139 Ok(kps)
140 }
141
142 pub async fn valid_keypackages_count(
144 &self,
145 backend: &MlsCryptoProvider,
146 ciphersuite: MlsCiphersuite,
147 credential_type: MlsCredentialType,
148 ) -> Result<usize> {
149 let kps: Vec<MlsKeyPackage> = backend
150 .key_store()
151 .find_all(EntityFindParams::default())
152 .await
153 .map_err(KeystoreError::wrap("finding all key packages"))?;
154
155 let mut valid_count = 0;
156 for kp in kps
157 .into_iter()
158 .map(|kp| core_crypto_keystore::deser::<KeyPackage>(&kp.keypackage))
159 .filter(|kp| {
161 kp.as_ref()
162 .map(|b| b.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(b.leaf_node().credential().credential_type()) == credential_type)
163 .unwrap_or_default()
164 })
165 {
166 let kp = kp.map_err(KeystoreError::wrap("counting valid keypackages"))?;
167 if !Self::is_mls_keypackage_expired(&kp) {
168 valid_count += 1;
169 }
170 }
171
172 Ok(valid_count)
173 }
174
175 fn is_mls_keypackage_expired(kp: &KeyPackage) -> bool {
178 let Some(lifetime) = kp.leaf_node().life_time() else {
179 return false;
180 };
181
182 !(lifetime.has_acceptable_range() && lifetime.is_valid())
183 }
184
185 pub async fn prune_keypackages(&self, backend: &MlsCryptoProvider, refs: &[KeyPackageRef]) -> Result<()> {
191 let keystore = backend.keystore();
192 let kps = self.find_all_keypackages(&keystore).await?;
193 let _ = self._prune_keypackages(&kps, &keystore, refs).await?;
194 Ok(())
195 }
196
197 pub(crate) async fn prune_keypackages_and_credential(
198 &mut self,
199 backend: &MlsCryptoProvider,
200 refs: &[KeyPackageRef],
201 ) -> Result<()> {
202 match self.inner.write().await.deref_mut() {
203 None => Err(Error::MlsNotInitialized),
204 Some(SessionInner { identities, .. }) => {
205 let keystore = backend.key_store();
206 let kps = self.find_all_keypackages(keystore).await?;
207 let kp_to_delete = self._prune_keypackages(&kps, keystore, refs).await?;
208
209 let mut grouped_kps = HashMap::<Vec<u8>, Vec<KeyPackageRef>>::new();
211 for (_, kp) in &kps {
212 let cred = kp
213 .leaf_node()
214 .credential()
215 .tls_serialize_detached()
216 .map_err(Error::tls_serialize("keypackage"))?;
217 let kp_ref = kp
218 .hash_ref(backend.crypto())
219 .map_err(MlsError::wrap("computing keypackage hashref"))?;
220 grouped_kps
221 .entry(cred)
222 .and_modify(|kprfs| kprfs.push(kp_ref.clone()))
223 .or_insert(vec![kp_ref]);
224 }
225
226 for (credential, kps) in &grouped_kps {
227 let all_to_delete = kps.iter().all(|kpr| kp_to_delete.contains(&kpr.as_slice()));
229 if all_to_delete {
230 backend
232 .keystore()
233 .cred_delete_by_credential(credential.clone())
234 .await
235 .map_err(KeystoreError::wrap("deleting credential"))?;
236 let credential = Credential::tls_deserialize(&mut credential.as_slice())
237 .map_err(Error::tls_deserialize("credential"))?;
238 identities.remove(&credential).await?;
239 }
240 }
241 Ok(())
242 }
243 }
244 }
245
246 async fn _prune_keypackages<'a>(
251 &self,
252 kps: &'a [(MlsKeyPackage, KeyPackage)],
253 keystore: &CryptoKeystore,
254 refs: &[KeyPackageRef],
255 ) -> Result<Vec<&'a [u8]>, Error> {
256 let kp_to_delete: Vec<_> = kps
257 .iter()
258 .filter_map(|(store_kp, kp)| {
259 let is_expired = Self::is_mls_keypackage_expired(kp);
260 let mut to_delete = is_expired;
261 if !(is_expired || refs.is_empty()) {
262 to_delete = refs.iter().any(|r| r.as_slice() == store_kp.keypackage_ref);
265 }
266
267 to_delete.then_some((kp, &store_kp.keypackage_ref))
268 })
269 .collect();
270
271 for (kp, kp_ref) in &kp_to_delete {
272 keystore
274 .remove::<MlsKeyPackage, &[u8]>(kp_ref.as_slice())
275 .await
276 .map_err(KeystoreError::wrap("removing key package from keystore"))?;
277 keystore
278 .remove::<MlsHpkePrivateKey, &[u8]>(kp.hpke_init_key().as_slice())
279 .await
280 .map_err(KeystoreError::wrap("removing private key from keystore"))?;
281 keystore
282 .remove::<MlsEncryptionKeyPair, &[u8]>(kp.leaf_node().encryption_key().as_slice())
283 .await
284 .map_err(KeystoreError::wrap("removing encryption keypair from keystore"))?;
285 }
286
287 let kp_to_delete = kp_to_delete
288 .into_iter()
289 .map(|(_, kpref)| &kpref[..])
290 .collect::<Vec<_>>();
291
292 Ok(kp_to_delete)
293 }
294
295 async fn find_all_keypackages(&self, keystore: &CryptoKeystore) -> Result<Vec<(MlsKeyPackage, KeyPackage)>> {
296 let kps: Vec<MlsKeyPackage> = keystore
297 .find_all(EntityFindParams::default())
298 .await
299 .map_err(KeystoreError::wrap("finding all keypackages"))?;
300
301 let kps = kps
302 .into_iter()
303 .map(|raw_kp| -> Result<_> {
304 let kp = core_crypto_keystore::deser::<KeyPackage>(&raw_kp.keypackage)
305 .map_err(KeystoreError::wrap("deserializing keypackage"))?;
306 Ok((raw_kp, kp))
307 })
308 .collect::<Result<Vec<_>, _>>()?;
309
310 Ok(kps)
311 }
312
313 #[cfg(test)]
316 pub async fn set_keypackage_lifetime(&self, duration: std::time::Duration) -> Result<()> {
317 use std::ops::DerefMut;
318 match self.inner.write().await.deref_mut() {
319 None => Err(Error::MlsNotInitialized),
320 Some(SessionInner {
321 keypackage_lifetime, ..
322 }) => {
323 *keypackage_lifetime = duration;
324 Ok(())
325 }
326 }
327 }
328}
329
330impl TransactionContext {
331 pub async fn get_or_create_client_keypackages(
345 &self,
346 ciphersuite: MlsCiphersuite,
347 credential_type: MlsCredentialType,
348 amount_requested: usize,
349 ) -> Result<Vec<KeyPackage>> {
350 let client = self
351 .session()
352 .await
353 .map_err(RecursiveError::transaction("getting mls client"))?;
354 client
355 .request_key_packages(
356 amount_requested,
357 ciphersuite,
358 credential_type,
359 &self
360 .mls_provider()
361 .await
362 .map_err(RecursiveError::transaction("getting mls provider"))?,
363 )
364 .await
365 }
366
367 #[cfg_attr(test, crate::idempotent)]
369 pub async fn client_valid_key_packages_count(
370 &self,
371 ciphersuite: MlsCiphersuite,
372 credential_type: MlsCredentialType,
373 ) -> Result<usize> {
374 let client = self
375 .session()
376 .await
377 .map_err(RecursiveError::transaction("getting mls client"))?;
378 client
379 .valid_keypackages_count(
380 &self
381 .mls_provider()
382 .await
383 .map_err(RecursiveError::transaction("getting mls provider"))?,
384 ciphersuite,
385 credential_type,
386 )
387 .await
388 }
389
390 #[cfg_attr(test, crate::dispotent)]
393 pub async fn delete_keypackages(&self, refs: &[KeyPackageRef]) -> Result<()> {
394 if refs.is_empty() {
395 return Err(Error::EmptyKeypackageList);
396 }
397 let mut client = self
398 .session()
399 .await
400 .map_err(RecursiveError::transaction("getting mls client"))?;
401 client
402 .prune_keypackages_and_credential(
403 &self
404 .mls_provider()
405 .await
406 .map_err(RecursiveError::transaction("getting mls provider"))?,
407 refs,
408 )
409 .await
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use openmls::prelude::{KeyPackage, KeyPackageIn, KeyPackageRef, ProtocolVersion};
416 use openmls_traits::OpenMlsCryptoProvider;
417 use openmls_traits::types::VerifiableCiphersuite;
418 use wasm_bindgen_test::*;
419
420 use mls_crypto_provider::MlsCryptoProvider;
421
422 use crate::e2e_identity::enrollment::test_utils::{e2ei_enrollment, init_activation_or_rotation, noop_restore};
423 use crate::prelude::MlsConversationConfiguration;
424 use crate::prelude::key_package::INITIAL_KEYING_MATERIAL_COUNT;
425 use crate::test_utils::*;
426 use core_crypto_keystore::DatabaseKey;
427
428 use super::Session;
429
430 wasm_bindgen_test_configure!(run_in_browser);
431
432 #[apply(all_cred_cipher)]
433 #[wasm_bindgen_test]
434 async fn can_assess_keypackage_expiration(case: TestCase) {
435 run_test_with_central(case.clone(), move |[central]| {
436 Box::pin(async move {
437 let (cs, ct) = (case.ciphersuite(), case.credential_type);
438 let key = DatabaseKey::generate();
439 let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
440 let x509_test_chain = if case.is_x509() {
441 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
442 x509_test_chain.register_with_provider(&backend).await;
443 Some(x509_test_chain)
444 } else {
445 None
446 };
447
448 backend.new_transaction().await.unwrap();
449 let client = central.session;
450 client
451 .random_generate(
452 &case,
453 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
454 false,
455 )
456 .await
457 .unwrap();
458
459 let kp_std_exp = client.generate_one_keypackage(&backend, cs, ct).await.unwrap();
461 assert!(!Session::is_mls_keypackage_expired(&kp_std_exp));
462
463 client
465 .set_keypackage_lifetime(std::time::Duration::from_secs(1))
466 .await
467 .unwrap();
468 let kp_1s_exp = client.generate_one_keypackage(&backend, cs, ct).await.unwrap();
469 async_std::task::sleep(std::time::Duration::from_secs(2)).await;
471 assert!(Session::is_mls_keypackage_expired(&kp_1s_exp));
472 })
473 })
474 .await;
475 }
476
477 #[apply(all_cred_cipher)]
478 #[wasm_bindgen_test]
479 async fn requesting_x509_key_packages_after_basic(case: TestCase) {
480 if !case.is_basic() {
482 return;
483 }
484 run_test_with_client_ids(case.clone(), ["alice"], move |[mut session_context]| {
485 Box::pin(async move {
486 let signature_scheme = case.signature_scheme();
487 let cipher_suite = case.ciphersuite();
488
489 let _basic_key_packages = session_context
491 .context
492 .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::Basic, 5)
493 .await
494 .unwrap();
495
496 let test_chain = x509::X509TestChain::init_for_random_clients(signature_scheme, 1);
498
499 let (mut enrollment, cert_chain) = e2ei_enrollment(
500 &mut session_context,
501 &case,
502 &test_chain,
503 None,
504 false,
505 init_activation_or_rotation,
506 noop_restore,
507 )
508 .await
509 .unwrap();
510
511 let _rotate_bundle = session_context
512 .context
513 .save_x509_credential(&mut enrollment, cert_chain)
514 .await
515 .unwrap();
516
517 assert!(session_context.context.e2ei_is_enabled(signature_scheme).await.unwrap());
519
520 let x509_key_packages = session_context
522 .context
523 .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::X509, 5)
524 .await
525 .unwrap();
526
527 assert!(x509_key_packages.iter().all(|kp| MlsCredentialType::X509
529 == MlsCredentialType::from(kp.leaf_node().credential().credential_type())));
530 })
531 })
532 .await
533 }
534
535 #[apply(all_cred_cipher)]
536 #[wasm_bindgen_test]
537 async fn generates_correct_number_of_kpbs(case: TestCase) {
538 run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
539 Box::pin(async move {
540 const N: usize = 2;
541 const COUNT: usize = 109;
542
543 let init = cc.context.count_entities().await;
544 assert_eq!(init.key_package, INITIAL_KEYING_MATERIAL_COUNT);
545 assert_eq!(init.encryption_keypair, INITIAL_KEYING_MATERIAL_COUNT);
546 assert_eq!(init.hpke_private_key, INITIAL_KEYING_MATERIAL_COUNT);
547 assert_eq!(init.credential, 1);
548 assert_eq!(init.signature_keypair, 1);
549
550 let transactional_provider = cc.context.mls_provider().await.unwrap();
553 let crypto_provider = transactional_provider.crypto();
554 let mut pinned_kp = None;
555
556 let mut prev_kps: Option<Vec<KeyPackage>> = None;
557 for _ in 0..N {
558 let mut kps = cc
559 .context
560 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, COUNT + 1)
561 .await
562 .unwrap();
563
564 pinned_kp = Some(kps.pop().unwrap());
566
567 assert_eq!(kps.len(), COUNT);
568 let after_creation = cc.context.count_entities().await;
569 assert_eq!(after_creation.key_package, COUNT + 1);
570 assert_eq!(after_creation.encryption_keypair, COUNT + 1);
571 assert_eq!(after_creation.hpke_private_key, COUNT + 1);
572 assert_eq!(after_creation.credential, 1);
573
574 let kpbs_refs = kps
575 .iter()
576 .map(|kp| kp.hash_ref(crypto_provider).unwrap())
577 .collect::<Vec<KeyPackageRef>>();
578
579 if let Some(pkpbs) = prev_kps.replace(kps) {
580 let pkpbs_refs = pkpbs
581 .into_iter()
582 .map(|kpb| kpb.hash_ref(crypto_provider).unwrap())
583 .collect::<Vec<KeyPackageRef>>();
584
585 let has_duplicates = kpbs_refs.iter().any(|href| pkpbs_refs.contains(href));
586 assert!(!has_duplicates);
588 }
589 cc.context.delete_keypackages(&kpbs_refs).await.unwrap();
590 }
591
592 let count = cc
593 .context
594 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
595 .await
596 .unwrap();
597 assert_eq!(count, 1);
598
599 let pinned_kpr = pinned_kp.unwrap().hash_ref(crypto_provider).unwrap();
600 cc.context.delete_keypackages(&[pinned_kpr]).await.unwrap();
601 let count = cc
602 .context
603 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
604 .await
605 .unwrap();
606 assert_eq!(count, 0);
607 let after_delete = cc.context.count_entities().await;
608 assert_eq!(after_delete.key_package, 0);
609 assert_eq!(after_delete.encryption_keypair, 0);
610 assert_eq!(after_delete.hpke_private_key, 0);
611 assert_eq!(after_delete.credential, 0);
612 })
613 })
614 .await
615 }
616
617 #[apply(all_cred_cipher)]
618 #[wasm_bindgen_test]
619 async fn automatically_prunes_lifetime_expired_keypackages(case: TestCase) {
620 run_test_with_central(case.clone(), move |[central]| {
621 Box::pin(async move {
622 const UNEXPIRED_COUNT: usize = 125;
623 const EXPIRED_COUNT: usize = 200;
624 let key = DatabaseKey::generate();
625 let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
626 let x509_test_chain = if case.is_x509() {
627 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
628 x509_test_chain.register_with_provider(&backend).await;
629 Some(x509_test_chain)
630 } else {
631 None
632 };
633 backend.new_transaction().await.unwrap();
634 let client = central.session().await;
635 client
636 .random_generate(
637 &case,
638 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
639 false,
640 )
641 .await
642 .unwrap();
643
644 let unexpired_kpbs = client
646 .request_key_packages(UNEXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
647 .await
648 .unwrap();
649 let len = client
650 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
651 .await
652 .unwrap();
653 assert_eq!(len, unexpired_kpbs.len());
654 assert_eq!(len, UNEXPIRED_COUNT);
655
656 client
658 .set_keypackage_lifetime(std::time::Duration::from_secs(10))
659 .await
660 .unwrap();
661
662 let partially_expired_kpbs = client
664 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
665 .await
666 .unwrap();
667 assert_eq!(partially_expired_kpbs.len(), EXPIRED_COUNT);
668
669 async_std::task::sleep(std::time::Duration::from_secs(10)).await;
671
672 let fresh_kpbs = client
675 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
676 .await
677 .unwrap();
678 let len = client
679 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
680 .await
681 .unwrap();
682 assert_eq!(len, fresh_kpbs.len());
683 assert_eq!(len, EXPIRED_COUNT);
684
685 let (unexpired_match, expired_match) =
687 fresh_kpbs
688 .iter()
689 .fold((0usize, 0usize), |(mut unexpired_match, mut expired_match), fresh| {
690 if unexpired_kpbs.iter().any(|kp| kp == fresh) {
691 unexpired_match += 1;
692 } else if partially_expired_kpbs.iter().any(|kpb| kpb == fresh) {
693 expired_match += 1;
694 }
695
696 (unexpired_match, expired_match)
697 });
698
699 assert_eq!(unexpired_match, UNEXPIRED_COUNT);
701 assert_eq!(expired_match, 0);
702 })
703 })
704 .await
705 }
706
707 #[apply(all_cred_cipher)]
708 #[wasm_bindgen_test]
709 async fn new_keypackage_has_correct_extensions(case: TestCase) {
710 run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
711 Box::pin(async move {
712 let kps = cc
713 .context
714 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 1)
715 .await
716 .unwrap();
717 let kp = kps.first().unwrap();
718
719 let _ = KeyPackageIn::from(kp.clone())
721 .standalone_validate(&cc.context.mls_provider().await.unwrap(), ProtocolVersion::Mls10, true)
722 .await
723 .unwrap();
724
725 assert!(kp.extensions().is_empty());
727
728 assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
729 assert_eq!(
730 kp.leaf_node().capabilities().ciphersuites().to_vec(),
731 MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
732 .iter()
733 .map(|c| VerifiableCiphersuite::from(*c))
734 .collect::<Vec<_>>()
735 );
736 assert!(kp.leaf_node().capabilities().proposals().is_empty());
737 assert!(kp.leaf_node().capabilities().extensions().is_empty());
738 assert_eq!(
739 kp.leaf_node().capabilities().credentials(),
740 MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
741 );
742 })
743 })
744 .await
745 }
746}