1use std::collections::{HashMap, HashSet};
2
3use core_crypto_keystore::{
4 connection::FetchFromDatabase,
5 entities::{EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage},
6};
7use mls_crypto_provider::{Database, MlsCryptoProvider};
8use openmls::prelude::{Credential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime};
9use openmls_traits::OpenMlsCryptoProvider;
10use tls_codec::{Deserialize, Serialize};
11
12use super::{Error, Result};
13use crate::{
14 KeystoreError, MlsCiphersuite, MlsConversationConfiguration, MlsCredentialType, MlsError, Session,
15 mls::{credential::CredentialBundle, session::SessionInner},
16};
17
18#[cfg(not(test))]
20pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 100;
21#[cfg(test)]
23pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
24
25pub(crate) const KEYPACKAGE_DEFAULT_LIFETIME: std::time::Duration =
27 std::time::Duration::from_secs(60 * 60 * 24 * 28 * 3); impl Session {
30 pub async fn generate_one_keypackage_from_credential_bundle(
38 &self,
39 backend: &MlsCryptoProvider,
40 cs: MlsCiphersuite,
41 cb: &CredentialBundle,
42 ) -> Result<KeyPackage> {
43 let guard = self.inner.read().await;
44 let SessionInner {
45 keypackage_lifetime, ..
46 } = guard.as_ref().ok_or(Error::MlsNotInitialized)?;
47
48 let keypackage = KeyPackage::builder()
49 .leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
50 .key_package_lifetime(Lifetime::new(keypackage_lifetime.as_secs()))
51 .build(
52 CryptoConfig {
53 ciphersuite: cs.into(),
54 version: openmls::versions::ProtocolVersion::default(),
55 },
56 backend,
57 &cb.signature_key,
58 CredentialWithKey {
59 credential: cb.credential.clone(),
60 signature_key: cb.signature_key.public().into(),
61 },
62 )
63 .await
64 .map_err(KeystoreError::wrap("building keypackage"))?;
65
66 Ok(keypackage)
67 }
68
69 pub async fn request_key_packages(
80 &self,
81 count: usize,
82 ciphersuite: MlsCiphersuite,
83 credential_type: MlsCredentialType,
84 backend: &MlsCryptoProvider,
85 ) -> Result<Vec<KeyPackage>> {
86 self.prune_keypackages(backend, std::iter::empty()).await?;
88 use core_crypto_keystore::CryptoKeystoreMls as _;
89
90 let mut existing_kps = backend
91 .key_store()
92 .mls_fetch_keypackages::<KeyPackage>(count as u32)
93 .await.map_err(KeystoreError::wrap("fetching mls keypackages"))?
94 .into_iter()
95 .filter(|kp|
97 kp.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(kp.leaf_node().credential().credential_type()) == credential_type)
98 .collect::<Vec<_>>();
99
100 let kpb_count = existing_kps.len();
101 let mut kps = if count > kpb_count {
102 let to_generate = count - kpb_count;
103 let cb = self
104 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
105 .await?;
106 self.generate_new_keypackages(backend, ciphersuite, &cb, to_generate)
107 .await?
108 } else {
109 vec![]
110 };
111
112 existing_kps.reverse();
113
114 kps.append(&mut existing_kps);
115 Ok(kps)
116 }
117
118 pub(crate) async fn generate_new_keypackages(
119 &self,
120 backend: &MlsCryptoProvider,
121 ciphersuite: MlsCiphersuite,
122 cb: &CredentialBundle,
123 count: usize,
124 ) -> Result<Vec<KeyPackage>> {
125 let mut kps = Vec::with_capacity(count);
126
127 for _ in 0..count {
128 let kp = self
129 .generate_one_keypackage_from_credential_bundle(backend, ciphersuite, cb)
130 .await?;
131 kps.push(kp);
132 }
133
134 Ok(kps)
135 }
136
137 pub async fn valid_keypackages_count(
139 &self,
140 backend: &MlsCryptoProvider,
141 ciphersuite: MlsCiphersuite,
142 credential_type: MlsCredentialType,
143 ) -> Result<usize> {
144 let kps: Vec<MlsKeyPackage> = backend
145 .key_store()
146 .find_all(EntityFindParams::default())
147 .await
148 .map_err(KeystoreError::wrap("finding all key packages"))?;
149
150 let mut valid_count = 0;
151 for kp in kps
152 .into_iter()
153 .map(|kp| core_crypto_keystore::deser::<KeyPackage>(&kp.keypackage))
154 .filter(|kp| {
156 kp.as_ref()
157 .map(|b| b.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(b.leaf_node().credential().credential_type()) == credential_type)
158 .unwrap_or_default()
159 })
160 {
161 let kp = kp.map_err(KeystoreError::wrap("counting valid keypackages"))?;
162 if !Self::is_mls_keypackage_expired(&kp) {
163 valid_count += 1;
164 }
165 }
166
167 Ok(valid_count)
168 }
169
170 fn is_mls_keypackage_expired(kp: &KeyPackage) -> bool {
173 let Some(lifetime) = kp.leaf_node().life_time() else {
174 return false;
175 };
176
177 !(lifetime.has_acceptable_range() && lifetime.is_valid())
178 }
179
180 pub async fn prune_keypackages(
186 &self,
187 backend: &MlsCryptoProvider,
188 refs: impl IntoIterator<Item = KeyPackageRef>,
189 ) -> Result<()> {
190 let keystore = backend.keystore();
191 let kps = self.find_all_keypackages(&keystore).await?;
192 let _ = self._prune_keypackages(&kps, &keystore, refs).await?;
193 Ok(())
194 }
195
196 pub(crate) async fn prune_keypackages_and_credential(
197 &mut self,
198 backend: &MlsCryptoProvider,
199 refs: impl IntoIterator<Item = KeyPackageRef>,
200 ) -> Result<()> {
201 let mut guard = self.inner.write().await;
202 let SessionInner { identities, .. } = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
203
204 let keystore = backend.key_store();
205 let kps = self.find_all_keypackages(keystore).await?;
206 let kp_to_delete = self._prune_keypackages(&kps, keystore, refs).await?;
207
208 let mut grouped_kps = HashMap::<Vec<u8>, Vec<KeyPackageRef>>::new();
210 for (_, kp) in &kps {
211 let cred = kp
212 .leaf_node()
213 .credential()
214 .tls_serialize_detached()
215 .map_err(Error::tls_serialize("keypackage"))?;
216 let kp_ref = kp
217 .hash_ref(backend.crypto())
218 .map_err(MlsError::wrap("computing keypackage hashref"))?;
219 grouped_kps
220 .entry(cred)
221 .and_modify(|kprfs| kprfs.push(kp_ref.clone()))
222 .or_insert(vec![kp_ref]);
223 }
224
225 for (credential, kps) in &grouped_kps {
226 let all_to_delete = kps.iter().all(|kpr| kp_to_delete.contains(&kpr.as_slice()));
228 if all_to_delete {
229 backend
231 .keystore()
232 .cred_delete_by_credential(credential.clone())
233 .await
234 .map_err(KeystoreError::wrap("deleting credential"))?;
235 let credential = Credential::tls_deserialize(&mut credential.as_slice())
236 .map_err(Error::tls_deserialize("credential"))?;
237 identities.remove(&credential).await?;
238 }
239 }
240
241 Ok(())
242 }
243
244 async fn _prune_keypackages<'a>(
249 &self,
250 kps: &'a [(MlsKeyPackage, KeyPackage)],
251 keystore: &Database,
252 refs: impl IntoIterator<Item = KeyPackageRef>,
253 ) -> Result<HashSet<&'a [u8]>, Error> {
254 let refs = refs
255 .into_iter()
256 .map(|kp| {
257 kp.as_slice().to_owned()
267 })
268 .collect::<HashSet<_>>();
269
270 let kp_to_delete = kps.iter().filter_map(|(store_kp, kp)| {
271 let is_expired = Self::is_mls_keypackage_expired(kp);
272 let to_delete = is_expired || refs.contains(store_kp.keypackage_ref.as_slice());
273 to_delete.then_some((kp, &store_kp.keypackage_ref))
274 });
275
276 for (kp, kp_ref) in kp_to_delete.clone() {
278 keystore
279 .remove::<MlsKeyPackage, &[u8]>(kp_ref.as_slice())
280 .await
281 .map_err(KeystoreError::wrap("removing key package from keystore"))?;
282 keystore
283 .remove::<MlsHpkePrivateKey, &[u8]>(kp.hpke_init_key().as_slice())
284 .await
285 .map_err(KeystoreError::wrap("removing private key from keystore"))?;
286 keystore
287 .remove::<MlsEncryptionKeyPair, &[u8]>(kp.leaf_node().encryption_key().as_slice())
288 .await
289 .map_err(KeystoreError::wrap("removing encryption keypair from keystore"))?;
290 }
291
292 Ok(kp_to_delete.map(|(_, kpref)| kpref.as_slice()).collect())
293 }
294
295 async fn find_all_keypackages(&self, keystore: &Database) -> Result<Vec<(MlsKeyPackage, KeyPackage)>> {
296 let kps: Vec<MlsKeyPackage> = keystore
297 .find_all(EntityFindParams::default())
298 .await
299 .map_err(KeystoreError::wrap("finding all keypackages"))?;
300
301 let kps = kps
302 .into_iter()
303 .map(|raw_kp| -> Result<_> {
304 let kp = core_crypto_keystore::deser::<KeyPackage>(&raw_kp.keypackage)
305 .map_err(KeystoreError::wrap("deserializing keypackage"))?;
306 Ok((raw_kp, kp))
307 })
308 .collect::<Result<Vec<_>, _>>()?;
309
310 Ok(kps)
311 }
312
313 #[cfg(test)]
316 pub async fn set_keypackage_lifetime(&self, duration: std::time::Duration) -> Result<()> {
317 use std::ops::DerefMut;
318 match self.inner.write().await.deref_mut() {
319 None => Err(Error::MlsNotInitialized),
320 Some(SessionInner {
321 keypackage_lifetime, ..
322 }) => {
323 *keypackage_lifetime = duration;
324 Ok(())
325 }
326 }
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use core_crypto_keystore::{ConnectionType, DatabaseKey};
333 use mls_crypto_provider::{Database, MlsCryptoProvider};
334 use openmls::prelude::{KeyPackage, KeyPackageIn, KeyPackageRef, ProtocolVersion};
335 use openmls_traits::{OpenMlsCryptoProvider, types::VerifiableCiphersuite};
336
337 use super::Session;
338 use crate::{
339 MlsConversationConfiguration,
340 e2e_identity::enrollment::test_utils::{e2ei_enrollment, init_activation_or_rotation, noop_restore},
341 test_utils::*,
342 };
343
344 #[apply(all_cred_cipher)]
345 async fn can_assess_keypackage_expiration(case: TestContext) {
346 let [session] = case.sessions().await;
347 let (cs, ct) = (case.ciphersuite(), case.credential_type);
348 let key = DatabaseKey::generate();
349 let database = Database::open(ConnectionType::InMemory, &key).await.unwrap();
350 let backend = MlsCryptoProvider::new(database);
351 let x509_test_chain = if case.is_x509() {
352 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
353 x509_test_chain.register_with_provider(&backend).await;
354 Some(x509_test_chain)
355 } else {
356 None
357 };
358
359 backend.new_transaction().await.unwrap();
360 let session = session.session;
361 session
362 .random_generate(
363 &case,
364 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
365 )
366 .await
367 .unwrap();
368
369 let kp_std_exp = session.generate_one_keypackage(&backend, cs, ct).await.unwrap();
371 assert!(!Session::is_mls_keypackage_expired(&kp_std_exp));
372
373 session
375 .set_keypackage_lifetime(std::time::Duration::from_secs(1))
376 .await
377 .unwrap();
378 let kp_1s_exp = session.generate_one_keypackage(&backend, cs, ct).await.unwrap();
379 smol::Timer::after(std::time::Duration::from_secs(2)).await;
381 assert!(Session::is_mls_keypackage_expired(&kp_1s_exp));
382 }
383
384 #[apply(all_cred_cipher)]
385 async fn requesting_x509_key_packages_after_basic(case: TestContext) {
386 if !case.is_basic() {
388 return;
389 }
390
391 let [session_context] = case.sessions_basic_with_pki_env().await;
392 Box::pin(async move {
393 let signature_scheme = case.signature_scheme();
394 let cipher_suite = case.ciphersuite();
395
396 let _basic_key_packages = session_context
398 .transaction
399 .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::Basic, 5)
400 .await
401 .unwrap();
402
403 let test_chain = session_context.x509_chain_unchecked();
405
406 let (mut enrollment, cert_chain) = e2ei_enrollment(
407 &session_context,
408 &case,
409 test_chain,
410 None,
411 false,
412 init_activation_or_rotation,
413 noop_restore,
414 )
415 .await
416 .unwrap();
417
418 let _rotate_bundle = session_context
419 .transaction
420 .save_x509_credential(&mut enrollment, cert_chain)
421 .await
422 .unwrap();
423
424 assert!(
426 session_context
427 .transaction
428 .e2ei_is_enabled(signature_scheme)
429 .await
430 .unwrap()
431 );
432
433 let x509_key_packages = session_context
435 .transaction
436 .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::X509, 5)
437 .await
438 .unwrap();
439
440 assert!(
442 x509_key_packages.iter().all(|kp| MlsCredentialType::X509
443 == MlsCredentialType::from(kp.leaf_node().credential().credential_type()))
444 );
445 })
446 .await
447 }
448
449 #[apply(all_cred_cipher)]
450 async fn generates_correct_number_of_kpbs(case: TestContext) {
451 let [cc] = case.sessions().await;
452 Box::pin(async move {
453 const N: usize = 2;
454 const COUNT: usize = 109;
455
456 let init = cc.transaction.count_entities().await;
457 assert_eq!(init.key_package, 0);
458 assert_eq!(init.encryption_keypair, 0);
459 assert_eq!(init.hpke_private_key, 0);
460 assert_eq!(init.credential, 1);
461 assert_eq!(init.signature_keypair, 1);
462
463 let transactional_provider = cc.transaction.mls_provider().await.unwrap();
466 let crypto_provider = transactional_provider.crypto();
467 let mut pinned_kp = None;
468
469 let mut prev_kps: Option<Vec<KeyPackage>> = None;
470 for _ in 0..N {
471 let mut kps = cc
472 .transaction
473 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, COUNT + 1)
474 .await
475 .unwrap();
476
477 pinned_kp = Some(kps.pop().unwrap());
479
480 assert_eq!(kps.len(), COUNT);
481 let after_creation = cc.transaction.count_entities().await;
482 assert_eq!(after_creation.key_package, COUNT + 1);
483 assert_eq!(after_creation.encryption_keypair, COUNT + 1);
484 assert_eq!(after_creation.hpke_private_key, COUNT + 1);
485 assert_eq!(after_creation.credential, 1);
486
487 let kpbs_refs = kps
488 .iter()
489 .map(|kp| kp.hash_ref(crypto_provider).unwrap())
490 .collect::<Vec<KeyPackageRef>>();
491
492 if let Some(pkpbs) = prev_kps.replace(kps) {
493 let pkpbs_refs = pkpbs
494 .into_iter()
495 .map(|kpb| kpb.hash_ref(crypto_provider).unwrap())
496 .collect::<Vec<KeyPackageRef>>();
497
498 let has_duplicates = kpbs_refs.iter().any(|href| pkpbs_refs.contains(href));
499 assert!(!has_duplicates);
501 }
502 cc.transaction.delete_keypackages(kpbs_refs).await.unwrap();
503 }
504
505 let count = cc
506 .transaction
507 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
508 .await
509 .unwrap();
510 assert_eq!(count, 1);
511
512 let pinned_kpr = pinned_kp.unwrap().hash_ref(crypto_provider).unwrap();
513 cc.transaction.delete_keypackages([pinned_kpr]).await.unwrap();
514 let count = cc
515 .transaction
516 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
517 .await
518 .unwrap();
519 assert_eq!(count, 0);
520 let after_delete = cc.transaction.count_entities().await;
521 assert_eq!(after_delete.key_package, 0);
522 assert_eq!(after_delete.encryption_keypair, 0);
523 assert_eq!(after_delete.hpke_private_key, 0);
524 assert_eq!(after_delete.credential, 0);
525 })
526 .await
527 }
528
529 #[apply(all_cred_cipher)]
530 async fn automatically_prunes_lifetime_expired_keypackages(case: TestContext) {
531 let [session] = case.sessions().await;
532 const UNEXPIRED_COUNT: usize = 125;
533 const EXPIRED_COUNT: usize = 200;
534 let key = DatabaseKey::generate();
535 let key_store = Database::open(ConnectionType::InMemory, &key).await.unwrap();
536 let backend = MlsCryptoProvider::new(key_store);
537 let x509_test_chain = if case.is_x509() {
538 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
539 x509_test_chain.register_with_provider(&backend).await;
540 Some(x509_test_chain)
541 } else {
542 None
543 };
544 backend.new_transaction().await.unwrap();
545 let session = session.session().await;
546 session
547 .random_generate(
548 &case,
549 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
550 )
551 .await
552 .unwrap();
553
554 let unexpired_kpbs = session
556 .request_key_packages(UNEXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
557 .await
558 .unwrap();
559 let len = session
560 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
561 .await
562 .unwrap();
563 assert_eq!(len, unexpired_kpbs.len());
564 assert_eq!(len, UNEXPIRED_COUNT);
565
566 session
568 .set_keypackage_lifetime(std::time::Duration::from_secs(10))
569 .await
570 .unwrap();
571
572 let partially_expired_kpbs = session
574 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
575 .await
576 .unwrap();
577 assert_eq!(partially_expired_kpbs.len(), EXPIRED_COUNT);
578
579 smol::Timer::after(std::time::Duration::from_secs(10)).await;
581
582 let fresh_kpbs = session
585 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
586 .await
587 .unwrap();
588 let len = session
589 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
590 .await
591 .unwrap();
592 assert_eq!(len, fresh_kpbs.len());
593 assert_eq!(len, EXPIRED_COUNT);
594
595 let (unexpired_match, expired_match) =
597 fresh_kpbs
598 .iter()
599 .fold((0usize, 0usize), |(mut unexpired_match, mut expired_match), fresh| {
600 if unexpired_kpbs.iter().any(|kp| kp == fresh) {
601 unexpired_match += 1;
602 } else if partially_expired_kpbs.iter().any(|kpb| kpb == fresh) {
603 expired_match += 1;
604 }
605
606 (unexpired_match, expired_match)
607 });
608
609 assert_eq!(unexpired_match, UNEXPIRED_COUNT);
611 assert_eq!(expired_match, 0);
612 }
613
614 #[apply(all_cred_cipher)]
615 async fn new_keypackage_has_correct_extensions(case: TestContext) {
616 let [cc] = case.sessions().await;
617 Box::pin(async move {
618 let kps = cc
619 .transaction
620 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 1)
621 .await
622 .unwrap();
623 let kp = kps.first().unwrap();
624
625 let _ = KeyPackageIn::from(kp.clone())
627 .standalone_validate(
628 &cc.transaction.mls_provider().await.unwrap(),
629 ProtocolVersion::Mls10,
630 true,
631 )
632 .await
633 .unwrap();
634
635 assert!(kp.extensions().is_empty());
637
638 assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
639 assert_eq!(
640 kp.leaf_node().capabilities().ciphersuites().to_vec(),
641 MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
642 .iter()
643 .map(|c| VerifiableCiphersuite::from(*c))
644 .collect::<Vec<_>>()
645 );
646 assert!(kp.leaf_node().capabilities().proposals().is_empty());
647 assert!(kp.leaf_node().capabilities().extensions().is_empty());
648 assert_eq!(
649 kp.leaf_node().capabilities().credentials(),
650 MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
651 );
652 })
653 .await
654 }
655}