1pub(crate) mod id;
18pub(crate) mod identifier;
19pub(crate) mod identities;
20pub(crate) mod key_package;
21pub(crate) mod user_id;
22
23use crate::{
24 mls::credential::{ext::CredentialExt, CredentialBundle},
25 prelude::{
26 identifier::ClientIdentifier, key_package::KEYPACKAGE_DEFAULT_LIFETIME, CertificateBundle, ClientId,
27 CryptoError, CryptoResult, MlsCiphersuite, MlsCredentialType, MlsError,
28 },
29};
30use async_lock::RwLock;
31use core_crypto_keystore::{connection::FetchFromDatabase, Connection, CryptoKeystoreError};
32use log::debug;
33use openmls::prelude::{Credential, CredentialType};
34use openmls_basic_credential::SignatureKeyPair;
35use openmls_traits::{crypto::OpenMlsCrypto, types::SignatureScheme, OpenMlsCryptoProvider};
36use std::collections::HashSet;
37use std::ops::{Deref, DerefMut};
38use std::sync::Arc;
39use tls_codec::{Deserialize, Serialize};
40
41use core_crypto_keystore::entities::{EntityFindParams, MlsCredential, MlsSignatureKeyPair};
42use identities::ClientIdentities;
43use mls_crypto_provider::MlsCryptoProvider;
44
45#[derive(Clone, Debug, Default)]
53pub struct Client {
54 state: Arc<RwLock<Option<ClientInner>>>,
55}
56
57#[derive(Debug, Clone)]
58struct ClientInner {
59 id: ClientId,
60 pub(crate) identities: ClientIdentities,
61 keypackage_lifetime: std::time::Duration,
62}
63
64impl Client {
65 pub async fn init(
77 &self,
78 identifier: ClientIdentifier,
79 ciphersuites: &[MlsCiphersuite],
80 backend: &MlsCryptoProvider,
81 nb_key_package: usize,
82 ) -> CryptoResult<()> {
83 self.ensure_unready().await?;
84 let id = identifier.get_id()?;
85
86 let credentials = backend
87 .key_store()
88 .find_all::<MlsCredential>(EntityFindParams::default())
89 .await?;
90
91 let credentials = credentials
92 .into_iter()
93 .filter(|c| &c.id[..] == id.as_slice())
94 .try_fold(vec![], |mut acc, c| {
95 let credential = Credential::tls_deserialize(&mut c.credential.as_slice()).map_err(MlsError::from)?;
96 acc.push((credential, c.created_at));
97 CryptoResult::Ok(acc)
98 })?;
99
100 if !credentials.is_empty() {
101 let signature_schemes = ciphersuites
102 .iter()
103 .map(|cs| cs.signature_algorithm())
104 .collect::<HashSet<_>>();
105 match self.load(backend, id.as_ref(), credentials, signature_schemes).await {
106 Ok(client) => client,
107 Err(CryptoError::ClientSignatureNotFound) => {
108 debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Client signature not found. Generating client");
109 self.generate(identifier, backend, ciphersuites, nb_key_package).await?
110 }
111 Err(e) => return Err(e),
112 }
113 } else {
114 debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Generating client");
115 self.generate(identifier, backend, ciphersuites, nb_key_package).await?
116 };
117
118 Ok(())
119 }
120
121 pub(crate) async fn is_ready(&self) -> bool {
122 let inner_lock = self.state.read().await;
123 inner_lock.is_some()
124 }
125
126 async fn ensure_unready(&self) -> CryptoResult<()> {
127 if self.is_ready().await {
128 Err(CryptoError::ConsumerError)
129 } else {
130 Ok(())
131 }
132 }
133
134 async fn replace_inner(&self, new_inner: ClientInner) {
135 let mut inner_lock = self.state.write().await;
136 *inner_lock = Some(new_inner);
137 }
138
139 pub async fn generate_raw_keypairs(
149 &self,
150 ciphersuites: &[MlsCiphersuite],
151 backend: &MlsCryptoProvider,
152 ) -> CryptoResult<Vec<ClientId>> {
153 self.ensure_unready().await?;
154 const TEMP_KEY_SIZE: usize = 16;
155
156 let credentials = Self::find_all_basic_credentials(backend).await?;
157 if !credentials.is_empty() {
158 return Err(CryptoError::IdentityAlreadyPresent);
159 }
160
161 use openmls_traits::random::OpenMlsRand as _;
162 let mut tmp_client_ids = Vec::with_capacity(ciphersuites.len());
164 for cs in ciphersuites {
165 let tmp_client_id: ClientId = backend.rand().random_vec(TEMP_KEY_SIZE)?.into();
166
167 let cb = Self::new_basic_credential_bundle(&tmp_client_id, cs.signature_algorithm(), backend)?;
168
169 let sign_kp = MlsSignatureKeyPair::new(
170 cs.signature_algorithm(),
171 cb.signature_key.to_public_vec(),
172 cb.signature_key.tls_serialize_detached().map_err(MlsError::from)?,
173 tmp_client_id.clone().into(),
174 );
175 backend.key_store().save(sign_kp).await?;
176
177 tmp_client_ids.push(tmp_client_id);
178 }
179
180 Ok(tmp_client_ids)
181 }
182
183 pub async fn init_with_external_client_id(
193 &self,
194 client_id: ClientId,
195 tmp_ids: Vec<ClientId>,
196 ciphersuites: &[MlsCiphersuite],
197 backend: &MlsCryptoProvider,
198 ) -> CryptoResult<()> {
199 self.ensure_unready().await?;
200 let stored_skp = backend
202 .key_store()
203 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
204 .await?;
205
206 match stored_skp.len() {
207 i if i < tmp_ids.len() => return Err(CryptoError::NoProvisionalIdentityFound),
208 i if i > tmp_ids.len() => return Err(CryptoError::TooManyIdentitiesPresent),
209 _ => {}
210 }
211
212 let all_tmp_ids_exist = stored_skp
214 .iter()
215 .all(|kp| tmp_ids.contains(&kp.credential_id.as_slice().into()));
216 if !all_tmp_ids_exist {
217 return Err(CryptoError::NoProvisionalIdentityFound);
218 }
219
220 let identities = stored_skp.iter().zip(ciphersuites);
221
222 self.replace_inner(ClientInner {
223 id: client_id.clone(),
224 identities: ClientIdentities::new(stored_skp.len()),
225 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
226 })
227 .await;
228
229 let id = &client_id;
230
231 for (tmp_kp, &cs) in identities {
232 let scheme = tmp_kp
233 .signature_scheme
234 .try_into()
235 .map_err(|_| CryptoError::ImplementationError)?;
236 let new_keypair =
237 MlsSignatureKeyPair::new(scheme, tmp_kp.pk.clone(), tmp_kp.keypair.clone(), id.clone().into());
238
239 let new_credential = MlsCredential {
240 id: id.clone().into(),
241 credential: tmp_kp.credential_id.clone(),
242 created_at: 0,
243 };
244
245 backend
247 .key_store()
248 .remove::<MlsSignatureKeyPair, &[u8]>(&new_keypair.pk)
249 .await?;
250
251 let signature_key =
252 SignatureKeyPair::tls_deserialize(&mut new_keypair.keypair.as_slice()).map_err(MlsError::from)?;
253 let cb = CredentialBundle {
254 credential: Credential::new_basic(new_credential.credential.clone()),
255 signature_key,
256 created_at: 0, };
258
259 self.save_identity(&backend.keystore(), Some(id), cs.signature_algorithm(), cb)
261 .await?;
262 }
263
264 Ok(())
265 }
266
267 pub(crate) async fn generate(
269 &self,
270 identifier: ClientIdentifier,
271 backend: &MlsCryptoProvider,
272 ciphersuites: &[MlsCiphersuite],
273 nb_key_package: usize,
274 ) -> CryptoResult<()> {
275 self.ensure_unready().await?;
276 let id = identifier.get_id()?;
277 let signature_schemes = ciphersuites
278 .iter()
279 .map(|cs| cs.signature_algorithm())
280 .collect::<HashSet<_>>();
281 self.replace_inner(ClientInner {
282 id: id.into_owned(),
283 identities: ClientIdentities::new(signature_schemes.len()),
284 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
285 })
286 .await;
287
288 let identities = identifier.generate_credential_bundles(backend, signature_schemes)?;
289
290 for (sc, id, cb) in identities {
291 self.save_identity(&backend.keystore(), Some(&id), sc, cb).await?;
292 }
293
294 let identities = match self.state.read().await.deref() {
295 None => Err(CryptoError::MlsNotInitialized),
296 Some(ClientInner { identities, .. }) => Ok(identities.clone()),
300 }?;
301
302 if nb_key_package != 0 {
303 for cs in ciphersuites {
304 let sc = cs.signature_algorithm();
305 let identity = identities.iter().filter(|(id_sc, _)| id_sc == &sc);
306 for (_, cb) in identity {
307 self.request_key_packages(nb_key_package, *cs, cb.credential.credential_type().into(), backend)
308 .await?;
309 }
310 }
311 }
312
313 Ok(())
314 }
315
316 pub(crate) async fn load(
318 &self,
319 backend: &MlsCryptoProvider,
320 id: &ClientId,
321 mut credentials: Vec<(Credential, u64)>,
322 signature_schemes: HashSet<SignatureScheme>,
323 ) -> CryptoResult<()> {
324 self.ensure_unready().await?;
325 let mut identities = ClientIdentities::new(signature_schemes.len());
326
327 credentials.sort_by(|(_, a), (_, b)| a.cmp(b));
329
330 let store_skps = backend
331 .key_store()
332 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
333 .await?;
334
335 for sc in signature_schemes {
336 let kp = store_skps.iter().find(|skp| skp.signature_scheme == (sc as u16));
337
338 let signature_key = if let Some(kp) = kp {
339 SignatureKeyPair::tls_deserialize(&mut kp.keypair.as_slice()).map_err(MlsError::from)?
340 } else {
341 let (sk, pk) = backend.crypto().signature_key_gen(sc).map_err(MlsError::from)?;
342 let keypair = SignatureKeyPair::from_raw(sc, sk, pk.clone());
343 let raw_keypair = keypair.tls_serialize_detached().map_err(MlsError::from)?;
344 let store_keypair = MlsSignatureKeyPair::new(sc, pk, raw_keypair, id.as_slice().into());
345 backend.key_store().save(store_keypair.clone()).await?;
346 SignatureKeyPair::tls_deserialize(&mut store_keypair.keypair.as_slice()).map_err(MlsError::from)?
347 };
348
349 for (credential, created_at) in &credentials {
350 match credential.mls_credential() {
351 openmls::prelude::MlsCredentialType::Basic(_) => {
352 if id.as_slice() != credential.identity() {
353 return Err(CryptoError::ImplementationError);
354 }
355 }
356 openmls::prelude::MlsCredentialType::X509(cert) => {
357 let spk = cert.extract_public_key()?.ok_or(CryptoError::InternalMlsError)?;
358 if signature_key.public() != spk {
359 return Err(CryptoError::ImplementationError);
360 }
361 }
362 };
363 let cb = CredentialBundle {
364 credential: credential.clone(),
365 signature_key: signature_key.clone(),
366 created_at: *created_at,
367 };
368 identities.push_credential_bundle(sc, cb).await?;
369 }
370 }
371 self.replace_inner(ClientInner {
372 id: id.clone(),
373 identities,
374 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
375 })
376 .await;
377 Ok(())
378 }
379
380 async fn find_all_basic_credentials(backend: &MlsCryptoProvider) -> CryptoResult<Vec<Credential>> {
381 let store_credentials = backend
382 .key_store()
383 .find_all::<MlsCredential>(EntityFindParams::default())
384 .await?;
385 let mut credentials = Vec::with_capacity(store_credentials.len());
386 for store_credential in store_credentials.into_iter() {
387 let credential =
388 Credential::tls_deserialize(&mut store_credential.credential.as_slice()).map_err(MlsError::from)?;
389 if !matches!(credential.credential_type(), CredentialType::Basic) {
390 continue;
391 }
392 credentials.push(credential);
393 }
394
395 Ok(credentials)
396 }
397
398 pub(crate) async fn save_identity(
399 &self,
400 keystore: &Connection,
401 id: Option<&ClientId>,
402 sc: SignatureScheme,
403 mut cb: CredentialBundle,
404 ) -> CryptoResult<CredentialBundle> {
405 match self.state.write().await.deref_mut() {
406 None => Err(CryptoError::MlsNotInitialized),
407 Some(ClientInner {
408 id: existing_id,
409 identities,
410 ..
411 }) => {
412 let id = id.unwrap_or(existing_id);
413
414 let credential = cb.credential.tls_serialize_detached().map_err(MlsError::from)?;
415 let credential = MlsCredential {
416 id: id.clone().into(),
417 credential,
418 created_at: 0,
419 };
420
421 let credential = keystore.save(credential).await?;
422
423 let sign_kp = MlsSignatureKeyPair::new(
424 sc,
425 cb.signature_key.to_public_vec(),
426 cb.signature_key.tls_serialize_detached().map_err(MlsError::from)?,
427 id.clone().into(),
428 );
429 keystore.save(sign_kp).await.map_err(|e| match e {
430 CryptoKeystoreError::AlreadyExists => CryptoError::CredentialBundleConflict,
431 _ => e.into(),
432 })?;
433
434 cb.created_at = credential.created_at;
436
437 identities.push_credential_bundle(sc, cb.clone()).await?;
438
439 Ok(cb)
440 }
441 }
442 }
443
444 pub async fn id(&self) -> CryptoResult<ClientId> {
446 match self.state.read().await.deref() {
447 None => Err(CryptoError::MlsNotInitialized),
448 Some(ClientInner { id, .. }) => Ok(id.clone()),
449 }
450 }
451
452 pub async fn is_e2ei_capable(&self) -> bool {
454 match self.state.read().await.deref() {
455 None => false,
456 Some(ClientInner { identities, .. }) => identities
457 .iter()
458 .any(|(_, cred)| cred.credential().credential_type() == CredentialType::X509),
459 }
460 }
461
462 pub(crate) async fn get_most_recent_or_create_credential_bundle(
463 &self,
464 backend: &MlsCryptoProvider,
465 sc: SignatureScheme,
466 ct: MlsCredentialType,
467 ) -> CryptoResult<Arc<CredentialBundle>> {
468 match ct {
469 MlsCredentialType::Basic => {
470 self.init_basic_credential_bundle_if_missing(backend, sc).await?;
471 self.find_most_recent_credential_bundle(sc, ct).await
472 }
473 MlsCredentialType::X509 => self.find_most_recent_credential_bundle(sc, ct).await.map_err(|e| {
474 if matches!(e, CryptoError::CredentialNotFound(_)) {
475 CryptoError::E2eiEnrollmentNotDone
476 } else {
477 e
478 }
479 }),
480 }
481 }
482
483 pub(crate) async fn init_basic_credential_bundle_if_missing(
484 &self,
485 backend: &MlsCryptoProvider,
486 sc: SignatureScheme,
487 ) -> CryptoResult<()> {
488 let existing_cb = self
489 .find_most_recent_credential_bundle(sc, MlsCredentialType::Basic)
490 .await;
491 if matches!(existing_cb, Err(CryptoError::CredentialNotFound(_))) {
492 let id = self.id().await?;
493 debug!(id:% = &id; "Initializing basic credential bundle");
494 let cb = Self::new_basic_credential_bundle(&id, sc, backend)?;
495 self.save_identity(&backend.keystore(), None, sc, cb).await?;
496 }
497 Ok(())
498 }
499
500 pub(crate) async fn save_new_x509_credential_bundle(
501 &self,
502 keystore: &Connection,
503 sc: SignatureScheme,
504 cb: CertificateBundle,
505 ) -> CryptoResult<CredentialBundle> {
506 let id = cb.get_client_id()?;
507 let cb = Self::new_x509_credential_bundle(cb)?;
508 self.save_identity(keystore, Some(&id), sc, cb).await
509 }
510}
511
512#[cfg(test)]
513impl Client {
514 pub async fn random_generate(
515 case: &crate::test_utils::TestCase,
516 backend: &MlsCryptoProvider,
517 signer: Option<&crate::test_utils::x509::X509Certificate>,
518 provision: bool,
519 ) -> CryptoResult<Self> {
520 let user_uuid = uuid::Uuid::new_v4();
521 let rnd_id = rand::random::<usize>();
522 let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
523 let identity = match case.credential_type {
524 MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.as_str().into()),
525 MlsCredentialType::X509 => {
526 let signer = signer.expect("Missing intermediate CA");
527 CertificateBundle::rand_identifier(&client_id, &[signer])
528 }
529 };
530 let nb_key_package = if provision {
531 crate::prelude::INITIAL_KEYING_MATERIAL_COUNT
532 } else {
533 0
534 };
535 let client = Self::default();
536 client
537 .generate(identity, backend, &[case.ciphersuite()], nb_key_package)
538 .await?;
539 Ok(client)
540 }
541
542 pub async fn find_keypackages(
543 &self,
544 backend: &MlsCryptoProvider,
545 ) -> CryptoResult<Vec<openmls::prelude::KeyPackage>> {
546 use core_crypto_keystore::CryptoKeystoreMls as _;
547 let kps = backend
548 .key_store()
549 .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
550 .await?;
551 Ok(kps)
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use crate::prelude::ClientId;
558 use crate::test_utils::*;
559 use core_crypto_keystore::connection::FetchFromDatabase;
560 use core_crypto_keystore::entities::{EntityFindParams, MlsSignatureKeyPair};
561 use mls_crypto_provider::MlsCryptoProvider;
562 use wasm_bindgen_test::*;
563
564 use super::Client;
565
566 wasm_bindgen_test_configure!(run_in_browser);
567
568 #[apply(all_cred_cipher)]
569 #[wasm_bindgen_test]
570 async fn can_generate_client(case: TestCase) {
571 let backend = MlsCryptoProvider::try_new_in_memory("test").await.unwrap();
572 let x509_test_chain = if case.is_x509() {
573 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
574 x509_test_chain.register_with_provider(&backend).await;
575 Some(x509_test_chain)
576 } else {
577 None
578 };
579 backend.new_transaction().await.unwrap();
580 let _ = Client::random_generate(
581 &case,
582 &backend,
583 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
584 false,
585 )
586 .await
587 .unwrap();
588 }
589
590 #[apply(all_cred_cipher)]
591 #[wasm_bindgen_test]
592 async fn can_externally_generate_client(case: TestCase) {
593 if case.is_basic() {
594 run_tests(move |[tmp_dir_argument]| {
595 Box::pin(async move {
596 let backend = MlsCryptoProvider::try_new(tmp_dir_argument, "test").await.unwrap();
597 backend.new_transaction().await.unwrap();
598 let client_id: ClientId = b"whatever:my:client:is@world.com".to_vec().into();
600 let alice = Client::default();
601 let handles = alice
603 .generate_raw_keypairs(&[case.ciphersuite()], &backend)
604 .await
605 .unwrap();
606
607 let mut identities = backend
608 .keystore()
609 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
610 .await
611 .unwrap();
612
613 assert_eq!(identities.len(), 1);
614
615 let prov_identity = identities.pop().unwrap();
616
617 let prov_client_id: ClientId = prov_identity.credential_id.as_slice().into();
620 assert_eq!(&prov_client_id, handles.first().unwrap());
621
622 alice
624 .init_with_external_client_id(
625 client_id.clone(),
626 handles.clone(),
627 &[case.ciphersuite()],
628 &backend,
629 )
630 .await
631 .unwrap();
632
633 assert_eq!(alice.id().await.unwrap(), client_id);
635 let cb = alice
636 .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
637 .await
638 .unwrap();
639 let client_id: ClientId = cb.credential().identity().into();
640 assert_eq!(&client_id, handles.first().unwrap());
641 })
642 })
643 .await
644 }
645 }
646}