core_crypto/mls/session/
mod.rs

1mod credential;
2pub(crate) mod e2e_identity;
3mod epoch_observer;
4mod error;
5mod history_observer;
6pub(crate) mod id;
7pub(crate) mod identifier;
8pub(crate) mod identities;
9pub(crate) mod key_package;
10pub(crate) mod user_id;
11
12use std::sync::Arc;
13
14use async_lock::RwLock;
15use core_crypto_keystore::Database;
16pub use epoch_observer::EpochObserver;
17pub(crate) use error::{Error, Result};
18pub use history_observer::HistoryObserver;
19use identities::Identities;
20use key_package::KEYPACKAGE_DEFAULT_LIFETIME;
21use mls_crypto_provider::{EntropySeed, MlsCryptoProvider};
22use openmls_traits::{OpenMlsCryptoProvider, types::SignatureScheme};
23
24use crate::{
25    Ciphersuite, ClientId, ClientIdentifier, CoreCrypto, CredentialFindFilters, CredentialRef, CredentialType,
26    HistorySecret, LeafError, MlsError, MlsTransport, RecursiveError,
27    group_store::GroupStore,
28    mls::{
29        self, HasSessionAndCrypto,
30        conversation::{ConversationIdRef, ImmutableConversation},
31    },
32};
33
34/// A MLS Session enables a user device to communicate via the MLS protocol.
35///
36/// This closely maps to the `Client` term in [RFC 9720], but we avoid that term to avoid ambiguity;
37/// `Client` is very overloaded with distinct meanings.
38///
39/// There is one `Session` per user per device. A session can contain many MLS groups/conversations.
40///
41/// It is cheap to clone a `Session` because everything heavy is wrapped inside an [Arc].
42///
43/// [RFC 9720]: https://www.rfc-editor.org/rfc/rfc9420.html
44#[derive(Clone, derive_more::Debug)]
45pub struct Session {
46    pub(crate) inner: Arc<RwLock<Option<SessionInner>>>,
47    pub(crate) crypto_provider: MlsCryptoProvider,
48    pub(crate) transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
49    #[debug("EpochObserver")]
50    pub(crate) epoch_observer: Arc<RwLock<Option<Arc<dyn EpochObserver + 'static>>>>,
51    #[debug("HistoryObserver")]
52    pub(crate) history_observer: Arc<RwLock<Option<Arc<dyn HistoryObserver + 'static>>>>,
53}
54
55#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
56#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
57impl HasSessionAndCrypto for Session {
58    async fn session(&self) -> mls::Result<Session> {
59        Ok(self.clone())
60    }
61
62    async fn crypto_provider(&self) -> mls::Result<MlsCryptoProvider> {
63        Ok(self.crypto_provider.clone())
64    }
65}
66
67#[derive(Clone, Debug)]
68pub(crate) struct SessionInner {
69    id: ClientId,
70    pub(crate) identities: Identities,
71    keypackage_lifetime: std::time::Duration,
72}
73
74impl Session {
75    /// Creates a new [Session]. Does not initialize MLS or Proteus.
76    ///
77    /// ## Errors
78    ///
79    /// Failures in the initialization of the KeyStore can cause errors, such as IO, the same kind
80    /// of errors can happen when the groups are being restored from the KeyStore or even during
81    /// the client initialization (to fetch the identity signature).
82    pub async fn try_new(database: &Database) -> crate::mls::Result<Self> {
83        // cloning a database is relatively cheap; it's all arcs inside
84        let database = database.to_owned();
85        // Init backend (crypto + rand + keystore)
86        let mls_backend = MlsCryptoProvider::new(database);
87
88        // We create the core crypto instance first to enable creating a transaction from it and
89        // doing all subsequent actions inside a single transaction, though it forces us to clone
90        // a few Arcs and locks.
91        let session = Self {
92            crypto_provider: mls_backend,
93            inner: Default::default(),
94            transport: Arc::new(None.into()),
95            epoch_observer: Arc::new(None.into()),
96            history_observer: Arc::new(None.into()),
97        };
98
99        let cc = CoreCrypto::from(session);
100        let context = cc
101            .new_transaction()
102            .await
103            .map_err(RecursiveError::transaction("starting new transaction"))?;
104
105        context
106            .init_pki_env()
107            .await
108            .map_err(RecursiveError::transaction("initializing pki environment"))?;
109        context
110            .finish()
111            .await
112            .map_err(RecursiveError::transaction("finishing transaction"))?;
113
114        Ok(cc.mls)
115    }
116
117    /// Provide the implementation of functions to communicate with the delivery service
118    /// (see [MlsTransport]).
119    pub async fn provide_transport(&self, transport: Arc<dyn MlsTransport>) {
120        self.transport.write().await.replace(transport);
121    }
122
123    /// Initializes the client.
124    ///
125    /// Loads any cryptographic material already present in the keystore, but does not create any.
126    /// If no credentials are present in the keystore, then one _must_ be created and added to the
127    /// session before it can be used.
128    pub async fn init(&self, identifier: ClientIdentifier, signature_schemes: &[SignatureScheme]) -> Result<()> {
129        self.ensure_unready().await?;
130        let client_id = identifier.get_id()?.into_owned();
131
132        // we want to find all credentials matching this identifier, having a valid signature scheme.
133        // the `CredentialRef::find` API doesn't allow us to easily find those credentials having
134        // one of a set of signature schemes, meaning we have two paths here:
135        // we could either search unbound by signature schemes and then filter for valid ones here,
136        // or we could iterate over the list of signature schemes and build up a set of credential refs.
137        // as there are only a few signature schemes possible and the cost of a find operation is non-trivial,
138        // we choose the first option.
139        // we might revisit this choice after WPB-20844 and WPB-21819.
140        let mut credential_refs = CredentialRef::find(
141            &self.crypto_provider.keystore(),
142            CredentialFindFilters::builder().client_id(&client_id).build(),
143        )
144        .await
145        .map_err(RecursiveError::mls_credential_ref(
146            "loading matching credential refs while initializing a client",
147        ))?;
148        credential_refs.retain(|credential_ref| signature_schemes.contains(&credential_ref.signature_scheme()));
149
150        let mut identities = Identities::new(credential_refs.len());
151        let credentials_cache = CredentialRef::load_stored_credentials(&self.crypto_provider.keystore())
152            .await
153            .map_err(RecursiveError::mls_credential_ref(
154                "loading credential ref cache while initializing session",
155            ))?;
156
157        for credential_ref in credential_refs {
158            for credential_result in
159                credential_ref
160                    .load_from_cache(&credentials_cache)
161                    .map_err(RecursiveError::mls_credential_ref(
162                        "loading credential list in session init",
163                    ))?
164            {
165                let credential = credential_result
166                    .map_err(RecursiveError::mls_credential_ref("loading credential in session init"))?;
167
168                match identities.push_credential(credential).await {
169                    Err(Error::CredentialConflict) => {
170                        // this is what we get for not having real primary keys in our DB
171                        // no harm done though; no need to propagate this error
172                    }
173                    Ok(_) => {}
174                    Err(err) => {
175                        return Err(RecursiveError::MlsClient {
176                            context: "adding credential to identities in init",
177                            source: Box::new(err),
178                        }
179                        .into());
180                    }
181                }
182            }
183        }
184
185        self.replace_inner(SessionInner {
186            id: client_id,
187            identities,
188            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
189        })
190        .await;
191
192        Ok(())
193    }
194
195    /// Resets the client to an uninitialized state.
196    #[cfg(test)]
197    pub(crate) async fn reset(&self) {
198        let mut inner_lock = self.inner.write().await;
199        *inner_lock = None;
200    }
201
202    pub(crate) async fn is_ready(&self) -> bool {
203        let inner_lock = self.inner.read().await;
204        inner_lock.is_some()
205    }
206
207    async fn ensure_unready(&self) -> Result<()> {
208        if self.is_ready().await {
209            Err(Error::UnexpectedlyReady)
210        } else {
211            Ok(())
212        }
213    }
214
215    async fn replace_inner(&self, new_inner: SessionInner) {
216        let mut inner_lock = self.inner.write().await;
217        *inner_lock = Some(new_inner);
218    }
219
220    /// Get an immutable view of an `MlsConversation`.
221    ///
222    /// Because it operates on the raw conversation type, this may be faster than
223    /// [crate::transaction_context::TransactionContext::conversation]. for transient and immutable
224    /// purposes. For long-lived or mutable purposes, prefer the other method.
225    pub async fn get_raw_conversation(&self, id: &ConversationIdRef) -> Result<ImmutableConversation> {
226        let raw_conversation = GroupStore::fetch_from_keystore(id, &self.crypto_provider.keystore(), None)
227            .await
228            .map_err(RecursiveError::root("getting conversation by id"))?
229            .ok_or_else(|| LeafError::ConversationNotFound(id.to_owned()))?;
230        Ok(ImmutableConversation::new(raw_conversation, self.clone()))
231    }
232
233    /// Returns the client's most recent public signature key as a buffer.
234    /// Used to upload a public key to the server in order to verify client's messages signature.
235    ///
236    /// # Arguments
237    /// * `ciphersuite` - a callback to be called to perform authorization
238    /// * `credential_type` - of the credential to look for
239    pub async fn public_key(
240        &self,
241        ciphersuite: Ciphersuite,
242        credential_type: CredentialType,
243    ) -> crate::mls::Result<Vec<u8>> {
244        let cb = self
245            .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type)
246            .await
247            .map_err(RecursiveError::mls_client("finding most recent credential"))?;
248        Ok(cb.signature_key_pair.to_public_vec())
249    }
250
251    /// Checks if a given conversation id exists locally
252    pub async fn conversation_exists(&self, id: &ConversationIdRef) -> Result<bool> {
253        match self.get_raw_conversation(id).await {
254            Ok(_) => Ok(true),
255            Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
256            Err(e) => Err(e),
257        }
258    }
259
260    /// Generates a random byte array of the specified size
261    pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
262        use openmls_traits::random::OpenMlsRand as _;
263        self.crypto_provider
264            .rand()
265            .random_vec(len)
266            .map_err(MlsError::wrap("generating random vector"))
267            .map_err(Into::into)
268    }
269
270    /// Waits for running transactions to finish, then closes the connection with the local KeyStore.
271    ///
272    /// # Errors
273    /// KeyStore errors, such as IO, and if there is more than one strong reference
274    /// to the connection.
275    pub async fn close(&self) -> crate::mls::Result<()> {
276        self.crypto_provider
277            .close()
278            .await
279            .map_err(MlsError::wrap("closing connection with keystore"))
280            .map_err(Into::into)
281    }
282
283    /// see [mls_crypto_provider::MlsCryptoProvider::reseed]
284    pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
285        self.crypto_provider
286            .reseed(seed)
287            .map_err(MlsError::wrap("reseeding mls backend"))
288            .map_err(Into::into)
289    }
290
291    /// Restore from an external [`HistorySecret`].
292    pub(crate) async fn restore_from_history_secret(&self, history_secret: HistorySecret) -> Result<()> {
293        self.ensure_unready().await?;
294
295        // store the client id (with some other stuff)
296        self.replace_inner(SessionInner {
297            id: history_secret.client_id.clone(),
298            identities: Identities::new(0),
299            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
300        })
301        .await;
302
303        // store the key package
304        history_secret
305            .key_package
306            .store(&self.crypto_provider)
307            .await
308            .map_err(MlsError::wrap("storing key package encapsulation"))?;
309
310        Ok(())
311    }
312
313    /// Retrieves the client's client id. This is free-form and not inspected.
314    pub async fn id(&self) -> Result<ClientId> {
315        match &*self.inner.read().await {
316            None => Err(Error::MlsNotInitialized),
317            Some(SessionInner { id, .. }) => Ok(id.clone()),
318        }
319    }
320
321    /// Returns whether this client is E2EI capable
322    pub async fn is_e2ei_capable(&self) -> bool {
323        match &*self.inner.read().await {
324            None => false,
325            Some(SessionInner { identities, .. }) => identities
326                .iter()
327                .any(|cred| cred.credential_type() == CredentialType::X509),
328        }
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use core_crypto_keystore::{connection::FetchFromDatabase as _, entities::*};
335    use mls_crypto_provider::MlsCryptoProvider;
336
337    use super::*;
338    use crate::{
339        CertificateBundle, Credential, KeystoreError, test_utils::*, transaction_context::test_utils::EntitiesCount,
340    };
341
342    impl Session {
343        // test functions are not held to the same documentation standard as proper functions
344        #![allow(missing_docs)]
345
346        /// Replace any existing credentials, identities, client_id, and similar with newly generated ones.
347        pub async fn random_generate(
348            &self,
349            case: &crate::test_utils::TestContext,
350            signer: Option<&crate::test_utils::x509::X509Certificate>,
351        ) -> Result<()> {
352            self.reset().await;
353            let user_uuid = uuid::Uuid::new_v4();
354            let rnd_id = rand::random::<usize>();
355            let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
356            let client_id = ClientId(client_id.into_bytes());
357
358            let credential;
359            let identifier;
360            match case.credential_type {
361                CredentialType::Basic => {
362                    identifier = ClientIdentifier::Basic(client_id.clone());
363                    credential = Credential::basic(case.signature_scheme(), client_id, &self.crypto_provider).unwrap();
364                }
365                CredentialType::X509 => {
366                    let signer = signer.expect("Missing intermediate CA").to_owned();
367                    let cert = CertificateBundle::rand(&client_id, &signer);
368                    identifier = ClientIdentifier::X509([(case.signature_scheme(), cert.clone())].into());
369                    credential = Credential::x509(cert).unwrap();
370                }
371            };
372
373            self.init(identifier, &[case.signature_scheme()]).await.unwrap();
374
375            self.add_credential(credential).await.unwrap();
376
377            Ok(())
378        }
379
380        pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
381            use core_crypto_keystore::CryptoKeystoreMls as _;
382            let kps = backend
383                .key_store()
384                .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
385                .await
386                .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
387            Ok(kps)
388        }
389
390        pub(crate) async fn generate_one_keypackage(
391            &self,
392            backend: &MlsCryptoProvider,
393            cs: Ciphersuite,
394            ct: CredentialType,
395        ) -> Result<openmls::prelude::KeyPackage> {
396            let cb = self.find_most_recent_credential(cs.signature_algorithm(), ct).await?;
397            self.generate_one_keypackage_from_credential(backend, cs, &cb).await
398        }
399
400        /// Count the entities
401        pub async fn count_entities(&self) -> EntitiesCount {
402            let keystore = self.crypto_provider.keystore();
403            let credential = keystore.count::<StoredCredential>().await.unwrap();
404            let encryption_keypair = keystore.count::<StoredEncryptionKeyPair>().await.unwrap();
405            let epoch_encryption_keypair = keystore.count::<StoredEpochEncryptionKeypair>().await.unwrap();
406            let enrollment = keystore.count::<StoredE2eiEnrollment>().await.unwrap();
407            let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
408            let hpke_private_key = keystore.count::<StoredHpkePrivateKey>().await.unwrap();
409            let key_package = keystore.count::<StoredKeypackage>().await.unwrap();
410            let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
411            let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
412            let psk_bundle = keystore.count::<StoredPskBundle>().await.unwrap();
413            EntitiesCount {
414                credential,
415                encryption_keypair,
416                epoch_encryption_keypair,
417                enrollment,
418                group,
419                hpke_private_key,
420                key_package,
421                pending_group,
422                pending_messages,
423                psk_bundle,
424            }
425        }
426    }
427
428    #[apply(all_cred_cipher)]
429    async fn can_generate_session(mut case: TestContext) {
430        let [alice] = case.sessions().await;
431        let key_store = case.create_in_memory_database().await;
432        let backend = MlsCryptoProvider::new(key_store);
433        let x509_test_chain = if case.is_x509() {
434            let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
435            x509_test_chain.register_with_provider(&backend).await;
436            Some(x509_test_chain)
437        } else {
438            None
439        };
440        backend.new_transaction().await.unwrap();
441        let session = alice.session().await;
442        session
443            .random_generate(
444                &case,
445                x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
446            )
447            .await
448            .unwrap();
449    }
450}