core_crypto/mls/session/
mod.rs

1mod credential;
2pub(crate) mod e2e_identity;
3mod epoch_observer;
4mod error;
5mod history_observer;
6pub(crate) mod id;
7pub(crate) mod identifier;
8pub(crate) mod identities;
9pub(crate) mod key_package;
10pub(crate) mod user_id;
11
12use std::sync::Arc;
13
14use async_lock::RwLock;
15use core_crypto_keystore::Database;
16pub use epoch_observer::EpochObserver;
17pub(crate) use error::{Error, Result};
18pub use history_observer::HistoryObserver;
19use identities::Identities;
20use mls_crypto_provider::{EntropySeed, MlsCryptoProvider};
21use openmls_traits::{OpenMlsCryptoProvider, types::SignatureScheme};
22
23use crate::{
24    Ciphersuite, ClientId, ClientIdentifier, CoreCrypto, CredentialFindFilters, CredentialRef, CredentialType,
25    HistorySecret, LeafError, MlsError, MlsTransport, RecursiveError,
26    group_store::GroupStore,
27    mls::{
28        self, HasSessionAndCrypto,
29        conversation::{ConversationIdRef, ImmutableConversation},
30    },
31};
32
33/// A MLS Session enables a user device to communicate via the MLS protocol.
34///
35/// This closely maps to the `Client` term in [RFC 9720], but we avoid that term to avoid ambiguity;
36/// `Client` is very overloaded with distinct meanings.
37///
38/// There is one `Session` per user per device. A session can contain many MLS groups/conversations.
39///
40/// It is cheap to clone a `Session` because everything heavy is wrapped inside an [Arc].
41///
42/// [RFC 9720]: https://www.rfc-editor.org/rfc/rfc9420.html
43#[derive(Clone, derive_more::Debug)]
44pub struct Session {
45    pub(crate) inner: Arc<RwLock<Option<SessionInner>>>,
46    pub(crate) crypto_provider: MlsCryptoProvider,
47    pub(crate) transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
48    #[debug("EpochObserver")]
49    pub(crate) epoch_observer: Arc<RwLock<Option<Arc<dyn EpochObserver + 'static>>>>,
50    #[debug("HistoryObserver")]
51    pub(crate) history_observer: Arc<RwLock<Option<Arc<dyn HistoryObserver + 'static>>>>,
52}
53
54#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
55#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
56impl HasSessionAndCrypto for Session {
57    async fn session(&self) -> mls::Result<Session> {
58        Ok(self.clone())
59    }
60
61    async fn crypto_provider(&self) -> mls::Result<MlsCryptoProvider> {
62        Ok(self.crypto_provider.clone())
63    }
64}
65
66#[derive(Clone, Debug)]
67pub(crate) struct SessionInner {
68    id: ClientId,
69    pub(crate) identities: Identities,
70}
71
72impl Session {
73    /// Creates a new [Session]. Does not initialize MLS or Proteus.
74    ///
75    /// ## Errors
76    ///
77    /// Failures in the initialization of the KeyStore can cause errors, such as IO, the same kind
78    /// of errors can happen when the groups are being restored from the KeyStore or even during
79    /// the client initialization (to fetch the identity signature).
80    pub async fn try_new(database: &Database) -> crate::mls::Result<Self> {
81        // cloning a database is relatively cheap; it's all arcs inside
82        let database = database.to_owned();
83        // Init backend (crypto + rand + keystore)
84        let mls_backend = MlsCryptoProvider::new(database);
85
86        // We create the core crypto instance first to enable creating a transaction from it and
87        // doing all subsequent actions inside a single transaction, though it forces us to clone
88        // a few Arcs and locks.
89        let session = Self {
90            crypto_provider: mls_backend,
91            inner: Default::default(),
92            transport: Arc::new(None.into()),
93            epoch_observer: Arc::new(None.into()),
94            history_observer: Arc::new(None.into()),
95        };
96
97        let cc = CoreCrypto::from(session);
98        let context = cc
99            .new_transaction()
100            .await
101            .map_err(RecursiveError::transaction("starting new transaction"))?;
102
103        context
104            .init_pki_env()
105            .await
106            .map_err(RecursiveError::transaction("initializing pki environment"))?;
107        context
108            .finish()
109            .await
110            .map_err(RecursiveError::transaction("finishing transaction"))?;
111
112        Ok(cc.mls)
113    }
114
115    /// Provide the implementation of functions to communicate with the delivery service
116    /// (see [MlsTransport]).
117    pub async fn provide_transport(&self, transport: Arc<dyn MlsTransport>) {
118        self.transport.write().await.replace(transport);
119    }
120
121    /// Initializes the client.
122    ///
123    /// Loads any cryptographic material already present in the keystore, but does not create any.
124    /// If no credentials are present in the keystore, then one _must_ be created and added to the
125    /// session before it can be used.
126    pub async fn init(&self, identifier: ClientIdentifier, signature_schemes: &[SignatureScheme]) -> Result<()> {
127        self.ensure_unready().await?;
128        let client_id = identifier.get_id()?.into_owned();
129
130        // we want to find all credentials matching this identifier, having a valid signature scheme.
131        // the `CredentialRef::find` API doesn't allow us to easily find those credentials having
132        // one of a set of signature schemes, meaning we have two paths here:
133        // we could either search unbound by signature schemes and then filter for valid ones here,
134        // or we could iterate over the list of signature schemes and build up a set of credential refs.
135        // as there are only a few signature schemes possible and the cost of a find operation is non-trivial,
136        // we choose the first option.
137        // we might revisit this choice after WPB-20844 and WPB-21819.
138        let mut credential_refs = CredentialRef::find(
139            &self.crypto_provider.keystore(),
140            CredentialFindFilters::builder().client_id(&client_id).build(),
141        )
142        .await
143        .map_err(RecursiveError::mls_credential_ref(
144            "loading matching credential refs while initializing a client",
145        ))?;
146        credential_refs.retain(|credential_ref| signature_schemes.contains(&credential_ref.signature_scheme()));
147
148        let mut identities = Identities::new(credential_refs.len());
149        let credentials_cache = CredentialRef::load_stored_credentials(&self.crypto_provider.keystore())
150            .await
151            .map_err(RecursiveError::mls_credential_ref(
152                "loading credential ref cache while initializing session",
153            ))?;
154
155        for credential_ref in credential_refs {
156            if let Some(credential) =
157                credential_ref
158                    .load_from_cache(&credentials_cache)
159                    .map_err(RecursiveError::mls_credential_ref(
160                        "loading credential list in session init",
161                    ))?
162            {
163                match identities.push_credential(credential).await {
164                    Err(Error::CredentialConflict) => {
165                        // this is what we get for not having real primary keys in our DB
166                        // no harm done though; no need to propagate this error
167                    }
168                    Ok(_) => {}
169                    Err(err) => {
170                        return Err(RecursiveError::MlsClient {
171                            context: "adding credential to identities in init",
172                            source: Box::new(err),
173                        }
174                        .into());
175                    }
176                }
177            }
178        }
179
180        self.replace_inner(SessionInner {
181            id: client_id,
182            identities,
183        })
184        .await;
185
186        Ok(())
187    }
188
189    /// Resets the client to an uninitialized state.
190    #[cfg(test)]
191    pub(crate) async fn reset(&self) {
192        let mut inner_lock = self.inner.write().await;
193        *inner_lock = None;
194    }
195
196    pub(crate) async fn is_ready(&self) -> bool {
197        let inner_lock = self.inner.read().await;
198        inner_lock.is_some()
199    }
200
201    async fn ensure_unready(&self) -> Result<()> {
202        if self.is_ready().await {
203            Err(Error::UnexpectedlyReady)
204        } else {
205            Ok(())
206        }
207    }
208
209    async fn replace_inner(&self, new_inner: SessionInner) {
210        let mut inner_lock = self.inner.write().await;
211        *inner_lock = Some(new_inner);
212    }
213
214    /// Get an immutable view of an `MlsConversation`.
215    ///
216    /// Because it operates on the raw conversation type, this may be faster than
217    /// [crate::transaction_context::TransactionContext::conversation]. for transient and immutable
218    /// purposes. For long-lived or mutable purposes, prefer the other method.
219    pub async fn get_raw_conversation(&self, id: &ConversationIdRef) -> Result<ImmutableConversation> {
220        let raw_conversation = GroupStore::fetch_from_keystore(id, &self.crypto_provider.keystore(), None)
221            .await
222            .map_err(RecursiveError::root("getting conversation by id"))?
223            .ok_or_else(|| LeafError::ConversationNotFound(id.to_owned()))?;
224        Ok(ImmutableConversation::new(raw_conversation, self.clone()))
225    }
226
227    /// Returns the client's most recent public signature key as a buffer.
228    /// Used to upload a public key to the server in order to verify client's messages signature.
229    ///
230    /// # Arguments
231    /// * `ciphersuite` - a callback to be called to perform authorization
232    /// * `credential_type` - of the credential to look for
233    pub async fn public_key(
234        &self,
235        ciphersuite: Ciphersuite,
236        credential_type: CredentialType,
237    ) -> crate::mls::Result<Vec<u8>> {
238        let cb = self
239            .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type)
240            .await
241            .map_err(RecursiveError::mls_client("finding most recent credential"))?;
242        Ok(cb.signature_key_pair.to_public_vec())
243    }
244
245    /// Checks if a given conversation id exists locally
246    pub async fn conversation_exists(&self, id: &ConversationIdRef) -> Result<bool> {
247        match self.get_raw_conversation(id).await {
248            Ok(_) => Ok(true),
249            Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
250            Err(e) => Err(e),
251        }
252    }
253
254    /// Generates a random byte array of the specified size
255    pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
256        use openmls_traits::random::OpenMlsRand as _;
257        self.crypto_provider
258            .rand()
259            .random_vec(len)
260            .map_err(MlsError::wrap("generating random vector"))
261            .map_err(Into::into)
262    }
263
264    /// Waits for running transactions to finish, then closes the connection with the local KeyStore.
265    ///
266    /// # Errors
267    /// KeyStore errors, such as IO, and if there is more than one strong reference
268    /// to the connection.
269    pub async fn close(&self) -> crate::mls::Result<()> {
270        self.crypto_provider
271            .close()
272            .await
273            .map_err(MlsError::wrap("closing connection with keystore"))
274            .map_err(Into::into)
275    }
276
277    /// see [mls_crypto_provider::MlsCryptoProvider::reseed]
278    pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
279        self.crypto_provider
280            .reseed(seed)
281            .map_err(MlsError::wrap("reseeding mls backend"))
282            .map_err(Into::into)
283    }
284
285    /// Restore from an external [`HistorySecret`].
286    pub(crate) async fn restore_from_history_secret(&self, history_secret: HistorySecret) -> Result<()> {
287        self.ensure_unready().await?;
288
289        // store the client id (with some other stuff)
290        self.replace_inner(SessionInner {
291            id: history_secret.client_id.clone(),
292            identities: Identities::new(0),
293        })
294        .await;
295
296        // store the key package
297        history_secret
298            .key_package
299            .store(&self.crypto_provider)
300            .await
301            .map_err(MlsError::wrap("storing key package encapsulation"))?;
302
303        Ok(())
304    }
305
306    /// Retrieves the client's client id. This is free-form and not inspected.
307    pub async fn id(&self) -> Result<ClientId> {
308        match &*self.inner.read().await {
309            None => Err(Error::MlsNotInitialized),
310            Some(SessionInner { id, .. }) => Ok(id.clone()),
311        }
312    }
313
314    /// Returns whether this client is E2EI capable
315    pub async fn is_e2ei_capable(&self) -> bool {
316        match &*self.inner.read().await {
317            None => false,
318            Some(SessionInner { identities, .. }) => identities
319                .iter()
320                .any(|cred| cred.credential_type() == CredentialType::X509),
321        }
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use core_crypto_keystore::{connection::FetchFromDatabase as _, entities::*};
328    use mls_crypto_provider::MlsCryptoProvider;
329
330    use super::*;
331    use crate::{
332        CertificateBundle, Credential, KeystoreError, test_utils::*, transaction_context::test_utils::EntitiesCount,
333    };
334
335    impl Session {
336        // test functions are not held to the same documentation standard as proper functions
337        #![allow(missing_docs)]
338
339        /// Replace any existing credentials, identities, client_id, and similar with newly generated ones.
340        pub async fn random_generate(
341            &self,
342            case: &crate::test_utils::TestContext,
343            signer: Option<&crate::test_utils::x509::X509Certificate>,
344        ) -> Result<()> {
345            self.reset().await;
346            let user_uuid = uuid::Uuid::new_v4();
347            let rnd_id = rand::random::<usize>();
348            let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
349            let client_id = ClientId(client_id.into_bytes());
350
351            let credential;
352            let identifier;
353            match case.credential_type {
354                CredentialType::Basic => {
355                    identifier = ClientIdentifier::Basic(client_id.clone());
356                    credential = Credential::basic(case.ciphersuite(), client_id, &self.crypto_provider).unwrap();
357                }
358                CredentialType::X509 => {
359                    let signer = signer.expect("Missing intermediate CA").to_owned();
360                    let cert = CertificateBundle::rand(&client_id, &signer);
361                    identifier = ClientIdentifier::X509([(case.signature_scheme(), cert.clone())].into());
362                    credential = Credential::x509(case.ciphersuite(), cert).unwrap();
363                }
364            };
365
366            self.init(identifier, &[case.signature_scheme()]).await.unwrap();
367
368            self.add_credential(credential).await.unwrap();
369
370            Ok(())
371        }
372
373        pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
374            use core_crypto_keystore::CryptoKeystoreMls as _;
375            let kps = backend
376                .key_store()
377                .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
378                .await
379                .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
380            Ok(kps)
381        }
382
383        /// Count the entities
384        pub async fn count_entities(&self) -> EntitiesCount {
385            let keystore = self.crypto_provider.keystore();
386            let credential = keystore.count::<StoredCredential>().await.unwrap();
387            let encryption_keypair = keystore.count::<StoredEncryptionKeyPair>().await.unwrap();
388            let epoch_encryption_keypair = keystore.count::<StoredEpochEncryptionKeypair>().await.unwrap();
389            let enrollment = keystore.count::<StoredE2eiEnrollment>().await.unwrap();
390            let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
391            let hpke_private_key = keystore.count::<StoredHpkePrivateKey>().await.unwrap();
392            let key_package = keystore.count::<StoredKeypackage>().await.unwrap();
393            let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
394            let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
395            let psk_bundle = keystore.count::<StoredPskBundle>().await.unwrap();
396            EntitiesCount {
397                credential,
398                encryption_keypair,
399                epoch_encryption_keypair,
400                enrollment,
401                group,
402                hpke_private_key,
403                key_package,
404                pending_group,
405                pending_messages,
406                psk_bundle,
407            }
408        }
409    }
410
411    #[apply(all_cred_cipher)]
412    async fn can_generate_session(mut case: TestContext) {
413        let [alice] = case.sessions().await;
414        let key_store = case.create_in_memory_database().await;
415        let backend = MlsCryptoProvider::new(key_store);
416        let x509_test_chain = if case.is_x509() {
417            let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
418            x509_test_chain.register_with_provider(&backend).await;
419            Some(x509_test_chain)
420        } else {
421            None
422        };
423        backend.new_transaction().await.unwrap();
424        let session = alice.session().await;
425        session
426            .random_generate(
427                &case,
428                x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
429            )
430            .await
431            .unwrap();
432    }
433}