core_crypto/mls/session/
key_package.rs

1use openmls::prelude::{Credential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime};
2use openmls_traits::OpenMlsCryptoProvider;
3use std::collections::HashMap;
4use std::ops::{Deref, DerefMut};
5use tls_codec::{Deserialize, Serialize};
6
7use core_crypto_keystore::{
8    connection::FetchFromDatabase,
9    entities::{EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage},
10};
11use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
12
13use super::{Error, Result};
14use crate::{
15    KeystoreError, MlsError,
16    mls::{credential::CredentialBundle, session::SessionInner},
17    prelude::{MlsCiphersuite, MlsConversationConfiguration, MlsCredentialType, Session},
18};
19
20/// Default number of KeyPackages a client generates the first time it's created
21#[cfg(not(test))]
22pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 100;
23/// Default number of KeyPackages a client generates the first time it's created
24#[cfg(test)]
25pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
26
27/// Default lifetime of all generated KeyPackages. Matches the limit defined in openmls
28pub(crate) const KEYPACKAGE_DEFAULT_LIFETIME: std::time::Duration =
29    std::time::Duration::from_secs(60 * 60 * 24 * 28 * 3); // ~3 months
30
31impl Session {
32    /// Generates a single new keypackage
33    ///
34    /// # Arguments
35    /// * `backend` - the KeyStorage to load the keypackages from
36    ///
37    /// # Errors
38    /// KeyStore and OpenMls errors
39    pub async fn generate_one_keypackage_from_credential_bundle(
40        &self,
41        backend: &MlsCryptoProvider,
42        cs: MlsCiphersuite,
43        cb: &CredentialBundle,
44    ) -> Result<KeyPackage> {
45        match self.inner.read().await.deref() {
46            None => Err(Error::MlsNotInitialized),
47            Some(SessionInner {
48                keypackage_lifetime, ..
49            }) => {
50                let keypackage = KeyPackage::builder()
51                    .leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
52                    .key_package_lifetime(Lifetime::new(keypackage_lifetime.as_secs()))
53                    .build(
54                        CryptoConfig {
55                            ciphersuite: cs.into(),
56                            version: openmls::versions::ProtocolVersion::default(),
57                        },
58                        backend,
59                        &cb.signature_key,
60                        CredentialWithKey {
61                            credential: cb.credential.clone(),
62                            signature_key: cb.signature_key.public().into(),
63                        },
64                    )
65                    .await
66                    .map_err(KeystoreError::wrap("building keypackage"))?;
67
68                Ok(keypackage)
69            }
70        }
71    }
72
73    /// Requests `count` keying material to be present and returns
74    /// a reference to it for the consumer to copy/clone.
75    ///
76    /// # Arguments
77    /// * `count` - number of [openmls::key_packages::KeyPackage] to generate
78    /// * `ciphersuite` - of [openmls::key_packages::KeyPackage] to generate
79    /// * `backend` - the KeyStorage to load the keypackages from
80    ///
81    /// # Errors
82    /// KeyStore and OpenMls errors
83    pub async fn request_key_packages(
84        &self,
85        count: usize,
86        ciphersuite: MlsCiphersuite,
87        credential_type: MlsCredentialType,
88        backend: &MlsCryptoProvider,
89    ) -> Result<Vec<KeyPackage>> {
90        // Auto-prune expired keypackages on request
91        self.prune_keypackages(backend, &[]).await?;
92        use core_crypto_keystore::CryptoKeystoreMls as _;
93
94        let mut existing_kps = backend
95            .key_store()
96            .mls_fetch_keypackages::<KeyPackage>(count as u32)
97            .await.map_err(KeystoreError::wrap("fetching mls keypackages"))?
98            .into_iter()
99            // TODO: do this filtering in SQL when the schema is updated. Tracking issue: WPB-9599
100            .filter(|kp|
101                kp.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(kp.leaf_node().credential().credential_type()) == credential_type)
102            .collect::<Vec<_>>();
103
104        let kpb_count = existing_kps.len();
105        let mut kps = if count > kpb_count {
106            let to_generate = count - kpb_count;
107            let cb = self
108                .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
109                .await?;
110            self.generate_new_keypackages(backend, ciphersuite, &cb, to_generate)
111                .await?
112        } else {
113            vec![]
114        };
115
116        existing_kps.reverse();
117
118        kps.append(&mut existing_kps);
119        Ok(kps)
120    }
121
122    pub(crate) async fn generate_new_keypackages(
123        &self,
124        backend: &MlsCryptoProvider,
125        ciphersuite: MlsCiphersuite,
126        cb: &CredentialBundle,
127        count: usize,
128    ) -> Result<Vec<KeyPackage>> {
129        let mut kps = Vec::with_capacity(count);
130
131        for _ in 0..count {
132            let kp = self
133                .generate_one_keypackage_from_credential_bundle(backend, ciphersuite, cb)
134                .await?;
135            kps.push(kp);
136        }
137
138        Ok(kps)
139    }
140
141    /// Returns the count of valid, non-expired, unclaimed keypackages in store
142    pub async fn valid_keypackages_count(
143        &self,
144        backend: &MlsCryptoProvider,
145        ciphersuite: MlsCiphersuite,
146        credential_type: MlsCredentialType,
147    ) -> Result<usize> {
148        let kps: Vec<MlsKeyPackage> = backend
149            .key_store()
150            .find_all(EntityFindParams::default())
151            .await
152            .map_err(KeystoreError::wrap("finding all key packages"))?;
153
154        let mut valid_count = 0;
155        for kp in kps
156            .into_iter()
157            .map(|kp| core_crypto_keystore::deser::<KeyPackage>(&kp.keypackage))
158            // TODO: do this filtering in SQL when the schema is updated. Tracking issue: WPB-9599
159            .filter(|kp| {
160                kp.as_ref()
161                    .map(|b| b.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(b.leaf_node().credential().credential_type()) == credential_type)
162                    .unwrap_or_default()
163            })
164        {
165            let kp = kp.map_err(KeystoreError::wrap("counting valid keypackages"))?;
166            if !Self::is_mls_keypackage_expired(&kp) {
167                valid_count += 1;
168            }
169        }
170
171        Ok(valid_count)
172    }
173
174    /// Checks if a given OpenMLS [`KeyPackage`] is expired by looking through its extensions,
175    /// finding a lifetime extension and checking if it's valid.
176    fn is_mls_keypackage_expired(kp: &KeyPackage) -> bool {
177        let Some(lifetime) = kp.leaf_node().life_time() else {
178            return false;
179        };
180
181        !(lifetime.has_acceptable_range() && lifetime.is_valid())
182    }
183
184    /// Prune the provided KeyPackageRefs from the keystore
185    ///
186    /// Warning: Despite this API being public, the caller should know what they're doing.
187    /// Provided KeypackageRefs **will** be purged regardless of their expiration state, so please be wary of what you are doing if you directly call this API.
188    /// This could result in still valid, uploaded keypackages being pruned from the system and thus being impossible to find when referenced in a future Welcome message.
189    pub async fn prune_keypackages(&self, backend: &MlsCryptoProvider, refs: &[KeyPackageRef]) -> Result<()> {
190        let keystore = backend.keystore();
191        let kps = self.find_all_keypackages(&keystore).await?;
192        let _ = self._prune_keypackages(&kps, &keystore, refs).await?;
193        Ok(())
194    }
195
196    pub(crate) async fn prune_keypackages_and_credential(
197        &mut self,
198        backend: &MlsCryptoProvider,
199        refs: &[KeyPackageRef],
200    ) -> Result<()> {
201        match self.inner.write().await.deref_mut() {
202            None => Err(Error::MlsNotInitialized),
203            Some(SessionInner { identities, .. }) => {
204                let keystore = backend.key_store();
205                let kps = self.find_all_keypackages(keystore).await?;
206                let kp_to_delete = self._prune_keypackages(&kps, keystore, refs).await?;
207
208                // Let's group KeyPackages by Credential
209                let mut grouped_kps = HashMap::<Vec<u8>, Vec<KeyPackageRef>>::new();
210                for (_, kp) in &kps {
211                    let cred = kp
212                        .leaf_node()
213                        .credential()
214                        .tls_serialize_detached()
215                        .map_err(Error::tls_serialize("keypackage"))?;
216                    let kp_ref = kp
217                        .hash_ref(backend.crypto())
218                        .map_err(MlsError::wrap("computing keypackage hashref"))?;
219                    grouped_kps
220                        .entry(cred)
221                        .and_modify(|kprfs| kprfs.push(kp_ref.clone()))
222                        .or_insert(vec![kp_ref]);
223                }
224
225                for (credential, kps) in &grouped_kps {
226                    // If all KeyPackages are to be deleted for this given Credential
227                    let all_to_delete = kps.iter().all(|kpr| kp_to_delete.contains(&kpr.as_slice()));
228                    if all_to_delete {
229                        // then delete this Credential
230                        backend
231                            .keystore()
232                            .cred_delete_by_credential(credential.clone())
233                            .await
234                            .map_err(KeystoreError::wrap("deleting credential"))?;
235                        let credential = Credential::tls_deserialize(&mut credential.as_slice())
236                            .map_err(Error::tls_deserialize("credential"))?;
237                        identities.remove(&credential).await?;
238                    }
239                }
240                Ok(())
241            }
242        }
243    }
244
245    /// Deletes all expired KeyPackages plus the ones in `refs`. It also deletes all associated:
246    /// * HPKE private keys
247    /// * HPKE Encryption KeyPairs
248    /// * Signature KeyPairs & Credentials (use [Self::prune_keypackages_and_credential])
249    async fn _prune_keypackages<'a>(
250        &self,
251        kps: &'a [(MlsKeyPackage, KeyPackage)],
252        keystore: &CryptoKeystore,
253        refs: &[KeyPackageRef],
254    ) -> Result<Vec<&'a [u8]>, Error> {
255        let kp_to_delete: Vec<_> = kps
256            .iter()
257            .filter_map(|(store_kp, kp)| {
258                let is_expired = Self::is_mls_keypackage_expired(kp);
259                let mut to_delete = is_expired;
260                if !(is_expired || refs.is_empty()) {
261                    // not expired and there are some refs to check
262                    // then delete it when it's found in the refs
263                    to_delete = refs.iter().any(|r| r.as_slice() == store_kp.keypackage_ref);
264                }
265
266                to_delete.then_some((kp, &store_kp.keypackage_ref))
267            })
268            .collect();
269
270        for (kp, kp_ref) in &kp_to_delete {
271            // TODO: maybe rewrite this to optimize it. But honestly it's called so rarely and on a so tiny amount of data. Tacking issue: WPB-9600
272            keystore
273                .remove::<MlsKeyPackage, &[u8]>(kp_ref.as_slice())
274                .await
275                .map_err(KeystoreError::wrap("removing key package from keystore"))?;
276            keystore
277                .remove::<MlsHpkePrivateKey, &[u8]>(kp.hpke_init_key().as_slice())
278                .await
279                .map_err(KeystoreError::wrap("removing private key from keystore"))?;
280            keystore
281                .remove::<MlsEncryptionKeyPair, &[u8]>(kp.leaf_node().encryption_key().as_slice())
282                .await
283                .map_err(KeystoreError::wrap("removing encryption keypair from keystore"))?;
284        }
285
286        let kp_to_delete = kp_to_delete
287            .into_iter()
288            .map(|(_, kpref)| &kpref[..])
289            .collect::<Vec<_>>();
290
291        Ok(kp_to_delete)
292    }
293
294    async fn find_all_keypackages(&self, keystore: &CryptoKeystore) -> Result<Vec<(MlsKeyPackage, KeyPackage)>> {
295        let kps: Vec<MlsKeyPackage> = keystore
296            .find_all(EntityFindParams::default())
297            .await
298            .map_err(KeystoreError::wrap("finding all keypackages"))?;
299
300        let kps = kps
301            .into_iter()
302            .map(|raw_kp| -> Result<_> {
303                let kp = core_crypto_keystore::deser::<KeyPackage>(&raw_kp.keypackage)
304                    .map_err(KeystoreError::wrap("deserializing keypackage"))?;
305                Ok((raw_kp, kp))
306            })
307            .collect::<Result<Vec<_>, _>>()?;
308
309        Ok(kps)
310    }
311
312    /// Allows to set the current default keypackage lifetime extension duration.
313    /// It will be embedded in the [openmls::key_packages::KeyPackage]'s [openmls::extensions::LifetimeExtension]
314    #[cfg(test)]
315    pub async fn set_keypackage_lifetime(&self, duration: std::time::Duration) -> Result<()> {
316        use std::ops::DerefMut;
317        match self.inner.write().await.deref_mut() {
318            None => Err(Error::MlsNotInitialized),
319            Some(SessionInner {
320                keypackage_lifetime, ..
321            }) => {
322                *keypackage_lifetime = duration;
323                Ok(())
324            }
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use openmls::prelude::{KeyPackage, KeyPackageIn, KeyPackageRef, ProtocolVersion};
332    use openmls_traits::OpenMlsCryptoProvider;
333    use openmls_traits::types::VerifiableCiphersuite;
334    use wasm_bindgen_test::*;
335
336    use mls_crypto_provider::MlsCryptoProvider;
337
338    use crate::e2e_identity::enrollment::test_utils::{e2ei_enrollment, init_activation_or_rotation, noop_restore};
339    use crate::prelude::MlsConversationConfiguration;
340    use crate::prelude::key_package::INITIAL_KEYING_MATERIAL_COUNT;
341    use crate::test_utils::*;
342    use core_crypto_keystore::DatabaseKey;
343
344    use super::Session;
345
346    wasm_bindgen_test_configure!(run_in_browser);
347
348    #[apply(all_cred_cipher)]
349    #[wasm_bindgen_test]
350    async fn can_assess_keypackage_expiration(case: TestContext) {
351        let [session] = case.sessions().await;
352        let (cs, ct) = (case.ciphersuite(), case.credential_type);
353        let key = DatabaseKey::generate();
354        let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
355        let x509_test_chain = if case.is_x509() {
356            let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
357            x509_test_chain.register_with_provider(&backend).await;
358            Some(x509_test_chain)
359        } else {
360            None
361        };
362
363        backend.new_transaction().await.unwrap();
364        let session = session.session;
365        session
366            .random_generate(
367                &case,
368                x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
369                false,
370            )
371            .await
372            .unwrap();
373
374        // 90-day standard expiration
375        let kp_std_exp = session.generate_one_keypackage(&backend, cs, ct).await.unwrap();
376        assert!(!Session::is_mls_keypackage_expired(&kp_std_exp));
377
378        // 1-second expiration
379        session
380            .set_keypackage_lifetime(std::time::Duration::from_secs(1))
381            .await
382            .unwrap();
383        let kp_1s_exp = session.generate_one_keypackage(&backend, cs, ct).await.unwrap();
384        // Sleep 2 seconds to make sure we make the kp expire
385        async_std::task::sleep(std::time::Duration::from_secs(2)).await;
386        assert!(Session::is_mls_keypackage_expired(&kp_1s_exp));
387    }
388
389    #[apply(all_cred_cipher)]
390    #[wasm_bindgen_test]
391    async fn requesting_x509_key_packages_after_basic(case: TestContext) {
392        // Basic test case
393        if !case.is_basic() {
394            return;
395        }
396        run_test_with_client_ids(case.clone(), ["alice"], move |[mut session_context]| {
397            Box::pin(async move {
398                let signature_scheme = case.signature_scheme();
399                let cipher_suite = case.ciphersuite();
400
401                // Generate 5 Basic key packages first
402                let _basic_key_packages = session_context
403                    .transaction
404                    .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::Basic, 5)
405                    .await
406                    .unwrap();
407
408                // Set up E2E identity
409                let test_chain = x509::X509TestChain::init_for_random_clients(signature_scheme, 1);
410
411                let (mut enrollment, cert_chain) = e2ei_enrollment(
412                    &mut session_context,
413                    &case,
414                    &test_chain,
415                    None,
416                    false,
417                    init_activation_or_rotation,
418                    noop_restore,
419                )
420                .await
421                .unwrap();
422
423                let _rotate_bundle = session_context
424                    .transaction
425                    .save_x509_credential(&mut enrollment, cert_chain)
426                    .await
427                    .unwrap();
428
429                // E2E identity has been set up correctly
430                assert!(
431                    session_context
432                        .transaction
433                        .e2ei_is_enabled(signature_scheme)
434                        .await
435                        .unwrap()
436                );
437
438                // Request X509 key packages
439                let x509_key_packages = session_context
440                    .transaction
441                    .get_or_create_client_keypackages(cipher_suite, MlsCredentialType::X509, 5)
442                    .await
443                    .unwrap();
444
445                // Verify that the key packages are X509
446                assert!(x509_key_packages.iter().all(|kp| MlsCredentialType::X509
447                    == MlsCredentialType::from(kp.leaf_node().credential().credential_type())));
448            })
449        })
450        .await
451    }
452
453    #[apply(all_cred_cipher)]
454    #[wasm_bindgen_test]
455    async fn generates_correct_number_of_kpbs(case: TestContext) {
456        run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
457            Box::pin(async move {
458                const N: usize = 2;
459                const COUNT: usize = 109;
460
461                let init = cc.transaction.count_entities().await;
462                assert_eq!(init.key_package, INITIAL_KEYING_MATERIAL_COUNT);
463                assert_eq!(init.encryption_keypair, INITIAL_KEYING_MATERIAL_COUNT);
464                assert_eq!(init.hpke_private_key, INITIAL_KEYING_MATERIAL_COUNT);
465                assert_eq!(init.credential, 1);
466                assert_eq!(init.signature_keypair, 1);
467
468                // since 'delete_keypackages' will evict all Credentials unlinked to a KeyPackage, each iteration
469                // generates 1 extra KeyPackage in order for this Credential no to be evicted and next iteration sto succeed.
470                let transactional_provider = cc.transaction.mls_provider().await.unwrap();
471                let crypto_provider = transactional_provider.crypto();
472                let mut pinned_kp = None;
473
474                let mut prev_kps: Option<Vec<KeyPackage>> = None;
475                for _ in 0..N {
476                    let mut kps = cc
477                        .transaction
478                        .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, COUNT + 1)
479                        .await
480                        .unwrap();
481
482                    // this will always be the same, first KeyPackage
483                    pinned_kp = Some(kps.pop().unwrap());
484
485                    assert_eq!(kps.len(), COUNT);
486                    let after_creation = cc.transaction.count_entities().await;
487                    assert_eq!(after_creation.key_package, COUNT + 1);
488                    assert_eq!(after_creation.encryption_keypair, COUNT + 1);
489                    assert_eq!(after_creation.hpke_private_key, COUNT + 1);
490                    assert_eq!(after_creation.credential, 1);
491
492                    let kpbs_refs = kps
493                        .iter()
494                        .map(|kp| kp.hash_ref(crypto_provider).unwrap())
495                        .collect::<Vec<KeyPackageRef>>();
496
497                    if let Some(pkpbs) = prev_kps.replace(kps) {
498                        let pkpbs_refs = pkpbs
499                            .into_iter()
500                            .map(|kpb| kpb.hash_ref(crypto_provider).unwrap())
501                            .collect::<Vec<KeyPackageRef>>();
502
503                        let has_duplicates = kpbs_refs.iter().any(|href| pkpbs_refs.contains(href));
504                        // Make sure we have no previous keypackages found (that were pruned) in our new batch of KPs
505                        assert!(!has_duplicates);
506                    }
507                    cc.transaction.delete_keypackages(&kpbs_refs).await.unwrap();
508                }
509
510                let count = cc
511                    .transaction
512                    .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
513                    .await
514                    .unwrap();
515                assert_eq!(count, 1);
516
517                let pinned_kpr = pinned_kp.unwrap().hash_ref(crypto_provider).unwrap();
518                cc.transaction.delete_keypackages(&[pinned_kpr]).await.unwrap();
519                let count = cc
520                    .transaction
521                    .client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
522                    .await
523                    .unwrap();
524                assert_eq!(count, 0);
525                let after_delete = cc.transaction.count_entities().await;
526                assert_eq!(after_delete.key_package, 0);
527                assert_eq!(after_delete.encryption_keypair, 0);
528                assert_eq!(after_delete.hpke_private_key, 0);
529                assert_eq!(after_delete.credential, 0);
530            })
531        })
532        .await
533    }
534
535    #[apply(all_cred_cipher)]
536    #[wasm_bindgen_test]
537    async fn automatically_prunes_lifetime_expired_keypackages(case: TestContext) {
538        let [session] = case.sessions().await;
539        const UNEXPIRED_COUNT: usize = 125;
540        const EXPIRED_COUNT: usize = 200;
541        let key = DatabaseKey::generate();
542        let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
543        let x509_test_chain = if case.is_x509() {
544            let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
545            x509_test_chain.register_with_provider(&backend).await;
546            Some(x509_test_chain)
547        } else {
548            None
549        };
550        backend.new_transaction().await.unwrap();
551        let session = session.session().await;
552        session
553            .random_generate(
554                &case,
555                x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
556                false,
557            )
558            .await
559            .unwrap();
560
561        // Generate `UNEXPIRED_COUNT` kpbs that are with default 3 months expiration. We *should* keep them for the duration of the test
562        let unexpired_kpbs = session
563            .request_key_packages(UNEXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
564            .await
565            .unwrap();
566        let len = session
567            .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
568            .await
569            .unwrap();
570        assert_eq!(len, unexpired_kpbs.len());
571        assert_eq!(len, UNEXPIRED_COUNT);
572
573        // Set the keypackage expiration to be in 2 seconds
574        session
575            .set_keypackage_lifetime(std::time::Duration::from_secs(10))
576            .await
577            .unwrap();
578
579        // Generate new keypackages that are normally partially expired 2s after they're requested
580        let partially_expired_kpbs = session
581            .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
582            .await
583            .unwrap();
584        assert_eq!(partially_expired_kpbs.len(), EXPIRED_COUNT);
585
586        // Sleep to trigger the expiration
587        async_std::task::sleep(std::time::Duration::from_secs(10)).await;
588
589        // Request the same number of keypackages. The automatic lifetime-based expiration should take
590        // place and remove old expired keypackages and generate fresh ones instead
591        let fresh_kpbs = session
592            .request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
593            .await
594            .unwrap();
595        let len = session
596            .valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
597            .await
598            .unwrap();
599        assert_eq!(len, fresh_kpbs.len());
600        assert_eq!(len, EXPIRED_COUNT);
601
602        // Try to deep compare and find kps matching expired and non-expired ones
603        let (unexpired_match, expired_match) =
604            fresh_kpbs
605                .iter()
606                .fold((0usize, 0usize), |(mut unexpired_match, mut expired_match), fresh| {
607                    if unexpired_kpbs.iter().any(|kp| kp == fresh) {
608                        unexpired_match += 1;
609                    } else if partially_expired_kpbs.iter().any(|kpb| kpb == fresh) {
610                        expired_match += 1;
611                    }
612
613                    (unexpired_match, expired_match)
614                });
615
616        // TADA!
617        assert_eq!(unexpired_match, UNEXPIRED_COUNT);
618        assert_eq!(expired_match, 0);
619    }
620
621    #[apply(all_cred_cipher)]
622    #[wasm_bindgen_test]
623    async fn new_keypackage_has_correct_extensions(case: TestContext) {
624        run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
625            Box::pin(async move {
626                let kps = cc
627                    .transaction
628                    .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 1)
629                    .await
630                    .unwrap();
631                let kp = kps.first().unwrap();
632
633                // make sure it's valid
634                let _ = KeyPackageIn::from(kp.clone())
635                    .standalone_validate(
636                        &cc.transaction.mls_provider().await.unwrap(),
637                        ProtocolVersion::Mls10,
638                        true,
639                    )
640                    .await
641                    .unwrap();
642
643                // see https://www.rfc-editor.org/rfc/rfc9420.html#section-10-10
644                assert!(kp.extensions().is_empty());
645
646                assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
647                assert_eq!(
648                    kp.leaf_node().capabilities().ciphersuites().to_vec(),
649                    MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
650                        .iter()
651                        .map(|c| VerifiableCiphersuite::from(*c))
652                        .collect::<Vec<_>>()
653                );
654                assert!(kp.leaf_node().capabilities().proposals().is_empty());
655                assert!(kp.leaf_node().capabilities().extensions().is_empty());
656                assert_eq!(
657                    kp.leaf_node().capabilities().credentials(),
658                    MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
659                );
660            })
661        })
662        .await
663    }
664}