core_crypto/mls/session/
mod.rs1mod credential;
2pub(crate) mod e2e_identity;
3mod epoch_observer;
4mod error;
5mod history_observer;
6pub(crate) mod id;
7pub(crate) mod identifier;
8pub(crate) mod key_package;
9pub(crate) mod user_id;
10
11use std::sync::Arc;
12
13use async_lock::{Mutex, RwLock};
14pub use epoch_observer::EpochObserver;
15pub(crate) use error::{Error, Result};
16pub use history_observer::HistoryObserver;
17use openmls_traits::OpenMlsCryptoProvider;
18
19use crate::{
20 ClientId, HistorySecret, ImmutableDatabase, LeafError, MlsTransport, OpenMlsError, RecursiveError,
21 mls::{
22 conversation::{Conversation, ConversationIdRef},
23 conversation_cache::ConversationCache,
24 },
25 mls_provider::{CryptoProvider, EntropySeed},
26};
27
28#[derive(Clone, derive_more::Debug)]
39pub struct Session {
40 id: ClientId,
41 pub(crate) crypto_provider: CryptoProvider,
42 pub(crate) transport: Arc<dyn MlsTransport + 'static>,
43 database: ImmutableDatabase,
44 #[debug("EpochObserver")]
45 pub(crate) epoch_observer: Arc<RwLock<Option<Arc<dyn EpochObserver + 'static>>>>,
46 #[debug("HistoryObserver")]
47 pub(crate) history_observer: Arc<RwLock<Option<Arc<dyn HistoryObserver + 'static>>>>,
48 pub(crate) conversation_cache: Arc<Mutex<ConversationCache>>,
53}
54
55impl Session {
56 pub fn new(
58 id: ClientId,
59 crypto_provider: CryptoProvider,
60 database: ImmutableDatabase,
61 transport: Arc<dyn MlsTransport>,
62 ) -> Self {
63 Self {
64 id,
65 crypto_provider,
66 transport,
67 database,
68 epoch_observer: Arc::new(RwLock::new(None)),
69 history_observer: Arc::new(RwLock::new(None)),
70 conversation_cache: Arc::new(Mutex::new(ConversationCache::new())),
71 }
72 }
73
74 pub async fn get_raw_conversation(&self, id: &ConversationIdRef) -> Result<Conversation> {
79 Conversation::load(self.clone(), id)
80 .await
81 .map_err(RecursiveError::mls_conversation("getting raw conversation by id"))?
82 .ok_or_else(|| LeafError::ConversationNotFound(id.to_owned()))
83 .map_err(Into::into)
84 }
85
86 pub async fn conversation_exists(&self, id: &ConversationIdRef) -> Result<bool> {
88 match self.get_raw_conversation(id).await {
89 Ok(_) => Ok(true),
90 Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
91 Err(e) => Err(e),
92 }
93 }
94
95 pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
97 use openmls_traits::random::OpenMlsRand as _;
98 self.crypto_provider
99 .rand()
100 .random_vec(len)
101 .map_err(OpenMlsError::wrap("generating random vector"))
102 .map_err(Into::into)
103 }
104
105 pub async fn close(&self) -> crate::mls::Result<()> {
111 self.crypto_provider
112 .close()
113 .await
114 .map_err(OpenMlsError::wrap("closing connection with keystore"))
115 .map_err(Into::into)
116 }
117
118 pub fn database(&self) -> &ImmutableDatabase {
120 &self.database
121 }
122
123 pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
125 self.crypto_provider
126 .reseed(seed)
127 .map_err(OpenMlsError::wrap("reseeding mls backend"))
128 .map_err(Into::into)
129 }
130
131 pub(crate) async fn restore_from_history_secret(&self, history_secret: HistorySecret) -> Result<()> {
133 history_secret
135 .key_package
136 .store(&self.crypto_provider)
137 .await
138 .map_err(OpenMlsError::wrap("storing key package encapsulation"))?;
139
140 Ok(())
141 }
142
143 pub fn id(&self) -> ClientId {
145 self.id.clone()
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use core_crypto_keystore::{entities::*, traits::FetchFromDatabase};
152
153 use super::*;
154 use crate::{KeystoreError, mls_provider::CryptoProvider, transaction_context::test_utils::EntitiesCount};
155
156 impl Session {
157 #![allow(missing_docs)]
159
160 pub async fn find_keypackages(&self, backend: &CryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
161 use core_crypto_keystore::CryptoKeystoreMls as _;
162 let kps = backend
163 .key_store()
164 .mls_fetch_key_packages::<openmls::prelude::KeyPackage>(u32::MAX)
165 .await
166 .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
167 Ok(kps)
168 }
169
170 pub async fn count_entities(&self) -> EntitiesCount {
172 let keystore = &self.database;
173 let credential = keystore.count::<StoredCredential>().await.unwrap();
174 let encryption_keypair = keystore.count::<StoredEncryptionKeyPair>().await.unwrap();
175 let epoch_encryption_keypair = keystore.count::<StoredEpochEncryptionKeypair>().await.unwrap();
176 let enrollment = keystore.count::<StoredE2eiEnrollment>().await.unwrap();
177 let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
178 let hpke_private_key = keystore.count::<StoredHpkePrivateKey>().await.unwrap();
179 let key_package = keystore.count::<StoredKeypackage>().await.unwrap();
180 let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
181 let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
182 let psk_bundle = keystore.count::<StoredPskBundle>().await.unwrap();
183 EntitiesCount {
184 credential,
185 encryption_keypair,
186 epoch_encryption_keypair,
187 enrollment,
188 group,
189 hpke_private_key,
190 key_package,
191 pending_group,
192 pending_messages,
193 psk_bundle,
194 }
195 }
196 }
197}