core_crypto/transaction_context/
mod.rs

1//! This module contains the primitives to enable transactional support on a higher level within the
2//! [Session]. All mutating operations need to be done through a [TransactionContext].
3
4#[cfg(feature = "proteus")]
5use crate::proteus::ProteusCentral;
6use crate::{
7    CoreCrypto, KeystoreError, MlsError, MlsTransport, RecursiveError,
8    group_store::GroupStore,
9    prelude::{ClientId, INITIAL_KEYING_MATERIAL_COUNT, MlsConversation, MlsCredentialType, Session},
10};
11use crate::{
12    mls::HasSessionAndCrypto,
13    prelude::{ClientIdentifier, MlsCiphersuite},
14};
15#[cfg(feature = "proteus")]
16use async_lock::Mutex;
17use async_lock::{RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
18use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
19pub use error::{Error, Result};
20use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
21use openmls_traits::OpenMlsCryptoProvider as _;
22use std::{ops::Deref, sync::Arc};
23pub mod conversation;
24pub mod e2e_identity;
25mod error;
26pub mod key_package;
27#[cfg(feature = "proteus")]
28pub mod proteus;
29#[cfg(test)]
30pub mod test_utils;
31
32/// This struct provides transactional support for Core Crypto.
33///
34/// This struct provides mutable access to the internals of Core Crypto. Every operation that
35/// causes data to be persisted needs to be done through this struct. This struct will buffer all
36/// operations in memory and when [TransactionContext::finish] is called, it will persist the data into
37/// the keystore.
38#[derive(Debug, Clone)]
39pub struct TransactionContext {
40    inner: Arc<RwLock<TransactionContextInner>>,
41}
42
43/// Due to uniffi's design, we can't force the context to be dropped after the transaction is
44/// committed. To work around that we switch the value to `Invalid` when the context is finished
45/// and throw errors if something is called
46#[derive(Debug, Clone)]
47enum TransactionContextInner {
48    Valid {
49        provider: MlsCryptoProvider,
50        transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
51        mls_client: Session,
52        mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
53        #[cfg(feature = "proteus")]
54        proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
55    },
56    Invalid,
57}
58
59impl CoreCrypto {
60    /// Creates a new transaction. All operations that persist data will be
61    /// buffered in memory and when [TransactionContext::finish] is called, the data will be persisted
62    /// in a single database transaction.
63    pub async fn new_transaction(&self) -> Result<TransactionContext> {
64        TransactionContext::new(
65            &self.mls,
66            #[cfg(feature = "proteus")]
67            self.proteus.clone(),
68        )
69        .await
70    }
71}
72
73#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
74#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
75impl HasSessionAndCrypto for TransactionContext {
76    async fn session(&self) -> crate::mls::Result<Session> {
77        self.session()
78            .await
79            .map_err(RecursiveError::transaction("getting mls client"))
80            .map_err(Into::into)
81    }
82
83    async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
84        self.mls_provider()
85            .await
86            .map_err(RecursiveError::transaction("getting mls provider"))
87            .map_err(Into::into)
88    }
89}
90
91impl TransactionContext {
92    async fn new(
93        client: &Session,
94        #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
95    ) -> Result<Self> {
96        client
97            .crypto_provider
98            .new_transaction()
99            .await
100            .map_err(MlsError::wrap("creating new transaction"))?;
101        let mls_groups = Arc::new(RwLock::new(Default::default()));
102        let callbacks = client.transport.clone();
103        let mls_client = client.clone();
104        Ok(Self {
105            inner: Arc::new(
106                TransactionContextInner::Valid {
107                    mls_client,
108                    transport: callbacks,
109                    provider: client.crypto_provider.clone(),
110                    mls_groups,
111                    #[cfg(feature = "proteus")]
112                    proteus_central,
113                }
114                .into(),
115            ),
116        })
117    }
118
119    pub(crate) async fn session(&self) -> Result<Session> {
120        match self.inner.read().await.deref() {
121            TransactionContextInner::Valid { mls_client, .. } => Ok(mls_client.clone()),
122            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
123        }
124    }
125
126    pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
127        match self.inner.read().await.deref() {
128            TransactionContextInner::Valid {
129                transport: callbacks, ..
130            } => Ok(callbacks.read_arc().await),
131            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
132        }
133    }
134
135    #[cfg(test)]
136    pub(crate) async fn set_transport_callbacks(
137        &self,
138        callbacks: Option<Arc<dyn MlsTransport + 'static>>,
139    ) -> Result<()> {
140        match self.inner.read().await.deref() {
141            TransactionContextInner::Valid { transport: cbs, .. } => {
142                *cbs.write_arc().await = callbacks;
143                Ok(())
144            }
145            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
146        }
147    }
148
149    /// Clones all references that the [MlsCryptoProvider] comprises.
150    pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
151        match self.inner.read().await.deref() {
152            TransactionContextInner::Valid { provider, .. } => Ok(provider.clone()),
153            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
154        }
155    }
156
157    pub(crate) async fn keystore(&self) -> Result<CryptoKeystore> {
158        match self.inner.read().await.deref() {
159            TransactionContextInner::Valid { provider, .. } => Ok(provider.keystore()),
160            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
161        }
162    }
163
164    pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
165        match self.inner.read().await.deref() {
166            TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
167            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
168        }
169    }
170
171    #[cfg(feature = "proteus")]
172    pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
173        match self.inner.read().await.deref() {
174            TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
175            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
176        }
177    }
178
179    /// Commits the transaction, meaning it takes all the enqueued operations and persist them into
180    /// the keystore. After that the internal state is switched to invalid, causing errors if
181    /// something is called from this object.
182    pub async fn finish(&self) -> Result<()> {
183        let mut guard = self.inner.write().await;
184        let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
185            return Err(Error::InvalidTransactionContext);
186        };
187
188        let commit_result = provider
189            .keystore()
190            .commit_transaction()
191            .await
192            .map_err(KeystoreError::wrap("commiting transaction"))
193            .map_err(Into::into);
194
195        *guard = TransactionContextInner::Invalid;
196        commit_result
197    }
198
199    /// Aborts the transaction, meaning it discards all the enqueued operations.
200    /// After that the internal state is switched to invalid, causing errors if
201    /// something is called from this object.
202    pub async fn abort(&self) -> Result<()> {
203        let mut guard = self.inner.write().await;
204
205        let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
206            return Err(Error::InvalidTransactionContext);
207        };
208
209        let result = provider
210            .keystore()
211            .rollback_transaction()
212            .await
213            .map_err(KeystoreError::wrap("rolling back transaction"))
214            .map_err(Into::into);
215
216        *guard = TransactionContextInner::Invalid;
217        result
218    }
219
220    /// Initializes the MLS client if [super::CoreCrypto] has previously been initialized with
221    /// `CoreCrypto::deferred_init` instead of `CoreCrypto::new`.
222    /// This should stay as long as proteus is supported. Then it should be removed.
223    pub async fn mls_init(
224        &self,
225        identifier: ClientIdentifier,
226        ciphersuites: Vec<MlsCiphersuite>,
227        nb_init_key_packages: Option<usize>,
228    ) -> Result<()> {
229        let nb_key_package = nb_init_key_packages.unwrap_or(INITIAL_KEYING_MATERIAL_COUNT);
230        let mls_client = self.session().await?;
231        mls_client
232            .init(identifier, &ciphersuites, &self.mls_provider().await?, nb_key_package)
233            .await
234            .map_err(RecursiveError::mls_client("initializing mls client"))?;
235
236        if mls_client.is_e2ei_capable().await {
237            let client_id = mls_client
238                .id()
239                .await
240                .map_err(RecursiveError::mls_client("getting client id"))?;
241            log::trace!(client_id:% = client_id; "Initializing PKI environment");
242            self.init_pki_env().await?;
243        }
244
245        Ok(())
246    }
247
248    /// Returns the client's public key.
249    pub async fn client_public_key(
250        &self,
251        ciphersuite: MlsCiphersuite,
252        credential_type: MlsCredentialType,
253    ) -> Result<Vec<u8>> {
254        let cb = self
255            .session()
256            .await?
257            .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
258            .await
259            .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
260        Ok(cb.signature_key.to_public_vec())
261    }
262
263    /// see [Session::id]
264    pub async fn client_id(&self) -> Result<ClientId> {
265        self.session()
266            .await?
267            .id()
268            .await
269            .map_err(RecursiveError::mls_client("getting client id"))
270            .map_err(Into::into)
271    }
272
273    /// Generates a random byte array of the specified size
274    pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
275        use openmls_traits::random::OpenMlsRand as _;
276        self.mls_provider()
277            .await?
278            .rand()
279            .random_vec(len)
280            .map_err(MlsError::wrap("generating random vector"))
281            .map_err(Into::into)
282    }
283
284    /// Set arbitrary data to be retrieved by [TransactionContext::get_data].
285    /// This is meant to be used as a check point at the end of a transaction.
286    /// The data should be limited to a reasonable size.
287    pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
288        self.keystore()
289            .await?
290            .save(ConsumerData::from(data))
291            .await
292            .map_err(KeystoreError::wrap("saving consumer data"))?;
293        Ok(())
294    }
295
296    /// Get the data that has previously been set by [TransactionContext::set_data].
297    /// This is meant to be used as a check point at the end of a transaction.
298    pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
299        match self.keystore().await?.find_unique::<ConsumerData>().await {
300            Ok(data) => Ok(Some(data.into())),
301            Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
302            Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
303        }
304    }
305}