core_crypto/
context.rs

1//! This module contains the primitives to enable transactional support on a higher level within the
2//! [MlsCentral]. All mutating operations need to be done through a [CentralContext].
3
4use crate::mls::HasClientAndProvider;
5#[cfg(feature = "proteus")]
6use crate::proteus::ProteusCentral;
7use crate::{
8    CoreCrypto, Error, KeystoreError, MlsError, MlsTransport, RecursiveError, Result,
9    group_store::GroupStore,
10    mls::MlsCentral,
11    prelude::{Client, MlsConversation},
12};
13use async_lock::{Mutex, RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
14use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
15use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
16use std::{ops::Deref, sync::Arc};
17
18/// This struct provides transactional support for Core Crypto.
19///
20/// This is struct provides mutable access to the internals of Core Crypto. Every operation that
21/// causes data to be persisted needs to be done through this struct. This struct will buffer all
22/// operations in memory and when [CentralContext::finish] is called, it will persist the data into
23/// the keystore.
24#[derive(Debug, Clone)]
25pub struct CentralContext {
26    state: Arc<RwLock<ContextState>>,
27}
28
29/// Due to uniffi's design, we can't force the context to be dropped after the transaction is
30/// committed. To work around that we switch the value to `Invalid` when the context is finished
31/// and throw errors if something is called
32#[derive(Debug, Clone)]
33enum ContextState {
34    Valid {
35        provider: MlsCryptoProvider,
36        transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
37        mls_client: Client,
38        mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
39        #[cfg(feature = "proteus")]
40        proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
41    },
42    Invalid,
43}
44
45impl CoreCrypto {
46    /// Creates a new transaction. All operations that persist data will be
47    /// buffered in memory and when [CentralContext::finish] is called, the data will be persisted
48    /// in a single database transaction.
49    pub async fn new_transaction(&self) -> Result<CentralContext> {
50        CentralContext::new(
51            &self.mls,
52            #[cfg(feature = "proteus")]
53            self.proteus.clone(),
54        )
55        .await
56    }
57}
58
59#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
60#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
61impl HasClientAndProvider for CentralContext {
62    async fn client(&self) -> crate::mls::Result<Client> {
63        self.mls_client()
64            .await
65            .map_err(RecursiveError::root("getting mls client"))
66            .map_err(Into::into)
67    }
68
69    async fn mls_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
70        self.mls_provider()
71            .await
72            .map_err(RecursiveError::root("getting mls provider"))
73            .map_err(Into::into)
74    }
75}
76
77impl CentralContext {
78    async fn new(
79        mls_central: &MlsCentral,
80        #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
81    ) -> Result<Self> {
82        mls_central
83            .mls_backend
84            .new_transaction()
85            .await
86            .map_err(MlsError::wrap("creating new transaction"))?;
87        let mls_groups = Arc::new(RwLock::new(Default::default()));
88        let callbacks = mls_central.transport.clone();
89        let mls_client = mls_central.mls_client.clone();
90        Ok(Self {
91            state: Arc::new(
92                ContextState::Valid {
93                    mls_client,
94                    transport: callbacks,
95                    provider: mls_central.mls_backend.clone(),
96                    mls_groups,
97                    #[cfg(feature = "proteus")]
98                    proteus_central,
99                }
100                .into(),
101            ),
102        })
103    }
104
105    pub(crate) async fn mls_client(&self) -> Result<Client> {
106        match self.state.read().await.deref() {
107            ContextState::Valid { mls_client, .. } => Ok(mls_client.clone()),
108            ContextState::Invalid => Err(Error::InvalidContext),
109        }
110    }
111
112    pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
113        match self.state.read().await.deref() {
114            ContextState::Valid {
115                transport: callbacks, ..
116            } => Ok(callbacks.read_arc().await),
117            ContextState::Invalid => Err(Error::InvalidContext),
118        }
119    }
120
121    #[cfg(test)]
122    pub(crate) async fn set_transport_callbacks(
123        &self,
124        callbacks: Option<Arc<dyn MlsTransport + 'static>>,
125    ) -> Result<()> {
126        match self.state.read().await.deref() {
127            ContextState::Valid { transport: cbs, .. } => {
128                *cbs.write_arc().await = callbacks;
129                Ok(())
130            }
131            ContextState::Invalid => Err(Error::InvalidContext),
132        }
133    }
134
135    /// Clones all references that the [MlsCryptoProvider] comprises.
136    pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
137        match self.state.read().await.deref() {
138            ContextState::Valid { provider, .. } => Ok(provider.clone()),
139            ContextState::Invalid => Err(Error::InvalidContext),
140        }
141    }
142
143    pub(crate) async fn keystore(&self) -> Result<CryptoKeystore> {
144        match self.state.read().await.deref() {
145            ContextState::Valid { provider, .. } => Ok(provider.keystore()),
146            ContextState::Invalid => Err(Error::InvalidContext),
147        }
148    }
149
150    pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
151        match self.state.read().await.deref() {
152            ContextState::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
153            ContextState::Invalid => Err(Error::InvalidContext),
154        }
155    }
156
157    #[cfg(feature = "proteus")]
158    pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
159        match self.state.read().await.deref() {
160            ContextState::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
161            ContextState::Invalid => Err(Error::InvalidContext),
162        }
163    }
164
165    /// Commits the transaction, meaning it takes all the enqueued operations and persist them into
166    /// the keystore. After that the internal state is switched to invalid, causing errors if
167    /// something is called from this object.
168    pub async fn finish(&self) -> Result<()> {
169        let mut guard = self.state.write().await;
170        let ContextState::Valid { provider, .. } = guard.deref() else {
171            return Err(Error::InvalidContext);
172        };
173
174        let commit_result = provider
175            .keystore()
176            .commit_transaction()
177            .await
178            .map_err(KeystoreError::wrap("commiting transaction"))
179            .map_err(Into::into);
180
181        *guard = ContextState::Invalid;
182        commit_result
183    }
184
185    /// Aborts the transaction, meaning it discards all the enqueued operations.
186    /// After that the internal state is switched to invalid, causing errors if
187    /// something is called from this object.
188    pub async fn abort(&self) -> Result<()> {
189        let mut guard = self.state.write().await;
190
191        let ContextState::Valid { provider, .. } = guard.deref() else {
192            return Err(Error::InvalidContext);
193        };
194
195        let result = provider
196            .keystore()
197            .rollback_transaction()
198            .await
199            .map_err(KeystoreError::wrap("rolling back transaction"))
200            .map_err(Into::into);
201
202        *guard = ContextState::Invalid;
203        result
204    }
205
206    /// Set arbitrary data to be retrieved by [CentralContext::get_data].
207    /// This is meant to be used as a check point at the end of a transaction.
208    /// The data should be limited to a reasonable size.
209    pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
210        self.keystore()
211            .await?
212            .save(ConsumerData::from(data))
213            .await
214            .map_err(KeystoreError::wrap("saving consumer data"))?;
215        Ok(())
216    }
217
218    /// Get the data that has previously been set by [CentralContext::set_data].
219    /// This is meant to be used as a check point at the end of a transaction.
220    pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
221        match self.keystore().await?.find_unique::<ConsumerData>().await {
222            Ok(data) => Ok(Some(data.into())),
223            Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
224            Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
225        }
226    }
227}