core_crypto/transaction_context/
mod.rs

1//! This module contains the primitives to enable transactional support on a higher level within the
2//! [Session]. All mutating operations need to be done through a [TransactionContext].
3
4use std::sync::Arc;
5
6#[cfg(feature = "proteus")]
7use async_lock::Mutex;
8use async_lock::{RwLock, RwLockWriteGuardArc};
9use core_crypto_keystore::{CryptoKeystoreError, entities::ConsumerData, traits::FetchFromDatabase as _};
10pub use error::{Error, Result};
11use openmls_traits::OpenMlsCryptoProvider as _;
12use wire_e2e_identity::pki_env::PkiEnvironment;
13
14#[cfg(feature = "proteus")]
15use crate::proteus::ProteusCentral;
16use crate::{
17    ClientId, CoreCrypto, CredentialFindFilters, CredentialRef, KeystoreError, MlsConversation, MlsError, MlsTransport,
18    RecursiveError, Session,
19    group_store::GroupStore,
20    mls::{self, HasSessionAndCrypto},
21    mls_provider::{Database, MlsCryptoProvider},
22};
23pub mod conversation;
24mod credential;
25pub mod e2e_identity;
26mod error;
27pub mod key_package;
28#[cfg(feature = "proteus")]
29pub mod proteus;
30#[cfg(test)]
31pub mod test_utils;
32
33/// This struct provides transactional support for Core Crypto.
34///
35/// This struct provides mutable access to the internals of Core Crypto. Every operation that
36/// causes data to be persisted needs to be done through this struct. This struct will buffer all
37/// operations in memory and when [TransactionContext::finish] is called, it will persist the data into
38/// the keystore.
39#[derive(Debug, Clone)]
40pub struct TransactionContext {
41    inner: Arc<RwLock<TransactionContextInner>>,
42}
43
44/// Due to uniffi's design, we can't force the context to be dropped after the transaction is
45/// committed. To work around that we switch the value to `Invalid` when the context is finished
46/// and throw errors if something is called
47#[derive(Debug, Clone)]
48enum TransactionContextInner {
49    Valid {
50        pki_environment: Arc<RwLock<Option<PkiEnvironment>>>,
51        database: Database,
52        mls_session: Arc<RwLock<Option<Session<Database>>>>,
53        mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
54        #[cfg(feature = "proteus")]
55        proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
56    },
57    Invalid,
58}
59
60impl CoreCrypto {
61    /// Creates a new transaction. All operations that persist data will be
62    /// buffered in memory and when [TransactionContext::finish] is called, the data will be persisted
63    /// in a single database transaction.
64    pub async fn new_transaction(&self) -> Result<TransactionContext> {
65        TransactionContext::new(
66            self.database.clone(),
67            self.pki_environment.clone(),
68            self.mls.clone(),
69            #[cfg(feature = "proteus")]
70            self.proteus.clone(),
71        )
72        .await
73    }
74}
75
76#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
77#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
78impl HasSessionAndCrypto for TransactionContext {
79    async fn session(&self) -> crate::mls::Result<Session<Database>> {
80        self.session()
81            .await
82            .map_err(RecursiveError::transaction("getting mls client"))
83            .map_err(Into::into)
84    }
85
86    async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
87        self.mls_provider()
88            .await
89            .map_err(RecursiveError::transaction("getting mls provider"))
90            .map_err(Into::into)
91    }
92}
93
94impl TransactionContext {
95    async fn new(
96        keystore: Database,
97        pki_environment: Arc<RwLock<Option<PkiEnvironment>>>,
98        mls_session: Arc<RwLock<Option<Session<Database>>>>,
99        #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
100    ) -> Result<Self> {
101        keystore
102            .new_transaction()
103            .await
104            .map_err(MlsError::wrap("creating new transaction"))?;
105        let mls_groups = Arc::new(RwLock::new(Default::default()));
106        Ok(Self {
107            inner: Arc::new(
108                TransactionContextInner::Valid {
109                    database: keystore,
110                    pki_environment,
111                    mls_session: mls_session.clone(),
112                    mls_groups,
113                    #[cfg(feature = "proteus")]
114                    proteus_central,
115                }
116                .into(),
117            ),
118        })
119    }
120
121    pub(crate) async fn session(&self) -> Result<Session<Database>> {
122        match &*self.inner.read().await {
123            TransactionContextInner::Valid { mls_session, .. } => mls_session.read().await.as_ref().cloned().ok_or(
124                RecursiveError::mls_client("Getting mls session from transaction context")(
125                    mls::session::Error::MlsNotInitialized,
126                )
127                .into(),
128            ),
129            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
130        }
131    }
132
133    #[cfg(test)]
134    pub(crate) async fn set_session_if_exists(&self, new_session: Session<Database>) {
135        match &*self.inner.read().await {
136            TransactionContextInner::Valid { mls_session, .. } => {
137                let mut guard = mls_session.write().await;
138
139                if guard.as_ref().is_some() {
140                    *guard = Some(new_session)
141                }
142            }
143            TransactionContextInner::Invalid => {}
144        }
145    }
146
147    pub(crate) async fn mls_transport(&self) -> Result<Arc<dyn MlsTransport + 'static>> {
148        match &*self.inner.read().await {
149            TransactionContextInner::Valid { mls_session, .. } => {
150                mls_session.read().await.as_ref().map(|s| s.transport.clone()).ok_or(
151                    RecursiveError::mls_client("Getting mls session from transaction context")(
152                        mls::session::Error::MlsNotInitialized,
153                    )
154                    .into(),
155                )
156            }
157
158            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
159        }
160    }
161
162    /// Clones all references that the [MlsCryptoProvider] comprises.
163    pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
164        match &*self.inner.read().await {
165            TransactionContextInner::Valid { mls_session, .. } => mls_session
166                .read()
167                .await
168                .as_ref()
169                .map(|s| s.crypto_provider.clone())
170                .ok_or(
171                    RecursiveError::mls_client("Getting mls session from transaction context")(
172                        mls::session::Error::MlsNotInitialized,
173                    )
174                    .into(),
175                ),
176            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
177        }
178    }
179
180    pub(crate) async fn database(&self) -> Result<Database> {
181        match &*self.inner.read().await {
182            TransactionContextInner::Valid { database, .. } => Ok(database.clone()),
183            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
184        }
185    }
186
187    pub(crate) async fn pki_environment(&self) -> Result<PkiEnvironment> {
188        match &*self.inner.read().await {
189            TransactionContextInner::Valid { pki_environment, .. } => {
190                pki_environment.read().await.as_ref().map(Clone::clone).ok_or(
191                    RecursiveError::transaction("Getting PKI environment from transaction context")(
192                        e2e_identity::Error::PkiEnvironmentUnset,
193                    )
194                    .into(),
195                )
196            }
197            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
198        }
199    }
200
201    pub(crate) async fn pki_environment_option(&self) -> Result<Option<PkiEnvironment>> {
202        match &*self.inner.read().await {
203            TransactionContextInner::Valid { pki_environment, .. } => Ok(pki_environment.read().await.clone()),
204
205            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
206        }
207    }
208
209    pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
210        match &*self.inner.read().await {
211            TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
212            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
213        }
214    }
215
216    #[cfg(feature = "proteus")]
217    pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
218        match &*self.inner.read().await {
219            TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
220            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
221        }
222    }
223
224    /// Commits the transaction, meaning it takes all the enqueued operations and persist them into
225    /// the keystore. After that the internal state is switched to invalid, causing errors if
226    /// something is called from this object.
227    pub async fn finish(&self) -> Result<()> {
228        let mut guard = self.inner.write().await;
229        let TransactionContextInner::Valid { database: keystore, .. } = &*guard else {
230            return Err(Error::InvalidTransactionContext);
231        };
232
233        let commit_result = keystore
234            .commit_transaction()
235            .await
236            .map_err(KeystoreError::wrap("commiting transaction"))
237            .map_err(Into::into);
238
239        *guard = TransactionContextInner::Invalid;
240        commit_result
241    }
242
243    /// Aborts the transaction, meaning it discards all the enqueued operations.
244    /// After that the internal state is switched to invalid, causing errors if
245    /// something is called from this object.
246    pub async fn abort(&self) -> Result<()> {
247        let mut guard = self.inner.write().await;
248
249        let TransactionContextInner::Valid { database: keystore, .. } = &*guard else {
250            return Err(Error::InvalidTransactionContext);
251        };
252
253        let result = keystore
254            .rollback_transaction()
255            .await
256            .map_err(KeystoreError::wrap("rolling back transaction"))
257            .map_err(Into::into);
258
259        *guard = TransactionContextInner::Invalid;
260        result
261    }
262
263    /// Initializes the MLS client of [super::CoreCrypto].
264    pub async fn mls_init(&self, session_id: ClientId, transport: Arc<dyn MlsTransport>) -> Result<()> {
265        let database = self.database().await?;
266
267        let pki_env_provider = self
268            .pki_environment_option()
269            .await?
270            .map(|pki_env| pki_env.mls_pki_env_provider())
271            .unwrap_or_default();
272
273        let crypto_provider = MlsCryptoProvider::new_with_pki_env(database, pki_env_provider);
274        let database = self.database().await?;
275        let session = Session::new(session_id.clone(), crypto_provider, database, transport);
276        self.set_mls_session(session).await?;
277
278        Ok(())
279    }
280
281    /// Set the `mls_session` Arc (also sets it on the transaction's CoreCrypto instance)
282    pub(crate) async fn set_mls_session(&self, session: Session<Database>) -> Result<()> {
283        match &*self.inner.read().await {
284            TransactionContextInner::Valid { mls_session, .. } => {
285                let mut guard = mls_session.write().await;
286                *guard = Some(session);
287                Ok(())
288            }
289            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
290        }
291    }
292
293    /// see [Session::id]
294    pub async fn client_id(&self) -> Result<ClientId> {
295        let session = self.session().await?;
296        Ok(session.id())
297    }
298
299    /// Generates a random byte array of the specified size
300    pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
301        use openmls_traits::random::OpenMlsRand as _;
302        self.mls_provider()
303            .await?
304            .rand()
305            .random_vec(len)
306            .map_err(MlsError::wrap("generating random vector"))
307            .map_err(Into::into)
308    }
309
310    /// Set arbitrary data to be retrieved by [TransactionContext::get_data].
311    /// This is meant to be used as a check point at the end of a transaction.
312    /// The data should be limited to a reasonable size.
313    pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
314        self.database()
315            .await?
316            .save(ConsumerData::from(data))
317            .await
318            .map_err(KeystoreError::wrap("saving consumer data"))?;
319        Ok(())
320    }
321
322    /// Get the data that has previously been set by [TransactionContext::set_data].
323    /// This is meant to be used as a check point at the end of a transaction.
324    pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
325        match self.database().await?.get_unique::<ConsumerData>().await {
326            Ok(maybe_data) => Ok(maybe_data.map(Into::into)),
327            Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
328            Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
329        }
330    }
331
332    /// Find credentials matching the find filters among the identities of this session
333    ///
334    /// Note that finding credentials with no filters set is equivalent to [`Self::get_credentials`].
335    pub async fn find_credentials(&self, find_filters: CredentialFindFilters<'_>) -> Result<Vec<CredentialRef>> {
336        self.session()
337            .await?
338            .find_credentials(find_filters)
339            .await
340            .map_err(RecursiveError::mls_client("finding credentials by filter"))
341            .map_err(Into::into)
342    }
343
344    /// Get all credentials from the identities of this session.
345    ///
346    /// To get specific credentials, it can be more efficient to use [`Self::find_credentials`].
347    pub async fn get_credentials(&self) -> Result<Vec<CredentialRef>> {
348        self.session()
349            .await?
350            .get_credentials()
351            .await
352            .map_err(RecursiveError::mls_client("getting all credentials"))
353            .map_err(Into::into)
354    }
355}