core_crypto/transaction_context/
mod.rs

1//! This module contains the primitives to enable transactional support on a higher level within the
2//! [Session]. All mutating operations need to be done through a [TransactionContext].
3
4use std::sync::Arc;
5
6#[cfg(feature = "proteus")]
7use async_lock::Mutex;
8use async_lock::{RwLock, RwLockWriteGuardArc};
9use core_crypto_keystore::{CryptoKeystoreError, entities::ConsumerData, traits::FetchFromDatabase as _};
10pub use error::{Error, Result};
11use mls_crypto_provider::{Database, MlsCryptoProvider};
12use openmls_traits::OpenMlsCryptoProvider as _;
13
14#[cfg(feature = "proteus")]
15use crate::proteus::ProteusCentral;
16use crate::{
17    Ciphersuite, ClientId, ClientIdentifier, CoreCrypto, Credential, CredentialFindFilters, CredentialRef,
18    CredentialType, KeystoreError, MlsConversation, MlsError, MlsTransport, RecursiveError, Session,
19    group_store::GroupStore,
20    mls::{
21        self, HasSessionAndCrypto,
22        session::{Error as SessionError, identities::Identities},
23    },
24};
25pub mod conversation;
26pub mod e2e_identity;
27mod error;
28pub mod key_package;
29#[cfg(feature = "proteus")]
30pub mod proteus;
31#[cfg(test)]
32pub mod test_utils;
33
34/// This struct provides transactional support for Core Crypto.
35///
36/// This struct provides mutable access to the internals of Core Crypto. Every operation that
37/// causes data to be persisted needs to be done through this struct. This struct will buffer all
38/// operations in memory and when [TransactionContext::finish] is called, it will persist the data into
39/// the keystore.
40#[derive(Debug, Clone)]
41pub struct TransactionContext {
42    inner: Arc<RwLock<TransactionContextInner>>,
43}
44
45/// Due to uniffi's design, we can't force the context to be dropped after the transaction is
46/// committed. To work around that we switch the value to `Invalid` when the context is finished
47/// and throw errors if something is called
48#[derive(Debug, Clone)]
49enum TransactionContextInner {
50    Valid {
51        keystore: Database,
52        mls_session: Arc<RwLock<Option<Session>>>,
53        mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
54        #[cfg(feature = "proteus")]
55        proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
56    },
57    Invalid,
58}
59
60impl CoreCrypto {
61    /// Creates a new transaction. All operations that persist data will be
62    /// buffered in memory and when [TransactionContext::finish] is called, the data will be persisted
63    /// in a single database transaction.
64    pub async fn new_transaction(&self) -> Result<TransactionContext> {
65        TransactionContext::new(
66            self.database.clone(),
67            self.mls.clone(),
68            #[cfg(feature = "proteus")]
69            self.proteus.clone(),
70        )
71        .await
72    }
73}
74
75#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
76#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
77impl HasSessionAndCrypto for TransactionContext {
78    async fn session(&self) -> crate::mls::Result<Session> {
79        self.session()
80            .await
81            .map_err(RecursiveError::transaction("getting mls client"))
82            .map_err(Into::into)
83    }
84
85    async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
86        self.mls_provider()
87            .await
88            .map_err(RecursiveError::transaction("getting mls provider"))
89            .map_err(Into::into)
90    }
91}
92
93impl TransactionContext {
94    async fn new(
95        keystore: Database,
96        mls_session: Arc<RwLock<Option<Session>>>,
97        #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
98    ) -> Result<Self> {
99        keystore
100            .new_transaction()
101            .await
102            .map_err(MlsError::wrap("creating new transaction"))?;
103        let mls_groups = Arc::new(RwLock::new(Default::default()));
104        Ok(Self {
105            inner: Arc::new(
106                TransactionContextInner::Valid {
107                    keystore,
108                    mls_session: mls_session.clone(),
109                    mls_groups,
110                    #[cfg(feature = "proteus")]
111                    proteus_central,
112                }
113                .into(),
114            ),
115        })
116    }
117
118    pub(crate) async fn session(&self) -> Result<Session> {
119        match &*self.inner.read().await {
120            TransactionContextInner::Valid { mls_session, .. } => {
121                if let Some(session) = mls_session.read().await.as_ref() {
122                    return Ok(session.clone());
123                }
124                Err(mls::session::Error::MlsNotInitialized)
125                    .map_err(RecursiveError::mls_client(
126                        "Getting mls session from transaction context",
127                    ))
128                    .map_err(Into::into)
129            }
130            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
131        }
132    }
133
134    #[cfg(test)]
135    pub(crate) async fn set_session_if_exists(&self, new_session: Session) {
136        match &*self.inner.read().await {
137            TransactionContextInner::Valid { mls_session, .. } => {
138                let mut guard = mls_session.write().await;
139
140                if guard.as_ref().is_some() {
141                    *guard = Some(new_session)
142                }
143            }
144            TransactionContextInner::Invalid => {}
145        }
146    }
147
148    pub(crate) async fn mls_transport(&self) -> Result<Arc<dyn MlsTransport + 'static>> {
149        match &*self.inner.read().await {
150            TransactionContextInner::Valid { mls_session, .. } => {
151                if let Some(session) = mls_session.read().await.as_ref() {
152                    let transport = session.transport.clone();
153                    return Ok(transport);
154                }
155                Err(mls::session::Error::MlsNotInitialized)
156                    .map_err(RecursiveError::mls_client(
157                        "Getting mls session from transaction context",
158                    ))
159                    .map_err(Into::into)
160            }
161
162            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
163        }
164    }
165
166    /// Clones all references that the [MlsCryptoProvider] comprises.
167    pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
168        match &*self.inner.read().await {
169            TransactionContextInner::Valid { mls_session, .. } => {
170                if let Some(session) = mls_session.read().await.as_ref() {
171                    return Ok(session.crypto_provider.clone());
172                }
173                Err(mls::session::Error::MlsNotInitialized)
174                    .map_err(RecursiveError::mls_client(
175                        "Getting mls session from transaction context",
176                    ))
177                    .map_err(Into::into)
178            }
179            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
180        }
181    }
182
183    pub(crate) async fn keystore(&self) -> Result<Database> {
184        match &*self.inner.read().await {
185            TransactionContextInner::Valid { keystore, .. } => Ok(keystore.clone()),
186            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
187        }
188    }
189
190    pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
191        match &*self.inner.read().await {
192            TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
193            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
194        }
195    }
196
197    #[cfg(feature = "proteus")]
198    pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
199        match &*self.inner.read().await {
200            TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
201            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
202        }
203    }
204
205    /// Commits the transaction, meaning it takes all the enqueued operations and persist them into
206    /// the keystore. After that the internal state is switched to invalid, causing errors if
207    /// something is called from this object.
208    pub async fn finish(&self) -> Result<()> {
209        let mut guard = self.inner.write().await;
210        let TransactionContextInner::Valid { keystore, .. } = &*guard else {
211            return Err(Error::InvalidTransactionContext);
212        };
213
214        let commit_result = keystore
215            .commit_transaction()
216            .await
217            .map_err(KeystoreError::wrap("commiting transaction"))
218            .map_err(Into::into);
219
220        *guard = TransactionContextInner::Invalid;
221        commit_result
222    }
223
224    /// Aborts the transaction, meaning it discards all the enqueued operations.
225    /// After that the internal state is switched to invalid, causing errors if
226    /// something is called from this object.
227    pub async fn abort(&self) -> Result<()> {
228        let mut guard = self.inner.write().await;
229
230        let TransactionContextInner::Valid { keystore, .. } = &*guard else {
231            return Err(Error::InvalidTransactionContext);
232        };
233
234        let result = keystore
235            .rollback_transaction()
236            .await
237            .map_err(KeystoreError::wrap("rolling back transaction"))
238            .map_err(Into::into);
239
240        *guard = TransactionContextInner::Invalid;
241        result
242    }
243
244    /// Loads any cryptographic material already present in the keystore, but does not create any.
245    /// If no credentials are present in the keystore, then one _must_ be created and added to the
246    /// session before it can be used.
247    async fn init(&self, identifier: ClientIdentifier, ciphersuites: &[Ciphersuite]) -> Result<(ClientId, Identities)> {
248        let database = self.keystore().await?;
249        let client_id = identifier
250            .get_id()
251            .map_err(RecursiveError::mls_client("getting client id"))?
252            .into_owned();
253
254        let signature_schemes = &ciphersuites
255            .iter()
256            .map(|ciphersuite| ciphersuite.signature_algorithm())
257            .collect::<Vec<_>>();
258
259        // we want to find all credentials matching this identifier, having a valid signature scheme.
260        // the `CredentialRef::find` API doesn't allow us to easily find those credentials having
261        // one of a set of signature schemes, meaning we have two paths here:
262        // we could either search unbound by signature schemes and then filter for valid ones here,
263        // or we could iterate over the list of signature schemes and build up a set of credential refs.
264        // as there are only a few signature schemes possible and the cost of a find operation is non-trivial,
265        // we choose the first option.
266        // we might revisit this choice after WPB-20844 and WPB-21819.
267        let mut credential_refs = CredentialRef::find(
268            &database,
269            CredentialFindFilters::builder().client_id(&client_id).build(),
270        )
271        .await
272        .map_err(RecursiveError::mls_credential_ref(
273            "loading matching credential refs while initializing a client",
274        ))?;
275        credential_refs.retain(|credential_ref| signature_schemes.contains(&credential_ref.signature_scheme()));
276
277        let mut identities = Identities::new(credential_refs.len());
278        let credentials_cache =
279            CredentialRef::load_stored_credentials(&database)
280                .await
281                .map_err(RecursiveError::mls_credential_ref(
282                    "loading credential ref cache while initializing session",
283                ))?;
284
285        for credential_ref in credential_refs {
286            if let Some(credential) =
287                credential_ref
288                    .load_from_cache(&credentials_cache)
289                    .map_err(RecursiveError::mls_credential_ref(
290                        "loading credential list in session init",
291                    ))?
292            {
293                match identities.push_credential(credential).await {
294                    Err(SessionError::CredentialConflict) => {
295                        // this is what we get for not having real primary keys in our DB
296                        // no harm done though; no need to propagate this error
297                    }
298                    Ok(_) => {}
299                    Err(err) => {
300                        return Err(RecursiveError::MlsClient {
301                            context: "adding credential to identities in init",
302                            source: Box::new(err),
303                        }
304                        .into());
305                    }
306                }
307            }
308        }
309
310        Ok((client_id, identities))
311    }
312
313    /// Initializes the MLS client of [super::CoreCrypto].
314    pub async fn mls_init(
315        &self,
316        identifier: ClientIdentifier,
317        ciphersuites: &[Ciphersuite],
318        transport: Arc<dyn MlsTransport>,
319    ) -> Result<()> {
320        let database = self.keystore().await?;
321        let (client_id, identities) = self.init(identifier, ciphersuites).await?;
322
323        let mls_backend = MlsCryptoProvider::new(database);
324        let session = Session::new(client_id.clone(), identities, mls_backend, transport);
325
326        if session.is_e2ei_capable().await {
327            log::trace!(client_id:% = client_id; "Initializing PKI environment");
328            self.init_pki_env().await?;
329        }
330
331        self.set_mls_session(session).await?;
332
333        Ok(())
334    }
335
336    /// Set the `mls_session` Arc (also sets it on the transaction's CoreCrypto instance)
337    pub(crate) async fn set_mls_session(&self, session: Session) -> Result<()> {
338        match &*self.inner.read().await {
339            TransactionContextInner::Valid { mls_session, .. } => {
340                let mut guard = mls_session.write().await;
341                *guard = Some(session);
342                Ok(())
343            }
344            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
345        }
346    }
347
348    /// Returns the client's public key.
349    pub async fn client_public_key(
350        &self,
351        ciphersuite: Ciphersuite,
352        credential_type: CredentialType,
353    ) -> Result<Vec<u8>> {
354        let cb = self
355            .session()
356            .await?
357            .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type)
358            .await
359            .map_err(RecursiveError::mls_client("finding most recent credential"))?;
360        Ok(cb.signature_key_pair.to_public_vec())
361    }
362
363    /// see [Session::id]
364    pub async fn client_id(&self) -> Result<ClientId> {
365        let session = self.session().await?;
366        Ok(session.id())
367    }
368
369    /// Generates a random byte array of the specified size
370    pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
371        use openmls_traits::random::OpenMlsRand as _;
372        self.mls_provider()
373            .await?
374            .rand()
375            .random_vec(len)
376            .map_err(MlsError::wrap("generating random vector"))
377            .map_err(Into::into)
378    }
379
380    /// Set arbitrary data to be retrieved by [TransactionContext::get_data].
381    /// This is meant to be used as a check point at the end of a transaction.
382    /// The data should be limited to a reasonable size.
383    pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
384        self.keystore()
385            .await?
386            .save(ConsumerData::from(data))
387            .await
388            .map_err(KeystoreError::wrap("saving consumer data"))?;
389        Ok(())
390    }
391
392    /// Get the data that has previously been set by [TransactionContext::set_data].
393    /// This is meant to be used as a check point at the end of a transaction.
394    pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
395        match self.keystore().await?.get_unique::<ConsumerData>().await {
396            Ok(maybe_data) => Ok(maybe_data.map(Into::into)),
397            Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
398            Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
399        }
400    }
401
402    /// Add a credential to the identities of this session.
403    ///
404    /// As a side effect, stores the credential in the keystore.
405    pub async fn add_credential(&self, credential: Credential) -> Result<CredentialRef> {
406        self.session()
407            .await?
408            .add_credential(credential)
409            .await
410            .map_err(RecursiveError::mls_client("adding credential to session"))
411            .map_err(Into::into)
412    }
413
414    /// Remove a credential from the identities of this session.
415    ///
416    /// As a side effect, delete the credential from the keystore.
417    ///
418    /// Removes both the credential itself and also any key packages which were generated from it.
419    pub async fn remove_credential(&self, credential_ref: &CredentialRef) -> Result<()> {
420        self.session()
421            .await?
422            .remove_credential(credential_ref)
423            .await
424            .map_err(RecursiveError::mls_client("removing credential from session"))
425            .map_err(Into::into)
426    }
427
428    /// Find credentials matching the find filters among the identities of this session
429    ///
430    /// Note that finding credentials with no filters set is equivalent to [`Self::get_credentials`].
431    pub async fn find_credentials(&self, find_filters: CredentialFindFilters<'_>) -> Result<Vec<CredentialRef>> {
432        self.session()
433            .await?
434            .find_credentials(find_filters)
435            .await
436            .map_err(RecursiveError::mls_client("finding credentials by filter"))
437            .map_err(Into::into)
438    }
439
440    /// Get all credentials from the identities of this session.
441    ///
442    /// To get specific credentials, it can be more efficient to use [`Self::find_credentials`].
443    pub async fn get_credentials(&self) -> Result<Vec<CredentialRef>> {
444        self.session()
445            .await?
446            .get_credentials()
447            .await
448            .map_err(RecursiveError::mls_client("getting all credentials"))
449            .map_err(Into::into)
450    }
451}