Skip to main content

core_crypto/transaction_context/
mod.rs

1//! This module contains the primitives to enable transactional support on a higher level within the
2//! [Session]. All mutating operations need to be done through a [TransactionContext].
3
4use std::sync::Arc;
5
6use async_lock::{Mutex, MutexGuardArc, RwLock};
7use core_crypto_keystore::{CryptoKeystoreError, entities::ConsumerData, traits::FetchFromDatabase as _};
8pub use error::{Error, Result};
9use openmls_traits::OpenMlsCryptoProvider as _;
10use wire_e2e_identity::pki_env::PkiEnvironment;
11
12use crate::{
13    ClientId, ConversationId, CoreCrypto, KeystoreError, MlsTransport, OpenMlsError, RecursiveError, Session,
14    mls::{self, conversation_cache::ConversationCache},
15    mls_provider::{CryptoProvider, Database},
16};
17pub mod conversation;
18mod credential;
19pub mod e2e_identity;
20mod error;
21pub mod key_package;
22#[cfg(feature = "proteus")]
23pub mod proteus;
24#[cfg(test)]
25pub mod test_utils;
26
27/// This struct provides transactional support for Core Crypto.
28///
29/// This struct provides mutable access to the internals of Core Crypto. Every operation that
30/// causes data to be persisted needs to be done through this struct. This struct will buffer all
31/// operations in memory and when [TransactionContext::finish] is called, it will persist the data into
32/// the keystore.
33#[derive(Debug, Clone)]
34pub struct TransactionContext {
35    inner: Arc<RwLock<TransactionContextInner>>,
36}
37
38/// Due to uniffi's design, we can't force the context to be dropped after the transaction is
39/// committed. To work around that we switch the value to `Invalid` when the context is finished
40/// and throw errors if something is called
41#[derive(Debug, Clone)]
42enum TransactionContextInner {
43    Valid {
44        core_crypto: Arc<CoreCrypto>,
45        pending_epoch_changes: Arc<Mutex<Vec<(ConversationId, u64)>>>,
46    },
47    Invalid,
48}
49
50impl CoreCrypto {
51    /// Creates a new transaction. All operations that persist data will be
52    /// buffered in memory and when [TransactionContext::finish] is called, the data will be persisted
53    /// in a single database transaction.
54    pub async fn new_transaction(self: &Arc<Self>) -> Result<TransactionContext> {
55        TransactionContext::new(self.clone()).await
56    }
57}
58
59impl TransactionContext {
60    async fn new(core_crypto: Arc<CoreCrypto>) -> Result<Self> {
61        core_crypto
62            .database
63            .new_transaction()
64            .await
65            .map_err(OpenMlsError::wrap("creating new transaction"))?;
66        Ok(Self {
67            inner: Arc::new(
68                TransactionContextInner::Valid {
69                    core_crypto,
70                    pending_epoch_changes: Default::default(),
71                }
72                .into(),
73            ),
74        })
75    }
76
77    pub(crate) async fn session(&self) -> Result<Session> {
78        let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else {
79            return Err(Error::InvalidTransactionContext);
80        };
81        core_crypto.mls.read().await.as_ref().cloned().ok_or(
82            RecursiveError::mls_client("Getting mls session from transaction context")(
83                mls::session::Error::MlsNotInitialized,
84            )
85            .into(),
86        )
87    }
88
89    #[cfg(test)]
90    pub(crate) async fn set_session_if_exists(&self, new_session: Session) {
91        let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else {
92            return;
93        };
94
95        let mut guard = core_crypto.mls.write().await;
96        if guard.as_ref().is_some() {
97            *guard = Some(new_session)
98        }
99    }
100
101    pub(crate) async fn mls_transport(&self) -> Result<Arc<dyn MlsTransport + 'static>> {
102        match &*self.inner.read().await {
103            TransactionContextInner::Valid { core_crypto, .. } => core_crypto
104                .mls
105                .read()
106                .await
107                .as_ref()
108                .map(|s| s.transport.clone())
109                .ok_or(
110                    RecursiveError::mls_client("Getting mls session from transaction context")(
111                        mls::session::Error::MlsNotInitialized,
112                    )
113                    .into(),
114                ),
115
116            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
117        }
118    }
119
120    /// Clones the [CryptoProvider].
121    pub async fn crypto_provider(&self) -> Result<CryptoProvider> {
122        match &*self.inner.read().await {
123            TransactionContextInner::Valid { core_crypto, .. } => core_crypto
124                .mls
125                .read()
126                .await
127                .as_ref()
128                .map(|s| s.crypto_provider.clone())
129                .ok_or(
130                    RecursiveError::mls_client("Getting mls session from transaction context")(
131                        mls::session::Error::MlsNotInitialized,
132                    )
133                    .into(),
134                ),
135            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
136        }
137    }
138
139    pub(crate) async fn database(&self) -> Result<Database> {
140        match &*self.inner.read().await {
141            TransactionContextInner::Valid { core_crypto, .. } => Ok(core_crypto.database.clone()),
142            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
143        }
144    }
145
146    pub(crate) async fn pki_environment(&self) -> Result<Arc<PkiEnvironment>> {
147        match &*self.inner.read().await {
148            TransactionContextInner::Valid { core_crypto, .. } => core_crypto
149                .pki_environment
150                .read()
151                .await
152                .as_ref()
153                .map(Clone::clone)
154                .ok_or(
155                    RecursiveError::transaction("getting PKI environment from transaction context")(
156                        e2e_identity::Error::PkiEnvironmentUnset,
157                    )
158                    .into(),
159                ),
160            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
161        }
162    }
163
164    pub(crate) async fn mls_groups(&self) -> Result<MutexGuardArc<ConversationCache>> {
165        let guard = self.inner.read().await;
166        let TransactionContextInner::Valid { core_crypto, .. } = &*guard else {
167            return Err(Error::InvalidTransactionContext);
168        };
169        let cache = core_crypto
170            .mls
171            .read()
172            .await
173            .as_ref()
174            .map(|session| session.conversation_cache.clone())
175            .ok_or_else(|| {
176                RecursiveError::mls_client("getting mls session from transaction context")(
177                    mls::session::Error::MlsNotInitialized,
178                )
179            })?;
180
181        Ok(cache.lock_arc().await)
182    }
183
184    pub(crate) async fn queue_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) -> Result<()> {
185        match &*self.inner.read().await {
186            TransactionContextInner::Valid {
187                pending_epoch_changes, ..
188            } => {
189                pending_epoch_changes.lock().await.push((conversation_id, epoch));
190                Ok(())
191            }
192            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
193        }
194    }
195
196    /// Commits the transaction, meaning it takes all the enqueued operations and persist them into
197    /// the keystore. After that the internal state is switched to invalid, causing errors if
198    /// something is called from this object.
199    pub async fn finish(&self) -> Result<()> {
200        let mut guard = self.inner.write().await;
201        let TransactionContextInner::Valid {
202            core_crypto,
203            pending_epoch_changes,
204            ..
205        } = &*guard
206        else {
207            return Err(Error::InvalidTransactionContext);
208        };
209
210        let commit_result = core_crypto
211            .database
212            .commit_transaction()
213            .await
214            .map_err(KeystoreError::wrap("commiting transaction"))
215            .map_err(Into::into);
216
217        if let Some(session) = core_crypto.mls.read().await.as_ref() {
218            if commit_result.is_ok() {
219                // We need owned values, so we could just clone the conversation ids, but we don't need the events
220                // anymore, so draining the vector works, too.
221                let mut epoch_changes = pending_epoch_changes.lock().await;
222                for (conversation_id, epoch) in epoch_changes.drain(..) {
223                    session.notify_epoch_changed(conversation_id, epoch).await;
224                }
225            } else {
226                // Commit failed: the keystore is back to its pre-transaction state, but the in-memory
227                // conversation cache may have absorbed mutations that never made it to disk. Clear them
228                // so subsequent reads load fresh state from the keystore.
229                session.conversation_cache.lock().await.clear();
230            }
231        }
232
233        *guard = TransactionContextInner::Invalid;
234        commit_result
235    }
236
237    /// Aborts the transaction, meaning it discards all the enqueued operations.
238    /// After that the internal state is switched to invalid, causing errors if
239    /// something is called from this object.
240    pub async fn abort(&self) -> Result<()> {
241        let mut guard = self.inner.write().await;
242
243        let TransactionContextInner::Valid { core_crypto, .. } = &*guard else {
244            return Err(Error::InvalidTransactionContext);
245        };
246
247        // Drop any in-memory conversation state mutated during this transaction; it never reached
248        // the keystore and would otherwise diverge from disk after rollback.
249        if let Some(session) = core_crypto.mls.read().await.as_ref() {
250            session.conversation_cache.lock().await.clear();
251        }
252
253        let result = core_crypto
254            .database
255            .rollback_transaction()
256            .await
257            .map_err(KeystoreError::wrap("rolling back transaction"))
258            .map_err(Into::into);
259
260        *guard = TransactionContextInner::Invalid;
261        result
262    }
263
264    /// Initializes the MLS client of [super::CoreCrypto].
265    pub async fn mls_init(&self, session_id: ClientId, transport: Arc<dyn MlsTransport>) -> Result<()> {
266        let database = self.database().await?;
267        let pki_env = self.pki_environment().await.ok();
268        let crypto_provider = CryptoProvider::new_with_pki_env(database.clone(), pki_env);
269        let session = Session::new(session_id.clone(), crypto_provider, database.into(), transport);
270        self.set_mls_session(session).await?;
271
272        Ok(())
273    }
274
275    /// Set the `mls_session` Arc (also sets it on the transaction's CoreCrypto instance)
276    pub(crate) async fn set_mls_session(&self, session: Session) -> Result<()> {
277        match &*self.inner.read().await {
278            TransactionContextInner::Valid { core_crypto, .. } => {
279                let mut guard = core_crypto.mls.write().await;
280                *guard = Some(session);
281                Ok(())
282            }
283            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
284        }
285    }
286
287    /// see [Session::id]
288    pub async fn client_id(&self) -> Result<ClientId> {
289        let session = self.session().await?;
290        Ok(session.id())
291    }
292
293    /// Generates a random byte array of the specified size
294    pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
295        use openmls_traits::random::OpenMlsRand as _;
296        self.crypto_provider()
297            .await?
298            .rand()
299            .random_vec(len)
300            .map_err(OpenMlsError::wrap("generating random vector"))
301            .map_err(Into::into)
302    }
303
304    /// Set arbitrary data to be retrieved by [TransactionContext::get_data].
305    /// This is meant to be used as a check point at the end of a transaction.
306    /// The data should be limited to a reasonable size.
307    pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
308        self.database()
309            .await?
310            .save(ConsumerData::from(data))
311            .await
312            .map_err(KeystoreError::wrap("saving consumer data"))?;
313        Ok(())
314    }
315
316    /// Get the data that has previously been set by [TransactionContext::set_data].
317    /// This is meant to be used as a check point at the end of a transaction.
318    pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
319        match self.database().await?.get_unique::<ConsumerData>().await {
320            Ok(maybe_data) => Ok(maybe_data.map(Into::into)),
321            Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
322            Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
323        }
324    }
325}