core_crypto/mls/session/
key_package.rs

1use openmls::prelude::{Credential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime};
2use openmls_traits::OpenMlsCryptoProvider;
3use std::collections::HashMap;
4use std::ops::{Deref, DerefMut};
5use tls_codec::{Deserialize, Serialize};
6
7use core_crypto_keystore::{
8    connection::FetchFromDatabase,
9    entities::{EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage},
10};
11use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
12
13use super::{Error, Result};
14use crate::{
15    KeystoreError, MlsError, RecursiveError,
16    mls::{credential::CredentialBundle, session::SessionInner},
17    prelude::{MlsCiphersuite, MlsConversationConfiguration, MlsCredentialType, Session},
18    transaction_context::TransactionContext,
19};
20
21/// Default number of KeyPackages a client generates the first time it's created
22#[cfg(not(test))]
23pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 100;
24/// Default number of KeyPackages a client generates the first time it's created
25#[cfg(test)]
26pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
27
28/// Default lifetime of all generated KeyPackages. Matches the limit defined in openmls
29pub(crate) const KEYPACKAGE_DEFAULT_LIFETIME: std::time::Duration =
30    std::time::Duration::from_secs(60 * 60 * 24 * 28 * 3); // ~3 months
31
32impl Session {
33    /// Generates a single new keypackage
34    ///
35    /// # Arguments
36    /// * `backend` - the KeyStorage to load the keypackages from
37    ///
38    /// # Errors
39    /// KeyStore and OpenMls errors
40    pub async fn generate_one_keypackage_from_credential_bundle(
41        &self,
42        backend: &MlsCryptoProvider,
43        cs: MlsCiphersuite,
44        cb: &CredentialBundle,
45    ) -> Result<KeyPackage> {
46        match self.inner.read().await.deref() {
47            None => Err(Error::MlsNotInitialized),
48            Some(SessionInner {
49                keypackage_lifetime, ..
50            }) => {
51                let keypackage = KeyPackage::builder()
52                    .leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
53                    .key_package_lifetime(Lifetime::new(keypackage_lifetime.as_secs()))
54                    .build(
55                        CryptoConfig {
56                            ciphersuite: cs.into(),
57                            version: openmls::versions::ProtocolVersion::default(),
58                        },
59                        backend,
60                        &cb.signature_key,
61                        CredentialWithKey {
62                            credential: cb.credential.clone(),
63                            signature_key: cb.signature_key.public().into(),
64                        },
65                    )
66                    .await
67                    .map_err(KeystoreError::wrap("building keypackage"))?;
68
69                Ok(keypackage)
70            }
71        }
72    }
73
74    /// Requests `count` keying material to be present and returns
75    /// a reference to it for the consumer to copy/clone.
76    ///
77    /// # Arguments
78    /// * `count` - number of [openmls::key_packages::KeyPackage] to generate
79    /// * `ciphersuite` - of [openmls::key_packages::KeyPackage] to generate
80    /// * `backend` - the KeyStorage to load the keypackages from
81    ///
82    /// # Errors
83    /// KeyStore and OpenMls errors
84    pub async fn request_key_packages(
85        &self,
86        count: usize,
87        ciphersuite: MlsCiphersuite,
88        credential_type: MlsCredentialType,
89        backend: &MlsCryptoProvider,
90    ) -> Result<Vec<KeyPackage>> {
91        // Auto-prune expired keypackages on request
92        self.prune_keypackages(backend, &[]).await?;
93        use core_crypto_keystore::CryptoKeystoreMls as _;
94
95        let mut existing_kps = backend
96            .key_store()
97            .mls_fetch_keypackages::<KeyPackage>(count as u32)
98            .await.map_err(KeystoreError::wrap("fetching mls keypackages"))?
99            .into_iter()
100            // TODO: do this filtering in SQL when the schema is updated. Tracking issue: WPB-9599
101            .filter(|kp|
102                kp.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(kp.leaf_node().credential().credential_type()) == credential_type)
103            .collect::<Vec<_>>();
104
105        let kpb_count = existing_kps.len();
106        let mut kps = if count > kpb_count {
107            let to_generate = count - kpb_count;
108            let cb = self
109                .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
110                .await?;
111            self.generate_new_keypackages(backend, ciphersuite, &cb, to_generate)
112                .await?
113        } else {
114            vec![]
115        };
116
117        existing_kps.reverse();
118
119        kps.append(&mut existing_kps);
120        Ok(kps)
121    }
122
123    pub(crate) async fn generate_new_keypackages(
124        &self,
125        backend: &MlsCryptoProvider,
126        ciphersuite: MlsCiphersuite,
127        cb: &CredentialBundle,
128        count: usize,
129    ) -> Result<Vec<KeyPackage>> {
130        let mut kps = Vec::with_capacity(count);
131
132        for _ in 0..count {
133            let kp = self
134                .generate_one_keypackage_from_credential_bundle(backend, ciphersuite, cb)
135                .await?;
136            kps.push(kp);
137        }
138
139        Ok(kps)
140    }
141
142    /// Returns the count of valid, non-expired, unclaimed keypackages in store
143    pub async fn valid_keypackages_count(
144        &self,
145        backend: &MlsCryptoProvider,
146        ciphersuite: MlsCiphersuite,
147        credential_type: MlsCredentialType,
148    ) -> Result<usize> {
149        let kps: Vec<MlsKeyPackage> = backend
150            .key_store()
151            .find_all(EntityFindParams::default())
152            .await
153            .map_err(KeystoreError::wrap("finding all key packages"))?;
154
155        let mut valid_count = 0;
156        for kp in kps
157            .into_iter()
158            .map(|kp| core_crypto_keystore::deser::<KeyPackage>(&kp.keypackage))
159            // TODO: do this filtering in SQL when the schema is updated. Tracking issue: WPB-9599
160            .filter(|kp| {
161                kp.as_ref()
162                    .map(|b| b.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(b.leaf_node().credential().credential_type()) == credential_type)
163                    .unwrap_or_default()
164            })
165        {
166            let kp = kp.map_err(KeystoreError::wrap("counting valid keypackages"))?;
167            if !Self::is_mls_keypackage_expired(&kp) {
168                valid_count += 1;
169            }
170        }
171
172        Ok(valid_count)
173    }
174
175    /// Checks if a given OpenMLS [`KeyPackage`] is expired by looking through its extensions,
176    /// finding a lifetime extension and checking if it's valid.
177    fn is_mls_keypackage_expired(kp: &KeyPackage) -> bool {
178        let Some(lifetime) = kp.leaf_node().life_time() else {
179            return false;
180        };
181
182        !(lifetime.has_acceptable_range() && lifetime.is_valid())
183    }
184
185    /// Prune the provided KeyPackageRefs from the keystore
186    ///
187    /// Warning: Despite this API being public, the caller should know what they're doing.
188    /// Provided KeypackageRefs **will** be purged regardless of their expiration state, so please be wary of what you are doing if you directly call this API.
189    /// This could result in still valid, uploaded keypackages being pruned from the system and thus being impossible to find when referenced in a future Welcome message.
190    pub async fn prune_keypackages(&self, backend: &MlsCryptoProvider, refs: &[KeyPackageRef]) -> Result<()> {
191        let keystore = backend.keystore();
192        let kps = self.find_all_keypackages(&keystore).await?;
193        let _ = self._prune_keypackages(&kps, &keystore, refs).await?;
194        Ok(())
195    }
196
197    pub(crate) async fn prune_keypackages_and_credential(
198        &mut self,
199        backend: &MlsCryptoProvider,
200        refs: &[KeyPackageRef],
201    ) -> Result<()> {
202        match self.inner.write().await.deref_mut() {
203            None => Err(Error::MlsNotInitialized),
204            Some(SessionInner { identities, .. }) => {
205                let keystore = backend.key_store();
206                let kps = self.find_all_keypackages(keystore).await?;
207                let kp_to_delete = self._prune_keypackages(&kps, keystore, refs).await?;
208
209                // Let's group KeyPackages by Credential
210                let mut grouped_kps = HashMap::<Vec<u8>, Vec<KeyPackageRef>>::new();
211                for (_, kp) in &kps {
212                    let cred = kp
213                        .leaf_node()
214                        .credential()
215                        .tls_serialize_detached()
216                        .map_err(Error::tls_serialize("keypackage"))?;
217                    let kp_ref = kp
218                        .hash_ref(backend.crypto())
219                        .map_err(MlsError::wrap("computing keypackage hashref"))?;
220                    grouped_kps
221                        .entry(cred)
222                        .and_modify(|kprfs| kprfs.push(kp_ref.clone()))
223                        .or_insert(vec![kp_ref]);
224                }
225
226                for (credential, kps) in &grouped_kps {
227                    // If all KeyPackages are to be deleted for this given Credential
228                    let all_to_delete = kps.iter().all(|kpr| kp_to_delete.contains(&kpr.as_slice()));
229                    if all_to_delete {
230                        // then delete this Credential
231                        backend
232                            .keystore()
233                            .cred_delete_by_credential(credential.clone())
234                            .await
235                            .map_err(KeystoreError::wrap("deleting credential"))?;
236                        let credential = Credential::tls_deserialize(&mut credential.as_slice())
237                            .map_err(Error::tls_deserialize("credential"))?;
238                        identities.remove(&credential).await?;
239                    }
240                }
241                Ok(())
242            }
243        }
244    }
245
246    /// Deletes all expired KeyPackages plus the ones in `refs`. It also deletes all associated:
247    /// * HPKE private keys
248    /// * HPKE Encryption KeyPairs
249    /// * Signature KeyPairs & Credentials (use [Self::prune_keypackages_and_credential])
250    async fn _prune_keypackages<'a>(
251        &self,
252        kps: &'a [(MlsKeyPackage, KeyPackage)],
253        keystore: &CryptoKeystore,
254        refs: &[KeyPackageRef],
255    ) -> Result<Vec<&'a [u8]>, Error> {
256        let kp_to_delete: Vec<_> = kps
257            .iter()
258            .filter_map(|(store_kp, kp)| {
259                let is_expired = Self::is_mls_keypackage_expired(kp);
260                let mut to_delete = is_expired;
261                if !(is_expired || refs.is_empty()) {
262                    // not expired and there are some refs to check
263                    // then delete it when it's found in the refs
264                    to_delete = refs.iter().any(|r| r.as_slice() == store_kp.keypackage_ref);
265                }
266
267                to_delete.then_some((kp, &store_kp.keypackage_ref))
268            })
269            .collect();
270
271        for (kp, kp_ref) in &kp_to_delete {
272            // TODO: maybe rewrite this to optimize it. But honestly it's called so rarely and on a so tiny amount of data. Tacking issue: WPB-9600
273            keystore
274                .remove::<MlsKeyPackage, &[u8]>(kp_ref.as_slice())
275                .await
276                .map_err(KeystoreError::wrap("removing key package from keystore"))?;
277            keystore
278                .remove::<MlsHpkePrivateKey, &[u8]>(kp.hpke_init_key().as_slice())
279                .await
280                .map_err(KeystoreError::wrap("removing private key from keystore"))?;
281            keystore
282                .remove::<MlsEncryptionKeyPair, &[u8]>(kp.leaf_node().encryption_key().as_slice())
283                .await
284                .map_err(KeystoreError::wrap("removing encryption keypair from keystore"))?;
285        }
286
287        let kp_to_delete = kp_to_delete
288            .into_iter()
289            .map(|(_, kpref)| &kpref[..])
290            .collect::<Vec<_>>();
291
292        Ok(kp_to_delete)
293    }
294
295    async fn find_all_keypackages(&self, keystore: &CryptoKeystore) -> Result<Vec<(MlsKeyPackage, KeyPackage)>> {
296        let kps: Vec<MlsKeyPackage> = keystore
297            .find_all(EntityFindParams::default())
298            .await
299            .map_err(KeystoreError::wrap("finding all keypackages"))?;
300
301        let kps = kps
302            .into_iter()
303            .map(|raw_kp| -> Result<_> {
304                let kp = core_crypto_keystore::deser::<KeyPackage>(&raw_kp.keypackage)
305                    .map_err(KeystoreError::wrap("deserializing keypackage"))?;
306                Ok((raw_kp, kp))
307            })
308            .collect::<Result<Vec<_>, _>>()?;
309
310        Ok(kps)
311    }
312
313    /// Allows to set the current default keypackage lifetime extension duration.
314    /// It will be embedded in the [openmls::key_packages::KeyPackage]'s [openmls::extensions::LifetimeExtension]
315    #[cfg(test)]
316    pub async fn set_keypackage_lifetime(&self, duration: std::time::Duration) -> Result<()> {
317        use std::ops::DerefMut;
318        match self.inner.write().await.deref_mut() {
319            None => Err(Error::MlsNotInitialized),
320            Some(SessionInner {
321                keypackage_lifetime, ..
322            }) => {
323                *keypackage_lifetime = duration;
324                Ok(())
325            }
326        }
327    }
328}
329
330impl TransactionContext {
331    /// Returns `amount_requested` OpenMLS [openmls::key_packages::KeyPackage]s.
332    /// Will always return the requested amount as it will generate the necessary (lacking) amount on-the-fly
333    ///
334    /// Note: Keypackage pruning is performed as a first step
335    ///
336    /// # Arguments
337    /// * `amount_requested` - number of KeyPackages to request and fill the `KeyPackageBundle`
338    ///
339    /// # Return type
340    /// A vector of `KeyPackageBundle`
341    ///
342    /// # Errors
343    /// Errors can happen when accessing the KeyStore
344    pub async fn get_or_create_client_keypackages(
345        &self,
346        ciphersuite: MlsCiphersuite,
347        credential_type: MlsCredentialType,
348        amount_requested: usize,
349    ) -> Result<Vec<KeyPackage>> {
350        let client = self
351            .session()
352            .await
353            .map_err(RecursiveError::transaction("getting mls client"))?;
354        client
355            .request_key_packages(
356                amount_requested,
357                ciphersuite,
358                credential_type,
359                &self
360                    .mls_provider()
361                    .await
362                    .map_err(RecursiveError::transaction("getting mls provider"))?,
363            )
364            .await
365    }
366
367    /// Returns the count of valid, non-expired, unclaimed keypackages in store for the given [MlsCiphersuite] and [MlsCredentialType]
368    #[cfg_attr(test, crate::idempotent)]
369    pub async fn client_valid_key_packages_count(
370        &self,
371        ciphersuite: MlsCiphersuite,
372        credential_type: MlsCredentialType,
373    ) -> Result<usize> {
374        let client = self
375            .session()
376            .await
377            .map_err(RecursiveError::transaction("getting mls client"))?;
378        client
379            .valid_keypackages_count(
380                &self
381                    .mls_provider()
382                    .await
383                    .map_err(RecursiveError::transaction("getting mls provider"))?,
384                ciphersuite,
385                credential_type,
386            )
387            .await
388    }
389
390    /// Prunes local KeyPackages after making sure they also have been deleted on the backend side
391    /// You should only use this after [TransactionContext::save_x509_credential]
392    #[cfg_attr(test, crate::dispotent)]
393    pub async fn delete_keypackages(&self, refs: &[KeyPackageRef]) -> Result<()> {
394        if refs.is_empty() {
395            return Err(Error::EmptyKeypackageList);
396        }
397        let mut client = self
398            .session()
399            .await
400            .map_err(RecursiveError::transaction("getting mls client"))?;
401        client
402            .prune_keypackages_and_credential(
403                &self
404                    .mls_provider()
405                    .await
406                    .map_err(RecursiveError::transaction("getting mls provider"))?,
407                refs,
408            )
409            .await
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use openmls::prelude::{KeyPackage, KeyPackageIn, KeyPackageRef, ProtocolVersion};
416    use openmls_traits::OpenMlsCryptoProvider;
417    use openmls_traits::types::VerifiableCiphersuite;
418    use wasm_bindgen_test::*;
419
420    use mls_crypto_provider::MlsCryptoProvider;
421
422    use crate::e2e_identity::enrollment::test_utils::{e2ei_enrollment, init_activation_or_rotation, noop_restore};
423    use crate::prelude::MlsConversationConfiguration;
424    use crate::prelude::key_package::INITIAL_KEYING_MATERIAL_COUNT;
425    use crate::test_utils::*;
426    use core_crypto_keystore::DatabaseKey;
427
428    use super::Session;
429
430    wasm_bindgen_test_configure!(run_in_browser);
431
432    #[apply(all_cred_cipher)]
433    #[wasm_bindgen_test]
434    async fn can_assess_keypackage_expiration(case: TestCase) {
435        run_test_with_central(case.clone(), move |[central]| {
436            Box::pin(async move {
437                let (cs, ct) = (case.ciphersuite(), case.credential_type);
438                let key = DatabaseKey::generate();
439                let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
440                let x509_test_chain = if case.is_x509() {
441                    let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
442                    x509_test_chain.register_with_provider(&backend).await;
443                    Some(x509_test_chain)
444                } else {
445                    None
446                };
447
448                backend.new_transaction().await.unwrap();
449                let client = central.session;
450                client
451                    .random_generate(
452                        &case,
453                        x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
454                        false,
455                    )
456                    .await
457                    .unwrap();
458
459                // 90-day standard expiration
460                let kp_std_exp = client.generate_one_keypackage(&backend, cs, ct).await.unwrap();
461                assert!(!Session::is_mls_keypackage_expired(&kp_std_exp));
462
463                // 1-second expiration
464                client
465                    .set_keypackage_lifetime(std::time::Duration::from_secs(1))
466                    .await
467                    .unwrap();
468                let kp_1s_exp = client.generate_one_keypackage(&backend, cs, ct).await.unwrap();
469                // Sleep 2 seconds to make sure we make the kp expire
470                async_std::task::sleep(std::time::Duration::from_secs(2)).await;
471                assert!(Session::is_mls_keypackage_expired(&kp_1s_exp));
472            })
473        })
474        .await;
475    }
476
477    #[apply(all_cred_cipher)]
478    #[wasm_bindgen_test]
479    async fn requesting_x509_key_packages_after_basic(case: TestCase) {
480        // Basic test case
481        if !case.is_basic() {
482            return;
483        }
484        run_test_with_client_ids(case.clone(), ["alice"], move |[mut session_context]| {
485            Box::pin(async move {
486                let signature_scheme = case.signature_scheme();
487                let cipher_suite = case.ciphersuite();
488
489                // Generate 5 Basic key packages first
490                let _basic_key_packages = session_context
491                    .context
492                    .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::Basic, 5)
493                    .await
494                    .unwrap();
495
496                // Set up E2E identity
497                let test_chain = x509::X509TestChain::init_for_random_clients(signature_scheme, 1);
498
499                let (mut enrollment, cert_chain) = e2ei_enrollment(
500                    &mut session_context,
501                    &case,
502                    &test_chain,
503                    None,
504                    false,
505                    init_activation_or_rotation,
506                    noop_restore,
507                )
508                .await
509                .unwrap();
510
511                let _rotate_bundle = session_context
512                    .context
513                    .save_x509_credential(&mut enrollment, cert_chain)
514                    .await
515                    .unwrap();
516
517                // E2E identity has been set up correctly
518                assert!(session_context.context.e2ei_is_enabled(signature_scheme).await.unwrap());
519
520                // Request X509 key packages
521                let x509_key_packages = session_context
522                    .context
523                    .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::X509, 5)
524                    .await
525                    .unwrap();
526
527                // Verify that the key packages are X509
528                assert!(x509_key_packages.iter().all(|kp| MlsCredentialType::X509
529                    == MlsCredentialType::from(kp.leaf_node().credential().credential_type())));
530            })
531        })
532        .await
533    }
534
535    #[apply(all_cred_cipher)]
536    #[wasm_bindgen_test]
537    async fn generates_correct_number_of_kpbs(case: TestCase) {
538        run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
539            Box::pin(async move {
540                const N: usize = 2;
541                const COUNT: usize = 109;
542
543                let init = cc.context.count_entities().await;
544                assert_eq!(init.key_package, INITIAL_KEYING_MATERIAL_COUNT);
545                assert_eq!(init.encryption_keypair, INITIAL_KEYING_MATERIAL_COUNT);
546                assert_eq!(init.hpke_private_key, INITIAL_KEYING_MATERIAL_COUNT);
547                assert_eq!(init.credential, 1);
548                assert_eq!(init.signature_keypair, 1);
549
550                // since 'delete_keypackages' will evict all Credentials unlinked to a KeyPackage, each iteration
551                // generates 1 extra KeyPackage in order for this Credential no to be evicted and next iteration sto succeed.
552                let transactional_provider = cc.context.mls_provider().await.unwrap();
553                let crypto_provider = transactional_provider.crypto();
554                let mut pinned_kp = None;
555
556                let mut prev_kps: Option<Vec<KeyPackage>> = None;
557                for _ in 0..N {
558                    let mut kps = cc
559                        .context
560                        .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, COUNT + 1)
561                        .await
562                        .unwrap();
563
564                    // this will always be the same, first KeyPackage
565                    pinned_kp = Some(kps.pop().unwrap());
566
567                    assert_eq!(kps.len(), COUNT);
568                    let after_creation = cc.context.count_entities().await;
569                    assert_eq!(after_creation.key_package, COUNT + 1);
570                    assert_eq!(after_creation.encryption_keypair, COUNT + 1);
571                    assert_eq!(after_creation.hpke_private_key, COUNT + 1);
572                    assert_eq!(after_creation.credential, 1);
573
574                    let kpbs_refs = kps
575                        .iter()
576                        .map(|kp| kp.hash_ref(crypto_provider).unwrap())
577                        .collect::<Vec<KeyPackageRef>>();
578
579                    if let Some(pkpbs) = prev_kps.replace(kps) {
580                        let pkpbs_refs = pkpbs
581                            .into_iter()
582                            .map(|kpb| kpb.hash_ref(crypto_provider).unwrap())
583                            .collect::<Vec<KeyPackageRef>>();
584
585                        let has_duplicates = kpbs_refs.iter().any(|href| pkpbs_refs.contains(href));
586                        // Make sure we have no previous keypackages found (that were pruned) in our new batch of KPs
587                        assert!(!has_duplicates);
588                    }
589                    cc.context.delete_keypackages(&kpbs_refs).await.unwrap();
590                }
591
592                let count = cc
593                    .context
594                    .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
595                    .await
596                    .unwrap();
597                assert_eq!(count, 1);
598
599                let pinned_kpr = pinned_kp.unwrap().hash_ref(crypto_provider).unwrap();
600                cc.context.delete_keypackages(&[pinned_kpr]).await.unwrap();
601                let count = cc
602                    .context
603                    .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
604                    .await
605                    .unwrap();
606                assert_eq!(count, 0);
607                let after_delete = cc.context.count_entities().await;
608                assert_eq!(after_delete.key_package, 0);
609                assert_eq!(after_delete.encryption_keypair, 0);
610                assert_eq!(after_delete.hpke_private_key, 0);
611                assert_eq!(after_delete.credential, 0);
612            })
613        })
614        .await
615    }
616
617    #[apply(all_cred_cipher)]
618    #[wasm_bindgen_test]
619    async fn automatically_prunes_lifetime_expired_keypackages(case: TestCase) {
620        run_test_with_central(case.clone(), move |[central]| {
621            Box::pin(async move {
622                const UNEXPIRED_COUNT: usize = 125;
623                const EXPIRED_COUNT: usize = 200;
624                let key = DatabaseKey::generate();
625                let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
626                let x509_test_chain = if case.is_x509() {
627                    let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
628                    x509_test_chain.register_with_provider(&backend).await;
629                    Some(x509_test_chain)
630                } else {
631                    None
632                };
633                backend.new_transaction().await.unwrap();
634                let client = central.session().await;
635                client
636                    .random_generate(
637                        &case,
638                        x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
639                        false,
640                    )
641                    .await
642                    .unwrap();
643
644                // Generate `UNEXPIRED_COUNT` kpbs that are with default 3 months expiration. We *should* keep them for the duration of the test
645                let unexpired_kpbs = client
646                    .request_key_packages(UNEXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
647                    .await
648                    .unwrap();
649                let len = client
650                    .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
651                    .await
652                    .unwrap();
653                assert_eq!(len, unexpired_kpbs.len());
654                assert_eq!(len, UNEXPIRED_COUNT);
655
656                // Set the keypackage expiration to be in 2 seconds
657                client
658                    .set_keypackage_lifetime(std::time::Duration::from_secs(10))
659                    .await
660                    .unwrap();
661
662                // Generate new keypackages that are normally partially expired 2s after they're requested
663                let partially_expired_kpbs = client
664                    .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
665                    .await
666                    .unwrap();
667                assert_eq!(partially_expired_kpbs.len(), EXPIRED_COUNT);
668
669                // Sleep to trigger the expiration
670                async_std::task::sleep(std::time::Duration::from_secs(10)).await;
671
672                // Request the same number of keypackages. The automatic lifetime-based expiration should take
673                // place and remove old expired keypackages and generate fresh ones instead
674                let fresh_kpbs = client
675                    .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
676                    .await
677                    .unwrap();
678                let len = client
679                    .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
680                    .await
681                    .unwrap();
682                assert_eq!(len, fresh_kpbs.len());
683                assert_eq!(len, EXPIRED_COUNT);
684
685                // Try to deep compare and find kps matching expired and non-expired ones
686                let (unexpired_match, expired_match) =
687                    fresh_kpbs
688                        .iter()
689                        .fold((0usize, 0usize), |(mut unexpired_match, mut expired_match), fresh| {
690                            if unexpired_kpbs.iter().any(|kp| kp == fresh) {
691                                unexpired_match += 1;
692                            } else if partially_expired_kpbs.iter().any(|kpb| kpb == fresh) {
693                                expired_match += 1;
694                            }
695
696                            (unexpired_match, expired_match)
697                        });
698
699                // TADA!
700                assert_eq!(unexpired_match, UNEXPIRED_COUNT);
701                assert_eq!(expired_match, 0);
702            })
703        })
704        .await
705    }
706
707    #[apply(all_cred_cipher)]
708    #[wasm_bindgen_test]
709    async fn new_keypackage_has_correct_extensions(case: TestCase) {
710        run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
711            Box::pin(async move {
712                let kps = cc
713                    .context
714                    .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 1)
715                    .await
716                    .unwrap();
717                let kp = kps.first().unwrap();
718
719                // make sure it's valid
720                let _ = KeyPackageIn::from(kp.clone())
721                    .standalone_validate(&cc.context.mls_provider().await.unwrap(), ProtocolVersion::Mls10, true)
722                    .await
723                    .unwrap();
724
725                // see https://www.rfc-editor.org/rfc/rfc9420.html#section-10-10
726                assert!(kp.extensions().is_empty());
727
728                assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
729                assert_eq!(
730                    kp.leaf_node().capabilities().ciphersuites().to_vec(),
731                    MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
732                        .iter()
733                        .map(|c| VerifiableCiphersuite::from(*c))
734                        .collect::<Vec<_>>()
735                );
736                assert!(kp.leaf_node().capabilities().proposals().is_empty());
737                assert!(kp.leaf_node().capabilities().extensions().is_empty());
738                assert_eq!(
739                    kp.leaf_node().capabilities().credentials(),
740                    MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
741                );
742            })
743        })
744        .await
745    }
746}