1pub(crate) mod e2e_identity;
2mod epoch_observer;
3mod error;
4pub(crate) mod id;
5pub(crate) mod identifier;
6pub(crate) mod identities;
7pub(crate) mod key_package;
8pub(crate) mod user_id;
9
10use crate::{
11 CoreCrypto, KeystoreError, LeafError, MlsError, MlsTransport, RecursiveError,
12 group_store::GroupStore,
13 mls::{
14 self, HasSessionAndCrypto,
15 conversation::ImmutableConversation,
16 credential::{CredentialBundle, ext::CredentialExt},
17 },
18 prelude::{
19 CertificateBundle, ClientId, ConversationId, INITIAL_KEYING_MATERIAL_COUNT, MlsCiphersuite,
20 MlsClientConfiguration, MlsCredentialType, identifier::ClientIdentifier,
21 key_package::KEYPACKAGE_DEFAULT_LIFETIME,
22 },
23};
24use async_lock::RwLock;
25use core_crypto_keystore::{
26 Connection, CryptoKeystoreError,
27 connection::FetchFromDatabase,
28 entities::{EntityFindParams, MlsCredential, MlsSignatureKeyPair},
29};
30pub use epoch_observer::EpochObserver;
31pub(crate) use error::{Error, Result};
32use identities::Identities;
33use log::debug;
34use mls_crypto_provider::{EntropySeed, MlsCryptoProvider, MlsCryptoProviderConfiguration};
35use openmls::prelude::{Credential, CredentialType};
36use openmls_basic_credential::SignatureKeyPair;
37use openmls_traits::{OpenMlsCryptoProvider, crypto::OpenMlsCrypto, types::SignatureScheme};
38use openmls_x509_credential::CertificateKeyPair;
39use std::ops::{Deref, DerefMut};
40use std::sync::Arc;
41use std::{collections::HashSet, fmt};
42use tls_codec::{Deserialize, Serialize};
43
44#[derive(Clone, Debug)]
55pub struct Session {
56 pub(crate) inner: Arc<RwLock<Option<SessionInner>>>,
57 pub(crate) crypto_provider: MlsCryptoProvider,
58 pub(crate) transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
59}
60
61#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
62#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
63impl HasSessionAndCrypto for Session {
64 async fn session(&self) -> mls::Result<Session> {
65 Ok(self.clone())
66 }
67
68 async fn crypto_provider(&self) -> mls::Result<MlsCryptoProvider> {
69 Ok(self.crypto_provider.clone())
70 }
71}
72
73#[derive(Clone)]
74pub(crate) struct SessionInner {
75 id: ClientId,
76 pub(crate) identities: Identities,
77 keypackage_lifetime: std::time::Duration,
78 epoch_observer: Option<Arc<dyn EpochObserver>>,
79}
80
81impl fmt::Debug for SessionInner {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 let observer_debug = if self.epoch_observer.is_some() {
84 "Some(Arc<dyn EpochObserver>)"
85 } else {
86 "None"
87 };
88 f.debug_struct("ClientInner")
89 .field("id", &self.id)
90 .field("identities", &self.identities)
91 .field("keypackage_lifetime", &self.keypackage_lifetime)
92 .field("epoch_observer", &observer_debug)
93 .finish()
94 }
95}
96
97impl Session {
98 pub async fn try_new(configuration: MlsClientConfiguration) -> crate::mls::Result<Self> {
114 let mls_backend = MlsCryptoProvider::try_new_with_configuration(MlsCryptoProviderConfiguration {
116 db_path: &configuration.store_path,
117 db_key: configuration.database_key.clone(),
118 in_memory: false,
119 entropy_seed: configuration.external_entropy.clone(),
120 })
121 .await
122 .map_err(MlsError::wrap("trying to initialize mls crypto provider object"))?;
123 Self::new_with_backend(mls_backend, configuration).await
124 }
125
126 pub async fn try_new_in_memory(configuration: MlsClientConfiguration) -> crate::mls::Result<Self> {
129 let mls_backend = MlsCryptoProvider::try_new_with_configuration(MlsCryptoProviderConfiguration {
130 db_path: &configuration.store_path,
131 db_key: configuration.database_key.clone(),
132 in_memory: true,
133 entropy_seed: configuration.external_entropy.clone(),
134 })
135 .await
136 .map_err(MlsError::wrap(
137 "trying to initialize mls crypto provider object (in memory)",
138 ))?;
139 Self::new_with_backend(mls_backend, configuration).await
140 }
141
142 async fn new_with_backend(
143 mls_backend: MlsCryptoProvider,
144 configuration: MlsClientConfiguration,
145 ) -> crate::mls::Result<Self> {
146 let client = Self {
150 crypto_provider: mls_backend.clone(),
151 inner: Default::default(),
152 transport: Arc::new(None.into()),
153 };
154
155 let cc = CoreCrypto::from(client.clone());
156 let context = cc
157 .new_transaction()
158 .await
159 .map_err(RecursiveError::transaction("starting new transaction"))?;
160
161 if let Some(id) = configuration.client_id {
162 client
163 .init(
164 ClientIdentifier::Basic(id),
165 configuration.ciphersuites.as_slice(),
166 &mls_backend,
167 configuration
168 .nb_init_key_packages
169 .unwrap_or(INITIAL_KEYING_MATERIAL_COUNT),
170 )
171 .await
172 .map_err(RecursiveError::mls_client("initializing mls client"))?
173 }
174
175 let central = cc.mls;
176 context
177 .init_pki_env()
178 .await
179 .map_err(RecursiveError::transaction("initializing pki environment"))?;
180 context
181 .finish()
182 .await
183 .map_err(RecursiveError::transaction("finishing transaction"))?;
184 Ok(central)
185 }
186
187 pub async fn provide_transport(&self, transport: Arc<dyn MlsTransport>) {
190 self.transport.write().await.replace(transport);
191 }
192
193 pub async fn init(
205 &self,
206 identifier: ClientIdentifier,
207 ciphersuites: &[MlsCiphersuite],
208 backend: &MlsCryptoProvider,
209 nb_key_package: usize,
210 ) -> Result<()> {
211 self.ensure_unready().await?;
212 let id = identifier.get_id()?;
213
214 let credentials = backend
215 .key_store()
216 .find_all::<MlsCredential>(EntityFindParams::default())
217 .await
218 .map_err(KeystoreError::wrap("finding all mls credentials"))?;
219
220 let credentials = credentials
221 .into_iter()
222 .filter(|mls_credential| &mls_credential.id[..] == id.as_slice())
223 .map(|mls_credential| -> Result<_> {
224 let credential = Credential::tls_deserialize(&mut mls_credential.credential.as_slice())
225 .map_err(Error::tls_deserialize("mls credential"))?;
226 Ok((credential, mls_credential.created_at))
227 })
228 .collect::<Result<Vec<_>>>()?;
229
230 if !credentials.is_empty() {
231 let signature_schemes = ciphersuites
232 .iter()
233 .map(|cs| cs.signature_algorithm())
234 .collect::<HashSet<_>>();
235 match self.load(backend, id.as_ref(), credentials, signature_schemes).await {
236 Ok(client) => client,
237 Err(Error::ClientSignatureNotFound) => {
238 debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Client signature not found. Generating client");
239 self.generate(identifier, backend, ciphersuites, nb_key_package).await?
240 }
241 Err(e) => return Err(e),
242 }
243 } else {
244 debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Generating client");
245 self.generate(identifier, backend, ciphersuites, nb_key_package).await?
246 };
247
248 Ok(())
249 }
250
251 #[cfg(test)]
253 pub(crate) async fn reset(&self) {
254 let mut inner_lock = self.inner.write().await;
255 *inner_lock = None;
256 }
257
258 pub(crate) async fn is_ready(&self) -> bool {
259 let inner_lock = self.inner.read().await;
260 inner_lock.is_some()
261 }
262
263 async fn ensure_unready(&self) -> Result<()> {
264 if self.is_ready().await {
265 Err(Error::UnexpectedlyReady)
266 } else {
267 Ok(())
268 }
269 }
270
271 async fn replace_inner(&self, new_inner: SessionInner) {
272 let mut inner_lock = self.inner.write().await;
273 *inner_lock = Some(new_inner);
274 }
275
276 pub async fn get_raw_conversation(&self, id: &ConversationId) -> Result<ImmutableConversation> {
281 let raw_conversation = GroupStore::fetch_from_keystore(id, &self.crypto_provider.keystore(), None)
282 .await
283 .map_err(RecursiveError::root("getting conversation by id"))?
284 .ok_or_else(|| LeafError::ConversationNotFound(id.clone()))?;
285 Ok(ImmutableConversation::new(raw_conversation, self.clone()))
286 }
287
288 pub async fn public_key(
295 &self,
296 ciphersuite: MlsCiphersuite,
297 credential_type: MlsCredentialType,
298 ) -> crate::mls::Result<Vec<u8>> {
299 let cb = self
300 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
301 .await
302 .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
303 Ok(cb.signature_key.to_public_vec())
304 }
305
306 pub(crate) fn new_basic_credential_bundle(
307 id: &ClientId,
308 sc: SignatureScheme,
309 backend: &MlsCryptoProvider,
310 ) -> Result<CredentialBundle> {
311 let (sk, pk) = backend
312 .crypto()
313 .signature_key_gen(sc)
314 .map_err(MlsError::wrap("generating a signature key"))?;
315
316 let signature_key = SignatureKeyPair::from_raw(sc, sk, pk);
317 let credential = Credential::new_basic(id.to_vec());
318 let cb = CredentialBundle {
319 credential,
320 signature_key,
321 created_at: 0,
322 };
323
324 Ok(cb)
325 }
326
327 pub(crate) fn new_x509_credential_bundle(cert: CertificateBundle) -> Result<CredentialBundle> {
328 let created_at = cert
329 .get_created_at()
330 .map_err(RecursiveError::mls_credential("getting credetntial created at"))?;
331 let (sk, ..) = cert.private_key.into_parts();
332 let chain = cert.certificate_chain;
333
334 let kp = CertificateKeyPair::new(sk, chain.clone()).map_err(MlsError::wrap("creating certificate key pair"))?;
335
336 let credential = Credential::new_x509(chain).map_err(MlsError::wrap("creating x509 credential"))?;
337
338 let cb = CredentialBundle {
339 credential,
340 signature_key: kp.0,
341 created_at,
342 };
343 Ok(cb)
344 }
345
346 pub async fn conversation_exists(&self, id: &ConversationId) -> Result<bool> {
348 match self.get_raw_conversation(id).await {
349 Ok(_) => Ok(true),
350 Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
351 Err(e) => Err(e),
352 }
353 }
354
355 pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
357 use openmls_traits::random::OpenMlsRand as _;
358 self.crypto_provider
359 .rand()
360 .random_vec(len)
361 .map_err(MlsError::wrap("generating random vector"))
362 .map_err(Into::into)
363 }
364
365 pub async fn can_close(&self) -> bool {
369 self.crypto_provider.can_close().await
370 }
371
372 pub async fn close(self) -> crate::mls::Result<()> {
377 self.crypto_provider
378 .close()
379 .await
380 .map_err(MlsError::wrap("closing connection with keystore"))
381 .map_err(Into::into)
382 }
383
384 pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
386 self.crypto_provider
387 .reseed(seed)
388 .map_err(MlsError::wrap("reseeding mls backend"))
389 .map_err(Into::into)
390 }
391
392 pub async fn generate_raw_keypairs(
402 &self,
403 ciphersuites: &[MlsCiphersuite],
404 backend: &MlsCryptoProvider,
405 ) -> Result<Vec<ClientId>> {
406 self.ensure_unready().await?;
407 const TEMP_KEY_SIZE: usize = 16;
408
409 let credentials = Self::find_all_basic_credentials(backend).await?;
410 if !credentials.is_empty() {
411 return Err(Error::IdentityAlreadyPresent);
412 }
413
414 use openmls_traits::random::OpenMlsRand as _;
415 let mut tmp_client_ids = Vec::with_capacity(ciphersuites.len());
417 for cs in ciphersuites {
418 let tmp_client_id: ClientId = backend
419 .rand()
420 .random_vec(TEMP_KEY_SIZE)
421 .map_err(MlsError::wrap("generating random client id"))?
422 .into();
423
424 let cb = Self::new_basic_credential_bundle(&tmp_client_id, cs.signature_algorithm(), backend)?;
425
426 let sign_kp = MlsSignatureKeyPair::new(
427 cs.signature_algorithm(),
428 cb.signature_key.to_public_vec(),
429 cb.signature_key
430 .tls_serialize_detached()
431 .map_err(Error::tls_serialize("signature key"))?,
432 tmp_client_id.clone().into(),
433 );
434 backend
435 .key_store()
436 .save(sign_kp)
437 .await
438 .map_err(KeystoreError::wrap("save signature keypair in keystore"))?;
439
440 tmp_client_ids.push(tmp_client_id);
441 }
442
443 Ok(tmp_client_ids)
444 }
445
446 pub async fn init_with_external_client_id(
456 &self,
457 client_id: ClientId,
458 tmp_ids: Vec<ClientId>,
459 ciphersuites: &[MlsCiphersuite],
460 backend: &MlsCryptoProvider,
461 ) -> Result<()> {
462 self.ensure_unready().await?;
463 let stored_skp = backend
465 .key_store()
466 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
467 .await
468 .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
469
470 match stored_skp.len().cmp(&tmp_ids.len()) {
471 std::cmp::Ordering::Less => return Err(Error::NoProvisionalIdentityFound),
472 std::cmp::Ordering::Greater => return Err(Error::TooManyIdentitiesPresent),
473 _ => {}
474 }
475
476 let all_tmp_ids_exist = stored_skp
478 .iter()
479 .all(|kp| tmp_ids.contains(&kp.credential_id.as_slice().into()));
480 if !all_tmp_ids_exist {
481 return Err(Error::NoProvisionalIdentityFound);
482 }
483
484 let identities = stored_skp.iter().zip(ciphersuites);
485
486 self.replace_inner(SessionInner {
487 id: client_id.clone(),
488 identities: Identities::new(stored_skp.len()),
489 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
490 epoch_observer: None,
491 })
492 .await;
493
494 let id = &client_id;
495
496 for (tmp_kp, &cs) in identities {
497 let scheme = tmp_kp
498 .signature_scheme
499 .try_into()
500 .map_err(|_| Error::InvalidSignatureScheme)?;
501 let new_keypair =
502 MlsSignatureKeyPair::new(scheme, tmp_kp.pk.clone(), tmp_kp.keypair.clone(), id.clone().into());
503
504 let new_credential = MlsCredential {
505 id: id.clone().into(),
506 credential: tmp_kp.credential_id.clone(),
507 created_at: 0,
508 };
509
510 backend
512 .key_store()
513 .remove::<MlsSignatureKeyPair, &[u8]>(&new_keypair.pk)
514 .await
515 .map_err(KeystoreError::wrap("removing mls signature keypair"))?;
516
517 let signature_key = SignatureKeyPair::tls_deserialize(&mut new_keypair.keypair.as_slice())
518 .map_err(Error::tls_deserialize("signature key"))?;
519 let cb = CredentialBundle {
520 credential: Credential::new_basic(new_credential.credential.clone()),
521 signature_key,
522 created_at: 0, };
524
525 self.save_identity(&backend.keystore(), Some(id), cs.signature_algorithm(), cb)
527 .await?;
528 }
529
530 Ok(())
531 }
532
533 pub(crate) async fn generate(
535 &self,
536 identifier: ClientIdentifier,
537 backend: &MlsCryptoProvider,
538 ciphersuites: &[MlsCiphersuite],
539 nb_key_package: usize,
540 ) -> Result<()> {
541 self.ensure_unready().await?;
542 let id = identifier.get_id()?;
543 let signature_schemes = ciphersuites
544 .iter()
545 .map(|cs| cs.signature_algorithm())
546 .collect::<HashSet<_>>();
547 self.replace_inner(SessionInner {
548 id: id.into_owned(),
549 identities: Identities::new(signature_schemes.len()),
550 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
551 epoch_observer: None,
552 })
553 .await;
554
555 let identities = identifier.generate_credential_bundles(backend, signature_schemes)?;
556
557 for (sc, id, cb) in identities {
558 self.save_identity(&backend.keystore(), Some(&id), sc, cb).await?;
559 }
560
561 let identities = match self.inner.read().await.deref() {
562 None => return Err(Error::MlsNotInitialized),
563 Some(SessionInner { identities, .. }) => identities.clone(),
567 };
568
569 if nb_key_package != 0 {
570 for cs in ciphersuites {
571 let sc = cs.signature_algorithm();
572 let identity = identities.iter().filter(|(id_sc, _)| id_sc == &sc);
573 for (_, cb) in identity {
574 self.request_key_packages(nb_key_package, *cs, cb.credential.credential_type().into(), backend)
575 .await?;
576 }
577 }
578 }
579
580 Ok(())
581 }
582
583 pub(crate) async fn load(
585 &self,
586 backend: &MlsCryptoProvider,
587 id: &ClientId,
588 mut credentials: Vec<(Credential, u64)>,
589 signature_schemes: HashSet<SignatureScheme>,
590 ) -> Result<()> {
591 self.ensure_unready().await?;
592 let mut identities = Identities::new(signature_schemes.len());
593
594 credentials.sort_by_key(|(_, timestamp)| *timestamp);
596
597 let store_skps = backend
598 .key_store()
599 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
600 .await
601 .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
602
603 for sc in signature_schemes {
604 let kp = store_skps.iter().find(|skp| skp.signature_scheme == (sc as u16));
605
606 let signature_key = if let Some(kp) = kp {
607 SignatureKeyPair::tls_deserialize(&mut kp.keypair.as_slice())
608 .map_err(Error::tls_deserialize("signature keypair"))?
609 } else {
610 let (sk, pk) = backend
611 .crypto()
612 .signature_key_gen(sc)
613 .map_err(MlsError::wrap("generating signature key"))?;
614 let keypair = SignatureKeyPair::from_raw(sc, sk, pk.clone());
615 let raw_keypair = keypair
616 .tls_serialize_detached()
617 .map_err(Error::tls_serialize("raw keypair"))?;
618 let store_keypair = MlsSignatureKeyPair::new(sc, pk, raw_keypair, id.as_slice().into());
619 backend
620 .key_store()
621 .save(store_keypair.clone())
622 .await
623 .map_err(KeystoreError::wrap("storing keypairs in keystore"))?;
624 SignatureKeyPair::tls_deserialize(&mut store_keypair.keypair.as_slice())
625 .map_err(Error::tls_deserialize("signature keypair"))?
626 };
627
628 for (credential, created_at) in &credentials {
629 match credential.mls_credential() {
630 openmls::prelude::MlsCredentialType::Basic(_) => {
631 if id.as_slice() != credential.identity() {
632 return Err(Error::WrongCredential);
633 }
634 }
635 openmls::prelude::MlsCredentialType::X509(cert) => {
636 let spk = cert
637 .extract_public_key()
638 .map_err(RecursiveError::mls_credential("extracting public key"))?
639 .ok_or(LeafError::InternalMlsError)?;
640 if signature_key.public() != spk {
641 return Err(Error::WrongCredential);
642 }
643 }
644 };
645 let cb = CredentialBundle {
646 credential: credential.clone(),
647 signature_key: signature_key.clone(),
648 created_at: *created_at,
649 };
650 identities.push_credential_bundle(sc, cb).await?;
651 }
652 }
653 self.replace_inner(SessionInner {
654 id: id.clone(),
655 identities,
656 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
657 epoch_observer: None,
658 })
659 .await;
660 Ok(())
661 }
662
663 async fn find_all_basic_credentials(backend: &MlsCryptoProvider) -> Result<Vec<Credential>> {
664 let store_credentials = backend
665 .key_store()
666 .find_all::<MlsCredential>(EntityFindParams::default())
667 .await
668 .map_err(KeystoreError::wrap("finding all mls credentialss"))?;
669 let mut credentials = Vec::with_capacity(store_credentials.len());
670 for store_credential in store_credentials.into_iter() {
671 let credential = Credential::tls_deserialize(&mut store_credential.credential.as_slice())
672 .map_err(Error::tls_deserialize("credential"))?;
673 if !matches!(credential.credential_type(), CredentialType::Basic) {
674 continue;
675 }
676 credentials.push(credential);
677 }
678
679 Ok(credentials)
680 }
681
682 pub(crate) async fn save_identity(
683 &self,
684 keystore: &Connection,
685 id: Option<&ClientId>,
686 sc: SignatureScheme,
687 mut cb: CredentialBundle,
688 ) -> Result<CredentialBundle> {
689 match self.inner.write().await.deref_mut() {
690 None => Err(Error::MlsNotInitialized),
691 Some(SessionInner {
692 id: existing_id,
693 identities,
694 ..
695 }) => {
696 let id = id.unwrap_or(existing_id);
697
698 let credential = cb
699 .credential
700 .tls_serialize_detached()
701 .map_err(Error::tls_serialize("credential bundle"))?;
702 let credential = MlsCredential {
703 id: id.clone().into(),
704 credential,
705 created_at: 0,
706 };
707
708 let credential = keystore
709 .save(credential)
710 .await
711 .map_err(KeystoreError::wrap("saving credential"))?;
712
713 let sign_kp = MlsSignatureKeyPair::new(
714 sc,
715 cb.signature_key.to_public_vec(),
716 cb.signature_key
717 .tls_serialize_detached()
718 .map_err(Error::tls_serialize("signature keypair"))?,
719 id.clone().into(),
720 );
721 keystore.save(sign_kp).await.map_err(|e| match e {
722 CryptoKeystoreError::AlreadyExists => Error::CredentialBundleConflict,
723 _ => KeystoreError::wrap("saving mls signature key pair")(e).into(),
724 })?;
725
726 cb.created_at = credential.created_at;
728
729 identities.push_credential_bundle(sc, cb.clone()).await?;
730
731 Ok(cb)
732 }
733 }
734 }
735
736 pub async fn id(&self) -> Result<ClientId> {
738 match self.inner.read().await.deref() {
739 None => Err(Error::MlsNotInitialized),
740 Some(SessionInner { id, .. }) => Ok(id.clone()),
741 }
742 }
743
744 pub async fn is_e2ei_capable(&self) -> bool {
746 match self.inner.read().await.deref() {
747 None => false,
748 Some(SessionInner { identities, .. }) => identities
749 .iter()
750 .any(|(_, cred)| cred.credential().credential_type() == CredentialType::X509),
751 }
752 }
753
754 pub(crate) async fn get_most_recent_or_create_credential_bundle(
755 &self,
756 backend: &MlsCryptoProvider,
757 sc: SignatureScheme,
758 ct: MlsCredentialType,
759 ) -> Result<Arc<CredentialBundle>> {
760 match ct {
761 MlsCredentialType::Basic => {
762 self.init_basic_credential_bundle_if_missing(backend, sc).await?;
763 self.find_most_recent_credential_bundle(sc, ct).await
764 }
765 MlsCredentialType::X509 => self
766 .find_most_recent_credential_bundle(sc, ct)
767 .await
768 .map_err(|e| match e {
769 Error::CredentialNotFound(_) => LeafError::E2eiEnrollmentNotDone.into(),
770 _ => e,
771 }),
772 }
773 }
774
775 pub(crate) async fn init_basic_credential_bundle_if_missing(
776 &self,
777 backend: &MlsCryptoProvider,
778 sc: SignatureScheme,
779 ) -> Result<()> {
780 let existing_cb = self
781 .find_most_recent_credential_bundle(sc, MlsCredentialType::Basic)
782 .await;
783 if matches!(existing_cb, Err(Error::CredentialNotFound(_))) {
784 let id = self.id().await?;
785 debug!(id:% = &id; "Initializing basic credential bundle");
786 let cb = Self::new_basic_credential_bundle(&id, sc, backend)?;
787 self.save_identity(&backend.keystore(), None, sc, cb).await?;
788 }
789 Ok(())
790 }
791
792 pub(crate) async fn save_new_x509_credential_bundle(
793 &self,
794 keystore: &Connection,
795 sc: SignatureScheme,
796 cb: CertificateBundle,
797 ) -> Result<CredentialBundle> {
798 let id = cb
799 .get_client_id()
800 .map_err(RecursiveError::mls_credential("getting client id"))?;
801 let cb = Self::new_x509_credential_bundle(cb)?;
802 self.save_identity(keystore, Some(&id), sc, cb).await
803 }
804}
805
806#[cfg(test)]
807mod tests {
808 use super::*;
809 use crate::prelude::ClientId;
810 use crate::test_utils::*;
811 use crate::transaction_context::test_utils::EntitiesCount;
812 use core_crypto_keystore::connection::{DatabaseKey, FetchFromDatabase};
813 use core_crypto_keystore::entities::*;
814 use mls_crypto_provider::MlsCryptoProvider;
815 use wasm_bindgen_test::*;
816
817 impl Session {
818 #![allow(missing_docs)]
820
821 pub async fn random_generate(
822 &self,
823 case: &crate::test_utils::TestCase,
824 signer: Option<&crate::test_utils::x509::X509Certificate>,
825 provision: bool,
826 ) -> Result<()> {
827 self.reset().await;
828 let user_uuid = uuid::Uuid::new_v4();
829 let rnd_id = rand::random::<usize>();
830 let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
831 let identity = match case.credential_type {
832 MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.as_str().into()),
833 MlsCredentialType::X509 => {
834 let signer = signer.expect("Missing intermediate CA");
835 CertificateBundle::rand_identifier(&client_id, &[signer])
836 }
837 };
838 let nb_key_package = if provision {
839 crate::prelude::INITIAL_KEYING_MATERIAL_COUNT
840 } else {
841 0
842 };
843 let backend = self.crypto_provider.clone();
844 self.generate(identity, &backend, &[case.ciphersuite()], nb_key_package)
845 .await?;
846 Ok(())
847 }
848
849 pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
850 use core_crypto_keystore::CryptoKeystoreMls as _;
851 let kps = backend
852 .key_store()
853 .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
854 .await
855 .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
856 Ok(kps)
857 }
858
859 pub(crate) async fn init_x509_credential_bundle_if_missing(
860 &self,
861 backend: &MlsCryptoProvider,
862 sc: SignatureScheme,
863 cb: CertificateBundle,
864 ) -> Result<()> {
865 let existing_cb = self
866 .find_most_recent_credential_bundle(sc, MlsCredentialType::X509)
867 .await
868 .is_err();
869 if existing_cb {
870 self.save_new_x509_credential_bundle(&backend.keystore(), sc, cb)
871 .await?;
872 }
873 Ok(())
874 }
875
876 pub(crate) async fn generate_one_keypackage(
877 &self,
878 backend: &MlsCryptoProvider,
879 cs: MlsCiphersuite,
880 ct: MlsCredentialType,
881 ) -> Result<openmls::prelude::KeyPackage> {
882 let cb = self
883 .find_most_recent_credential_bundle(cs.signature_algorithm(), ct)
884 .await?;
885 self.generate_one_keypackage_from_credential_bundle(backend, cs, &cb)
886 .await
887 }
888
889 pub async fn count_entities(&self) -> EntitiesCount {
891 let keystore = self.crypto_provider.keystore();
892 let credential = keystore.count::<MlsCredential>().await.unwrap();
893 let encryption_keypair = keystore.count::<MlsEncryptionKeyPair>().await.unwrap();
894 let epoch_encryption_keypair = keystore.count::<MlsEpochEncryptionKeyPair>().await.unwrap();
895 let enrollment = keystore.count::<E2eiEnrollment>().await.unwrap();
896 let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
897 let hpke_private_key = keystore.count::<MlsHpkePrivateKey>().await.unwrap();
898 let key_package = keystore.count::<MlsKeyPackage>().await.unwrap();
899 let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
900 let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
901 let psk_bundle = keystore.count::<MlsPskBundle>().await.unwrap();
902 let signature_keypair = keystore.count::<MlsSignatureKeyPair>().await.unwrap();
903 EntitiesCount {
904 credential,
905 encryption_keypair,
906 epoch_encryption_keypair,
907 enrollment,
908 group,
909 hpke_private_key,
910 key_package,
911 pending_group,
912 pending_messages,
913 psk_bundle,
914 signature_keypair,
915 }
916 }
917 }
918 wasm_bindgen_test_configure!(run_in_browser);
919
920 #[apply(all_cred_cipher)]
921 #[wasm_bindgen_test]
922 async fn can_generate_client(case: TestCase) {
923 run_test_with_central(case.clone(), move |[alice]| {
924 Box::pin(async move {
925 let key = DatabaseKey::generate();
926 let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
927 let x509_test_chain = if case.is_x509() {
928 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
929 x509_test_chain.register_with_provider(&backend).await;
930 Some(x509_test_chain)
931 } else {
932 None
933 };
934 backend.new_transaction().await.unwrap();
935 let client = alice.session().await;
936 client
937 .random_generate(
938 &case,
939 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
940 false,
941 )
942 .await
943 .unwrap();
944 })
945 })
946 .await
947 }
948
949 #[apply(all_cred_cipher)]
950 #[wasm_bindgen_test]
951 async fn can_externally_generate_client(case: TestCase) {
952 run_test_with_central(case.clone(), move |[alice]| {
953 Box::pin(async move {
954 if case.is_basic() {
955 run_tests(move |[tmp_dir_argument]| {
956 Box::pin(async move {
957 let key = DatabaseKey::generate();
958 let backend = MlsCryptoProvider::try_new(tmp_dir_argument, &key).await.unwrap();
959 backend.new_transaction().await.unwrap();
960 let client_id: ClientId = b"whatever:my:client:is@world.com".to_vec().into();
962 let alice = alice.session().await;
963 alice.reset().await;
964 let handles = alice
966 .generate_raw_keypairs(&[case.ciphersuite()], &backend)
967 .await
968 .unwrap();
969
970 let mut identities = backend
971 .keystore()
972 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
973 .await
974 .unwrap();
975
976 assert_eq!(identities.len(), 1);
977
978 let prov_identity = identities.pop().unwrap();
979
980 let prov_client_id: ClientId = prov_identity.credential_id.as_slice().into();
983 assert_eq!(&prov_client_id, handles.first().unwrap());
984
985 alice
987 .init_with_external_client_id(
988 client_id.clone(),
989 handles.clone(),
990 &[case.ciphersuite()],
991 &backend,
992 )
993 .await
994 .unwrap();
995
996 assert_eq!(alice.id().await.unwrap(), client_id);
998 let cb = alice
999 .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
1000 .await
1001 .unwrap();
1002 let client_id: ClientId = cb.credential().identity().into();
1003 assert_eq!(&client_id, handles.first().unwrap());
1004 })
1005 })
1006 .await
1007 }
1008 })
1009 })
1010 .await
1011 }
1012}