core_crypto/transaction_context/
mod.rs

1//! This module contains the primitives to enable transactional support on a higher level within the
2//! [Client]. All mutating operations need to be done through a [TransactionContext].
3
4#[cfg(feature = "proteus")]
5use crate::proteus::ProteusCentral;
6use crate::{
7    CoreCrypto, KeystoreError, MlsError, MlsTransport, RecursiveError,
8    group_store::GroupStore,
9    prelude::{ClientId, INITIAL_KEYING_MATERIAL_COUNT, MlsConversation, MlsCredentialType, Session},
10};
11use crate::{
12    mls::HasSessionAndCrypto,
13    prelude::{ClientIdentifier, MlsCiphersuite},
14};
15use async_lock::{Mutex, RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
16use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
17pub use error::{Error, Result};
18use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
19use openmls_traits::OpenMlsCryptoProvider as _;
20use std::{ops::Deref, sync::Arc};
21pub mod conversation;
22pub mod e2e_identity;
23mod error;
24pub mod key_package;
25#[cfg(feature = "proteus")]
26pub mod proteus;
27#[cfg(test)]
28pub mod test_utils;
29
30/// This struct provides transactional support for Core Crypto.
31///
32/// This is struct provides mutable access to the internals of Core Crypto. Every operation that
33/// causes data to be persisted needs to be done through this struct. This struct will buffer all
34/// operations in memory and when [TransactionContext::finish] is called, it will persist the data into
35/// the keystore.
36#[derive(Debug, Clone)]
37pub struct TransactionContext {
38    inner: Arc<RwLock<TransactionContextInner>>,
39}
40
41/// Due to uniffi's design, we can't force the context to be dropped after the transaction is
42/// committed. To work around that we switch the value to `Invalid` when the context is finished
43/// and throw errors if something is called
44#[derive(Debug, Clone)]
45enum TransactionContextInner {
46    Valid {
47        provider: MlsCryptoProvider,
48        transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
49        mls_client: Session,
50        mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
51        #[cfg(feature = "proteus")]
52        proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
53    },
54    Invalid,
55}
56
57impl CoreCrypto {
58    /// Creates a new transaction. All operations that persist data will be
59    /// buffered in memory and when [TransactionContext::finish] is called, the data will be persisted
60    /// in a single database transaction.
61    pub async fn new_transaction(&self) -> Result<TransactionContext> {
62        TransactionContext::new(
63            &self.mls,
64            #[cfg(feature = "proteus")]
65            self.proteus.clone(),
66        )
67        .await
68    }
69}
70
71#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
72#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
73impl HasSessionAndCrypto for TransactionContext {
74    async fn session(&self) -> crate::mls::Result<Session> {
75        self.session()
76            .await
77            .map_err(RecursiveError::transaction("getting mls client"))
78            .map_err(Into::into)
79    }
80
81    async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
82        self.mls_provider()
83            .await
84            .map_err(RecursiveError::transaction("getting mls provider"))
85            .map_err(Into::into)
86    }
87}
88
89impl TransactionContext {
90    async fn new(
91        client: &Session,
92        #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
93    ) -> Result<Self> {
94        client
95            .crypto_provider
96            .new_transaction()
97            .await
98            .map_err(MlsError::wrap("creating new transaction"))?;
99        let mls_groups = Arc::new(RwLock::new(Default::default()));
100        let callbacks = client.transport.clone();
101        let mls_client = client.clone();
102        Ok(Self {
103            inner: Arc::new(
104                TransactionContextInner::Valid {
105                    mls_client,
106                    transport: callbacks,
107                    provider: client.crypto_provider.clone(),
108                    mls_groups,
109                    #[cfg(feature = "proteus")]
110                    proteus_central,
111                }
112                .into(),
113            ),
114        })
115    }
116
117    pub(crate) async fn session(&self) -> Result<Session> {
118        match self.inner.read().await.deref() {
119            TransactionContextInner::Valid { mls_client, .. } => Ok(mls_client.clone()),
120            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
121        }
122    }
123
124    pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
125        match self.inner.read().await.deref() {
126            TransactionContextInner::Valid {
127                transport: callbacks, ..
128            } => Ok(callbacks.read_arc().await),
129            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
130        }
131    }
132
133    #[cfg(test)]
134    pub(crate) async fn set_transport_callbacks(
135        &self,
136        callbacks: Option<Arc<dyn MlsTransport + 'static>>,
137    ) -> Result<()> {
138        match self.inner.read().await.deref() {
139            TransactionContextInner::Valid { transport: cbs, .. } => {
140                *cbs.write_arc().await = callbacks;
141                Ok(())
142            }
143            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
144        }
145    }
146
147    /// Clones all references that the [MlsCryptoProvider] comprises.
148    pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
149        match self.inner.read().await.deref() {
150            TransactionContextInner::Valid { provider, .. } => Ok(provider.clone()),
151            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
152        }
153    }
154
155    pub(crate) async fn keystore(&self) -> Result<CryptoKeystore> {
156        match self.inner.read().await.deref() {
157            TransactionContextInner::Valid { provider, .. } => Ok(provider.keystore()),
158            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
159        }
160    }
161
162    pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
163        match self.inner.read().await.deref() {
164            TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
165            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
166        }
167    }
168
169    #[cfg(feature = "proteus")]
170    pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
171        match self.inner.read().await.deref() {
172            TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
173            TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
174        }
175    }
176
177    /// Commits the transaction, meaning it takes all the enqueued operations and persist them into
178    /// the keystore. After that the internal state is switched to invalid, causing errors if
179    /// something is called from this object.
180    pub async fn finish(&self) -> Result<()> {
181        let mut guard = self.inner.write().await;
182        let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
183            return Err(Error::InvalidTransactionContext);
184        };
185
186        let commit_result = provider
187            .keystore()
188            .commit_transaction()
189            .await
190            .map_err(KeystoreError::wrap("commiting transaction"))
191            .map_err(Into::into);
192
193        *guard = TransactionContextInner::Invalid;
194        commit_result
195    }
196
197    /// Aborts the transaction, meaning it discards all the enqueued operations.
198    /// After that the internal state is switched to invalid, causing errors if
199    /// something is called from this object.
200    pub async fn abort(&self) -> Result<()> {
201        let mut guard = self.inner.write().await;
202
203        let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
204            return Err(Error::InvalidTransactionContext);
205        };
206
207        let result = provider
208            .keystore()
209            .rollback_transaction()
210            .await
211            .map_err(KeystoreError::wrap("rolling back transaction"))
212            .map_err(Into::into);
213
214        *guard = TransactionContextInner::Invalid;
215        result
216    }
217
218    /// Initializes the MLS client if [super::CoreCrypto] has previously been initialized with
219    /// `CoreCrypto::deferred_init` instead of `CoreCrypto::new`.
220    /// This should stay as long as proteus is supported. Then it should be removed.
221    pub async fn mls_init(
222        &self,
223        identifier: ClientIdentifier,
224        ciphersuites: Vec<MlsCiphersuite>,
225        nb_init_key_packages: Option<usize>,
226    ) -> Result<()> {
227        let nb_key_package = nb_init_key_packages.unwrap_or(INITIAL_KEYING_MATERIAL_COUNT);
228        let mls_client = self.session().await?;
229        mls_client
230            .init(identifier, &ciphersuites, &self.mls_provider().await?, nb_key_package)
231            .await
232            .map_err(RecursiveError::mls_client("initializing mls client"))?;
233
234        if mls_client.is_e2ei_capable().await {
235            let client_id = mls_client
236                .id()
237                .await
238                .map_err(RecursiveError::mls_client("getting client id"))?;
239            log::trace!(client_id:% = client_id; "Initializing PKI environment");
240            self.init_pki_env().await?;
241        }
242
243        Ok(())
244    }
245
246    /// Generates MLS KeyPairs/CredentialBundle with a temporary, random client ID.
247    /// This method is designed to be used in conjunction with [TransactionContext::mls_init_with_client_id] and represents the first step in this process.
248    ///
249    /// This returns the TLS-serialized identity keys (i.e. the signature keypair's public key)
250    #[cfg_attr(test, crate::dispotent)]
251    pub async fn mls_generate_keypairs(&self, ciphersuites: Vec<MlsCiphersuite>) -> Result<Vec<ClientId>> {
252        self.session()
253            .await?
254            .generate_raw_keypairs(&ciphersuites, &self.mls_provider().await?)
255            .await
256            .map_err(RecursiveError::mls_client("generating raw keypairs"))
257            .map_err(Into::into)
258    }
259
260    /// Updates the current temporary Client ID with the newly provided one. This is the second step in the externally-generated clients process
261    ///
262    /// Important: This is designed to be called after [TransactionContext::mls_generate_keypairs]
263    #[cfg_attr(test, crate::dispotent)]
264    pub async fn mls_init_with_client_id(
265        &self,
266        client_id: ClientId,
267        tmp_client_ids: Vec<ClientId>,
268        ciphersuites: Vec<MlsCiphersuite>,
269    ) -> Result<()> {
270        self.session()
271            .await?
272            .init_with_external_client_id(client_id, tmp_client_ids, &ciphersuites, &self.mls_provider().await?)
273            .await
274            .map_err(RecursiveError::mls_client(
275                "initializing mls client with external client id",
276            ))
277            .map_err(Into::into)
278    }
279
280    /// see [Client::client_public_key]
281    pub async fn client_public_key(
282        &self,
283        ciphersuite: MlsCiphersuite,
284        credential_type: MlsCredentialType,
285    ) -> Result<Vec<u8>> {
286        let cb = self
287            .session()
288            .await?
289            .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
290            .await
291            .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
292        Ok(cb.signature_key.to_public_vec())
293    }
294
295    /// see [Client::id]
296    pub async fn client_id(&self) -> Result<ClientId> {
297        self.session()
298            .await?
299            .id()
300            .await
301            .map_err(RecursiveError::mls_client("getting client id"))
302            .map_err(Into::into)
303    }
304
305    /// Generates a random byte array of the specified size
306    pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
307        use openmls_traits::random::OpenMlsRand as _;
308        self.mls_provider()
309            .await?
310            .rand()
311            .random_vec(len)
312            .map_err(MlsError::wrap("generating random vector"))
313            .map_err(Into::into)
314    }
315
316    /// Set arbitrary data to be retrieved by [TransactionContext::get_data].
317    /// This is meant to be used as a check point at the end of a transaction.
318    /// The data should be limited to a reasonable size.
319    pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
320        self.keystore()
321            .await?
322            .save(ConsumerData::from(data))
323            .await
324            .map_err(KeystoreError::wrap("saving consumer data"))?;
325        Ok(())
326    }
327
328    /// Get the data that has previously been set by [TransactionContext::set_data].
329    /// This is meant to be used as a check point at the end of a transaction.
330    pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
331        match self.keystore().await?.find_unique::<ConsumerData>().await {
332            Ok(data) => Ok(Some(data.into())),
333            Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
334            Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
335        }
336    }
337}