use std::collections::HashMap;
use openmls::prelude::{Credential, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageRef, Lifetime};
use openmls_traits::OpenMlsCryptoProvider;
use tls_codec::{Deserialize, Serialize};
use core_crypto_keystore::{
connection::KeystoreDatabaseConnection,
entities::{
EntityBase, EntityFindParams, MlsCredential, MlsCredentialExt, MlsEncryptionKeyPair, MlsHpkePrivateKey,
MlsKeyPackage,
},
};
use mls_crypto_provider::MlsCryptoProvider;
use crate::{
mls::credential::CredentialBundle,
prelude::{
Client, CryptoError, CryptoResult, MlsCentral, MlsCiphersuite, MlsConversationConfiguration, MlsCredentialType,
MlsError,
},
};
#[cfg(not(test))]
pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 100;
#[cfg(test)]
pub const INITIAL_KEYING_MATERIAL_COUNT: usize = 10;
pub(crate) const KEYPACKAGE_DEFAULT_LIFETIME: std::time::Duration =
std::time::Duration::from_secs(60 * 60 * 24 * 28 * 3); impl Client {
pub async fn generate_one_keypackage_from_credential_bundle(
&self,
backend: &MlsCryptoProvider,
cs: MlsCiphersuite,
cb: &CredentialBundle,
) -> CryptoResult<KeyPackage> {
let keypackage = KeyPackage::builder()
.leaf_node_capabilities(MlsConversationConfiguration::default_leaf_capabilities())
.key_package_lifetime(Lifetime::new(self.keypackage_lifetime.as_secs()))
.build(
CryptoConfig {
ciphersuite: cs.into(),
version: openmls::versions::ProtocolVersion::default(),
},
backend,
&cb.signature_key,
CredentialWithKey {
credential: cb.credential.clone(),
signature_key: cb.signature_key.public().into(),
},
)
.await
.map_err(MlsError::from)?;
Ok(keypackage)
}
pub async fn request_key_packages(
&self,
count: usize,
ciphersuite: MlsCiphersuite,
credential_type: MlsCredentialType,
backend: &MlsCryptoProvider,
) -> CryptoResult<Vec<KeyPackage>> {
self.prune_keypackages(backend, &[]).await?;
use core_crypto_keystore::CryptoKeystoreMls as _;
let mut existing_kps = backend
.key_store()
.mls_fetch_keypackages::<KeyPackage>(count as u32)
.await?
.into_iter()
.filter(|kp|
kp.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(kp.leaf_node().credential().credential_type()) == credential_type)
.collect::<Vec<_>>();
let kpb_count = existing_kps.len();
let mut kps = if count > kpb_count {
let to_generate = count - kpb_count;
let cb = self
.find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
.ok_or(CryptoError::MlsNotInitialized)?;
self.generate_new_keypackages(backend, ciphersuite, cb, to_generate)
.await?
} else {
vec![]
};
existing_kps.reverse();
kps.append(&mut existing_kps);
Ok(kps)
}
pub(crate) async fn generate_new_keypackages(
&self,
backend: &MlsCryptoProvider,
ciphersuite: MlsCiphersuite,
cb: &CredentialBundle,
count: usize,
) -> CryptoResult<Vec<KeyPackage>> {
let mut kps = Vec::with_capacity(count);
for _ in 0..count {
let kp = self
.generate_one_keypackage_from_credential_bundle(backend, ciphersuite, cb)
.await?;
kps.push(kp);
}
Ok(kps)
}
pub async fn valid_keypackages_count(
&self,
backend: &MlsCryptoProvider,
ciphersuite: MlsCiphersuite,
credential_type: MlsCredentialType,
) -> CryptoResult<usize> {
use core_crypto_keystore::entities::EntityBase as _;
let keystore = backend.key_store();
let mut conn = keystore.borrow_conn().await?;
let kps = MlsKeyPackage::find_all(&mut conn, EntityFindParams::default()).await?;
let valid_count = kps
.into_iter()
.map(|kp| core_crypto_keystore::deser::<KeyPackage>(&kp.keypackage))
.filter(|kp| {
kp.as_ref()
.map(|b| b.ciphersuite() == ciphersuite.0 && MlsCredentialType::from(b.leaf_node().credential().credential_type()) == credential_type)
.unwrap_or_default()
})
.try_fold(0usize, |mut valid_count, kp| {
if !Self::is_mls_keypackage_expired(&kp?) {
valid_count += 1;
}
CryptoResult::Ok(valid_count)
})?;
Ok(valid_count)
}
fn is_mls_keypackage_expired(kp: &KeyPackage) -> bool {
let Some(lifetime) = kp.leaf_node().life_time() else {
return false;
};
!(lifetime.has_acceptable_range() && lifetime.is_valid())
}
pub async fn prune_keypackages(&self, backend: &MlsCryptoProvider, refs: &[KeyPackageRef]) -> CryptoResult<()> {
let mut conn = backend.key_store().borrow_conn().await?;
let kps = self.find_all_keypackages(&mut conn).await?;
let _ = self._prune_keypackages(&kps, &mut conn, refs).await?;
Ok(())
}
pub(crate) async fn prune_keypackages_and_credential(
&mut self,
backend: &MlsCryptoProvider,
refs: &[KeyPackageRef],
) -> CryptoResult<()> {
let mut conn = backend.key_store().borrow_conn().await?;
let kps = self.find_all_keypackages(&mut conn).await?;
let kp_to_delete = self._prune_keypackages(&kps, &mut conn, refs).await?;
let mut grouped_kps = HashMap::<Vec<u8>, Vec<KeyPackageRef>>::new();
for (_, kp) in &kps {
let cred = kp
.leaf_node()
.credential()
.tls_serialize_detached()
.map_err(MlsError::from)?;
let kp_ref = kp.hash_ref(backend.crypto()).map_err(MlsError::from)?;
grouped_kps
.entry(cred)
.and_modify(|kprfs| kprfs.push(kp_ref.clone()))
.or_insert(vec![kp_ref]);
}
for (credential, kps) in &grouped_kps {
let all_to_delete = kps.iter().all(|kpr| kp_to_delete.contains(&kpr.as_slice()));
if all_to_delete {
MlsCredential::delete_by_credential(&mut conn, credential.clone()).await?;
let credential = Credential::tls_deserialize(&mut credential.as_slice()).map_err(MlsError::from)?;
self.identities.remove(&credential)?;
}
}
Ok(())
}
async fn _prune_keypackages<'a>(
&self,
kps: &'a [(MlsKeyPackage, KeyPackage)],
conn: &mut KeystoreDatabaseConnection,
refs: &[KeyPackageRef],
) -> Result<Vec<&'a [u8]>, CryptoError> {
use core_crypto_keystore::entities::EntityBase as _;
let kp_to_delete: Vec<_> = kps
.iter()
.filter_map(|(store_kp, kp)| {
let is_expired = Self::is_mls_keypackage_expired(kp);
let mut to_delete = is_expired;
if !(is_expired || refs.is_empty()) {
to_delete = refs.iter().any(|r| r.as_slice() == store_kp.keypackage_ref);
}
to_delete.then_some((kp, &store_kp.keypackage_ref))
})
.collect();
for (kp, kp_ref) in &kp_to_delete {
MlsKeyPackage::delete(conn, &[kp_ref.as_slice().into()]).await?;
MlsHpkePrivateKey::delete(conn, &[kp.hpke_init_key().as_slice().into()]).await?;
MlsEncryptionKeyPair::delete(conn, &[kp.leaf_node().encryption_key().as_slice().into()]).await?;
}
let kp_to_delete = kp_to_delete
.into_iter()
.map(|(_, kpref)| &kpref[..])
.collect::<Vec<_>>();
Ok(kp_to_delete)
}
async fn find_all_keypackages(
&self,
conn: &mut KeystoreDatabaseConnection,
) -> CryptoResult<Vec<(MlsKeyPackage, KeyPackage)>> {
let kps = MlsKeyPackage::find_all(conn, EntityFindParams::default()).await?;
let kps = kps.into_iter().try_fold(vec![], |mut acc, raw_kp| {
let kp = core_crypto_keystore::deser::<KeyPackage>(&raw_kp.keypackage)?;
acc.push((raw_kp, kp));
CryptoResult::Ok(acc)
})?;
Ok(kps)
}
#[cfg(test)]
pub fn set_keypackage_lifetime(&mut self, duration: std::time::Duration) {
self.keypackage_lifetime = duration;
}
}
impl MlsCentral {
pub async fn get_or_create_client_keypackages(
&self,
ciphersuite: MlsCiphersuite,
credential_type: MlsCredentialType,
amount_requested: usize,
) -> CryptoResult<Vec<KeyPackage>> {
self.mls_client()?
.request_key_packages(amount_requested, ciphersuite, credential_type, &self.mls_backend)
.await
}
#[cfg_attr(test, crate::idempotent)]
pub async fn client_valid_key_packages_count(
&self,
ciphersuite: MlsCiphersuite,
credential_type: MlsCredentialType,
) -> CryptoResult<usize> {
self.mls_client()?
.valid_keypackages_count(&self.mls_backend, ciphersuite, credential_type)
.await
}
#[cfg_attr(test, crate::dispotent)]
pub async fn delete_keypackages(&mut self, refs: &[KeyPackageRef]) -> CryptoResult<()> {
if refs.is_empty() {
return Err(CryptoError::ConsumerError);
}
let client = self.mls_client.as_mut().ok_or(CryptoError::MlsNotInitialized)?;
client.prune_keypackages_and_credential(&self.mls_backend, refs).await
}
}
#[cfg(test)]
mod tests {
use openmls::prelude::{KeyPackage, KeyPackageIn, KeyPackageRef, ProtocolVersion};
use openmls_traits::types::VerifiableCiphersuite;
use openmls_traits::OpenMlsCryptoProvider;
use wasm_bindgen_test::*;
use mls_crypto_provider::MlsCryptoProvider;
use crate::e2e_identity::tests::{e2ei_enrollment, init_activation_or_rotation, noop_restore};
use crate::prelude::key_package::INITIAL_KEYING_MATERIAL_COUNT;
use crate::prelude::MlsConversationConfiguration;
use crate::test_utils::*;
use super::Client;
wasm_bindgen_test_configure!(run_in_browser);
#[apply(all_cred_cipher)]
#[wasm_bindgen_test]
async fn can_assess_keypackage_expiration(case: TestCase) {
let (cs, ct) = (case.ciphersuite(), case.credential_type);
let backend = MlsCryptoProvider::try_new_in_memory("test").await.unwrap();
let x509_test_chain = if case.is_x509() {
let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
x509_test_chain.register_with_provider(&backend).await;
Some(x509_test_chain)
} else {
None
};
let mut client = Client::random_generate(
&case,
&backend,
x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
false,
)
.await
.unwrap();
let kp_std_exp = client.generate_one_keypackage(&backend, cs, ct).await.unwrap();
assert!(!Client::is_mls_keypackage_expired(&kp_std_exp));
client.set_keypackage_lifetime(std::time::Duration::from_secs(1));
let kp_1s_exp = client.generate_one_keypackage(&backend, cs, ct).await.unwrap();
async_std::task::sleep(std::time::Duration::from_secs(2)).await;
assert!(Client::is_mls_keypackage_expired(&kp_1s_exp));
}
#[apply(all_cred_cipher)]
#[wasm_bindgen_test]
async fn requesting_x509_key_packages_after_basic(case: TestCase) {
if !case.is_basic() {
return;
}
run_test_with_client_ids(case.clone(), ["alice"], move |[mut client_context]| {
Box::pin(async move {
let signature_scheme = case.signature_scheme();
let cipher_suite = case.ciphersuite();
let _basic_key_packages = client_context
.mls_central
.get_or_create_client_keypackages(cipher_suite, MlsCredentialType::Basic, 5)
.await
.unwrap();
let test_chain = x509::X509TestChain::init_for_random_clients(signature_scheme, 1);
let (mut enrollment, cert_chain) = e2ei_enrollment(
&mut client_context,
&case,
&test_chain,
None,
false,
init_activation_or_rotation,
noop_restore,
)
.await
.unwrap();
let _rotate_bundle = client_context
.mls_central
.e2ei_rotate_all(&mut enrollment, cert_chain, 5)
.await
.unwrap();
assert!(client_context.mls_central.e2ei_is_enabled(signature_scheme).unwrap());
let x509_key_packages = client_context
.mls_central
.get_or_create_client_keypackages(cipher_suite, MlsCredentialType::X509, 5)
.await
.unwrap();
assert!(x509_key_packages.iter().all(|kp| MlsCredentialType::X509
== MlsCredentialType::from(kp.leaf_node().credential().credential_type())));
})
})
.await
}
#[apply(all_cred_cipher)]
#[wasm_bindgen_test]
async fn generates_correct_number_of_kpbs(case: TestCase) {
run_test_with_client_ids(case.clone(), ["alice"], move |[mut cc]| {
Box::pin(async move {
const N: usize = 2;
const COUNT: usize = 109;
let init = cc.mls_central.count_entities().await;
assert_eq!(init.key_package, INITIAL_KEYING_MATERIAL_COUNT);
assert_eq!(init.encryption_keypair, INITIAL_KEYING_MATERIAL_COUNT);
assert_eq!(init.hpke_private_key, INITIAL_KEYING_MATERIAL_COUNT);
assert_eq!(init.credential, 1);
assert_eq!(init.signature_keypair, 1);
let mut pinned_kp = None;
let mut prev_kps: Option<Vec<KeyPackage>> = None;
for _ in 0..N {
let mut kps = cc
.mls_central
.get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, COUNT + 1)
.await
.unwrap();
pinned_kp = Some(kps.pop().unwrap());
assert_eq!(kps.len(), COUNT);
let after_creation = cc.mls_central.count_entities().await;
assert_eq!(after_creation.key_package, COUNT + 1);
assert_eq!(after_creation.encryption_keypair, COUNT + 1);
assert_eq!(after_creation.hpke_private_key, COUNT + 1);
assert_eq!(after_creation.credential, 1);
let kpbs_refs = kps
.iter()
.map(|kp| kp.hash_ref(cc.mls_central.mls_backend.crypto()).unwrap())
.collect::<Vec<KeyPackageRef>>();
if let Some(pkpbs) = prev_kps.replace(kps) {
let pkpbs_refs = pkpbs
.into_iter()
.map(|kpb| kpb.hash_ref(cc.mls_central.mls_backend.crypto()).unwrap())
.collect::<Vec<KeyPackageRef>>();
let has_duplicates = kpbs_refs.iter().any(|href| pkpbs_refs.contains(href));
assert!(!has_duplicates);
}
cc.mls_central.delete_keypackages(&kpbs_refs).await.unwrap();
}
let count = cc
.mls_central
.client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
.await
.unwrap();
assert_eq!(count, 1);
let pinned_kpr = pinned_kp
.unwrap()
.hash_ref(cc.mls_central.mls_backend.crypto())
.unwrap();
cc.mls_central.delete_keypackages(&[pinned_kpr]).await.unwrap();
let count = cc
.mls_central
.client_valid_key_packages_count(case.ciphersuite(), case.credential_type)
.await
.unwrap();
assert_eq!(count, 0);
let after_delete = cc.mls_central.count_entities().await;
assert_eq!(after_delete.key_package, 0);
assert_eq!(after_delete.encryption_keypair, 0);
assert_eq!(after_delete.hpke_private_key, 0);
assert_eq!(after_delete.credential, 0);
})
})
.await
}
#[apply(all_cred_cipher)]
#[wasm_bindgen_test]
async fn automatically_prunes_lifetime_expired_keypackages(case: TestCase) {
const UNEXPIRED_COUNT: usize = 125;
const EXPIRED_COUNT: usize = 200;
let backend = MlsCryptoProvider::try_new_in_memory("test").await.unwrap();
let x509_test_chain = if case.is_x509() {
let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
x509_test_chain.register_with_provider(&backend).await;
Some(x509_test_chain)
} else {
None
};
let mut client = Client::random_generate(
&case,
&backend,
x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
false,
)
.await
.unwrap();
let unexpired_kpbs = client
.request_key_packages(UNEXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
.await
.unwrap();
let len = client
.valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
.await
.unwrap();
assert_eq!(len, unexpired_kpbs.len());
assert_eq!(len, UNEXPIRED_COUNT);
client.set_keypackage_lifetime(std::time::Duration::from_secs(10));
let partially_expired_kpbs = client
.request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
.await
.unwrap();
assert_eq!(partially_expired_kpbs.len(), EXPIRED_COUNT);
async_std::task::sleep(std::time::Duration::from_secs(10)).await;
let fresh_kpbs = client
.request_key_packages(EXPIRED_COUNT, case.ciphersuite(), case.credential_type, &backend)
.await
.unwrap();
let len = client
.valid_keypackages_count(&backend, case.ciphersuite(), case.credential_type)
.await
.unwrap();
assert_eq!(len, fresh_kpbs.len());
assert_eq!(len, EXPIRED_COUNT);
let (unexpired_match, expired_match) =
fresh_kpbs
.iter()
.fold((0usize, 0usize), |(mut unexpired_match, mut expired_match), fresh| {
if unexpired_kpbs.iter().any(|kp| kp == fresh) {
unexpired_match += 1;
} else if partially_expired_kpbs.iter().any(|kpb| kpb == fresh) {
expired_match += 1;
}
(unexpired_match, expired_match)
});
assert_eq!(unexpired_match, UNEXPIRED_COUNT);
assert_eq!(expired_match, 0);
}
#[apply(all_cred_cipher)]
#[wasm_bindgen_test]
async fn new_keypackage_has_correct_extensions(case: TestCase) {
run_test_with_client_ids(case.clone(), ["alice"], move |[cc]| {
Box::pin(async move {
let kps = cc
.mls_central
.get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 1)
.await
.unwrap();
let kp = kps.first().unwrap();
let _ = KeyPackageIn::from(kp.clone())
.standalone_validate(&cc.mls_central.mls_backend, ProtocolVersion::Mls10, true)
.await
.unwrap();
assert!(kp.extensions().is_empty());
assert_eq!(kp.leaf_node().capabilities().versions(), &[ProtocolVersion::Mls10]);
assert_eq!(
kp.leaf_node().capabilities().ciphersuites().to_vec(),
MlsConversationConfiguration::DEFAULT_SUPPORTED_CIPHERSUITES
.iter()
.map(|c| VerifiableCiphersuite::from(*c))
.collect::<Vec<_>>()
);
assert!(kp.leaf_node().capabilities().proposals().is_empty());
assert!(kp.leaf_node().capabilities().extensions().is_empty());
assert_eq!(
kp.leaf_node().capabilities().credentials(),
MlsConversationConfiguration::DEFAULT_SUPPORTED_CREDENTIALS
);
})
})
.await
}
}