core_crypto/
context.rs

1//! This module contains the primitives to enable transactional support on a higher level within the
2//! [MlsCentral]. All mutating operations need to be done through a [CentralContext].
3
4use crate::mls::MlsCentral;
5#[cfg(feature = "proteus")]
6use crate::proteus::ProteusCentral;
7use crate::{
8    group_store::GroupStore,
9    prelude::{Client, MlsConversation},
10    CoreCrypto, CoreCryptoCallbacks, CryptoError, CryptoResult,
11};
12use async_lock::{Mutex, RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
13use core_crypto_keystore::connection::FetchFromDatabase;
14use core_crypto_keystore::entities::ConsumerData;
15use core_crypto_keystore::CryptoKeystoreError;
16use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
17use std::{ops::Deref, sync::Arc};
18
19/// This struct provides transactional support for Core Crypto.
20///
21/// This is struct provides mutable access to the internals of Core Crypto. Every operation that
22/// causes data to be persisted needs to be done through this struct. This struct will buffer all
23/// operations in memory and when [CentralContext::finish] is called, it will persist the data into
24/// the keystore.
25#[derive(Debug, Clone)]
26pub struct CentralContext {
27    state: Arc<RwLock<ContextState>>,
28}
29
30/// Due to uniffi's design, we can't force the context to be dropped after the transaction is
31/// committed. To work around that we switch the value to `Invalid` when the context is finished
32/// and throw errors if something is called
33#[derive(Debug, Clone)]
34enum ContextState {
35    Valid {
36        provider: MlsCryptoProvider,
37        callbacks: Arc<RwLock<Option<std::sync::Arc<dyn CoreCryptoCallbacks + 'static>>>>,
38        mls_client: Client,
39        mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
40        #[cfg(feature = "proteus")]
41        proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
42    },
43    Invalid,
44}
45
46impl CoreCrypto {
47    /// Creates a new transaction. All operations that persist data will be
48    /// buffered in memory and when [CentralContext::finish] is called, the data will be persisted
49    /// in a single database transaction.
50    pub async fn new_transaction(&self) -> CryptoResult<CentralContext> {
51        CentralContext::new(
52            &self.mls,
53            #[cfg(feature = "proteus")]
54            self.proteus.clone(),
55        )
56        .await
57    }
58}
59
60impl CentralContext {
61    async fn new(
62        mls_central: &MlsCentral,
63        #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
64    ) -> CryptoResult<Self> {
65        mls_central.mls_backend.new_transaction().await?;
66        let mls_groups = Arc::new(RwLock::new(Default::default()));
67        let callbacks = mls_central.callbacks.clone();
68        let mls_client = mls_central.mls_client.clone();
69        Ok(Self {
70            state: Arc::new(
71                ContextState::Valid {
72                    mls_client,
73                    callbacks,
74                    provider: mls_central.mls_backend.clone(),
75                    mls_groups,
76                    #[cfg(feature = "proteus")]
77                    proteus_central,
78                }
79                .into(),
80            ),
81        })
82    }
83
84    pub(crate) async fn mls_client(&self) -> CryptoResult<Client> {
85        match self.state.read().await.deref() {
86            ContextState::Valid { mls_client, .. } => Ok(mls_client.clone()),
87            ContextState::Invalid => Err(CryptoError::InvalidContext),
88        }
89    }
90
91    pub(crate) async fn callbacks(
92        &self,
93    ) -> CryptoResult<RwLockReadGuardArc<Option<Arc<dyn CoreCryptoCallbacks + 'static>>>> {
94        match self.state.read().await.deref() {
95            ContextState::Valid { callbacks, .. } => Ok(callbacks.read_arc().await),
96            ContextState::Invalid => Err(CryptoError::InvalidContext),
97        }
98    }
99
100    #[cfg(test)]
101    pub(crate) async fn set_callbacks(
102        &self,
103        callbacks: Option<Arc<dyn CoreCryptoCallbacks + 'static>>,
104    ) -> CryptoResult<()> {
105        match self.state.read().await.deref() {
106            ContextState::Valid { callbacks: cbs, .. } => {
107                *cbs.write_arc().await = callbacks;
108                Ok(())
109            }
110            ContextState::Invalid => Err(CryptoError::InvalidContext),
111        }
112    }
113
114    /// Clones all references that the [MlsCryptoProvider] comprises.
115    pub async fn mls_provider(&self) -> CryptoResult<MlsCryptoProvider> {
116        match self.state.read().await.deref() {
117            ContextState::Valid { provider, .. } => Ok(provider.clone()),
118            ContextState::Invalid => Err(CryptoError::InvalidContext),
119        }
120    }
121
122    pub(crate) async fn keystore(&self) -> CryptoResult<CryptoKeystore> {
123        match self.state.read().await.deref() {
124            ContextState::Valid { provider, .. } => Ok(provider.keystore()),
125            ContextState::Invalid => Err(CryptoError::InvalidContext),
126        }
127    }
128
129    pub(crate) async fn mls_groups(&self) -> CryptoResult<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
130        match self.state.read().await.deref() {
131            ContextState::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
132            ContextState::Invalid => Err(CryptoError::InvalidContext),
133        }
134    }
135
136    #[cfg(feature = "proteus")]
137    pub(crate) async fn proteus_central(&self) -> CryptoResult<Arc<Mutex<Option<ProteusCentral>>>> {
138        match self.state.read().await.deref() {
139            ContextState::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
140            ContextState::Invalid => Err(CryptoError::InvalidContext),
141        }
142    }
143
144    /// Commits the transaction, meaning it takes all the enqueued operations and persist them into
145    /// the keystore. After that the internal state is switched to invalid, causing errors if
146    /// something is called from this object.
147    pub async fn finish(&self) -> CryptoResult<()> {
148        let mut guard = self.state.write().await;
149        let commit_result = match guard.deref() {
150            ContextState::Valid { provider, .. } => provider.keystore().commit_transaction().await,
151            ContextState::Invalid => return Err(CryptoError::InvalidContext),
152        };
153        *guard = ContextState::Invalid;
154        commit_result.map_err(Into::into)
155    }
156
157    /// Aborts the transaction, meaning it discards all the enqueued operations.
158    /// After that the internal state is switched to invalid, causing errors if
159    /// something is called from this object.
160    pub async fn abort(&self) -> CryptoResult<()> {
161        let mut guard = self.state.write().await;
162        let rollback_result = match guard.deref() {
163            ContextState::Valid { provider, .. } => provider.keystore().rollback_transaction().await,
164            ContextState::Invalid => return Err(CryptoError::InvalidContext),
165        };
166        *guard = ContextState::Invalid;
167        rollback_result.map_err(Into::into)
168    }
169
170    /// Set arbitrary data to be retrieved by [CentralContext::get_data].
171    /// This is meant to be used as a check point at the end of a transaction.
172    /// The data should be limited to a reasonable size.
173    pub async fn set_data(&self, data: Vec<u8>) -> CryptoResult<()> {
174        self.keystore().await?.save(ConsumerData::from(data)).await?;
175        Ok(())
176    }
177
178    /// Get the data that has previously been set by [CentralContext::set_data].
179    /// This is meant to be used as a check point at the end of a transaction.
180    pub async fn get_data(&self) -> CryptoResult<Option<Vec<u8>>> {
181        match self.keystore().await?.find_unique::<ConsumerData>().await {
182            Ok(data) => Ok(Some(data.into())),
183            Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
184            Err(err) => Err(err.into()),
185        }
186    }
187}