Skip to main content

core_crypto/transaction_context/
mod.rs

1//! This module contains the primitives to enable transactional support on a higher level within the
2//! [Session]. All mutating operations need to be done through a [TransactionContext].
3
4use std::sync::Arc;
5
6use async_lock::{Mutex, RwLock, RwLockWriteGuardArc};
7use core_crypto_keystore::{CryptoKeystoreError, entities::ConsumerData, traits::FetchFromDatabase as _};
8pub use error::{Error, Result};
9use openmls_traits::OpenMlsCryptoProvider as _;
10use wire_e2e_identity::pki_env::PkiEnvironment;
11
12#[cfg(feature = "proteus")]
13use crate::proteus::ProteusCentral;
14use crate::{
15    ClientId, ConversationId, CoreCrypto, CredentialFindFilters, CredentialRef, KeystoreError, MlsConversation,
16    MlsError, MlsTransport, RecursiveError, Session,
17    group_store::GroupStore,
18    mls::{self, HasSessionAndCrypto},
19    mls_provider::{Database, MlsCryptoProvider},
20};
21pub mod conversation;
22mod credential;
23pub mod e2e_identity;
24mod error;
25pub mod key_package;
26#[cfg(feature = "proteus")]
27pub mod proteus;
28#[cfg(test)]
29pub mod test_utils;
30
31/// This struct provides transactional support for Core Crypto.
32///
33/// This struct provides mutable access to the internals of Core Crypto. Every operation that
34/// causes data to be persisted needs to be done through this struct. This struct will buffer all
35/// operations in memory and when [TransactionContext::finish] is called, it will persist the data into
36/// the keystore.
37#[derive(Debug, Clone)]
38pub struct TransactionContext {
39    inner: Arc<RwLock<TransactionContextInner>>,
40}
41
42/// Due to uniffi's design, we can't force the context to be dropped after the transaction is
43/// committed. To work around that we switch the value to `Invalid` when the context is finished
44/// and throw errors if something is called
45#[derive(Debug, Clone)]
46enum TransactionContextInner {
47    Valid {
48        pki_environment: Arc<RwLock<Option<PkiEnvironment>>>,
49        database: Database,
50        mls_session: Arc<RwLock<Option<Session<Database>>>>,
51        mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
52        pending_epoch_changes: Arc<Mutex<Vec<(ConversationId, u64)>>>,
53        #[cfg(feature = "proteus")]
54        proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
55    },
56    Invalid,
57}
58
59impl CoreCrypto {
60    /// Creates a new transaction. All operations that persist data will be
61    /// buffered in memory and when [TransactionContext::finish] is called, the data will be persisted
62    /// in a single database transaction.
63    pub async fn new_transaction(&self) -> Result<TransactionContext> {
64        TransactionContext::new(
65            self.database.clone(),
66            self.pki_environment.clone(),
67            self.mls.clone(),
68            #[cfg(feature = "proteus")]
69            self.proteus.clone(),
70        )
71        .await
72    }
73}
74
75#[cfg_attr(target_os = "unknown", async_trait::async_trait(?Send))]
76#[cfg_attr(not(target_os = "unknown"), async_trait::async_trait)]
77impl HasSessionAndCrypto for TransactionContext {
78    async fn session(&self) -> crate::mls::Result<Session<Database>> {
79        self.session()
80            .await
81            .map_err(RecursiveError::transaction("getting mls client"))
82            .map_err(Into::into)
83    }
84
85    async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
86        self.mls_provider()
87            .await
88            .map_err(RecursiveError::transaction("getting mls provider"))
89            .map_err(Into::into)
90    }
91}
92
93impl TransactionContext {
94    async fn new(
95        keystore: Database,
96        pki_environment: Arc<RwLock<Option<PkiEnvironment>>>,
97        mls_session: Arc<RwLock<Option<Session<Database>>>>,
98        #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
99    ) -> Result<Self> {
100        keystore
101            .new_transaction()
102            .await
103            .map_err(MlsError::wrap("creating new transaction"))?;
104        let mls_groups = Arc::new(RwLock::new(Default::default()));
105        Ok(Self {
106            inner: Arc::new(
107                TransactionContextInner::Valid {
108                    database: keystore,
109                    pki_environment,
110                    mls_session: mls_session.clone(),
111                    mls_groups,
112                    pending_epoch_changes: Default::default(),
113                    #[cfg(feature = "proteus")]
114                    proteus_central,
115                }
116                .into(),
117            ),
118        })
119    }
120
121    pub(crate) async fn session(&self) -> Result<Session<Database>> {
122        match &*self.inner.read().await {
123            TransactionContextInner::Valid { mls_session, .. } => mls_session.read().await.as_ref().cloned().ok_or(
124                RecursiveError::mls_client("Getting mls session from transaction context")(
125                    mls::session::Error::MlsNotInitialized,
126                )
127                .into(),
128            ),
129            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
130        }
131    }
132
133    #[cfg(test)]
134    pub(crate) async fn set_session_if_exists(&self, new_session: Session<Database>) {
135        match &*self.inner.read().await {
136            TransactionContextInner::Valid { mls_session, .. } => {
137                let mut guard = mls_session.write().await;
138
139                if guard.as_ref().is_some() {
140                    *guard = Some(new_session)
141                }
142            }
143            TransactionContextInner::Invalid => {}
144        }
145    }
146
147    pub(crate) async fn mls_transport(&self) -> Result<Arc<dyn MlsTransport + 'static>> {
148        match &*self.inner.read().await {
149            TransactionContextInner::Valid { mls_session, .. } => {
150                mls_session.read().await.as_ref().map(|s| s.transport.clone()).ok_or(
151                    RecursiveError::mls_client("Getting mls session from transaction context")(
152                        mls::session::Error::MlsNotInitialized,
153                    )
154                    .into(),
155                )
156            }
157
158            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
159        }
160    }
161
162    /// Clones all references that the [MlsCryptoProvider] comprises.
163    pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
164        match &*self.inner.read().await {
165            TransactionContextInner::Valid { mls_session, .. } => mls_session
166                .read()
167                .await
168                .as_ref()
169                .map(|s| s.crypto_provider.clone())
170                .ok_or(
171                    RecursiveError::mls_client("Getting mls session from transaction context")(
172                        mls::session::Error::MlsNotInitialized,
173                    )
174                    .into(),
175                ),
176            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
177        }
178    }
179
180    pub(crate) async fn database(&self) -> Result<Database> {
181        match &*self.inner.read().await {
182            TransactionContextInner::Valid { database, .. } => Ok(database.clone()),
183            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
184        }
185    }
186
187    pub(crate) async fn pki_environment(&self) -> Result<PkiEnvironment> {
188        match &*self.inner.read().await {
189            TransactionContextInner::Valid { pki_environment, .. } => {
190                pki_environment.read().await.as_ref().map(Clone::clone).ok_or(
191                    RecursiveError::transaction("Getting PKI environment from transaction context")(
192                        e2e_identity::Error::PkiEnvironmentUnset,
193                    )
194                    .into(),
195                )
196            }
197            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
198        }
199    }
200
201    pub(crate) async fn pki_environment_option(&self) -> Result<Option<PkiEnvironment>> {
202        match &*self.inner.read().await {
203            TransactionContextInner::Valid { pki_environment, .. } => Ok(pki_environment.read().await.clone()),
204
205            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
206        }
207    }
208
209    pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
210        match &*self.inner.read().await {
211            TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
212            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
213        }
214    }
215
216    pub(crate) async fn queue_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) -> Result<()> {
217        match &*self.inner.read().await {
218            TransactionContextInner::Valid {
219                pending_epoch_changes, ..
220            } => {
221                pending_epoch_changes.lock().await.push((conversation_id, epoch));
222                Ok(())
223            }
224            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
225        }
226    }
227
228    #[cfg(feature = "proteus")]
229    pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
230        match &*self.inner.read().await {
231            TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
232            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
233        }
234    }
235
236    /// Commits the transaction, meaning it takes all the enqueued operations and persist them into
237    /// the keystore. After that the internal state is switched to invalid, causing errors if
238    /// something is called from this object.
239    pub async fn finish(&self) -> Result<()> {
240        let mut guard = self.inner.write().await;
241        let TransactionContextInner::Valid {
242            database,
243            pending_epoch_changes,
244            mls_session,
245            ..
246        } = &*guard
247        else {
248            return Err(Error::InvalidTransactionContext);
249        };
250
251        let commit_result = database
252            .commit_transaction()
253            .await
254            .map_err(KeystoreError::wrap("commiting transaction"))
255            .map_err(Into::into);
256
257        if let Some(session) = mls_session.read_arc().await.clone()
258            && commit_result.is_ok()
259        {
260            // We need owned values, so we could just clone the conversation ids, but we don't need the events anymore,
261            // so draining the vector works, too.
262            let mut epoch_changes = pending_epoch_changes.lock().await;
263            let epoch_changes = epoch_changes.drain(..);
264            for (conversation_id, epoch) in epoch_changes {
265                session.notify_epoch_changed(conversation_id, epoch).await;
266            }
267        }
268
269        *guard = TransactionContextInner::Invalid;
270        commit_result
271    }
272
273    /// Aborts the transaction, meaning it discards all the enqueued operations.
274    /// After that the internal state is switched to invalid, causing errors if
275    /// something is called from this object.
276    pub async fn abort(&self) -> Result<()> {
277        let mut guard = self.inner.write().await;
278
279        let TransactionContextInner::Valid { database: keystore, .. } = &*guard else {
280            return Err(Error::InvalidTransactionContext);
281        };
282
283        let result = keystore
284            .rollback_transaction()
285            .await
286            .map_err(KeystoreError::wrap("rolling back transaction"))
287            .map_err(Into::into);
288
289        *guard = TransactionContextInner::Invalid;
290        result
291    }
292
293    /// Initializes the MLS client of [super::CoreCrypto].
294    pub async fn mls_init(&self, session_id: ClientId, transport: Arc<dyn MlsTransport>) -> Result<()> {
295        let database = self.database().await?;
296
297        let pki_env_provider = self
298            .pki_environment_option()
299            .await?
300            .map(|pki_env| pki_env.mls_pki_env_provider())
301            .unwrap_or_default();
302
303        let crypto_provider = MlsCryptoProvider::new_with_pki_env(database.clone(), pki_env_provider);
304        let session = Session::new(session_id.clone(), crypto_provider, database, transport);
305        self.set_mls_session(session).await?;
306
307        Ok(())
308    }
309
310    /// Set the `mls_session` Arc (also sets it on the transaction's CoreCrypto instance)
311    pub(crate) async fn set_mls_session(&self, session: Session<Database>) -> Result<()> {
312        match &*self.inner.read().await {
313            TransactionContextInner::Valid { mls_session, .. } => {
314                let mut guard = mls_session.write().await;
315                *guard = Some(session);
316                Ok(())
317            }
318            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
319        }
320    }
321
322    /// see [Session::id]
323    pub async fn client_id(&self) -> Result<ClientId> {
324        let session = self.session().await?;
325        Ok(session.id())
326    }
327
328    /// Generates a random byte array of the specified size
329    pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
330        use openmls_traits::random::OpenMlsRand as _;
331        self.mls_provider()
332            .await?
333            .rand()
334            .random_vec(len)
335            .map_err(MlsError::wrap("generating random vector"))
336            .map_err(Into::into)
337    }
338
339    /// Set arbitrary data to be retrieved by [TransactionContext::get_data].
340    /// This is meant to be used as a check point at the end of a transaction.
341    /// The data should be limited to a reasonable size.
342    pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
343        self.database()
344            .await?
345            .save(ConsumerData::from(data))
346            .await
347            .map_err(KeystoreError::wrap("saving consumer data"))?;
348        Ok(())
349    }
350
351    /// Get the data that has previously been set by [TransactionContext::set_data].
352    /// This is meant to be used as a check point at the end of a transaction.
353    pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
354        match self.database().await?.get_unique::<ConsumerData>().await {
355            Ok(maybe_data) => Ok(maybe_data.map(Into::into)),
356            Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
357            Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
358        }
359    }
360
361    /// Find credentials matching the find filters among the identities of this session
362    ///
363    /// Note that finding credentials with no filters set is equivalent to [`Self::get_credentials`].
364    pub async fn find_credentials(&self, find_filters: CredentialFindFilters<'_>) -> Result<Vec<CredentialRef>> {
365        self.session()
366            .await?
367            .find_credentials(find_filters)
368            .await
369            .map_err(RecursiveError::mls_client("finding credentials by filter"))
370            .map_err(Into::into)
371    }
372
373    /// Get all credentials from the identities of this session.
374    ///
375    /// To get specific credentials, it can be more efficient to use [`Self::find_credentials`].
376    pub async fn get_credentials(&self) -> Result<Vec<CredentialRef>> {
377        self.session()
378            .await?
379            .get_credentials()
380            .await
381            .map_err(RecursiveError::mls_client("getting all credentials"))
382            .map_err(Into::into)
383    }
384}