core_crypto/transaction_context/
mod.rs1use std::sync::Arc;
5
6use async_lock::{Mutex, MutexGuardArc, RwLock};
7use core_crypto_keystore::{CryptoKeystoreError, entities::ConsumerData, traits::FetchFromDatabase as _};
8pub use error::{Error, Result};
9use openmls_traits::OpenMlsCryptoProvider as _;
10use wire_e2e_identity::pki_env::PkiEnvironment;
11
12use crate::{
13 ClientId, ConversationId, CoreCrypto, KeystoreError, MlsTransport, OpenMlsError, RecursiveError, Session,
14 mls::{self, conversation_cache::ConversationCache},
15 mls_provider::{CryptoProvider, Database},
16};
17pub mod conversation;
18mod credential;
19pub mod e2e_identity;
20mod error;
21pub mod key_package;
22#[cfg(feature = "proteus")]
23pub mod proteus;
24#[cfg(test)]
25pub mod test_utils;
26
27#[derive(Debug, Clone)]
34pub struct TransactionContext {
35 inner: Arc<RwLock<TransactionContextInner>>,
36}
37
38#[derive(Debug, Clone)]
42enum TransactionContextInner {
43 Valid {
44 core_crypto: Arc<CoreCrypto>,
45 pending_epoch_changes: Arc<Mutex<Vec<(ConversationId, u64)>>>,
46 },
47 Invalid,
48}
49
50impl CoreCrypto {
51 pub async fn new_transaction(self: &Arc<Self>) -> Result<TransactionContext> {
55 TransactionContext::new(self.clone()).await
56 }
57}
58
59impl TransactionContext {
60 async fn new(core_crypto: Arc<CoreCrypto>) -> Result<Self> {
61 core_crypto
62 .database
63 .new_transaction()
64 .await
65 .map_err(OpenMlsError::wrap("creating new transaction"))?;
66 Ok(Self {
67 inner: Arc::new(
68 TransactionContextInner::Valid {
69 core_crypto,
70 pending_epoch_changes: Default::default(),
71 }
72 .into(),
73 ),
74 })
75 }
76
77 pub(crate) async fn session(&self) -> Result<Session> {
78 let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else {
79 return Err(Error::InvalidTransactionContext);
80 };
81 core_crypto.mls.read().await.as_ref().cloned().ok_or(
82 RecursiveError::mls_client("Getting mls session from transaction context")(
83 mls::session::Error::MlsNotInitialized,
84 )
85 .into(),
86 )
87 }
88
89 #[cfg(test)]
90 pub(crate) async fn set_session_if_exists(&self, new_session: Session) {
91 let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else {
92 return;
93 };
94
95 let mut guard = core_crypto.mls.write().await;
96 if guard.as_ref().is_some() {
97 *guard = Some(new_session)
98 }
99 }
100
101 pub(crate) async fn mls_transport(&self) -> Result<Arc<dyn MlsTransport + 'static>> {
102 match &*self.inner.read().await {
103 TransactionContextInner::Valid { core_crypto, .. } => core_crypto
104 .mls
105 .read()
106 .await
107 .as_ref()
108 .map(|s| s.transport.clone())
109 .ok_or(
110 RecursiveError::mls_client("Getting mls session from transaction context")(
111 mls::session::Error::MlsNotInitialized,
112 )
113 .into(),
114 ),
115
116 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
117 }
118 }
119
120 pub async fn crypto_provider(&self) -> Result<CryptoProvider> {
122 match &*self.inner.read().await {
123 TransactionContextInner::Valid { core_crypto, .. } => core_crypto
124 .mls
125 .read()
126 .await
127 .as_ref()
128 .map(|s| s.crypto_provider.clone())
129 .ok_or(
130 RecursiveError::mls_client("Getting mls session from transaction context")(
131 mls::session::Error::MlsNotInitialized,
132 )
133 .into(),
134 ),
135 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
136 }
137 }
138
139 pub(crate) async fn database(&self) -> Result<Database> {
140 match &*self.inner.read().await {
141 TransactionContextInner::Valid { core_crypto, .. } => Ok(core_crypto.database.clone()),
142 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
143 }
144 }
145
146 pub(crate) async fn pki_environment(&self) -> Result<Arc<PkiEnvironment>> {
147 match &*self.inner.read().await {
148 TransactionContextInner::Valid { core_crypto, .. } => core_crypto
149 .pki_environment
150 .read()
151 .await
152 .as_ref()
153 .map(Clone::clone)
154 .ok_or(
155 RecursiveError::transaction("getting PKI environment from transaction context")(
156 e2e_identity::Error::PkiEnvironmentUnset,
157 )
158 .into(),
159 ),
160 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
161 }
162 }
163
164 pub(crate) async fn mls_groups(&self) -> Result<MutexGuardArc<ConversationCache>> {
165 let guard = self.inner.read().await;
166 let TransactionContextInner::Valid { core_crypto, .. } = &*guard else {
167 return Err(Error::InvalidTransactionContext);
168 };
169 let cache = core_crypto
170 .mls
171 .read()
172 .await
173 .as_ref()
174 .map(|session| session.conversation_cache.clone())
175 .ok_or_else(|| {
176 RecursiveError::mls_client("getting mls session from transaction context")(
177 mls::session::Error::MlsNotInitialized,
178 )
179 })?;
180
181 Ok(cache.lock_arc().await)
182 }
183
184 pub(crate) async fn queue_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) -> Result<()> {
185 match &*self.inner.read().await {
186 TransactionContextInner::Valid {
187 pending_epoch_changes, ..
188 } => {
189 pending_epoch_changes.lock().await.push((conversation_id, epoch));
190 Ok(())
191 }
192 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
193 }
194 }
195
196 pub async fn finish(&self) -> Result<()> {
200 let mut guard = self.inner.write().await;
201 let TransactionContextInner::Valid {
202 core_crypto,
203 pending_epoch_changes,
204 ..
205 } = &*guard
206 else {
207 return Err(Error::InvalidTransactionContext);
208 };
209
210 let commit_result = core_crypto
211 .database
212 .commit_transaction()
213 .await
214 .map_err(KeystoreError::wrap("commiting transaction"))
215 .map_err(Into::into);
216
217 if let Some(session) = core_crypto.mls.read().await.as_ref() {
218 if commit_result.is_ok() {
219 let mut epoch_changes = pending_epoch_changes.lock().await;
222 for (conversation_id, epoch) in epoch_changes.drain(..) {
223 session.notify_epoch_changed(conversation_id, epoch).await;
224 }
225 } else {
226 session.conversation_cache.lock().await.clear();
230 }
231 }
232
233 *guard = TransactionContextInner::Invalid;
234 commit_result
235 }
236
237 pub async fn abort(&self) -> Result<()> {
241 let mut guard = self.inner.write().await;
242
243 let TransactionContextInner::Valid { core_crypto, .. } = &*guard else {
244 return Err(Error::InvalidTransactionContext);
245 };
246
247 if let Some(session) = core_crypto.mls.read().await.as_ref() {
250 session.conversation_cache.lock().await.clear();
251 }
252
253 let result = core_crypto
254 .database
255 .rollback_transaction()
256 .await
257 .map_err(KeystoreError::wrap("rolling back transaction"))
258 .map_err(Into::into);
259
260 *guard = TransactionContextInner::Invalid;
261 result
262 }
263
264 pub async fn mls_init(&self, session_id: ClientId, transport: Arc<dyn MlsTransport>) -> Result<()> {
266 let database = self.database().await?;
267 let pki_env = self.pki_environment().await.ok();
268 let crypto_provider = CryptoProvider::new_with_pki_env(database.clone(), pki_env);
269 let session = Session::new(session_id.clone(), crypto_provider, database.into(), transport);
270 self.set_mls_session(session).await?;
271
272 Ok(())
273 }
274
275 pub(crate) async fn set_mls_session(&self, session: Session) -> Result<()> {
277 match &*self.inner.read().await {
278 TransactionContextInner::Valid { core_crypto, .. } => {
279 let mut guard = core_crypto.mls.write().await;
280 *guard = Some(session);
281 Ok(())
282 }
283 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
284 }
285 }
286
287 pub async fn client_id(&self) -> Result<ClientId> {
289 let session = self.session().await?;
290 Ok(session.id())
291 }
292
293 pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
295 use openmls_traits::random::OpenMlsRand as _;
296 self.crypto_provider()
297 .await?
298 .rand()
299 .random_vec(len)
300 .map_err(OpenMlsError::wrap("generating random vector"))
301 .map_err(Into::into)
302 }
303
304 pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
308 self.database()
309 .await?
310 .save(ConsumerData::from(data))
311 .await
312 .map_err(KeystoreError::wrap("saving consumer data"))?;
313 Ok(())
314 }
315
316 pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
319 match self.database().await?.get_unique::<ConsumerData>().await {
320 Ok(maybe_data) => Ok(maybe_data.map(Into::into)),
321 Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
322 Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
323 }
324 }
325}