core_crypto/mls/session/
mod.rs

1mod credential;
2pub(crate) mod e2e_identity;
3mod epoch_observer;
4mod error;
5mod history_observer;
6pub(crate) mod id;
7pub(crate) mod identifier;
8pub(crate) mod identities;
9pub(crate) mod key_package;
10pub(crate) mod user_id;
11
12use std::sync::Arc;
13
14use async_lock::RwLock;
15pub use epoch_observer::EpochObserver;
16pub(crate) use error::{Error, Result};
17pub use history_observer::HistoryObserver;
18use identities::Identities;
19use mls_crypto_provider::{EntropySeed, MlsCryptoProvider};
20use openmls_traits::OpenMlsCryptoProvider;
21
22use crate::{
23    Ciphersuite, ClientId, CredentialType, HistorySecret, LeafError, MlsConversation, MlsError, MlsTransport,
24    RecursiveError,
25    mls::{
26        self, HasSessionAndCrypto,
27        conversation::{ConversationIdRef, ImmutableConversation},
28    },
29};
30
31/// A MLS Session enables a user device to communicate via the MLS protocol.
32///
33/// This closely maps to the `Client` term in [RFC 9720], but we avoid that term to avoid ambiguity;
34/// `Client` is very overloaded with distinct meanings.
35///
36/// There is one `Session` per user per device. A session can contain many MLS groups/conversations.
37///
38/// It is cheap to clone a `Session` because everything heavy is wrapped inside an [Arc].
39///
40/// [RFC 9720]: https://www.rfc-editor.org/rfc/rfc9420.html
41#[derive(Clone, derive_more::Debug)]
42pub struct Session {
43    id: ClientId,
44    identities: Arc<RwLock<Identities>>,
45    pub(crate) crypto_provider: MlsCryptoProvider,
46    pub(crate) transport: Arc<dyn MlsTransport + 'static>,
47    #[debug("EpochObserver")]
48    pub(crate) epoch_observer: Arc<RwLock<Option<Arc<dyn EpochObserver + 'static>>>>,
49    #[debug("HistoryObserver")]
50    pub(crate) history_observer: Arc<RwLock<Option<Arc<dyn HistoryObserver + 'static>>>>,
51}
52
53#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
54#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
55impl HasSessionAndCrypto for Session {
56    async fn session(&self) -> mls::Result<Session> {
57        Ok(self.clone())
58    }
59
60    async fn crypto_provider(&self) -> mls::Result<MlsCryptoProvider> {
61        Ok(self.crypto_provider.clone())
62    }
63}
64
65impl Session {
66    /// Create a new `Session`
67    pub fn new(
68        id: ClientId,
69        identities: Identities,
70        crypto_provider: MlsCryptoProvider,
71        transport: Arc<dyn MlsTransport>,
72    ) -> Self {
73        Self {
74            id,
75            identities: Arc::new(RwLock::new(identities)),
76            crypto_provider,
77            transport,
78            epoch_observer: Arc::new(RwLock::new(None)),
79            history_observer: Arc::new(RwLock::new(None)),
80        }
81    }
82
83    /// Get an immutable view of an `MlsConversation`.
84    ///
85    /// Because it operates on the raw conversation type, this may be faster than
86    /// [crate::transaction_context::TransactionContext::conversation] for transient and immutable
87    /// purposes. For long-lived or mutable purposes, prefer the other method.
88    pub async fn get_raw_conversation(&self, id: &ConversationIdRef) -> Result<ImmutableConversation> {
89        let raw_conversation = MlsConversation::load(&self.crypto_provider.keystore(), id)
90            .await
91            .map_err(RecursiveError::mls_conversation("getting raw conversation by id"))?
92            .ok_or_else(|| LeafError::ConversationNotFound(id.to_owned()))?;
93        Ok(ImmutableConversation::new(raw_conversation, self.clone()))
94    }
95
96    /// Returns the client's most recent public signature key as a buffer.
97    /// Used to upload a public key to the server in order to verify client's messages signature.
98    ///
99    /// # Arguments
100    /// * `ciphersuite` - a callback to be called to perform authorization
101    /// * `credential_type` - of the credential to look for
102    pub async fn public_key(
103        &self,
104        ciphersuite: Ciphersuite,
105        credential_type: CredentialType,
106    ) -> crate::mls::Result<Vec<u8>> {
107        let cb = self
108            .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type)
109            .await
110            .map_err(RecursiveError::mls_client("finding most recent credential"))?;
111        Ok(cb.signature_key_pair.to_public_vec())
112    }
113
114    /// Checks if a given conversation id exists locally
115    pub async fn conversation_exists(&self, id: &ConversationIdRef) -> Result<bool> {
116        match self.get_raw_conversation(id).await {
117            Ok(_) => Ok(true),
118            Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
119            Err(e) => Err(e),
120        }
121    }
122
123    /// Generates a random byte array of the specified size
124    pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
125        use openmls_traits::random::OpenMlsRand as _;
126        self.crypto_provider
127            .rand()
128            .random_vec(len)
129            .map_err(MlsError::wrap("generating random vector"))
130            .map_err(Into::into)
131    }
132
133    /// Waits for running transactions to finish, then closes the connection with the local KeyStore.
134    ///
135    /// # Errors
136    /// KeyStore errors, such as IO, and if there is more than one strong reference
137    /// to the connection.
138    pub async fn close(&self) -> crate::mls::Result<()> {
139        self.crypto_provider
140            .close()
141            .await
142            .map_err(MlsError::wrap("closing connection with keystore"))
143            .map_err(Into::into)
144    }
145
146    /// see [mls_crypto_provider::MlsCryptoProvider::reseed]
147    pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
148        self.crypto_provider
149            .reseed(seed)
150            .map_err(MlsError::wrap("reseeding mls backend"))
151            .map_err(Into::into)
152    }
153
154    /// Restore from an external [`HistorySecret`].
155    pub(crate) async fn restore_from_history_secret(&self, history_secret: HistorySecret) -> Result<()> {
156        // store the key package
157        history_secret
158            .key_package
159            .store(&self.crypto_provider)
160            .await
161            .map_err(MlsError::wrap("storing key package encapsulation"))?;
162
163        Ok(())
164    }
165
166    /// Retrieves the client's client id. This is free-form and not inspected.
167    pub fn id(&self) -> ClientId {
168        self.id.clone()
169    }
170
171    /// Returns whether this client is E2EI capable
172    pub async fn is_e2ei_capable(&self) -> bool {
173        self.identities
174            .read()
175            .await
176            .iter()
177            .any(|cred| cred.credential_type() == CredentialType::X509)
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use core_crypto_keystore::{entities::*, traits::FetchFromDatabase as _};
184    use mls_crypto_provider::MlsCryptoProvider;
185
186    use super::*;
187    use crate::{KeystoreError, test_utils::*, transaction_context::test_utils::EntitiesCount};
188
189    impl Session {
190        // test functions are not held to the same documentation standard as proper functions
191        #![allow(missing_docs)]
192
193        pub async fn identities(&self) -> Identities {
194            self.identities.read().await.clone()
195        }
196
197        pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
198            use core_crypto_keystore::CryptoKeystoreMls as _;
199            let kps = backend
200                .key_store()
201                .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
202                .await
203                .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
204            Ok(kps)
205        }
206
207        /// Count the entities
208        pub async fn count_entities(&self) -> EntitiesCount {
209            let keystore = self.crypto_provider.keystore();
210            let credential = keystore.count::<StoredCredential>().await.unwrap();
211            let encryption_keypair = keystore.count::<StoredEncryptionKeyPair>().await.unwrap();
212            let epoch_encryption_keypair = keystore.count::<StoredEpochEncryptionKeypair>().await.unwrap();
213            let enrollment = keystore.count::<StoredE2eiEnrollment>().await.unwrap();
214            let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
215            let hpke_private_key = keystore.count::<StoredHpkePrivateKey>().await.unwrap();
216            let key_package = keystore.count::<StoredKeypackage>().await.unwrap();
217            let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
218            let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
219            let psk_bundle = keystore.count::<StoredPskBundle>().await.unwrap();
220            EntitiesCount {
221                credential,
222                encryption_keypair,
223                epoch_encryption_keypair,
224                enrollment,
225                group,
226                hpke_private_key,
227                key_package,
228                pending_group,
229                pending_messages,
230                psk_bundle,
231            }
232        }
233    }
234
235    #[apply(all_cred_cipher)]
236    async fn can_generate_session(mut case: TestContext) {
237        let [alice] = case.sessions().await;
238        let key_store = case.create_in_memory_database().await;
239        let backend = MlsCryptoProvider::new(key_store);
240        let x509_test_chain = if case.is_x509() {
241            let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
242            x509_test_chain.register_with_provider(&backend).await;
243            Some(x509_test_chain)
244        } else {
245            None
246        };
247        backend.new_transaction().await.unwrap();
248        alice
249            .random_generate(
250                &case,
251                x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
252            )
253            .await
254            .unwrap();
255    }
256}