1mod epoch_observer;
18mod error;
19pub(crate) mod id;
20pub(crate) mod identifier;
21pub(crate) mod identities;
22pub(crate) mod key_package;
23pub(crate) mod user_id;
24
25use crate::{
26 KeystoreError, LeafError, MlsError, RecursiveError,
27 mls::credential::{CredentialBundle, ext::CredentialExt},
28 prelude::{
29 CertificateBundle, ClientId, MlsCiphersuite, MlsCredentialType, identifier::ClientIdentifier,
30 key_package::KEYPACKAGE_DEFAULT_LIFETIME,
31 },
32};
33pub use epoch_observer::EpochObserver;
34pub(crate) use error::{Error, Result};
35
36use async_lock::RwLock;
37use core_crypto_keystore::{Connection, CryptoKeystoreError, connection::FetchFromDatabase};
38use log::debug;
39use openmls::prelude::{Credential, CredentialType};
40use openmls_basic_credential::SignatureKeyPair;
41use openmls_traits::{OpenMlsCryptoProvider, crypto::OpenMlsCrypto, types::SignatureScheme};
42use std::ops::{Deref, DerefMut};
43use std::sync::Arc;
44use std::{collections::HashSet, fmt};
45use tls_codec::{Deserialize, Serialize};
46
47use core_crypto_keystore::entities::{EntityFindParams, MlsCredential, MlsSignatureKeyPair};
48use identities::ClientIdentities;
49use mls_crypto_provider::MlsCryptoProvider;
50
51#[derive(Clone, Debug, Default)]
59pub struct Client {
60 state: Arc<RwLock<Option<ClientInner>>>,
61}
62
63#[derive(Clone)]
64struct ClientInner {
65 id: ClientId,
66 pub(crate) identities: ClientIdentities,
67 keypackage_lifetime: std::time::Duration,
68 epoch_observer: Option<Arc<dyn EpochObserver>>,
69}
70
71impl fmt::Debug for ClientInner {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 let observer_debug = if self.epoch_observer.is_some() {
74 "Some(Arc<dyn EpochObserver>)"
75 } else {
76 "None"
77 };
78 f.debug_struct("ClientInner")
79 .field("id", &self.id)
80 .field("identities", &self.identities)
81 .field("keypackage_lifetime", &self.keypackage_lifetime)
82 .field("epoch_observer", &observer_debug)
83 .finish()
84 }
85}
86
87impl Client {
88 pub async fn init(
100 &self,
101 identifier: ClientIdentifier,
102 ciphersuites: &[MlsCiphersuite],
103 backend: &MlsCryptoProvider,
104 nb_key_package: usize,
105 ) -> Result<()> {
106 self.ensure_unready().await?;
107 let id = identifier.get_id()?;
108
109 let credentials = backend
110 .key_store()
111 .find_all::<MlsCredential>(EntityFindParams::default())
112 .await
113 .map_err(KeystoreError::wrap("finding all mls credentials"))?;
114
115 let credentials = credentials
116 .into_iter()
117 .filter(|mls_credential| &mls_credential.id[..] == id.as_slice())
118 .map(|mls_credential| -> Result<_> {
119 let credential = Credential::tls_deserialize(&mut mls_credential.credential.as_slice())
120 .map_err(Error::tls_deserialize("mls credential"))?;
121 Ok((credential, mls_credential.created_at))
122 })
123 .collect::<Result<Vec<_>>>()?;
124
125 if !credentials.is_empty() {
126 let signature_schemes = ciphersuites
127 .iter()
128 .map(|cs| cs.signature_algorithm())
129 .collect::<HashSet<_>>();
130 match self.load(backend, id.as_ref(), credentials, signature_schemes).await {
131 Ok(client) => client,
132 Err(Error::ClientSignatureNotFound) => {
133 debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Client signature not found. Generating client");
134 self.generate(identifier, backend, ciphersuites, nb_key_package).await?
135 }
136 Err(e) => return Err(e),
137 }
138 } else {
139 debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Generating client");
140 self.generate(identifier, backend, ciphersuites, nb_key_package).await?
141 };
142
143 Ok(())
144 }
145
146 pub(crate) async fn is_ready(&self) -> bool {
147 let inner_lock = self.state.read().await;
148 inner_lock.is_some()
149 }
150
151 async fn ensure_unready(&self) -> Result<()> {
152 if self.is_ready().await {
153 Err(Error::UnexpectedlyReady)
154 } else {
155 Ok(())
156 }
157 }
158
159 async fn replace_inner(&self, new_inner: ClientInner) {
160 let mut inner_lock = self.state.write().await;
161 *inner_lock = Some(new_inner);
162 }
163
164 pub async fn generate_raw_keypairs(
174 &self,
175 ciphersuites: &[MlsCiphersuite],
176 backend: &MlsCryptoProvider,
177 ) -> Result<Vec<ClientId>> {
178 self.ensure_unready().await?;
179 const TEMP_KEY_SIZE: usize = 16;
180
181 let credentials = Self::find_all_basic_credentials(backend).await?;
182 if !credentials.is_empty() {
183 return Err(Error::IdentityAlreadyPresent);
184 }
185
186 use openmls_traits::random::OpenMlsRand as _;
187 let mut tmp_client_ids = Vec::with_capacity(ciphersuites.len());
189 for cs in ciphersuites {
190 let tmp_client_id: ClientId = backend
191 .rand()
192 .random_vec(TEMP_KEY_SIZE)
193 .map_err(MlsError::wrap("generating random client id"))?
194 .into();
195
196 let cb = Self::new_basic_credential_bundle(&tmp_client_id, cs.signature_algorithm(), backend)
197 .map_err(RecursiveError::mls_credential("creating new basic credential bundle"))?;
198
199 let sign_kp = MlsSignatureKeyPair::new(
200 cs.signature_algorithm(),
201 cb.signature_key.to_public_vec(),
202 cb.signature_key
203 .tls_serialize_detached()
204 .map_err(Error::tls_serialize("signature key"))?,
205 tmp_client_id.clone().into(),
206 );
207 backend
208 .key_store()
209 .save(sign_kp)
210 .await
211 .map_err(KeystoreError::wrap("save signature keypair in keystore"))?;
212
213 tmp_client_ids.push(tmp_client_id);
214 }
215
216 Ok(tmp_client_ids)
217 }
218
219 pub async fn init_with_external_client_id(
229 &self,
230 client_id: ClientId,
231 tmp_ids: Vec<ClientId>,
232 ciphersuites: &[MlsCiphersuite],
233 backend: &MlsCryptoProvider,
234 ) -> Result<()> {
235 self.ensure_unready().await?;
236 let stored_skp = backend
238 .key_store()
239 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
240 .await
241 .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
242
243 match stored_skp.len().cmp(&tmp_ids.len()) {
244 std::cmp::Ordering::Less => return Err(Error::NoProvisionalIdentityFound),
245 std::cmp::Ordering::Greater => return Err(Error::TooManyIdentitiesPresent),
246 _ => {}
247 }
248
249 let all_tmp_ids_exist = stored_skp
251 .iter()
252 .all(|kp| tmp_ids.contains(&kp.credential_id.as_slice().into()));
253 if !all_tmp_ids_exist {
254 return Err(Error::NoProvisionalIdentityFound);
255 }
256
257 let identities = stored_skp.iter().zip(ciphersuites);
258
259 self.replace_inner(ClientInner {
260 id: client_id.clone(),
261 identities: ClientIdentities::new(stored_skp.len()),
262 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
263 epoch_observer: None,
264 })
265 .await;
266
267 let id = &client_id;
268
269 for (tmp_kp, &cs) in identities {
270 let scheme = tmp_kp
271 .signature_scheme
272 .try_into()
273 .map_err(|_| Error::InvalidSignatureScheme)?;
274 let new_keypair =
275 MlsSignatureKeyPair::new(scheme, tmp_kp.pk.clone(), tmp_kp.keypair.clone(), id.clone().into());
276
277 let new_credential = MlsCredential {
278 id: id.clone().into(),
279 credential: tmp_kp.credential_id.clone(),
280 created_at: 0,
281 };
282
283 backend
285 .key_store()
286 .remove::<MlsSignatureKeyPair, &[u8]>(&new_keypair.pk)
287 .await
288 .map_err(KeystoreError::wrap("removing mls signature keypair"))?;
289
290 let signature_key = SignatureKeyPair::tls_deserialize(&mut new_keypair.keypair.as_slice())
291 .map_err(Error::tls_deserialize("signature key"))?;
292 let cb = CredentialBundle {
293 credential: Credential::new_basic(new_credential.credential.clone()),
294 signature_key,
295 created_at: 0, };
297
298 self.save_identity(&backend.keystore(), Some(id), cs.signature_algorithm(), cb)
300 .await?;
301 }
302
303 Ok(())
304 }
305
306 pub(crate) async fn generate(
308 &self,
309 identifier: ClientIdentifier,
310 backend: &MlsCryptoProvider,
311 ciphersuites: &[MlsCiphersuite],
312 nb_key_package: usize,
313 ) -> Result<()> {
314 self.ensure_unready().await?;
315 let id = identifier.get_id()?;
316 let signature_schemes = ciphersuites
317 .iter()
318 .map(|cs| cs.signature_algorithm())
319 .collect::<HashSet<_>>();
320 self.replace_inner(ClientInner {
321 id: id.into_owned(),
322 identities: ClientIdentities::new(signature_schemes.len()),
323 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
324 epoch_observer: None,
325 })
326 .await;
327
328 let identities = identifier.generate_credential_bundles(backend, signature_schemes)?;
329
330 for (sc, id, cb) in identities {
331 self.save_identity(&backend.keystore(), Some(&id), sc, cb).await?;
332 }
333
334 let identities = match self.state.read().await.deref() {
335 None => return Err(Error::MlsNotInitialized),
336 Some(ClientInner { identities, .. }) => identities.clone(),
340 };
341
342 if nb_key_package != 0 {
343 for cs in ciphersuites {
344 let sc = cs.signature_algorithm();
345 let identity = identities.iter().filter(|(id_sc, _)| id_sc == &sc);
346 for (_, cb) in identity {
347 self.request_key_packages(nb_key_package, *cs, cb.credential.credential_type().into(), backend)
348 .await?;
349 }
350 }
351 }
352
353 Ok(())
354 }
355
356 pub(crate) async fn load(
358 &self,
359 backend: &MlsCryptoProvider,
360 id: &ClientId,
361 mut credentials: Vec<(Credential, u64)>,
362 signature_schemes: HashSet<SignatureScheme>,
363 ) -> Result<()> {
364 self.ensure_unready().await?;
365 let mut identities = ClientIdentities::new(signature_schemes.len());
366
367 credentials.sort_by_key(|(_, timestamp)| *timestamp);
369
370 let store_skps = backend
371 .key_store()
372 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
373 .await
374 .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
375
376 for sc in signature_schemes {
377 let kp = store_skps.iter().find(|skp| skp.signature_scheme == (sc as u16));
378
379 let signature_key = if let Some(kp) = kp {
380 SignatureKeyPair::tls_deserialize(&mut kp.keypair.as_slice())
381 .map_err(Error::tls_deserialize("signature keypair"))?
382 } else {
383 let (sk, pk) = backend
384 .crypto()
385 .signature_key_gen(sc)
386 .map_err(MlsError::wrap("generating signature key"))?;
387 let keypair = SignatureKeyPair::from_raw(sc, sk, pk.clone());
388 let raw_keypair = keypair
389 .tls_serialize_detached()
390 .map_err(Error::tls_serialize("raw keypair"))?;
391 let store_keypair = MlsSignatureKeyPair::new(sc, pk, raw_keypair, id.as_slice().into());
392 backend
393 .key_store()
394 .save(store_keypair.clone())
395 .await
396 .map_err(KeystoreError::wrap("storing keypairs in keystore"))?;
397 SignatureKeyPair::tls_deserialize(&mut store_keypair.keypair.as_slice())
398 .map_err(Error::tls_deserialize("signature keypair"))?
399 };
400
401 for (credential, created_at) in &credentials {
402 match credential.mls_credential() {
403 openmls::prelude::MlsCredentialType::Basic(_) => {
404 if id.as_slice() != credential.identity() {
405 return Err(Error::WrongCredential);
406 }
407 }
408 openmls::prelude::MlsCredentialType::X509(cert) => {
409 let spk = cert
410 .extract_public_key()
411 .map_err(RecursiveError::mls_credential("extracting public key"))?
412 .ok_or(LeafError::InternalMlsError)?;
413 if signature_key.public() != spk {
414 return Err(Error::WrongCredential);
415 }
416 }
417 };
418 let cb = CredentialBundle {
419 credential: credential.clone(),
420 signature_key: signature_key.clone(),
421 created_at: *created_at,
422 };
423 identities.push_credential_bundle(sc, cb).await?;
424 }
425 }
426 self.replace_inner(ClientInner {
427 id: id.clone(),
428 identities,
429 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
430 epoch_observer: None,
431 })
432 .await;
433 Ok(())
434 }
435
436 async fn find_all_basic_credentials(backend: &MlsCryptoProvider) -> Result<Vec<Credential>> {
437 let store_credentials = backend
438 .key_store()
439 .find_all::<MlsCredential>(EntityFindParams::default())
440 .await
441 .map_err(KeystoreError::wrap("finding all mls credentialss"))?;
442 let mut credentials = Vec::with_capacity(store_credentials.len());
443 for store_credential in store_credentials.into_iter() {
444 let credential = Credential::tls_deserialize(&mut store_credential.credential.as_slice())
445 .map_err(Error::tls_deserialize("credential"))?;
446 if !matches!(credential.credential_type(), CredentialType::Basic) {
447 continue;
448 }
449 credentials.push(credential);
450 }
451
452 Ok(credentials)
453 }
454
455 pub(crate) async fn save_identity(
456 &self,
457 keystore: &Connection,
458 id: Option<&ClientId>,
459 sc: SignatureScheme,
460 mut cb: CredentialBundle,
461 ) -> Result<CredentialBundle> {
462 match self.state.write().await.deref_mut() {
463 None => Err(Error::MlsNotInitialized),
464 Some(ClientInner {
465 id: existing_id,
466 identities,
467 ..
468 }) => {
469 let id = id.unwrap_or(existing_id);
470
471 let credential = cb
472 .credential
473 .tls_serialize_detached()
474 .map_err(Error::tls_serialize("credential bundle"))?;
475 let credential = MlsCredential {
476 id: id.clone().into(),
477 credential,
478 created_at: 0,
479 };
480
481 let credential = keystore
482 .save(credential)
483 .await
484 .map_err(KeystoreError::wrap("saving credential"))?;
485
486 let sign_kp = MlsSignatureKeyPair::new(
487 sc,
488 cb.signature_key.to_public_vec(),
489 cb.signature_key
490 .tls_serialize_detached()
491 .map_err(Error::tls_serialize("signature keypair"))?,
492 id.clone().into(),
493 );
494 keystore.save(sign_kp).await.map_err(|e| match e {
495 CryptoKeystoreError::AlreadyExists => Error::CredentialBundleConflict,
496 _ => KeystoreError::wrap("saving mls signature key pair")(e).into(),
497 })?;
498
499 cb.created_at = credential.created_at;
501
502 identities.push_credential_bundle(sc, cb.clone()).await?;
503
504 Ok(cb)
505 }
506 }
507 }
508
509 pub async fn id(&self) -> Result<ClientId> {
511 match self.state.read().await.deref() {
512 None => Err(Error::MlsNotInitialized),
513 Some(ClientInner { id, .. }) => Ok(id.clone()),
514 }
515 }
516
517 pub async fn is_e2ei_capable(&self) -> bool {
519 match self.state.read().await.deref() {
520 None => false,
521 Some(ClientInner { identities, .. }) => identities
522 .iter()
523 .any(|(_, cred)| cred.credential().credential_type() == CredentialType::X509),
524 }
525 }
526
527 pub(crate) async fn get_most_recent_or_create_credential_bundle(
528 &self,
529 backend: &MlsCryptoProvider,
530 sc: SignatureScheme,
531 ct: MlsCredentialType,
532 ) -> Result<Arc<CredentialBundle>> {
533 match ct {
534 MlsCredentialType::Basic => {
535 self.init_basic_credential_bundle_if_missing(backend, sc).await?;
536 self.find_most_recent_credential_bundle(sc, ct).await
537 }
538 MlsCredentialType::X509 => self
539 .find_most_recent_credential_bundle(sc, ct)
540 .await
541 .map_err(|e| match e {
542 Error::CredentialNotFound(_) => LeafError::E2eiEnrollmentNotDone.into(),
543 _ => e,
544 }),
545 }
546 }
547
548 pub(crate) async fn init_basic_credential_bundle_if_missing(
549 &self,
550 backend: &MlsCryptoProvider,
551 sc: SignatureScheme,
552 ) -> Result<()> {
553 let existing_cb = self
554 .find_most_recent_credential_bundle(sc, MlsCredentialType::Basic)
555 .await;
556 if matches!(existing_cb, Err(Error::CredentialNotFound(_))) {
557 let id = self.id().await?;
558 debug!(id:% = &id; "Initializing basic credential bundle");
559 let cb = Self::new_basic_credential_bundle(&id, sc, backend)
560 .map_err(RecursiveError::mls_credential("creating new basic credential bundle"))?;
561 self.save_identity(&backend.keystore(), None, sc, cb).await?;
562 }
563 Ok(())
564 }
565
566 pub(crate) async fn save_new_x509_credential_bundle(
567 &self,
568 keystore: &Connection,
569 sc: SignatureScheme,
570 cb: CertificateBundle,
571 ) -> Result<CredentialBundle> {
572 let id = cb
573 .get_client_id()
574 .map_err(RecursiveError::mls_credential("getting client id"))?;
575 let cb = Self::new_x509_credential_bundle(cb)
576 .map_err(RecursiveError::mls_credential("creating new x509 credential bundle"))?;
577 self.save_identity(keystore, Some(&id), sc, cb).await
578 }
579}
580
581#[cfg(test)]
582impl Client {
583 #![allow(missing_docs)]
585
586 pub async fn random_generate(
587 case: &crate::test_utils::TestCase,
588 backend: &MlsCryptoProvider,
589 signer: Option<&crate::test_utils::x509::X509Certificate>,
590 provision: bool,
591 ) -> Result<Self> {
592 let user_uuid = uuid::Uuid::new_v4();
593 let rnd_id = rand::random::<usize>();
594 let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
595 let identity = match case.credential_type {
596 MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.as_str().into()),
597 MlsCredentialType::X509 => {
598 let signer = signer.expect("Missing intermediate CA");
599 CertificateBundle::rand_identifier(&client_id, &[signer])
600 }
601 };
602 let nb_key_package = if provision {
603 crate::prelude::INITIAL_KEYING_MATERIAL_COUNT
604 } else {
605 0
606 };
607 let client = Self::default();
608 client
609 .generate(identity, backend, &[case.ciphersuite()], nb_key_package)
610 .await?;
611 Ok(client)
612 }
613
614 pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
615 use core_crypto_keystore::CryptoKeystoreMls as _;
616 let kps = backend
617 .key_store()
618 .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
619 .await
620 .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
621 Ok(kps)
622 }
623}
624
625#[cfg(test)]
626mod tests {
627 use crate::prelude::ClientId;
628 use crate::test_utils::*;
629 use core_crypto_keystore::connection::FetchFromDatabase;
630 use core_crypto_keystore::entities::{EntityFindParams, MlsSignatureKeyPair};
631 use mls_crypto_provider::MlsCryptoProvider;
632 use wasm_bindgen_test::*;
633
634 use super::Client;
635
636 wasm_bindgen_test_configure!(run_in_browser);
637
638 #[apply(all_cred_cipher)]
639 #[wasm_bindgen_test]
640 async fn can_generate_client(case: TestCase) {
641 let backend = MlsCryptoProvider::try_new_in_memory("test").await.unwrap();
642 let x509_test_chain = if case.is_x509() {
643 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
644 x509_test_chain.register_with_provider(&backend).await;
645 Some(x509_test_chain)
646 } else {
647 None
648 };
649 backend.new_transaction().await.unwrap();
650 let _ = Client::random_generate(
651 &case,
652 &backend,
653 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
654 false,
655 )
656 .await
657 .unwrap();
658 }
659
660 #[apply(all_cred_cipher)]
661 #[wasm_bindgen_test]
662 async fn can_externally_generate_client(case: TestCase) {
663 if case.is_basic() {
664 run_tests(move |[tmp_dir_argument]| {
665 Box::pin(async move {
666 let backend = MlsCryptoProvider::try_new(tmp_dir_argument, "test").await.unwrap();
667 backend.new_transaction().await.unwrap();
668 let client_id: ClientId = b"whatever:my:client:is@world.com".to_vec().into();
670 let alice = Client::default();
671 let handles = alice
673 .generate_raw_keypairs(&[case.ciphersuite()], &backend)
674 .await
675 .unwrap();
676
677 let mut identities = backend
678 .keystore()
679 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
680 .await
681 .unwrap();
682
683 assert_eq!(identities.len(), 1);
684
685 let prov_identity = identities.pop().unwrap();
686
687 let prov_client_id: ClientId = prov_identity.credential_id.as_slice().into();
690 assert_eq!(&prov_client_id, handles.first().unwrap());
691
692 alice
694 .init_with_external_client_id(
695 client_id.clone(),
696 handles.clone(),
697 &[case.ciphersuite()],
698 &backend,
699 )
700 .await
701 .unwrap();
702
703 assert_eq!(alice.id().await.unwrap(), client_id);
705 let cb = alice
706 .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
707 .await
708 .unwrap();
709 let client_id: ClientId = cb.credential().identity().into();
710 assert_eq!(&client_id, handles.first().unwrap());
711 })
712 })
713 .await
714 }
715 }
716}