1use openmls::prelude::{Credential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime};
2use openmls_traits::OpenMlsCryptoProvider;
3use std::collections::HashMap;
4use std::ops::{Deref, DerefMut};
5use tls_codec::{Deserialize, Serialize};
6
7use core_crypto_keystore::{
8 connection::FetchFromDatabase,
9 entities::{EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage},
10};
11use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
12
13use super::{Error, Result};
14use crate::{
15 KeystoreError, MlsError,
16 mls::{credential::CredentialBundle, session::SessionInner},
17 prelude::{MlsCiphersuite, MlsConversationConfiguration, MlsCredentialType, Session},
18};
19
20#[cfg(not(test))]
22pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 100;
23#[cfg(test)]
25pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
26
27pub(crate) const KEYPACKAGE_DEFAULT_LIFETIME: std::time::Duration =
29 std::time::Duration::from_secs(60 * 60 * 24 * 28 * 3); impl Session {
32 pub async fn generate_one_keypackage_from_credential_bundle(
40 &self,
41 backend: &MlsCryptoProvider,
42 cs: MlsCiphersuite,
43 cb: &CredentialBundle,
44 ) -> Result<KeyPackage> {
45 match self.inner.read().await.deref() {
46 None => Err(Error::MlsNotInitialized),
47 Some(SessionInner {
48 keypackage_lifetime, ..
49 }) => {
50 let keypackage = KeyPackage::builder()
51 .leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
52 .key_package_lifetime(Lifetime::new(keypackage_lifetime.as_secs()))
53 .build(
54 CryptoConfig {
55 ciphersuite: cs.into(),
56 version: openmls::versions::ProtocolVersion::default(),
57 },
58 backend,
59 &cb.signature_key,
60 CredentialWithKey {
61 credential: cb.credential.clone(),
62 signature_key: cb.signature_key.public().into(),
63 },
64 )
65 .await
66 .map_err(KeystoreError::wrap("building keypackage"))?;
67
68 Ok(keypackage)
69 }
70 }
71 }
72
73 pub async fn request_key_packages(
84 &self,
85 count: usize,
86 ciphersuite: MlsCiphersuite,
87 credential_type: MlsCredentialType,
88 backend: &MlsCryptoProvider,
89 ) -> Result<Vec<KeyPackage>> {
90 self.prune_keypackages(backend, &[]).await?;
92 use core_crypto_keystore::CryptoKeystoreMls as _;
93
94 let mut existing_kps = backend
95 .key_store()
96 .mls_fetch_keypackages::<KeyPackage>(count as u32)
97 .await.map_err(KeystoreError::wrap("fetching mls keypackages"))?
98 .into_iter()
99 .filter(|kp|
101 kp.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(kp.leaf_node().credential().credential_type()) == credential_type)
102 .collect::<Vec<_>>();
103
104 let kpb_count = existing_kps.len();
105 let mut kps = if count > kpb_count {
106 let to_generate = count - kpb_count;
107 let cb = self
108 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
109 .await?;
110 self.generate_new_keypackages(backend, ciphersuite, &cb, to_generate)
111 .await?
112 } else {
113 vec![]
114 };
115
116 existing_kps.reverse();
117
118 kps.append(&mut existing_kps);
119 Ok(kps)
120 }
121
122 pub(crate) async fn generate_new_keypackages(
123 &self,
124 backend: &MlsCryptoProvider,
125 ciphersuite: MlsCiphersuite,
126 cb: &CredentialBundle,
127 count: usize,
128 ) -> Result<Vec<KeyPackage>> {
129 let mut kps = Vec::with_capacity(count);
130
131 for _ in 0..count {
132 let kp = self
133 .generate_one_keypackage_from_credential_bundle(backend, ciphersuite, cb)
134 .await?;
135 kps.push(kp);
136 }
137
138 Ok(kps)
139 }
140
141 pub async fn valid_keypackages_count(
143 &self,
144 backend: &MlsCryptoProvider,
145 ciphersuite: MlsCiphersuite,
146 credential_type: MlsCredentialType,
147 ) -> Result<usize> {
148 let kps: Vec<MlsKeyPackage> = backend
149 .key_store()
150 .find_all(EntityFindParams::default())
151 .await
152 .map_err(KeystoreError::wrap("finding all key packages"))?;
153
154 let mut valid_count = 0;
155 for kp in kps
156 .into_iter()
157 .map(|kp| core_crypto_keystore::deser::<KeyPackage>(&kp.keypackage))
158 .filter(|kp| {
160 kp.as_ref()
161 .map(|b| b.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(b.leaf_node().credential().credential_type()) == credential_type)
162 .unwrap_or_default()
163 })
164 {
165 let kp = kp.map_err(KeystoreError::wrap("counting valid keypackages"))?;
166 if !Self::is_mls_keypackage_expired(&kp) {
167 valid_count += 1;
168 }
169 }
170
171 Ok(valid_count)
172 }
173
174 fn is_mls_keypackage_expired(kp: &KeyPackage) -> bool {
177 let Some(lifetime) = kp.leaf_node().life_time() else {
178 return false;
179 };
180
181 !(lifetime.has_acceptable_range() && lifetime.is_valid())
182 }
183
184 pub async fn prune_keypackages(&self, backend: &MlsCryptoProvider, refs: &[KeyPackageRef]) -> Result<()> {
190 let keystore = backend.keystore();
191 let kps = self.find_all_keypackages(&keystore).await?;
192 let _ = self._prune_keypackages(&kps, &keystore, refs).await?;
193 Ok(())
194 }
195
196 pub(crate) async fn prune_keypackages_and_credential(
197 &mut self,
198 backend: &MlsCryptoProvider,
199 refs: &[KeyPackageRef],
200 ) -> Result<()> {
201 match self.inner.write().await.deref_mut() {
202 None => Err(Error::MlsNotInitialized),
203 Some(SessionInner { identities, .. }) => {
204 let keystore = backend.key_store();
205 let kps = self.find_all_keypackages(keystore).await?;
206 let kp_to_delete = self._prune_keypackages(&kps, keystore, refs).await?;
207
208 let mut grouped_kps = HashMap::<Vec<u8>, Vec<KeyPackageRef>>::new();
210 for (_, kp) in &kps {
211 let cred = kp
212 .leaf_node()
213 .credential()
214 .tls_serialize_detached()
215 .map_err(Error::tls_serialize("keypackage"))?;
216 let kp_ref = kp
217 .hash_ref(backend.crypto())
218 .map_err(MlsError::wrap("computing keypackage hashref"))?;
219 grouped_kps
220 .entry(cred)
221 .and_modify(|kprfs| kprfs.push(kp_ref.clone()))
222 .or_insert(vec![kp_ref]);
223 }
224
225 for (credential, kps) in &grouped_kps {
226 let all_to_delete = kps.iter().all(|kpr| kp_to_delete.contains(&kpr.as_slice()));
228 if all_to_delete {
229 backend
231 .keystore()
232 .cred_delete_by_credential(credential.clone())
233 .await
234 .map_err(KeystoreError::wrap("deleting credential"))?;
235 let credential = Credential::tls_deserialize(&mut credential.as_slice())
236 .map_err(Error::tls_deserialize("credential"))?;
237 identities.remove(&credential).await?;
238 }
239 }
240 Ok(())
241 }
242 }
243 }
244
245 async fn _prune_keypackages<'a>(
250 &self,
251 kps: &'a [(MlsKeyPackage, KeyPackage)],
252 keystore: &CryptoKeystore,
253 refs: &[KeyPackageRef],
254 ) -> Result<Vec<&'a [u8]>, Error> {
255 let kp_to_delete: Vec<_> = kps
256 .iter()
257 .filter_map(|(store_kp, kp)| {
258 let is_expired = Self::is_mls_keypackage_expired(kp);
259 let mut to_delete = is_expired;
260 if !(is_expired || refs.is_empty()) {
261 to_delete = refs.iter().any(|r| r.as_slice() == store_kp.keypackage_ref);
264 }
265
266 to_delete.then_some((kp, &store_kp.keypackage_ref))
267 })
268 .collect();
269
270 for (kp, kp_ref) in &kp_to_delete {
271 keystore
273 .remove::<MlsKeyPackage, &[u8]>(kp_ref.as_slice())
274 .await
275 .map_err(KeystoreError::wrap("removing key package from keystore"))?;
276 keystore
277 .remove::<MlsHpkePrivateKey, &[u8]>(kp.hpke_init_key().as_slice())
278 .await
279 .map_err(KeystoreError::wrap("removing private key from keystore"))?;
280 keystore
281 .remove::<MlsEncryptionKeyPair, &[u8]>(kp.leaf_node().encryption_key().as_slice())
282 .await
283 .map_err(KeystoreError::wrap("removing encryption keypair from keystore"))?;
284 }
285
286 let kp_to_delete = kp_to_delete
287 .into_iter()
288 .map(|(_, kpref)| &kpref[..])
289 .collect::<Vec<_>>();
290
291 Ok(kp_to_delete)
292 }
293
294 async fn find_all_keypackages(&self, keystore: &CryptoKeystore) -> Result<Vec<(MlsKeyPackage, KeyPackage)>> {
295 let kps: Vec<MlsKeyPackage> = keystore
296 .find_all(EntityFindParams::default())
297 .await
298 .map_err(KeystoreError::wrap("finding all keypackages"))?;
299
300 let kps = kps
301 .into_iter()
302 .map(|raw_kp| -> Result<_> {
303 let kp = core_crypto_keystore::deser::<KeyPackage>(&raw_kp.keypackage)
304 .map_err(KeystoreError::wrap("deserializing keypackage"))?;
305 Ok((raw_kp, kp))
306 })
307 .collect::<Result<Vec<_>, _>>()?;
308
309 Ok(kps)
310 }
311
312 #[cfg(test)]
315 pub async fn set_keypackage_lifetime(&self, duration: std::time::Duration) -> Result<()> {
316 use std::ops::DerefMut;
317 match self.inner.write().await.deref_mut() {
318 None => Err(Error::MlsNotInitialized),
319 Some(SessionInner {
320 keypackage_lifetime, ..
321 }) => {
322 *keypackage_lifetime = duration;
323 Ok(())
324 }
325 }
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use openmls::prelude::{KeyPackage, KeyPackageIn, KeyPackageRef, ProtocolVersion};
332 use openmls_traits::OpenMlsCryptoProvider;
333 use openmls_traits::types::VerifiableCiphersuite;
334 use wasm_bindgen_test::*;
335
336 use mls_crypto_provider::MlsCryptoProvider;
337
338 use crate::e2e_identity::enrollment::test_utils::{e2ei_enrollment, init_activation_or_rotation, noop_restore};
339 use crate::prelude::MlsConversationConfiguration;
340 use crate::prelude::key_package::INITIAL_KEYING_MATERIAL_COUNT;
341 use crate::test_utils::*;
342 use core_crypto_keystore::DatabaseKey;
343
344 use super::Session;
345
346 wasm_bindgen_test_configure!(run_in_browser);
347
348 #[apply(all_cred_cipher)]
349 #[wasm_bindgen_test]
350 async fn can_assess_keypackage_expiration(case: TestContext) {
351 let [session] = case.sessions().await;
352 let (cs, ct) = (case.ciphersuite(), case.credential_type);
353 let key = DatabaseKey::generate();
354 let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
355 let x509_test_chain = if case.is_x509() {
356 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
357 x509_test_chain.register_with_provider(&backend).await;
358 Some(x509_test_chain)
359 } else {
360 None
361 };
362
363 backend.new_transaction().await.unwrap();
364 let session = session.session;
365 session
366 .random_generate(
367 &case,
368 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
369 false,
370 )
371 .await
372 .unwrap();
373
374 let kp_std_exp = session.generate_one_keypackage(&backend, cs, ct).await.unwrap();
376 assert!(!Session::is_mls_keypackage_expired(&kp_std_exp));
377
378 session
380 .set_keypackage_lifetime(std::time::Duration::from_secs(1))
381 .await
382 .unwrap();
383 let kp_1s_exp = session.generate_one_keypackage(&backend, cs, ct).await.unwrap();
384 async_std::task::sleep(std::time::Duration::from_secs(2)).await;
386 assert!(Session::is_mls_keypackage_expired(&kp_1s_exp));
387 }
388
389 #[apply(all_cred_cipher)]
390 #[wasm_bindgen_test]
391 async fn requesting_x509_key_packages_after_basic(case: TestContext) {
392 if !case.is_basic() {
394 return;
395 }
396 run_test_with_client_ids(case.clone(), ["alice"], move |[mut session_context]| {
397 Box::pin(async move {
398 let signature_scheme = case.signature_scheme();
399 let cipher_suite = case.ciphersuite();
400
401 let _basic_key_packages = session_context
403 .transaction
404 .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::Basic, 5)
405 .await
406 .unwrap();
407
408 let test_chain = x509::X509TestChain::init_for_random_clients(signature_scheme, 1);
410
411 let (mut enrollment, cert_chain) = e2ei_enrollment(
412 &mut session_context,
413 &case,
414 &test_chain,
415 None,
416 false,
417 init_activation_or_rotation,
418 noop_restore,
419 )
420 .await
421 .unwrap();
422
423 let _rotate_bundle = session_context
424 .transaction
425 .save_x509_credential(&mut enrollment, cert_chain)
426 .await
427 .unwrap();
428
429 assert!(
431 session_context
432 .transaction
433 .e2ei_is_enabled(signature_scheme)
434 .await
435 .unwrap()
436 );
437
438 let x509_key_packages = session_context
440 .transaction
441 .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::X509, 5)
442 .await
443 .unwrap();
444
445 assert!(x509_key_packages.iter().all(|kp| MlsCredentialType::X509
447 == MlsCredentialType::from(kp.leaf_node().credential().credential_type())));
448 })
449 })
450 .await
451 }
452
453 #[apply(all_cred_cipher)]
454 #[wasm_bindgen_test]
455 async fn generates_correct_number_of_kpbs(case: TestContext) {
456 run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
457 Box::pin(async move {
458 const N: usize = 2;
459 const COUNT: usize = 109;
460
461 let init = cc.transaction.count_entities().await;
462 assert_eq!(init.key_package, INITIAL_KEYING_MATERIAL_COUNT);
463 assert_eq!(init.encryption_keypair, INITIAL_KEYING_MATERIAL_COUNT);
464 assert_eq!(init.hpke_private_key, INITIAL_KEYING_MATERIAL_COUNT);
465 assert_eq!(init.credential, 1);
466 assert_eq!(init.signature_keypair, 1);
467
468 let transactional_provider = cc.transaction.mls_provider().await.unwrap();
471 let crypto_provider = transactional_provider.crypto();
472 let mut pinned_kp = None;
473
474 let mut prev_kps: Option<Vec<KeyPackage>> = None;
475 for _ in 0..N {
476 let mut kps = cc
477 .transaction
478 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, COUNT + 1)
479 .await
480 .unwrap();
481
482 pinned_kp = Some(kps.pop().unwrap());
484
485 assert_eq!(kps.len(), COUNT);
486 let after_creation = cc.transaction.count_entities().await;
487 assert_eq!(after_creation.key_package, COUNT + 1);
488 assert_eq!(after_creation.encryption_keypair, COUNT + 1);
489 assert_eq!(after_creation.hpke_private_key, COUNT + 1);
490 assert_eq!(after_creation.credential, 1);
491
492 let kpbs_refs = kps
493 .iter()
494 .map(|kp| kp.hash_ref(crypto_provider).unwrap())
495 .collect::<Vec<KeyPackageRef>>();
496
497 if let Some(pkpbs) = prev_kps.replace(kps) {
498 let pkpbs_refs = pkpbs
499 .into_iter()
500 .map(|kpb| kpb.hash_ref(crypto_provider).unwrap())
501 .collect::<Vec<KeyPackageRef>>();
502
503 let has_duplicates = kpbs_refs.iter().any(|href| pkpbs_refs.contains(href));
504 assert!(!has_duplicates);
506 }
507 cc.transaction.delete_keypackages(&kpbs_refs).await.unwrap();
508 }
509
510 let count = cc
511 .transaction
512 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
513 .await
514 .unwrap();
515 assert_eq!(count, 1);
516
517 let pinned_kpr = pinned_kp.unwrap().hash_ref(crypto_provider).unwrap();
518 cc.transaction.delete_keypackages(&[pinned_kpr]).await.unwrap();
519 let count = cc
520 .transaction
521 .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
522 .await
523 .unwrap();
524 assert_eq!(count, 0);
525 let after_delete = cc.transaction.count_entities().await;
526 assert_eq!(after_delete.key_package, 0);
527 assert_eq!(after_delete.encryption_keypair, 0);
528 assert_eq!(after_delete.hpke_private_key, 0);
529 assert_eq!(after_delete.credential, 0);
530 })
531 })
532 .await
533 }
534
535 #[apply(all_cred_cipher)]
536 #[wasm_bindgen_test]
537 async fn automatically_prunes_lifetime_expired_keypackages(case: TestContext) {
538 let [session] = case.sessions().await;
539 const UNEXPIRED_COUNT: usize = 125;
540 const EXPIRED_COUNT: usize = 200;
541 let key = DatabaseKey::generate();
542 let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
543 let x509_test_chain = if case.is_x509() {
544 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
545 x509_test_chain.register_with_provider(&backend).await;
546 Some(x509_test_chain)
547 } else {
548 None
549 };
550 backend.new_transaction().await.unwrap();
551 let session = session.session().await;
552 session
553 .random_generate(
554 &case,
555 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
556 false,
557 )
558 .await
559 .unwrap();
560
561 let unexpired_kpbs = session
563 .request_key_packages(UNEXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
564 .await
565 .unwrap();
566 let len = session
567 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
568 .await
569 .unwrap();
570 assert_eq!(len, unexpired_kpbs.len());
571 assert_eq!(len, UNEXPIRED_COUNT);
572
573 session
575 .set_keypackage_lifetime(std::time::Duration::from_secs(10))
576 .await
577 .unwrap();
578
579 let partially_expired_kpbs = session
581 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
582 .await
583 .unwrap();
584 assert_eq!(partially_expired_kpbs.len(), EXPIRED_COUNT);
585
586 async_std::task::sleep(std::time::Duration::from_secs(10)).await;
588
589 let fresh_kpbs = session
592 .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
593 .await
594 .unwrap();
595 let len = session
596 .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
597 .await
598 .unwrap();
599 assert_eq!(len, fresh_kpbs.len());
600 assert_eq!(len, EXPIRED_COUNT);
601
602 let (unexpired_match, expired_match) =
604 fresh_kpbs
605 .iter()
606 .fold((0usize, 0usize), |(mut unexpired_match, mut expired_match), fresh| {
607 if unexpired_kpbs.iter().any(|kp| kp == fresh) {
608 unexpired_match += 1;
609 } else if partially_expired_kpbs.iter().any(|kpb| kpb == fresh) {
610 expired_match += 1;
611 }
612
613 (unexpired_match, expired_match)
614 });
615
616 assert_eq!(unexpired_match, UNEXPIRED_COUNT);
618 assert_eq!(expired_match, 0);
619 }
620
621 #[apply(all_cred_cipher)]
622 #[wasm_bindgen_test]
623 async fn new_keypackage_has_correct_extensions(case: TestContext) {
624 run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
625 Box::pin(async move {
626 let kps = cc
627 .transaction
628 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 1)
629 .await
630 .unwrap();
631 let kp = kps.first().unwrap();
632
633 let _ = KeyPackageIn::from(kp.clone())
635 .standalone_validate(
636 &cc.transaction.mls_provider().await.unwrap(),
637 ProtocolVersion::Mls10,
638 true,
639 )
640 .await
641 .unwrap();
642
643 assert!(kp.extensions().is_empty());
645
646 assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
647 assert_eq!(
648 kp.leaf_node().capabilities().ciphersuites().to_vec(),
649 MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
650 .iter()
651 .map(|c| VerifiableCiphersuite::from(*c))
652 .collect::<Vec<_>>()
653 );
654 assert!(kp.leaf_node().capabilities().proposals().is_empty());
655 assert!(kp.leaf_node().capabilities().extensions().is_empty());
656 assert_eq!(
657 kp.leaf_node().capabilities().credentials(),
658 MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
659 );
660 })
661 })
662 .await
663 }
664}