1use std::collections::{HashMap, HashSet};
2
3use core_crypto_keystore::{
4 connection::FetchFromDatabase,
5 entities::{EntityFindParams, StoredEncryptionKeyPair, StoredHpkePrivateKey, StoredKeypackage},
6};
7use mls_crypto_provider::{Database, MlsCryptoProvider};
8use openmls::prelude::{
9 Credential as MlsCredential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime,
10};
11use openmls_traits::OpenMlsCryptoProvider;
12use tls_codec::{Deserialize, Serialize};
13
14use super::{Error, Result};
15use crate::{
16 Ciphersuite, Credential, CredentialType, KeystoreError, MlsConversationConfiguration, MlsError, Session,
17 mls::session::SessionInner,
18};
19
20#[cfg(not(test))]
22pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 100;
23#[cfg(test)]
25pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
26
27pub(crate) const KEYPACKAGE_DEFAULT_LIFETIME: std::time::Duration =
29 std::time::Duration::from_secs(60 * 60 * 24 * 28 * 3); impl Session {
32 pub async fn generate_one_keypackage_from_credential(
40 &self,
41 backend: &MlsCryptoProvider,
42 cs: Ciphersuite,
43 cb: &Credential,
44 ) -> Result<KeyPackage> {
45 let guard = self.inner.read().await;
46 let SessionInner {
47 keypackage_lifetime, ..
48 } = guard.as_ref().ok_or(Error::MlsNotInitialized)?;
49
50 let keypackage = KeyPackage::builder()
51 .leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
52 .key_package_lifetime(Lifetime::new(keypackage_lifetime.as_secs()))
53 .build(
54 CryptoConfig {
55 ciphersuite: cs.into(),
56 version: openmls::versions::ProtocolVersion::default(),
57 },
58 backend,
59 &cb.signature_key_pair,
60 CredentialWithKey {
61 credential: cb.mls_credential.clone(),
62 signature_key: cb.signature_key_pair.public().into(),
63 },
64 )
65 .await
66 .map_err(KeystoreError::wrap("building keypackage"))?;
67
68 Ok(keypackage)
69 }
70
71 pub async fn request_key_packages(
82 &self,
83 count: usize,
84 ciphersuite: Ciphersuite,
85 credential_type: CredentialType,
86 backend: &MlsCryptoProvider,
87 ) -> Result<Vec<KeyPackage>> {
88 self.prune_keypackages(backend, std::iter::empty()).await?;
90 use core_crypto_keystore::CryptoKeystoreMls as _;
91
92 let mut existing_kps = backend
93 .key_store()
94 .mls_fetch_keypackages::<KeyPackage>(count as u32)
95 .await.map_err(KeystoreError::wrap("fetching mls keypackages"))?
96 .into_iter()
97 .filter(|kp|
99 ciphersuite == kp.ciphersuite() && credential_type == kp.leaf_node().credential().credential_type() )
100 .collect::<Vec<_>>();
101
102 let kpb_count = existing_kps.len();
103 let mut kps = if count > kpb_count {
104 let to_generate = count - kpb_count;
105 let cb = self
106 .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type)
107 .await?;
108 self.generate_new_keypackages(backend, ciphersuite, &cb, to_generate)
109 .await?
110 } else {
111 vec![]
112 };
113
114 existing_kps.reverse();
115
116 kps.append(&mut existing_kps);
117 Ok(kps)
118 }
119
120 pub(crate) async fn generate_new_keypackages(
121 &self,
122 backend: &MlsCryptoProvider,
123 ciphersuite: Ciphersuite,
124 cb: &Credential,
125 count: usize,
126 ) -> Result<Vec<KeyPackage>> {
127 let mut kps = Vec::with_capacity(count);
128
129 for _ in 0..count {
130 let kp = self
131 .generate_one_keypackage_from_credential(backend, ciphersuite, cb)
132 .await?;
133 kps.push(kp);
134 }
135
136 Ok(kps)
137 }
138
139 pub async fn valid_keypackages_count(
141 &self,
142 backend: &MlsCryptoProvider,
143 ciphersuite: Ciphersuite,
144 credential_type: CredentialType,
145 ) -> Result<usize> {
146 let kps: Vec<StoredKeypackage> = backend
147 .key_store()
148 .find_all(EntityFindParams::default())
149 .await
150 .map_err(KeystoreError::wrap("finding all key packages"))?;
151
152 let mut valid_count = 0;
153 for kp in kps
154 .into_iter()
155 .map(|kp| core_crypto_keystore::deser::<KeyPackage>(&kp.keypackage))
156 .filter(|kp_result| {
158 kp_result.as_ref().ok().is_none_or(|key_package| ciphersuite == key_package.ciphersuite() && credential_type == key_package.leaf_node().credential().credential_type())
159 })
160 {
161 let kp = kp.map_err(KeystoreError::wrap("counting valid keypackages"))?;
162 if !Self::is_mls_keypackage_expired(&kp) {
163 valid_count += 1;
164 }
165 }
166
167 Ok(valid_count)
168 }
169
170 fn is_mls_keypackage_expired(kp: &KeyPackage) -> bool {
173 let Some(lifetime) = kp.leaf_node().life_time() else {
174 return false;
175 };
176
177 !(lifetime.has_acceptable_range() && lifetime.is_valid())
178 }
179
180 pub async fn prune_keypackages(
187 &self,
188 backend: &MlsCryptoProvider,
189 refs: impl IntoIterator<Item = KeyPackageRef>,
190 ) -> Result<()> {
191 let keystore = backend.keystore();
192 let kps = self.find_all_keypackages(&keystore).await?;
193 let _ = self._prune_keypackages(&kps, &keystore, refs).await?;
194 Ok(())
195 }
196
197 pub(crate) async fn prune_keypackages_and_credential(
198 &mut self,
199 backend: &MlsCryptoProvider,
200 refs: impl IntoIterator<Item = KeyPackageRef>,
201 ) -> Result<()> {
202 let mut guard = self.inner.write().await;
203 let SessionInner { identities, .. } = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
204
205 let keystore = backend.key_store();
206 let kps = self.find_all_keypackages(keystore).await?;
207 let kp_to_delete = self._prune_keypackages(&kps, keystore, refs).await?;
208
209 let mut grouped_kps = HashMap::<Vec<u8>, Vec<KeyPackageRef>>::new();
211 for (_, kp) in &kps {
212 let cred = kp
213 .leaf_node()
214 .credential()
215 .tls_serialize_detached()
216 .map_err(Error::tls_serialize("keypackage"))?;
217 let kp_ref = kp
218 .hash_ref(backend.crypto())
219 .map_err(MlsError::wrap("computing keypackage hashref"))?;
220 grouped_kps.entry(cred).or_default().push(kp_ref);
221 }
222
223 for (credential, kps) in &grouped_kps {
224 let all_to_delete = kps.iter().all(|kpr| kp_to_delete.contains(&kpr.as_slice()));
226 if all_to_delete {
227 backend
229 .keystore()
230 .cred_delete_by_credential(credential.clone())
231 .await
232 .map_err(KeystoreError::wrap("deleting credential"))?;
233 let credential = MlsCredential::tls_deserialize(&mut credential.as_slice())
234 .map_err(Error::tls_deserialize("credential"))?;
235 identities.remove_by_mls_credential(&credential);
236 }
237 }
238
239 Ok(())
240 }
241
242 async fn _prune_keypackages<'a>(
247 &self,
248 kps: &'a [(StoredKeypackage, KeyPackage)],
249 keystore: &Database,
250 refs: impl IntoIterator<Item = KeyPackageRef>,
251 ) -> Result<HashSet<&'a [u8]>, Error> {
252 let refs = refs
253 .into_iter()
254 .map(|kp| {
255 kp.as_slice().to_owned()
265 })
266 .collect::<HashSet<_>>();
267
268 let kp_to_delete = kps.iter().filter_map(|(store_kp, kp)| {
269 let is_expired = Self::is_mls_keypackage_expired(kp);
270 let to_delete = is_expired || refs.contains(store_kp.keypackage_ref.as_slice());
271 to_delete.then_some((kp, &store_kp.keypackage_ref))
272 });
273
274 for (kp, kp_ref) in kp_to_delete.clone() {
276 keystore
277 .remove::<StoredKeypackage, &[u8]>(kp_ref.as_slice())
278 .await
279 .map_err(KeystoreError::wrap("removing key package from keystore"))?;
280 keystore
281 .remove::<StoredHpkePrivateKey, &[u8]>(kp.hpke_init_key().as_slice())
282 .await
283 .map_err(KeystoreError::wrap("removing private key from keystore"))?;
284 keystore
285 .remove::<StoredEncryptionKeyPair, &[u8]>(kp.leaf_node().encryption_key().as_slice())
286 .await
287 .map_err(KeystoreError::wrap("removing encryption keypair from keystore"))?;
288 }
289
290 Ok(kp_to_delete.map(|(_, kpref)| kpref.as_slice()).collect())
291 }
292
293 pub(super) async fn find_all_keypackages(
294 &self,
295 keystore: &Database,
296 ) -> Result<Vec<(StoredKeypackage, KeyPackage)>> {
297 let kps: Vec<StoredKeypackage> = keystore
298 .find_all(EntityFindParams::default())
299 .await
300 .map_err(KeystoreError::wrap("finding all keypackages"))?;
301
302 let kps = kps
303 .into_iter()
304 .map(|raw_kp| -> Result<_> {
305 let kp = core_crypto_keystore::deser::<KeyPackage>(&raw_kp.keypackage)
306 .map_err(KeystoreError::wrap("deserializing keypackage"))?;
307 Ok((raw_kp, kp))
308 })
309 .collect::<Result<Vec<_>, _>>()?;
310
311 Ok(kps)
312 }
313
314 #[cfg(test)]
317 pub async fn set_keypackage_lifetime(&self, duration: std::time::Duration) -> Result<()> {
318 match &mut *self.inner.write().await {
319 None => Err(Error::MlsNotInitialized),
320 Some(SessionInner {
321 keypackage_lifetime, ..
322 }) => {
323 *keypackage_lifetime = duration;
324 Ok(())
325 }
326 }
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use core_crypto_keystore::{ConnectionType, DatabaseKey};
333 use mls_crypto_provider::{Database, MlsCryptoProvider};
334 use openmls::prelude::{KeyPackage, KeyPackageIn, KeyPackageRef, ProtocolVersion};
335 use openmls_traits::{OpenMlsCryptoProvider, types::VerifiableCiphersuite};
336
337 use super::Session;
338 use crate::{
339 MlsConversationConfiguration,
340 e2e_identity::enrollment::test_utils::{e2ei_enrollment, init_activation_or_rotation, noop_restore},
341 test_utils::*,
342 };
343
344 #[apply(all_cred_cipher)]
345 async fn can_assess_keypackage_expiration(case: TestContext) {
346 let [session_context] = case.sessions().await;
347 let (cs, ct) = (case.ciphersuite(), case.credential_type);
348 let key = DatabaseKey::generate();
349 let database = Database::open(ConnectionType::InMemory, &key).await.unwrap();
350 let backend = MlsCryptoProvider::new(database);
351 let x509_test_chain = if case.is_x509() {
352 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
353 x509_test_chain.register_with_provider(&backend).await;
354 Some(x509_test_chain)
355 } else {
356 None
357 };
358
359 backend.new_transaction().await.unwrap();
360 let session = session_context.session;
361 session
362 .random_generate(
363 &case,
364 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
365 )
366 .await
367 .unwrap();
368
369 let kp_std_exp = session.generate_one_keypackage(&backend, cs, ct).await.unwrap();
371 assert!(!Session::is_mls_keypackage_expired(&kp_std_exp));
372
373 session
375 .set_keypackage_lifetime(std::time::Duration::from_secs(1))
376 .await
377 .unwrap();
378 let kp_1s_exp = session.generate_one_keypackage(&backend, cs, ct).await.unwrap();
379 smol::Timer::after(std::time::Duration::from_secs(2)).await;
381 assert!(Session::is_mls_keypackage_expired(&kp_1s_exp));
382 }
383
384 #[apply(all_cred_cipher)]
385 async fn requesting_x509_key_packages_after_basic(case: TestContext) {
386 if !case.is_basic() {
388 return;
389 }
390
391 let [session_context] = case.sessions_basic_with_pki_env().await;
392 Box::pin(async move {
393 let signature_scheme = case.signature_scheme();
394 let cipher_suite = case.ciphersuite();
395
396 let _basic_key_packages = session_context
398 .transaction
399 .get_or_create_client_keypackages(cipher_suite, CredentialType::Basic, 5)
400 .await
401 .unwrap();
402
403 let test_chain = session_context.x509_chain_unchecked();
405
406 let (mut enrollment, cert_chain) = e2ei_enrollment(
407 &session_context,
408 &case,
409 test_chain,
410 None,
411 false,
412 init_activation_or_rotation,
413 noop_restore,
414 )
415 .await
416 .unwrap();
417
418 let _rotate_bundle = session_context
419 .transaction
420 .save_x509_credential(&mut enrollment, cert_chain)
421 .await
422 .unwrap();
423
424 assert!(
426 session_context
427 .transaction
428 .e2ei_is_enabled(signature_scheme)
429 .await
430 .unwrap()
431 );
432
433 let x509_key_packages = session_context
435 .transaction
436 .get_or_create_client_keypackages(cipher_suite, CredentialType::X509, 5)
437 .await
438 .unwrap();
439
440 assert!(
442 x509_key_packages
443 .iter()
444 .all(|kp| CredentialType::X509 == kp.leaf_node().credential().credential_type())
445 );
446 })
447 .await
448 }
449
450 #[apply(all_cred_cipher)]
451 async fn generates_correct_number_of_kpbs(case: TestContext) {
452 let [cc] = case.sessions().await;
453 Box::pin(async move {
454 const N: usize = 2;
455 const COUNT: usize = 109;
456
457 let init = cc.transaction.count_entities().await;
458 assert_eq!(init.key_package, 0);
459 assert_eq!(init.encryption_keypair, 0);
460 assert_eq!(init.hpke_private_key, 0);
461 assert_eq!(init.credential, 1);
462
463 let transactional_provider = cc.transaction.mls_provider().await.unwrap();
467 let crypto_provider = transactional_provider.crypto();
468 let mut pinned_kp = None;
469
470 let mut prev_kps: Option<Vec<KeyPackage>> = None;
471 for _ in 0..N {
472 let mut kps = cc
473 .transaction
474 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, COUNT + 1)
475 .await
476 .unwrap();
477
478 pinned_kp = Some(kps.pop().unwrap());
480
481 assert_eq!(kps.len(), COUNT);
482 let after_creation = cc.transaction.count_entities().await;
483 assert_eq!(after_creation.key_package, COUNT + 1);
484 assert_eq!(after_creation.encryption_keypair, COUNT + 1);
485 assert_eq!(after_creation.hpke_private_key, COUNT + 1);
486 assert_eq!(after_creation.credential, 1);
487
488 let kpbs_refs = kps
489 .iter()
490 .map(|kp| kp.hash_ref(crypto_provider).unwrap())
491 .collect::<Vec<KeyPackageRef>>();
492
493 if let Some(pkpbs) = prev_kps.replace(kps) {
494 let pkpbs_refs = pkpbs
495 .into_iter()
496 .map(|kpb| kpb.hash_ref(crypto_provider).unwrap())
497 .collect::<Vec<KeyPackageRef>>();
498
499 let has_duplicates = kpbs_refs.iter().any(|href| pkpbs_refs.contains(href));
500 assert!(!has_duplicates);
502 }
503 cc.transaction.delete_keypackages(kpbs_refs).await.unwrap();
504 }
505
506 let count = cc
507 .transaction
508 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
509 .await
510 .unwrap();
511 assert_eq!(count, 1);
512
513 let pinned_kpr = pinned_kp.unwrap().hash_ref(crypto_provider).unwrap();
514 cc.transaction.delete_keypackages([pinned_kpr]).await.unwrap();
515 let count = cc
516 .transaction
517 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
518 .await
519 .unwrap();
520 assert_eq!(count, 0);
521 let after_delete = cc.transaction.count_entities().await;
522 assert_eq!(after_delete.key_package, 0);
523 assert_eq!(after_delete.encryption_keypair, 0);
524 assert_eq!(after_delete.hpke_private_key, 0);
525 assert_eq!(after_delete.credential, 0);
526 })
527 .await
528 }
529
530 #[apply(all_cred_cipher)]
531 async fn automatically_prunes_lifetime_expired_keypackages(case: TestContext) {
532 let [session_context] = case.sessions().await;
533 const UNEXPIRED_COUNT: usize = 125;
534 const EXPIRED_COUNT: usize = 200;
535 let key = DatabaseKey::generate();
536 let key_store = Database::open(ConnectionType::InMemory, &key).await.unwrap();
537 let backend = MlsCryptoProvider::new(key_store);
538 let x509_test_chain = if case.is_x509() {
539 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
540 x509_test_chain.register_with_provider(&backend).await;
541 Some(x509_test_chain)
542 } else {
543 None
544 };
545 backend.new_transaction().await.unwrap();
546 let session = session_context.session().await;
547 session
548 .random_generate(
549 &case,
550 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
551 )
552 .await
553 .unwrap();
554
555 let unexpired_kpbs = session
558 .request_key_packages(UNEXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
559 .await
560 .unwrap();
561 let len = session
562 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
563 .await
564 .unwrap();
565 assert_eq!(len, unexpired_kpbs.len());
566 assert_eq!(len, UNEXPIRED_COUNT);
567
568 session
570 .set_keypackage_lifetime(std::time::Duration::from_secs(10))
571 .await
572 .unwrap();
573
574 let partially_expired_kpbs = session
576 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
577 .await
578 .unwrap();
579 assert_eq!(partially_expired_kpbs.len(), EXPIRED_COUNT);
580
581 smol::Timer::after(std::time::Duration::from_secs(10)).await;
583
584 let fresh_kpbs = session
587 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
588 .await
589 .unwrap();
590 let len = session
591 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
592 .await
593 .unwrap();
594 assert_eq!(len, fresh_kpbs.len());
595 assert_eq!(len, EXPIRED_COUNT);
596
597 let (unexpired_match, expired_match) =
599 fresh_kpbs
600 .iter()
601 .fold((0usize, 0usize), |(mut unexpired_match, mut expired_match), fresh| {
602 if unexpired_kpbs.iter().any(|kp| kp == fresh) {
603 unexpired_match += 1;
604 } else if partially_expired_kpbs.iter().any(|kpb| kpb == fresh) {
605 expired_match += 1;
606 }
607
608 (unexpired_match, expired_match)
609 });
610
611 assert_eq!(unexpired_match, UNEXPIRED_COUNT);
613 assert_eq!(expired_match, 0);
614 }
615
616 #[apply(all_cred_cipher)]
617 async fn new_keypackage_has_correct_extensions(case: TestContext) {
618 let [cc] = case.sessions().await;
619 Box::pin(async move {
620 let kps = cc
621 .transaction
622 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 1)
623 .await
624 .unwrap();
625 let kp = kps.first().unwrap();
626
627 let _ = KeyPackageIn::from(kp.clone())
629 .standalone_validate(
630 &cc.transaction.mls_provider().await.unwrap(),
631 ProtocolVersion::Mls10,
632 true,
633 )
634 .await
635 .unwrap();
636
637 assert!(kp.extensions().is_empty());
639
640 assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
641 assert_eq!(
642 kp.leaf_node().capabilities().ciphersuites().to_vec(),
643 MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
644 .iter()
645 .map(|c| VerifiableCiphersuite::from(*c))
646 .collect::<Vec<_>>()
647 );
648 assert!(kp.leaf_node().capabilities().proposals().is_empty());
649 assert!(kp.leaf_node().capabilities().extensions().is_empty());
650 assert_eq!(
651 kp.leaf_node().capabilities().credentials(),
652 MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
653 );
654 })
655 .await
656 }
657}