1mod credential;
2pub(crate) mod e2e_identity;
3mod epoch_observer;
4mod error;
5mod history_observer;
6pub(crate) mod id;
7pub(crate) mod identifier;
8pub(crate) mod identities;
9pub(crate) mod key_package;
10pub(crate) mod user_id;
11
12use std::sync::Arc;
13
14use async_lock::RwLock;
15use core_crypto_keystore::Database;
16pub use epoch_observer::EpochObserver;
17pub(crate) use error::{Error, Result};
18pub use history_observer::HistoryObserver;
19use identities::Identities;
20use mls_crypto_provider::{EntropySeed, MlsCryptoProvider};
21use openmls_traits::{OpenMlsCryptoProvider, types::SignatureScheme};
22
23use crate::{
24 Ciphersuite, ClientId, ClientIdentifier, CoreCrypto, CredentialFindFilters, CredentialRef, CredentialType,
25 HistorySecret, LeafError, MlsError, MlsTransport, RecursiveError,
26 group_store::GroupStore,
27 mls::{
28 self, HasSessionAndCrypto,
29 conversation::{ConversationIdRef, ImmutableConversation},
30 },
31};
32
33#[derive(Clone, derive_more::Debug)]
44pub struct Session {
45 pub(crate) inner: Arc<RwLock<Option<SessionInner>>>,
46 pub(crate) crypto_provider: MlsCryptoProvider,
47 pub(crate) transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
48 #[debug("EpochObserver")]
49 pub(crate) epoch_observer: Arc<RwLock<Option<Arc<dyn EpochObserver + 'static>>>>,
50 #[debug("HistoryObserver")]
51 pub(crate) history_observer: Arc<RwLock<Option<Arc<dyn HistoryObserver + 'static>>>>,
52}
53
54#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
55#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
56impl HasSessionAndCrypto for Session {
57 async fn session(&self) -> mls::Result<Session> {
58 Ok(self.clone())
59 }
60
61 async fn crypto_provider(&self) -> mls::Result<MlsCryptoProvider> {
62 Ok(self.crypto_provider.clone())
63 }
64}
65
66#[derive(Clone, Debug)]
67pub(crate) struct SessionInner {
68 id: ClientId,
69 pub(crate) identities: Identities,
70}
71
72impl Session {
73 pub async fn try_new(database: &Database) -> crate::mls::Result<Self> {
81 let database = database.to_owned();
83 let mls_backend = MlsCryptoProvider::new(database);
85
86 let session = Self {
90 crypto_provider: mls_backend,
91 inner: Default::default(),
92 transport: Arc::new(None.into()),
93 epoch_observer: Arc::new(None.into()),
94 history_observer: Arc::new(None.into()),
95 };
96
97 let cc = CoreCrypto::from(session);
98 let context = cc
99 .new_transaction()
100 .await
101 .map_err(RecursiveError::transaction("starting new transaction"))?;
102
103 context
104 .init_pki_env()
105 .await
106 .map_err(RecursiveError::transaction("initializing pki environment"))?;
107 context
108 .finish()
109 .await
110 .map_err(RecursiveError::transaction("finishing transaction"))?;
111
112 Ok(cc.mls)
113 }
114
115 pub async fn provide_transport(&self, transport: Arc<dyn MlsTransport>) {
118 self.transport.write().await.replace(transport);
119 }
120
121 pub async fn init(&self, identifier: ClientIdentifier, signature_schemes: &[SignatureScheme]) -> Result<()> {
127 self.ensure_unready().await?;
128 let client_id = identifier.get_id()?.into_owned();
129
130 let mut credential_refs = CredentialRef::find(
139 &self.crypto_provider.keystore(),
140 CredentialFindFilters::builder().client_id(&client_id).build(),
141 )
142 .await
143 .map_err(RecursiveError::mls_credential_ref(
144 "loading matching credential refs while initializing a client",
145 ))?;
146 credential_refs.retain(|credential_ref| signature_schemes.contains(&credential_ref.signature_scheme()));
147
148 let mut identities = Identities::new(credential_refs.len());
149 let credentials_cache = CredentialRef::load_stored_credentials(&self.crypto_provider.keystore())
150 .await
151 .map_err(RecursiveError::mls_credential_ref(
152 "loading credential ref cache while initializing session",
153 ))?;
154
155 for credential_ref in credential_refs {
156 if let Some(credential) =
157 credential_ref
158 .load_from_cache(&credentials_cache)
159 .map_err(RecursiveError::mls_credential_ref(
160 "loading credential list in session init",
161 ))?
162 {
163 match identities.push_credential(credential).await {
164 Err(Error::CredentialConflict) => {
165 }
168 Ok(_) => {}
169 Err(err) => {
170 return Err(RecursiveError::MlsClient {
171 context: "adding credential to identities in init",
172 source: Box::new(err),
173 }
174 .into());
175 }
176 }
177 }
178 }
179
180 self.replace_inner(SessionInner {
181 id: client_id,
182 identities,
183 })
184 .await;
185
186 Ok(())
187 }
188
189 #[cfg(test)]
191 pub(crate) async fn reset(&self) {
192 let mut inner_lock = self.inner.write().await;
193 *inner_lock = None;
194 }
195
196 pub(crate) async fn is_ready(&self) -> bool {
197 let inner_lock = self.inner.read().await;
198 inner_lock.is_some()
199 }
200
201 async fn ensure_unready(&self) -> Result<()> {
202 if self.is_ready().await {
203 Err(Error::UnexpectedlyReady)
204 } else {
205 Ok(())
206 }
207 }
208
209 async fn replace_inner(&self, new_inner: SessionInner) {
210 let mut inner_lock = self.inner.write().await;
211 *inner_lock = Some(new_inner);
212 }
213
214 pub async fn get_raw_conversation(&self, id: &ConversationIdRef) -> Result<ImmutableConversation> {
220 let raw_conversation = GroupStore::fetch_from_keystore(id, &self.crypto_provider.keystore(), None)
221 .await
222 .map_err(RecursiveError::root("getting conversation by id"))?
223 .ok_or_else(|| LeafError::ConversationNotFound(id.to_owned()))?;
224 Ok(ImmutableConversation::new(raw_conversation, self.clone()))
225 }
226
227 pub async fn public_key(
234 &self,
235 ciphersuite: Ciphersuite,
236 credential_type: CredentialType,
237 ) -> crate::mls::Result<Vec<u8>> {
238 let cb = self
239 .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type)
240 .await
241 .map_err(RecursiveError::mls_client("finding most recent credential"))?;
242 Ok(cb.signature_key_pair.to_public_vec())
243 }
244
245 pub async fn conversation_exists(&self, id: &ConversationIdRef) -> Result<bool> {
247 match self.get_raw_conversation(id).await {
248 Ok(_) => Ok(true),
249 Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
250 Err(e) => Err(e),
251 }
252 }
253
254 pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
256 use openmls_traits::random::OpenMlsRand as _;
257 self.crypto_provider
258 .rand()
259 .random_vec(len)
260 .map_err(MlsError::wrap("generating random vector"))
261 .map_err(Into::into)
262 }
263
264 pub async fn close(&self) -> crate::mls::Result<()> {
270 self.crypto_provider
271 .close()
272 .await
273 .map_err(MlsError::wrap("closing connection with keystore"))
274 .map_err(Into::into)
275 }
276
277 pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
279 self.crypto_provider
280 .reseed(seed)
281 .map_err(MlsError::wrap("reseeding mls backend"))
282 .map_err(Into::into)
283 }
284
285 pub(crate) async fn restore_from_history_secret(&self, history_secret: HistorySecret) -> Result<()> {
287 self.ensure_unready().await?;
288
289 self.replace_inner(SessionInner {
291 id: history_secret.client_id.clone(),
292 identities: Identities::new(0),
293 })
294 .await;
295
296 history_secret
298 .key_package
299 .store(&self.crypto_provider)
300 .await
301 .map_err(MlsError::wrap("storing key package encapsulation"))?;
302
303 Ok(())
304 }
305
306 pub async fn id(&self) -> Result<ClientId> {
308 match &*self.inner.read().await {
309 None => Err(Error::MlsNotInitialized),
310 Some(SessionInner { id, .. }) => Ok(id.clone()),
311 }
312 }
313
314 pub async fn is_e2ei_capable(&self) -> bool {
316 match &*self.inner.read().await {
317 None => false,
318 Some(SessionInner { identities, .. }) => identities
319 .iter()
320 .any(|cred| cred.credential_type() == CredentialType::X509),
321 }
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use core_crypto_keystore::{connection::FetchFromDatabase as _, entities::*};
328 use mls_crypto_provider::MlsCryptoProvider;
329
330 use super::*;
331 use crate::{
332 CertificateBundle, Credential, KeystoreError, test_utils::*, transaction_context::test_utils::EntitiesCount,
333 };
334
335 impl Session {
336 #![allow(missing_docs)]
338
339 pub async fn random_generate(
341 &self,
342 case: &crate::test_utils::TestContext,
343 signer: Option<&crate::test_utils::x509::X509Certificate>,
344 ) -> Result<()> {
345 self.reset().await;
346 let user_uuid = uuid::Uuid::new_v4();
347 let rnd_id = rand::random::<usize>();
348 let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
349 let client_id = ClientId(client_id.into_bytes());
350
351 let credential;
352 let identifier;
353 match case.credential_type {
354 CredentialType::Basic => {
355 identifier = ClientIdentifier::Basic(client_id.clone());
356 credential = Credential::basic(case.ciphersuite(), client_id, &self.crypto_provider).unwrap();
357 }
358 CredentialType::X509 => {
359 let signer = signer.expect("Missing intermediate CA").to_owned();
360 let cert = CertificateBundle::rand(&client_id, &signer);
361 identifier = ClientIdentifier::X509([(case.signature_scheme(), cert.clone())].into());
362 credential = Credential::x509(case.ciphersuite(), cert).unwrap();
363 }
364 };
365
366 self.init(identifier, &[case.signature_scheme()]).await.unwrap();
367
368 self.add_credential(credential).await.unwrap();
369
370 Ok(())
371 }
372
373 pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
374 use core_crypto_keystore::CryptoKeystoreMls as _;
375 let kps = backend
376 .key_store()
377 .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
378 .await
379 .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
380 Ok(kps)
381 }
382
383 pub async fn count_entities(&self) -> EntitiesCount {
385 let keystore = self.crypto_provider.keystore();
386 let credential = keystore.count::<StoredCredential>().await.unwrap();
387 let encryption_keypair = keystore.count::<StoredEncryptionKeyPair>().await.unwrap();
388 let epoch_encryption_keypair = keystore.count::<StoredEpochEncryptionKeypair>().await.unwrap();
389 let enrollment = keystore.count::<StoredE2eiEnrollment>().await.unwrap();
390 let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
391 let hpke_private_key = keystore.count::<StoredHpkePrivateKey>().await.unwrap();
392 let key_package = keystore.count::<StoredKeypackage>().await.unwrap();
393 let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
394 let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
395 let psk_bundle = keystore.count::<StoredPskBundle>().await.unwrap();
396 EntitiesCount {
397 credential,
398 encryption_keypair,
399 epoch_encryption_keypair,
400 enrollment,
401 group,
402 hpke_private_key,
403 key_package,
404 pending_group,
405 pending_messages,
406 psk_bundle,
407 }
408 }
409 }
410
411 #[apply(all_cred_cipher)]
412 async fn can_generate_session(mut case: TestContext) {
413 let [alice] = case.sessions().await;
414 let key_store = case.create_in_memory_database().await;
415 let backend = MlsCryptoProvider::new(key_store);
416 let x509_test_chain = if case.is_x509() {
417 let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
418 x509_test_chain.register_with_provider(&backend).await;
419 Some(x509_test_chain)
420 } else {
421 None
422 };
423 backend.new_transaction().await.unwrap();
424 let session = alice.session().await;
425 session
426 .random_generate(
427 &case,
428 x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
429 )
430 .await
431 .unwrap();
432 }
433}