Skip to main content

core_crypto/mls/session/
mod.rs

1mod credential;
2pub(crate) mod e2e_identity;
3mod epoch_observer;
4mod error;
5mod history_observer;
6pub(crate) mod id;
7pub(crate) mod identifier;
8pub(crate) mod key_package;
9pub(crate) mod user_id;
10
11use std::sync::Arc;
12
13use async_lock::{Mutex, RwLock};
14pub use epoch_observer::EpochObserver;
15pub(crate) use error::{Error, Result};
16pub use history_observer::HistoryObserver;
17use openmls_traits::OpenMlsCryptoProvider;
18
19use crate::{
20    ClientId, HistorySecret, ImmutableDatabase, LeafError, MlsTransport, OpenMlsError, RecursiveError,
21    mls::{
22        conversation::{Conversation, ConversationIdRef},
23        conversation_cache::ConversationCache,
24    },
25    mls_provider::{CryptoProvider, EntropySeed},
26};
27
28/// A MLS Session enables a user device to communicate via the MLS protocol.
29///
30/// This closely maps to the `Client` term in [RFC 9720], but we avoid that term to avoid ambiguity;
31/// `Client` is very overloaded with distinct meanings.
32///
33/// There is one `Session` per user per device. A session can contain many MLS groups/conversations.
34///
35/// It is cheap to clone a `Session` because everything heavy is wrapped inside an [Arc].
36///
37/// [RFC 9720]: https://www.rfc-editor.org/rfc/rfc9420.html
38#[derive(Clone, derive_more::Debug)]
39pub struct Session {
40    id: ClientId,
41    pub(crate) crypto_provider: CryptoProvider,
42    pub(crate) transport: Arc<dyn MlsTransport + 'static>,
43    database: ImmutableDatabase,
44    #[debug("EpochObserver")]
45    pub(crate) epoch_observer: Arc<RwLock<Option<Arc<dyn EpochObserver + 'static>>>>,
46    #[debug("HistoryObserver")]
47    pub(crate) history_observer: Arc<RwLock<Option<Arc<dyn HistoryObserver + 'static>>>>,
48    /// LRU cache of live MLS conversations.
49    ///
50    /// Shared across transactions for cache reuse;
51    /// cleared on transaction rollback to avoid serving stale state.
52    pub(crate) conversation_cache: Arc<Mutex<ConversationCache>>,
53}
54
55impl Session {
56    /// Create a new `Session`
57    pub fn new(
58        id: ClientId,
59        crypto_provider: CryptoProvider,
60        database: ImmutableDatabase,
61        transport: Arc<dyn MlsTransport>,
62    ) -> Self {
63        Self {
64            id,
65            crypto_provider,
66            transport,
67            database,
68            epoch_observer: Arc::new(RwLock::new(None)),
69            history_observer: Arc::new(RwLock::new(None)),
70            conversation_cache: Arc::new(Mutex::new(ConversationCache::new())),
71        }
72    }
73
74    /// Get an immutable view of an MLS conversation.
75    ///
76    /// This may be faster than
77    /// [crate::transaction_context::TransactionContext::conversation].
78    pub async fn get_raw_conversation(&self, id: &ConversationIdRef) -> Result<Conversation> {
79        Conversation::load(self.clone(), id)
80            .await
81            .map_err(RecursiveError::mls_conversation("getting raw conversation by id"))?
82            .ok_or_else(|| LeafError::ConversationNotFound(id.to_owned()))
83            .map_err(Into::into)
84    }
85
86    /// Checks if a given conversation id exists locally
87    pub async fn conversation_exists(&self, id: &ConversationIdRef) -> Result<bool> {
88        match self.get_raw_conversation(id).await {
89            Ok(_) => Ok(true),
90            Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
91            Err(e) => Err(e),
92        }
93    }
94
95    /// Generates a random byte array of the specified size
96    pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
97        use openmls_traits::random::OpenMlsRand as _;
98        self.crypto_provider
99            .rand()
100            .random_vec(len)
101            .map_err(OpenMlsError::wrap("generating random vector"))
102            .map_err(Into::into)
103    }
104
105    /// Waits for running transactions to finish, then closes the connection with the local KeyStore.
106    ///
107    /// # Errors
108    /// KeyStore errors, such as IO, and if there is more than one strong reference
109    /// to the connection.
110    pub async fn close(&self) -> crate::mls::Result<()> {
111        self.crypto_provider
112            .close()
113            .await
114            .map_err(OpenMlsError::wrap("closing connection with keystore"))
115            .map_err(Into::into)
116    }
117
118    /// Get read-only access to the database.
119    pub fn database(&self) -> &ImmutableDatabase {
120        &self.database
121    }
122
123    /// see [crate::mls_provider::CryptoProvider::reseed]
124    pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
125        self.crypto_provider
126            .reseed(seed)
127            .map_err(OpenMlsError::wrap("reseeding mls backend"))
128            .map_err(Into::into)
129    }
130
131    /// Restore from an external [`HistorySecret`].
132    pub(crate) async fn restore_from_history_secret(&self, history_secret: HistorySecret) -> Result<()> {
133        // store the key package
134        history_secret
135            .key_package
136            .store(&self.crypto_provider)
137            .await
138            .map_err(OpenMlsError::wrap("storing key package encapsulation"))?;
139
140        Ok(())
141    }
142
143    /// Retrieves the client's client id. This is free-form and not inspected.
144    pub fn id(&self) -> ClientId {
145        self.id.clone()
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use core_crypto_keystore::{entities::*, traits::FetchFromDatabase};
152
153    use super::*;
154    use crate::{KeystoreError, mls_provider::CryptoProvider, transaction_context::test_utils::EntitiesCount};
155
156    impl Session {
157        // test functions are not held to the same documentation standard as proper functions
158        #![allow(missing_docs)]
159
160        pub async fn find_keypackages(&self, backend: &CryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
161            use core_crypto_keystore::CryptoKeystoreMls as _;
162            let kps = backend
163                .key_store()
164                .mls_fetch_key_packages::<openmls::prelude::KeyPackage>(u32::MAX)
165                .await
166                .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
167            Ok(kps)
168        }
169
170        /// Count the entities
171        pub async fn count_entities(&self) -> EntitiesCount {
172            let keystore = &self.database;
173            let credential = keystore.count::<StoredCredential>().await.unwrap();
174            let encryption_keypair = keystore.count::<StoredEncryptionKeyPair>().await.unwrap();
175            let epoch_encryption_keypair = keystore.count::<StoredEpochEncryptionKeypair>().await.unwrap();
176            let enrollment = keystore.count::<StoredE2eiEnrollment>().await.unwrap();
177            let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
178            let hpke_private_key = keystore.count::<StoredHpkePrivateKey>().await.unwrap();
179            let key_package = keystore.count::<StoredKeypackage>().await.unwrap();
180            let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
181            let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
182            let psk_bundle = keystore.count::<StoredPskBundle>().await.unwrap();
183            EntitiesCount {
184                credential,
185                encryption_keypair,
186                epoch_encryption_keypair,
187                enrollment,
188                group,
189                hpke_private_key,
190                key_package,
191                pending_group,
192                pending_messages,
193                psk_bundle,
194            }
195        }
196    }
197}