core_crypto/mls/session/
mod.rs

1mod credential;
2pub(crate) mod e2e_identity;
3mod epoch_observer;
4mod error;
5mod history_observer;
6pub(crate) mod id;
7pub(crate) mod identifier;
8pub(crate) mod key_package;
9pub(crate) mod user_id;
10
11use std::sync::Arc;
12
13use async_lock::RwLock;
14use core_crypto_keystore::traits::FetchFromDatabase;
15pub use epoch_observer::EpochObserver;
16pub(crate) use error::{Error, Result};
17pub use history_observer::HistoryObserver;
18use openmls_traits::OpenMlsCryptoProvider;
19
20use crate::{
21    ClientId, HistorySecret, LeafError, MlsConversation, MlsError, MlsTransport, RecursiveError,
22    mls::{
23        self, HasSessionAndCrypto,
24        conversation::{ConversationIdRef, ImmutableConversation},
25    },
26    mls_provider::{EntropySeed, MlsCryptoProvider},
27};
28
29/// A MLS Session enables a user device to communicate via the MLS protocol.
30///
31/// This closely maps to the `Client` term in [RFC 9720], but we avoid that term to avoid ambiguity;
32/// `Client` is very overloaded with distinct meanings.
33///
34/// There is one `Session` per user per device. A session can contain many MLS groups/conversations.
35///
36/// It is cheap to clone a `Session` because everything heavy is wrapped inside an [Arc].
37///
38/// [RFC 9720]: https://www.rfc-editor.org/rfc/rfc9420.html
39///
40/// ## Why does the session have a generic parameter?
41///
42/// The reason is to ensure at compile time that from inside the session, we don't have the full interface of the
43/// database, just the read-only one, so we don't have to remember that the session should only have API that doesn't
44/// require writing to the DB – it won't compile if we forget and try to do that. All API requiring to write to the DB
45/// should live on the transaction context.
46///
47/// Ideally, we'd like to have the database as an `Arc<dyn FetchFromDatabase>>` field on the session. However, we'd
48/// need to refactor the FetchFromDatabase trait, and thus the Entity trait to be dyn-compatible, which would require us
49/// to rewrite the entire keystore crate.
50#[derive(Clone, derive_more::Debug)]
51pub struct Session<D> {
52    id: ClientId,
53    pub(crate) crypto_provider: MlsCryptoProvider,
54    pub(crate) transport: Arc<dyn MlsTransport + 'static>,
55    database: D,
56    #[debug("EpochObserver")]
57    pub(crate) epoch_observer: Arc<RwLock<Option<Arc<dyn EpochObserver + 'static>>>>,
58    #[debug("HistoryObserver")]
59    pub(crate) history_observer: Arc<RwLock<Option<Arc<dyn HistoryObserver + 'static>>>>,
60}
61
62#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
63#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
64impl HasSessionAndCrypto for Session<core_crypto_keystore::Database> {
65    async fn session(&self) -> mls::Result<Session<core_crypto_keystore::Database>> {
66        Ok(self.clone())
67    }
68
69    async fn crypto_provider(&self) -> mls::Result<MlsCryptoProvider> {
70        Ok(self.crypto_provider.clone())
71    }
72}
73
74impl<D: FetchFromDatabase> Session<D> {
75    /// Create a new `Session`
76    pub fn new(
77        id: ClientId,
78        crypto_provider: MlsCryptoProvider,
79        database: D,
80        transport: Arc<dyn MlsTransport>,
81    ) -> Self {
82        Self {
83            id,
84            crypto_provider,
85            transport,
86            database,
87            epoch_observer: Arc::new(RwLock::new(None)),
88            history_observer: Arc::new(RwLock::new(None)),
89        }
90    }
91
92    /// Get an immutable view of an `MlsConversation`.
93    ///
94    /// Because it operates on the raw conversation type, this may be faster than
95    /// [crate::transaction_context::TransactionContext::conversation] for transient and immutable
96    /// purposes. For long-lived or mutable purposes, prefer the other method.
97    pub async fn get_raw_conversation(&self, id: &ConversationIdRef) -> Result<ImmutableConversation<D>>
98    where
99        D: FetchFromDatabase + Clone,
100    {
101        let raw_conversation = MlsConversation::load(&self.database, id)
102            .await
103            .map_err(RecursiveError::mls_conversation("getting raw conversation by id"))?
104            .ok_or_else(|| LeafError::ConversationNotFound(id.to_owned()))?;
105        Ok(ImmutableConversation::new(raw_conversation, self.clone()))
106    }
107
108    /// Checks if a given conversation id exists locally
109    pub async fn conversation_exists(&self, id: &ConversationIdRef) -> Result<bool>
110    where
111        D: Clone + FetchFromDatabase,
112    {
113        match self.get_raw_conversation(id).await {
114            Ok(_) => Ok(true),
115            Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
116            Err(e) => Err(e),
117        }
118    }
119
120    /// Generates a random byte array of the specified size
121    pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
122        use openmls_traits::random::OpenMlsRand as _;
123        self.crypto_provider
124            .rand()
125            .random_vec(len)
126            .map_err(MlsError::wrap("generating random vector"))
127            .map_err(Into::into)
128    }
129
130    /// Waits for running transactions to finish, then closes the connection with the local KeyStore.
131    ///
132    /// # Errors
133    /// KeyStore errors, such as IO, and if there is more than one strong reference
134    /// to the connection.
135    pub async fn close(&self) -> crate::mls::Result<()> {
136        self.crypto_provider
137            .close()
138            .await
139            .map_err(MlsError::wrap("closing connection with keystore"))
140            .map_err(Into::into)
141    }
142
143    /// Get read-only access to the database.
144    pub fn database(&self) -> &impl FetchFromDatabase {
145        &self.database
146    }
147
148    /// see [crate::mls_provider::MlsCryptoProvider::reseed]
149    pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
150        self.crypto_provider
151            .reseed(seed)
152            .map_err(MlsError::wrap("reseeding mls backend"))
153            .map_err(Into::into)
154    }
155
156    /// Restore from an external [`HistorySecret`].
157    pub(crate) async fn restore_from_history_secret(&self, history_secret: HistorySecret) -> Result<()> {
158        // store the key package
159        history_secret
160            .key_package
161            .store(&self.crypto_provider)
162            .await
163            .map_err(MlsError::wrap("storing key package encapsulation"))?;
164
165        Ok(())
166    }
167
168    /// Retrieves the client's client id. This is free-form and not inspected.
169    pub fn id(&self) -> ClientId {
170        self.id.clone()
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use core_crypto_keystore::{entities::*, traits::FetchFromDatabase};
177
178    use super::*;
179    use crate::{
180        KeystoreError, mls_provider::MlsCryptoProvider, test_utils::*, transaction_context::test_utils::EntitiesCount,
181    };
182
183    impl<D: FetchFromDatabase> Session<D> {
184        // test functions are not held to the same documentation standard as proper functions
185        #![allow(missing_docs)]
186
187        pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
188            use core_crypto_keystore::CryptoKeystoreMls as _;
189            let kps = backend
190                .key_store()
191                .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
192                .await
193                .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
194            Ok(kps)
195        }
196
197        /// Count the entities
198        pub async fn count_entities(&self) -> EntitiesCount {
199            let keystore = &self.database;
200            let credential = keystore.count::<StoredCredential>().await.unwrap();
201            let encryption_keypair = keystore.count::<StoredEncryptionKeyPair>().await.unwrap();
202            let epoch_encryption_keypair = keystore.count::<StoredEpochEncryptionKeypair>().await.unwrap();
203            let enrollment = keystore.count::<StoredE2eiEnrollment>().await.unwrap();
204            let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
205            let hpke_private_key = keystore.count::<StoredHpkePrivateKey>().await.unwrap();
206            let key_package = keystore.count::<StoredKeypackage>().await.unwrap();
207            let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
208            let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
209            let psk_bundle = keystore.count::<StoredPskBundle>().await.unwrap();
210            EntitiesCount {
211                credential,
212                encryption_keypair,
213                epoch_encryption_keypair,
214                enrollment,
215                group,
216                hpke_private_key,
217                key_package,
218                pending_group,
219                pending_messages,
220                psk_bundle,
221            }
222        }
223    }
224
225    #[apply(all_cred_cipher)]
226    async fn can_generate_session(mut case: TestContext) {
227        let [alice] = case.sessions().await;
228        let key_store = case.create_in_memory_database().await;
229        let backend = MlsCryptoProvider::new(key_store);
230        let x509_test_chain = if case.is_x509() {
231            let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
232            x509_test_chain.register_with_provider(&backend).await;
233            Some(x509_test_chain)
234        } else {
235            None
236        };
237        backend.new_transaction().await.unwrap();
238        alice
239            .random_generate(
240                &case,
241                x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
242            )
243            .await
244            .unwrap();
245    }
246}