core_crypto/transaction_context/
mod.rs

1//! This module contains the primitives to enable transactional support on a higher level within the
2//! [Session]. All mutating operations need to be done through a [TransactionContext].
3
4#[cfg(feature = "proteus")]
5use crate::proteus::ProteusCentral;
6use crate::{
7    CoreCrypto, KeystoreError, MlsError, MlsTransport, RecursiveError,
8    group_store::GroupStore,
9    prelude::{ClientId, ConversationId, INITIAL_KEYING_MATERIAL_COUNT, MlsConversation, MlsCredentialType, Session},
10};
11use crate::{
12    mls::HasSessionAndCrypto,
13    prelude::{ClientIdentifier, MlsCiphersuite},
14};
15use async_lock::{Mutex, RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
16use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
17pub use error::{Error, Result};
18use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
19use openmls_traits::OpenMlsCryptoProvider as _;
20use std::{ops::Deref, sync::Arc};
21pub mod conversation;
22pub mod e2e_identity;
23mod error;
24pub mod key_package;
25#[cfg(feature = "proteus")]
26pub mod proteus;
27#[cfg(test)]
28pub mod test_utils;
29
30/// This struct provides transactional support for Core Crypto.
31///
32/// This struct provides mutable access to the internals of Core Crypto. Every operation that
33/// causes data to be persisted needs to be done through this struct. This struct will buffer all
34/// operations in memory and when [TransactionContext::finish] is called, it will persist the data into
35/// the keystore.
36#[derive(Debug, Clone)]
37pub struct TransactionContext {
38    inner: Arc<RwLock<TransactionContextInner>>,
39}
40
41/// Due to uniffi's design, we can't force the context to be dropped after the transaction is
42/// committed. To work around that we switch the value to `Invalid` when the context is finished
43/// and throw errors if something is called
44#[derive(Debug, Clone)]
45enum TransactionContextInner {
46    Valid {
47        provider: MlsCryptoProvider,
48        transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
49        mls_client: Session,
50        mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
51        pending_epoch_changes: Arc<Mutex<Vec<(ConversationId, u64)>>>,
52        #[cfg(feature = "proteus")]
53        proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
54    },
55    Invalid,
56}
57
58impl CoreCrypto {
59    /// Creates a new transaction. All operations that persist data will be
60    /// buffered in memory and when [TransactionContext::finish] is called, the data will be persisted
61    /// in a single database transaction.
62    pub async fn new_transaction(&self) -> Result<TransactionContext> {
63        TransactionContext::new(
64            &self.mls,
65            #[cfg(feature = "proteus")]
66            self.proteus.clone(),
67        )
68        .await
69    }
70}
71
72#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
73#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
74impl HasSessionAndCrypto for TransactionContext {
75    async fn session(&self) -> crate::mls::Result<Session> {
76        self.session()
77            .await
78            .map_err(RecursiveError::transaction("getting mls client"))
79            .map_err(Into::into)
80    }
81
82    async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
83        self.mls_provider()
84            .await
85            .map_err(RecursiveError::transaction("getting mls provider"))
86            .map_err(Into::into)
87    }
88}
89
90impl TransactionContext {
91    async fn new(
92        client: &Session,
93        #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
94    ) -> Result<Self> {
95        client
96            .crypto_provider
97            .new_transaction()
98            .await
99            .map_err(MlsError::wrap("creating new transaction"))?;
100        let mls_groups = Arc::new(RwLock::new(Default::default()));
101        let callbacks = client.transport.clone();
102        let mls_client = client.clone();
103        Ok(Self {
104            inner: Arc::new(
105                TransactionContextInner::Valid {
106                    mls_client,
107                    transport: callbacks,
108                    provider: client.crypto_provider.clone(),
109                    mls_groups,
110                    pending_epoch_changes: Default::default(),
111                    #[cfg(feature = "proteus")]
112                    proteus_central,
113                }
114                .into(),
115            ),
116        })
117    }
118
119    pub(crate) async fn session(&self) -> Result<Session> {
120        match self.inner.read().await.deref() {
121            TransactionContextInner::Valid { mls_client, .. } => Ok(mls_client.clone()),
122            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
123        }
124    }
125
126    pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
127        match self.inner.read().await.deref() {
128            TransactionContextInner::Valid {
129                transport: callbacks, ..
130            } => Ok(callbacks.read_arc().await),
131            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
132        }
133    }
134
135    #[cfg(test)]
136    pub(crate) async fn set_transport_callbacks(
137        &self,
138        callbacks: Option<Arc<dyn MlsTransport + 'static>>,
139    ) -> Result<()> {
140        match self.inner.read().await.deref() {
141            TransactionContextInner::Valid { transport: cbs, .. } => {
142                *cbs.write_arc().await = callbacks;
143                Ok(())
144            }
145            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
146        }
147    }
148
149    /// Clones all references that the [MlsCryptoProvider] comprises.
150    pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
151        match self.inner.read().await.deref() {
152            TransactionContextInner::Valid { provider, .. } => Ok(provider.clone()),
153            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
154        }
155    }
156
157    pub(crate) async fn keystore(&self) -> Result<CryptoKeystore> {
158        match self.inner.read().await.deref() {
159            TransactionContextInner::Valid { provider, .. } => Ok(provider.keystore()),
160            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
161        }
162    }
163
164    pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
165        match self.inner.read().await.deref() {
166            TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
167            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
168        }
169    }
170
171    pub(crate) async fn queue_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) -> Result<()> {
172        match self.inner.read().await.deref() {
173            TransactionContextInner::Valid {
174                pending_epoch_changes, ..
175            } => {
176                pending_epoch_changes.lock().await.push((conversation_id, epoch));
177                Ok(())
178            }
179            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
180        }
181    }
182
183    #[cfg(feature = "proteus")]
184    pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
185        match self.inner.read().await.deref() {
186            TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
187            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
188        }
189    }
190
191    /// Commits the transaction, meaning it takes all the enqueued operations and persist them into
192    /// the keystore. After that the internal state is switched to invalid, causing errors if
193    /// something is called from this object.
194    pub async fn finish(&self) -> Result<()> {
195        let mut guard = self.inner.write().await;
196        let TransactionContextInner::Valid {
197            provider,
198            mls_client,
199            pending_epoch_changes,
200            ..
201        } = guard.deref()
202        else {
203            return Err(Error::InvalidTransactionContext);
204        };
205
206        let commit_result = provider
207            .keystore()
208            .commit_transaction()
209            .await
210            .map_err(KeystoreError::wrap("commiting transaction"))
211            .map_err(Into::into);
212
213        if commit_result.is_ok() {
214            // We need owned values, so we could just clone the conversation ids, but we don't need the events anymore,
215            // so draining the vector works, too.
216            let mut epoch_changes = pending_epoch_changes.lock().await;
217            let epoch_changes = epoch_changes.drain(..);
218            for (conversation_id, epoch) in epoch_changes {
219                mls_client.notify_epoch_changed(conversation_id, epoch).await;
220            }
221        }
222
223        *guard = TransactionContextInner::Invalid;
224        commit_result
225    }
226
227    /// Aborts the transaction, meaning it discards all the enqueued operations.
228    /// After that the internal state is switched to invalid, causing errors if
229    /// something is called from this object.
230    pub async fn abort(&self) -> Result<()> {
231        let mut guard = self.inner.write().await;
232
233        let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
234            return Err(Error::InvalidTransactionContext);
235        };
236
237        let result = provider
238            .keystore()
239            .rollback_transaction()
240            .await
241            .map_err(KeystoreError::wrap("rolling back transaction"))
242            .map_err(Into::into);
243
244        *guard = TransactionContextInner::Invalid;
245        result
246    }
247
248    /// Initializes the MLS client if [super::CoreCrypto] has previously been initialized with
249    /// `CoreCrypto::deferred_init` instead of `CoreCrypto::new`.
250    /// This should stay as long as proteus is supported. Then it should be removed.
251    pub async fn mls_init(
252        &self,
253        identifier: ClientIdentifier,
254        ciphersuites: Vec<MlsCiphersuite>,
255        nb_init_key_packages: Option<usize>,
256    ) -> Result<()> {
257        let nb_key_package = nb_init_key_packages.unwrap_or(INITIAL_KEYING_MATERIAL_COUNT);
258        let mls_client = self.session().await?;
259        mls_client
260            .init(identifier, &ciphersuites, &self.mls_provider().await?, nb_key_package)
261            .await
262            .map_err(RecursiveError::mls_client("initializing mls client"))?;
263
264        if mls_client.is_e2ei_capable().await {
265            let client_id = mls_client
266                .id()
267                .await
268                .map_err(RecursiveError::mls_client("getting client id"))?;
269            log::trace!(client_id:% = client_id; "Initializing PKI environment");
270            self.init_pki_env().await?;
271        }
272
273        Ok(())
274    }
275
276    /// Returns the client's public key.
277    pub async fn client_public_key(
278        &self,
279        ciphersuite: MlsCiphersuite,
280        credential_type: MlsCredentialType,
281    ) -> Result<Vec<u8>> {
282        let cb = self
283            .session()
284            .await?
285            .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
286            .await
287            .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
288        Ok(cb.signature_key.to_public_vec())
289    }
290
291    /// see [Session::id]
292    pub async fn client_id(&self) -> Result<ClientId> {
293        self.session()
294            .await?
295            .id()
296            .await
297            .map_err(RecursiveError::mls_client("getting client id"))
298            .map_err(Into::into)
299    }
300
301    /// Generates a random byte array of the specified size
302    pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
303        use openmls_traits::random::OpenMlsRand as _;
304        self.mls_provider()
305            .await?
306            .rand()
307            .random_vec(len)
308            .map_err(MlsError::wrap("generating random vector"))
309            .map_err(Into::into)
310    }
311
312    /// Set arbitrary data to be retrieved by [TransactionContext::get_data].
313    /// This is meant to be used as a check point at the end of a transaction.
314    /// The data should be limited to a reasonable size.
315    pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
316        self.keystore()
317            .await?
318            .save(ConsumerData::from(data))
319            .await
320            .map_err(KeystoreError::wrap("saving consumer data"))?;
321        Ok(())
322    }
323
324    /// Get the data that has previously been set by [TransactionContext::set_data].
325    /// This is meant to be used as a check point at the end of a transaction.
326    pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
327        match self.keystore().await?.find_unique::<ConsumerData>().await {
328            Ok(data) => Ok(Some(data.into())),
329            Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
330            Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
331        }
332    }
333}