core_crypto/mls/client/
key_package.rs

1// Wire
2// Copyright (C) 2022 Wire Swiss GmbH
3
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with this program. If not, see http://www.gnu.org/licenses/.
16
17use openmls::prelude::{Credential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime};
18use openmls_traits::OpenMlsCryptoProvider;
19use std::collections::HashMap;
20use std::ops::{Deref, DerefMut};
21use tls_codec::{Deserialize, Serialize};
22
23use core_crypto_keystore::{
24    connection::FetchFromDatabase,
25    entities::{EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage},
26};
27use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
28
29use crate::context::CentralContext;
30use crate::mls::client::ClientInner;
31use crate::{
32    mls::credential::CredentialBundle,
33    prelude::{
34        Client, CryptoError, CryptoResult, MlsCiphersuite, MlsConversationConfiguration, MlsCredentialType, MlsError,
35    },
36};
37
38/// Default number of KeyPackages a client generates the first time it's created
39#[cfg(not(test))]
40pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 100;
41#[cfg(test)]
42pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
43
44/// Default lifetime of all generated KeyPackages. Matches the limit defined in openmls
45pub(crate) const KEYPACKAGE_DEFAULT_LIFETIME: std::time::Duration =
46    std::time::Duration::from_secs(60 * 60 * 24 * 28 * 3); // ~3 months
47
48impl Client {
49    /// Generates a single new keypackage
50    ///
51    /// # Arguments
52    /// * `backend` - the KeyStorage to load the keypackages from
53    ///
54    /// # Errors
55    /// KeyStore and OpenMls errors
56    pub async fn generate_one_keypackage_from_credential_bundle(
57        &self,
58        backend: &MlsCryptoProvider,
59        cs: MlsCiphersuite,
60        cb: &CredentialBundle,
61    ) -> CryptoResult<KeyPackage> {
62        match self.state.read().await.deref() {
63            None => Err(CryptoError::MlsNotInitialized),
64            Some(ClientInner {
65                keypackage_lifetime, ..
66            }) => {
67                let keypackage = KeyPackage::builder()
68                    .leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
69                    .key_package_lifetime(Lifetime::new(keypackage_lifetime.as_secs()))
70                    .build(
71                        CryptoConfig {
72                            ciphersuite: cs.into(),
73                            version: openmls::versions::ProtocolVersion::default(),
74                        },
75                        backend,
76                        &cb.signature_key,
77                        CredentialWithKey {
78                            credential: cb.credential.clone(),
79                            signature_key: cb.signature_key.public().into(),
80                        },
81                    )
82                    .await
83                    .map_err(MlsError::from)?;
84
85                Ok(keypackage)
86            }
87        }
88    }
89
90    /// Requests `count` keying material to be present and returns
91    /// a reference to it for the consumer to copy/clone.
92    ///
93    /// # Arguments
94    /// * `count` - number of [openmls::key_packages::KeyPackage] to generate
95    /// * `ciphersuite` - of [openmls::key_packages::KeyPackage] to generate
96    /// * `backend` - the KeyStorage to load the keypackages from
97    ///
98    /// # Errors
99    /// KeyStore and OpenMls errors
100    pub async fn request_key_packages(
101        &self,
102        count: usize,
103        ciphersuite: MlsCiphersuite,
104        credential_type: MlsCredentialType,
105        backend: &MlsCryptoProvider,
106    ) -> CryptoResult<Vec<KeyPackage>> {
107        // Auto-prune expired keypackages on request
108        self.prune_keypackages(backend, &[]).await?;
109        use core_crypto_keystore::CryptoKeystoreMls as _;
110
111        let mut existing_kps = backend
112            .key_store()
113            .mls_fetch_keypackages::<KeyPackage>(count as u32)
114            .await?
115            .into_iter()
116            // TODO: do this filtering in SQL when the schema is updated. Tracking issue: WPB-9599
117            .filter(|kp|
118                kp.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(kp.leaf_node().credential().credential_type()) == credential_type)
119            .collect::<Vec<_>>();
120
121        let kpb_count = existing_kps.len();
122        let mut kps = if count > kpb_count {
123            let to_generate = count - kpb_count;
124            let cb = self
125                .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
126                .await?;
127            self.generate_new_keypackages(backend, ciphersuite, &cb, to_generate)
128                .await?
129        } else {
130            vec![]
131        };
132
133        existing_kps.reverse();
134
135        kps.append(&mut existing_kps);
136        Ok(kps)
137    }
138
139    pub(crate) async fn generate_new_keypackages(
140        &self,
141        backend: &MlsCryptoProvider,
142        ciphersuite: MlsCiphersuite,
143        cb: &CredentialBundle,
144        count: usize,
145    ) -> CryptoResult<Vec<KeyPackage>> {
146        let mut kps = Vec::with_capacity(count);
147
148        for _ in 0..count {
149            let kp = self
150                .generate_one_keypackage_from_credential_bundle(backend, ciphersuite, cb)
151                .await?;
152            kps.push(kp);
153        }
154
155        Ok(kps)
156    }
157
158    /// Returns the count of valid, non-expired, unclaimed keypackages in store
159    pub async fn valid_keypackages_count(
160        &self,
161        backend: &MlsCryptoProvider,
162        ciphersuite: MlsCiphersuite,
163        credential_type: MlsCredentialType,
164    ) -> CryptoResult<usize> {
165        let kps: Vec<MlsKeyPackage> = backend.key_store().find_all(EntityFindParams::default()).await?;
166
167        let valid_count = kps
168            .into_iter()
169            .map(|kp| core_crypto_keystore::deser::<KeyPackage>(&kp.keypackage))
170            // TODO: do this filtering in SQL when the schema is updated. Tracking issue: WPB-9599
171            .filter(|kp| {
172                kp.as_ref()
173                    .map(|b| b.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(b.leaf_node().credential().credential_type()) == credential_type)
174                    .unwrap_or_default()
175            })
176            .try_fold(0usize, |mut valid_count, kp| {
177                if !Self::is_mls_keypackage_expired(&kp?) {
178                    valid_count += 1;
179                }
180                CryptoResult::Ok(valid_count)
181            })?;
182
183        Ok(valid_count)
184    }
185
186    /// Checks if a given OpenMLS [`KeyPackage`] is expired by looking through its extensions,
187    /// finding a lifetime extension and checking if it's valid.
188    fn is_mls_keypackage_expired(kp: &KeyPackage) -> bool {
189        let Some(lifetime) = kp.leaf_node().life_time() else {
190            return false;
191        };
192
193        !(lifetime.has_acceptable_range() && lifetime.is_valid())
194    }
195
196    /// Prune the provided KeyPackageRefs from the keystore
197    ///
198    /// Warning: Despite this API being public, the caller should know what they're doing.
199    /// Provided KeypackageRefs **will** be purged regardless of their expiration state, so please be wary of what you are doing if you directly call this API.
200    /// This could result in still valid, uploaded keypackages being pruned from the system and thus being impossible to find when referenced in a future Welcome message.
201    pub async fn prune_keypackages(&self, backend: &MlsCryptoProvider, refs: &[KeyPackageRef]) -> CryptoResult<()> {
202        let keystore = backend.keystore();
203        let kps = self.find_all_keypackages(&keystore).await?;
204        let _ = self._prune_keypackages(&kps, &keystore, refs).await?;
205        Ok(())
206    }
207
208    pub(crate) async fn prune_keypackages_and_credential(
209        &mut self,
210        backend: &MlsCryptoProvider,
211        refs: &[KeyPackageRef],
212    ) -> CryptoResult<()> {
213        match self.state.write().await.deref_mut() {
214            None => Err(CryptoError::MlsNotInitialized),
215            Some(ClientInner { identities, .. }) => {
216                let keystore = backend.key_store();
217                let kps = self.find_all_keypackages(keystore).await?;
218                let kp_to_delete = self._prune_keypackages(&kps, keystore, refs).await?;
219
220                // Let's group KeyPackages by Credential
221                let mut grouped_kps = HashMap::<Vec<u8>, Vec<KeyPackageRef>>::new();
222                for (_, kp) in &kps {
223                    let cred = kp
224                        .leaf_node()
225                        .credential()
226                        .tls_serialize_detached()
227                        .map_err(MlsError::from)?;
228                    let kp_ref = kp.hash_ref(backend.crypto()).map_err(MlsError::from)?;
229                    grouped_kps
230                        .entry(cred)
231                        .and_modify(|kprfs| kprfs.push(kp_ref.clone()))
232                        .or_insert(vec![kp_ref]);
233                }
234
235                for (credential, kps) in &grouped_kps {
236                    // If all KeyPackages are to be deleted for this given Credential
237                    let all_to_delete = kps.iter().all(|kpr| kp_to_delete.contains(&kpr.as_slice()));
238                    if all_to_delete {
239                        // then delete this Credential
240                        backend.keystore().cred_delete_by_credential(credential.clone()).await?;
241                        let credential =
242                            Credential::tls_deserialize(&mut credential.as_slice()).map_err(MlsError::from)?;
243                        identities.remove(&credential).await?;
244                    }
245                }
246                Ok(())
247            }
248        }
249    }
250
251    /// Deletes all expired KeyPackages plus the ones in `refs`. It also deletes all associated:
252    /// * HPKE private keys
253    /// * HPKE Encryption KeyPairs
254    /// * Signature KeyPairs & Credentials (use [Self::prune_keypackages_and_credential])
255    async fn _prune_keypackages<'a>(
256        &self,
257        kps: &'a [(MlsKeyPackage, KeyPackage)],
258        keystore: &CryptoKeystore,
259        refs: &[KeyPackageRef],
260    ) -> Result<Vec<&'a [u8]>, CryptoError> {
261        let kp_to_delete: Vec<_> = kps
262            .iter()
263            .filter_map(|(store_kp, kp)| {
264                let is_expired = Self::is_mls_keypackage_expired(kp);
265                let mut to_delete = is_expired;
266                if !(is_expired || refs.is_empty()) {
267                    // not expired and there are some refs to check
268                    // then delete it when it's found in the refs
269                    to_delete = refs.iter().any(|r| r.as_slice() == store_kp.keypackage_ref);
270                }
271
272                to_delete.then_some((kp, &store_kp.keypackage_ref))
273            })
274            .collect();
275
276        for (kp, kp_ref) in &kp_to_delete {
277            // TODO: maybe rewrite this to optimize it. But honestly it's called so rarely and on a so tiny amount of data. Tacking issue: WPB-9600
278            keystore.remove::<MlsKeyPackage, &[u8]>(kp_ref.as_slice()).await?;
279            keystore
280                .remove::<MlsHpkePrivateKey, &[u8]>(kp.hpke_init_key().as_slice())
281                .await?;
282            keystore
283                .remove::<MlsEncryptionKeyPair, &[u8]>(kp.leaf_node().encryption_key().as_slice())
284                .await?;
285        }
286
287        let kp_to_delete = kp_to_delete
288            .into_iter()
289            .map(|(_, kpref)| &kpref[..])
290            .collect::<Vec<_>>();
291
292        Ok(kp_to_delete)
293    }
294
295    async fn find_all_keypackages(&self, keystore: &CryptoKeystore) -> CryptoResult<Vec<(MlsKeyPackage, KeyPackage)>> {
296        let kps: Vec<MlsKeyPackage> = keystore.find_all(EntityFindParams::default()).await?;
297
298        let kps = kps.into_iter().try_fold(vec![], |mut acc, raw_kp| {
299            let kp = core_crypto_keystore::deser::<KeyPackage>(&raw_kp.keypackage)?;
300            acc.push((raw_kp, kp));
301            CryptoResult::Ok(acc)
302        })?;
303
304        Ok(kps)
305    }
306
307    /// Allows to set the current default keypackage lifetime extension duration.
308    /// It will be embedded in the [openmls::key_packages::KeyPackage]'s [openmls::extensions::LifetimeExtension]
309    #[cfg(test)]
310    pub async fn set_keypackage_lifetime(&self, duration: std::time::Duration) -> CryptoResult<()> {
311        use std::ops::DerefMut;
312        match self.state.write().await.deref_mut() {
313            None => Err(CryptoError::MlsNotInitialized),
314            Some(ClientInner {
315                ref mut keypackage_lifetime,
316                ..
317            }) => {
318                *keypackage_lifetime = duration;
319                Ok(())
320            }
321        }
322    }
323}
324
325impl CentralContext {
326    /// Returns `amount_requested` OpenMLS [openmls::key_packages::KeyPackage]s.
327    /// Will always return the requested amount as it will generate the necessary (lacking) amount on-the-fly
328    ///
329    /// Note: Keypackage pruning is performed as a first step
330    ///
331    /// # Arguments
332    /// * `amount_requested` - number of KeyPackages to request and fill the `KeyPackageBundle`
333    ///
334    /// # Return type
335    /// A vector of `KeyPackageBundle`
336    ///
337    /// # Errors
338    /// Errors can happen when accessing the KeyStore
339    pub async fn get_or_create_client_keypackages(
340        &self,
341        ciphersuite: MlsCiphersuite,
342        credential_type: MlsCredentialType,
343        amount_requested: usize,
344    ) -> CryptoResult<Vec<KeyPackage>> {
345        let client = self.mls_client().await?;
346        client
347            .request_key_packages(
348                amount_requested,
349                ciphersuite,
350                credential_type,
351                &self.mls_provider().await?,
352            )
353            .await
354    }
355
356    /// Returns the count of valid, non-expired, unclaimed keypackages in store for the given [MlsCiphersuite] and [MlsCredentialType]
357    #[cfg_attr(test, crate::idempotent)]
358    pub async fn client_valid_key_packages_count(
359        &self,
360        ciphersuite: MlsCiphersuite,
361        credential_type: MlsCredentialType,
362    ) -> CryptoResult<usize> {
363        let client = self.mls_client().await?;
364        client
365            .valid_keypackages_count(&self.mls_provider().await?, ciphersuite, credential_type)
366            .await
367    }
368
369    /// Prunes local KeyPackages after making sure they also have been deleted on the backend side
370    /// You should only use this after [CentralContext::e2ei_rotate_all]
371    #[cfg_attr(test, crate::dispotent)]
372    pub async fn delete_keypackages(&self, refs: &[KeyPackageRef]) -> CryptoResult<()> {
373        if refs.is_empty() {
374            return Err(CryptoError::ConsumerError);
375        }
376        let mut client = self.mls_client().await?;
377        client
378            .prune_keypackages_and_credential(&self.mls_provider().await?, refs)
379            .await
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use openmls::prelude::{KeyPackage, KeyPackageIn, KeyPackageRef, ProtocolVersion};
386    use openmls_traits::types::VerifiableCiphersuite;
387    use openmls_traits::OpenMlsCryptoProvider;
388    use wasm_bindgen_test::*;
389
390    use mls_crypto_provider::MlsCryptoProvider;
391
392    use crate::e2e_identity::tests::{e2ei_enrollment, init_activation_or_rotation, noop_restore};
393    use crate::prelude::key_package::INITIAL_KEYING_MATERIAL_COUNT;
394    use crate::prelude::MlsConversationConfiguration;
395    use crate::test_utils::*;
396
397    use super::Client;
398
399    wasm_bindgen_test_configure!(run_in_browser);
400
401    #[apply(all_cred_cipher)]
402    #[wasm_bindgen_test]
403    async fn can_assess_keypackage_expiration(case: TestCase) {
404        let (cs, ct) = (case.ciphersuite(), case.credential_type);
405        let backend = MlsCryptoProvider::try_new_in_memory("test").await.unwrap();
406        let x509_test_chain = if case.is_x509() {
407            let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
408            x509_test_chain.register_with_provider(&backend).await;
409            Some(x509_test_chain)
410        } else {
411            None
412        };
413
414        backend.new_transaction().await.unwrap();
415        let client = Client::random_generate(
416            &case,
417            &backend,
418            x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
419            false,
420        )
421        .await
422        .unwrap();
423
424        // 90-day standard expiration
425        let kp_std_exp = client.generate_one_keypackage(&backend, cs, ct).await.unwrap();
426        assert!(!Client::is_mls_keypackage_expired(&kp_std_exp));
427
428        // 1-second expiration
429        client
430            .set_keypackage_lifetime(std::time::Duration::from_secs(1))
431            .await
432            .unwrap();
433        let kp_1s_exp = client.generate_one_keypackage(&backend, cs, ct).await.unwrap();
434        // Sleep 2 seconds to make sure we make the kp expire
435        async_std::task::sleep(std::time::Duration::from_secs(2)).await;
436        assert!(Client::is_mls_keypackage_expired(&kp_1s_exp));
437    }
438
439    #[apply(all_cred_cipher)]
440    #[wasm_bindgen_test]
441    async fn requesting_x509_key_packages_after_basic(case: TestCase) {
442        // Basic test case
443        if !case.is_basic() {
444            return;
445        }
446        run_test_with_client_ids(case.clone(), ["alice"], move |[mut client_context]| {
447            Box::pin(async move {
448                let signature_scheme = case.signature_scheme();
449                let cipher_suite = case.ciphersuite();
450
451                // Generate 5 Basic key packages first
452                let _basic_key_packages = client_context
453                    .context
454                    .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::Basic, 5)
455                    .await
456                    .unwrap();
457
458                // Set up E2E identity
459                let test_chain = x509::X509TestChain::init_for_random_clients(signature_scheme, 1);
460
461                let (mut enrollment, cert_chain) = e2ei_enrollment(
462                    &mut client_context,
463                    &case,
464                    &test_chain,
465                    None,
466                    false,
467                    init_activation_or_rotation,
468                    noop_restore,
469                )
470                .await
471                .unwrap();
472
473                let _rotate_bundle = client_context
474                    .context
475                    .e2ei_rotate_all(&mut enrollment, cert_chain, 5)
476                    .await
477                    .unwrap();
478
479                // E2E identity has been set up correctly
480                assert!(client_context.context.e2ei_is_enabled(signature_scheme).await.unwrap());
481
482                // Request X509 key packages
483                let x509_key_packages = client_context
484                    .context
485                    .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::X509, 5)
486                    .await
487                    .unwrap();
488
489                // Verify that the key packages are X509
490                assert!(x509_key_packages.iter().all(|kp| MlsCredentialType::X509
491                    == MlsCredentialType::from(kp.leaf_node().credential().credential_type())));
492            })
493        })
494        .await
495    }
496
497    #[apply(all_cred_cipher)]
498    #[wasm_bindgen_test]
499    async fn generates_correct_number_of_kpbs(case: TestCase) {
500        run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
501            Box::pin(async move {
502                const N: usize = 2;
503                const COUNT: usize = 109;
504
505                let init = cc.context.count_entities().await;
506                assert_eq!(init.key_package, INITIAL_KEYING_MATERIAL_COUNT);
507                assert_eq!(init.encryption_keypair, INITIAL_KEYING_MATERIAL_COUNT);
508                assert_eq!(init.hpke_private_key, INITIAL_KEYING_MATERIAL_COUNT);
509                assert_eq!(init.credential, 1);
510                assert_eq!(init.signature_keypair, 1);
511
512                // since 'delete_keypackages' will evict all Credentials unlinked to a KeyPackage, each iteration
513                // generates 1 extra KeyPackage in order for this Credential no to be evicted and next iteration sto succeed.
514                let transactional_provider = cc.context.mls_provider().await.unwrap();
515                let crypto_provider = transactional_provider.crypto();
516                let mut pinned_kp = None;
517
518                let mut prev_kps: Option<Vec<KeyPackage>> = None;
519                for _ in 0..N {
520                    let mut kps = cc
521                        .context
522                        .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, COUNT + 1)
523                        .await
524                        .unwrap();
525
526                    // this will always be the same, first KeyPackage
527                    pinned_kp = Some(kps.pop().unwrap());
528
529                    assert_eq!(kps.len(), COUNT);
530                    let after_creation = cc.context.count_entities().await;
531                    assert_eq!(after_creation.key_package, COUNT + 1);
532                    assert_eq!(after_creation.encryption_keypair, COUNT + 1);
533                    assert_eq!(after_creation.hpke_private_key, COUNT + 1);
534                    assert_eq!(after_creation.credential, 1);
535
536                    let kpbs_refs = kps
537                        .iter()
538                        .map(|kp| kp.hash_ref(crypto_provider).unwrap())
539                        .collect::<Vec<KeyPackageRef>>();
540
541                    if let Some(pkpbs) = prev_kps.replace(kps) {
542                        let pkpbs_refs = pkpbs
543                            .into_iter()
544                            .map(|kpb| kpb.hash_ref(crypto_provider).unwrap())
545                            .collect::<Vec<KeyPackageRef>>();
546
547                        let has_duplicates = kpbs_refs.iter().any(|href| pkpbs_refs.contains(href));
548                        // Make sure we have no previous keypackages found (that were pruned) in our new batch of KPs
549                        assert!(!has_duplicates);
550                    }
551                    cc.context.delete_keypackages(&kpbs_refs).await.unwrap();
552                }
553
554                let count = cc
555                    .context
556                    .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
557                    .await
558                    .unwrap();
559                assert_eq!(count, 1);
560
561                let pinned_kpr = pinned_kp.unwrap().hash_ref(crypto_provider).unwrap();
562                cc.context.delete_keypackages(&[pinned_kpr]).await.unwrap();
563                let count = cc
564                    .context
565                    .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
566                    .await
567                    .unwrap();
568                assert_eq!(count, 0);
569                let after_delete = cc.context.count_entities().await;
570                assert_eq!(after_delete.key_package, 0);
571                assert_eq!(after_delete.encryption_keypair, 0);
572                assert_eq!(after_delete.hpke_private_key, 0);
573                assert_eq!(after_delete.credential, 0);
574            })
575        })
576        .await
577    }
578
579    #[apply(all_cred_cipher)]
580    #[wasm_bindgen_test]
581    async fn automatically_prunes_lifetime_expired_keypackages(case: TestCase) {
582        const UNEXPIRED_COUNT: usize = 125;
583        const EXPIRED_COUNT: usize = 200;
584        let backend = MlsCryptoProvider::try_new_in_memory("test").await.unwrap();
585        let x509_test_chain = if case.is_x509() {
586            let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
587            x509_test_chain.register_with_provider(&backend).await;
588            Some(x509_test_chain)
589        } else {
590            None
591        };
592        backend.new_transaction().await.unwrap();
593        let client = Client::random_generate(
594            &case,
595            &backend,
596            x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
597            false,
598        )
599        .await
600        .unwrap();
601
602        // Generate `UNEXPIRED_COUNT` kpbs that are with default 3 months expiration. We *should* keep them for the duration of the test
603        let unexpired_kpbs = client
604            .request_key_packages(UNEXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
605            .await
606            .unwrap();
607        let len = client
608            .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
609            .await
610            .unwrap();
611        assert_eq!(len, unexpired_kpbs.len());
612        assert_eq!(len, UNEXPIRED_COUNT);
613
614        // Set the keypackage expiration to be in 2 seconds
615        client
616            .set_keypackage_lifetime(std::time::Duration::from_secs(10))
617            .await
618            .unwrap();
619
620        // Generate new keypackages that are normally partially expired 2s after they're requested
621        let partially_expired_kpbs = client
622            .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
623            .await
624            .unwrap();
625        assert_eq!(partially_expired_kpbs.len(), EXPIRED_COUNT);
626
627        // Sleep to trigger the expiration
628        async_std::task::sleep(std::time::Duration::from_secs(10)).await;
629
630        // Request the same number of keypackages. The automatic lifetime-based expiration should take
631        // place and remove old expired keypackages and generate fresh ones instead
632        let fresh_kpbs = client
633            .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
634            .await
635            .unwrap();
636        let len = client
637            .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
638            .await
639            .unwrap();
640        assert_eq!(len, fresh_kpbs.len());
641        assert_eq!(len, EXPIRED_COUNT);
642
643        // Try to deep compare and find kps matching expired and non-expired ones
644        let (unexpired_match, expired_match) =
645            fresh_kpbs
646                .iter()
647                .fold((0usize, 0usize), |(mut unexpired_match, mut expired_match), fresh| {
648                    if unexpired_kpbs.iter().any(|kp| kp == fresh) {
649                        unexpired_match += 1;
650                    } else if partially_expired_kpbs.iter().any(|kpb| kpb == fresh) {
651                        expired_match += 1;
652                    }
653
654                    (unexpired_match, expired_match)
655                });
656
657        // TADA!
658        assert_eq!(unexpired_match, UNEXPIRED_COUNT);
659        assert_eq!(expired_match, 0);
660    }
661
662    #[apply(all_cred_cipher)]
663    #[wasm_bindgen_test]
664    async fn new_keypackage_has_correct_extensions(case: TestCase) {
665        run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
666            Box::pin(async move {
667                let kps = cc
668                    .context
669                    .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 1)
670                    .await
671                    .unwrap();
672                let kp = kps.first().unwrap();
673
674                // make sure it's valid
675                let _ = KeyPackageIn::from(kp.clone())
676                    .standalone_validate(&cc.context.mls_provider().await.unwrap(), ProtocolVersion::Mls10, true)
677                    .await
678                    .unwrap();
679
680                // see https://www.rfc-editor.org/rfc/rfc9420.html#section-10-10
681                assert!(kp.extensions().is_empty());
682
683                assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
684                assert_eq!(
685                    kp.leaf_node().capabilities().ciphersuites().to_vec(),
686                    MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
687                        .iter()
688                        .map(|c| VerifiableCiphersuite::from(*c))
689                        .collect::<Vec<_>>()
690                );
691                assert!(kp.leaf_node().capabilities().proposals().is_empty());
692                assert!(kp.leaf_node().capabilities().extensions().is_empty());
693                assert_eq!(
694                    kp.leaf_node().capabilities().credentials(),
695                    MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
696                );
697            })
698        })
699        .await
700    }
701}