core_crypto/mls/session/
mod.rs

1pub(crate) mod e2e_identity;
2mod epoch_observer;
3mod error;
4pub(crate) mod id;
5pub(crate) mod identifier;
6pub(crate) mod identities;
7pub(crate) mod key_package;
8pub(crate) mod user_id;
9
10use crate::{
11    CoreCrypto, KeystoreError, LeafError, MlsError, MlsTransport, RecursiveError,
12    group_store::GroupStore,
13    mls::{
14        self, HasSessionAndCrypto,
15        conversation::ImmutableConversation,
16        credential::{CredentialBundle, ext::CredentialExt},
17    },
18    prelude::{
19        CertificateBundle, ClientId, ConversationId, INITIAL_KEYING_MATERIAL_COUNT, MlsCiphersuite,
20        MlsClientConfiguration, MlsCredentialType, identifier::ClientIdentifier,
21        key_package::KEYPACKAGE_DEFAULT_LIFETIME,
22    },
23};
24use async_lock::RwLock;
25use core_crypto_keystore::{
26    Connection, CryptoKeystoreError,
27    connection::FetchFromDatabase,
28    entities::{EntityFindParams, MlsCredential, MlsSignatureKeyPair},
29};
30pub use epoch_observer::EpochObserver;
31pub(crate) use error::{Error, Result};
32use identities::Identities;
33use log::debug;
34use mls_crypto_provider::{EntropySeed, MlsCryptoProvider, MlsCryptoProviderConfiguration};
35use openmls::prelude::{Credential, CredentialType};
36use openmls_basic_credential::SignatureKeyPair;
37use openmls_traits::{OpenMlsCryptoProvider, crypto::OpenMlsCrypto, types::SignatureScheme};
38use openmls_x509_credential::CertificateKeyPair;
39use std::ops::{Deref, DerefMut};
40use std::sync::Arc;
41use std::{collections::HashSet, fmt};
42use tls_codec::{Deserialize, Serialize};
43
44/// A MLS Session enables a user device to communicate via the MLS protocol.
45///
46/// This closely maps to the `Client` term in [RFC 9720], but we avoid that term to avoid ambiguity;
47/// `Client` is very overloaded with distinct meanings.
48///
49/// There is one `Session` per user per device. A session can contain many MLS groups/conversations.
50///
51/// It is cheap to clone a `Session` because everything heavy is wrapped inside an [Arc].
52///
53/// [RFC 9720]: https://www.rfc-editor.org/rfc/rfc9420.html
54#[derive(Clone, Debug)]
55pub struct Session {
56    pub(crate) inner: Arc<RwLock<Option<SessionInner>>>,
57    pub(crate) crypto_provider: MlsCryptoProvider,
58    pub(crate) transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
59}
60
61#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
62#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
63impl HasSessionAndCrypto for Session {
64    async fn session(&self) -> mls::Result<Session> {
65        Ok(self.clone())
66    }
67
68    async fn crypto_provider(&self) -> mls::Result<MlsCryptoProvider> {
69        Ok(self.crypto_provider.clone())
70    }
71}
72
73#[derive(Clone)]
74pub(crate) struct SessionInner {
75    id: ClientId,
76    pub(crate) identities: Identities,
77    keypackage_lifetime: std::time::Duration,
78    epoch_observer: Option<Arc<dyn EpochObserver>>,
79}
80
81impl fmt::Debug for SessionInner {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        let observer_debug = if self.epoch_observer.is_some() {
84            "Some(Arc<dyn EpochObserver>)"
85        } else {
86            "None"
87        };
88        f.debug_struct("ClientInner")
89            .field("id", &self.id)
90            .field("identities", &self.identities)
91            .field("keypackage_lifetime", &self.keypackage_lifetime)
92            .field("epoch_observer", &observer_debug)
93            .finish()
94    }
95}
96
97impl Session {
98    /// Tries to initialize the [Client].
99    /// Takes a store path (i.e. Disk location of the embedded database, should be consistent between messaging sessions)
100    /// And a root identity key (i.e. enclaved encryption key for this device)
101    ///
102    /// # Arguments
103    /// * `configuration` - the configuration for the `MlsCentral`
104    ///
105    /// # Errors
106    /// Failures in the initialization of the KeyStore can cause errors, such as IO, the same kind
107    /// of errors can happen when the groups are being restored from the KeyStore or even during
108    /// the client initialization (to fetch the identity signature). Other than that, `MlsError`
109    /// can be caused by group deserialization or during the initialization of the credentials:
110    /// * for x509 Credentials if the certificate chain length is lower than 2
111    /// * for Basic Credentials if the signature key cannot be generated either by not supported
112    ///   scheme or the key generation fails
113    pub async fn try_new(configuration: MlsClientConfiguration) -> crate::mls::Result<Self> {
114        // Init backend (crypto + rand + keystore)
115        let mls_backend = MlsCryptoProvider::try_new_with_configuration(MlsCryptoProviderConfiguration {
116            db_path: &configuration.store_path,
117            db_key: configuration.database_key.clone(),
118            in_memory: false,
119            entropy_seed: configuration.external_entropy.clone(),
120        })
121        .await
122        .map_err(MlsError::wrap("trying to initialize mls crypto provider object"))?;
123        Self::new_with_backend(mls_backend, configuration).await
124    }
125
126    /// Same as the [Client::try_new] but instead, it uses an in memory KeyStore.
127    /// Although required, the `store_path` parameter from the `MlsClientConfiguration` won't be used here.
128    pub async fn try_new_in_memory(configuration: MlsClientConfiguration) -> crate::mls::Result<Self> {
129        let mls_backend = MlsCryptoProvider::try_new_with_configuration(MlsCryptoProviderConfiguration {
130            db_path: &configuration.store_path,
131            db_key: configuration.database_key.clone(),
132            in_memory: true,
133            entropy_seed: configuration.external_entropy.clone(),
134        })
135        .await
136        .map_err(MlsError::wrap(
137            "trying to initialize mls crypto provider object (in memory)",
138        ))?;
139        Self::new_with_backend(mls_backend, configuration).await
140    }
141
142    async fn new_with_backend(
143        mls_backend: MlsCryptoProvider,
144        configuration: MlsClientConfiguration,
145    ) -> crate::mls::Result<Self> {
146        // We create the core crypto instance first to enable creating a transaction from it and
147        // doing all subsequent actions inside a single transaction, though it forces us to clone
148        // a few Arcs and locks.
149        let client = Self {
150            crypto_provider: mls_backend.clone(),
151            inner: Default::default(),
152            transport: Arc::new(None.into()),
153        };
154
155        let cc = CoreCrypto::from(client.clone());
156        let context = cc
157            .new_transaction()
158            .await
159            .map_err(RecursiveError::transaction("starting new transaction"))?;
160
161        if let Some(id) = configuration.client_id {
162            client
163                .init(
164                    ClientIdentifier::Basic(id),
165                    configuration.ciphersuites.as_slice(),
166                    &mls_backend,
167                    configuration
168                        .nb_init_key_packages
169                        .unwrap_or(INITIAL_KEYING_MATERIAL_COUNT),
170                )
171                .await
172                .map_err(RecursiveError::mls_client("initializing mls client"))?
173        }
174
175        let central = cc.mls;
176        context
177            .init_pki_env()
178            .await
179            .map_err(RecursiveError::transaction("initializing pki environment"))?;
180        context
181            .finish()
182            .await
183            .map_err(RecursiveError::transaction("finishing transaction"))?;
184        Ok(central)
185    }
186
187    /// Provide the implementation of functions to communicate with the delivery service
188    /// (see [MlsTransport]).
189    pub async fn provide_transport(&self, transport: Arc<dyn MlsTransport>) {
190        self.transport.write().await.replace(transport);
191    }
192
193    /// Initializes the client.
194    /// If the client's cryptographic material is already stored in the keystore, it loads it
195    /// Otherwise, it is being created.
196    ///
197    /// # Arguments
198    /// * `identifier` - client identifier ; either a [ClientId] or a x509 certificate chain
199    /// * `ciphersuites` - all ciphersuites this client is supposed to support
200    /// * `backend` - the KeyStore and crypto provider to read identities from
201    ///
202    /// # Errors
203    /// KeyStore and OpenMls errors can happen
204    pub async fn init(
205        &self,
206        identifier: ClientIdentifier,
207        ciphersuites: &[MlsCiphersuite],
208        backend: &MlsCryptoProvider,
209        nb_key_package: usize,
210    ) -> Result<()> {
211        self.ensure_unready().await?;
212        let id = identifier.get_id()?;
213
214        let credentials = backend
215            .key_store()
216            .find_all::<MlsCredential>(EntityFindParams::default())
217            .await
218            .map_err(KeystoreError::wrap("finding all mls credentials"))?;
219
220        let credentials = credentials
221            .into_iter()
222            .filter(|mls_credential| &mls_credential.id[..] == id.as_slice())
223            .map(|mls_credential| -> Result<_> {
224                let credential = Credential::tls_deserialize(&mut mls_credential.credential.as_slice())
225                    .map_err(Error::tls_deserialize("mls credential"))?;
226                Ok((credential, mls_credential.created_at))
227            })
228            .collect::<Result<Vec<_>>>()?;
229
230        if !credentials.is_empty() {
231            let signature_schemes = ciphersuites
232                .iter()
233                .map(|cs| cs.signature_algorithm())
234                .collect::<HashSet<_>>();
235            match self.load(backend, id.as_ref(), credentials, signature_schemes).await {
236                Ok(client) => client,
237                Err(Error::ClientSignatureNotFound) => {
238                    debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Client signature not found. Generating client");
239                    self.generate(identifier, backend, ciphersuites, nb_key_package).await?
240                }
241                Err(e) => return Err(e),
242            }
243        } else {
244            debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Generating client");
245            self.generate(identifier, backend, ciphersuites, nb_key_package).await?
246        };
247
248        Ok(())
249    }
250
251    /// Resets the client to an uninitialized state.
252    #[cfg(test)]
253    pub(crate) async fn reset(&self) {
254        let mut inner_lock = self.inner.write().await;
255        *inner_lock = None;
256    }
257
258    pub(crate) async fn is_ready(&self) -> bool {
259        let inner_lock = self.inner.read().await;
260        inner_lock.is_some()
261    }
262
263    async fn ensure_unready(&self) -> Result<()> {
264        if self.is_ready().await {
265            Err(Error::UnexpectedlyReady)
266        } else {
267            Ok(())
268        }
269    }
270
271    async fn replace_inner(&self, new_inner: SessionInner) {
272        let mut inner_lock = self.inner.write().await;
273        *inner_lock = Some(new_inner);
274    }
275
276    /// Get an immutable view of an `MlsConversation`.
277    ///
278    /// Because it operates on the raw conversation type, this may be faster than [crate::mls::TransactionContext::conversation].
279    /// for transient and immutable purposes. For long-lived or mutable purposes, prefer the other method.
280    pub async fn get_raw_conversation(&self, id: &ConversationId) -> Result<ImmutableConversation> {
281        let raw_conversation = GroupStore::fetch_from_keystore(id, &self.crypto_provider.keystore(), None)
282            .await
283            .map_err(RecursiveError::root("getting conversation by id"))?
284            .ok_or_else(|| LeafError::ConversationNotFound(id.clone()))?;
285        Ok(ImmutableConversation::new(raw_conversation, self.clone()))
286    }
287
288    /// Returns the client's most recent public signature key as a buffer.
289    /// Used to upload a public key to the server in order to verify client's messages signature.
290    ///
291    /// # Arguments
292    /// * `ciphersuite` - a callback to be called to perform authorization
293    /// * `credential_type` - of the credential to look for
294    pub async fn public_key(
295        &self,
296        ciphersuite: MlsCiphersuite,
297        credential_type: MlsCredentialType,
298    ) -> crate::mls::Result<Vec<u8>> {
299        let cb = self
300            .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
301            .await
302            .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
303        Ok(cb.signature_key.to_public_vec())
304    }
305
306    pub(crate) fn new_basic_credential_bundle(
307        id: &ClientId,
308        sc: SignatureScheme,
309        backend: &MlsCryptoProvider,
310    ) -> Result<CredentialBundle> {
311        let (sk, pk) = backend
312            .crypto()
313            .signature_key_gen(sc)
314            .map_err(MlsError::wrap("generating a signature key"))?;
315
316        let signature_key = SignatureKeyPair::from_raw(sc, sk, pk);
317        let credential = Credential::new_basic(id.to_vec());
318        let cb = CredentialBundle {
319            credential,
320            signature_key,
321            created_at: 0,
322        };
323
324        Ok(cb)
325    }
326
327    pub(crate) fn new_x509_credential_bundle(cert: CertificateBundle) -> Result<CredentialBundle> {
328        let created_at = cert
329            .get_created_at()
330            .map_err(RecursiveError::mls_credential("getting credetntial created at"))?;
331        let (sk, ..) = cert.private_key.into_parts();
332        let chain = cert.certificate_chain;
333
334        let kp = CertificateKeyPair::new(sk, chain.clone()).map_err(MlsError::wrap("creating certificate key pair"))?;
335
336        let credential = Credential::new_x509(chain).map_err(MlsError::wrap("creating x509 credential"))?;
337
338        let cb = CredentialBundle {
339            credential,
340            signature_key: kp.0,
341            created_at,
342        };
343        Ok(cb)
344    }
345
346    /// Checks if a given conversation id exists locally
347    pub async fn conversation_exists(&self, id: &ConversationId) -> Result<bool> {
348        match self.get_raw_conversation(id).await {
349            Ok(_) => Ok(true),
350            Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
351            Err(e) => Err(e),
352        }
353    }
354
355    /// Generates a random byte array of the specified size
356    pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
357        use openmls_traits::random::OpenMlsRand as _;
358        self.crypto_provider
359            .rand()
360            .random_vec(len)
361            .map_err(MlsError::wrap("generating random vector"))
362            .map_err(Into::into)
363    }
364
365    /// Reports whether the local KeyStore believes that it can currently close.
366    ///
367    /// Beware TOCTOU!
368    pub async fn can_close(&self) -> bool {
369        self.crypto_provider.can_close().await
370    }
371
372    /// Closes the connection with the local KeyStore
373    ///
374    /// # Errors
375    /// KeyStore errors, such as IO
376    pub async fn close(self) -> crate::mls::Result<()> {
377        self.crypto_provider
378            .close()
379            .await
380            .map_err(MlsError::wrap("closing connection with keystore"))
381            .map_err(Into::into)
382    }
383
384    /// see [mls_crypto_provider::MlsCryptoProvider::reseed]
385    pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
386        self.crypto_provider
387            .reseed(seed)
388            .map_err(MlsError::wrap("reseeding mls backend"))
389            .map_err(Into::into)
390    }
391
392    /// Initializes a raw MLS keypair without an associated client ID
393    /// Returns a random ClientId to bind later in [Client::init_with_external_client_id]
394    ///
395    /// # Arguments
396    /// * `ciphersuites` - all ciphersuites this client is supposed to support
397    /// * `backend` - the KeyStore and crypto provider to read identities from
398    ///
399    /// # Errors
400    /// KeyStore and OpenMls errors can happen
401    pub async fn generate_raw_keypairs(
402        &self,
403        ciphersuites: &[MlsCiphersuite],
404        backend: &MlsCryptoProvider,
405    ) -> Result<Vec<ClientId>> {
406        self.ensure_unready().await?;
407        const TEMP_KEY_SIZE: usize = 16;
408
409        let credentials = Self::find_all_basic_credentials(backend).await?;
410        if !credentials.is_empty() {
411            return Err(Error::IdentityAlreadyPresent);
412        }
413
414        use openmls_traits::random::OpenMlsRand as _;
415        // Here we generate a provisional, random, uuid-like random Client ID for no purpose other than database/store constraints
416        let mut tmp_client_ids = Vec::with_capacity(ciphersuites.len());
417        for cs in ciphersuites {
418            let tmp_client_id: ClientId = backend
419                .rand()
420                .random_vec(TEMP_KEY_SIZE)
421                .map_err(MlsError::wrap("generating random client id"))?
422                .into();
423
424            let cb = Self::new_basic_credential_bundle(&tmp_client_id, cs.signature_algorithm(), backend)?;
425
426            let sign_kp = MlsSignatureKeyPair::new(
427                cs.signature_algorithm(),
428                cb.signature_key.to_public_vec(),
429                cb.signature_key
430                    .tls_serialize_detached()
431                    .map_err(Error::tls_serialize("signature key"))?,
432                tmp_client_id.clone().into(),
433            );
434            backend
435                .key_store()
436                .save(sign_kp)
437                .await
438                .map_err(KeystoreError::wrap("save signature keypair in keystore"))?;
439
440            tmp_client_ids.push(tmp_client_id);
441        }
442
443        Ok(tmp_client_ids)
444    }
445
446    /// Finalizes initialization using a 2-step process of uploading first a public key and then associating a new Client ID to that keypair
447    ///
448    /// # Arguments
449    /// * `client_id` - The client ID you have fetched from the MLS Authentication Service
450    /// * `tmp_ids` - The temporary random client ids generated in the previous step [Client::generate_raw_keypairs]
451    /// * `ciphersuites` - To initialize the Client with
452    /// * `backend` - the KeyStore and crypto provider to read identities from
453    ///
454    /// **WARNING**: You have absolutely NO reason to call this if you didn't call [Client::generate_raw_keypairs] first. You have been warned!
455    pub async fn init_with_external_client_id(
456        &self,
457        client_id: ClientId,
458        tmp_ids: Vec<ClientId>,
459        ciphersuites: &[MlsCiphersuite],
460        backend: &MlsCryptoProvider,
461    ) -> Result<()> {
462        self.ensure_unready().await?;
463        // Find all the keypairs, get the ones that exist (or bail), then insert new ones + delete the provisional ones
464        let stored_skp = backend
465            .key_store()
466            .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
467            .await
468            .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
469
470        match stored_skp.len().cmp(&tmp_ids.len()) {
471            std::cmp::Ordering::Less => return Err(Error::NoProvisionalIdentityFound),
472            std::cmp::Ordering::Greater => return Err(Error::TooManyIdentitiesPresent),
473            _ => {}
474        }
475
476        // we verify that the supplied temporary ids are all present in the keypairs we have in store
477        let all_tmp_ids_exist = stored_skp
478            .iter()
479            .all(|kp| tmp_ids.contains(&kp.credential_id.as_slice().into()));
480        if !all_tmp_ids_exist {
481            return Err(Error::NoProvisionalIdentityFound);
482        }
483
484        let identities = stored_skp.iter().zip(ciphersuites);
485
486        self.replace_inner(SessionInner {
487            id: client_id.clone(),
488            identities: Identities::new(stored_skp.len()),
489            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
490            epoch_observer: None,
491        })
492        .await;
493
494        let id = &client_id;
495
496        for (tmp_kp, &cs) in identities {
497            let scheme = tmp_kp
498                .signature_scheme
499                .try_into()
500                .map_err(|_| Error::InvalidSignatureScheme)?;
501            let new_keypair =
502                MlsSignatureKeyPair::new(scheme, tmp_kp.pk.clone(), tmp_kp.keypair.clone(), id.clone().into());
503
504            let new_credential = MlsCredential {
505                id: id.clone().into(),
506                credential: tmp_kp.credential_id.clone(),
507                created_at: 0,
508            };
509
510            // Delete the old identity optimistically
511            backend
512                .key_store()
513                .remove::<MlsSignatureKeyPair, &[u8]>(&new_keypair.pk)
514                .await
515                .map_err(KeystoreError::wrap("removing mls signature keypair"))?;
516
517            let signature_key = SignatureKeyPair::tls_deserialize(&mut new_keypair.keypair.as_slice())
518                .map_err(Error::tls_deserialize("signature key"))?;
519            let cb = CredentialBundle {
520                credential: Credential::new_basic(new_credential.credential.clone()),
521                signature_key,
522                created_at: 0, // this is fine setting a default value here, this will be set in `save_identity` to the current timestamp
523            };
524
525            // And now we save the new one
526            self.save_identity(&backend.keystore(), Some(id), cs.signature_algorithm(), cb)
527                .await?;
528        }
529
530        Ok(())
531    }
532
533    /// Generates a brand new client from scratch
534    pub(crate) async fn generate(
535        &self,
536        identifier: ClientIdentifier,
537        backend: &MlsCryptoProvider,
538        ciphersuites: &[MlsCiphersuite],
539        nb_key_package: usize,
540    ) -> Result<()> {
541        self.ensure_unready().await?;
542        let id = identifier.get_id()?;
543        let signature_schemes = ciphersuites
544            .iter()
545            .map(|cs| cs.signature_algorithm())
546            .collect::<HashSet<_>>();
547        self.replace_inner(SessionInner {
548            id: id.into_owned(),
549            identities: Identities::new(signature_schemes.len()),
550            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
551            epoch_observer: None,
552        })
553        .await;
554
555        let identities = identifier.generate_credential_bundles(backend, signature_schemes)?;
556
557        for (sc, id, cb) in identities {
558            self.save_identity(&backend.keystore(), Some(&id), sc, cb).await?;
559        }
560
561        let identities = match self.inner.read().await.deref() {
562            None => return Err(Error::MlsNotInitialized),
563            // Cloning is fine because identities is an arc internally.
564            // We can't keep the lock for longer because requesting the key packages below will also
565            // acquire it.
566            Some(SessionInner { identities, .. }) => identities.clone(),
567        };
568
569        if nb_key_package != 0 {
570            for cs in ciphersuites {
571                let sc = cs.signature_algorithm();
572                let identity = identities.iter().filter(|(id_sc, _)| id_sc == &sc);
573                for (_, cb) in identity {
574                    self.request_key_packages(nb_key_package, *cs, cb.credential.credential_type().into(), backend)
575                        .await?;
576                }
577            }
578        }
579
580        Ok(())
581    }
582
583    /// Loads the client from the keystore.
584    pub(crate) async fn load(
585        &self,
586        backend: &MlsCryptoProvider,
587        id: &ClientId,
588        mut credentials: Vec<(Credential, u64)>,
589        signature_schemes: HashSet<SignatureScheme>,
590    ) -> Result<()> {
591        self.ensure_unready().await?;
592        let mut identities = Identities::new(signature_schemes.len());
593
594        // ensures we load credentials in chronological order
595        credentials.sort_by_key(|(_, timestamp)| *timestamp);
596
597        let store_skps = backend
598            .key_store()
599            .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
600            .await
601            .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
602
603        for sc in signature_schemes {
604            let kp = store_skps.iter().find(|skp| skp.signature_scheme == (sc as u16));
605
606            let signature_key = if let Some(kp) = kp {
607                SignatureKeyPair::tls_deserialize(&mut kp.keypair.as_slice())
608                    .map_err(Error::tls_deserialize("signature keypair"))?
609            } else {
610                let (sk, pk) = backend
611                    .crypto()
612                    .signature_key_gen(sc)
613                    .map_err(MlsError::wrap("generating signature key"))?;
614                let keypair = SignatureKeyPair::from_raw(sc, sk, pk.clone());
615                let raw_keypair = keypair
616                    .tls_serialize_detached()
617                    .map_err(Error::tls_serialize("raw keypair"))?;
618                let store_keypair = MlsSignatureKeyPair::new(sc, pk, raw_keypair, id.as_slice().into());
619                backend
620                    .key_store()
621                    .save(store_keypair.clone())
622                    .await
623                    .map_err(KeystoreError::wrap("storing keypairs in keystore"))?;
624                SignatureKeyPair::tls_deserialize(&mut store_keypair.keypair.as_slice())
625                    .map_err(Error::tls_deserialize("signature keypair"))?
626            };
627
628            for (credential, created_at) in &credentials {
629                match credential.mls_credential() {
630                    openmls::prelude::MlsCredentialType::Basic(_) => {
631                        if id.as_slice() != credential.identity() {
632                            return Err(Error::WrongCredential);
633                        }
634                    }
635                    openmls::prelude::MlsCredentialType::X509(cert) => {
636                        let spk = cert
637                            .extract_public_key()
638                            .map_err(RecursiveError::mls_credential("extracting public key"))?
639                            .ok_or(LeafError::InternalMlsError)?;
640                        if signature_key.public() != spk {
641                            return Err(Error::WrongCredential);
642                        }
643                    }
644                };
645                let cb = CredentialBundle {
646                    credential: credential.clone(),
647                    signature_key: signature_key.clone(),
648                    created_at: *created_at,
649                };
650                identities.push_credential_bundle(sc, cb).await?;
651            }
652        }
653        self.replace_inner(SessionInner {
654            id: id.clone(),
655            identities,
656            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
657            epoch_observer: None,
658        })
659        .await;
660        Ok(())
661    }
662
663    async fn find_all_basic_credentials(backend: &MlsCryptoProvider) -> Result<Vec<Credential>> {
664        let store_credentials = backend
665            .key_store()
666            .find_all::<MlsCredential>(EntityFindParams::default())
667            .await
668            .map_err(KeystoreError::wrap("finding all mls credentialss"))?;
669        let mut credentials = Vec::with_capacity(store_credentials.len());
670        for store_credential in store_credentials.into_iter() {
671            let credential = Credential::tls_deserialize(&mut store_credential.credential.as_slice())
672                .map_err(Error::tls_deserialize("credential"))?;
673            if !matches!(credential.credential_type(), CredentialType::Basic) {
674                continue;
675            }
676            credentials.push(credential);
677        }
678
679        Ok(credentials)
680    }
681
682    pub(crate) async fn save_identity(
683        &self,
684        keystore: &Connection,
685        id: Option<&ClientId>,
686        sc: SignatureScheme,
687        mut cb: CredentialBundle,
688    ) -> Result<CredentialBundle> {
689        match self.inner.write().await.deref_mut() {
690            None => Err(Error::MlsNotInitialized),
691            Some(SessionInner {
692                id: existing_id,
693                identities,
694                ..
695            }) => {
696                let id = id.unwrap_or(existing_id);
697
698                let credential = cb
699                    .credential
700                    .tls_serialize_detached()
701                    .map_err(Error::tls_serialize("credential bundle"))?;
702                let credential = MlsCredential {
703                    id: id.clone().into(),
704                    credential,
705                    created_at: 0,
706                };
707
708                let credential = keystore
709                    .save(credential)
710                    .await
711                    .map_err(KeystoreError::wrap("saving credential"))?;
712
713                let sign_kp = MlsSignatureKeyPair::new(
714                    sc,
715                    cb.signature_key.to_public_vec(),
716                    cb.signature_key
717                        .tls_serialize_detached()
718                        .map_err(Error::tls_serialize("signature keypair"))?,
719                    id.clone().into(),
720                );
721                keystore.save(sign_kp).await.map_err(|e| match e {
722                    CryptoKeystoreError::AlreadyExists => Error::CredentialBundleConflict,
723                    _ => KeystoreError::wrap("saving mls signature key pair")(e).into(),
724                })?;
725
726                // set the creation date of the signature keypair which is the same for the CredentialBundle
727                cb.created_at = credential.created_at;
728
729                identities.push_credential_bundle(sc, cb.clone()).await?;
730
731                Ok(cb)
732            }
733        }
734    }
735
736    /// Retrieves the client's client id. This is free-form and not inspected.
737    pub async fn id(&self) -> Result<ClientId> {
738        match self.inner.read().await.deref() {
739            None => Err(Error::MlsNotInitialized),
740            Some(SessionInner { id, .. }) => Ok(id.clone()),
741        }
742    }
743
744    /// Returns whether this client is E2EI capable
745    pub async fn is_e2ei_capable(&self) -> bool {
746        match self.inner.read().await.deref() {
747            None => false,
748            Some(SessionInner { identities, .. }) => identities
749                .iter()
750                .any(|(_, cred)| cred.credential().credential_type() == CredentialType::X509),
751        }
752    }
753
754    pub(crate) async fn get_most_recent_or_create_credential_bundle(
755        &self,
756        backend: &MlsCryptoProvider,
757        sc: SignatureScheme,
758        ct: MlsCredentialType,
759    ) -> Result<Arc<CredentialBundle>> {
760        match ct {
761            MlsCredentialType::Basic => {
762                self.init_basic_credential_bundle_if_missing(backend, sc).await?;
763                self.find_most_recent_credential_bundle(sc, ct).await
764            }
765            MlsCredentialType::X509 => self
766                .find_most_recent_credential_bundle(sc, ct)
767                .await
768                .map_err(|e| match e {
769                    Error::CredentialNotFound(_) => LeafError::E2eiEnrollmentNotDone.into(),
770                    _ => e,
771                }),
772        }
773    }
774
775    pub(crate) async fn init_basic_credential_bundle_if_missing(
776        &self,
777        backend: &MlsCryptoProvider,
778        sc: SignatureScheme,
779    ) -> Result<()> {
780        let existing_cb = self
781            .find_most_recent_credential_bundle(sc, MlsCredentialType::Basic)
782            .await;
783        if matches!(existing_cb, Err(Error::CredentialNotFound(_))) {
784            let id = self.id().await?;
785            debug!(id:% = &id; "Initializing basic credential bundle");
786            let cb = Self::new_basic_credential_bundle(&id, sc, backend)?;
787            self.save_identity(&backend.keystore(), None, sc, cb).await?;
788        }
789        Ok(())
790    }
791
792    pub(crate) async fn save_new_x509_credential_bundle(
793        &self,
794        keystore: &Connection,
795        sc: SignatureScheme,
796        cb: CertificateBundle,
797    ) -> Result<CredentialBundle> {
798        let id = cb
799            .get_client_id()
800            .map_err(RecursiveError::mls_credential("getting client id"))?;
801        let cb = Self::new_x509_credential_bundle(cb)?;
802        self.save_identity(keystore, Some(&id), sc, cb).await
803    }
804}
805
806#[cfg(test)]
807mod tests {
808    use super::*;
809    use crate::prelude::ClientId;
810    use crate::test_utils::*;
811    use crate::transaction_context::test_utils::EntitiesCount;
812    use core_crypto_keystore::connection::{DatabaseKey, FetchFromDatabase};
813    use core_crypto_keystore::entities::*;
814    use mls_crypto_provider::MlsCryptoProvider;
815    use wasm_bindgen_test::*;
816
817    impl Session {
818        // test functions are not held to the same documentation standard as proper functions
819        #![allow(missing_docs)]
820
821        pub async fn random_generate(
822            &self,
823            case: &crate::test_utils::TestCase,
824            signer: Option<&crate::test_utils::x509::X509Certificate>,
825            provision: bool,
826        ) -> Result<()> {
827            self.reset().await;
828            let user_uuid = uuid::Uuid::new_v4();
829            let rnd_id = rand::random::<usize>();
830            let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
831            let identity = match case.credential_type {
832                MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.as_str().into()),
833                MlsCredentialType::X509 => {
834                    let signer = signer.expect("Missing intermediate CA");
835                    CertificateBundle::rand_identifier(&client_id, &[signer])
836                }
837            };
838            let nb_key_package = if provision {
839                crate::prelude::INITIAL_KEYING_MATERIAL_COUNT
840            } else {
841                0
842            };
843            let backend = self.crypto_provider.clone();
844            self.generate(identity, &backend, &[case.ciphersuite()], nb_key_package)
845                .await?;
846            Ok(())
847        }
848
849        pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
850            use core_crypto_keystore::CryptoKeystoreMls as _;
851            let kps = backend
852                .key_store()
853                .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
854                .await
855                .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
856            Ok(kps)
857        }
858
859        pub(crate) async fn init_x509_credential_bundle_if_missing(
860            &self,
861            backend: &MlsCryptoProvider,
862            sc: SignatureScheme,
863            cb: CertificateBundle,
864        ) -> Result<()> {
865            let existing_cb = self
866                .find_most_recent_credential_bundle(sc, MlsCredentialType::X509)
867                .await
868                .is_err();
869            if existing_cb {
870                self.save_new_x509_credential_bundle(&backend.keystore(), sc, cb)
871                    .await?;
872            }
873            Ok(())
874        }
875
876        pub(crate) async fn generate_one_keypackage(
877            &self,
878            backend: &MlsCryptoProvider,
879            cs: MlsCiphersuite,
880            ct: MlsCredentialType,
881        ) -> Result<openmls::prelude::KeyPackage> {
882            let cb = self
883                .find_most_recent_credential_bundle(cs.signature_algorithm(), ct)
884                .await?;
885            self.generate_one_keypackage_from_credential_bundle(backend, cs, &cb)
886                .await
887        }
888
889        /// Count the entities
890        pub async fn count_entities(&self) -> EntitiesCount {
891            let keystore = self.crypto_provider.keystore();
892            let credential = keystore.count::<MlsCredential>().await.unwrap();
893            let encryption_keypair = keystore.count::<MlsEncryptionKeyPair>().await.unwrap();
894            let epoch_encryption_keypair = keystore.count::<MlsEpochEncryptionKeyPair>().await.unwrap();
895            let enrollment = keystore.count::<E2eiEnrollment>().await.unwrap();
896            let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
897            let hpke_private_key = keystore.count::<MlsHpkePrivateKey>().await.unwrap();
898            let key_package = keystore.count::<MlsKeyPackage>().await.unwrap();
899            let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
900            let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
901            let psk_bundle = keystore.count::<MlsPskBundle>().await.unwrap();
902            let signature_keypair = keystore.count::<MlsSignatureKeyPair>().await.unwrap();
903            EntitiesCount {
904                credential,
905                encryption_keypair,
906                epoch_encryption_keypair,
907                enrollment,
908                group,
909                hpke_private_key,
910                key_package,
911                pending_group,
912                pending_messages,
913                psk_bundle,
914                signature_keypair,
915            }
916        }
917    }
918    wasm_bindgen_test_configure!(run_in_browser);
919
920    #[apply(all_cred_cipher)]
921    #[wasm_bindgen_test]
922    async fn can_generate_client(case: TestCase) {
923        run_test_with_central(case.clone(), move |[alice]| {
924            Box::pin(async move {
925                let key = DatabaseKey::generate();
926                let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
927                let x509_test_chain = if case.is_x509() {
928                    let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
929                    x509_test_chain.register_with_provider(&backend).await;
930                    Some(x509_test_chain)
931                } else {
932                    None
933                };
934                backend.new_transaction().await.unwrap();
935                let client = alice.session().await;
936                client
937                    .random_generate(
938                        &case,
939                        x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
940                        false,
941                    )
942                    .await
943                    .unwrap();
944            })
945        })
946        .await
947    }
948
949    #[apply(all_cred_cipher)]
950    #[wasm_bindgen_test]
951    async fn can_externally_generate_client(case: TestCase) {
952        run_test_with_central(case.clone(), move |[alice]| {
953            Box::pin(async move {
954                if case.is_basic() {
955                    run_tests(move |[tmp_dir_argument]| {
956                        Box::pin(async move {
957                            let key = DatabaseKey::generate();
958                            let backend = MlsCryptoProvider::try_new(tmp_dir_argument, &key).await.unwrap();
959                            backend.new_transaction().await.unwrap();
960                            // phase 1: generate standalone keypair
961                            let client_id: ClientId = b"whatever:my:client:is@world.com".to_vec().into();
962                            let alice = alice.session().await;
963                            alice.reset().await;
964                            // TODO: test with multi-ciphersuite. Tracking issue: WPB-9601
965                            let handles = alice
966                                .generate_raw_keypairs(&[case.ciphersuite()], &backend)
967                                .await
968                                .unwrap();
969
970                            let mut identities = backend
971                                .keystore()
972                                .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
973                                .await
974                                .unwrap();
975
976                            assert_eq!(identities.len(), 1);
977
978                            let prov_identity = identities.pop().unwrap();
979
980                            // Make sure we are actually returning the clientId
981                            // TODO: test with multi-ciphersuite. Tracking issue: WPB-9601
982                            let prov_client_id: ClientId = prov_identity.credential_id.as_slice().into();
983                            assert_eq!(&prov_client_id, handles.first().unwrap());
984
985                            // phase 2: pretend we have a new client ID from the backend, and try to init the client this way
986                            alice
987                                .init_with_external_client_id(
988                                    client_id.clone(),
989                                    handles.clone(),
990                                    &[case.ciphersuite()],
991                                    &backend,
992                                )
993                                .await
994                                .unwrap();
995
996                            // Make sure both client id and PK are intact
997                            assert_eq!(alice.id().await.unwrap(), client_id);
998                            let cb = alice
999                                .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
1000                                .await
1001                                .unwrap();
1002                            let client_id: ClientId = cb.credential().identity().into();
1003                            assert_eq!(&client_id, handles.first().unwrap());
1004                        })
1005                    })
1006                    .await
1007                }
1008            })
1009        })
1010        .await
1011    }
1012}