core_crypto/mls/session/
key_package.rs

1use std::{sync::Arc, time::Duration};
2
3use core_crypto_keystore::{
4    connection::FetchFromDatabase,
5    entities::{EntityFindParams, StoredEncryptionKeyPair, StoredHpkePrivateKey, StoredKeypackage},
6};
7use openmls::prelude::{CryptoConfig, Lifetime};
8
9use super::{Error, Result};
10use crate::{
11    Credential, CredentialRef, Keypackage, KeypackageRef, KeystoreError, MlsConversationConfiguration, Session,
12    mls::key_package::KeypackageExt,
13};
14
15/// Default number of Keypackages a client generates the first time it's created
16#[cfg(not(test))]
17pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 100;
18/// Default number of Keypackages a client generates the first time it's created
19#[cfg(test)]
20pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
21
22/// Default lifetime of all generated Keypackages. Matches the limit defined in openmls
23pub const KEYPACKAGE_DEFAULT_LIFETIME: Duration = Duration::from_secs(60 * 60 * 24 * 28 * 3); // ~3 months
24
25fn from_stored(stored_keypackage: &StoredKeypackage) -> Result<Keypackage> {
26    core_crypto_keystore::deser::<Keypackage>(&stored_keypackage.keypackage)
27        .map_err(KeystoreError::wrap("deserializing keypackage"))
28        .map_err(Into::into)
29}
30
31impl Session {
32    /// Get an unambiguous credential for the provided ref from the currently-loaded set.
33    async fn credential_from_ref(&self, credential_ref: &CredentialRef) -> Result<Arc<Credential>> {
34        let guard = self.inner.read().await;
35        let identities = &guard.as_ref().ok_or(Error::MlsNotInitialized)?.identities;
36        identities
37            .find_credential_by_public_key(
38                credential_ref.signature_scheme(),
39                credential_ref.r#type(),
40                &credential_ref.public_key().into(),
41            )
42            .await
43            .ok_or(Error::CredentialNotFound(
44                credential_ref.r#type(),
45                credential_ref.signature_scheme(),
46            ))
47    }
48
49    /// Generate a [Keypackage] from the referenced credential.
50    ///
51    /// Makes no attempt to look up or prune existing keypackges.
52    ///
53    /// If `lifetime` is set, the keypackages will expire that span into the future.
54    /// If it is unset, [`KEYPACKAGE_DEFAULT_LIFETIME`] is used.
55    ///
56    /// As a side effect, stores the keypackages and some related data in the keystore.
57    ///
58    /// Must not be fully public, only crate-public, because as it mutates the keystore it must only ever happen within
59    /// a transaction.
60    pub(crate) async fn generate_keypackage(
61        &self,
62        credential_ref: &CredentialRef,
63        lifetime: Option<Duration>,
64    ) -> Result<Keypackage> {
65        let lifetime = Lifetime::new(lifetime.unwrap_or(KEYPACKAGE_DEFAULT_LIFETIME).as_secs());
66        let credential = self.credential_from_ref(credential_ref).await?;
67
68        let config = CryptoConfig {
69            ciphersuite: credential.ciphersuite.into(),
70            version: openmls::versions::ProtocolVersion::default(),
71        };
72
73        Keypackage::builder()
74            .leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
75            .key_package_lifetime(lifetime)
76            .build(
77                config,
78                &self.crypto_provider,
79                &credential.signature_key_pair,
80                credential.to_mls_credential_with_key(),
81            )
82            .await
83            .map_err(Error::keypackage_new())
84    }
85
86    /// Get all [`Keypackage`]s in the database.
87    pub(crate) async fn get_keypackages(&self) -> Result<Vec<Keypackage>> {
88        let stored_keypackages: Vec<StoredKeypackage> = self
89            .crypto_provider
90            .keystore()
91            .find_all(EntityFindParams::default())
92            .await
93            .map_err(KeystoreError::wrap("finding all keypackages"))?;
94
95        let keypackages = stored_keypackages
96            .iter()
97            .map(from_stored)
98            // if any ref from loading all fails to load now, skip it
99            // strictly we could panic, but this is safer--maybe someone removed it concurrently
100            .filter_map(|kp| kp.ok())
101            .collect();
102
103        Ok(keypackages)
104    }
105
106    /// Get all [`KeypackageRef`]s in the database.
107    pub async fn get_keypackage_refs(&self) -> Result<Vec<KeypackageRef>> {
108        self.get_keypackages()
109            .await?
110            .iter()
111            .map(|keypackage| keypackage.make_ref().map_err(Into::into))
112            .collect()
113    }
114
115    /// Load one [`Keypackage`] from its [`KeypackageRef`]
116    pub(crate) async fn load_keypackage(&self, kp_ref: &KeypackageRef) -> Result<Option<Keypackage>> {
117        self.crypto_provider
118            .keystore()
119            .find::<StoredKeypackage>(kp_ref.hash_ref())
120            .await
121            .map_err(KeystoreError::wrap("loading keypackage from database"))?
122            .map(|stored_keypackage| from_stored(&stored_keypackage))
123            .transpose()
124    }
125
126    /// Remove one [`Keypackage`] from the database.
127    ///
128    /// Succeeds silently if the keypackage does not exist in the database.
129    ///
130    /// Implementation note: this must first load and deserialize the keypackage,
131    /// then remove items from three distinct tables.
132    pub(crate) async fn remove_keypackage(&self, kp_ref: &KeypackageRef) -> Result<()> {
133        let Some(kp) = self.load_keypackage(kp_ref).await? else {
134            return Ok(());
135        };
136
137        let db = self.crypto_provider.keystore();
138        db.remove::<StoredKeypackage, _>(kp_ref.hash_ref())
139            .await
140            .map_err(KeystoreError::wrap("removing key package from keystore"))?;
141        db.remove::<StoredHpkePrivateKey, _>(kp.hpke_init_key().as_slice())
142            .await
143            .map_err(KeystoreError::wrap("removing private key from keystore"))?;
144        db.remove::<StoredEncryptionKeyPair, _>(kp.leaf_node().encryption_key().as_slice())
145            .await
146            .map_err(KeystoreError::wrap("removing encryption keypair from keystore"))?;
147
148        Ok(())
149    }
150
151    /// Remove all keypackages associated with this credential.
152    ///
153    /// This is fairly expensive as it must first load all keypackages, then delete those matching the credential.
154    ///
155    /// Implementation note: once it makes it as far as having a list of keypackages, does _not_ short-circuit
156    /// if removing one returns an error. In that case, only the first produced error is returned.
157    /// This helps ensure that as many keypackages for the given credential ref are removed as possible.
158    pub(crate) async fn remove_keypackages_for(&self, credential_ref: &CredentialRef) -> Result<()> {
159        let credential = self.credential_from_ref(credential_ref).await?;
160        let signature_public_key = credential.signature_key_pair.public();
161
162        let mut first_err = None;
163        macro_rules! try_retain_err {
164            ($e:expr) => {
165                match $e {
166                    Err(err) => {
167                        if first_err.is_none() {
168                            first_err = Some(Error::from(err));
169                        }
170                        continue;
171                    }
172                    Ok(val) => val,
173                }
174            };
175        }
176
177        for keypackage in self
178            .get_keypackages()
179            .await?
180            .into_iter()
181            .filter(|keypackage| keypackage.leaf_node().signature_key().as_slice() == signature_public_key)
182        {
183            let kp_ref = try_retain_err!(keypackage.make_ref());
184            try_retain_err!(self.remove_keypackage(&kp_ref).await);
185        }
186
187        match first_err {
188            None => Ok(()),
189            Some(err) => Err(err),
190        }
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use std::time::Duration;
197
198    use core_crypto_keystore::{ConnectionType, DatabaseKey};
199    use mls_crypto_provider::{Database, MlsCryptoProvider};
200    use openmls::prelude::{KeyPackageIn, ProtocolVersion};
201    use openmls_traits::types::VerifiableCiphersuite;
202
203    use crate::{
204        MlsConversationConfiguration,
205        e2e_identity::enrollment::test_utils::{e2ei_enrollment, init_activation_or_rotation, noop_restore},
206        mls::key_package::KeypackageExt as _,
207        test_utils::*,
208    };
209
210    #[apply(all_cred_cipher)]
211    async fn can_assess_keypackage_expiration(case: TestContext) {
212        let [session_context] = case.sessions().await;
213        let key = DatabaseKey::generate();
214        let database = Database::open(ConnectionType::InMemory, &key).await.unwrap();
215        let backend = MlsCryptoProvider::new(database);
216        let x509_test_chain = if case.is_x509() {
217            let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
218            x509_test_chain.register_with_provider(&backend).await;
219            Some(x509_test_chain)
220        } else {
221            None
222        };
223
224        backend.new_transaction().await.unwrap();
225        session_context
226            .session
227            .random_generate(
228                &case,
229                x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
230            )
231            .await
232            .unwrap();
233
234        // 90-day standard expiration
235        let kp_std_exp = session_context.new_keypackage(&case).await;
236        assert!(kp_std_exp.is_valid());
237
238        // 1-second expiration
239        let kp_1s_exp = session_context
240            .new_keypackage_with_lifetime(&case, Some(Duration::from_secs(1)))
241            .await;
242
243        // Sleep 2 seconds to make sure we make the kp expire
244        smol::Timer::after(std::time::Duration::from_secs(2)).await;
245        assert!(!kp_1s_exp.is_valid());
246    }
247
248    #[apply(all_cred_cipher)]
249    async fn requesting_x509_key_packages_after_basic(case: TestContext) {
250        // Basic test case
251        if !case.is_basic() {
252            return;
253        }
254
255        let [session_context] = case.sessions_basic_with_pki_env().await;
256        Box::pin(async move {
257            let signature_scheme = case.signature_scheme();
258
259            // Generate 5 Basic key packages first
260            let mut initial_kp_refs = Vec::new();
261            for _ in 0..5 {
262                let kp = session_context.new_keypackage(&case).await;
263                initial_kp_refs.push(kp.make_ref().unwrap());
264            }
265            initial_kp_refs.sort_by_key(|kp_ref| kp_ref.hash_ref().to_owned());
266
267            // Set up E2E identity
268            let test_chain = session_context.x509_chain_unchecked();
269
270            let (mut enrollment, cert_chain) = e2ei_enrollment(
271                &session_context,
272                &case,
273                test_chain,
274                None,
275                false,
276                init_activation_or_rotation,
277                noop_restore,
278            )
279            .await
280            .unwrap();
281
282            let _rotate_bundle = session_context
283                .transaction
284                .save_x509_credential(&mut enrollment, cert_chain)
285                .await
286                .unwrap();
287
288            // E2E identity has been set up correctly
289            assert!(
290                session_context
291                    .transaction
292                    .e2ei_is_enabled(signature_scheme)
293                    .await
294                    .unwrap()
295            );
296
297            // Request X509 key packages
298            let key_packages = session_context.transaction.get_keypackage_refs().await.unwrap();
299            let (mut from_initial_set, x509_key_packages) = key_packages
300                .into_iter()
301                .partition::<Vec<_>, _>(|kp_ref| initial_kp_refs.contains(kp_ref));
302
303            from_initial_set.sort_by_key(|kp_ref| kp_ref.hash_ref().to_owned());
304            assert_eq!(initial_kp_refs, from_initial_set);
305
306            // Verify that the key packages are X509
307            assert!(
308                x509_key_packages
309                    .iter()
310                    .all(|kp| CredentialType::X509 == kp.credential_type())
311            );
312        })
313        .await
314    }
315
316    #[apply(all_cred_cipher)]
317    async fn new_keypackage_has_correct_extensions(case: TestContext) {
318        let [cc] = case.sessions().await;
319        Box::pin(async move {
320            let kp = cc.new_keypackage(&case).await;
321
322            // make sure it's valid
323            let _ = KeyPackageIn::from(kp.clone())
324                .standalone_validate(
325                    &cc.transaction.mls_provider().await.unwrap(),
326                    ProtocolVersion::Mls10,
327                    true,
328                )
329                .await
330                .unwrap();
331
332            // see https://www.rfc-editor.org/rfc/rfc9420.html#section-10-10
333            assert!(kp.extensions().is_empty());
334
335            assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
336            assert_eq!(
337                kp.leaf_node().capabilities().ciphersuites().to_vec(),
338                MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
339                    .iter()
340                    .map(|c| VerifiableCiphersuite::from(*c))
341                    .collect::<Vec<_>>()
342            );
343            assert!(kp.leaf_node().capabilities().proposals().is_empty());
344            assert!(kp.leaf_node().capabilities().extensions().is_empty());
345            assert_eq!(
346                kp.leaf_node().capabilities().credentials(),
347                MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
348            );
349        })
350        .await
351    }
352}