1pub(crate) mod e2e_identity;
2mod epoch_observer;
3mod error;
4pub(crate) mod id;
5pub(crate) mod identifier;
6pub(crate) mod identities;
7pub(crate) mod key_package;
8pub(crate) mod user_id;
9
10use crate::{
11 CoreCrypto, KeystoreError, LeafError, MlsError, MlsTransport, RecursiveError,
12 group_store::GroupStore,
13 mls::{
14 self, HasSessionAndCrypto,
15 conversation::ImmutableConversation,
16 credential::{CredentialBundle, ext::CredentialExt},
17 },
18 prelude::{
19 CertificateBundle, ClientId, ConversationId, HistorySecret, INITIAL_KEYING_MATERIAL_COUNT, MlsCiphersuite,
20 MlsClientConfiguration, MlsCredentialType, identifier::ClientIdentifier,
21 key_package::KEYPACKAGE_DEFAULT_LIFETIME,
22 },
23};
24use async_lock::RwLock;
25use core_crypto_keystore::{
26 Connection, CryptoKeystoreError,
27 connection::FetchFromDatabase,
28 entities::{EntityFindParams, MlsCredential, MlsSignatureKeyPair},
29};
30pub use epoch_observer::EpochObserver;
31pub(crate) use error::{Error, Result};
32use identities::Identities;
33use log::debug;
34use mls_crypto_provider::{EntropySeed, MlsCryptoProvider, MlsCryptoProviderConfiguration};
35use openmls::prelude::{Credential, CredentialType};
36use openmls_basic_credential::SignatureKeyPair;
37use openmls_traits::{OpenMlsCryptoProvider, crypto::OpenMlsCrypto, types::SignatureScheme};
38use openmls_x509_credential::CertificateKeyPair;
39use std::ops::Deref;
40use std::sync::Arc;
41use std::{collections::HashSet, fmt};
42use tls_codec::{Deserialize, Serialize};
43
44#[derive(Clone, Debug)]
55pub struct Session {
56 pub(crate) inner: Arc<RwLock<Option<SessionInner>>>,
57 pub(crate) crypto_provider: MlsCryptoProvider,
58 pub(crate) transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
59}
60
61#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
62#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
63impl HasSessionAndCrypto for Session {
64 async fn session(&self) -> mls::Result<Session> {
65 Ok(self.clone())
66 }
67
68 async fn crypto_provider(&self) -> mls::Result<MlsCryptoProvider> {
69 Ok(self.crypto_provider.clone())
70 }
71}
72
73#[derive(Clone)]
74pub(crate) struct SessionInner {
75 id: ClientId,
76 pub(crate) identities: Identities,
77 keypackage_lifetime: std::time::Duration,
78 epoch_observer: Option<Arc<dyn EpochObserver>>,
79}
80
81impl fmt::Debug for SessionInner {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 let observer_debug = if self.epoch_observer.is_some() {
84 "Some(Arc<dyn EpochObserver>)"
85 } else {
86 "None"
87 };
88 f.debug_struct("ClientInner")
89 .field("id", &self.id)
90 .field("identities", &self.identities)
91 .field("keypackage_lifetime", &self.keypackage_lifetime)
92 .field("epoch_observer", &observer_debug)
93 .finish()
94 }
95}
96
97impl Session {
98 pub async fn try_new(configuration: MlsClientConfiguration) -> crate::mls::Result<Self> {
114 let mls_backend = MlsCryptoProvider::try_new_with_configuration(MlsCryptoProviderConfiguration {
116 db_path: &configuration.store_path,
117 db_key: configuration.database_key.clone(),
118 in_memory: false,
119 entropy_seed: configuration.external_entropy.clone(),
120 })
121 .await
122 .map_err(MlsError::wrap("trying to initialize mls crypto provider object"))?;
123 Self::new_with_backend(mls_backend, configuration).await
124 }
125
126 pub async fn try_new_in_memory(configuration: MlsClientConfiguration) -> crate::mls::Result<Self> {
129 let mls_backend = MlsCryptoProvider::try_new_with_configuration(MlsCryptoProviderConfiguration {
130 db_path: &configuration.store_path,
131 db_key: configuration.database_key.clone(),
132 in_memory: true,
133 entropy_seed: configuration.external_entropy.clone(),
134 })
135 .await
136 .map_err(MlsError::wrap(
137 "trying to initialize mls crypto provider object (in memory)",
138 ))?;
139 Self::new_with_backend(mls_backend, configuration).await
140 }
141
142 async fn new_with_backend(
143 mls_backend: MlsCryptoProvider,
144 configuration: MlsClientConfiguration,
145 ) -> crate::mls::Result<Self> {
146 let client = Self {
150 crypto_provider: mls_backend.clone(),
151 inner: Default::default(),
152 transport: Arc::new(None.into()),
153 };
154
155 let cc = CoreCrypto::from(client.clone());
156 let context = cc
157 .new_transaction()
158 .await
159 .map_err(RecursiveError::transaction("starting new transaction"))?;
160
161 if let Some(id) = configuration.client_id {
162 client
163 .init(
164 ClientIdentifier::Basic(id),
165 configuration.ciphersuites.as_slice(),
166 &mls_backend,
167 configuration
168 .nb_init_key_packages
169 .unwrap_or(INITIAL_KEYING_MATERIAL_COUNT),
170 )
171 .await
172 .map_err(RecursiveError::mls_client("initializing mls client"))?
173 }
174
175 let central = cc.mls;
176 context
177 .init_pki_env()
178 .await
179 .map_err(RecursiveError::transaction("initializing pki environment"))?;
180 context
181 .finish()
182 .await
183 .map_err(RecursiveError::transaction("finishing transaction"))?;
184 Ok(central)
185 }
186
187 pub async fn provide_transport(&self, transport: Arc<dyn MlsTransport>) {
190 self.transport.write().await.replace(transport);
191 }
192
193 pub async fn init(
205 &self,
206 identifier: ClientIdentifier,
207 ciphersuites: &[MlsCiphersuite],
208 backend: &MlsCryptoProvider,
209 nb_key_package: usize,
210 ) -> Result<()> {
211 self.ensure_unready().await?;
212 let id = identifier.get_id()?;
213
214 let credentials = backend
215 .key_store()
216 .find_all::<MlsCredential>(EntityFindParams::default())
217 .await
218 .map_err(KeystoreError::wrap("finding all mls credentials"))?;
219
220 let credentials = credentials
221 .into_iter()
222 .filter(|mls_credential| mls_credential.id.as_slice() == id.as_slice())
223 .map(|mls_credential| -> Result<_> {
224 let credential = Credential::tls_deserialize(&mut mls_credential.credential.as_slice())
225 .map_err(Error::tls_deserialize("mls credential"))?;
226 Ok((credential, mls_credential.created_at))
227 })
228 .collect::<Result<Vec<_>>>()?;
229
230 if credentials.is_empty() {
231 debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Generating client");
232 self.generate(identifier, backend, ciphersuites, nb_key_package).await?;
233 } else {
234 let signature_schemes = ciphersuites
235 .iter()
236 .map(|cs| cs.signature_algorithm())
237 .collect::<HashSet<_>>();
238 let load_result = self.load(backend, id.as_ref(), credentials, signature_schemes).await;
239 if let Err(Error::ClientSignatureNotFound) = load_result {
240 debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Client signature not found. Generating client");
241 self.generate(identifier, backend, ciphersuites, nb_key_package).await?;
242 } else {
243 load_result?;
244 }
245 };
246
247 Ok(())
248 }
249
250 #[cfg(test)]
252 pub(crate) async fn reset(&self) {
253 let mut inner_lock = self.inner.write().await;
254 *inner_lock = None;
255 }
256
257 pub(crate) async fn is_ready(&self) -> bool {
258 let inner_lock = self.inner.read().await;
259 inner_lock.is_some()
260 }
261
262 async fn ensure_unready(&self) -> Result<()> {
263 if self.is_ready().await {
264 Err(Error::UnexpectedlyReady)
265 } else {
266 Ok(())
267 }
268 }
269
270 async fn replace_inner(&self, new_inner: SessionInner) {
271 let mut inner_lock = self.inner.write().await;
272 *inner_lock = Some(new_inner);
273 }
274
275 pub async fn get_raw_conversation(&self, id: &ConversationId) -> Result<ImmutableConversation> {
280 let raw_conversation = GroupStore::fetch_from_keystore(id, &self.crypto_provider.keystore(), None)
281 .await
282 .map_err(RecursiveError::root("getting conversation by id"))?
283 .ok_or_else(|| LeafError::ConversationNotFound(id.clone()))?;
284 Ok(ImmutableConversation::new(raw_conversation, self.clone()))
285 }
286
287 pub async fn public_key(
294 &self,
295 ciphersuite: MlsCiphersuite,
296 credential_type: MlsCredentialType,
297 ) -> crate::mls::Result<Vec<u8>> {
298 let cb = self
299 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
300 .await
301 .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
302 Ok(cb.signature_key.to_public_vec())
303 }
304
305 pub(crate) fn new_basic_credential_bundle(
306 id: &ClientId,
307 sc: SignatureScheme,
308 backend: &MlsCryptoProvider,
309 ) -> Result<CredentialBundle> {
310 let (sk, pk) = backend
311 .crypto()
312 .signature_key_gen(sc)
313 .map_err(MlsError::wrap("generating a signature key"))?;
314
315 let signature_key = SignatureKeyPair::from_raw(sc, sk, pk);
316 let credential = Credential::new_basic(id.to_vec());
317 let cb = CredentialBundle {
318 credential,
319 signature_key,
320 created_at: 0,
321 };
322
323 Ok(cb)
324 }
325
326 pub(crate) fn new_x509_credential_bundle(cert: CertificateBundle) -> Result<CredentialBundle> {
327 let created_at = cert
328 .get_created_at()
329 .map_err(RecursiveError::mls_credential("getting credetntial created at"))?;
330 let (sk, ..) = cert.private_key.into_parts();
331 let chain = cert.certificate_chain;
332
333 let kp = CertificateKeyPair::new(sk, chain.clone()).map_err(MlsError::wrap("creating certificate key pair"))?;
334
335 let credential = Credential::new_x509(chain).map_err(MlsError::wrap("creating x509 credential"))?;
336
337 let cb = CredentialBundle {
338 credential,
339 signature_key: kp.0,
340 created_at,
341 };
342 Ok(cb)
343 }
344
345 pub async fn conversation_exists(&self, id: &ConversationId) -> Result<bool> {
347 match self.get_raw_conversation(id).await {
348 Ok(_) => Ok(true),
349 Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
350 Err(e) => Err(e),
351 }
352 }
353
354 pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
356 use openmls_traits::random::OpenMlsRand as _;
357 self.crypto_provider
358 .rand()
359 .random_vec(len)
360 .map_err(MlsError::wrap("generating random vector"))
361 .map_err(Into::into)
362 }
363
364 pub async fn can_close(&self) -> bool {
368 self.crypto_provider.can_close().await
369 }
370
371 pub async fn close(self) -> crate::mls::Result<()> {
376 self.crypto_provider
377 .close()
378 .await
379 .map_err(MlsError::wrap("closing connection with keystore"))
380 .map_err(Into::into)
381 }
382
383 pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
385 self.crypto_provider
386 .reseed(seed)
387 .map_err(MlsError::wrap("reseeding mls backend"))
388 .map_err(Into::into)
389 }
390
391 pub async fn generate_raw_keypairs(
401 &self,
402 ciphersuites: &[MlsCiphersuite],
403 backend: &MlsCryptoProvider,
404 ) -> Result<Vec<ClientId>> {
405 self.ensure_unready().await?;
406 const TEMP_KEY_SIZE: usize = 16;
407
408 let credentials = Self::find_all_basic_credentials(backend).await?;
409 if !credentials.is_empty() {
410 return Err(Error::IdentityAlreadyPresent);
411 }
412
413 use openmls_traits::random::OpenMlsRand as _;
414 let mut tmp_client_ids = Vec::with_capacity(ciphersuites.len());
416 for cs in ciphersuites {
417 let tmp_client_id: ClientId = backend
418 .rand()
419 .random_vec(TEMP_KEY_SIZE)
420 .map_err(MlsError::wrap("generating random client id"))?
421 .into();
422
423 let cb = Self::new_basic_credential_bundle(&tmp_client_id, cs.signature_algorithm(), backend)?;
424
425 let sign_kp = MlsSignatureKeyPair::new(
426 cs.signature_algorithm(),
427 cb.signature_key.to_public_vec(),
428 cb.signature_key
429 .tls_serialize_detached()
430 .map_err(Error::tls_serialize("signature key"))?,
431 tmp_client_id.clone().into(),
432 );
433 backend
434 .key_store()
435 .save(sign_kp)
436 .await
437 .map_err(KeystoreError::wrap("save signature keypair in keystore"))?;
438
439 tmp_client_ids.push(tmp_client_id);
440 }
441
442 Ok(tmp_client_ids)
443 }
444
445 pub async fn init_with_external_client_id(
455 &self,
456 client_id: ClientId,
457 tmp_ids: Vec<ClientId>,
458 ciphersuites: &[MlsCiphersuite],
459 backend: &MlsCryptoProvider,
460 ) -> Result<()> {
461 self.ensure_unready().await?;
462 let stored_skp = backend
464 .key_store()
465 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
466 .await
467 .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
468
469 match stored_skp.len().cmp(&tmp_ids.len()) {
470 std::cmp::Ordering::Less => return Err(Error::NoProvisionalIdentityFound),
471 std::cmp::Ordering::Greater => return Err(Error::TooManyIdentitiesPresent),
472 _ => {}
473 }
474
475 let all_tmp_ids_exist = stored_skp
477 .iter()
478 .all(|kp| tmp_ids.contains(&kp.credential_id.as_slice().into()));
479 if !all_tmp_ids_exist {
480 return Err(Error::NoProvisionalIdentityFound);
481 }
482
483 let identities = stored_skp.iter().zip(ciphersuites);
484
485 self.replace_inner(SessionInner {
486 id: client_id.clone(),
487 identities: Identities::new(stored_skp.len()),
488 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
489 epoch_observer: None,
490 })
491 .await;
492
493 let id = &client_id;
494
495 for (tmp_kp, &cs) in identities {
496 let scheme = tmp_kp
497 .signature_scheme
498 .try_into()
499 .map_err(|_| Error::InvalidSignatureScheme)?;
500 let new_keypair =
501 MlsSignatureKeyPair::new(scheme, tmp_kp.pk.clone(), tmp_kp.keypair.clone(), id.clone().into());
502
503 let new_credential = MlsCredential {
504 id: id.clone().into(),
505 credential: tmp_kp.credential_id.clone(),
506 created_at: 0,
507 };
508
509 backend
511 .key_store()
512 .remove::<MlsSignatureKeyPair, &[u8]>(&new_keypair.pk)
513 .await
514 .map_err(KeystoreError::wrap("removing mls signature keypair"))?;
515
516 let signature_key = SignatureKeyPair::tls_deserialize(&mut new_keypair.keypair.as_slice())
517 .map_err(Error::tls_deserialize("signature key"))?;
518 let cb = CredentialBundle {
519 credential: Credential::new_basic(new_credential.credential.clone()),
520 signature_key,
521 created_at: 0, };
523
524 self.save_identity(&backend.keystore(), Some(id), cs.signature_algorithm(), cb)
526 .await?;
527 }
528
529 Ok(())
530 }
531
532 pub(crate) async fn generate(
534 &self,
535 identifier: ClientIdentifier,
536 backend: &MlsCryptoProvider,
537 ciphersuites: &[MlsCiphersuite],
538 nb_key_package: usize,
539 ) -> Result<()> {
540 self.ensure_unready().await?;
541 let id = identifier.get_id()?;
542 let signature_schemes = ciphersuites
543 .iter()
544 .map(|cs| cs.signature_algorithm())
545 .collect::<HashSet<_>>();
546 self.replace_inner(SessionInner {
547 id: id.into_owned(),
548 identities: Identities::new(signature_schemes.len()),
549 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
550 epoch_observer: None,
551 })
552 .await;
553
554 let identities = identifier.generate_credential_bundles(backend, signature_schemes)?;
555
556 for (sc, id, cb) in identities {
557 self.save_identity(&backend.keystore(), Some(&id), sc, cb).await?;
558 }
559
560 let guard = self.inner.read().await;
561 let SessionInner { identities, .. } = guard.as_ref().ok_or(Error::MlsNotInitialized)?;
562
563 if nb_key_package != 0 {
564 for ciphersuite in ciphersuites.iter().copied() {
565 let ciphersuite_signature_scheme = ciphersuite.signature_algorithm();
566 for credential_bundle in identities.iter().filter_map(|(signature_scheme, credential_bundle)| {
567 (signature_scheme == ciphersuite_signature_scheme).then_some(credential_bundle)
568 }) {
569 let credential_type = credential_bundle.credential.credential_type().into();
570 self.request_key_packages(nb_key_package, ciphersuite, credential_type, backend)
571 .await?;
572 }
573 }
574 }
575
576 Ok(())
577 }
578
579 pub(crate) async fn load(
581 &self,
582 backend: &MlsCryptoProvider,
583 id: &ClientId,
584 mut credentials: Vec<(Credential, u64)>,
585 signature_schemes: HashSet<SignatureScheme>,
586 ) -> Result<()> {
587 self.ensure_unready().await?;
588 let mut identities = Identities::new(signature_schemes.len());
589
590 credentials.sort_by_key(|(_, timestamp)| *timestamp);
592
593 let stored_signature_keypairs = backend
594 .key_store()
595 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
596 .await
597 .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
598
599 for signature_scheme in signature_schemes {
600 let signature_keypair = stored_signature_keypairs
601 .iter()
602 .find(|skp| skp.signature_scheme == (signature_scheme as u16));
603
604 let signature_key = if let Some(kp) = signature_keypair {
605 SignatureKeyPair::tls_deserialize(&mut kp.keypair.as_slice())
606 .map_err(Error::tls_deserialize("signature keypair"))?
607 } else {
608 let (private_key, public_key) = backend
609 .crypto()
610 .signature_key_gen(signature_scheme)
611 .map_err(MlsError::wrap("generating signature key"))?;
612 let keypair = SignatureKeyPair::from_raw(signature_scheme, private_key, public_key.clone());
613 let raw_keypair = keypair
614 .tls_serialize_detached()
615 .map_err(Error::tls_serialize("raw keypair"))?;
616 let store_keypair =
617 MlsSignatureKeyPair::new(signature_scheme, public_key, raw_keypair, id.as_slice().into());
618 backend
619 .key_store()
620 .save(store_keypair.clone())
621 .await
622 .map_err(KeystoreError::wrap("storing keypairs in keystore"))?;
623 SignatureKeyPair::tls_deserialize(&mut store_keypair.keypair.as_slice())
624 .map_err(Error::tls_deserialize("signature keypair"))?
625 };
626
627 for (credential, created_at) in &credentials {
628 match credential.mls_credential() {
629 openmls::prelude::MlsCredentialType::Basic(_) => {
630 if id.as_slice() != credential.identity() {
631 return Err(Error::WrongCredential);
632 }
633 }
634 openmls::prelude::MlsCredentialType::X509(cert) => {
635 let spk = cert
636 .extract_public_key()
637 .map_err(RecursiveError::mls_credential("extracting public key"))?
638 .ok_or(LeafError::InternalMlsError)?;
639 if signature_key.public() != spk {
640 return Err(Error::WrongCredential);
641 }
642 }
643 };
644 let cb = CredentialBundle {
645 credential: credential.clone(),
646 signature_key: signature_key.clone(),
647 created_at: *created_at,
648 };
649 identities.push_credential_bundle(signature_scheme, cb).await?;
650 }
651 }
652 self.replace_inner(SessionInner {
653 id: id.clone(),
654 identities,
655 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
656 epoch_observer: None,
657 })
658 .await;
659 Ok(())
660 }
661
662 pub(crate) async fn restore_from_history_secret(&self, history_secret: HistorySecret) -> Result<()> {
664 self.ensure_unready().await?;
665
666 self.replace_inner(SessionInner {
668 id: history_secret.client_id.clone(),
669 identities: Identities::new(1),
670 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
671 epoch_observer: None,
672 })
673 .await;
674
675 let key_package = history_secret
677 .key_package
678 .store(&self.crypto_provider)
679 .await
680 .map_err(MlsError::wrap("storing key package encapsulation"))?;
681
682 let keystore = self.crypto_provider.key_store();
683
684 self.save_identity(
686 keystore,
687 Some(&history_secret.client_id),
688 key_package.ciphersuite().signature_algorithm(),
689 history_secret.credential_bundle,
690 )
691 .await?;
692
693 Ok(())
694 }
695
696 async fn find_all_basic_credentials(backend: &MlsCryptoProvider) -> Result<Vec<Credential>> {
697 let store_credentials = backend
698 .key_store()
699 .find_all::<MlsCredential>(EntityFindParams::default())
700 .await
701 .map_err(KeystoreError::wrap("finding all mls credentialss"))?;
702 let mut credentials = Vec::with_capacity(store_credentials.len());
703 for store_credential in store_credentials.into_iter() {
704 let credential = Credential::tls_deserialize(&mut store_credential.credential.as_slice())
705 .map_err(Error::tls_deserialize("credential"))?;
706 if !matches!(credential.credential_type(), CredentialType::Basic) {
707 continue;
708 }
709 credentials.push(credential);
710 }
711
712 Ok(credentials)
713 }
714
715 pub(crate) async fn save_identity(
716 &self,
717 keystore: &Connection,
718 id: Option<&ClientId>,
719 signature_scheme: SignatureScheme,
720 mut credential_bundle: CredentialBundle,
721 ) -> Result<CredentialBundle> {
722 let mut guard = self.inner.write().await;
723 let SessionInner {
724 id: existing_id,
725 identities,
726 ..
727 } = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
728
729 let id = id.unwrap_or(existing_id);
730
731 let credential = credential_bundle
732 .credential
733 .tls_serialize_detached()
734 .map_err(Error::tls_serialize("credential bundle"))?;
735 let credential = MlsCredential {
736 id: id.clone().into(),
737 credential,
738 created_at: 0,
739 };
740
741 let credential = keystore
742 .save(credential)
743 .await
744 .map_err(KeystoreError::wrap("saving credential"))?;
745
746 let sign_kp = MlsSignatureKeyPair::new(
747 signature_scheme,
748 credential_bundle.signature_key.to_public_vec(),
749 credential_bundle
750 .signature_key
751 .tls_serialize_detached()
752 .map_err(Error::tls_serialize("signature keypair"))?,
753 id.clone().into(),
754 );
755 keystore.save(sign_kp).await.map_err(|e| match e {
756 CryptoKeystoreError::AlreadyExists => Error::CredentialBundleConflict,
757 _ => KeystoreError::wrap("saving mls signature key pair")(e).into(),
758 })?;
759
760 credential_bundle.created_at = credential.created_at;
762
763 identities
764 .push_credential_bundle(signature_scheme, credential_bundle.clone())
765 .await?;
766
767 Ok(credential_bundle)
768 }
769
770 pub async fn id(&self) -> Result<ClientId> {
772 match self.inner.read().await.deref() {
773 None => Err(Error::MlsNotInitialized),
774 Some(SessionInner { id, .. }) => Ok(id.clone()),
775 }
776 }
777
778 pub async fn is_e2ei_capable(&self) -> bool {
780 match self.inner.read().await.deref() {
781 None => false,
782 Some(SessionInner { identities, .. }) => identities
783 .iter()
784 .any(|(_, cred)| cred.credential().credential_type() == CredentialType::X509),
785 }
786 }
787
788 pub(crate) async fn get_most_recent_or_create_credential_bundle(
789 &self,
790 backend: &MlsCryptoProvider,
791 sc: SignatureScheme,
792 ct: MlsCredentialType,
793 ) -> Result<Arc<CredentialBundle>> {
794 match ct {
795 MlsCredentialType::Basic => {
796 self.init_basic_credential_bundle_if_missing(backend, sc).await?;
797 self.find_most_recent_credential_bundle(sc, ct).await
798 }
799 MlsCredentialType::X509 => self
800 .find_most_recent_credential_bundle(sc, ct)
801 .await
802 .map_err(|e| match e {
803 Error::CredentialNotFound(_) => LeafError::E2eiEnrollmentNotDone.into(),
804 _ => e,
805 }),
806 }
807 }
808
809 pub(crate) async fn init_basic_credential_bundle_if_missing(
810 &self,
811 backend: &MlsCryptoProvider,
812 sc: SignatureScheme,
813 ) -> Result<()> {
814 let existing_cb = self
815 .find_most_recent_credential_bundle(sc, MlsCredentialType::Basic)
816 .await;
817 if matches!(existing_cb, Err(Error::CredentialNotFound(_))) {
818 let id = self.id().await?;
819 debug!(id:% = &id; "Initializing basic credential bundle");
820 let cb = Self::new_basic_credential_bundle(&id, sc, backend)?;
821 self.save_identity(&backend.keystore(), None, sc, cb).await?;
822 }
823 Ok(())
824 }
825
826 pub(crate) async fn save_new_x509_credential_bundle(
827 &self,
828 keystore: &Connection,
829 sc: SignatureScheme,
830 cb: CertificateBundle,
831 ) -> Result<CredentialBundle> {
832 let id = cb
833 .get_client_id()
834 .map_err(RecursiveError::mls_credential("getting client id"))?;
835 let cb = Self::new_x509_credential_bundle(cb)?;
836 self.save_identity(keystore, Some(&id), sc, cb).await
837 }
838}
839
840#[cfg(test)]
841mod tests {
842 use super::*;
843 use crate::prelude::ClientId;
844 use crate::test_utils::*;
845 use crate::transaction_context::test_utils::EntitiesCount;
846 use core_crypto_keystore::connection::{DatabaseKey, FetchFromDatabase};
847 use core_crypto_keystore::entities::*;
848 use mls_crypto_provider::MlsCryptoProvider;
849 use wasm_bindgen_test::*;
850
851 impl Session {
852 #![allow(missing_docs)]
854
855 pub async fn random_generate(
856 &self,
857 case: &crate::test_utils::TestContext,
858 signer: Option<&crate::test_utils::x509::X509Certificate>,
859 provision: bool,
860 ) -> Result<()> {
861 self.reset().await;
862 let user_uuid = uuid::Uuid::new_v4();
863 let rnd_id = rand::random::<usize>();
864 let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
865 let identity = match case.credential_type {
866 MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.as_str().into()),
867 MlsCredentialType::X509 => {
868 let signer = signer.expect("Missing intermediate CA");
869 CertificateBundle::rand_identifier(&client_id, &[signer])
870 }
871 };
872 let nb_key_package = if provision {
873 crate::prelude::INITIAL_KEYING_MATERIAL_COUNT
874 } else {
875 0
876 };
877 let backend = self.crypto_provider.clone();
878 self.generate(identity, &backend, &[case.ciphersuite()], nb_key_package)
879 .await?;
880 Ok(())
881 }
882
883 pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
884 use core_crypto_keystore::CryptoKeystoreMls as _;
885 let kps = backend
886 .key_store()
887 .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
888 .await
889 .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
890 Ok(kps)
891 }
892
893 pub(crate) async fn init_x509_credential_bundle_if_missing(
894 &self,
895 backend: &MlsCryptoProvider,
896 sc: SignatureScheme,
897 cb: CertificateBundle,
898 ) -> Result<()> {
899 let existing_cb = self
900 .find_most_recent_credential_bundle(sc, MlsCredentialType::X509)
901 .await
902 .is_err();
903 if existing_cb {
904 self.save_new_x509_credential_bundle(&backend.keystore(), sc, cb)
905 .await?;
906 }
907 Ok(())
908 }
909
910 pub(crate) async fn generate_one_keypackage(
911 &self,
912 backend: &MlsCryptoProvider,
913 cs: MlsCiphersuite,
914 ct: MlsCredentialType,
915 ) -> Result<openmls::prelude::KeyPackage> {
916 let cb = self
917 .find_most_recent_credential_bundle(cs.signature_algorithm(), ct)
918 .await?;
919 self.generate_one_keypackage_from_credential_bundle(backend, cs, &cb)
920 .await
921 }
922
923 pub async fn count_entities(&self) -> EntitiesCount {
925 let keystore = self.crypto_provider.keystore();
926 let credential = keystore.count::<MlsCredential>().await.unwrap();
927 let encryption_keypair = keystore.count::<MlsEncryptionKeyPair>().await.unwrap();
928 let epoch_encryption_keypair = keystore.count::<MlsEpochEncryptionKeyPair>().await.unwrap();
929 let enrollment = keystore.count::<E2eiEnrollment>().await.unwrap();
930 let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
931 let hpke_private_key = keystore.count::<MlsHpkePrivateKey>().await.unwrap();
932 let key_package = keystore.count::<MlsKeyPackage>().await.unwrap();
933 let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
934 let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
935 let psk_bundle = keystore.count::<MlsPskBundle>().await.unwrap();
936 let signature_keypair = keystore.count::<MlsSignatureKeyPair>().await.unwrap();
937 EntitiesCount {
938 credential,
939 encryption_keypair,
940 epoch_encryption_keypair,
941 enrollment,
942 group,
943 hpke_private_key,
944 key_package,
945 pending_group,
946 pending_messages,
947 psk_bundle,
948 signature_keypair,
949 }
950 }
951 }
952 wasm_bindgen_test_configure!(run_in_browser);
953
954 #[apply(all_cred_cipher)]
955 #[wasm_bindgen_test]
956 async fn can_generate_session(case: TestContext) {
957 let [alice] = case.sessions().await;
958 let key = DatabaseKey::generate();
959 let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
960 let x509_test_chain = if case.is_x509() {
961 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
962 x509_test_chain.register_with_provider(&backend).await;
963 Some(x509_test_chain)
964 } else {
965 None
966 };
967 backend.new_transaction().await.unwrap();
968 let session = alice.session().await;
969 session
970 .random_generate(
971 &case,
972 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
973 false,
974 )
975 .await
976 .unwrap();
977 }
978
979 #[apply(all_cred_cipher)]
980 #[wasm_bindgen_test]
981 async fn can_externally_generate_client(case: TestContext) {
982 let [alice] = case.sessions().await;
983 if !case.is_basic() {
984 return;
985 }
986 run_tests(move |[tmp_dir_argument]| {
987 Box::pin(async move {
988 let key = DatabaseKey::generate();
989 let backend = MlsCryptoProvider::try_new(tmp_dir_argument, &key).await.unwrap();
990 backend.new_transaction().await.unwrap();
991 let client_id: ClientId = b"whatever:my:client:is@world.com".to_vec().into();
993 let alice = alice.session().await;
994 alice.reset().await;
995 let handles = alice
997 .generate_raw_keypairs(&[case.ciphersuite()], &backend)
998 .await
999 .unwrap();
1000
1001 let mut identities = backend
1002 .keystore()
1003 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
1004 .await
1005 .unwrap();
1006
1007 assert_eq!(identities.len(), 1);
1008
1009 let prov_identity = identities.pop().unwrap();
1010
1011 let prov_client_id: ClientId = prov_identity.credential_id.as_slice().into();
1014 assert_eq!(&prov_client_id, handles.first().unwrap());
1015
1016 alice
1018 .init_with_external_client_id(client_id.clone(), handles.clone(), &[case.ciphersuite()], &backend)
1019 .await
1020 .unwrap();
1021
1022 assert_eq!(alice.id().await.unwrap(), client_id);
1024 let cb = alice
1025 .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
1026 .await
1027 .unwrap();
1028 let client_id: ClientId = cb.credential().identity().into();
1029 assert_eq!(&client_id, handles.first().unwrap());
1030 })
1031 })
1032 .await
1033 }
1034}