core_crypto/mls/session/
mod.rs

1pub(crate) mod e2e_identity;
2mod epoch_observer;
3mod error;
4pub(crate) mod id;
5pub(crate) mod identifier;
6pub(crate) mod identities;
7pub(crate) mod key_package;
8pub(crate) mod user_id;
9
10use crate::{
11    CoreCrypto, KeystoreError, LeafError, MlsError, MlsTransport, RecursiveError,
12    group_store::GroupStore,
13    mls::{
14        self, HasSessionAndCrypto,
15        conversation::ImmutableConversation,
16        credential::{CredentialBundle, ext::CredentialExt},
17    },
18    prelude::{
19        CertificateBundle, ClientId, ConversationId, HistorySecret, INITIAL_KEYING_MATERIAL_COUNT, MlsCiphersuite,
20        MlsClientConfiguration, MlsCredentialType, identifier::ClientIdentifier,
21        key_package::KEYPACKAGE_DEFAULT_LIFETIME,
22    },
23};
24use async_lock::RwLock;
25use core_crypto_keystore::{
26    Connection, CryptoKeystoreError,
27    connection::FetchFromDatabase,
28    entities::{EntityFindParams, MlsCredential, MlsSignatureKeyPair},
29};
30pub use epoch_observer::EpochObserver;
31pub(crate) use error::{Error, Result};
32use identities::Identities;
33use log::debug;
34use mls_crypto_provider::{EntropySeed, MlsCryptoProvider, MlsCryptoProviderConfiguration};
35use openmls::prelude::{Credential, CredentialType};
36use openmls_basic_credential::SignatureKeyPair;
37use openmls_traits::{OpenMlsCryptoProvider, crypto::OpenMlsCrypto, types::SignatureScheme};
38use openmls_x509_credential::CertificateKeyPair;
39use std::ops::Deref;
40use std::sync::Arc;
41use std::{collections::HashSet, fmt};
42use tls_codec::{Deserialize, Serialize};
43
44/// A MLS Session enables a user device to communicate via the MLS protocol.
45///
46/// This closely maps to the `Client` term in [RFC 9720], but we avoid that term to avoid ambiguity;
47/// `Client` is very overloaded with distinct meanings.
48///
49/// There is one `Session` per user per device. A session can contain many MLS groups/conversations.
50///
51/// It is cheap to clone a `Session` because everything heavy is wrapped inside an [Arc].
52///
53/// [RFC 9720]: https://www.rfc-editor.org/rfc/rfc9420.html
54#[derive(Clone, Debug)]
55pub struct Session {
56    pub(crate) inner: Arc<RwLock<Option<SessionInner>>>,
57    pub(crate) crypto_provider: MlsCryptoProvider,
58    pub(crate) transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
59}
60
61#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
62#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
63impl HasSessionAndCrypto for Session {
64    async fn session(&self) -> mls::Result<Session> {
65        Ok(self.clone())
66    }
67
68    async fn crypto_provider(&self) -> mls::Result<MlsCryptoProvider> {
69        Ok(self.crypto_provider.clone())
70    }
71}
72
73#[derive(Clone)]
74pub(crate) struct SessionInner {
75    id: ClientId,
76    pub(crate) identities: Identities,
77    keypackage_lifetime: std::time::Duration,
78    epoch_observer: Option<Arc<dyn EpochObserver>>,
79}
80
81impl fmt::Debug for SessionInner {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        let observer_debug = if self.epoch_observer.is_some() {
84            "Some(Arc<dyn EpochObserver>)"
85        } else {
86            "None"
87        };
88        f.debug_struct("ClientInner")
89            .field("id", &self.id)
90            .field("identities", &self.identities)
91            .field("keypackage_lifetime", &self.keypackage_lifetime)
92            .field("epoch_observer", &observer_debug)
93            .finish()
94    }
95}
96
97impl Session {
98    /// Tries to initialize the [Client].
99    /// Takes a store path (i.e. Disk location of the embedded database, should be consistent between messaging sessions)
100    /// And a root identity key (i.e. enclaved encryption key for this device)
101    ///
102    /// # Arguments
103    /// * `configuration` - the configuration for the `MlsCentral`
104    ///
105    /// # Errors
106    /// Failures in the initialization of the KeyStore can cause errors, such as IO, the same kind
107    /// of errors can happen when the groups are being restored from the KeyStore or even during
108    /// the client initialization (to fetch the identity signature). Other than that, `MlsError`
109    /// can be caused by group deserialization or during the initialization of the credentials:
110    /// * for x509 Credentials if the certificate chain length is lower than 2
111    /// * for Basic Credentials if the signature key cannot be generated either by not supported
112    ///   scheme or the key generation fails
113    pub async fn try_new(configuration: MlsClientConfiguration) -> crate::mls::Result<Self> {
114        // Init backend (crypto + rand + keystore)
115        let mls_backend = MlsCryptoProvider::try_new_with_configuration(MlsCryptoProviderConfiguration {
116            db_path: &configuration.store_path,
117            db_key: configuration.database_key.clone(),
118            in_memory: false,
119            entropy_seed: configuration.external_entropy.clone(),
120        })
121        .await
122        .map_err(MlsError::wrap("trying to initialize mls crypto provider object"))?;
123        Self::new_with_backend(mls_backend, configuration).await
124    }
125
126    /// Same as the [Client::try_new] but instead, it uses an in memory KeyStore.
127    /// Although required, the `store_path` parameter from the `MlsClientConfiguration` won't be used here.
128    pub async fn try_new_in_memory(configuration: MlsClientConfiguration) -> crate::mls::Result<Self> {
129        let mls_backend = MlsCryptoProvider::try_new_with_configuration(MlsCryptoProviderConfiguration {
130            db_path: &configuration.store_path,
131            db_key: configuration.database_key.clone(),
132            in_memory: true,
133            entropy_seed: configuration.external_entropy.clone(),
134        })
135        .await
136        .map_err(MlsError::wrap(
137            "trying to initialize mls crypto provider object (in memory)",
138        ))?;
139        Self::new_with_backend(mls_backend, configuration).await
140    }
141
142    async fn new_with_backend(
143        mls_backend: MlsCryptoProvider,
144        configuration: MlsClientConfiguration,
145    ) -> crate::mls::Result<Self> {
146        // We create the core crypto instance first to enable creating a transaction from it and
147        // doing all subsequent actions inside a single transaction, though it forces us to clone
148        // a few Arcs and locks.
149        let client = Self {
150            crypto_provider: mls_backend.clone(),
151            inner: Default::default(),
152            transport: Arc::new(None.into()),
153        };
154
155        let cc = CoreCrypto::from(client.clone());
156        let context = cc
157            .new_transaction()
158            .await
159            .map_err(RecursiveError::transaction("starting new transaction"))?;
160
161        if let Some(id) = configuration.client_id {
162            client
163                .init(
164                    ClientIdentifier::Basic(id),
165                    configuration.ciphersuites.as_slice(),
166                    &mls_backend,
167                    configuration
168                        .nb_init_key_packages
169                        .unwrap_or(INITIAL_KEYING_MATERIAL_COUNT),
170                )
171                .await
172                .map_err(RecursiveError::mls_client("initializing mls client"))?
173        }
174
175        let central = cc.mls;
176        context
177            .init_pki_env()
178            .await
179            .map_err(RecursiveError::transaction("initializing pki environment"))?;
180        context
181            .finish()
182            .await
183            .map_err(RecursiveError::transaction("finishing transaction"))?;
184        Ok(central)
185    }
186
187    /// Provide the implementation of functions to communicate with the delivery service
188    /// (see [MlsTransport]).
189    pub async fn provide_transport(&self, transport: Arc<dyn MlsTransport>) {
190        self.transport.write().await.replace(transport);
191    }
192
193    /// Initializes the client.
194    /// If the client's cryptographic material is already stored in the keystore, it loads it
195    /// Otherwise, it is being created.
196    ///
197    /// # Arguments
198    /// * `identifier` - client identifier ; either a [ClientId] or a x509 certificate chain
199    /// * `ciphersuites` - all ciphersuites this client is supposed to support
200    /// * `backend` - the KeyStore and crypto provider to read identities from
201    ///
202    /// # Errors
203    /// KeyStore and OpenMls errors can happen
204    pub async fn init(
205        &self,
206        identifier: ClientIdentifier,
207        ciphersuites: &[MlsCiphersuite],
208        backend: &MlsCryptoProvider,
209        nb_key_package: usize,
210    ) -> Result<()> {
211        self.ensure_unready().await?;
212        let id = identifier.get_id()?;
213
214        let credentials = backend
215            .key_store()
216            .find_all::<MlsCredential>(EntityFindParams::default())
217            .await
218            .map_err(KeystoreError::wrap("finding all mls credentials"))?;
219
220        let credentials = credentials
221            .into_iter()
222            .filter(|mls_credential| mls_credential.id.as_slice() == id.as_slice())
223            .map(|mls_credential| -> Result<_> {
224                let credential = Credential::tls_deserialize(&mut mls_credential.credential.as_slice())
225                    .map_err(Error::tls_deserialize("mls credential"))?;
226                Ok((credential, mls_credential.created_at))
227            })
228            .collect::<Result<Vec<_>>>()?;
229
230        if credentials.is_empty() {
231            debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Generating client");
232            self.generate(identifier, backend, ciphersuites, nb_key_package).await?;
233        } else {
234            let signature_schemes = ciphersuites
235                .iter()
236                .map(|cs| cs.signature_algorithm())
237                .collect::<HashSet<_>>();
238            let load_result = self.load(backend, id.as_ref(), credentials, signature_schemes).await;
239            if let Err(Error::ClientSignatureNotFound) = load_result {
240                debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Client signature not found. Generating client");
241                self.generate(identifier, backend, ciphersuites, nb_key_package).await?;
242            } else {
243                load_result?;
244            }
245        };
246
247        Ok(())
248    }
249
250    /// Resets the client to an uninitialized state.
251    #[cfg(test)]
252    pub(crate) async fn reset(&self) {
253        let mut inner_lock = self.inner.write().await;
254        *inner_lock = None;
255    }
256
257    pub(crate) async fn is_ready(&self) -> bool {
258        let inner_lock = self.inner.read().await;
259        inner_lock.is_some()
260    }
261
262    async fn ensure_unready(&self) -> Result<()> {
263        if self.is_ready().await {
264            Err(Error::UnexpectedlyReady)
265        } else {
266            Ok(())
267        }
268    }
269
270    async fn replace_inner(&self, new_inner: SessionInner) {
271        let mut inner_lock = self.inner.write().await;
272        *inner_lock = Some(new_inner);
273    }
274
275    /// Get an immutable view of an `MlsConversation`.
276    ///
277    /// Because it operates on the raw conversation type, this may be faster than [crate::mls::TransactionContext::conversation].
278    /// for transient and immutable purposes. For long-lived or mutable purposes, prefer the other method.
279    pub async fn get_raw_conversation(&self, id: &ConversationId) -> Result<ImmutableConversation> {
280        let raw_conversation = GroupStore::fetch_from_keystore(id, &self.crypto_provider.keystore(), None)
281            .await
282            .map_err(RecursiveError::root("getting conversation by id"))?
283            .ok_or_else(|| LeafError::ConversationNotFound(id.clone()))?;
284        Ok(ImmutableConversation::new(raw_conversation, self.clone()))
285    }
286
287    /// Returns the client's most recent public signature key as a buffer.
288    /// Used to upload a public key to the server in order to verify client's messages signature.
289    ///
290    /// # Arguments
291    /// * `ciphersuite` - a callback to be called to perform authorization
292    /// * `credential_type` - of the credential to look for
293    pub async fn public_key(
294        &self,
295        ciphersuite: MlsCiphersuite,
296        credential_type: MlsCredentialType,
297    ) -> crate::mls::Result<Vec<u8>> {
298        let cb = self
299            .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
300            .await
301            .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
302        Ok(cb.signature_key.to_public_vec())
303    }
304
305    pub(crate) fn new_basic_credential_bundle(
306        id: &ClientId,
307        sc: SignatureScheme,
308        backend: &MlsCryptoProvider,
309    ) -> Result<CredentialBundle> {
310        let (sk, pk) = backend
311            .crypto()
312            .signature_key_gen(sc)
313            .map_err(MlsError::wrap("generating a signature key"))?;
314
315        let signature_key = SignatureKeyPair::from_raw(sc, sk, pk);
316        let credential = Credential::new_basic(id.to_vec());
317        let cb = CredentialBundle {
318            credential,
319            signature_key,
320            created_at: 0,
321        };
322
323        Ok(cb)
324    }
325
326    pub(crate) fn new_x509_credential_bundle(cert: CertificateBundle) -> Result<CredentialBundle> {
327        let created_at = cert
328            .get_created_at()
329            .map_err(RecursiveError::mls_credential("getting credetntial created at"))?;
330        let (sk, ..) = cert.private_key.into_parts();
331        let chain = cert.certificate_chain;
332
333        let kp = CertificateKeyPair::new(sk, chain.clone()).map_err(MlsError::wrap("creating certificate key pair"))?;
334
335        let credential = Credential::new_x509(chain).map_err(MlsError::wrap("creating x509 credential"))?;
336
337        let cb = CredentialBundle {
338            credential,
339            signature_key: kp.0,
340            created_at,
341        };
342        Ok(cb)
343    }
344
345    /// Checks if a given conversation id exists locally
346    pub async fn conversation_exists(&self, id: &ConversationId) -> Result<bool> {
347        match self.get_raw_conversation(id).await {
348            Ok(_) => Ok(true),
349            Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
350            Err(e) => Err(e),
351        }
352    }
353
354    /// Generates a random byte array of the specified size
355    pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
356        use openmls_traits::random::OpenMlsRand as _;
357        self.crypto_provider
358            .rand()
359            .random_vec(len)
360            .map_err(MlsError::wrap("generating random vector"))
361            .map_err(Into::into)
362    }
363
364    /// Reports whether the local KeyStore believes that it can currently close.
365    ///
366    /// Beware TOCTOU!
367    pub async fn can_close(&self) -> bool {
368        self.crypto_provider.can_close().await
369    }
370
371    /// Closes the connection with the local KeyStore
372    ///
373    /// # Errors
374    /// KeyStore errors, such as IO
375    pub async fn close(self) -> crate::mls::Result<()> {
376        self.crypto_provider
377            .close()
378            .await
379            .map_err(MlsError::wrap("closing connection with keystore"))
380            .map_err(Into::into)
381    }
382
383    /// see [mls_crypto_provider::MlsCryptoProvider::reseed]
384    pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
385        self.crypto_provider
386            .reseed(seed)
387            .map_err(MlsError::wrap("reseeding mls backend"))
388            .map_err(Into::into)
389    }
390
391    /// Initializes a raw MLS keypair without an associated client ID
392    /// Returns a random ClientId to bind later in [Client::init_with_external_client_id]
393    ///
394    /// # Arguments
395    /// * `ciphersuites` - all ciphersuites this client is supposed to support
396    /// * `backend` - the KeyStore and crypto provider to read identities from
397    ///
398    /// # Errors
399    /// KeyStore and OpenMls errors can happen
400    pub async fn generate_raw_keypairs(
401        &self,
402        ciphersuites: &[MlsCiphersuite],
403        backend: &MlsCryptoProvider,
404    ) -> Result<Vec<ClientId>> {
405        self.ensure_unready().await?;
406        const TEMP_KEY_SIZE: usize = 16;
407
408        let credentials = Self::find_all_basic_credentials(backend).await?;
409        if !credentials.is_empty() {
410            return Err(Error::IdentityAlreadyPresent);
411        }
412
413        use openmls_traits::random::OpenMlsRand as _;
414        // Here we generate a provisional, random, uuid-like random Client ID for no purpose other than database/store constraints
415        let mut tmp_client_ids = Vec::with_capacity(ciphersuites.len());
416        for cs in ciphersuites {
417            let tmp_client_id: ClientId = backend
418                .rand()
419                .random_vec(TEMP_KEY_SIZE)
420                .map_err(MlsError::wrap("generating random client id"))?
421                .into();
422
423            let cb = Self::new_basic_credential_bundle(&tmp_client_id, cs.signature_algorithm(), backend)?;
424
425            let sign_kp = MlsSignatureKeyPair::new(
426                cs.signature_algorithm(),
427                cb.signature_key.to_public_vec(),
428                cb.signature_key
429                    .tls_serialize_detached()
430                    .map_err(Error::tls_serialize("signature key"))?,
431                tmp_client_id.clone().into(),
432            );
433            backend
434                .key_store()
435                .save(sign_kp)
436                .await
437                .map_err(KeystoreError::wrap("save signature keypair in keystore"))?;
438
439            tmp_client_ids.push(tmp_client_id);
440        }
441
442        Ok(tmp_client_ids)
443    }
444
445    /// Finalizes initialization using a 2-step process of uploading first a public key and then associating a new Client ID to that keypair
446    ///
447    /// # Arguments
448    /// * `client_id` - The client ID you have fetched from the MLS Authentication Service
449    /// * `tmp_ids` - The temporary random client ids generated in the previous step [Client::generate_raw_keypairs]
450    /// * `ciphersuites` - To initialize the Client with
451    /// * `backend` - the KeyStore and crypto provider to read identities from
452    ///
453    /// **WARNING**: You have absolutely NO reason to call this if you didn't call [Client::generate_raw_keypairs] first. You have been warned!
454    pub async fn init_with_external_client_id(
455        &self,
456        client_id: ClientId,
457        tmp_ids: Vec<ClientId>,
458        ciphersuites: &[MlsCiphersuite],
459        backend: &MlsCryptoProvider,
460    ) -> Result<()> {
461        self.ensure_unready().await?;
462        // Find all the keypairs, get the ones that exist (or bail), then insert new ones + delete the provisional ones
463        let stored_skp = backend
464            .key_store()
465            .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
466            .await
467            .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
468
469        match stored_skp.len().cmp(&tmp_ids.len()) {
470            std::cmp::Ordering::Less => return Err(Error::NoProvisionalIdentityFound),
471            std::cmp::Ordering::Greater => return Err(Error::TooManyIdentitiesPresent),
472            _ => {}
473        }
474
475        // we verify that the supplied temporary ids are all present in the keypairs we have in store
476        let all_tmp_ids_exist = stored_skp
477            .iter()
478            .all(|kp| tmp_ids.contains(&kp.credential_id.as_slice().into()));
479        if !all_tmp_ids_exist {
480            return Err(Error::NoProvisionalIdentityFound);
481        }
482
483        let identities = stored_skp.iter().zip(ciphersuites);
484
485        self.replace_inner(SessionInner {
486            id: client_id.clone(),
487            identities: Identities::new(stored_skp.len()),
488            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
489            epoch_observer: None,
490        })
491        .await;
492
493        let id = &client_id;
494
495        for (tmp_kp, &cs) in identities {
496            let scheme = tmp_kp
497                .signature_scheme
498                .try_into()
499                .map_err(|_| Error::InvalidSignatureScheme)?;
500            let new_keypair =
501                MlsSignatureKeyPair::new(scheme, tmp_kp.pk.clone(), tmp_kp.keypair.clone(), id.clone().into());
502
503            let new_credential = MlsCredential {
504                id: id.clone().into(),
505                credential: tmp_kp.credential_id.clone(),
506                created_at: 0,
507            };
508
509            // Delete the old identity optimistically
510            backend
511                .key_store()
512                .remove::<MlsSignatureKeyPair, &[u8]>(&new_keypair.pk)
513                .await
514                .map_err(KeystoreError::wrap("removing mls signature keypair"))?;
515
516            let signature_key = SignatureKeyPair::tls_deserialize(&mut new_keypair.keypair.as_slice())
517                .map_err(Error::tls_deserialize("signature key"))?;
518            let cb = CredentialBundle {
519                credential: Credential::new_basic(new_credential.credential.clone()),
520                signature_key,
521                created_at: 0, // this is fine setting a default value here, this will be set in `save_identity` to the current timestamp
522            };
523
524            // And now we save the new one
525            self.save_identity(&backend.keystore(), Some(id), cs.signature_algorithm(), cb)
526                .await?;
527        }
528
529        Ok(())
530    }
531
532    /// Generates a brand new client from scratch
533    pub(crate) async fn generate(
534        &self,
535        identifier: ClientIdentifier,
536        backend: &MlsCryptoProvider,
537        ciphersuites: &[MlsCiphersuite],
538        nb_key_package: usize,
539    ) -> Result<()> {
540        self.ensure_unready().await?;
541        let id = identifier.get_id()?;
542        let signature_schemes = ciphersuites
543            .iter()
544            .map(|cs| cs.signature_algorithm())
545            .collect::<HashSet<_>>();
546        self.replace_inner(SessionInner {
547            id: id.into_owned(),
548            identities: Identities::new(signature_schemes.len()),
549            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
550            epoch_observer: None,
551        })
552        .await;
553
554        let identities = identifier.generate_credential_bundles(backend, signature_schemes)?;
555
556        for (sc, id, cb) in identities {
557            self.save_identity(&backend.keystore(), Some(&id), sc, cb).await?;
558        }
559
560        let guard = self.inner.read().await;
561        let SessionInner { identities, .. } = guard.as_ref().ok_or(Error::MlsNotInitialized)?;
562
563        if nb_key_package != 0 {
564            for ciphersuite in ciphersuites.iter().copied() {
565                let ciphersuite_signature_scheme = ciphersuite.signature_algorithm();
566                for credential_bundle in identities.iter().filter_map(|(signature_scheme, credential_bundle)| {
567                    (signature_scheme == ciphersuite_signature_scheme).then_some(credential_bundle)
568                }) {
569                    let credential_type = credential_bundle.credential.credential_type().into();
570                    self.request_key_packages(nb_key_package, ciphersuite, credential_type, backend)
571                        .await?;
572                }
573            }
574        }
575
576        Ok(())
577    }
578
579    /// Loads the client from the keystore.
580    pub(crate) async fn load(
581        &self,
582        backend: &MlsCryptoProvider,
583        id: &ClientId,
584        mut credentials: Vec<(Credential, u64)>,
585        signature_schemes: HashSet<SignatureScheme>,
586    ) -> Result<()> {
587        self.ensure_unready().await?;
588        let mut identities = Identities::new(signature_schemes.len());
589
590        // ensures we load credentials in chronological order
591        credentials.sort_by_key(|(_, timestamp)| *timestamp);
592
593        let stored_signature_keypairs = backend
594            .key_store()
595            .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
596            .await
597            .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
598
599        for signature_scheme in signature_schemes {
600            let signature_keypair = stored_signature_keypairs
601                .iter()
602                .find(|skp| skp.signature_scheme == (signature_scheme as u16));
603
604            let signature_key = if let Some(kp) = signature_keypair {
605                SignatureKeyPair::tls_deserialize(&mut kp.keypair.as_slice())
606                    .map_err(Error::tls_deserialize("signature keypair"))?
607            } else {
608                let (private_key, public_key) = backend
609                    .crypto()
610                    .signature_key_gen(signature_scheme)
611                    .map_err(MlsError::wrap("generating signature key"))?;
612                let keypair = SignatureKeyPair::from_raw(signature_scheme, private_key, public_key.clone());
613                let raw_keypair = keypair
614                    .tls_serialize_detached()
615                    .map_err(Error::tls_serialize("raw keypair"))?;
616                let store_keypair =
617                    MlsSignatureKeyPair::new(signature_scheme, public_key, raw_keypair, id.as_slice().into());
618                backend
619                    .key_store()
620                    .save(store_keypair.clone())
621                    .await
622                    .map_err(KeystoreError::wrap("storing keypairs in keystore"))?;
623                SignatureKeyPair::tls_deserialize(&mut store_keypair.keypair.as_slice())
624                    .map_err(Error::tls_deserialize("signature keypair"))?
625            };
626
627            for (credential, created_at) in &credentials {
628                match credential.mls_credential() {
629                    openmls::prelude::MlsCredentialType::Basic(_) => {
630                        if id.as_slice() != credential.identity() {
631                            return Err(Error::WrongCredential);
632                        }
633                    }
634                    openmls::prelude::MlsCredentialType::X509(cert) => {
635                        let spk = cert
636                            .extract_public_key()
637                            .map_err(RecursiveError::mls_credential("extracting public key"))?
638                            .ok_or(LeafError::InternalMlsError)?;
639                        if signature_key.public() != spk {
640                            return Err(Error::WrongCredential);
641                        }
642                    }
643                };
644                let cb = CredentialBundle {
645                    credential: credential.clone(),
646                    signature_key: signature_key.clone(),
647                    created_at: *created_at,
648                };
649                identities.push_credential_bundle(signature_scheme, cb).await?;
650            }
651        }
652        self.replace_inner(SessionInner {
653            id: id.clone(),
654            identities,
655            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
656            epoch_observer: None,
657        })
658        .await;
659        Ok(())
660    }
661
662    /// Restore from an external [`HistorySecret`].
663    pub(crate) async fn restore_from_history_secret(&self, history_secret: HistorySecret) -> Result<()> {
664        self.ensure_unready().await?;
665
666        // store the client id (with some other stuff)
667        self.replace_inner(SessionInner {
668            id: history_secret.client_id.clone(),
669            identities: Identities::new(1),
670            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
671            epoch_observer: None,
672        })
673        .await;
674
675        // store the key package
676        let key_package = history_secret
677            .key_package
678            .store(&self.crypto_provider)
679            .await
680            .map_err(MlsError::wrap("storing key package encapsulation"))?;
681
682        let keystore = self.crypto_provider.key_store();
683
684        // store the credential bundle (with some other stuff)
685        self.save_identity(
686            keystore,
687            Some(&history_secret.client_id),
688            key_package.ciphersuite().signature_algorithm(),
689            history_secret.credential_bundle,
690        )
691        .await?;
692
693        Ok(())
694    }
695
696    async fn find_all_basic_credentials(backend: &MlsCryptoProvider) -> Result<Vec<Credential>> {
697        let store_credentials = backend
698            .key_store()
699            .find_all::<MlsCredential>(EntityFindParams::default())
700            .await
701            .map_err(KeystoreError::wrap("finding all mls credentialss"))?;
702        let mut credentials = Vec::with_capacity(store_credentials.len());
703        for store_credential in store_credentials.into_iter() {
704            let credential = Credential::tls_deserialize(&mut store_credential.credential.as_slice())
705                .map_err(Error::tls_deserialize("credential"))?;
706            if !matches!(credential.credential_type(), CredentialType::Basic) {
707                continue;
708            }
709            credentials.push(credential);
710        }
711
712        Ok(credentials)
713    }
714
715    pub(crate) async fn save_identity(
716        &self,
717        keystore: &Connection,
718        id: Option<&ClientId>,
719        signature_scheme: SignatureScheme,
720        mut credential_bundle: CredentialBundle,
721    ) -> Result<CredentialBundle> {
722        let mut guard = self.inner.write().await;
723        let SessionInner {
724            id: existing_id,
725            identities,
726            ..
727        } = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
728
729        let id = id.unwrap_or(existing_id);
730
731        let credential = credential_bundle
732            .credential
733            .tls_serialize_detached()
734            .map_err(Error::tls_serialize("credential bundle"))?;
735        let credential = MlsCredential {
736            id: id.clone().into(),
737            credential,
738            created_at: 0,
739        };
740
741        let credential = keystore
742            .save(credential)
743            .await
744            .map_err(KeystoreError::wrap("saving credential"))?;
745
746        let sign_kp = MlsSignatureKeyPair::new(
747            signature_scheme,
748            credential_bundle.signature_key.to_public_vec(),
749            credential_bundle
750                .signature_key
751                .tls_serialize_detached()
752                .map_err(Error::tls_serialize("signature keypair"))?,
753            id.clone().into(),
754        );
755        keystore.save(sign_kp).await.map_err(|e| match e {
756            CryptoKeystoreError::AlreadyExists => Error::CredentialBundleConflict,
757            _ => KeystoreError::wrap("saving mls signature key pair")(e).into(),
758        })?;
759
760        // set the creation date of the signature keypair which is the same for the CredentialBundle
761        credential_bundle.created_at = credential.created_at;
762
763        identities
764            .push_credential_bundle(signature_scheme, credential_bundle.clone())
765            .await?;
766
767        Ok(credential_bundle)
768    }
769
770    /// Retrieves the client's client id. This is free-form and not inspected.
771    pub async fn id(&self) -> Result<ClientId> {
772        match self.inner.read().await.deref() {
773            None => Err(Error::MlsNotInitialized),
774            Some(SessionInner { id, .. }) => Ok(id.clone()),
775        }
776    }
777
778    /// Returns whether this client is E2EI capable
779    pub async fn is_e2ei_capable(&self) -> bool {
780        match self.inner.read().await.deref() {
781            None => false,
782            Some(SessionInner { identities, .. }) => identities
783                .iter()
784                .any(|(_, cred)| cred.credential().credential_type() == CredentialType::X509),
785        }
786    }
787
788    pub(crate) async fn get_most_recent_or_create_credential_bundle(
789        &self,
790        backend: &MlsCryptoProvider,
791        sc: SignatureScheme,
792        ct: MlsCredentialType,
793    ) -> Result<Arc<CredentialBundle>> {
794        match ct {
795            MlsCredentialType::Basic => {
796                self.init_basic_credential_bundle_if_missing(backend, sc).await?;
797                self.find_most_recent_credential_bundle(sc, ct).await
798            }
799            MlsCredentialType::X509 => self
800                .find_most_recent_credential_bundle(sc, ct)
801                .await
802                .map_err(|e| match e {
803                    Error::CredentialNotFound(_) => LeafError::E2eiEnrollmentNotDone.into(),
804                    _ => e,
805                }),
806        }
807    }
808
809    pub(crate) async fn init_basic_credential_bundle_if_missing(
810        &self,
811        backend: &MlsCryptoProvider,
812        sc: SignatureScheme,
813    ) -> Result<()> {
814        let existing_cb = self
815            .find_most_recent_credential_bundle(sc, MlsCredentialType::Basic)
816            .await;
817        if matches!(existing_cb, Err(Error::CredentialNotFound(_))) {
818            let id = self.id().await?;
819            debug!(id:% = &id; "Initializing basic credential bundle");
820            let cb = Self::new_basic_credential_bundle(&id, sc, backend)?;
821            self.save_identity(&backend.keystore(), None, sc, cb).await?;
822        }
823        Ok(())
824    }
825
826    pub(crate) async fn save_new_x509_credential_bundle(
827        &self,
828        keystore: &Connection,
829        sc: SignatureScheme,
830        cb: CertificateBundle,
831    ) -> Result<CredentialBundle> {
832        let id = cb
833            .get_client_id()
834            .map_err(RecursiveError::mls_credential("getting client id"))?;
835        let cb = Self::new_x509_credential_bundle(cb)?;
836        self.save_identity(keystore, Some(&id), sc, cb).await
837    }
838}
839
840#[cfg(test)]
841mod tests {
842    use super::*;
843    use crate::prelude::ClientId;
844    use crate::test_utils::*;
845    use crate::transaction_context::test_utils::EntitiesCount;
846    use core_crypto_keystore::connection::{DatabaseKey, FetchFromDatabase};
847    use core_crypto_keystore::entities::*;
848    use mls_crypto_provider::MlsCryptoProvider;
849    use wasm_bindgen_test::*;
850
851    impl Session {
852        // test functions are not held to the same documentation standard as proper functions
853        #![allow(missing_docs)]
854
855        pub async fn random_generate(
856            &self,
857            case: &crate::test_utils::TestContext,
858            signer: Option<&crate::test_utils::x509::X509Certificate>,
859            provision: bool,
860        ) -> Result<()> {
861            self.reset().await;
862            let user_uuid = uuid::Uuid::new_v4();
863            let rnd_id = rand::random::<usize>();
864            let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
865            let identity = match case.credential_type {
866                MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.as_str().into()),
867                MlsCredentialType::X509 => {
868                    let signer = signer.expect("Missing intermediate CA");
869                    CertificateBundle::rand_identifier(&client_id, &[signer])
870                }
871            };
872            let nb_key_package = if provision {
873                crate::prelude::INITIAL_KEYING_MATERIAL_COUNT
874            } else {
875                0
876            };
877            let backend = self.crypto_provider.clone();
878            self.generate(identity, &backend, &[case.ciphersuite()], nb_key_package)
879                .await?;
880            Ok(())
881        }
882
883        pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
884            use core_crypto_keystore::CryptoKeystoreMls as _;
885            let kps = backend
886                .key_store()
887                .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
888                .await
889                .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
890            Ok(kps)
891        }
892
893        pub(crate) async fn init_x509_credential_bundle_if_missing(
894            &self,
895            backend: &MlsCryptoProvider,
896            sc: SignatureScheme,
897            cb: CertificateBundle,
898        ) -> Result<()> {
899            let existing_cb = self
900                .find_most_recent_credential_bundle(sc, MlsCredentialType::X509)
901                .await
902                .is_err();
903            if existing_cb {
904                self.save_new_x509_credential_bundle(&backend.keystore(), sc, cb)
905                    .await?;
906            }
907            Ok(())
908        }
909
910        pub(crate) async fn generate_one_keypackage(
911            &self,
912            backend: &MlsCryptoProvider,
913            cs: MlsCiphersuite,
914            ct: MlsCredentialType,
915        ) -> Result<openmls::prelude::KeyPackage> {
916            let cb = self
917                .find_most_recent_credential_bundle(cs.signature_algorithm(), ct)
918                .await?;
919            self.generate_one_keypackage_from_credential_bundle(backend, cs, &cb)
920                .await
921        }
922
923        /// Count the entities
924        pub async fn count_entities(&self) -> EntitiesCount {
925            let keystore = self.crypto_provider.keystore();
926            let credential = keystore.count::<MlsCredential>().await.unwrap();
927            let encryption_keypair = keystore.count::<MlsEncryptionKeyPair>().await.unwrap();
928            let epoch_encryption_keypair = keystore.count::<MlsEpochEncryptionKeyPair>().await.unwrap();
929            let enrollment = keystore.count::<E2eiEnrollment>().await.unwrap();
930            let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
931            let hpke_private_key = keystore.count::<MlsHpkePrivateKey>().await.unwrap();
932            let key_package = keystore.count::<MlsKeyPackage>().await.unwrap();
933            let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
934            let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
935            let psk_bundle = keystore.count::<MlsPskBundle>().await.unwrap();
936            let signature_keypair = keystore.count::<MlsSignatureKeyPair>().await.unwrap();
937            EntitiesCount {
938                credential,
939                encryption_keypair,
940                epoch_encryption_keypair,
941                enrollment,
942                group,
943                hpke_private_key,
944                key_package,
945                pending_group,
946                pending_messages,
947                psk_bundle,
948                signature_keypair,
949            }
950        }
951    }
952    wasm_bindgen_test_configure!(run_in_browser);
953
954    #[apply(all_cred_cipher)]
955    #[wasm_bindgen_test]
956    async fn can_generate_session(case: TestContext) {
957        let [alice] = case.sessions().await;
958        let key = DatabaseKey::generate();
959        let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
960        let x509_test_chain = if case.is_x509() {
961            let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
962            x509_test_chain.register_with_provider(&backend).await;
963            Some(x509_test_chain)
964        } else {
965            None
966        };
967        backend.new_transaction().await.unwrap();
968        let session = alice.session().await;
969        session
970            .random_generate(
971                &case,
972                x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
973                false,
974            )
975            .await
976            .unwrap();
977    }
978
979    #[apply(all_cred_cipher)]
980    #[wasm_bindgen_test]
981    async fn can_externally_generate_client(case: TestContext) {
982        let [alice] = case.sessions().await;
983        if !case.is_basic() {
984            return;
985        }
986        run_tests(move |[tmp_dir_argument]| {
987            Box::pin(async move {
988                let key = DatabaseKey::generate();
989                let backend = MlsCryptoProvider::try_new(tmp_dir_argument, &key).await.unwrap();
990                backend.new_transaction().await.unwrap();
991                // phase 1: generate standalone keypair
992                let client_id: ClientId = b"whatever:my:client:is@world.com".to_vec().into();
993                let alice = alice.session().await;
994                alice.reset().await;
995                // TODO: test with multi-ciphersuite. Tracking issue: WPB-9601
996                let handles = alice
997                    .generate_raw_keypairs(&[case.ciphersuite()], &backend)
998                    .await
999                    .unwrap();
1000
1001                let mut identities = backend
1002                    .keystore()
1003                    .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
1004                    .await
1005                    .unwrap();
1006
1007                assert_eq!(identities.len(), 1);
1008
1009                let prov_identity = identities.pop().unwrap();
1010
1011                // Make sure we are actually returning the clientId
1012                // TODO: test with multi-ciphersuite. Tracking issue: WPB-9601
1013                let prov_client_id: ClientId = prov_identity.credential_id.as_slice().into();
1014                assert_eq!(&prov_client_id, handles.first().unwrap());
1015
1016                // phase 2: pretend we have a new client ID from the backend, and try to init the client this way
1017                alice
1018                    .init_with_external_client_id(client_id.clone(), handles.clone(), &[case.ciphersuite()], &backend)
1019                    .await
1020                    .unwrap();
1021
1022                // Make sure both client id and PK are intact
1023                assert_eq!(alice.id().await.unwrap(), client_id);
1024                let cb = alice
1025                    .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
1026                    .await
1027                    .unwrap();
1028                let client_id: ClientId = cb.credential().identity().into();
1029                assert_eq!(&client_id, handles.first().unwrap());
1030            })
1031        })
1032        .await
1033    }
1034}