core_crypto/mls/session/
key_package.rs

1use std::{sync::Arc, time::Duration};
2
3use core_crypto_keystore::{
4    entities::{StoredEncryptionKeyPair, StoredHpkePrivateKey, StoredKeypackage},
5    traits::FetchFromDatabase,
6};
7use openmls::prelude::{CryptoConfig, Lifetime};
8
9use super::{Error, Result};
10use crate::{
11    Credential, CredentialRef, Keypackage, KeypackageRef, KeystoreError, MlsConversationConfiguration, Session,
12    mls::key_package::KeypackageExt,
13};
14
15/// Default number of Keypackages a client generates the first time it's created
16#[cfg(not(test))]
17pub const INITIAL_KEYING_MATERIAL_COUNT: u32 = 100;
18/// Default number of Keypackages a client generates the first time it's created
19#[cfg(test)]
20pub const INITIAL_KEYING_MATERIAL_COUNT: u32 = 10;
21
22/// Default lifetime of all generated Keypackages. Matches the limit defined in openmls
23pub const KEYPACKAGE_DEFAULT_LIFETIME: Duration = Duration::from_secs(60 * 60 * 24 * 28 * 3); // ~3 months
24
25fn from_stored(stored_keypackage: &StoredKeypackage) -> Result<Keypackage> {
26    core_crypto_keystore::deser::<Keypackage>(&stored_keypackage.keypackage)
27        .map_err(KeystoreError::wrap("deserializing keypackage"))
28        .map_err(Into::into)
29}
30
31impl Session {
32    /// Get an unambiguous credential for the provided ref from the currently-loaded set.
33    async fn credential_from_ref(&self, credential_ref: &CredentialRef) -> Result<Arc<Credential>> {
34        let identities = self.identities.read().await;
35        identities
36            .find_credential_by_public_key(
37                credential_ref.signature_scheme(),
38                credential_ref.r#type(),
39                &credential_ref.public_key().into(),
40            )
41            .await
42            .ok_or(Error::CredentialNotFound(
43                credential_ref.r#type(),
44                credential_ref.signature_scheme(),
45            ))
46    }
47
48    /// Generate a [Keypackage] from the referenced credential.
49    ///
50    /// Makes no attempt to look up or prune existing keypackges.
51    ///
52    /// If `lifetime` is set, the keypackages will expire that span into the future.
53    /// If it is unset, [`KEYPACKAGE_DEFAULT_LIFETIME`] is used.
54    ///
55    /// As a side effect, stores the keypackages and some related data in the keystore.
56    ///
57    /// Must not be fully public, only crate-public, because as it mutates the keystore it must only ever happen within
58    /// a transaction.
59    pub(crate) async fn generate_keypackage(
60        &self,
61        credential_ref: &CredentialRef,
62        lifetime: Option<Duration>,
63    ) -> Result<Keypackage> {
64        let lifetime = Lifetime::new(lifetime.unwrap_or(KEYPACKAGE_DEFAULT_LIFETIME).as_secs());
65        let credential = self.credential_from_ref(credential_ref).await?;
66
67        let config = CryptoConfig {
68            ciphersuite: credential.ciphersuite.into(),
69            version: openmls::versions::ProtocolVersion::default(),
70        };
71
72        Keypackage::builder()
73            .leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
74            .key_package_lifetime(lifetime)
75            .build(
76                config,
77                &self.crypto_provider,
78                &credential.signature_key_pair,
79                credential.to_mls_credential_with_key(),
80            )
81            .await
82            .map_err(Error::keypackage_new())
83    }
84
85    /// Get all [`Keypackage`]s in the database.
86    pub(crate) async fn get_keypackages(&self) -> Result<Vec<Keypackage>> {
87        let stored_keypackages: Vec<StoredKeypackage> = self
88            .crypto_provider
89            .keystore()
90            .load_all()
91            .await
92            .map_err(KeystoreError::wrap("finding all keypackages"))?;
93
94        let keypackages = stored_keypackages
95            .iter()
96            .map(from_stored)
97            // if any ref from loading all fails to load now, skip it
98            // strictly we could panic, but this is safer--maybe someone removed it concurrently
99            .filter_map(|kp| kp.ok())
100            .collect();
101
102        Ok(keypackages)
103    }
104
105    /// Get all [`KeypackageRef`]s in the database.
106    pub async fn get_keypackage_refs(&self) -> Result<Vec<KeypackageRef>> {
107        self.get_keypackages()
108            .await?
109            .iter()
110            .map(|keypackage| keypackage.make_ref().map_err(Into::into))
111            .collect()
112    }
113
114    /// Load one [`Keypackage`] from its [`KeypackageRef`]
115    pub(crate) async fn load_keypackage(&self, kp_ref: &KeypackageRef) -> Result<Option<Keypackage>> {
116        self.crypto_provider
117            .keystore()
118            .get_borrowed::<StoredKeypackage>(kp_ref.hash_ref())
119            .await
120            .map_err(KeystoreError::wrap("loading keypackage from database"))?
121            .map(|stored_keypackage| from_stored(&stored_keypackage))
122            .transpose()
123    }
124
125    /// Remove one [`Keypackage`] from the database.
126    ///
127    /// Succeeds silently if the keypackage does not exist in the database.
128    ///
129    /// Implementation note: this must first load and deserialize the keypackage,
130    /// then remove items from three distinct tables.
131    pub(crate) async fn remove_keypackage(&self, kp_ref: &KeypackageRef) -> Result<()> {
132        let Some(kp) = self.load_keypackage(kp_ref).await? else {
133            return Ok(());
134        };
135
136        let db = self.crypto_provider.keystore();
137        db.remove_borrowed::<StoredKeypackage>(kp_ref.hash_ref())
138            .await
139            .map_err(KeystoreError::wrap("removing key package from keystore"))?;
140        db.remove_borrowed::<StoredHpkePrivateKey>(kp.hpke_init_key().as_slice())
141            .await
142            .map_err(KeystoreError::wrap("removing private key from keystore"))?;
143        db.remove_borrowed::<StoredEncryptionKeyPair>(kp.leaf_node().encryption_key().as_slice())
144            .await
145            .map_err(KeystoreError::wrap("removing encryption keypair from keystore"))?;
146
147        Ok(())
148    }
149
150    /// Remove all keypackages associated with this credential.
151    ///
152    /// This is fairly expensive as it must first load all keypackages, then delete those matching the credential.
153    ///
154    /// Implementation note: once it makes it as far as having a list of keypackages, does _not_ short-circuit
155    /// if removing one returns an error. In that case, only the first produced error is returned.
156    /// This helps ensure that as many keypackages for the given credential ref are removed as possible.
157    pub(crate) async fn remove_keypackages_for(&self, credential_ref: &CredentialRef) -> Result<()> {
158        let credential = self.credential_from_ref(credential_ref).await?;
159        let signature_public_key = credential.signature_key_pair.public();
160
161        let mut first_err = None;
162        macro_rules! try_retain_err {
163            ($e:expr) => {
164                match $e {
165                    Err(err) => {
166                        if first_err.is_none() {
167                            first_err = Some(Error::from(err));
168                        }
169                        continue;
170                    }
171                    Ok(val) => val,
172                }
173            };
174        }
175
176        for keypackage in self
177            .get_keypackages()
178            .await?
179            .into_iter()
180            .filter(|keypackage| keypackage.leaf_node().signature_key().as_slice() == signature_public_key)
181        {
182            let kp_ref = try_retain_err!(keypackage.make_ref());
183            try_retain_err!(self.remove_keypackage(&kp_ref).await);
184        }
185
186        match first_err {
187            None => Ok(()),
188            Some(err) => Err(err),
189        }
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use std::time::Duration;
196
197    use core_crypto_keystore::{ConnectionType, DatabaseKey};
198    use mls_crypto_provider::{Database, MlsCryptoProvider};
199    use openmls::prelude::{KeyPackageIn, ProtocolVersion};
200    use openmls_traits::types::VerifiableCiphersuite;
201
202    use crate::{
203        MlsConversationConfiguration,
204        e2e_identity::enrollment::test_utils::{e2ei_enrollment, init_activation_or_rotation, noop_restore},
205        mls::key_package::KeypackageExt as _,
206        test_utils::*,
207    };
208
209    #[apply(all_cred_cipher)]
210    async fn can_assess_keypackage_expiration(case: TestContext) {
211        let [session_context] = case.sessions().await;
212        let key = DatabaseKey::generate();
213        let database = Database::open(ConnectionType::InMemory, &key).await.unwrap();
214        let backend = MlsCryptoProvider::new(database);
215        let x509_test_chain = if case.is_x509() {
216            let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
217            x509_test_chain.register_with_provider(&backend).await;
218            Some(x509_test_chain)
219        } else {
220            None
221        };
222
223        backend.new_transaction().await.unwrap();
224        session_context
225            .random_generate(
226                &case,
227                x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
228            )
229            .await
230            .unwrap();
231
232        // 90-day standard expiration
233        let kp_std_exp = session_context.new_keypackage(&case).await;
234        assert!(kp_std_exp.is_valid());
235
236        // 1-second expiration
237        let kp_1s_exp = session_context
238            .new_keypackage_with_lifetime(&case, Some(Duration::from_secs(1)))
239            .await;
240
241        // Sleep 2 seconds to make sure we make the kp expire
242        smol::Timer::after(std::time::Duration::from_secs(2)).await;
243        assert!(!kp_1s_exp.is_valid());
244    }
245
246    #[apply(all_cred_cipher)]
247    async fn requesting_x509_key_packages_after_basic(case: TestContext) {
248        // Basic test case
249        if !case.is_basic() {
250            return;
251        }
252
253        let [session_context] = case.sessions_basic_with_pki_env().await;
254        Box::pin(async move {
255            let signature_scheme = case.signature_scheme();
256
257            // Generate 5 Basic key packages first
258            let mut initial_kp_refs = Vec::new();
259            for _ in 0..5 {
260                let kp = session_context.new_keypackage(&case).await;
261                initial_kp_refs.push(kp.make_ref().unwrap());
262            }
263            initial_kp_refs.sort_by_key(|kp_ref| kp_ref.hash_ref().to_owned());
264
265            // Set up E2E identity
266            let test_chain = session_context.x509_chain_unchecked();
267
268            let (mut enrollment, cert_chain) = e2ei_enrollment(
269                &session_context.transaction,
270                &case,
271                test_chain,
272                &session_context.get_e2ei_client_id().await.to_uri(),
273                false,
274                init_activation_or_rotation,
275                noop_restore,
276            )
277            .await
278            .unwrap();
279
280            let _rotate_bundle = session_context
281                .transaction
282                .save_x509_credential(&mut enrollment, cert_chain)
283                .await
284                .unwrap();
285
286            // E2E identity has been set up correctly
287            assert!(
288                session_context
289                    .transaction
290                    .e2ei_is_enabled(signature_scheme)
291                    .await
292                    .unwrap()
293            );
294
295            // Request X509 key packages
296            let key_packages = session_context.transaction.get_keypackage_refs().await.unwrap();
297            let (mut from_initial_set, x509_key_packages) = key_packages
298                .into_iter()
299                .partition::<Vec<_>, _>(|kp_ref| initial_kp_refs.contains(kp_ref));
300
301            from_initial_set.sort_by_key(|kp_ref| kp_ref.hash_ref().to_owned());
302            assert_eq!(initial_kp_refs, from_initial_set);
303
304            // Verify that the key packages are X509
305            assert!(
306                x509_key_packages
307                    .iter()
308                    .all(|kp| CredentialType::X509 == kp.credential_type())
309            );
310        })
311        .await
312    }
313
314    #[apply(all_cred_cipher)]
315    async fn new_keypackage_has_correct_extensions(case: TestContext) {
316        let [cc] = case.sessions().await;
317        Box::pin(async move {
318            let kp = cc.new_keypackage(&case).await;
319
320            // make sure it's valid
321            let _ = KeyPackageIn::from(kp.clone())
322                .standalone_validate(
323                    &cc.transaction.mls_provider().await.unwrap(),
324                    ProtocolVersion::Mls10,
325                    true,
326                )
327                .await
328                .unwrap();
329
330            // see https://www.rfc-editor.org/rfc/rfc9420.html#section-10-10
331            assert!(kp.extensions().is_empty());
332
333            assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
334            assert_eq!(
335                kp.leaf_node().capabilities().ciphersuites().to_vec(),
336                MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
337                    .iter()
338                    .map(|c| VerifiableCiphersuite::from(*c))
339                    .collect::<Vec<_>>()
340            );
341            assert!(kp.leaf_node().capabilities().proposals().is_empty());
342            assert!(kp.leaf_node().capabilities().extensions().is_empty());
343            assert_eq!(
344                kp.leaf_node().capabilities().credentials(),
345                MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
346            );
347        })
348        .await
349    }
350
351    #[apply(all_cred_cipher)]
352    async fn can_store_and_load_key_packages(case: TestContext) {
353        let [cc] = case.sessions().await;
354
355        // generate a keypackage; automatically saves it
356        let kp = cc.new_keypackage(&case).await;
357
358        let all_keypackages = cc.session.read().await.get_keypackages().await.unwrap();
359        assert_eq!(all_keypackages[0], kp);
360
361        let kp_ref = kp.make_ref().unwrap();
362        let by_ref = cc.session.read().await.load_keypackage(&kp_ref).await.unwrap().unwrap();
363        assert_eq!(kp, by_ref);
364    }
365}