core_crypto/mls/session/
mod.rs

1pub(crate) mod e2e_identity;
2mod epoch_observer;
3mod error;
4pub(crate) mod id;
5pub(crate) mod identifier;
6pub(crate) mod identities;
7pub(crate) mod key_package;
8pub(crate) mod user_id;
9
10use crate::{
11    CoreCrypto, KeystoreError, LeafError, MlsError, MlsTransport, RecursiveError,
12    group_store::GroupStore,
13    mls::{
14        self, HasSessionAndCrypto,
15        conversation::ImmutableConversation,
16        credential::{CredentialBundle, ext::CredentialExt},
17    },
18    prelude::{
19        CertificateBundle, ClientId, ConversationId, HistorySecret, INITIAL_KEYING_MATERIAL_COUNT, MlsCiphersuite,
20        MlsClientConfiguration, MlsCredentialType, identifier::ClientIdentifier,
21        key_package::KEYPACKAGE_DEFAULT_LIFETIME,
22    },
23};
24use async_lock::RwLock;
25use core_crypto_keystore::{
26    Connection, CryptoKeystoreError,
27    connection::FetchFromDatabase,
28    entities::{EntityFindParams, MlsCredential, MlsSignatureKeyPair},
29};
30pub use epoch_observer::EpochObserver;
31pub(crate) use error::{Error, Result};
32use identities::Identities;
33use log::debug;
34use mls_crypto_provider::{EntropySeed, MlsCryptoProvider, MlsCryptoProviderConfiguration};
35use openmls::prelude::{Credential, CredentialType};
36use openmls_basic_credential::SignatureKeyPair;
37use openmls_traits::{OpenMlsCryptoProvider, crypto::OpenMlsCrypto, types::SignatureScheme};
38use openmls_x509_credential::CertificateKeyPair;
39use std::collections::HashSet;
40use std::ops::Deref;
41use std::sync::Arc;
42use tls_codec::{Deserialize, Serialize};
43
44/// A MLS Session enables a user device to communicate via the MLS protocol.
45///
46/// This closely maps to the `Client` term in [RFC 9720], but we avoid that term to avoid ambiguity;
47/// `Client` is very overloaded with distinct meanings.
48///
49/// There is one `Session` per user per device. A session can contain many MLS groups/conversations.
50///
51/// It is cheap to clone a `Session` because everything heavy is wrapped inside an [Arc].
52///
53/// [RFC 9720]: https://www.rfc-editor.org/rfc/rfc9420.html
54#[derive(Clone, derive_more::Debug)]
55pub struct Session {
56    pub(crate) inner: Arc<RwLock<Option<SessionInner>>>,
57    pub(crate) crypto_provider: MlsCryptoProvider,
58    pub(crate) transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
59    #[debug("EpochObserver")]
60    pub(crate) epoch_observer: Arc<RwLock<Option<Arc<dyn EpochObserver + 'static>>>>,
61}
62
63#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
64#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
65impl HasSessionAndCrypto for Session {
66    async fn session(&self) -> mls::Result<Session> {
67        Ok(self.clone())
68    }
69
70    async fn crypto_provider(&self) -> mls::Result<MlsCryptoProvider> {
71        Ok(self.crypto_provider.clone())
72    }
73}
74
75#[derive(Clone, Debug)]
76pub(crate) struct SessionInner {
77    id: ClientId,
78    pub(crate) identities: Identities,
79    keypackage_lifetime: std::time::Duration,
80}
81
82impl Session {
83    /// Creates a new [Session].
84    /// Takes a store path (i.e. Disk location of the embedded database, should be consistent between messaging sessions)
85    /// And a root identity key (i.e. enclaved encryption key for this device)
86    ///
87    /// # Arguments
88    /// * `configuration` - the configuration for the `MlsCentral`
89    ///
90    /// # Errors
91    /// Failures in the initialization of the KeyStore can cause errors, such as IO, the same kind
92    /// of errors can happen when the groups are being restored from the KeyStore or even during
93    /// the client initialization (to fetch the identity signature). Other than that, `MlsError`
94    /// can be caused by group deserialization or during the initialization of the credentials:
95    /// * for x509 Credentials if the certificate chain length is lower than 2
96    /// * for Basic Credentials if the signature key cannot be generated either by not supported
97    ///   scheme or the key generation fails
98    pub async fn try_new(configuration: MlsClientConfiguration) -> crate::mls::Result<Self> {
99        // Init backend (crypto + rand + keystore)
100        let mls_backend = MlsCryptoProvider::try_new_with_configuration(MlsCryptoProviderConfiguration {
101            db_path: &configuration.store_path,
102            db_key: configuration.database_key.clone(),
103            in_memory: false,
104            entropy_seed: configuration.external_entropy.clone(),
105        })
106        .await
107        .map_err(MlsError::wrap("trying to initialize mls crypto provider object"))?;
108        Self::new_with_backend(mls_backend, configuration).await
109    }
110
111    /// Same as the [Self::try_new] but instead, it uses an in memory KeyStore.
112    /// Although required, the `store_path` parameter from the `MlsClientConfiguration` won't be used here.
113    pub async fn try_new_in_memory(configuration: MlsClientConfiguration) -> crate::mls::Result<Self> {
114        let mls_backend = MlsCryptoProvider::try_new_with_configuration(MlsCryptoProviderConfiguration {
115            db_path: &configuration.store_path,
116            db_key: configuration.database_key.clone(),
117            in_memory: true,
118            entropy_seed: configuration.external_entropy.clone(),
119        })
120        .await
121        .map_err(MlsError::wrap(
122            "trying to initialize mls crypto provider object (in memory)",
123        ))?;
124        Self::new_with_backend(mls_backend, configuration).await
125    }
126
127    async fn new_with_backend(
128        mls_backend: MlsCryptoProvider,
129        configuration: MlsClientConfiguration,
130    ) -> crate::mls::Result<Self> {
131        // We create the core crypto instance first to enable creating a transaction from it and
132        // doing all subsequent actions inside a single transaction, though it forces us to clone
133        // a few Arcs and locks.
134        let client = Self {
135            crypto_provider: mls_backend.clone(),
136            inner: Default::default(),
137            transport: Arc::new(None.into()),
138            epoch_observer: Arc::new(None.into()),
139        };
140
141        let cc = CoreCrypto::from(client.clone());
142        let context = cc
143            .new_transaction()
144            .await
145            .map_err(RecursiveError::transaction("starting new transaction"))?;
146
147        if let Some(id) = configuration.client_id {
148            client
149                .init(
150                    ClientIdentifier::Basic(id),
151                    configuration.ciphersuites.as_slice(),
152                    &mls_backend,
153                    configuration
154                        .nb_init_key_packages
155                        .unwrap_or(INITIAL_KEYING_MATERIAL_COUNT),
156                )
157                .await
158                .map_err(RecursiveError::mls_client("initializing mls client"))?
159        }
160
161        let central = cc.mls;
162        context
163            .init_pki_env()
164            .await
165            .map_err(RecursiveError::transaction("initializing pki environment"))?;
166        context
167            .finish()
168            .await
169            .map_err(RecursiveError::transaction("finishing transaction"))?;
170        Ok(central)
171    }
172
173    /// Provide the implementation of functions to communicate with the delivery service
174    /// (see [MlsTransport]).
175    pub async fn provide_transport(&self, transport: Arc<dyn MlsTransport>) {
176        self.transport.write().await.replace(transport);
177    }
178
179    /// Initializes the client.
180    /// If the client's cryptographic material is already stored in the keystore, it loads it
181    /// Otherwise, it is being created.
182    ///
183    /// # Arguments
184    /// * `identifier` - client identifier ; either a [ClientId] or a x509 certificate chain
185    /// * `ciphersuites` - all ciphersuites this client is supposed to support
186    /// * `backend` - the KeyStore and crypto provider to read identities from
187    ///
188    /// # Errors
189    /// KeyStore and OpenMls errors can happen
190    pub async fn init(
191        &self,
192        identifier: ClientIdentifier,
193        ciphersuites: &[MlsCiphersuite],
194        backend: &MlsCryptoProvider,
195        nb_key_package: usize,
196    ) -> Result<()> {
197        self.ensure_unready().await?;
198        let id = identifier.get_id()?;
199
200        let credentials = backend
201            .key_store()
202            .find_all::<MlsCredential>(EntityFindParams::default())
203            .await
204            .map_err(KeystoreError::wrap("finding all mls credentials"))?;
205
206        let credentials = credentials
207            .into_iter()
208            .filter(|mls_credential| mls_credential.id.as_slice() == id.as_slice())
209            .map(|mls_credential| -> Result<_> {
210                let credential = Credential::tls_deserialize(&mut mls_credential.credential.as_slice())
211                    .map_err(Error::tls_deserialize("mls credential"))?;
212                Ok((credential, mls_credential.created_at))
213            })
214            .collect::<Result<Vec<_>>>()?;
215
216        if credentials.is_empty() {
217            debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Generating client");
218            self.generate(identifier, backend, ciphersuites, nb_key_package).await?;
219        } else {
220            let signature_schemes = ciphersuites
221                .iter()
222                .map(|cs| cs.signature_algorithm())
223                .collect::<HashSet<_>>();
224            let load_result = self.load(backend, id.as_ref(), credentials, signature_schemes).await;
225            if let Err(Error::ClientSignatureNotFound) = load_result {
226                debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Client signature not found. Generating client");
227                self.generate(identifier, backend, ciphersuites, nb_key_package).await?;
228            } else {
229                load_result?;
230            }
231        };
232
233        Ok(())
234    }
235
236    /// Resets the client to an uninitialized state.
237    #[cfg(test)]
238    pub(crate) async fn reset(&self) {
239        let mut inner_lock = self.inner.write().await;
240        *inner_lock = None;
241    }
242
243    pub(crate) async fn is_ready(&self) -> bool {
244        let inner_lock = self.inner.read().await;
245        inner_lock.is_some()
246    }
247
248    async fn ensure_unready(&self) -> Result<()> {
249        if self.is_ready().await {
250            Err(Error::UnexpectedlyReady)
251        } else {
252            Ok(())
253        }
254    }
255
256    async fn replace_inner(&self, new_inner: SessionInner) {
257        let mut inner_lock = self.inner.write().await;
258        *inner_lock = Some(new_inner);
259    }
260
261    /// Get an immutable view of an `MlsConversation`.
262    ///
263    /// Because it operates on the raw conversation type, this may be faster than
264    /// [crate::transaction_context::TransactionContext::conversation]. for transient and immutable
265    /// purposes. For long-lived or mutable purposes, prefer the other method.
266    pub async fn get_raw_conversation(&self, id: &ConversationId) -> Result<ImmutableConversation> {
267        let raw_conversation = GroupStore::fetch_from_keystore(id, &self.crypto_provider.keystore(), None)
268            .await
269            .map_err(RecursiveError::root("getting conversation by id"))?
270            .ok_or_else(|| LeafError::ConversationNotFound(id.clone()))?;
271        Ok(ImmutableConversation::new(raw_conversation, self.clone()))
272    }
273
274    /// Returns the client's most recent public signature key as a buffer.
275    /// Used to upload a public key to the server in order to verify client's messages signature.
276    ///
277    /// # Arguments
278    /// * `ciphersuite` - a callback to be called to perform authorization
279    /// * `credential_type` - of the credential to look for
280    pub async fn public_key(
281        &self,
282        ciphersuite: MlsCiphersuite,
283        credential_type: MlsCredentialType,
284    ) -> crate::mls::Result<Vec<u8>> {
285        let cb = self
286            .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
287            .await
288            .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
289        Ok(cb.signature_key.to_public_vec())
290    }
291
292    pub(crate) fn new_basic_credential_bundle(
293        id: &ClientId,
294        sc: SignatureScheme,
295        backend: &MlsCryptoProvider,
296    ) -> Result<CredentialBundle> {
297        let (sk, pk) = backend
298            .crypto()
299            .signature_key_gen(sc)
300            .map_err(MlsError::wrap("generating a signature key"))?;
301
302        let signature_key = SignatureKeyPair::from_raw(sc, sk, pk);
303        let credential = Credential::new_basic(id.to_vec());
304        let cb = CredentialBundle {
305            credential,
306            signature_key,
307            created_at: 0,
308        };
309
310        Ok(cb)
311    }
312
313    pub(crate) fn new_x509_credential_bundle(cert: CertificateBundle) -> Result<CredentialBundle> {
314        let created_at = cert
315            .get_created_at()
316            .map_err(RecursiveError::mls_credential("getting credetntial created at"))?;
317        let (sk, ..) = cert.private_key.into_parts();
318        let chain = cert.certificate_chain;
319
320        let kp = CertificateKeyPair::new(sk, chain.clone()).map_err(MlsError::wrap("creating certificate key pair"))?;
321
322        let credential = Credential::new_x509(chain).map_err(MlsError::wrap("creating x509 credential"))?;
323
324        let cb = CredentialBundle {
325            credential,
326            signature_key: kp.0,
327            created_at,
328        };
329        Ok(cb)
330    }
331
332    /// Checks if a given conversation id exists locally
333    pub async fn conversation_exists(&self, id: &ConversationId) -> Result<bool> {
334        match self.get_raw_conversation(id).await {
335            Ok(_) => Ok(true),
336            Err(Error::Leaf(LeafError::ConversationNotFound(_))) => Ok(false),
337            Err(e) => Err(e),
338        }
339    }
340
341    /// Generates a random byte array of the specified size
342    pub fn random_bytes(&self, len: usize) -> crate::mls::Result<Vec<u8>> {
343        use openmls_traits::random::OpenMlsRand as _;
344        self.crypto_provider
345            .rand()
346            .random_vec(len)
347            .map_err(MlsError::wrap("generating random vector"))
348            .map_err(Into::into)
349    }
350
351    /// Reports whether the local KeyStore believes that it can currently close.
352    ///
353    /// Beware TOCTOU!
354    pub async fn can_close(&self) -> bool {
355        self.crypto_provider.can_close().await
356    }
357
358    /// Closes the connection with the local KeyStore
359    ///
360    /// # Errors
361    /// KeyStore errors, such as IO
362    pub async fn close(self) -> crate::mls::Result<()> {
363        self.crypto_provider
364            .close()
365            .await
366            .map_err(MlsError::wrap("closing connection with keystore"))
367            .map_err(Into::into)
368    }
369
370    /// see [mls_crypto_provider::MlsCryptoProvider::reseed]
371    pub async fn reseed(&self, seed: Option<EntropySeed>) -> crate::mls::Result<()> {
372        self.crypto_provider
373            .reseed(seed)
374            .map_err(MlsError::wrap("reseeding mls backend"))
375            .map_err(Into::into)
376    }
377
378    // Initializes a raw MLS keypair without an associated client ID
379    // Returns a random ClientId to bind later in [Session::init_with_external_client_id]
380    //
381    // # Arguments
382    // * `ciphersuites` - all ciphersuites this client is supposed to support
383    // * `backend` - the KeyStore and crypto provider to read identities from
384    //
385    // # Errors
386    // KeyStore and OpenMls errors can happen
387    pub(crate) async fn generate_raw_keypairs(
388        &self,
389        ciphersuites: &[MlsCiphersuite],
390        backend: &MlsCryptoProvider,
391    ) -> Result<Vec<ClientId>> {
392        self.ensure_unready().await?;
393        const TEMP_KEY_SIZE: usize = 16;
394
395        let credentials = Self::find_all_basic_credentials(backend).await?;
396        if !credentials.is_empty() {
397            return Err(Error::IdentityAlreadyPresent);
398        }
399
400        use openmls_traits::random::OpenMlsRand as _;
401        // Here we generate a provisional, random, uuid-like random Client ID for no purpose other than database/store constraints
402        let mut tmp_client_ids = Vec::with_capacity(ciphersuites.len());
403        for cs in ciphersuites {
404            let tmp_client_id: ClientId = backend
405                .rand()
406                .random_vec(TEMP_KEY_SIZE)
407                .map_err(MlsError::wrap("generating random client id"))?
408                .into();
409
410            let cb = Self::new_basic_credential_bundle(&tmp_client_id, cs.signature_algorithm(), backend)?;
411
412            let sign_kp = MlsSignatureKeyPair::new(
413                cs.signature_algorithm(),
414                cb.signature_key.to_public_vec(),
415                cb.signature_key
416                    .tls_serialize_detached()
417                    .map_err(Error::tls_serialize("signature key"))?,
418                tmp_client_id.clone().into(),
419            );
420            backend
421                .key_store()
422                .save(sign_kp)
423                .await
424                .map_err(KeystoreError::wrap("save signature keypair in keystore"))?;
425
426            tmp_client_ids.push(tmp_client_id);
427        }
428
429        Ok(tmp_client_ids)
430    }
431
432    // Finalizes initialization using a 2-step process of uploading first a public key and then associating a new Client ID to that keypair
433    //
434    // # Arguments
435    // * `client_id` - The client ID you have fetched from the MLS Authentication Service
436    // * `tmp_ids` - The temporary random client ids generated in the previous step [Session::generate_raw_keypairs]
437    // * `ciphersuites` - To initialize the Client with
438    // * `backend` - the KeyStore and crypto provider to read identities from
439    //
440    // **WARNING**: You have absolutely NO reason to call this if you didn't call [Session::generate_raw_keypairs] first. You have been warned!
441    pub(crate) async fn init_with_external_client_id(
442        &self,
443        client_id: ClientId,
444        tmp_ids: Vec<ClientId>,
445        ciphersuites: &[MlsCiphersuite],
446        backend: &MlsCryptoProvider,
447    ) -> Result<()> {
448        self.ensure_unready().await?;
449        // Find all the keypairs, get the ones that exist (or bail), then insert new ones + delete the provisional ones
450        let stored_skp = backend
451            .key_store()
452            .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
453            .await
454            .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
455
456        match stored_skp.len().cmp(&tmp_ids.len()) {
457            std::cmp::Ordering::Less => return Err(Error::NoProvisionalIdentityFound),
458            std::cmp::Ordering::Greater => return Err(Error::TooManyIdentitiesPresent),
459            _ => {}
460        }
461
462        // we verify that the supplied temporary ids are all present in the keypairs we have in store
463        let all_tmp_ids_exist = stored_skp
464            .iter()
465            .all(|kp| tmp_ids.contains(&kp.credential_id.as_slice().into()));
466        if !all_tmp_ids_exist {
467            return Err(Error::NoProvisionalIdentityFound);
468        }
469
470        let identities = stored_skp.iter().zip(ciphersuites);
471
472        self.replace_inner(SessionInner {
473            id: client_id.clone(),
474            identities: Identities::new(stored_skp.len()),
475            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
476        })
477        .await;
478
479        let id = &client_id;
480
481        for (tmp_kp, &cs) in identities {
482            let scheme = tmp_kp
483                .signature_scheme
484                .try_into()
485                .map_err(|_| Error::InvalidSignatureScheme)?;
486            let new_keypair =
487                MlsSignatureKeyPair::new(scheme, tmp_kp.pk.clone(), tmp_kp.keypair.clone(), id.clone().into());
488
489            let new_credential = MlsCredential {
490                id: id.clone().into(),
491                credential: tmp_kp.credential_id.clone(),
492                created_at: 0,
493            };
494
495            // Delete the old identity optimistically
496            backend
497                .key_store()
498                .remove::<MlsSignatureKeyPair, &[u8]>(&new_keypair.pk)
499                .await
500                .map_err(KeystoreError::wrap("removing mls signature keypair"))?;
501
502            let signature_key = SignatureKeyPair::tls_deserialize(&mut new_keypair.keypair.as_slice())
503                .map_err(Error::tls_deserialize("signature key"))?;
504            let cb = CredentialBundle {
505                credential: Credential::new_basic(new_credential.credential.clone()),
506                signature_key,
507                created_at: 0, // this is fine setting a default value here, this will be set in `save_identity` to the current timestamp
508            };
509
510            // And now we save the new one
511            self.save_identity(&backend.keystore(), Some(id), cs.signature_algorithm(), cb)
512                .await?;
513        }
514
515        Ok(())
516    }
517
518    /// Generates a brand new client from scratch
519    pub(crate) async fn generate(
520        &self,
521        identifier: ClientIdentifier,
522        backend: &MlsCryptoProvider,
523        ciphersuites: &[MlsCiphersuite],
524        nb_key_package: usize,
525    ) -> Result<()> {
526        self.ensure_unready().await?;
527        let id = identifier.get_id()?;
528        let signature_schemes = ciphersuites
529            .iter()
530            .map(|cs| cs.signature_algorithm())
531            .collect::<HashSet<_>>();
532        self.replace_inner(SessionInner {
533            id: id.into_owned(),
534            identities: Identities::new(signature_schemes.len()),
535            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
536        })
537        .await;
538
539        let identities = identifier.generate_credential_bundles(backend, signature_schemes)?;
540
541        for (sc, id, cb) in identities {
542            self.save_identity(&backend.keystore(), Some(&id), sc, cb).await?;
543        }
544
545        let guard = self.inner.read().await;
546        let SessionInner { identities, .. } = guard.as_ref().ok_or(Error::MlsNotInitialized)?;
547
548        if nb_key_package != 0 {
549            for ciphersuite in ciphersuites.iter().copied() {
550                let ciphersuite_signature_scheme = ciphersuite.signature_algorithm();
551                for credential_bundle in identities.iter().filter_map(|(signature_scheme, credential_bundle)| {
552                    (signature_scheme == ciphersuite_signature_scheme).then_some(credential_bundle)
553                }) {
554                    let credential_type = credential_bundle.credential.credential_type().into();
555                    self.request_key_packages(nb_key_package, ciphersuite, credential_type, backend)
556                        .await?;
557                }
558            }
559        }
560
561        Ok(())
562    }
563
564    /// Loads the client from the keystore.
565    pub(crate) async fn load(
566        &self,
567        backend: &MlsCryptoProvider,
568        id: &ClientId,
569        mut credentials: Vec<(Credential, u64)>,
570        signature_schemes: HashSet<SignatureScheme>,
571    ) -> Result<()> {
572        self.ensure_unready().await?;
573        let mut identities = Identities::new(signature_schemes.len());
574
575        // ensures we load credentials in chronological order
576        credentials.sort_by_key(|(_, timestamp)| *timestamp);
577
578        let stored_signature_keypairs = backend
579            .key_store()
580            .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
581            .await
582            .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
583
584        for signature_scheme in signature_schemes {
585            let signature_keypair = stored_signature_keypairs
586                .iter()
587                .find(|skp| skp.signature_scheme == (signature_scheme as u16));
588
589            let signature_key = if let Some(kp) = signature_keypair {
590                SignatureKeyPair::tls_deserialize(&mut kp.keypair.as_slice())
591                    .map_err(Error::tls_deserialize("signature keypair"))?
592            } else {
593                let (private_key, public_key) = backend
594                    .crypto()
595                    .signature_key_gen(signature_scheme)
596                    .map_err(MlsError::wrap("generating signature key"))?;
597                let keypair = SignatureKeyPair::from_raw(signature_scheme, private_key, public_key.clone());
598                let raw_keypair = keypair
599                    .tls_serialize_detached()
600                    .map_err(Error::tls_serialize("raw keypair"))?;
601                let store_keypair =
602                    MlsSignatureKeyPair::new(signature_scheme, public_key, raw_keypair, id.as_slice().into());
603                backend
604                    .key_store()
605                    .save(store_keypair.clone())
606                    .await
607                    .map_err(KeystoreError::wrap("storing keypairs in keystore"))?;
608                SignatureKeyPair::tls_deserialize(&mut store_keypair.keypair.as_slice())
609                    .map_err(Error::tls_deserialize("signature keypair"))?
610            };
611
612            for (credential, created_at) in &credentials {
613                match credential.mls_credential() {
614                    openmls::prelude::MlsCredentialType::Basic(_) => {
615                        if id.as_slice() != credential.identity() {
616                            return Err(Error::WrongCredential);
617                        }
618                    }
619                    openmls::prelude::MlsCredentialType::X509(cert) => {
620                        let spk = cert
621                            .extract_public_key()
622                            .map_err(RecursiveError::mls_credential("extracting public key"))?
623                            .ok_or(LeafError::InternalMlsError)?;
624                        if signature_key.public() != spk {
625                            return Err(Error::WrongCredential);
626                        }
627                    }
628                };
629                let cb = CredentialBundle {
630                    credential: credential.clone(),
631                    signature_key: signature_key.clone(),
632                    created_at: *created_at,
633                };
634                identities.push_credential_bundle(signature_scheme, cb).await?;
635            }
636        }
637        self.replace_inner(SessionInner {
638            id: id.clone(),
639            identities,
640            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
641        })
642        .await;
643        Ok(())
644    }
645
646    /// Restore from an external [`HistorySecret`].
647    pub(crate) async fn restore_from_history_secret(&self, history_secret: HistorySecret) -> Result<()> {
648        self.ensure_unready().await?;
649
650        // store the client id (with some other stuff)
651        self.replace_inner(SessionInner {
652            id: history_secret.client_id.clone(),
653            identities: Identities::new(1),
654            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
655        })
656        .await;
657
658        // store the key package
659        let key_package = history_secret
660            .key_package
661            .store(&self.crypto_provider)
662            .await
663            .map_err(MlsError::wrap("storing key package encapsulation"))?;
664
665        let keystore = self.crypto_provider.key_store();
666
667        // store the credential bundle (with some other stuff)
668        self.save_identity(
669            keystore,
670            Some(&history_secret.client_id),
671            key_package.ciphersuite().signature_algorithm(),
672            history_secret.credential_bundle,
673        )
674        .await?;
675
676        Ok(())
677    }
678
679    async fn find_all_basic_credentials(backend: &MlsCryptoProvider) -> Result<Vec<Credential>> {
680        let store_credentials = backend
681            .key_store()
682            .find_all::<MlsCredential>(EntityFindParams::default())
683            .await
684            .map_err(KeystoreError::wrap("finding all mls credentialss"))?;
685        let mut credentials = Vec::with_capacity(store_credentials.len());
686        for store_credential in store_credentials.into_iter() {
687            let credential = Credential::tls_deserialize(&mut store_credential.credential.as_slice())
688                .map_err(Error::tls_deserialize("credential"))?;
689            if !matches!(credential.credential_type(), CredentialType::Basic) {
690                continue;
691            }
692            credentials.push(credential);
693        }
694
695        Ok(credentials)
696    }
697
698    pub(crate) async fn save_identity(
699        &self,
700        keystore: &Connection,
701        id: Option<&ClientId>,
702        signature_scheme: SignatureScheme,
703        mut credential_bundle: CredentialBundle,
704    ) -> Result<CredentialBundle> {
705        let mut guard = self.inner.write().await;
706        let SessionInner {
707            id: existing_id,
708            identities,
709            ..
710        } = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
711
712        let id = id.unwrap_or(existing_id);
713
714        let credential = credential_bundle
715            .credential
716            .tls_serialize_detached()
717            .map_err(Error::tls_serialize("credential bundle"))?;
718        let credential = MlsCredential {
719            id: id.clone().into(),
720            credential,
721            created_at: 0,
722        };
723
724        let credential = keystore
725            .save(credential)
726            .await
727            .map_err(KeystoreError::wrap("saving credential"))?;
728
729        let sign_kp = MlsSignatureKeyPair::new(
730            signature_scheme,
731            credential_bundle.signature_key.to_public_vec(),
732            credential_bundle
733                .signature_key
734                .tls_serialize_detached()
735                .map_err(Error::tls_serialize("signature keypair"))?,
736            id.clone().into(),
737        );
738        keystore.save(sign_kp).await.map_err(|e| match e {
739            CryptoKeystoreError::AlreadyExists => Error::CredentialBundleConflict,
740            _ => KeystoreError::wrap("saving mls signature key pair")(e).into(),
741        })?;
742
743        // set the creation date of the signature keypair which is the same for the CredentialBundle
744        credential_bundle.created_at = credential.created_at;
745
746        identities
747            .push_credential_bundle(signature_scheme, credential_bundle.clone())
748            .await?;
749
750        Ok(credential_bundle)
751    }
752
753    /// Retrieves the client's client id. This is free-form and not inspected.
754    pub async fn id(&self) -> Result<ClientId> {
755        match self.inner.read().await.deref() {
756            None => Err(Error::MlsNotInitialized),
757            Some(SessionInner { id, .. }) => Ok(id.clone()),
758        }
759    }
760
761    /// Returns whether this client is E2EI capable
762    pub async fn is_e2ei_capable(&self) -> bool {
763        match self.inner.read().await.deref() {
764            None => false,
765            Some(SessionInner { identities, .. }) => identities
766                .iter()
767                .any(|(_, cred)| cred.credential().credential_type() == CredentialType::X509),
768        }
769    }
770
771    pub(crate) async fn get_most_recent_or_create_credential_bundle(
772        &self,
773        backend: &MlsCryptoProvider,
774        sc: SignatureScheme,
775        ct: MlsCredentialType,
776    ) -> Result<Arc<CredentialBundle>> {
777        match ct {
778            MlsCredentialType::Basic => {
779                self.init_basic_credential_bundle_if_missing(backend, sc).await?;
780                self.find_most_recent_credential_bundle(sc, ct).await
781            }
782            MlsCredentialType::X509 => self
783                .find_most_recent_credential_bundle(sc, ct)
784                .await
785                .map_err(|e| match e {
786                    Error::CredentialNotFound(_) => LeafError::E2eiEnrollmentNotDone.into(),
787                    _ => e,
788                }),
789        }
790    }
791
792    pub(crate) async fn init_basic_credential_bundle_if_missing(
793        &self,
794        backend: &MlsCryptoProvider,
795        sc: SignatureScheme,
796    ) -> Result<()> {
797        let existing_cb = self
798            .find_most_recent_credential_bundle(sc, MlsCredentialType::Basic)
799            .await;
800        if matches!(existing_cb, Err(Error::CredentialNotFound(_))) {
801            let id = self.id().await?;
802            debug!(id:% = &id; "Initializing basic credential bundle");
803            let cb = Self::new_basic_credential_bundle(&id, sc, backend)?;
804            self.save_identity(&backend.keystore(), None, sc, cb).await?;
805        }
806        Ok(())
807    }
808
809    pub(crate) async fn save_new_x509_credential_bundle(
810        &self,
811        keystore: &Connection,
812        sc: SignatureScheme,
813        cb: CertificateBundle,
814    ) -> Result<CredentialBundle> {
815        let id = cb
816            .get_client_id()
817            .map_err(RecursiveError::mls_credential("getting client id"))?;
818        let cb = Self::new_x509_credential_bundle(cb)?;
819        self.save_identity(keystore, Some(&id), sc, cb).await
820    }
821}
822
823#[cfg(test)]
824mod tests {
825    use super::*;
826    use crate::prelude::ClientId;
827    use crate::test_utils::*;
828    use crate::transaction_context::test_utils::EntitiesCount;
829    use core_crypto_keystore::connection::{DatabaseKey, FetchFromDatabase};
830    use core_crypto_keystore::entities::*;
831    use mls_crypto_provider::MlsCryptoProvider;
832    use wasm_bindgen_test::*;
833
834    impl Session {
835        // test functions are not held to the same documentation standard as proper functions
836        #![allow(missing_docs)]
837
838        pub async fn random_generate(
839            &self,
840            case: &crate::test_utils::TestContext,
841            signer: Option<&crate::test_utils::x509::X509Certificate>,
842            provision: bool,
843        ) -> Result<()> {
844            self.reset().await;
845            let user_uuid = uuid::Uuid::new_v4();
846            let rnd_id = rand::random::<usize>();
847            let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
848            let identity = match case.credential_type {
849                MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.as_str().into()),
850                MlsCredentialType::X509 => {
851                    let signer = signer.expect("Missing intermediate CA");
852                    CertificateBundle::rand_identifier(&client_id, &[signer])
853                }
854            };
855            let nb_key_package = if provision {
856                crate::prelude::INITIAL_KEYING_MATERIAL_COUNT
857            } else {
858                0
859            };
860            let backend = self.crypto_provider.clone();
861            self.generate(identity, &backend, &[case.ciphersuite()], nb_key_package)
862                .await?;
863            Ok(())
864        }
865
866        pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
867            use core_crypto_keystore::CryptoKeystoreMls as _;
868            let kps = backend
869                .key_store()
870                .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
871                .await
872                .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
873            Ok(kps)
874        }
875
876        pub(crate) async fn init_x509_credential_bundle_if_missing(
877            &self,
878            backend: &MlsCryptoProvider,
879            sc: SignatureScheme,
880            cb: CertificateBundle,
881        ) -> Result<()> {
882            let existing_cb = self
883                .find_most_recent_credential_bundle(sc, MlsCredentialType::X509)
884                .await
885                .is_err();
886            if existing_cb {
887                self.save_new_x509_credential_bundle(&backend.keystore(), sc, cb)
888                    .await?;
889            }
890            Ok(())
891        }
892
893        pub(crate) async fn generate_one_keypackage(
894            &self,
895            backend: &MlsCryptoProvider,
896            cs: MlsCiphersuite,
897            ct: MlsCredentialType,
898        ) -> Result<openmls::prelude::KeyPackage> {
899            let cb = self
900                .find_most_recent_credential_bundle(cs.signature_algorithm(), ct)
901                .await?;
902            self.generate_one_keypackage_from_credential_bundle(backend, cs, &cb)
903                .await
904        }
905
906        /// Count the entities
907        pub async fn count_entities(&self) -> EntitiesCount {
908            let keystore = self.crypto_provider.keystore();
909            let credential = keystore.count::<MlsCredential>().await.unwrap();
910            let encryption_keypair = keystore.count::<MlsEncryptionKeyPair>().await.unwrap();
911            let epoch_encryption_keypair = keystore.count::<MlsEpochEncryptionKeyPair>().await.unwrap();
912            let enrollment = keystore.count::<E2eiEnrollment>().await.unwrap();
913            let group = keystore.count::<PersistedMlsGroup>().await.unwrap();
914            let hpke_private_key = keystore.count::<MlsHpkePrivateKey>().await.unwrap();
915            let key_package = keystore.count::<MlsKeyPackage>().await.unwrap();
916            let pending_group = keystore.count::<PersistedMlsPendingGroup>().await.unwrap();
917            let pending_messages = keystore.count::<MlsPendingMessage>().await.unwrap();
918            let psk_bundle = keystore.count::<MlsPskBundle>().await.unwrap();
919            let signature_keypair = keystore.count::<MlsSignatureKeyPair>().await.unwrap();
920            EntitiesCount {
921                credential,
922                encryption_keypair,
923                epoch_encryption_keypair,
924                enrollment,
925                group,
926                hpke_private_key,
927                key_package,
928                pending_group,
929                pending_messages,
930                psk_bundle,
931                signature_keypair,
932            }
933        }
934    }
935    wasm_bindgen_test_configure!(run_in_browser);
936
937    #[apply(all_cred_cipher)]
938    #[wasm_bindgen_test]
939    async fn can_generate_session(case: TestContext) {
940        let [alice] = case.sessions().await;
941        let key = DatabaseKey::generate();
942        let backend = MlsCryptoProvider::try_new_in_memory(&key).await.unwrap();
943        let x509_test_chain = if case.is_x509() {
944            let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
945            x509_test_chain.register_with_provider(&backend).await;
946            Some(x509_test_chain)
947        } else {
948            None
949        };
950        backend.new_transaction().await.unwrap();
951        let session = alice.session().await;
952        session
953            .random_generate(
954                &case,
955                x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
956                false,
957            )
958            .await
959            .unwrap();
960    }
961
962    #[apply(all_cred_cipher)]
963    #[wasm_bindgen_test]
964    async fn can_externally_generate_client(mut case: TestContext) {
965        let [alice] = case.sessions().await;
966        if !case.is_basic() {
967            return;
968        }
969        let tmp_dir = case.tmp_dir().await;
970        Box::pin(async move {
971            let key = DatabaseKey::generate();
972            let backend = MlsCryptoProvider::try_new(tmp_dir, &key).await.unwrap();
973            backend.new_transaction().await.unwrap();
974            // phase 1: generate standalone keypair
975            let client_id: ClientId = b"whatever:my:client:is@world.com".to_vec().into();
976            let alice = alice.session().await;
977            alice.reset().await;
978            // TODO: test with multi-ciphersuite. Tracking issue: WPB-9601
979            let handles = alice
980                .generate_raw_keypairs(&[case.ciphersuite()], &backend)
981                .await
982                .unwrap();
983
984            let mut identities = backend
985                .keystore()
986                .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
987                .await
988                .unwrap();
989
990            assert_eq!(identities.len(), 1);
991
992            let prov_identity = identities.pop().unwrap();
993
994            // Make sure we are actually returning the clientId
995            // TODO: test with multi-ciphersuite. Tracking issue: WPB-9601
996            let prov_client_id: ClientId = prov_identity.credential_id.as_slice().into();
997            assert_eq!(&prov_client_id, handles.first().unwrap());
998
999            // phase 2: pretend we have a new client ID from the backend, and try to init the client this way
1000            alice
1001                .init_with_external_client_id(client_id.clone(), handles.clone(), &[case.ciphersuite()], &backend)
1002                .await
1003                .unwrap();
1004
1005            // Make sure both client id and PK are intact
1006            assert_eq!(alice.id().await.unwrap(), client_id);
1007            let cb = alice
1008                .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
1009                .await
1010                .unwrap();
1011            let client_id: ClientId = cb.credential().identity().into();
1012            assert_eq!(&client_id, handles.first().unwrap());
1013        })
1014        .await
1015    }
1016}