core_crypto/transaction_context/
mod.rs

1//! This module contains the primitives to enable transactional support on a higher level within the
2//! [Session]. All mutating operations need to be done through a [TransactionContext].
3
4use std::{ops::Deref, sync::Arc};
5
6#[cfg(feature = "proteus")]
7use async_lock::Mutex;
8use async_lock::{RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
9use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
10pub use error::{Error, Result};
11use mls_crypto_provider::{Database, MlsCryptoProvider};
12use openmls_traits::OpenMlsCryptoProvider as _;
13
14#[cfg(feature = "proteus")]
15use crate::proteus::ProteusCentral;
16use crate::{
17    ClientId, ClientIdentifier, CoreCrypto, KeystoreError, MlsCiphersuite, MlsConversation, MlsCredentialType,
18    MlsError, MlsTransport, RecursiveError, Session, group_store::GroupStore, mls::HasSessionAndCrypto,
19};
20pub mod conversation;
21pub mod e2e_identity;
22mod error;
23pub mod key_package;
24#[cfg(feature = "proteus")]
25pub mod proteus;
26#[cfg(test)]
27pub mod test_utils;
28
29/// This struct provides transactional support for Core Crypto.
30///
31/// This struct provides mutable access to the internals of Core Crypto. Every operation that
32/// causes data to be persisted needs to be done through this struct. This struct will buffer all
33/// operations in memory and when [TransactionContext::finish] is called, it will persist the data into
34/// the keystore.
35#[derive(Debug, Clone)]
36pub struct TransactionContext {
37    inner: Arc<RwLock<TransactionContextInner>>,
38}
39
40/// Due to uniffi's design, we can't force the context to be dropped after the transaction is
41/// committed. To work around that we switch the value to `Invalid` when the context is finished
42/// and throw errors if something is called
43#[derive(Debug, Clone)]
44enum TransactionContextInner {
45    Valid {
46        provider: MlsCryptoProvider,
47        transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
48        mls_client: Session,
49        mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
50        #[cfg(feature = "proteus")]
51        proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
52    },
53    Invalid,
54}
55
56impl CoreCrypto {
57    /// Creates a new transaction. All operations that persist data will be
58    /// buffered in memory and when [TransactionContext::finish] is called, the data will be persisted
59    /// in a single database transaction.
60    pub async fn new_transaction(&self) -> Result<TransactionContext> {
61        TransactionContext::new(
62            &self.mls,
63            #[cfg(feature = "proteus")]
64            self.proteus.clone(),
65        )
66        .await
67    }
68}
69
70#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
71#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
72impl HasSessionAndCrypto for TransactionContext {
73    async fn session(&self) -> crate::mls::Result<Session> {
74        self.session()
75            .await
76            .map_err(RecursiveError::transaction("getting mls client"))
77            .map_err(Into::into)
78    }
79
80    async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
81        self.mls_provider()
82            .await
83            .map_err(RecursiveError::transaction("getting mls provider"))
84            .map_err(Into::into)
85    }
86}
87
88impl TransactionContext {
89    async fn new(
90        client: &Session,
91        #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
92    ) -> Result<Self> {
93        client
94            .crypto_provider
95            .new_transaction()
96            .await
97            .map_err(MlsError::wrap("creating new transaction"))?;
98        let mls_groups = Arc::new(RwLock::new(Default::default()));
99        let callbacks = client.transport.clone();
100        let mls_client = client.clone();
101        Ok(Self {
102            inner: Arc::new(
103                TransactionContextInner::Valid {
104                    mls_client,
105                    transport: callbacks,
106                    provider: client.crypto_provider.clone(),
107                    mls_groups,
108                    #[cfg(feature = "proteus")]
109                    proteus_central,
110                }
111                .into(),
112            ),
113        })
114    }
115
116    pub(crate) async fn session(&self) -> Result<Session> {
117        match self.inner.read().await.deref() {
118            TransactionContextInner::Valid { mls_client, .. } => Ok(mls_client.clone()),
119            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
120        }
121    }
122
123    pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
124        match self.inner.read().await.deref() {
125            TransactionContextInner::Valid {
126                transport: callbacks, ..
127            } => Ok(callbacks.read_arc().await),
128            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
129        }
130    }
131
132    #[cfg(test)]
133    pub(crate) async fn set_transport_callbacks(
134        &self,
135        callbacks: Option<Arc<dyn MlsTransport + 'static>>,
136    ) -> Result<()> {
137        match self.inner.read().await.deref() {
138            TransactionContextInner::Valid { transport: cbs, .. } => {
139                *cbs.write_arc().await = callbacks;
140                Ok(())
141            }
142            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
143        }
144    }
145
146    /// Clones all references that the [MlsCryptoProvider] comprises.
147    pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
148        match self.inner.read().await.deref() {
149            TransactionContextInner::Valid { provider, .. } => Ok(provider.clone()),
150            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
151        }
152    }
153
154    pub(crate) async fn keystore(&self) -> Result<Database> {
155        match self.inner.read().await.deref() {
156            TransactionContextInner::Valid { provider, .. } => Ok(provider.keystore()),
157            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
158        }
159    }
160
161    pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
162        match self.inner.read().await.deref() {
163            TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
164            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
165        }
166    }
167
168    #[cfg(feature = "proteus")]
169    pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
170        match self.inner.read().await.deref() {
171            TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
172            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
173        }
174    }
175
176    /// Commits the transaction, meaning it takes all the enqueued operations and persist them into
177    /// the keystore. After that the internal state is switched to invalid, causing errors if
178    /// something is called from this object.
179    pub async fn finish(&self) -> Result<()> {
180        let mut guard = self.inner.write().await;
181        let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
182            return Err(Error::InvalidTransactionContext);
183        };
184
185        let commit_result = provider
186            .keystore()
187            .commit_transaction()
188            .await
189            .map_err(KeystoreError::wrap("commiting transaction"))
190            .map_err(Into::into);
191
192        *guard = TransactionContextInner::Invalid;
193        commit_result
194    }
195
196    /// Aborts the transaction, meaning it discards all the enqueued operations.
197    /// After that the internal state is switched to invalid, causing errors if
198    /// something is called from this object.
199    pub async fn abort(&self) -> Result<()> {
200        let mut guard = self.inner.write().await;
201
202        let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
203            return Err(Error::InvalidTransactionContext);
204        };
205
206        let result = provider
207            .keystore()
208            .rollback_transaction()
209            .await
210            .map_err(KeystoreError::wrap("rolling back transaction"))
211            .map_err(Into::into);
212
213        *guard = TransactionContextInner::Invalid;
214        result
215    }
216
217    /// Initializes the MLS client of [super::CoreCrypto].
218    pub async fn mls_init(&self, identifier: ClientIdentifier, ciphersuites: &[MlsCiphersuite]) -> Result<()> {
219        let mls_client = self.session().await?;
220        mls_client
221            .init(identifier, ciphersuites, &self.mls_provider().await?)
222            .await
223            .map_err(RecursiveError::mls_client("initializing mls client"))?;
224
225        if mls_client.is_e2ei_capable().await {
226            let client_id = mls_client
227                .id()
228                .await
229                .map_err(RecursiveError::mls_client("getting client id"))?;
230            log::trace!(client_id:% = client_id; "Initializing PKI environment");
231            self.init_pki_env().await?;
232        }
233
234        Ok(())
235    }
236
237    /// Returns the client's public key.
238    pub async fn client_public_key(
239        &self,
240        ciphersuite: MlsCiphersuite,
241        credential_type: MlsCredentialType,
242    ) -> Result<Vec<u8>> {
243        let cb = self
244            .session()
245            .await?
246            .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
247            .await
248            .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
249        Ok(cb.signature_key.to_public_vec())
250    }
251
252    /// see [Session::id]
253    pub async fn client_id(&self) -> Result<ClientId> {
254        self.session()
255            .await?
256            .id()
257            .await
258            .map_err(RecursiveError::mls_client("getting client id"))
259            .map_err(Into::into)
260    }
261
262    /// Generates a random byte array of the specified size
263    pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
264        use openmls_traits::random::OpenMlsRand as _;
265        self.mls_provider()
266            .await?
267            .rand()
268            .random_vec(len)
269            .map_err(MlsError::wrap("generating random vector"))
270            .map_err(Into::into)
271    }
272
273    /// Set arbitrary data to be retrieved by [TransactionContext::get_data].
274    /// This is meant to be used as a check point at the end of a transaction.
275    /// The data should be limited to a reasonable size.
276    pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
277        self.keystore()
278            .await?
279            .save(ConsumerData::from(data))
280            .await
281            .map_err(KeystoreError::wrap("saving consumer data"))?;
282        Ok(())
283    }
284
285    /// Get the data that has previously been set by [TransactionContext::set_data].
286    /// This is meant to be used as a check point at the end of a transaction.
287    pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
288        match self.keystore().await?.find_unique::<ConsumerData>().await {
289            Ok(data) => Ok(Some(data.into())),
290            Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
291            Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
292        }
293    }
294}