1use openmls::prelude::{Credential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime};
2use openmls_traits::OpenMlsCryptoProvider;
3use std::collections::HashMap;
4use std::ops::DerefMut;
5use tls_codec::{Deserialize, Serialize};
6
7use core_crypto_keystore::{
8 connection::FetchFromDatabase,
9 entities::{EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage},
10};
11use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
12
13use super::{Error, Result};
14use crate::{
15 KeystoreError, MlsError,
16 mls::{credential::CredentialBundle, session::SessionInner},
17 prelude::{MlsCiphersuite, MlsConversationConfiguration, MlsCredentialType, Session},
18};
19
20#[cfg(not(test))]
22pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 100;
23#[cfg(test)]
25pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
26
27pub(crate) const KEYPACKAGE_DEFAULT_LIFETIME: std::time::Duration =
29 std::time::Duration::from_secs(60 * 60 * 24 * 28 * 3); impl Session {
32 pub async fn generate_one_keypackage_from_credential_bundle(
40 &self,
41 backend: &MlsCryptoProvider,
42 cs: MlsCiphersuite,
43 cb: &CredentialBundle,
44 ) -> Result<KeyPackage> {
45 let guard = self.inner.read().await;
46 let SessionInner {
47 keypackage_lifetime, ..
48 } = guard.as_ref().ok_or(Error::MlsNotInitialized)?;
49
50 let keypackage = KeyPackage::builder()
51 .leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
52 .key_package_lifetime(Lifetime::new(keypackage_lifetime.as_secs()))
53 .build(
54 CryptoConfig {
55 ciphersuite: cs.into(),
56 version: openmls::versions::ProtocolVersion::default(),
57 },
58 backend,
59 &cb.signature_key,
60 CredentialWithKey {
61 credential: cb.credential.clone(),
62 signature_key: cb.signature_key.public().into(),
63 },
64 )
65 .await
66 .map_err(KeystoreError::wrap("building keypackage"))?;
67
68 Ok(keypackage)
69 }
70
71 pub async fn request_key_packages(
82 &self,
83 count: usize,
84 ciphersuite: MlsCiphersuite,
85 credential_type: MlsCredentialType,
86 backend: &MlsCryptoProvider,
87 ) -> Result<Vec<KeyPackage>> {
88 self.prune_keypackages(backend, &[]).await?;
90 use core_crypto_keystore::CryptoKeystoreMls as _;
91
92 let mut existing_kps = backend
93 .key_store()
94 .mls_fetch_keypackages::<KeyPackage>(count as u32)
95 .await.map_err(KeystoreError::wrap("fetching mls keypackages"))?
96 .into_iter()
97 .filter(|kp|
99 kp.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(kp.leaf_node().credential().credential_type()) == credential_type)
100 .collect::<Vec<_>>();
101
102 let kpb_count = existing_kps.len();
103 let mut kps = if count > kpb_count {
104 let to_generate = count - kpb_count;
105 let cb = self
106 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
107 .await?;
108 self.generate_new_keypackages(backend, ciphersuite, &cb, to_generate)
109 .await?
110 } else {
111 vec![]
112 };
113
114 existing_kps.reverse();
115
116 kps.append(&mut existing_kps);
117 Ok(kps)
118 }
119
120 pub(crate) async fn generate_new_keypackages(
121 &self,
122 backend: &MlsCryptoProvider,
123 ciphersuite: MlsCiphersuite,
124 cb: &CredentialBundle,
125 count: usize,
126 ) -> Result<Vec<KeyPackage>> {
127 let mut kps = Vec::with_capacity(count);
128
129 for _ in 0..count {
130 let kp = self
131 .generate_one_keypackage_from_credential_bundle(backend, ciphersuite, cb)
132 .await?;
133 kps.push(kp);
134 }
135
136 Ok(kps)
137 }
138
139 pub async fn valid_keypackages_count(
141 &self,
142 backend: &MlsCryptoProvider,
143 ciphersuite: MlsCiphersuite,
144 credential_type: MlsCredentialType,
145 ) -> Result<usize> {
146 let kps: Vec<MlsKeyPackage> = backend
147 .key_store()
148 .find_all(EntityFindParams::default())
149 .await
150 .map_err(KeystoreError::wrap("finding all key packages"))?;
151
152 let mut valid_count = 0;
153 for kp in kps
154 .into_iter()
155 .map(|kp| core_crypto_keystore::deser::<KeyPackage>(&kp.keypackage))
156 .filter(|kp| {
158 kp.as_ref()
159 .map(|b| b.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(b.leaf_node().credential().credential_type()) == credential_type)
160 .unwrap_or_default()
161 })
162 {
163 let kp = kp.map_err(KeystoreError::wrap("counting valid keypackages"))?;
164 if !Self::is_mls_keypackage_expired(&kp) {
165 valid_count += 1;
166 }
167 }
168
169 Ok(valid_count)
170 }
171
172 fn is_mls_keypackage_expired(kp: &KeyPackage) -> bool {
175 let Some(lifetime) = kp.leaf_node().life_time() else {
176 return false;
177 };
178
179 !(lifetime.has_acceptable_range() && lifetime.is_valid())
180 }
181
182 pub async fn prune_keypackages(&self, backend: &MlsCryptoProvider, refs: &[KeyPackageRef]) -> Result<()> {
188 let keystore = backend.keystore();
189 let kps = self.find_all_keypackages(&keystore).await?;
190 let _ = self._prune_keypackages(&kps, &keystore, refs).await?;
191 Ok(())
192 }
193
194 pub(crate) async fn prune_keypackages_and_credential(
195 &mut self,
196 backend: &MlsCryptoProvider,
197 refs: &[KeyPackageRef],
198 ) -> Result<()> {
199 match self.inner.write().await.deref_mut() {
200 None => Err(Error::MlsNotInitialized),
201 Some(SessionInner { identities, .. }) => {
202 let keystore = backend.key_store();
203 let kps = self.find_all_keypackages(keystore).await?;
204 let kp_to_delete = self._prune_keypackages(&kps, keystore, refs).await?;
205
206 let mut grouped_kps = HashMap::<Vec<u8>, Vec<KeyPackageRef>>::new();
208 for (_, kp) in &kps {
209 let cred = kp
210 .leaf_node()
211 .credential()
212 .tls_serialize_detached()
213 .map_err(Error::tls_serialize("keypackage"))?;
214 let kp_ref = kp
215 .hash_ref(backend.crypto())
216 .map_err(MlsError::wrap("computing keypackage hashref"))?;
217 grouped_kps
218 .entry(cred)
219 .and_modify(|kprfs| kprfs.push(kp_ref.clone()))
220 .or_insert(vec![kp_ref]);
221 }
222
223 for (credential, kps) in &grouped_kps {
224 let all_to_delete = kps.iter().all(|kpr| kp_to_delete.contains(&kpr.as_slice()));
226 if all_to_delete {
227 backend
229 .keystore()
230 .cred_delete_by_credential(credential.clone())
231 .await
232 .map_err(KeystoreError::wrap("deleting credential"))?;
233 let credential = Credential::tls_deserialize(&mut credential.as_slice())
234 .map_err(Error::tls_deserialize("credential"))?;
235 identities.remove(&credential).await?;
236 }
237 }
238 Ok(())
239 }
240 }
241 }
242
243 async fn _prune_keypackages<'a>(
248 &self,
249 kps: &'a [(MlsKeyPackage, KeyPackage)],
250 keystore: &CryptoKeystore,
251 refs: &[KeyPackageRef],
252 ) -> Result<Vec<&'a [u8]>, Error> {
253 let kp_to_delete: Vec<_> = kps
254 .iter()
255 .filter_map(|(store_kp, kp)| {
256 let is_expired = Self::is_mls_keypackage_expired(kp);
257 let mut to_delete = is_expired;
258 if !(is_expired || refs.is_empty()) {
259 to_delete = refs.iter().any(|r| r.as_slice() == store_kp.keypackage_ref);
262 }
263
264 to_delete.then_some((kp, &store_kp.keypackage_ref))
265 })
266 .collect();
267
268 for (kp, kp_ref) in &kp_to_delete {
269 keystore
271 .remove::<MlsKeyPackage, &[u8]>(kp_ref.as_slice())
272 .await
273 .map_err(KeystoreError::wrap("removing key package from keystore"))?;
274 keystore
275 .remove::<MlsHpkePrivateKey, &[u8]>(kp.hpke_init_key().as_slice())
276 .await
277 .map_err(KeystoreError::wrap("removing private key from keystore"))?;
278 keystore
279 .remove::<MlsEncryptionKeyPair, &[u8]>(kp.leaf_node().encryption_key().as_slice())
280 .await
281 .map_err(KeystoreError::wrap("removing encryption keypair from keystore"))?;
282 }
283
284 let kp_to_delete = kp_to_delete
285 .into_iter()
286 .map(|(_, kpref)| &kpref[..])
287 .collect::<Vec<_>>();
288
289 Ok(kp_to_delete)
290 }
291
292 async fn find_all_keypackages(&self, keystore: &CryptoKeystore) -> Result<Vec<(MlsKeyPackage, KeyPackage)>> {
293 let kps: Vec<MlsKeyPackage> = keystore
294 .find_all(EntityFindParams::default())
295 .await
296 .map_err(KeystoreError::wrap("finding all keypackages"))?;
297
298 let kps = kps
299 .into_iter()
300 .map(|raw_kp| -> Result<_> {
301 let kp = core_crypto_keystore::deser::<KeyPackage>(&raw_kp.keypackage)
302 .map_err(KeystoreError::wrap("deserializing keypackage"))?;
303 Ok((raw_kp, kp))
304 })
305 .collect::<Result<Vec<_>, _>>()?;
306
307 Ok(kps)
308 }
309
310 #[cfg(test)]
313 pub async fn set_keypackage_lifetime(&self, duration: std::time::Duration) -> Result<()> {
314 use std::ops::DerefMut;
315 match self.inner.write().await.deref_mut() {
316 None => Err(Error::MlsNotInitialized),
317 Some(SessionInner {
318 keypackage_lifetime, ..
319 }) => {
320 *keypackage_lifetime = duration;
321 Ok(())
322 }
323 }
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use openmls::prelude::{KeyPackage, KeyPackageIn, KeyPackageRef, ProtocolVersion};
330 use openmls_traits::OpenMlsCryptoProvider;
331 use openmls_traits::types::VerifiableCiphersuite;
332 use wasm_bindgen_test::*;
333
334 use mls_crypto_provider::MlsCryptoProvider;
335
336 use crate::e2e_identity::enrollment::test_utils::{e2ei_enrollment, init_activation_or_rotation, noop_restore};
337 use crate::prelude::MlsConversationConfiguration;
338 use crate::prelude::key_package::INITIAL_KEYING_MATERIAL_COUNT;
339 use crate::test_utils::*;
340 use core_crypto_keystore::DatabaseKey;
341
342 use super::Session;
343
344 wasm_bindgen_test_configure!(run_in_browser);
345
346 #[apply(all_cred_cipher)]
347 #[wasm_bindgen_test]
348 async fn can_assess_keypackage_expiration(case: TestContext) {
349 let [session] = case.sessions().await;
350 let (cs, ct) = (case.ciphersuite(), case.credential_type);
351 let key = DatabaseKey::generate();
352 let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
353 let x509_test_chain = if case.is_x509() {
354 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
355 x509_test_chain.register_with_provider(&backend).await;
356 Some(x509_test_chain)
357 } else {
358 None
359 };
360
361 backend.new_transaction().await.unwrap();
362 let session = session.session;
363 session
364 .random_generate(
365 &case,
366 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
367 false,
368 )
369 .await
370 .unwrap();
371
372 let kp_std_exp = session.generate_one_keypackage(&backend, cs, ct).await.unwrap();
374 assert!(!Session::is_mls_keypackage_expired(&kp_std_exp));
375
376 session
378 .set_keypackage_lifetime(std::time::Duration::from_secs(1))
379 .await
380 .unwrap();
381 let kp_1s_exp = session.generate_one_keypackage(&backend, cs, ct).await.unwrap();
382 async_std::task::sleep(std::time::Duration::from_secs(2)).await;
384 assert!(Session::is_mls_keypackage_expired(&kp_1s_exp));
385 }
386
387 #[apply(all_cred_cipher)]
388 #[wasm_bindgen_test]
389 async fn requesting_x509_key_packages_after_basic(case: TestContext) {
390 if !case.is_basic() {
392 return;
393 }
394 run_test_with_client_ids(case.clone(), ["alice"], move |[mut session_context]| {
395 Box::pin(async move {
396 let signature_scheme = case.signature_scheme();
397 let cipher_suite = case.ciphersuite();
398
399 let _basic_key_packages = session_context
401 .transaction
402 .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::Basic, 5)
403 .await
404 .unwrap();
405
406 let test_chain = x509::X509TestChain::init_for_random_clients(signature_scheme, 1);
408
409 let (mut enrollment, cert_chain) = e2ei_enrollment(
410 &mut session_context,
411 &case,
412 &test_chain,
413 None,
414 false,
415 init_activation_or_rotation,
416 noop_restore,
417 )
418 .await
419 .unwrap();
420
421 let _rotate_bundle = session_context
422 .transaction
423 .save_x509_credential(&mut enrollment, cert_chain)
424 .await
425 .unwrap();
426
427 assert!(
429 session_context
430 .transaction
431 .e2ei_is_enabled(signature_scheme)
432 .await
433 .unwrap()
434 );
435
436 let x509_key_packages = session_context
438 .transaction
439 .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::X509, 5)
440 .await
441 .unwrap();
442
443 assert!(x509_key_packages.iter().all(|kp| MlsCredentialType::X509
445 == MlsCredentialType::from(kp.leaf_node().credential().credential_type())));
446 })
447 })
448 .await
449 }
450
451 #[apply(all_cred_cipher)]
452 #[wasm_bindgen_test]
453 async fn generates_correct_number_of_kpbs(case: TestContext) {
454 run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
455 Box::pin(async move {
456 const N: usize = 2;
457 const COUNT: usize = 109;
458
459 let init = cc.transaction.count_entities().await;
460 assert_eq!(init.key_package, INITIAL_KEYING_MATERIAL_COUNT);
461 assert_eq!(init.encryption_keypair, INITIAL_KEYING_MATERIAL_COUNT);
462 assert_eq!(init.hpke_private_key, INITIAL_KEYING_MATERIAL_COUNT);
463 assert_eq!(init.credential, 1);
464 assert_eq!(init.signature_keypair, 1);
465
466 let transactional_provider = cc.transaction.mls_provider().await.unwrap();
469 let crypto_provider = transactional_provider.crypto();
470 let mut pinned_kp = None;
471
472 let mut prev_kps: Option<Vec<KeyPackage>> = None;
473 for _ in 0..N {
474 let mut kps = cc
475 .transaction
476 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, COUNT + 1)
477 .await
478 .unwrap();
479
480 pinned_kp = Some(kps.pop().unwrap());
482
483 assert_eq!(kps.len(), COUNT);
484 let after_creation = cc.transaction.count_entities().await;
485 assert_eq!(after_creation.key_package, COUNT + 1);
486 assert_eq!(after_creation.encryption_keypair, COUNT + 1);
487 assert_eq!(after_creation.hpke_private_key, COUNT + 1);
488 assert_eq!(after_creation.credential, 1);
489
490 let kpbs_refs = kps
491 .iter()
492 .map(|kp| kp.hash_ref(crypto_provider).unwrap())
493 .collect::<Vec<KeyPackageRef>>();
494
495 if let Some(pkpbs) = prev_kps.replace(kps) {
496 let pkpbs_refs = pkpbs
497 .into_iter()
498 .map(|kpb| kpb.hash_ref(crypto_provider).unwrap())
499 .collect::<Vec<KeyPackageRef>>();
500
501 let has_duplicates = kpbs_refs.iter().any(|href| pkpbs_refs.contains(href));
502 assert!(!has_duplicates);
504 }
505 cc.transaction.delete_keypackages(&kpbs_refs).await.unwrap();
506 }
507
508 let count = cc
509 .transaction
510 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
511 .await
512 .unwrap();
513 assert_eq!(count, 1);
514
515 let pinned_kpr = pinned_kp.unwrap().hash_ref(crypto_provider).unwrap();
516 cc.transaction.delete_keypackages(&[pinned_kpr]).await.unwrap();
517 let count = cc
518 .transaction
519 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
520 .await
521 .unwrap();
522 assert_eq!(count, 0);
523 let after_delete = cc.transaction.count_entities().await;
524 assert_eq!(after_delete.key_package, 0);
525 assert_eq!(after_delete.encryption_keypair, 0);
526 assert_eq!(after_delete.hpke_private_key, 0);
527 assert_eq!(after_delete.credential, 0);
528 })
529 })
530 .await
531 }
532
533 #[apply(all_cred_cipher)]
534 #[wasm_bindgen_test]
535 async fn automatically_prunes_lifetime_expired_keypackages(case: TestContext) {
536 let [session] = case.sessions().await;
537 const UNEXPIRED_COUNT: usize = 125;
538 const EXPIRED_COUNT: usize = 200;
539 let key = DatabaseKey::generate();
540 let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
541 let x509_test_chain = if case.is_x509() {
542 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
543 x509_test_chain.register_with_provider(&backend).await;
544 Some(x509_test_chain)
545 } else {
546 None
547 };
548 backend.new_transaction().await.unwrap();
549 let session = session.session().await;
550 session
551 .random_generate(
552 &case,
553 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
554 false,
555 )
556 .await
557 .unwrap();
558
559 let unexpired_kpbs = session
561 .request_key_packages(UNEXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
562 .await
563 .unwrap();
564 let len = session
565 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
566 .await
567 .unwrap();
568 assert_eq!(len, unexpired_kpbs.len());
569 assert_eq!(len, UNEXPIRED_COUNT);
570
571 session
573 .set_keypackage_lifetime(std::time::Duration::from_secs(10))
574 .await
575 .unwrap();
576
577 let partially_expired_kpbs = session
579 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
580 .await
581 .unwrap();
582 assert_eq!(partially_expired_kpbs.len(), EXPIRED_COUNT);
583
584 async_std::task::sleep(std::time::Duration::from_secs(10)).await;
586
587 let fresh_kpbs = session
590 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
591 .await
592 .unwrap();
593 let len = session
594 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
595 .await
596 .unwrap();
597 assert_eq!(len, fresh_kpbs.len());
598 assert_eq!(len, EXPIRED_COUNT);
599
600 let (unexpired_match, expired_match) =
602 fresh_kpbs
603 .iter()
604 .fold((0usize, 0usize), |(mut unexpired_match, mut expired_match), fresh| {
605 if unexpired_kpbs.iter().any(|kp| kp == fresh) {
606 unexpired_match += 1;
607 } else if partially_expired_kpbs.iter().any(|kpb| kpb == fresh) {
608 expired_match += 1;
609 }
610
611 (unexpired_match, expired_match)
612 });
613
614 assert_eq!(unexpired_match, UNEXPIRED_COUNT);
616 assert_eq!(expired_match, 0);
617 }
618
619 #[apply(all_cred_cipher)]
620 #[wasm_bindgen_test]
621 async fn new_keypackage_has_correct_extensions(case: TestContext) {
622 run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
623 Box::pin(async move {
624 let kps = cc
625 .transaction
626 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 1)
627 .await
628 .unwrap();
629 let kp = kps.first().unwrap();
630
631 let _ = KeyPackageIn::from(kp.clone())
633 .standalone_validate(
634 &cc.transaction.mls_provider().await.unwrap(),
635 ProtocolVersion::Mls10,
636 true,
637 )
638 .await
639 .unwrap();
640
641 assert!(kp.extensions().is_empty());
643
644 assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
645 assert_eq!(
646 kp.leaf_node().capabilities().ciphersuites().to_vec(),
647 MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
648 .iter()
649 .map(|c| VerifiableCiphersuite::from(*c))
650 .collect::<Vec<_>>()
651 );
652 assert!(kp.leaf_node().capabilities().proposals().is_empty());
653 assert!(kp.leaf_node().capabilities().extensions().is_empty());
654 assert_eq!(
655 kp.leaf_node().capabilities().credentials(),
656 MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
657 );
658 })
659 })
660 .await
661 }
662}