1pub(crate) mod e2e_identity;
2mod epoch_observer;
3mod error;
4pub(crate) mod id;
5pub(crate) mod identifier;
6pub(crate) mod identities;
7pub(crate) mod key_package;
8pub(crate) mod user_id;
9
10use crate::{
11 CoreCrypto, KeystoreError, LeafError, MlsError, MlsTransport, RecursiveError,
12 group_store::GroupStore,
13 mls::{
14 self, HasSessionAndCrypto,
15 conversation::ImmutableConversation,
16 credential::{CredentialBundle, ext::CredentialExt},
17 },
18 prelude::{
19 CertificateBundle, ClientId, ConversationId, HistorySecret, INITIAL_KEYING_MATERIAL_COUNT, MlsCiphersuite,
20 MlsClientConfiguration, MlsCredentialType, identifier::ClientIdentifier,
21 key_package::KEYPACKAGE_DEFAULT_LIFETIME,
22 },
23};
24use async_lock::RwLock;
25use core_crypto_keystore::{
26 Connection, CryptoKeystoreError,
27 connection::FetchFromDatabase,
28 entities::{EntityFindParams, MlsCredential, MlsSignatureKeyPair},
29};
30pub use epoch_observer::EpochObserver;
31pub(crate) use error::{Error, Result};
32use identities::Identities;
33use log::debug;
34use mls_crypto_provider::{EntropySeed, MlsCryptoProvider, MlsCryptoProviderConfiguration};
35use openmls::prelude::{Credential, CredentialType};
36use openmls_basic_credential::SignatureKeyPair;
37use openmls_traits::{OpenMlsCryptoProvider, crypto::OpenMlsCrypto, types::SignatureScheme};
38use openmls_x509_credential::CertificateKeyPair;
39use std::collections::HashSet;
40use std::ops::Deref;
41use std::sync::Arc;
42use tls_codec::{Deserialize, Serialize};
43
44#[derive(Clone, derive_more::Debug)]
55pub struct Session {
56 pub(crate) inner: Arc<RwLock<Option<SessionInner>>>,
57 pub(crate) crypto_provider: MlsCryptoProvider,
58 pub(crate) transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
59 #[debug("EpochObserver")]
60 pub(crate) epoch_observer: Arc<RwLock<Option<Arc<dyn EpochObserver + 'static>>>>,
61}
62
63#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
64#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
65impl HasSessionAndCrypto for Session {
66 async fn session(&self) -> mls::Result<Session> {
67 Ok(self.clone())
68 }
69
70 async fn crypto_provider(&self) -> mls::Result<MlsCryptoProvider> {
71 Ok(self.crypto_provider.clone())
72 }
73}
74
75#[derive(Clone, Debug)]
76pub(crate) struct SessionInner {
77 id: ClientId,
78 pub(crate) identities: Identities,
79 keypackage_lifetime: std::time::Duration,
80}
81
82impl Session {
83 pub async fn try_new(configuration: MlsClientConfiguration) -> crate::mls::Result<Self> {
99 let mls_backend = MlsCryptoProvider::try_new_with_configuration(MlsCryptoProviderConfiguration {
101 db_path: &configuration.store_path,
102 db_key: configuration.database_key.clone(),
103 in_memory: false,
104 entropy_seed: configuration.external_entropy.clone(),
105 })
106 .await
107 .map_err(MlsError::wrap("trying to initialize mls crypto provider object"))?;
108 Self::new_with_backend(mls_backend, configuration).await
109 }
110
111 pub async fn try_new_in_memory(configuration: MlsClientConfiguration) -> crate::mls::Result<Self> {
114 let mls_backend = MlsCryptoProvider::try_new_with_configuration(MlsCryptoProviderConfiguration {
115 db_path: &configuration.store_path,
116 db_key: configuration.database_key.clone(),
117 in_memory: true,
118 entropy_seed: configuration.external_entropy.clone(),
119 })
120 .await
121 .map_err(MlsError::wrap(
122 "trying to initialize mls crypto provider object (in memory)",
123 ))?;
124 Self::new_with_backend(mls_backend, configuration).await
125 }
126
127 async fn new_with_backend(
128 mls_backend: MlsCryptoProvider,
129 configuration: MlsClientConfiguration,
130 ) -> crate::mls::Result<Self> {
131 let client = Self {
135 crypto_provider: mls_backend.clone(),
136 inner: Default::default(),
137 transport: Arc::new(None.into()),
138 epoch_observer: Arc::new(None.into()),
139 };
140
141 let cc = CoreCrypto::from(client.clone());
142 let context = cc
143 .new_transaction()
144 .await
145 .map_err(RecursiveError::transaction("starting new transaction"))?;
146
147 if let Some(id) = configuration.client_id {
148 client
149 .init(
150 ClientIdentifier::Basic(id),
151 configuration.ciphersuites.as_slice(),
152 &mls_backend,
153 configuration
154 .nb_init_key_packages
155 .unwrap_or(INITIAL_KEYING_MATERIAL_COUNT),
156 )
157 .await
158 .map_err(RecursiveError::mls_client("initializing mls client"))?
159 }
160
161 let central = cc.mls;
162 context
163 .init_pki_env()
164 .await
165 .map_err(RecursiveError::transaction("initializing pki environment"))?;
166 context
167 .finish()
168 .await
169 .map_err(RecursiveError::transaction("finishing transaction"))?;
170 Ok(central)
171 }
172
173 pub async fn provide_transport(&self, transport: Arc<dyn MlsTransport>) {
176 self.transport.write().await.replace(transport);
177 }
178
179 pub async fn init(
191 &self,
192 identifier: ClientIdentifier,
193 ciphersuites: &[MlsCiphersuite],
194 backend: &MlsCryptoProvider,
195 nb_key_package: usize,
196 ) -> Result<()> {
197 self.ensure_unready().await?;
198 let id = identifier.get_id()?;
199
200 let credentials = backend
201 .key_store()
202 .find_all::<MlsCredential>(EntityFindParams::default())
203 .await
204 .map_err(KeystoreError::wrap("finding all mls credentials"))?;
205
206 let credentials = credentials
207 .into_iter()
208 .filter(|mls_credential| mls_credential.id.as_slice() == id.as_slice())
209 .map(|mls_credential| -> Result<_> {
210 let credential = Credential::tls_deserialize(&mut mls_credential.credential.as_slice())
211 .map_err(Error::tls_deserialize("mls credential"))?;
212 Ok((credential, mls_credential.created_at))
213 })
214 .collect::<Result<Vec<_>>>()?;
215
216 if credentials.is_empty() {
217 debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Generating client");
218 self.generate(identifier, backend, ciphersuites, nb_key_package).await?;
219 } else {
220 let signature_schemes = ciphersuites
221 .iter()
222 .map(|cs| cs.signature_algorithm())
223 .collect::<HashSet<_>>();
224 let load_result = self.load(backend, id.as_ref(), credentials, signature_schemes).await;
225 if let Err(Error::ClientSignatureNotFound) = load_result {
226 debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Client signature not found. Generating client");
227 self.generate(identifier, backend, ciphersuites, nb_key_package).await?;
228 } else {
229 load_result?;
230 }
231 };
232
233 Ok(())
234 }
235
236 #[cfg(test)]
238 pub(crate) async fn reset(&self) {
239 let mut inner_lock = self.inner.write().await;
240 *inner_lock = None;
241 }
242
243 pub(crate) async fn is_ready(&self) -> bool {
244 let inner_lock = self.inner.read().await;
245 inner_lock.is_some()
246 }
247
248 async fn ensure_unready(&self) -> Result<()> {
249 if self.is_ready().await {
250 Err(Error::UnexpectedlyReady)
251 } else {
252 Ok(())
253 }
254 }
255
256 async fn replace_inner(&self, new_inner: SessionInner) {
257 let mut inner_lock = self.inner.write().await;
258 *inner_lock = Some(new_inner);
259 }
260
261 pub async fn get_raw_conversation(&self, id: &ConversationId) -> Result<ImmutableConversation> {
267 let raw_conversation = GroupStore::fetch_from_keystore(id, &self.crypto_provider.keystore(), None)
268 .await
269 .map_err(RecursiveError::root("getting conversation by id"))?
270 .ok_or_else(|| LeafError::ConversationNotFound(id.clone()))?;
271 Ok(ImmutableConversation::new(raw_conversation, self.clone()))
272 }
273
274 pub async fn public_key(
281 &self,
282 ciphersuite: MlsCiphersuite,
283 credential_type: MlsCredentialType,
284 ) -> crate::mls::Result<Vec<u8>> {
285 let cb = self
286 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
287 .await
288 .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
289 Ok(cb.signature_key.to_public_vec())
290 }
291
292 pub(crate) fn new_basic_credential_bundle(
293 id: &ClientId,
294 sc: SignatureScheme,
295 backend: &MlsCryptoProvider,
296 ) -> Result<CredentialBundle> {
297 let (sk, pk) = backend
298 .crypto()
299 .signature_key_gen(sc)
300 .map_err(MlsError::wrap("generating a signature key"))?;
301
302 let signature_key = SignatureKeyPair::from_raw(sc, sk, pk);
303 let credential = Credential::new_basic(id.to_vec());
304 let cb = CredentialBundle {
305 credential,
306 signature_key,
307 created_at: 0,
308 };
309
310 Ok(cb)
311 }
312
313 pub(crate) fn new_x509_credential_bundle(cert: CertificateBundle) -> Result<CredentialBundle> {
314 let created_at = cert
315 .get_created_at()
316 .map_err(RecursiveError::mls_credential("getting credetntial created at"))?;
317 let (sk, ..) = cert.private_key.into_parts();
318 let chain = cert.certificate_chain;
319
320 let kp = CertificateKeyPair::new(sk, chain.clone()).map_err(MlsError::wrap("creating certificate key pair"))?;
321
322 let credential = Credential::new_x509(chain).map_err(MlsError::wrap("creating x509 credential"))?;
323
324 let cb = CredentialBundle {
325 credential,
326 signature_key: kp.0,
327 created_at,
328 };
329 Ok(cb)
330 }
331
332 pub async fn conversation_exists(&self, id: &ConversationId) -> Result<bool> {
334 match self.get_raw_conversation(id).await {
335 Ok(_) => Ok(true),
336 Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
337 Err(e) => Err(e),
338 }
339 }
340
341 pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
343 use openmls_traits::random::OpenMlsRand as _;
344 self.crypto_provider
345 .rand()
346 .random_vec(len)
347 .map_err(MlsError::wrap("generating random vector"))
348 .map_err(Into::into)
349 }
350
351 pub async fn can_close(&self) -> bool {
355 self.crypto_provider.can_close().await
356 }
357
358 pub async fn close(self) -> crate::mls::Result<()> {
363 self.crypto_provider
364 .close()
365 .await
366 .map_err(MlsError::wrap("closing connection with keystore"))
367 .map_err(Into::into)
368 }
369
370 pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
372 self.crypto_provider
373 .reseed(seed)
374 .map_err(MlsError::wrap("reseeding mls backend"))
375 .map_err(Into::into)
376 }
377
378 pub(crate) async fn generate_raw_keypairs(
388 &self,
389 ciphersuites: &[MlsCiphersuite],
390 backend: &MlsCryptoProvider,
391 ) -> Result<Vec<ClientId>> {
392 self.ensure_unready().await?;
393 const TEMP_KEY_SIZE: usize = 16;
394
395 let credentials = Self::find_all_basic_credentials(backend).await?;
396 if !credentials.is_empty() {
397 return Err(Error::IdentityAlreadyPresent);
398 }
399
400 use openmls_traits::random::OpenMlsRand as _;
401 let mut tmp_client_ids = Vec::with_capacity(ciphersuites.len());
403 for cs in ciphersuites {
404 let tmp_client_id: ClientId = backend
405 .rand()
406 .random_vec(TEMP_KEY_SIZE)
407 .map_err(MlsError::wrap("generating random client id"))?
408 .into();
409
410 let cb = Self::new_basic_credential_bundle(&tmp_client_id, cs.signature_algorithm(), backend)?;
411
412 let sign_kp = MlsSignatureKeyPair::new(
413 cs.signature_algorithm(),
414 cb.signature_key.to_public_vec(),
415 cb.signature_key
416 .tls_serialize_detached()
417 .map_err(Error::tls_serialize("signature key"))?,
418 tmp_client_id.clone().into(),
419 );
420 backend
421 .key_store()
422 .save(sign_kp)
423 .await
424 .map_err(KeystoreError::wrap("save signature keypair in keystore"))?;
425
426 tmp_client_ids.push(tmp_client_id);
427 }
428
429 Ok(tmp_client_ids)
430 }
431
432 pub(crate) async fn init_with_external_client_id(
442 &self,
443 client_id: ClientId,
444 tmp_ids: Vec<ClientId>,
445 ciphersuites: &[MlsCiphersuite],
446 backend: &MlsCryptoProvider,
447 ) -> Result<()> {
448 self.ensure_unready().await?;
449 let stored_skp = backend
451 .key_store()
452 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
453 .await
454 .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
455
456 match stored_skp.len().cmp(&tmp_ids.len()) {
457 std::cmp::Ordering::Less => return Err(Error::NoProvisionalIdentityFound),
458 std::cmp::Ordering::Greater => return Err(Error::TooManyIdentitiesPresent),
459 _ => {}
460 }
461
462 let all_tmp_ids_exist = stored_skp
464 .iter()
465 .all(|kp| tmp_ids.contains(&kp.credential_id.as_slice().into()));
466 if !all_tmp_ids_exist {
467 return Err(Error::NoProvisionalIdentityFound);
468 }
469
470 let identities = stored_skp.iter().zip(ciphersuites);
471
472 self.replace_inner(SessionInner {
473 id: client_id.clone(),
474 identities: Identities::new(stored_skp.len()),
475 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
476 })
477 .await;
478
479 let id = &client_id;
480
481 for (tmp_kp, &cs) in identities {
482 let scheme = tmp_kp
483 .signature_scheme
484 .try_into()
485 .map_err(|_| Error::InvalidSignatureScheme)?;
486 let new_keypair =
487 MlsSignatureKeyPair::new(scheme, tmp_kp.pk.clone(), tmp_kp.keypair.clone(), id.clone().into());
488
489 let new_credential = MlsCredential {
490 id: id.clone().into(),
491 credential: tmp_kp.credential_id.clone(),
492 created_at: 0,
493 };
494
495 backend
497 .key_store()
498 .remove::<MlsSignatureKeyPair, &[u8]>(&new_keypair.pk)
499 .await
500 .map_err(KeystoreError::wrap("removing mls signature keypair"))?;
501
502 let signature_key = SignatureKeyPair::tls_deserialize(&mut new_keypair.keypair.as_slice())
503 .map_err(Error::tls_deserialize("signature key"))?;
504 let cb = CredentialBundle {
505 credential: Credential::new_basic(new_credential.credential.clone()),
506 signature_key,
507 created_at: 0, };
509
510 self.save_identity(&backend.keystore(), Some(id), cs.signature_algorithm(), cb)
512 .await?;
513 }
514
515 Ok(())
516 }
517
518 pub(crate) async fn generate(
520 &self,
521 identifier: ClientIdentifier,
522 backend: &MlsCryptoProvider,
523 ciphersuites: &[MlsCiphersuite],
524 nb_key_package: usize,
525 ) -> Result<()> {
526 self.ensure_unready().await?;
527 let id = identifier.get_id()?;
528 let signature_schemes = ciphersuites
529 .iter()
530 .map(|cs| cs.signature_algorithm())
531 .collect::<HashSet<_>>();
532 self.replace_inner(SessionInner {
533 id: id.into_owned(),
534 identities: Identities::new(signature_schemes.len()),
535 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
536 })
537 .await;
538
539 let identities = identifier.generate_credential_bundles(backend, signature_schemes)?;
540
541 for (sc, id, cb) in identities {
542 self.save_identity(&backend.keystore(), Some(&id), sc, cb).await?;
543 }
544
545 let guard = self.inner.read().await;
546 let SessionInner { identities, .. } = guard.as_ref().ok_or(Error::MlsNotInitialized)?;
547
548 if nb_key_package != 0 {
549 for ciphersuite in ciphersuites.iter().copied() {
550 let ciphersuite_signature_scheme = ciphersuite.signature_algorithm();
551 for credential_bundle in identities.iter().filter_map(|(signature_scheme, credential_bundle)| {
552 (signature_scheme == ciphersuite_signature_scheme).then_some(credential_bundle)
553 }) {
554 let credential_type = credential_bundle.credential.credential_type().into();
555 self.request_key_packages(nb_key_package, ciphersuite, credential_type, backend)
556 .await?;
557 }
558 }
559 }
560
561 Ok(())
562 }
563
564 pub(crate) async fn load(
566 &self,
567 backend: &MlsCryptoProvider,
568 id: &ClientId,
569 mut credentials: Vec<(Credential, u64)>,
570 signature_schemes: HashSet<SignatureScheme>,
571 ) -> Result<()> {
572 self.ensure_unready().await?;
573 let mut identities = Identities::new(signature_schemes.len());
574
575 credentials.sort_by_key(|(_, timestamp)| *timestamp);
577
578 let stored_signature_keypairs = backend
579 .key_store()
580 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
581 .await
582 .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
583
584 for signature_scheme in signature_schemes {
585 let signature_keypair = stored_signature_keypairs
586 .iter()
587 .find(|skp| skp.signature_scheme == (signature_scheme as u16));
588
589 let signature_key = if let Some(kp) = signature_keypair {
590 SignatureKeyPair::tls_deserialize(&mut kp.keypair.as_slice())
591 .map_err(Error::tls_deserialize("signature keypair"))?
592 } else {
593 let (private_key, public_key) = backend
594 .crypto()
595 .signature_key_gen(signature_scheme)
596 .map_err(MlsError::wrap("generating signature key"))?;
597 let keypair = SignatureKeyPair::from_raw(signature_scheme, private_key, public_key.clone());
598 let raw_keypair = keypair
599 .tls_serialize_detached()
600 .map_err(Error::tls_serialize("raw keypair"))?;
601 let store_keypair =
602 MlsSignatureKeyPair::new(signature_scheme, public_key, raw_keypair, id.as_slice().into());
603 backend
604 .key_store()
605 .save(store_keypair.clone())
606 .await
607 .map_err(KeystoreError::wrap("storing keypairs in keystore"))?;
608 SignatureKeyPair::tls_deserialize(&mut store_keypair.keypair.as_slice())
609 .map_err(Error::tls_deserialize("signature keypair"))?
610 };
611
612 for (credential, created_at) in &credentials {
613 match credential.mls_credential() {
614 openmls::prelude::MlsCredentialType::Basic(_) => {
615 if id.as_slice() != credential.identity() {
616 return Err(Error::WrongCredential);
617 }
618 }
619 openmls::prelude::MlsCredentialType::X509(cert) => {
620 let spk = cert
621 .extract_public_key()
622 .map_err(RecursiveError::mls_credential("extracting public key"))?
623 .ok_or(LeafError::InternalMlsError)?;
624 if signature_key.public() != spk {
625 return Err(Error::WrongCredential);
626 }
627 }
628 };
629 let cb = CredentialBundle {
630 credential: credential.clone(),
631 signature_key: signature_key.clone(),
632 created_at: *created_at,
633 };
634 identities.push_credential_bundle(signature_scheme, cb).await?;
635 }
636 }
637 self.replace_inner(SessionInner {
638 id: id.clone(),
639 identities,
640 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
641 })
642 .await;
643 Ok(())
644 }
645
646 pub(crate) async fn restore_from_history_secret(&self, history_secret: HistorySecret) -> Result<()> {
648 self.ensure_unready().await?;
649
650 self.replace_inner(SessionInner {
652 id: history_secret.client_id.clone(),
653 identities: Identities::new(1),
654 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
655 })
656 .await;
657
658 let key_package = history_secret
660 .key_package
661 .store(&self.crypto_provider)
662 .await
663 .map_err(MlsError::wrap("storing key package encapsulation"))?;
664
665 let keystore = self.crypto_provider.key_store();
666
667 self.save_identity(
669 keystore,
670 Some(&history_secret.client_id),
671 key_package.ciphersuite().signature_algorithm(),
672 history_secret.credential_bundle,
673 )
674 .await?;
675
676 Ok(())
677 }
678
679 async fn find_all_basic_credentials(backend: &MlsCryptoProvider) -> Result<Vec<Credential>> {
680 let store_credentials = backend
681 .key_store()
682 .find_all::<MlsCredential>(EntityFindParams::default())
683 .await
684 .map_err(KeystoreError::wrap("finding all mls credentialss"))?;
685 let mut credentials = Vec::with_capacity(store_credentials.len());
686 for store_credential in store_credentials.into_iter() {
687 let credential = Credential::tls_deserialize(&mut store_credential.credential.as_slice())
688 .map_err(Error::tls_deserialize("credential"))?;
689 if !matches!(credential.credential_type(), CredentialType::Basic) {
690 continue;
691 }
692 credentials.push(credential);
693 }
694
695 Ok(credentials)
696 }
697
698 pub(crate) async fn save_identity(
699 &self,
700 keystore: &Connection,
701 id: Option<&ClientId>,
702 signature_scheme: SignatureScheme,
703 mut credential_bundle: CredentialBundle,
704 ) -> Result<CredentialBundle> {
705 let mut guard = self.inner.write().await;
706 let SessionInner {
707 id: existing_id,
708 identities,
709 ..
710 } = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
711
712 let id = id.unwrap_or(existing_id);
713
714 let credential = credential_bundle
715 .credential
716 .tls_serialize_detached()
717 .map_err(Error::tls_serialize("credential bundle"))?;
718 let credential = MlsCredential {
719 id: id.clone().into(),
720 credential,
721 created_at: 0,
722 };
723
724 let credential = keystore
725 .save(credential)
726 .await
727 .map_err(KeystoreError::wrap("saving credential"))?;
728
729 let sign_kp = MlsSignatureKeyPair::new(
730 signature_scheme,
731 credential_bundle.signature_key.to_public_vec(),
732 credential_bundle
733 .signature_key
734 .tls_serialize_detached()
735 .map_err(Error::tls_serialize("signature keypair"))?,
736 id.clone().into(),
737 );
738 keystore.save(sign_kp).await.map_err(|e| match e {
739 CryptoKeystoreError::AlreadyExists => Error::CredentialBundleConflict,
740 _ => KeystoreError::wrap("saving mls signature key pair")(e).into(),
741 })?;
742
743 credential_bundle.created_at = credential.created_at;
745
746 identities
747 .push_credential_bundle(signature_scheme, credential_bundle.clone())
748 .await?;
749
750 Ok(credential_bundle)
751 }
752
753 pub async fn id(&self) -> Result<ClientId> {
755 match self.inner.read().await.deref() {
756 None => Err(Error::MlsNotInitialized),
757 Some(SessionInner { id, .. }) => Ok(id.clone()),
758 }
759 }
760
761 pub async fn is_e2ei_capable(&self) -> bool {
763 match self.inner.read().await.deref() {
764 None => false,
765 Some(SessionInner { identities, .. }) => identities
766 .iter()
767 .any(|(_, cred)| cred.credential().credential_type() == CredentialType::X509),
768 }
769 }
770
771 pub(crate) async fn get_most_recent_or_create_credential_bundle(
772 &self,
773 backend: &MlsCryptoProvider,
774 sc: SignatureScheme,
775 ct: MlsCredentialType,
776 ) -> Result<Arc<CredentialBundle>> {
777 match ct {
778 MlsCredentialType::Basic => {
779 self.init_basic_credential_bundle_if_missing(backend, sc).await?;
780 self.find_most_recent_credential_bundle(sc, ct).await
781 }
782 MlsCredentialType::X509 => self
783 .find_most_recent_credential_bundle(sc, ct)
784 .await
785 .map_err(|e| match e {
786 Error::CredentialNotFound(_) => LeafError::E2eiEnrollmentNotDone.into(),
787 _ => e,
788 }),
789 }
790 }
791
792 pub(crate) async fn init_basic_credential_bundle_if_missing(
793 &self,
794 backend: &MlsCryptoProvider,
795 sc: SignatureScheme,
796 ) -> Result<()> {
797 let existing_cb = self
798 .find_most_recent_credential_bundle(sc, MlsCredentialType::Basic)
799 .await;
800 if matches!(existing_cb, Err(Error::CredentialNotFound(_))) {
801 let id = self.id().await?;
802 debug!(id:% = &id; "Initializing basic credential bundle");
803 let cb = Self::new_basic_credential_bundle(&id, sc, backend)?;
804 self.save_identity(&backend.keystore(), None, sc, cb).await?;
805 }
806 Ok(())
807 }
808
809 pub(crate) async fn save_new_x509_credential_bundle(
810 &self,
811 keystore: &Connection,
812 sc: SignatureScheme,
813 cb: CertificateBundle,
814 ) -> Result<CredentialBundle> {
815 let id = cb
816 .get_client_id()
817 .map_err(RecursiveError::mls_credential("getting client id"))?;
818 let cb = Self::new_x509_credential_bundle(cb)?;
819 self.save_identity(keystore, Some(&id), sc, cb).await
820 }
821}
822
823#[cfg(test)]
824mod tests {
825 use super::*;
826 use crate::prelude::ClientId;
827 use crate::test_utils::*;
828 use crate::transaction_context::test_utils::EntitiesCount;
829 use core_crypto_keystore::connection::{DatabaseKey, FetchFromDatabase};
830 use core_crypto_keystore::entities::*;
831 use mls_crypto_provider::MlsCryptoProvider;
832 use wasm_bindgen_test::*;
833
834 impl Session {
835 #![allow(missing_docs)]
837
838 pub async fn random_generate(
839 &self,
840 case: &crate::test_utils::TestContext,
841 signer: Option<&crate::test_utils::x509::X509Certificate>,
842 provision: bool,
843 ) -> Result<()> {
844 self.reset().await;
845 let user_uuid = uuid::Uuid::new_v4();
846 let rnd_id = rand::random::<usize>();
847 let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
848 let identity = match case.credential_type {
849 MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.as_str().into()),
850 MlsCredentialType::X509 => {
851 let signer = signer.expect("Missing intermediate CA");
852 CertificateBundle::rand_identifier(&client_id, &[signer])
853 }
854 };
855 let nb_key_package = if provision {
856 crate::prelude::INITIAL_KEYING_MATERIAL_COUNT
857 } else {
858 0
859 };
860 let backend = self.crypto_provider.clone();
861 self.generate(identity, &backend, &[case.ciphersuite()], nb_key_package)
862 .await?;
863 Ok(())
864 }
865
866 pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
867 use core_crypto_keystore::CryptoKeystoreMls as _;
868 let kps = backend
869 .key_store()
870 .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
871 .await
872 .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
873 Ok(kps)
874 }
875
876 pub(crate) async fn init_x509_credential_bundle_if_missing(
877 &self,
878 backend: &MlsCryptoProvider,
879 sc: SignatureScheme,
880 cb: CertificateBundle,
881 ) -> Result<()> {
882 let existing_cb = self
883 .find_most_recent_credential_bundle(sc, MlsCredentialType::X509)
884 .await
885 .is_err();
886 if existing_cb {
887 self.save_new_x509_credential_bundle(&backend.keystore(), sc, cb)
888 .await?;
889 }
890 Ok(())
891 }
892
893 pub(crate) async fn generate_one_keypackage(
894 &self,
895 backend: &MlsCryptoProvider,
896 cs: MlsCiphersuite,
897 ct: MlsCredentialType,
898 ) -> Result<openmls::prelude::KeyPackage> {
899 let cb = self
900 .find_most_recent_credential_bundle(cs.signature_algorithm(), ct)
901 .await?;
902 self.generate_one_keypackage_from_credential_bundle(backend, cs, &cb)
903 .await
904 }
905
906 pub async fn count_entities(&self) -> EntitiesCount {
908 let keystore = self.crypto_provider.keystore();
909 let credential = keystore.count::<MlsCredential>().await.unwrap();
910 let encryption_keypair = keystore.count::<MlsEncryptionKeyPair>().await.unwrap();
911 let epoch_encryption_keypair = keystore.count::<MlsEpochEncryptionKeyPair>().await.unwrap();
912 let enrollment = keystore.count::<E2eiEnrollment>().await.unwrap();
913 let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
914 let hpke_private_key = keystore.count::<MlsHpkePrivateKey>().await.unwrap();
915 let key_package = keystore.count::<MlsKeyPackage>().await.unwrap();
916 let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
917 let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
918 let psk_bundle = keystore.count::<MlsPskBundle>().await.unwrap();
919 let signature_keypair = keystore.count::<MlsSignatureKeyPair>().await.unwrap();
920 EntitiesCount {
921 credential,
922 encryption_keypair,
923 epoch_encryption_keypair,
924 enrollment,
925 group,
926 hpke_private_key,
927 key_package,
928 pending_group,
929 pending_messages,
930 psk_bundle,
931 signature_keypair,
932 }
933 }
934 }
935 wasm_bindgen_test_configure!(run_in_browser);
936
937 #[apply(all_cred_cipher)]
938 #[wasm_bindgen_test]
939 async fn can_generate_session(case: TestContext) {
940 let [alice] = case.sessions().await;
941 let key = DatabaseKey::generate();
942 let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
943 let x509_test_chain = if case.is_x509() {
944 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
945 x509_test_chain.register_with_provider(&backend).await;
946 Some(x509_test_chain)
947 } else {
948 None
949 };
950 backend.new_transaction().await.unwrap();
951 let session = alice.session().await;
952 session
953 .random_generate(
954 &case,
955 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
956 false,
957 )
958 .await
959 .unwrap();
960 }
961
962 #[apply(all_cred_cipher)]
963 #[wasm_bindgen_test]
964 async fn can_externally_generate_client(mut case: TestContext) {
965 let [alice] = case.sessions().await;
966 if !case.is_basic() {
967 return;
968 }
969 let tmp_dir = case.tmp_dir().await;
970 Box::pin(async move {
971 let key = DatabaseKey::generate();
972 let backend = MlsCryptoProvider::try_new(tmp_dir, &key).await.unwrap();
973 backend.new_transaction().await.unwrap();
974 let client_id: ClientId = b"whatever:my:client:is@world.com".to_vec().into();
976 let alice = alice.session().await;
977 alice.reset().await;
978 let handles = alice
980 .generate_raw_keypairs(&[case.ciphersuite()], &backend)
981 .await
982 .unwrap();
983
984 let mut identities = backend
985 .keystore()
986 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
987 .await
988 .unwrap();
989
990 assert_eq!(identities.len(), 1);
991
992 let prov_identity = identities.pop().unwrap();
993
994 let prov_client_id: ClientId = prov_identity.credential_id.as_slice().into();
997 assert_eq!(&prov_client_id, handles.first().unwrap());
998
999 alice
1001 .init_with_external_client_id(client_id.clone(), handles.clone(), &[case.ciphersuite()], &backend)
1002 .await
1003 .unwrap();
1004
1005 assert_eq!(alice.id().await.unwrap(), client_id);
1007 let cb = alice
1008 .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
1009 .await
1010 .unwrap();
1011 let client_id: ClientId = cb.credential().identity().into();
1012 assert_eq!(&client_id, handles.first().unwrap());
1013 })
1014 .await
1015 }
1016}