1pub(crate) mod config;
2pub(crate) mod e2e_identity;
3mod epoch_observer;
4mod error;
5mod history_observer;
6pub(crate) mod id;
7pub(crate) mod identifier;
8pub(crate) mod identities;
9pub(crate) mod key_package;
10pub(crate) mod user_id;
11
12use std::{collections::HashSet, ops::Deref, sync::Arc};
13
14use async_lock::RwLock;
15use core_crypto_keystore::{
16 CryptoKeystoreError, Database,
17 connection::FetchFromDatabase,
18 entities::{EntityFindParams, MlsCredential, MlsSignatureKeyPair},
19};
20pub use epoch_observer::EpochObserver;
21pub(crate) use error::{Error, Result};
22pub use history_observer::HistoryObserver;
23use identities::Identities;
24use key_package::KEYPACKAGE_DEFAULT_LIFETIME;
25use log::debug;
26use mls_crypto_provider::{EntropySeed, MlsCryptoProvider};
27use openmls::prelude::{Credential, CredentialType};
28use openmls_basic_credential::SignatureKeyPair;
29use openmls_traits::{OpenMlsCryptoProvider, crypto::OpenMlsCrypto, types::SignatureScheme};
30use openmls_x509_credential::CertificateKeyPair;
31use tls_codec::{Deserialize, Serialize};
32
33use crate::{
34 CertificateBundle, ClientId, ClientIdentifier, CoreCrypto, HistorySecret, KeystoreError, LeafError, MlsCiphersuite,
35 MlsCredentialType, MlsError, MlsTransport, RecursiveError, ValidatedSessionConfig,
36 group_store::GroupStore,
37 mls::{
38 self, HasSessionAndCrypto,
39 conversation::{ConversationIdRef, ImmutableConversation},
40 credential::{CredentialBundle, ext::CredentialExt},
41 },
42};
43
44#[derive(Clone, derive_more::Debug)]
55pub struct Session {
56 pub(crate) inner: Arc<RwLock<Option<SessionInner>>>,
57 pub(crate) crypto_provider: MlsCryptoProvider,
58 pub(crate) transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
59 #[debug("EpochObserver")]
60 pub(crate) epoch_observer: Arc<RwLock<Option<Arc<dyn EpochObserver + 'static>>>>,
61 #[debug("HistoryObserver")]
62 pub(crate) history_observer: Arc<RwLock<Option<Arc<dyn HistoryObserver + 'static>>>>,
63}
64
65#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
66#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
67impl HasSessionAndCrypto for Session {
68 async fn session(&self) -> mls::Result<Session> {
69 Ok(self.clone())
70 }
71
72 async fn crypto_provider(&self) -> mls::Result<MlsCryptoProvider> {
73 Ok(self.crypto_provider.clone())
74 }
75}
76
77#[derive(Clone, Debug)]
78pub(crate) struct SessionInner {
79 id: ClientId,
80 pub(crate) identities: Identities,
81 keypackage_lifetime: std::time::Duration,
82}
83
84impl Session {
85 pub async fn try_new(
99 ValidatedSessionConfig {
100 database,
101 client_id,
102 ciphersuites,
103 }: ValidatedSessionConfig,
104 ) -> crate::mls::Result<Self> {
105 let mls_backend = MlsCryptoProvider::new(database);
107
108 let session = Self {
112 crypto_provider: mls_backend.clone(),
113 inner: Default::default(),
114 transport: Arc::new(None.into()),
115 epoch_observer: Arc::new(None.into()),
116 history_observer: Arc::new(None.into()),
117 };
118
119 let cc = CoreCrypto::from(session);
120 let context = cc
121 .new_transaction()
122 .await
123 .map_err(RecursiveError::transaction("starting new transaction"))?;
124
125 if let Some(id) = client_id {
126 cc.mls
127 .init(ClientIdentifier::Basic(id), ciphersuites.as_slice(), &mls_backend)
128 .await
129 .map_err(RecursiveError::mls_client("initializing mls client"))?
130 }
131
132 context
133 .init_pki_env()
134 .await
135 .map_err(RecursiveError::transaction("initializing pki environment"))?;
136 context
137 .finish()
138 .await
139 .map_err(RecursiveError::transaction("finishing transaction"))?;
140
141 Ok(cc.mls)
142 }
143
144 pub async fn provide_transport(&self, transport: Arc<dyn MlsTransport>) {
147 self.transport.write().await.replace(transport);
148 }
149
150 pub async fn init(
162 &self,
163 identifier: ClientIdentifier,
164 ciphersuites: &[MlsCiphersuite],
165 backend: &MlsCryptoProvider,
166 ) -> Result<()> {
167 self.ensure_unready().await?;
168 let id = identifier.get_id()?;
169
170 let credentials = backend
171 .key_store()
172 .find_all::<MlsCredential>(EntityFindParams::default())
173 .await
174 .map_err(KeystoreError::wrap("finding all mls credentials"))?;
175
176 let credentials = credentials
177 .into_iter()
178 .filter(|mls_credential| mls_credential.id.as_slice() == id.as_slice())
179 .map(|mls_credential| -> Result<_> {
180 let credential = Credential::tls_deserialize(&mut mls_credential.credential.as_slice())
181 .map_err(Error::tls_deserialize("mls credential"))?;
182 Ok((credential, mls_credential.created_at))
183 })
184 .collect::<Result<Vec<_>>>()?;
185
186 if credentials.is_empty() {
187 debug!(ciphersuites:? = ciphersuites; "Generating client");
188 self.generate(identifier, backend, ciphersuites).await?;
189 } else {
190 let signature_schemes = ciphersuites
191 .iter()
192 .map(|cs| cs.signature_algorithm())
193 .collect::<HashSet<_>>();
194 let load_result = self.load(backend, id.as_ref(), credentials, signature_schemes).await;
195 if let Err(Error::ClientSignatureNotFound) = load_result {
196 debug!(ciphersuites:? = ciphersuites; "Client signature not found. Generating client");
197 self.generate(identifier, backend, ciphersuites).await?;
198 } else {
199 load_result?;
200 }
201 };
202
203 Ok(())
204 }
205
206 #[cfg(test)]
208 pub(crate) async fn reset(&self) {
209 let mut inner_lock = self.inner.write().await;
210 *inner_lock = None;
211 }
212
213 pub(crate) async fn is_ready(&self) -> bool {
214 let inner_lock = self.inner.read().await;
215 inner_lock.is_some()
216 }
217
218 async fn ensure_unready(&self) -> Result<()> {
219 if self.is_ready().await {
220 Err(Error::UnexpectedlyReady)
221 } else {
222 Ok(())
223 }
224 }
225
226 async fn replace_inner(&self, new_inner: SessionInner) {
227 let mut inner_lock = self.inner.write().await;
228 *inner_lock = Some(new_inner);
229 }
230
231 pub async fn get_raw_conversation(&self, id: &ConversationIdRef) -> Result<ImmutableConversation> {
237 let raw_conversation = GroupStore::fetch_from_keystore(id, &self.crypto_provider.keystore(), None)
238 .await
239 .map_err(RecursiveError::root("getting conversation by id"))?
240 .ok_or_else(|| LeafError::ConversationNotFound(id.to_owned()))?;
241 Ok(ImmutableConversation::new(raw_conversation, self.clone()))
242 }
243
244 pub async fn public_key(
251 &self,
252 ciphersuite: MlsCiphersuite,
253 credential_type: MlsCredentialType,
254 ) -> crate::mls::Result<Vec<u8>> {
255 let cb = self
256 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
257 .await
258 .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
259 Ok(cb.signature_key.to_public_vec())
260 }
261
262 pub(crate) fn new_basic_credential_bundle(
263 id: &ClientId,
264 sc: SignatureScheme,
265 backend: &MlsCryptoProvider,
266 ) -> Result<CredentialBundle> {
267 let (sk, pk) = backend
268 .crypto()
269 .signature_key_gen(sc)
270 .map_err(MlsError::wrap("generating a signature key"))?;
271
272 let signature_key = SignatureKeyPair::from_raw(sc, sk, pk);
273 let credential = Credential::new_basic(id.to_vec());
274 let cb = CredentialBundle {
275 credential,
276 signature_key,
277 created_at: 0,
278 };
279
280 Ok(cb)
281 }
282
283 pub(crate) fn new_x509_credential_bundle(cert: CertificateBundle) -> Result<CredentialBundle> {
284 let created_at = cert
285 .get_created_at()
286 .map_err(RecursiveError::mls_credential("getting credetntial created at"))?;
287 let (sk, ..) = cert.private_key.into_parts();
288 let chain = cert.certificate_chain;
289
290 let kp = CertificateKeyPair::new(sk, chain.clone()).map_err(MlsError::wrap("creating certificate key pair"))?;
291
292 let credential = Credential::new_x509(chain).map_err(MlsError::wrap("creating x509 credential"))?;
293
294 let cb = CredentialBundle {
295 credential,
296 signature_key: kp.0,
297 created_at,
298 };
299 Ok(cb)
300 }
301
302 pub async fn conversation_exists(&self, id: &ConversationIdRef) -> Result<bool> {
304 match self.get_raw_conversation(id).await {
305 Ok(_) => Ok(true),
306 Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
307 Err(e) => Err(e),
308 }
309 }
310
311 pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
313 use openmls_traits::random::OpenMlsRand as _;
314 self.crypto_provider
315 .rand()
316 .random_vec(len)
317 .map_err(MlsError::wrap("generating random vector"))
318 .map_err(Into::into)
319 }
320
321 pub async fn close(self) -> crate::mls::Result<()> {
327 self.crypto_provider
328 .close()
329 .await
330 .map_err(MlsError::wrap("closing connection with keystore"))
331 .map_err(Into::into)
332 }
333
334 pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
336 self.crypto_provider
337 .reseed(seed)
338 .map_err(MlsError::wrap("reseeding mls backend"))
339 .map_err(Into::into)
340 }
341
342 pub(crate) async fn generate(
344 &self,
345 identifier: ClientIdentifier,
346 backend: &MlsCryptoProvider,
347 ciphersuites: &[MlsCiphersuite],
348 ) -> Result<()> {
349 self.ensure_unready().await?;
350 let id = identifier.get_id()?;
351 let signature_schemes = ciphersuites
352 .iter()
353 .map(|cs| cs.signature_algorithm())
354 .collect::<HashSet<_>>();
355 self.replace_inner(SessionInner {
356 id: id.into_owned(),
357 identities: Identities::new(signature_schemes.len()),
358 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
359 })
360 .await;
361
362 let identities = identifier.generate_credential_bundles(backend, signature_schemes)?;
363
364 for (sc, id, cb) in identities {
365 self.save_identity(&backend.keystore(), Some(&id), sc, cb).await?;
366 }
367
368 Ok(())
369 }
370
371 pub(crate) async fn load(
373 &self,
374 backend: &MlsCryptoProvider,
375 id: &ClientId,
376 mut credentials: Vec<(Credential, u64)>,
377 signature_schemes: HashSet<SignatureScheme>,
378 ) -> Result<()> {
379 self.ensure_unready().await?;
380 let mut identities = Identities::new(signature_schemes.len());
381
382 credentials.sort_by_key(|(_, timestamp)| *timestamp);
384
385 let stored_signature_keypairs = backend
386 .key_store()
387 .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
388 .await
389 .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
390
391 for signature_scheme in signature_schemes {
392 let signature_keypair = stored_signature_keypairs
393 .iter()
394 .find(|skp| skp.signature_scheme == (signature_scheme as u16));
395
396 let signature_key = if let Some(kp) = signature_keypair {
397 SignatureKeyPair::tls_deserialize(&mut kp.keypair.as_slice())
398 .map_err(Error::tls_deserialize("signature keypair"))?
399 } else {
400 let (private_key, public_key) = backend
401 .crypto()
402 .signature_key_gen(signature_scheme)
403 .map_err(MlsError::wrap("generating signature key"))?;
404 let keypair = SignatureKeyPair::from_raw(signature_scheme, private_key, public_key.clone());
405 let raw_keypair = keypair
406 .tls_serialize_detached()
407 .map_err(Error::tls_serialize("raw keypair"))?;
408 let store_keypair =
409 MlsSignatureKeyPair::new(signature_scheme, public_key, raw_keypair, id.as_slice().into());
410 backend
411 .key_store()
412 .save(store_keypair.clone())
413 .await
414 .map_err(KeystoreError::wrap("storing keypairs in keystore"))?;
415 SignatureKeyPair::tls_deserialize(&mut store_keypair.keypair.as_slice())
416 .map_err(Error::tls_deserialize("signature keypair"))?
417 };
418
419 for (credential, created_at) in &credentials {
420 match credential.mls_credential() {
421 openmls::prelude::MlsCredentialType::Basic(_) => {
422 if id.as_slice() != credential.identity() {
423 return Err(Error::WrongCredential);
424 }
425 }
426 openmls::prelude::MlsCredentialType::X509(cert) => {
427 let spk = cert
428 .extract_public_key()
429 .map_err(RecursiveError::mls_credential("extracting public key"))?
430 .ok_or(LeafError::InternalMlsError)?;
431 if signature_key.public() != spk {
432 return Err(Error::WrongCredential);
433 }
434 }
435 };
436 let cb = CredentialBundle {
437 credential: credential.clone(),
438 signature_key: signature_key.clone(),
439 created_at: *created_at,
440 };
441 identities.push_credential_bundle(signature_scheme, cb).await?;
442 }
443 }
444 self.replace_inner(SessionInner {
445 id: id.clone(),
446 identities,
447 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
448 })
449 .await;
450 Ok(())
451 }
452
453 pub(crate) async fn restore_from_history_secret(&self, history_secret: HistorySecret) -> Result<()> {
455 self.ensure_unready().await?;
456
457 self.replace_inner(SessionInner {
459 id: history_secret.client_id.clone(),
460 identities: Identities::new(0),
461 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
462 })
463 .await;
464
465 history_secret
467 .key_package
468 .store(&self.crypto_provider)
469 .await
470 .map_err(MlsError::wrap("storing key package encapsulation"))?;
471
472 Ok(())
473 }
474
475 pub(crate) async fn save_identity(
476 &self,
477 keystore: &Database,
478 id: Option<&ClientId>,
479 signature_scheme: SignatureScheme,
480 mut credential_bundle: CredentialBundle,
481 ) -> Result<CredentialBundle> {
482 let mut guard = self.inner.write().await;
483 let SessionInner {
484 id: existing_id,
485 identities,
486 ..
487 } = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
488
489 let id = id.unwrap_or(existing_id);
490
491 let credential = credential_bundle
492 .credential
493 .tls_serialize_detached()
494 .map_err(Error::tls_serialize("credential bundle"))?;
495 let credential = MlsCredential {
496 id: id.clone().into(),
497 credential,
498 created_at: 0,
499 };
500
501 let credential = keystore
502 .save(credential)
503 .await
504 .map_err(KeystoreError::wrap("saving credential"))?;
505
506 let sign_kp = MlsSignatureKeyPair::new(
507 signature_scheme,
508 credential_bundle.signature_key.to_public_vec(),
509 credential_bundle
510 .signature_key
511 .tls_serialize_detached()
512 .map_err(Error::tls_serialize("signature keypair"))?,
513 id.clone().into(),
514 );
515 keystore.save(sign_kp).await.map_err(|e| match e {
516 CryptoKeystoreError::AlreadyExists(_) => Error::CredentialBundleConflict,
517 _ => KeystoreError::wrap("saving mls signature key pair")(e).into(),
518 })?;
519
520 credential_bundle.created_at = credential.created_at;
522
523 identities
524 .push_credential_bundle(signature_scheme, credential_bundle.clone())
525 .await?;
526
527 Ok(credential_bundle)
528 }
529
530 pub async fn id(&self) -> Result<ClientId> {
532 match self.inner.read().await.deref() {
533 None => Err(Error::MlsNotInitialized),
534 Some(SessionInner { id, .. }) => Ok(id.clone()),
535 }
536 }
537
538 pub async fn is_e2ei_capable(&self) -> bool {
540 match self.inner.read().await.deref() {
541 None => false,
542 Some(SessionInner { identities, .. }) => identities
543 .iter()
544 .any(|(_, cred)| cred.credential().credential_type() == CredentialType::X509),
545 }
546 }
547
548 pub(crate) async fn get_most_recent_or_create_credential_bundle(
549 &self,
550 backend: &MlsCryptoProvider,
551 sc: SignatureScheme,
552 ct: MlsCredentialType,
553 ) -> Result<Arc<CredentialBundle>> {
554 match ct {
555 MlsCredentialType::Basic => {
556 self.init_basic_credential_bundle_if_missing(backend, sc).await?;
557 self.find_most_recent_credential_bundle(sc, ct).await
558 }
559 MlsCredentialType::X509 => self
560 .find_most_recent_credential_bundle(sc, ct)
561 .await
562 .map_err(|e| match e {
563 Error::CredentialNotFound(_) => LeafError::E2eiEnrollmentNotDone.into(),
564 _ => e,
565 }),
566 }
567 }
568
569 pub(crate) async fn init_basic_credential_bundle_if_missing(
570 &self,
571 backend: &MlsCryptoProvider,
572 sc: SignatureScheme,
573 ) -> Result<()> {
574 let existing_cb = self
575 .find_most_recent_credential_bundle(sc, MlsCredentialType::Basic)
576 .await;
577 if matches!(existing_cb, Err(Error::CredentialNotFound(_))) {
578 let id = self.id().await?;
579 debug!(id:% = &id; "Initializing basic credential bundle");
580 let cb = Self::new_basic_credential_bundle(&id, sc, backend)?;
581 self.save_identity(&backend.keystore(), None, sc, cb).await?;
582 }
583 Ok(())
584 }
585
586 pub(crate) async fn save_new_x509_credential_bundle(
587 &self,
588 keystore: &Database,
589 sc: SignatureScheme,
590 cb: CertificateBundle,
591 ) -> Result<CredentialBundle> {
592 let id = cb
593 .get_client_id()
594 .map_err(RecursiveError::mls_credential("getting client id"))?;
595 let cb = Self::new_x509_credential_bundle(cb)?;
596 self.save_identity(keystore, Some(&id), sc, cb).await
597 }
598}
599
600#[cfg(test)]
601mod tests {
602 use core_crypto_keystore::{connection::FetchFromDatabase as _, entities::*};
603 use mls_crypto_provider::MlsCryptoProvider;
604
605 use super::*;
606 use crate::{test_utils::*, transaction_context::test_utils::EntitiesCount};
607
608 impl Session {
609 #![allow(missing_docs)]
611
612 pub async fn random_generate(
613 &self,
614 case: &crate::test_utils::TestContext,
615 signer: Option<&crate::test_utils::x509::X509Certificate>,
616 ) -> Result<()> {
617 self.reset().await;
618 let user_uuid = uuid::Uuid::new_v4();
619 let rnd_id = rand::random::<usize>();
620 let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
621 let identity = match case.credential_type {
622 MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.as_str().into()),
623 MlsCredentialType::X509 => {
624 let signer = signer.expect("Missing intermediate CA");
625 CertificateBundle::rand_identifier(&client_id, &[signer])
626 }
627 };
628 let backend = self.crypto_provider.clone();
629 self.generate(identity, &backend, &[case.ciphersuite()]).await?;
630 Ok(())
631 }
632
633 pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
634 use core_crypto_keystore::CryptoKeystoreMls as _;
635 let kps = backend
636 .key_store()
637 .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
638 .await
639 .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
640 Ok(kps)
641 }
642
643 pub(crate) async fn generate_one_keypackage(
644 &self,
645 backend: &MlsCryptoProvider,
646 cs: MlsCiphersuite,
647 ct: MlsCredentialType,
648 ) -> Result<openmls::prelude::KeyPackage> {
649 let cb = self
650 .find_most_recent_credential_bundle(cs.signature_algorithm(), ct)
651 .await?;
652 self.generate_one_keypackage_from_credential_bundle(backend, cs, &cb)
653 .await
654 }
655
656 pub async fn count_entities(&self) -> EntitiesCount {
658 let keystore = self.crypto_provider.keystore();
659 let credential = keystore.count::<MlsCredential>().await.unwrap();
660 let encryption_keypair = keystore.count::<MlsEncryptionKeyPair>().await.unwrap();
661 let epoch_encryption_keypair = keystore.count::<MlsEpochEncryptionKeyPair>().await.unwrap();
662 let enrollment = keystore.count::<E2eiEnrollment>().await.unwrap();
663 let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
664 let hpke_private_key = keystore.count::<MlsHpkePrivateKey>().await.unwrap();
665 let key_package = keystore.count::<MlsKeyPackage>().await.unwrap();
666 let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
667 let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
668 let psk_bundle = keystore.count::<MlsPskBundle>().await.unwrap();
669 let signature_keypair = keystore.count::<MlsSignatureKeyPair>().await.unwrap();
670 EntitiesCount {
671 credential,
672 encryption_keypair,
673 epoch_encryption_keypair,
674 enrollment,
675 group,
676 hpke_private_key,
677 key_package,
678 pending_group,
679 pending_messages,
680 psk_bundle,
681 signature_keypair,
682 }
683 }
684 }
685
686 #[apply(all_cred_cipher)]
687 async fn can_generate_session(mut case: TestContext) {
688 let [alice] = case.sessions().await;
689 let key_store = case.create_in_memory_database().await;
690 let backend = MlsCryptoProvider::new(key_store);
691 let x509_test_chain = if case.is_x509() {
692 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
693 x509_test_chain.register_with_provider(&backend).await;
694 Some(x509_test_chain)
695 } else {
696 None
697 };
698 backend.new_transaction().await.unwrap();
699 let session = alice.session().await;
700 session
701 .random_generate(
702 &case,
703 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
704 )
705 .await
706 .unwrap();
707 }
708}