core_crypto/
context.rs

1//! This module contains the primitives to enable transactional support on a higher level within the
2//! [Client]. All mutating operations need to be done through a [CentralContext].
3
4use crate::mls::HasClientAndProvider;
5#[cfg(feature = "proteus")]
6use crate::proteus::ProteusCentral;
7use crate::{
8    CoreCrypto, Error, KeystoreError, MlsError, MlsTransport, RecursiveError, Result,
9    group_store::GroupStore,
10    prelude::{Client, MlsConversation},
11};
12use async_lock::{Mutex, RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
13use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
14use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
15use std::{ops::Deref, sync::Arc};
16
17/// This struct provides transactional support for Core Crypto.
18///
19/// This is struct provides mutable access to the internals of Core Crypto. Every operation that
20/// causes data to be persisted needs to be done through this struct. This struct will buffer all
21/// operations in memory and when [CentralContext::finish] is called, it will persist the data into
22/// the keystore.
23#[derive(Debug, Clone)]
24pub struct CentralContext {
25    state: Arc<RwLock<ContextState>>,
26}
27
28/// Due to uniffi's design, we can't force the context to be dropped after the transaction is
29/// committed. To work around that we switch the value to `Invalid` when the context is finished
30/// and throw errors if something is called
31#[derive(Debug, Clone)]
32enum ContextState {
33    Valid {
34        provider: MlsCryptoProvider,
35        transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
36        mls_client: Client,
37        mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
38        #[cfg(feature = "proteus")]
39        proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
40    },
41    Invalid,
42}
43
44impl CoreCrypto {
45    /// Creates a new transaction. All operations that persist data will be
46    /// buffered in memory and when [CentralContext::finish] is called, the data will be persisted
47    /// in a single database transaction.
48    pub async fn new_transaction(&self) -> Result<CentralContext> {
49        CentralContext::new(
50            &self.mls,
51            #[cfg(feature = "proteus")]
52            self.proteus.clone(),
53        )
54        .await
55    }
56}
57
58#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
59#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
60impl HasClientAndProvider for CentralContext {
61    async fn client(&self) -> crate::mls::Result<Client> {
62        self.mls_client()
63            .await
64            .map_err(RecursiveError::root("getting mls client"))
65            .map_err(Into::into)
66    }
67
68    async fn mls_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
69        self.mls_provider()
70            .await
71            .map_err(RecursiveError::root("getting mls provider"))
72            .map_err(Into::into)
73    }
74}
75
76impl CentralContext {
77    async fn new(
78        client: &Client,
79        #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
80    ) -> Result<Self> {
81        client
82            .mls_backend
83            .new_transaction()
84            .await
85            .map_err(MlsError::wrap("creating new transaction"))?;
86        let mls_groups = Arc::new(RwLock::new(Default::default()));
87        let callbacks = client.transport.clone();
88        let mls_client = client.clone();
89        Ok(Self {
90            state: Arc::new(
91                ContextState::Valid {
92                    mls_client,
93                    transport: callbacks,
94                    provider: client.mls_backend.clone(),
95                    mls_groups,
96                    #[cfg(feature = "proteus")]
97                    proteus_central,
98                }
99                .into(),
100            ),
101        })
102    }
103
104    pub(crate) async fn mls_client(&self) -> Result<Client> {
105        match self.state.read().await.deref() {
106            ContextState::Valid { mls_client, .. } => Ok(mls_client.clone()),
107            ContextState::Invalid => Err(Error::InvalidContext),
108        }
109    }
110
111    pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
112        match self.state.read().await.deref() {
113            ContextState::Valid {
114                transport: callbacks, ..
115            } => Ok(callbacks.read_arc().await),
116            ContextState::Invalid => Err(Error::InvalidContext),
117        }
118    }
119
120    #[cfg(test)]
121    pub(crate) async fn set_transport_callbacks(
122        &self,
123        callbacks: Option<Arc<dyn MlsTransport + 'static>>,
124    ) -> Result<()> {
125        match self.state.read().await.deref() {
126            ContextState::Valid { transport: cbs, .. } => {
127                *cbs.write_arc().await = callbacks;
128                Ok(())
129            }
130            ContextState::Invalid => Err(Error::InvalidContext),
131        }
132    }
133
134    /// Clones all references that the [MlsCryptoProvider] comprises.
135    pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
136        match self.state.read().await.deref() {
137            ContextState::Valid { provider, .. } => Ok(provider.clone()),
138            ContextState::Invalid => Err(Error::InvalidContext),
139        }
140    }
141
142    pub(crate) async fn keystore(&self) -> Result<CryptoKeystore> {
143        match self.state.read().await.deref() {
144            ContextState::Valid { provider, .. } => Ok(provider.keystore()),
145            ContextState::Invalid => Err(Error::InvalidContext),
146        }
147    }
148
149    pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
150        match self.state.read().await.deref() {
151            ContextState::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
152            ContextState::Invalid => Err(Error::InvalidContext),
153        }
154    }
155
156    #[cfg(feature = "proteus")]
157    pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
158        match self.state.read().await.deref() {
159            ContextState::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
160            ContextState::Invalid => Err(Error::InvalidContext),
161        }
162    }
163
164    /// Commits the transaction, meaning it takes all the enqueued operations and persist them into
165    /// the keystore. After that the internal state is switched to invalid, causing errors if
166    /// something is called from this object.
167    pub async fn finish(&self) -> Result<()> {
168        let mut guard = self.state.write().await;
169        let ContextState::Valid { provider, .. } = guard.deref() else {
170            return Err(Error::InvalidContext);
171        };
172
173        let commit_result = provider
174            .keystore()
175            .commit_transaction()
176            .await
177            .map_err(KeystoreError::wrap("commiting transaction"))
178            .map_err(Into::into);
179
180        *guard = ContextState::Invalid;
181        commit_result
182    }
183
184    /// Aborts the transaction, meaning it discards all the enqueued operations.
185    /// After that the internal state is switched to invalid, causing errors if
186    /// something is called from this object.
187    pub async fn abort(&self) -> Result<()> {
188        let mut guard = self.state.write().await;
189
190        let ContextState::Valid { provider, .. } = guard.deref() else {
191            return Err(Error::InvalidContext);
192        };
193
194        let result = provider
195            .keystore()
196            .rollback_transaction()
197            .await
198            .map_err(KeystoreError::wrap("rolling back transaction"))
199            .map_err(Into::into);
200
201        *guard = ContextState::Invalid;
202        result
203    }
204
205    /// Set arbitrary data to be retrieved by [CentralContext::get_data].
206    /// This is meant to be used as a check point at the end of a transaction.
207    /// The data should be limited to a reasonable size.
208    pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
209        self.keystore()
210            .await?
211            .save(ConsumerData::from(data))
212            .await
213            .map_err(KeystoreError::wrap("saving consumer data"))?;
214        Ok(())
215    }
216
217    /// Get the data that has previously been set by [CentralContext::set_data].
218    /// This is meant to be used as a check point at the end of a transaction.
219    pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
220        match self.keystore().await?.find_unique::<ConsumerData>().await {
221            Ok(data) => Ok(Some(data.into())),
222            Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
223            Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
224        }
225    }
226}