1use openmls::prelude::{Credential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime};
2use openmls_traits::OpenMlsCryptoProvider;
3use std::collections::{HashMap, HashSet};
4use tls_codec::{Deserialize, Serialize};
5
6use core_crypto_keystore::{
7 connection::FetchFromDatabase,
8 entities::{EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage},
9};
10use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
11
12use super::{Error, Result};
13use crate::{
14 KeystoreError, MlsError,
15 mls::{credential::CredentialBundle, session::SessionInner},
16 prelude::{MlsCiphersuite, MlsConversationConfiguration, MlsCredentialType, Session},
17};
18
19#[cfg(not(test))]
21pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 100;
22#[cfg(test)]
24pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
25
26pub(crate) const KEYPACKAGE_DEFAULT_LIFETIME: std::time::Duration =
28 std::time::Duration::from_secs(60 * 60 * 24 * 28 * 3); impl Session {
31 pub async fn generate_one_keypackage_from_credential_bundle(
39 &self,
40 backend: &MlsCryptoProvider,
41 cs: MlsCiphersuite,
42 cb: &CredentialBundle,
43 ) -> Result<KeyPackage> {
44 let guard = self.inner.read().await;
45 let SessionInner {
46 keypackage_lifetime, ..
47 } = guard.as_ref().ok_or(Error::MlsNotInitialized)?;
48
49 let keypackage = KeyPackage::builder()
50 .leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
51 .key_package_lifetime(Lifetime::new(keypackage_lifetime.as_secs()))
52 .build(
53 CryptoConfig {
54 ciphersuite: cs.into(),
55 version: openmls::versions::ProtocolVersion::default(),
56 },
57 backend,
58 &cb.signature_key,
59 CredentialWithKey {
60 credential: cb.credential.clone(),
61 signature_key: cb.signature_key.public().into(),
62 },
63 )
64 .await
65 .map_err(KeystoreError::wrap("building keypackage"))?;
66
67 Ok(keypackage)
68 }
69
70 pub async fn request_key_packages(
81 &self,
82 count: usize,
83 ciphersuite: MlsCiphersuite,
84 credential_type: MlsCredentialType,
85 backend: &MlsCryptoProvider,
86 ) -> Result<Vec<KeyPackage>> {
87 self.prune_keypackages(backend, std::iter::empty()).await?;
89 use core_crypto_keystore::CryptoKeystoreMls as _;
90
91 let mut existing_kps = backend
92 .key_store()
93 .mls_fetch_keypackages::<KeyPackage>(count as u32)
94 .await.map_err(KeystoreError::wrap("fetching mls keypackages"))?
95 .into_iter()
96 .filter(|kp|
98 kp.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(kp.leaf_node().credential().credential_type()) == credential_type)
99 .collect::<Vec<_>>();
100
101 let kpb_count = existing_kps.len();
102 let mut kps = if count > kpb_count {
103 let to_generate = count - kpb_count;
104 let cb = self
105 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
106 .await?;
107 self.generate_new_keypackages(backend, ciphersuite, &cb, to_generate)
108 .await?
109 } else {
110 vec![]
111 };
112
113 existing_kps.reverse();
114
115 kps.append(&mut existing_kps);
116 Ok(kps)
117 }
118
119 pub(crate) async fn generate_new_keypackages(
120 &self,
121 backend: &MlsCryptoProvider,
122 ciphersuite: MlsCiphersuite,
123 cb: &CredentialBundle,
124 count: usize,
125 ) -> Result<Vec<KeyPackage>> {
126 let mut kps = Vec::with_capacity(count);
127
128 for _ in 0..count {
129 let kp = self
130 .generate_one_keypackage_from_credential_bundle(backend, ciphersuite, cb)
131 .await?;
132 kps.push(kp);
133 }
134
135 Ok(kps)
136 }
137
138 pub async fn valid_keypackages_count(
140 &self,
141 backend: &MlsCryptoProvider,
142 ciphersuite: MlsCiphersuite,
143 credential_type: MlsCredentialType,
144 ) -> Result<usize> {
145 let kps: Vec<MlsKeyPackage> = backend
146 .key_store()
147 .find_all(EntityFindParams::default())
148 .await
149 .map_err(KeystoreError::wrap("finding all key packages"))?;
150
151 let mut valid_count = 0;
152 for kp in kps
153 .into_iter()
154 .map(|kp| core_crypto_keystore::deser::<KeyPackage>(&kp.keypackage))
155 .filter(|kp| {
157 kp.as_ref()
158 .map(|b| b.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(b.leaf_node().credential().credential_type()) == credential_type)
159 .unwrap_or_default()
160 })
161 {
162 let kp = kp.map_err(KeystoreError::wrap("counting valid keypackages"))?;
163 if !Self::is_mls_keypackage_expired(&kp) {
164 valid_count += 1;
165 }
166 }
167
168 Ok(valid_count)
169 }
170
171 fn is_mls_keypackage_expired(kp: &KeyPackage) -> bool {
174 let Some(lifetime) = kp.leaf_node().life_time() else {
175 return false;
176 };
177
178 !(lifetime.has_acceptable_range() && lifetime.is_valid())
179 }
180
181 pub async fn prune_keypackages(
187 &self,
188 backend: &MlsCryptoProvider,
189 refs: impl IntoIterator<Item = KeyPackageRef>,
190 ) -> Result<()> {
191 let keystore = backend.keystore();
192 let kps = self.find_all_keypackages(&keystore).await?;
193 let _ = self._prune_keypackages(&kps, &keystore, refs).await?;
194 Ok(())
195 }
196
197 pub(crate) async fn prune_keypackages_and_credential(
198 &mut self,
199 backend: &MlsCryptoProvider,
200 refs: impl IntoIterator<Item = KeyPackageRef>,
201 ) -> Result<()> {
202 let mut guard = self.inner.write().await;
203 let SessionInner { identities, .. } = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
204
205 let keystore = backend.key_store();
206 let kps = self.find_all_keypackages(keystore).await?;
207 let kp_to_delete = self._prune_keypackages(&kps, keystore, refs).await?;
208
209 let mut grouped_kps = HashMap::<Vec<u8>, Vec<KeyPackageRef>>::new();
211 for (_, kp) in &kps {
212 let cred = kp
213 .leaf_node()
214 .credential()
215 .tls_serialize_detached()
216 .map_err(Error::tls_serialize("keypackage"))?;
217 let kp_ref = kp
218 .hash_ref(backend.crypto())
219 .map_err(MlsError::wrap("computing keypackage hashref"))?;
220 grouped_kps
221 .entry(cred)
222 .and_modify(|kprfs| kprfs.push(kp_ref.clone()))
223 .or_insert(vec![kp_ref]);
224 }
225
226 for (credential, kps) in &grouped_kps {
227 let all_to_delete = kps.iter().all(|kpr| kp_to_delete.contains(&kpr.as_slice()));
229 if all_to_delete {
230 backend
232 .keystore()
233 .cred_delete_by_credential(credential.clone())
234 .await
235 .map_err(KeystoreError::wrap("deleting credential"))?;
236 let credential = Credential::tls_deserialize(&mut credential.as_slice())
237 .map_err(Error::tls_deserialize("credential"))?;
238 identities.remove(&credential).await?;
239 }
240 }
241
242 Ok(())
243 }
244
245 async fn _prune_keypackages<'a>(
250 &self,
251 kps: &'a [(MlsKeyPackage, KeyPackage)],
252 keystore: &CryptoKeystore,
253 refs: impl IntoIterator<Item = KeyPackageRef>,
254 ) -> Result<HashSet<&'a [u8]>, Error> {
255 let refs = refs
256 .into_iter()
257 .map(|kp| {
258 kp.as_slice().to_owned()
268 })
269 .collect::<HashSet<_>>();
270
271 let kp_to_delete = kps.iter().filter_map(|(store_kp, kp)| {
272 let is_expired = Self::is_mls_keypackage_expired(kp);
273 let to_delete = is_expired || refs.contains(store_kp.keypackage_ref.as_slice());
274 to_delete.then_some((kp, &store_kp.keypackage_ref))
275 });
276
277 for (kp, kp_ref) in kp_to_delete.clone() {
279 keystore
280 .remove::<MlsKeyPackage, &[u8]>(kp_ref.as_slice())
281 .await
282 .map_err(KeystoreError::wrap("removing key package from keystore"))?;
283 keystore
284 .remove::<MlsHpkePrivateKey, &[u8]>(kp.hpke_init_key().as_slice())
285 .await
286 .map_err(KeystoreError::wrap("removing private key from keystore"))?;
287 keystore
288 .remove::<MlsEncryptionKeyPair, &[u8]>(kp.leaf_node().encryption_key().as_slice())
289 .await
290 .map_err(KeystoreError::wrap("removing encryption keypair from keystore"))?;
291 }
292
293 Ok(kp_to_delete.map(|(_, kpref)| kpref.as_slice()).collect())
294 }
295
296 async fn find_all_keypackages(&self, keystore: &CryptoKeystore) -> Result<Vec<(MlsKeyPackage, KeyPackage)>> {
297 let kps: Vec<MlsKeyPackage> = keystore
298 .find_all(EntityFindParams::default())
299 .await
300 .map_err(KeystoreError::wrap("finding all keypackages"))?;
301
302 let kps = kps
303 .into_iter()
304 .map(|raw_kp| -> Result<_> {
305 let kp = core_crypto_keystore::deser::<KeyPackage>(&raw_kp.keypackage)
306 .map_err(KeystoreError::wrap("deserializing keypackage"))?;
307 Ok((raw_kp, kp))
308 })
309 .collect::<Result<Vec<_>, _>>()?;
310
311 Ok(kps)
312 }
313
314 #[cfg(test)]
317 pub async fn set_keypackage_lifetime(&self, duration: std::time::Duration) -> Result<()> {
318 use std::ops::DerefMut;
319 match self.inner.write().await.deref_mut() {
320 None => Err(Error::MlsNotInitialized),
321 Some(SessionInner {
322 keypackage_lifetime, ..
323 }) => {
324 *keypackage_lifetime = duration;
325 Ok(())
326 }
327 }
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use openmls::prelude::{KeyPackage, KeyPackageIn, KeyPackageRef, ProtocolVersion};
334 use openmls_traits::OpenMlsCryptoProvider;
335 use openmls_traits::types::VerifiableCiphersuite;
336
337 use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
338
339 use crate::e2e_identity::enrollment::test_utils::{e2ei_enrollment, init_activation_or_rotation, noop_restore};
340 use crate::prelude::MlsConversationConfiguration;
341 use crate::prelude::key_package::INITIAL_KEYING_MATERIAL_COUNT;
342 use crate::test_utils::*;
343 use core_crypto_keystore::{ConnectionType, DatabaseKey};
344
345 use super::Session;
346
347 #[apply(all_cred_cipher)]
348 async fn can_assess_keypackage_expiration(case: TestContext) {
349 let [session] = case.sessions().await;
350 let (cs, ct) = (case.ciphersuite(), case.credential_type);
351 let key = DatabaseKey::generate();
352 let key_store = CryptoKeystore::open(ConnectionType::InMemory, &key).await.unwrap();
353 let backend = MlsCryptoProvider::builder().key_store(key_store).build();
354 let x509_test_chain = if case.is_x509() {
355 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
356 x509_test_chain.register_with_provider(&backend).await;
357 Some(x509_test_chain)
358 } else {
359 None
360 };
361
362 backend.new_transaction().await.unwrap();
363 let session = session.session;
364 session
365 .random_generate(
366 &case,
367 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
368 false,
369 )
370 .await
371 .unwrap();
372
373 let kp_std_exp = session.generate_one_keypackage(&backend, cs, ct).await.unwrap();
375 assert!(!Session::is_mls_keypackage_expired(&kp_std_exp));
376
377 session
379 .set_keypackage_lifetime(std::time::Duration::from_secs(1))
380 .await
381 .unwrap();
382 let kp_1s_exp = session.generate_one_keypackage(&backend, cs, ct).await.unwrap();
383 async_std::task::sleep(std::time::Duration::from_secs(2)).await;
385 assert!(Session::is_mls_keypackage_expired(&kp_1s_exp));
386 }
387
388 #[apply(all_cred_cipher)]
389 async fn requesting_x509_key_packages_after_basic(case: TestContext) {
390 if !case.is_basic() {
392 return;
393 }
394
395 let [session_context] = case.sessions_basic_with_pki_env().await;
396 Box::pin(async move {
397 let signature_scheme = case.signature_scheme();
398 let cipher_suite = case.ciphersuite();
399
400 let _basic_key_packages = session_context
402 .transaction
403 .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::Basic, 5)
404 .await
405 .unwrap();
406
407 let test_chain = session_context.x509_chain_unchecked();
409
410 let (mut enrollment, cert_chain) = e2ei_enrollment(
411 &session_context,
412 &case,
413 test_chain,
414 None,
415 false,
416 init_activation_or_rotation,
417 noop_restore,
418 )
419 .await
420 .unwrap();
421
422 let _rotate_bundle = session_context
423 .transaction
424 .save_x509_credential(&mut enrollment, cert_chain)
425 .await
426 .unwrap();
427
428 assert!(
430 session_context
431 .transaction
432 .e2ei_is_enabled(signature_scheme)
433 .await
434 .unwrap()
435 );
436
437 let x509_key_packages = session_context
439 .transaction
440 .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::X509, 5)
441 .await
442 .unwrap();
443
444 assert!(
446 x509_key_packages.iter().all(|kp| MlsCredentialType::X509
447 == MlsCredentialType::from(kp.leaf_node().credential().credential_type()))
448 );
449 })
450 .await
451 }
452
453 #[apply(all_cred_cipher)]
454 async fn generates_correct_number_of_kpbs(case: TestContext) {
455 let [cc] = case.sessions().await;
456 Box::pin(async move {
457 const N: usize = 2;
458 const COUNT: usize = 109;
459
460 let init = cc.transaction.count_entities().await;
461 assert_eq!(init.key_package, INITIAL_KEYING_MATERIAL_COUNT);
462 assert_eq!(init.encryption_keypair, INITIAL_KEYING_MATERIAL_COUNT);
463 assert_eq!(init.hpke_private_key, INITIAL_KEYING_MATERIAL_COUNT);
464 assert_eq!(init.credential, 1);
465 assert_eq!(init.signature_keypair, 1);
466
467 let transactional_provider = cc.transaction.mls_provider().await.unwrap();
470 let crypto_provider = transactional_provider.crypto();
471 let mut pinned_kp = None;
472
473 let mut prev_kps: Option<Vec<KeyPackage>> = None;
474 for _ in 0..N {
475 let mut kps = cc
476 .transaction
477 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, COUNT + 1)
478 .await
479 .unwrap();
480
481 pinned_kp = Some(kps.pop().unwrap());
483
484 assert_eq!(kps.len(), COUNT);
485 let after_creation = cc.transaction.count_entities().await;
486 assert_eq!(after_creation.key_package, COUNT + 1);
487 assert_eq!(after_creation.encryption_keypair, COUNT + 1);
488 assert_eq!(after_creation.hpke_private_key, COUNT + 1);
489 assert_eq!(after_creation.credential, 1);
490
491 let kpbs_refs = kps
492 .iter()
493 .map(|kp| kp.hash_ref(crypto_provider).unwrap())
494 .collect::<Vec<KeyPackageRef>>();
495
496 if let Some(pkpbs) = prev_kps.replace(kps) {
497 let pkpbs_refs = pkpbs
498 .into_iter()
499 .map(|kpb| kpb.hash_ref(crypto_provider).unwrap())
500 .collect::<Vec<KeyPackageRef>>();
501
502 let has_duplicates = kpbs_refs.iter().any(|href| pkpbs_refs.contains(href));
503 assert!(!has_duplicates);
505 }
506 cc.transaction.delete_keypackages(kpbs_refs).await.unwrap();
507 }
508
509 let count = cc
510 .transaction
511 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
512 .await
513 .unwrap();
514 assert_eq!(count, 1);
515
516 let pinned_kpr = pinned_kp.unwrap().hash_ref(crypto_provider).unwrap();
517 cc.transaction.delete_keypackages([pinned_kpr]).await.unwrap();
518 let count = cc
519 .transaction
520 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
521 .await
522 .unwrap();
523 assert_eq!(count, 0);
524 let after_delete = cc.transaction.count_entities().await;
525 assert_eq!(after_delete.key_package, 0);
526 assert_eq!(after_delete.encryption_keypair, 0);
527 assert_eq!(after_delete.hpke_private_key, 0);
528 assert_eq!(after_delete.credential, 0);
529 })
530 .await
531 }
532
533 #[apply(all_cred_cipher)]
534 async fn automatically_prunes_lifetime_expired_keypackages(case: TestContext) {
535 let [session] = case.sessions().await;
536 const UNEXPIRED_COUNT: usize = 125;
537 const EXPIRED_COUNT: usize = 200;
538 let key = DatabaseKey::generate();
539 let key_store = CryptoKeystore::open(ConnectionType::InMemory, &key).await.unwrap();
540 let backend = MlsCryptoProvider::builder().key_store(key_store).build();
541 let x509_test_chain = if case.is_x509() {
542 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
543 x509_test_chain.register_with_provider(&backend).await;
544 Some(x509_test_chain)
545 } else {
546 None
547 };
548 backend.new_transaction().await.unwrap();
549 let session = session.session().await;
550 session
551 .random_generate(
552 &case,
553 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
554 false,
555 )
556 .await
557 .unwrap();
558
559 let unexpired_kpbs = session
561 .request_key_packages(UNEXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
562 .await
563 .unwrap();
564 let len = session
565 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
566 .await
567 .unwrap();
568 assert_eq!(len, unexpired_kpbs.len());
569 assert_eq!(len, UNEXPIRED_COUNT);
570
571 session
573 .set_keypackage_lifetime(std::time::Duration::from_secs(10))
574 .await
575 .unwrap();
576
577 let partially_expired_kpbs = session
579 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
580 .await
581 .unwrap();
582 assert_eq!(partially_expired_kpbs.len(), EXPIRED_COUNT);
583
584 async_std::task::sleep(std::time::Duration::from_secs(10)).await;
586
587 let fresh_kpbs = session
590 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
591 .await
592 .unwrap();
593 let len = session
594 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
595 .await
596 .unwrap();
597 assert_eq!(len, fresh_kpbs.len());
598 assert_eq!(len, EXPIRED_COUNT);
599
600 let (unexpired_match, expired_match) =
602 fresh_kpbs
603 .iter()
604 .fold((0usize, 0usize), |(mut unexpired_match, mut expired_match), fresh| {
605 if unexpired_kpbs.iter().any(|kp| kp == fresh) {
606 unexpired_match += 1;
607 } else if partially_expired_kpbs.iter().any(|kpb| kpb == fresh) {
608 expired_match += 1;
609 }
610
611 (unexpired_match, expired_match)
612 });
613
614 assert_eq!(unexpired_match, UNEXPIRED_COUNT);
616 assert_eq!(expired_match, 0);
617 }
618
619 #[apply(all_cred_cipher)]
620 async fn new_keypackage_has_correct_extensions(case: TestContext) {
621 let [cc] = case.sessions().await;
622 Box::pin(async move {
623 let kps = cc
624 .transaction
625 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 1)
626 .await
627 .unwrap();
628 let kp = kps.first().unwrap();
629
630 let _ = KeyPackageIn::from(kp.clone())
632 .standalone_validate(
633 &cc.transaction.mls_provider().await.unwrap(),
634 ProtocolVersion::Mls10,
635 true,
636 )
637 .await
638 .unwrap();
639
640 assert!(kp.extensions().is_empty());
642
643 assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
644 assert_eq!(
645 kp.leaf_node().capabilities().ciphersuites().to_vec(),
646 MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
647 .iter()
648 .map(|c| VerifiableCiphersuite::from(*c))
649 .collect::<Vec<_>>()
650 );
651 assert!(kp.leaf_node().capabilities().proposals().is_empty());
652 assert!(kp.leaf_node().capabilities().extensions().is_empty());
653 assert_eq!(
654 kp.leaf_node().capabilities().credentials(),
655 MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
656 );
657 })
658 .await
659 }
660}