pub(crate) mod id;
pub(crate) mod identifier;
pub(crate) mod identities;
pub(crate) mod key_package;
pub(crate) mod user_id;
use crate::{
mls::credential::{ext::CredentialExt, CredentialBundle},
prelude::{
identifier::ClientIdentifier, key_package::KEYPACKAGE_DEFAULT_LIFETIME, CertificateBundle, ClientId,
CryptoError, CryptoResult, MlsCiphersuite, MlsCredentialType, MlsError,
},
};
use async_lock::RwLock;
use core_crypto_keystore::{connection::FetchFromDatabase, Connection, CryptoKeystoreError};
use log::debug;
use openmls::prelude::{Credential, CredentialType};
use openmls_basic_credential::SignatureKeyPair;
use openmls_traits::{crypto::OpenMlsCrypto, types::SignatureScheme, OpenMlsCryptoProvider};
use std::collections::HashSet;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use tls_codec::{Deserialize, Serialize};
use core_crypto_keystore::entities::{EntityFindParams, MlsCredential, MlsSignatureKeyPair};
use identities::ClientIdentities;
use mls_crypto_provider::MlsCryptoProvider;
#[derive(Clone, Debug, Default)]
pub struct Client {
state: Arc<RwLock<Option<ClientInner>>>,
}
#[derive(Debug, Clone)]
struct ClientInner {
id: ClientId,
pub(crate) identities: ClientIdentities,
keypackage_lifetime: std::time::Duration,
}
impl Client {
pub async fn init(
&self,
identifier: ClientIdentifier,
ciphersuites: &[MlsCiphersuite],
backend: &MlsCryptoProvider,
nb_key_package: usize,
) -> CryptoResult<()> {
self.ensure_unready().await?;
let id = identifier.get_id()?;
let credentials = backend
.key_store()
.find_all::<MlsCredential>(EntityFindParams::default())
.await?;
let credentials = credentials
.into_iter()
.filter(|c| &c.id[..] == id.as_slice())
.try_fold(vec![], |mut acc, c| {
let credential = Credential::tls_deserialize(&mut c.credential.as_slice()).map_err(MlsError::from)?;
acc.push((credential, c.created_at));
CryptoResult::Ok(acc)
})?;
if !credentials.is_empty() {
let signature_schemes = ciphersuites
.iter()
.map(|cs| cs.signature_algorithm())
.collect::<HashSet<_>>();
match self.load(backend, id.as_ref(), credentials, signature_schemes).await {
Ok(client) => client,
Err(CryptoError::ClientSignatureNotFound) => {
debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Client signature not found. Generating client");
self.generate(identifier, backend, ciphersuites, nb_key_package).await?
}
Err(e) => return Err(e),
}
} else {
debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Generating client");
self.generate(identifier, backend, ciphersuites, nb_key_package).await?
};
Ok(())
}
pub(crate) async fn is_ready(&self) -> bool {
let inner_lock = self.state.read().await;
inner_lock.is_some()
}
async fn ensure_unready(&self) -> CryptoResult<()> {
if self.is_ready().await {
Err(CryptoError::ConsumerError)
} else {
Ok(())
}
}
async fn replace_inner(&self, new_inner: ClientInner) {
let mut inner_lock = self.state.write().await;
*inner_lock = Some(new_inner);
}
pub async fn generate_raw_keypairs(
&self,
ciphersuites: &[MlsCiphersuite],
backend: &MlsCryptoProvider,
) -> CryptoResult<Vec<ClientId>> {
self.ensure_unready().await?;
const TEMP_KEY_SIZE: usize = 16;
let credentials = Self::find_all_basic_credentials(backend).await?;
if !credentials.is_empty() {
return Err(CryptoError::IdentityAlreadyPresent);
}
use openmls_traits::random::OpenMlsRand as _;
let mut tmp_client_ids = Vec::with_capacity(ciphersuites.len());
for cs in ciphersuites {
let tmp_client_id: ClientId = backend.rand().random_vec(TEMP_KEY_SIZE)?.into();
let cb = Self::new_basic_credential_bundle(&tmp_client_id, cs.signature_algorithm(), backend)?;
let sign_kp = MlsSignatureKeyPair::new(
cs.signature_algorithm(),
cb.signature_key.to_public_vec(),
cb.signature_key.tls_serialize_detached().map_err(MlsError::from)?,
tmp_client_id.clone().into(),
);
backend.key_store().save(sign_kp).await?;
tmp_client_ids.push(tmp_client_id);
}
Ok(tmp_client_ids)
}
pub async fn init_with_external_client_id(
&self,
client_id: ClientId,
tmp_ids: Vec<ClientId>,
ciphersuites: &[MlsCiphersuite],
backend: &MlsCryptoProvider,
) -> CryptoResult<()> {
self.ensure_unready().await?;
let stored_skp = backend
.key_store()
.find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
.await?;
match stored_skp.len() {
i if i < tmp_ids.len() => return Err(CryptoError::NoProvisionalIdentityFound),
i if i > tmp_ids.len() => return Err(CryptoError::TooManyIdentitiesPresent),
_ => {}
}
let all_tmp_ids_exist = stored_skp
.iter()
.all(|kp| tmp_ids.contains(&kp.credential_id.as_slice().into()));
if !all_tmp_ids_exist {
return Err(CryptoError::NoProvisionalIdentityFound);
}
let identities = stored_skp.iter().zip(ciphersuites);
self.replace_inner(ClientInner {
id: client_id.clone(),
identities: ClientIdentities::new(stored_skp.len()),
keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
})
.await;
let id = &client_id;
for (tmp_kp, &cs) in identities {
let scheme = tmp_kp
.signature_scheme
.try_into()
.map_err(|_| CryptoError::ImplementationError)?;
let new_keypair =
MlsSignatureKeyPair::new(scheme, tmp_kp.pk.clone(), tmp_kp.keypair.clone(), id.clone().into());
let new_credential = MlsCredential {
id: id.clone().into(),
credential: tmp_kp.credential_id.clone(),
created_at: 0,
};
backend
.key_store()
.remove::<MlsSignatureKeyPair, &[u8]>(&new_keypair.pk)
.await?;
let signature_key =
SignatureKeyPair::tls_deserialize(&mut new_keypair.keypair.as_slice()).map_err(MlsError::from)?;
let cb = CredentialBundle {
credential: Credential::new_basic(new_credential.credential.clone()),
signature_key,
created_at: 0, };
self.save_identity(&backend.keystore(), Some(id), cs.signature_algorithm(), cb)
.await?;
}
Ok(())
}
pub(crate) async fn generate(
&self,
identifier: ClientIdentifier,
backend: &MlsCryptoProvider,
ciphersuites: &[MlsCiphersuite],
nb_key_package: usize,
) -> CryptoResult<()> {
self.ensure_unready().await?;
let id = identifier.get_id()?;
let signature_schemes = ciphersuites
.iter()
.map(|cs| cs.signature_algorithm())
.collect::<HashSet<_>>();
self.replace_inner(ClientInner {
id: id.into_owned(),
identities: ClientIdentities::new(signature_schemes.len()),
keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
})
.await;
let identities = identifier.generate_credential_bundles(backend, signature_schemes)?;
for (sc, id, cb) in identities {
self.save_identity(&backend.keystore(), Some(&id), sc, cb).await?;
}
let identities = match self.state.read().await.deref() {
None => Err(CryptoError::MlsNotInitialized),
Some(ClientInner { identities, .. }) => Ok(identities.clone()),
}?;
if nb_key_package != 0 {
for cs in ciphersuites {
let sc = cs.signature_algorithm();
let identity = identities.iter().filter(|(id_sc, _)| id_sc == &sc);
for (_, cb) in identity {
self.request_key_packages(nb_key_package, *cs, cb.credential.credential_type().into(), backend)
.await?;
}
}
}
Ok(())
}
pub(crate) async fn load(
&self,
backend: &MlsCryptoProvider,
id: &ClientId,
mut credentials: Vec<(Credential, u64)>,
signature_schemes: HashSet<SignatureScheme>,
) -> CryptoResult<()> {
self.ensure_unready().await?;
let mut identities = ClientIdentities::new(signature_schemes.len());
credentials.sort_by(|(_, a), (_, b)| a.cmp(b));
let store_skps = backend
.key_store()
.find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
.await?;
for sc in signature_schemes {
let kp = store_skps.iter().find(|skp| skp.signature_scheme == (sc as u16));
let signature_key = if let Some(kp) = kp {
SignatureKeyPair::tls_deserialize(&mut kp.keypair.as_slice()).map_err(MlsError::from)?
} else {
let (sk, pk) = backend.crypto().signature_key_gen(sc).map_err(MlsError::from)?;
let keypair = SignatureKeyPair::from_raw(sc, sk, pk.clone());
let raw_keypair = keypair.tls_serialize_detached().map_err(MlsError::from)?;
let store_keypair = MlsSignatureKeyPair::new(sc, pk, raw_keypair, id.as_slice().into());
backend.key_store().save(store_keypair.clone()).await?;
SignatureKeyPair::tls_deserialize(&mut store_keypair.keypair.as_slice()).map_err(MlsError::from)?
};
for (credential, created_at) in &credentials {
match credential.mls_credential() {
openmls::prelude::MlsCredentialType::Basic(_) => {
if id.as_slice() != credential.identity() {
return Err(CryptoError::ImplementationError);
}
}
openmls::prelude::MlsCredentialType::X509(cert) => {
let spk = cert.extract_public_key()?.ok_or(CryptoError::InternalMlsError)?;
if signature_key.public() != spk {
return Err(CryptoError::ImplementationError);
}
}
};
let cb = CredentialBundle {
credential: credential.clone(),
signature_key: signature_key.clone(),
created_at: *created_at,
};
identities.push_credential_bundle(sc, cb).await?;
}
}
self.replace_inner(ClientInner {
id: id.clone(),
identities,
keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
})
.await;
Ok(())
}
async fn find_all_basic_credentials(backend: &MlsCryptoProvider) -> CryptoResult<Vec<Credential>> {
let store_credentials = backend
.key_store()
.find_all::<MlsCredential>(EntityFindParams::default())
.await?;
let mut credentials = Vec::with_capacity(store_credentials.len());
for store_credential in store_credentials.into_iter() {
let credential =
Credential::tls_deserialize(&mut store_credential.credential.as_slice()).map_err(MlsError::from)?;
if !matches!(credential.credential_type(), CredentialType::Basic) {
continue;
}
credentials.push(credential);
}
Ok(credentials)
}
pub(crate) async fn save_identity(
&self,
keystore: &Connection,
id: Option<&ClientId>,
sc: SignatureScheme,
mut cb: CredentialBundle,
) -> CryptoResult<CredentialBundle> {
match self.state.write().await.deref_mut() {
None => Err(CryptoError::MlsNotInitialized),
Some(ClientInner {
id: existing_id,
identities,
..
}) => {
let id = id.unwrap_or(existing_id);
let credential = cb.credential.tls_serialize_detached().map_err(MlsError::from)?;
let credential = MlsCredential {
id: id.clone().into(),
credential,
created_at: 0,
};
let credential = keystore.save(credential).await?;
let sign_kp = MlsSignatureKeyPair::new(
sc,
cb.signature_key.to_public_vec(),
cb.signature_key.tls_serialize_detached().map_err(MlsError::from)?,
id.clone().into(),
);
keystore.save(sign_kp).await.map_err(|e| match e {
CryptoKeystoreError::AlreadyExists => CryptoError::CredentialBundleConflict,
_ => e.into(),
})?;
cb.created_at = credential.created_at;
identities.push_credential_bundle(sc, cb.clone()).await?;
Ok(cb)
}
}
}
pub async fn id(&self) -> CryptoResult<ClientId> {
match self.state.read().await.deref() {
None => Err(CryptoError::MlsNotInitialized),
Some(ClientInner { id, .. }) => Ok(id.clone()),
}
}
pub async fn is_e2ei_capable(&self) -> bool {
match self.state.read().await.deref() {
None => false,
Some(ClientInner { identities, .. }) => identities
.iter()
.any(|(_, cred)| cred.credential().credential_type() == CredentialType::X509),
}
}
pub(crate) async fn get_most_recent_or_create_credential_bundle(
&self,
backend: &MlsCryptoProvider,
sc: SignatureScheme,
ct: MlsCredentialType,
) -> CryptoResult<Arc<CredentialBundle>> {
match ct {
MlsCredentialType::Basic => {
self.init_basic_credential_bundle_if_missing(backend, sc).await?;
self.find_most_recent_credential_bundle(sc, ct).await
}
MlsCredentialType::X509 => self.find_most_recent_credential_bundle(sc, ct).await.map_err(|e| {
if matches!(e, CryptoError::CredentialNotFound(_)) {
CryptoError::E2eiEnrollmentNotDone
} else {
e
}
}),
}
}
pub(crate) async fn init_basic_credential_bundle_if_missing(
&self,
backend: &MlsCryptoProvider,
sc: SignatureScheme,
) -> CryptoResult<()> {
let existing_cb = self
.find_most_recent_credential_bundle(sc, MlsCredentialType::Basic)
.await;
if matches!(existing_cb, Err(CryptoError::CredentialNotFound(_))) {
let id = self.id().await?;
debug!(id:% = &id; "Initializing basic credential bundle");
let cb = Self::new_basic_credential_bundle(&id, sc, backend)?;
self.save_identity(&backend.keystore(), None, sc, cb).await?;
}
Ok(())
}
pub(crate) async fn save_new_x509_credential_bundle(
&self,
keystore: &Connection,
sc: SignatureScheme,
cb: CertificateBundle,
) -> CryptoResult<CredentialBundle> {
let id = cb.get_client_id()?;
let cb = Self::new_x509_credential_bundle(cb)?;
self.save_identity(keystore, Some(&id), sc, cb).await
}
}
#[cfg(test)]
impl Client {
pub async fn random_generate(
case: &crate::test_utils::TestCase,
backend: &MlsCryptoProvider,
signer: Option<&crate::test_utils::x509::X509Certificate>,
provision: bool,
) -> CryptoResult<Self> {
let user_uuid = uuid::Uuid::new_v4();
let rnd_id = rand::random::<usize>();
let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
let identity = match case.credential_type {
MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.as_str().into()),
MlsCredentialType::X509 => {
let signer = signer.expect("Missing intermediate CA");
CertificateBundle::rand_identifier(&client_id, &[signer])
}
};
let nb_key_package = if provision {
crate::prelude::INITIAL_KEYING_MATERIAL_COUNT
} else {
0
};
let client = Self::default();
client
.generate(identity, backend, &[case.ciphersuite()], nb_key_package)
.await?;
Ok(client)
}
pub async fn find_keypackages(
&self,
backend: &MlsCryptoProvider,
) -> CryptoResult<Vec<openmls::prelude::KeyPackage>> {
use core_crypto_keystore::CryptoKeystoreMls as _;
let kps = backend
.key_store()
.mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
.await?;
Ok(kps)
}
}
#[cfg(test)]
mod tests {
use crate::prelude::ClientId;
use crate::test_utils::*;
use core_crypto_keystore::connection::FetchFromDatabase;
use core_crypto_keystore::entities::{EntityFindParams, MlsSignatureKeyPair};
use mls_crypto_provider::MlsCryptoProvider;
use wasm_bindgen_test::*;
use super::Client;
wasm_bindgen_test_configure!(run_in_browser);
#[apply(all_cred_cipher)]
#[wasm_bindgen_test]
async fn can_generate_client(case: TestCase) {
let backend = MlsCryptoProvider::try_new_in_memory("test").await.unwrap();
let x509_test_chain = if case.is_x509() {
let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
x509_test_chain.register_with_provider(&backend).await;
Some(x509_test_chain)
} else {
None
};
backend.new_transaction().await.unwrap();
let _ = Client::random_generate(
&case,
&backend,
x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
false,
)
.await
.unwrap();
}
#[apply(all_cred_cipher)]
#[wasm_bindgen_test]
async fn can_externally_generate_client(case: TestCase) {
if case.is_basic() {
run_tests(move |[tmp_dir_argument]| {
Box::pin(async move {
let backend = MlsCryptoProvider::try_new(tmp_dir_argument, "test").await.unwrap();
backend.new_transaction().await.unwrap();
let client_id: ClientId = b"whatever:my:client:is@world.com".to_vec().into();
let alice = Client::default();
let handles = alice
.generate_raw_keypairs(&[case.ciphersuite()], &backend)
.await
.unwrap();
let mut identities = backend
.keystore()
.find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
.await
.unwrap();
assert_eq!(identities.len(), 1);
let prov_identity = identities.pop().unwrap();
let prov_client_id: ClientId = prov_identity.credential_id.as_slice().into();
assert_eq!(&prov_client_id, handles.first().unwrap());
alice
.init_with_external_client_id(
client_id.clone(),
handles.clone(),
&[case.ciphersuite()],
&backend,
)
.await
.unwrap();
assert_eq!(alice.id().await.unwrap(), client_id);
let cb = alice
.find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
.await
.unwrap();
let client_id: ClientId = cb.credential().identity().into();
assert_eq!(&client_id, handles.first().unwrap());
})
})
.await
}
}
}