core_crypto/mls/session/
key_package.rs1use std::{sync::Arc, time::Duration};
2
3use core_crypto_keystore::{
4 entities::{StoredEncryptionKeyPair, StoredHpkePrivateKey, StoredKeypackage},
5 traits::FetchFromDatabase,
6};
7use openmls::prelude::{CryptoConfig, Lifetime};
8
9use super::{Error, Result};
10use crate::{
11 Credential, CredentialRef, Keypackage, KeypackageRef, KeystoreError, MlsConversationConfiguration, Session,
12 mls::key_package::KeypackageExt,
13};
14
15#[cfg(not(test))]
17pub const INITIAL_KEYING_MATERIAL_COUNT: u32 = 100;
18#[cfg(test)]
20pub const INITIAL_KEYING_MATERIAL_COUNT: u32 = 10;
21
22pub const KEYPACKAGE_DEFAULT_LIFETIME: Duration = Duration::from_secs(60 * 60 * 24 * 28 * 3); fn from_stored(stored_keypackage: &StoredKeypackage) -> Result<Keypackage> {
26 core_crypto_keystore::deser::<Keypackage>(&stored_keypackage.keypackage)
27 .map_err(KeystoreError::wrap("deserializing keypackage"))
28 .map_err(Into::into)
29}
30
31impl Session {
32 async fn credential_from_ref(&self, credential_ref: &CredentialRef) -> Result<Arc<Credential>> {
34 let identities = self.identities.read().await;
35 identities
36 .find_credential_by_public_key(
37 credential_ref.signature_scheme(),
38 credential_ref.r#type(),
39 &credential_ref.public_key().into(),
40 )
41 .await
42 .ok_or(Error::CredentialNotFound(
43 credential_ref.r#type(),
44 credential_ref.signature_scheme(),
45 ))
46 }
47
48 pub(crate) async fn generate_keypackage(
60 &self,
61 credential_ref: &CredentialRef,
62 lifetime: Option<Duration>,
63 ) -> Result<Keypackage> {
64 let lifetime = Lifetime::new(lifetime.unwrap_or(KEYPACKAGE_DEFAULT_LIFETIME).as_secs());
65 let credential = self.credential_from_ref(credential_ref).await?;
66
67 let config = CryptoConfig {
68 ciphersuite: credential.ciphersuite.into(),
69 version: openmls::versions::ProtocolVersion::default(),
70 };
71
72 Keypackage::builder()
73 .leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
74 .key_package_lifetime(lifetime)
75 .build(
76 config,
77 &self.crypto_provider,
78 &credential.signature_key_pair,
79 credential.to_mls_credential_with_key(),
80 )
81 .await
82 .map_err(Error::keypackage_new())
83 }
84
85 pub(crate) async fn get_keypackages(&self) -> Result<Vec<Keypackage>> {
87 let stored_keypackages: Vec<StoredKeypackage> = self
88 .crypto_provider
89 .keystore()
90 .load_all()
91 .await
92 .map_err(KeystoreError::wrap("finding all keypackages"))?;
93
94 let keypackages = stored_keypackages
95 .iter()
96 .map(from_stored)
97 .filter_map(|kp| kp.ok())
100 .collect();
101
102 Ok(keypackages)
103 }
104
105 pub async fn get_keypackage_refs(&self) -> Result<Vec<KeypackageRef>> {
107 self.get_keypackages()
108 .await?
109 .iter()
110 .map(|keypackage| keypackage.make_ref().map_err(Into::into))
111 .collect()
112 }
113
114 pub(crate) async fn load_keypackage(&self, kp_ref: &KeypackageRef) -> Result<Option<Keypackage>> {
116 self.crypto_provider
117 .keystore()
118 .get_borrowed::<StoredKeypackage>(kp_ref.hash_ref())
119 .await
120 .map_err(KeystoreError::wrap("loading keypackage from database"))?
121 .map(|stored_keypackage| from_stored(&stored_keypackage))
122 .transpose()
123 }
124
125 pub(crate) async fn remove_keypackage(&self, kp_ref: &KeypackageRef) -> Result<()> {
132 let Some(kp) = self.load_keypackage(kp_ref).await? else {
133 return Ok(());
134 };
135
136 let db = self.crypto_provider.keystore();
137 db.remove_borrowed::<StoredKeypackage>(kp_ref.hash_ref())
138 .await
139 .map_err(KeystoreError::wrap("removing key package from keystore"))?;
140 db.remove_borrowed::<StoredHpkePrivateKey>(kp.hpke_init_key().as_slice())
141 .await
142 .map_err(KeystoreError::wrap("removing private key from keystore"))?;
143 db.remove_borrowed::<StoredEncryptionKeyPair>(kp.leaf_node().encryption_key().as_slice())
144 .await
145 .map_err(KeystoreError::wrap("removing encryption keypair from keystore"))?;
146
147 Ok(())
148 }
149
150 pub(crate) async fn remove_keypackages_for(&self, credential_ref: &CredentialRef) -> Result<()> {
158 let credential = self.credential_from_ref(credential_ref).await?;
159 let signature_public_key = credential.signature_key_pair.public();
160
161 let mut first_err = None;
162 macro_rules! try_retain_err {
163 ($e:expr) => {
164 match $e {
165 Err(err) => {
166 if first_err.is_none() {
167 first_err = Some(Error::from(err));
168 }
169 continue;
170 }
171 Ok(val) => val,
172 }
173 };
174 }
175
176 for keypackage in self
177 .get_keypackages()
178 .await?
179 .into_iter()
180 .filter(|keypackage| keypackage.leaf_node().signature_key().as_slice() == signature_public_key)
181 {
182 let kp_ref = try_retain_err!(keypackage.make_ref());
183 try_retain_err!(self.remove_keypackage(&kp_ref).await);
184 }
185
186 match first_err {
187 None => Ok(()),
188 Some(err) => Err(err),
189 }
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use std::time::Duration;
196
197 use core_crypto_keystore::{ConnectionType, DatabaseKey};
198 use mls_crypto_provider::{Database, MlsCryptoProvider};
199 use openmls::prelude::{KeyPackageIn, ProtocolVersion};
200 use openmls_traits::types::VerifiableCiphersuite;
201
202 use crate::{
203 MlsConversationConfiguration,
204 e2e_identity::enrollment::test_utils::{e2ei_enrollment, init_activation_or_rotation, noop_restore},
205 mls::key_package::KeypackageExt as _,
206 test_utils::*,
207 };
208
209 #[apply(all_cred_cipher)]
210 async fn can_assess_keypackage_expiration(case: TestContext) {
211 let [session_context] = case.sessions().await;
212 let key = DatabaseKey::generate();
213 let database = Database::open(ConnectionType::InMemory, &key).await.unwrap();
214 let backend = MlsCryptoProvider::new(database);
215 let x509_test_chain = if case.is_x509() {
216 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
217 x509_test_chain.register_with_provider(&backend).await;
218 Some(x509_test_chain)
219 } else {
220 None
221 };
222
223 backend.new_transaction().await.unwrap();
224 session_context
225 .random_generate(
226 &case,
227 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
228 )
229 .await
230 .unwrap();
231
232 let kp_std_exp = session_context.new_keypackage(&case).await;
234 assert!(kp_std_exp.is_valid());
235
236 let kp_1s_exp = session_context
238 .new_keypackage_with_lifetime(&case, Some(Duration::from_secs(1)))
239 .await;
240
241 smol::Timer::after(std::time::Duration::from_secs(2)).await;
243 assert!(!kp_1s_exp.is_valid());
244 }
245
246 #[apply(all_cred_cipher)]
247 async fn requesting_x509_key_packages_after_basic(case: TestContext) {
248 if !case.is_basic() {
250 return;
251 }
252
253 let [session_context] = case.sessions_basic_with_pki_env().await;
254 Box::pin(async move {
255 let signature_scheme = case.signature_scheme();
256
257 let mut initial_kp_refs = Vec::new();
259 for _ in 0..5 {
260 let kp = session_context.new_keypackage(&case).await;
261 initial_kp_refs.push(kp.make_ref().unwrap());
262 }
263 initial_kp_refs.sort_by_key(|kp_ref| kp_ref.hash_ref().to_owned());
264
265 let test_chain = session_context.x509_chain_unchecked();
267
268 let (mut enrollment, cert_chain) = e2ei_enrollment(
269 &session_context.transaction,
270 &case,
271 test_chain,
272 &session_context.get_e2ei_client_id().await.to_uri(),
273 false,
274 init_activation_or_rotation,
275 noop_restore,
276 )
277 .await
278 .unwrap();
279
280 let _rotate_bundle = session_context
281 .transaction
282 .save_x509_credential(&mut enrollment, cert_chain)
283 .await
284 .unwrap();
285
286 assert!(
288 session_context
289 .transaction
290 .e2ei_is_enabled(signature_scheme)
291 .await
292 .unwrap()
293 );
294
295 let key_packages = session_context.transaction.get_keypackage_refs().await.unwrap();
297 let (mut from_initial_set, x509_key_packages) = key_packages
298 .into_iter()
299 .partition::<Vec<_>, _>(|kp_ref| initial_kp_refs.contains(kp_ref));
300
301 from_initial_set.sort_by_key(|kp_ref| kp_ref.hash_ref().to_owned());
302 assert_eq!(initial_kp_refs, from_initial_set);
303
304 assert!(
306 x509_key_packages
307 .iter()
308 .all(|kp| CredentialType::X509 == kp.credential_type())
309 );
310 })
311 .await
312 }
313
314 #[apply(all_cred_cipher)]
315 async fn new_keypackage_has_correct_extensions(case: TestContext) {
316 let [cc] = case.sessions().await;
317 Box::pin(async move {
318 let kp = cc.new_keypackage(&case).await;
319
320 let _ = KeyPackageIn::from(kp.clone())
322 .standalone_validate(
323 &cc.transaction.mls_provider().await.unwrap(),
324 ProtocolVersion::Mls10,
325 true,
326 )
327 .await
328 .unwrap();
329
330 assert!(kp.extensions().is_empty());
332
333 assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
334 assert_eq!(
335 kp.leaf_node().capabilities().ciphersuites().to_vec(),
336 MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
337 .iter()
338 .map(|c| VerifiableCiphersuite::from(*c))
339 .collect::<Vec<_>>()
340 );
341 assert!(kp.leaf_node().capabilities().proposals().is_empty());
342 assert!(kp.leaf_node().capabilities().extensions().is_empty());
343 assert_eq!(
344 kp.leaf_node().capabilities().credentials(),
345 MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
346 );
347 })
348 .await
349 }
350
351 #[apply(all_cred_cipher)]
352 async fn can_store_and_load_key_packages(case: TestContext) {
353 let [cc] = case.sessions().await;
354
355 let kp = cc.new_keypackage(&case).await;
357
358 let all_keypackages = cc.session.read().await.get_keypackages().await.unwrap();
359 assert_eq!(all_keypackages[0], kp);
360
361 let kp_ref = kp.make_ref().unwrap();
362 let by_ref = cc.session.read().await.load_keypackage(&kp_ref).await.unwrap().unwrap();
363 assert_eq!(kp, by_ref);
364 }
365}