core_crypto/transaction_context/
mod.rs

1//! This module contains the primitives to enable transactional support on a higher level within the
2//! [Client]. All mutating operations need to be done through a [TransactionContext].
3
4use crate::mls::HasSessionAndCrypto;
5#[cfg(feature = "proteus")]
6use crate::proteus::ProteusCentral;
7use crate::{
8    CoreCrypto, KeystoreError, MlsError, MlsTransport, RecursiveError,
9    group_store::GroupStore,
10    prelude::{MlsConversation, Session},
11};
12use async_lock::{Mutex, RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
13use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
14pub use error::{Error, Result};
15use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
16use std::{ops::Deref, sync::Arc};
17pub mod conversation;
18pub mod e2e_identity;
19mod error;
20#[cfg(test)]
21pub mod test_utils;
22
23/// This struct provides transactional support for Core Crypto.
24///
25/// This is struct provides mutable access to the internals of Core Crypto. Every operation that
26/// causes data to be persisted needs to be done through this struct. This struct will buffer all
27/// operations in memory and when [TransactionContext::finish] is called, it will persist the data into
28/// the keystore.
29#[derive(Debug, Clone)]
30pub struct TransactionContext {
31    inner: Arc<RwLock<TransactionContextInner>>,
32}
33
34/// Due to uniffi's design, we can't force the context to be dropped after the transaction is
35/// committed. To work around that we switch the value to `Invalid` when the context is finished
36/// and throw errors if something is called
37#[derive(Debug, Clone)]
38enum TransactionContextInner {
39    Valid {
40        provider: MlsCryptoProvider,
41        transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
42        mls_client: Session,
43        mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
44        #[cfg(feature = "proteus")]
45        proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
46    },
47    Invalid,
48}
49
50impl CoreCrypto {
51    /// Creates a new transaction. All operations that persist data will be
52    /// buffered in memory and when [TransactionContext::finish] is called, the data will be persisted
53    /// in a single database transaction.
54    pub async fn new_transaction(&self) -> Result<TransactionContext> {
55        TransactionContext::new(
56            &self.mls,
57            #[cfg(feature = "proteus")]
58            self.proteus.clone(),
59        )
60        .await
61    }
62}
63
64#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
65#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
66impl HasSessionAndCrypto for TransactionContext {
67    async fn session(&self) -> crate::mls::Result<Session> {
68        self.session()
69            .await
70            .map_err(RecursiveError::transaction("getting mls client"))
71            .map_err(Into::into)
72    }
73
74    async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
75        self.mls_provider()
76            .await
77            .map_err(RecursiveError::transaction("getting mls provider"))
78            .map_err(Into::into)
79    }
80}
81
82impl TransactionContext {
83    async fn new(
84        client: &Session,
85        #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
86    ) -> Result<Self> {
87        client
88            .crypto_provider
89            .new_transaction()
90            .await
91            .map_err(MlsError::wrap("creating new transaction"))?;
92        let mls_groups = Arc::new(RwLock::new(Default::default()));
93        let callbacks = client.transport.clone();
94        let mls_client = client.clone();
95        Ok(Self {
96            inner: Arc::new(
97                TransactionContextInner::Valid {
98                    mls_client,
99                    transport: callbacks,
100                    provider: client.crypto_provider.clone(),
101                    mls_groups,
102                    #[cfg(feature = "proteus")]
103                    proteus_central,
104                }
105                .into(),
106            ),
107        })
108    }
109
110    pub(crate) async fn session(&self) -> Result<Session> {
111        match self.inner.read().await.deref() {
112            TransactionContextInner::Valid { mls_client, .. } => Ok(mls_client.clone()),
113            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
114        }
115    }
116
117    pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
118        match self.inner.read().await.deref() {
119            TransactionContextInner::Valid {
120                transport: callbacks, ..
121            } => Ok(callbacks.read_arc().await),
122            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
123        }
124    }
125
126    #[cfg(test)]
127    pub(crate) async fn set_transport_callbacks(
128        &self,
129        callbacks: Option<Arc<dyn MlsTransport + 'static>>,
130    ) -> Result<()> {
131        match self.inner.read().await.deref() {
132            TransactionContextInner::Valid { transport: cbs, .. } => {
133                *cbs.write_arc().await = callbacks;
134                Ok(())
135            }
136            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
137        }
138    }
139
140    /// Clones all references that the [MlsCryptoProvider] comprises.
141    pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
142        match self.inner.read().await.deref() {
143            TransactionContextInner::Valid { provider, .. } => Ok(provider.clone()),
144            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
145        }
146    }
147
148    pub(crate) async fn keystore(&self) -> Result<CryptoKeystore> {
149        match self.inner.read().await.deref() {
150            TransactionContextInner::Valid { provider, .. } => Ok(provider.keystore()),
151            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
152        }
153    }
154
155    pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
156        match self.inner.read().await.deref() {
157            TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
158            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
159        }
160    }
161
162    #[cfg(feature = "proteus")]
163    pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
164        match self.inner.read().await.deref() {
165            TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
166            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
167        }
168    }
169
170    /// Commits the transaction, meaning it takes all the enqueued operations and persist them into
171    /// the keystore. After that the internal state is switched to invalid, causing errors if
172    /// something is called from this object.
173    pub async fn finish(&self) -> Result<()> {
174        let mut guard = self.inner.write().await;
175        let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
176            return Err(Error::InvalidTransactionContext);
177        };
178
179        let commit_result = provider
180            .keystore()
181            .commit_transaction()
182            .await
183            .map_err(KeystoreError::wrap("commiting transaction"))
184            .map_err(Into::into);
185
186        *guard = TransactionContextInner::Invalid;
187        commit_result
188    }
189
190    /// Aborts the transaction, meaning it discards all the enqueued operations.
191    /// After that the internal state is switched to invalid, causing errors if
192    /// something is called from this object.
193    pub async fn abort(&self) -> Result<()> {
194        let mut guard = self.inner.write().await;
195
196        let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
197            return Err(Error::InvalidTransactionContext);
198        };
199
200        let result = provider
201            .keystore()
202            .rollback_transaction()
203            .await
204            .map_err(KeystoreError::wrap("rolling back transaction"))
205            .map_err(Into::into);
206
207        *guard = TransactionContextInner::Invalid;
208        result
209    }
210
211    /// Set arbitrary data to be retrieved by [TransactionContext::get_data].
212    /// This is meant to be used as a check point at the end of a transaction.
213    /// The data should be limited to a reasonable size.
214    pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
215        self.keystore()
216            .await?
217            .save(ConsumerData::from(data))
218            .await
219            .map_err(KeystoreError::wrap("saving consumer data"))?;
220        Ok(())
221    }
222
223    /// Get the data that has previously been set by [TransactionContext::set_data].
224    /// This is meant to be used as a check point at the end of a transaction.
225    pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
226        match self.keystore().await?.find_unique::<ConsumerData>().await {
227            Ok(data) => Ok(Some(data.into())),
228            Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
229            Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
230        }
231    }
232}