1mod credential;
2pub(crate) mod e2e_identity;
3mod epoch_observer;
4mod error;
5mod history_observer;
6pub(crate) mod id;
7pub(crate) mod identifier;
8pub(crate) mod identities;
9pub(crate) mod key_package;
10pub(crate) mod user_id;
11
12use std::sync::Arc;
13
14use async_lock::RwLock;
15use core_crypto_keystore::Database;
16pub use epoch_observer::EpochObserver;
17pub(crate) use error::{Error, Result};
18pub use history_observer::HistoryObserver;
19use identities::Identities;
20use key_package::KEYPACKAGE_DEFAULT_LIFETIME;
21use mls_crypto_provider::{EntropySeed, MlsCryptoProvider};
22use openmls_traits::{OpenMlsCryptoProvider, types::SignatureScheme};
23
24use crate::{
25 Ciphersuite, ClientId, ClientIdentifier, CoreCrypto, CredentialFindFilters, CredentialRef, CredentialType,
26 HistorySecret, LeafError, MlsError, MlsTransport, RecursiveError,
27 group_store::GroupStore,
28 mls::{
29 self, HasSessionAndCrypto,
30 conversation::{ConversationIdRef, ImmutableConversation},
31 },
32};
33
34#[derive(Clone, derive_more::Debug)]
45pub struct Session {
46 pub(crate) inner: Arc<RwLock<Option<SessionInner>>>,
47 pub(crate) crypto_provider: MlsCryptoProvider,
48 pub(crate) transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
49 #[debug("EpochObserver")]
50 pub(crate) epoch_observer: Arc<RwLock<Option<Arc<dyn EpochObserver + 'static>>>>,
51 #[debug("HistoryObserver")]
52 pub(crate) history_observer: Arc<RwLock<Option<Arc<dyn HistoryObserver + 'static>>>>,
53}
54
55#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
56#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
57impl HasSessionAndCrypto for Session {
58 async fn session(&self) -> mls::Result<Session> {
59 Ok(self.clone())
60 }
61
62 async fn crypto_provider(&self) -> mls::Result<MlsCryptoProvider> {
63 Ok(self.crypto_provider.clone())
64 }
65}
66
67#[derive(Clone, Debug)]
68pub(crate) struct SessionInner {
69 id: ClientId,
70 pub(crate) identities: Identities,
71 keypackage_lifetime: std::time::Duration,
72}
73
74impl Session {
75 pub async fn try_new(database: &Database) -> crate::mls::Result<Self> {
83 let database = database.to_owned();
85 let mls_backend = MlsCryptoProvider::new(database);
87
88 let session = Self {
92 crypto_provider: mls_backend,
93 inner: Default::default(),
94 transport: Arc::new(None.into()),
95 epoch_observer: Arc::new(None.into()),
96 history_observer: Arc::new(None.into()),
97 };
98
99 let cc = CoreCrypto::from(session);
100 let context = cc
101 .new_transaction()
102 .await
103 .map_err(RecursiveError::transaction("starting new transaction"))?;
104
105 context
106 .init_pki_env()
107 .await
108 .map_err(RecursiveError::transaction("initializing pki environment"))?;
109 context
110 .finish()
111 .await
112 .map_err(RecursiveError::transaction("finishing transaction"))?;
113
114 Ok(cc.mls)
115 }
116
117 pub async fn provide_transport(&self, transport: Arc<dyn MlsTransport>) {
120 self.transport.write().await.replace(transport);
121 }
122
123 pub async fn init(&self, identifier: ClientIdentifier, signature_schemes: &[SignatureScheme]) -> Result<()> {
129 self.ensure_unready().await?;
130 let client_id = identifier.get_id()?.into_owned();
131
132 let mut credential_refs = CredentialRef::find(
141 &self.crypto_provider.keystore(),
142 CredentialFindFilters::builder().client_id(&client_id).build(),
143 )
144 .await
145 .map_err(RecursiveError::mls_credential_ref(
146 "loading matching credential refs while initializing a client",
147 ))?;
148 credential_refs.retain(|credential_ref| signature_schemes.contains(&credential_ref.signature_scheme()));
149
150 let mut identities = Identities::new(credential_refs.len());
151 let credentials_cache = CredentialRef::load_stored_credentials(&self.crypto_provider.keystore())
152 .await
153 .map_err(RecursiveError::mls_credential_ref(
154 "loading credential ref cache while initializing session",
155 ))?;
156
157 for credential_ref in credential_refs {
158 for credential_result in
159 credential_ref
160 .load_from_cache(&credentials_cache)
161 .map_err(RecursiveError::mls_credential_ref(
162 "loading credential list in session init",
163 ))?
164 {
165 let credential = credential_result
166 .map_err(RecursiveError::mls_credential_ref("loading credential in session init"))?;
167
168 match identities.push_credential(credential).await {
169 Err(Error::CredentialConflict) => {
170 }
173 Ok(_) => {}
174 Err(err) => {
175 return Err(RecursiveError::MlsClient {
176 context: "adding credential to identities in init",
177 source: Box::new(err),
178 }
179 .into());
180 }
181 }
182 }
183 }
184
185 self.replace_inner(SessionInner {
186 id: client_id,
187 identities,
188 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
189 })
190 .await;
191
192 Ok(())
193 }
194
195 #[cfg(test)]
197 pub(crate) async fn reset(&self) {
198 let mut inner_lock = self.inner.write().await;
199 *inner_lock = None;
200 }
201
202 pub(crate) async fn is_ready(&self) -> bool {
203 let inner_lock = self.inner.read().await;
204 inner_lock.is_some()
205 }
206
207 async fn ensure_unready(&self) -> Result<()> {
208 if self.is_ready().await {
209 Err(Error::UnexpectedlyReady)
210 } else {
211 Ok(())
212 }
213 }
214
215 async fn replace_inner(&self, new_inner: SessionInner) {
216 let mut inner_lock = self.inner.write().await;
217 *inner_lock = Some(new_inner);
218 }
219
220 pub async fn get_raw_conversation(&self, id: &ConversationIdRef) -> Result<ImmutableConversation> {
226 let raw_conversation = GroupStore::fetch_from_keystore(id, &self.crypto_provider.keystore(), None)
227 .await
228 .map_err(RecursiveError::root("getting conversation by id"))?
229 .ok_or_else(|| LeafError::ConversationNotFound(id.to_owned()))?;
230 Ok(ImmutableConversation::new(raw_conversation, self.clone()))
231 }
232
233 pub async fn public_key(
240 &self,
241 ciphersuite: Ciphersuite,
242 credential_type: CredentialType,
243 ) -> crate::mls::Result<Vec<u8>> {
244 let cb = self
245 .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type)
246 .await
247 .map_err(RecursiveError::mls_client("finding most recent credential"))?;
248 Ok(cb.signature_key_pair.to_public_vec())
249 }
250
251 pub async fn conversation_exists(&self, id: &ConversationIdRef) -> Result<bool> {
253 match self.get_raw_conversation(id).await {
254 Ok(_) => Ok(true),
255 Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
256 Err(e) => Err(e),
257 }
258 }
259
260 pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
262 use openmls_traits::random::OpenMlsRand as _;
263 self.crypto_provider
264 .rand()
265 .random_vec(len)
266 .map_err(MlsError::wrap("generating random vector"))
267 .map_err(Into::into)
268 }
269
270 pub async fn close(&self) -> crate::mls::Result<()> {
276 self.crypto_provider
277 .close()
278 .await
279 .map_err(MlsError::wrap("closing connection with keystore"))
280 .map_err(Into::into)
281 }
282
283 pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
285 self.crypto_provider
286 .reseed(seed)
287 .map_err(MlsError::wrap("reseeding mls backend"))
288 .map_err(Into::into)
289 }
290
291 pub(crate) async fn restore_from_history_secret(&self, history_secret: HistorySecret) -> Result<()> {
293 self.ensure_unready().await?;
294
295 self.replace_inner(SessionInner {
297 id: history_secret.client_id.clone(),
298 identities: Identities::new(0),
299 keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
300 })
301 .await;
302
303 history_secret
305 .key_package
306 .store(&self.crypto_provider)
307 .await
308 .map_err(MlsError::wrap("storing key package encapsulation"))?;
309
310 Ok(())
311 }
312
313 pub async fn id(&self) -> Result<ClientId> {
315 match &*self.inner.read().await {
316 None => Err(Error::MlsNotInitialized),
317 Some(SessionInner { id, .. }) => Ok(id.clone()),
318 }
319 }
320
321 pub async fn is_e2ei_capable(&self) -> bool {
323 match &*self.inner.read().await {
324 None => false,
325 Some(SessionInner { identities, .. }) => identities
326 .iter()
327 .any(|cred| cred.credential_type() == CredentialType::X509),
328 }
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use core_crypto_keystore::{connection::FetchFromDatabase as _, entities::*};
335 use mls_crypto_provider::MlsCryptoProvider;
336
337 use super::*;
338 use crate::{
339 CertificateBundle, Credential, KeystoreError, test_utils::*, transaction_context::test_utils::EntitiesCount,
340 };
341
342 impl Session {
343 #![allow(missing_docs)]
345
346 pub async fn random_generate(
348 &self,
349 case: &crate::test_utils::TestContext,
350 signer: Option<&crate::test_utils::x509::X509Certificate>,
351 ) -> Result<()> {
352 self.reset().await;
353 let user_uuid = uuid::Uuid::new_v4();
354 let rnd_id = rand::random::<usize>();
355 let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
356 let client_id = ClientId(client_id.into_bytes());
357
358 let credential;
359 let identifier;
360 match case.credential_type {
361 CredentialType::Basic => {
362 identifier = ClientIdentifier::Basic(client_id.clone());
363 credential = Credential::basic(case.signature_scheme(), client_id, &self.crypto_provider).unwrap();
364 }
365 CredentialType::X509 => {
366 let signer = signer.expect("Missing intermediate CA").to_owned();
367 let cert = CertificateBundle::rand(&client_id, &signer);
368 identifier = ClientIdentifier::X509([(case.signature_scheme(), cert.clone())].into());
369 credential = Credential::x509(cert).unwrap();
370 }
371 };
372
373 self.init(identifier, &[case.signature_scheme()]).await.unwrap();
374
375 self.add_credential(credential).await.unwrap();
376
377 Ok(())
378 }
379
380 pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
381 use core_crypto_keystore::CryptoKeystoreMls as _;
382 let kps = backend
383 .key_store()
384 .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
385 .await
386 .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
387 Ok(kps)
388 }
389
390 pub(crate) async fn generate_one_keypackage(
391 &self,
392 backend: &MlsCryptoProvider,
393 cs: Ciphersuite,
394 ct: CredentialType,
395 ) -> Result<openmls::prelude::KeyPackage> {
396 let cb = self.find_most_recent_credential(cs.signature_algorithm(), ct).await?;
397 self.generate_one_keypackage_from_credential(backend, cs, &cb).await
398 }
399
400 pub async fn count_entities(&self) -> EntitiesCount {
402 let keystore = self.crypto_provider.keystore();
403 let credential = keystore.count::<StoredCredential>().await.unwrap();
404 let encryption_keypair = keystore.count::<StoredEncryptionKeyPair>().await.unwrap();
405 let epoch_encryption_keypair = keystore.count::<StoredEpochEncryptionKeypair>().await.unwrap();
406 let enrollment = keystore.count::<StoredE2eiEnrollment>().await.unwrap();
407 let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
408 let hpke_private_key = keystore.count::<StoredHpkePrivateKey>().await.unwrap();
409 let key_package = keystore.count::<StoredKeypackage>().await.unwrap();
410 let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
411 let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
412 let psk_bundle = keystore.count::<StoredPskBundle>().await.unwrap();
413 EntitiesCount {
414 credential,
415 encryption_keypair,
416 epoch_encryption_keypair,
417 enrollment,
418 group,
419 hpke_private_key,
420 key_package,
421 pending_group,
422 pending_messages,
423 psk_bundle,
424 }
425 }
426 }
427
428 #[apply(all_cred_cipher)]
429 async fn can_generate_session(mut case: TestContext) {
430 let [alice] = case.sessions().await;
431 let key_store = case.create_in_memory_database().await;
432 let backend = MlsCryptoProvider::new(key_store);
433 let x509_test_chain = if case.is_x509() {
434 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
435 x509_test_chain.register_with_provider(&backend).await;
436 Some(x509_test_chain)
437 } else {
438 None
439 };
440 backend.new_transaction().await.unwrap();
441 let session = alice.session().await;
442 session
443 .random_generate(
444 &case,
445 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
446 )
447 .await
448 .unwrap();
449 }
450}