core_crypto/transaction_context/
mod.rs1use std::sync::Arc;
5
6use async_lock::{Mutex, RwLock, RwLockWriteGuardArc};
7use core_crypto_keystore::{CryptoKeystoreError, entities::ConsumerData, traits::FetchFromDatabase as _};
8pub use error::{Error, Result};
9use openmls_traits::OpenMlsCryptoProvider as _;
10use wire_e2e_identity::pki_env::PkiEnvironment;
11
12#[cfg(feature = "proteus")]
13use crate::proteus::ProteusCentral;
14use crate::{
15 ClientId, ConversationId, CoreCrypto, CredentialFindFilters, CredentialRef, KeystoreError, MlsConversation,
16 MlsError, MlsTransport, RecursiveError, Session,
17 group_store::GroupStore,
18 mls::{self, HasSessionAndCrypto},
19 mls_provider::{Database, MlsCryptoProvider},
20};
21pub mod conversation;
22mod credential;
23pub mod e2e_identity;
24mod error;
25pub mod key_package;
26#[cfg(feature = "proteus")]
27pub mod proteus;
28#[cfg(test)]
29pub mod test_utils;
30
31#[derive(Debug, Clone)]
38pub struct TransactionContext {
39 inner: Arc<RwLock<TransactionContextInner>>,
40}
41
42#[derive(Debug, Clone)]
46enum TransactionContextInner {
47 Valid {
48 pki_environment: Arc<RwLock<Option<PkiEnvironment>>>,
49 database: Database,
50 mls_session: Arc<RwLock<Option<Session<Database>>>>,
51 mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
52 pending_epoch_changes: Arc<Mutex<Vec<(ConversationId, u64)>>>,
53 #[cfg(feature = "proteus")]
54 proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
55 },
56 Invalid,
57}
58
59impl CoreCrypto {
60 pub async fn new_transaction(&self) -> Result<TransactionContext> {
64 TransactionContext::new(
65 self.database.clone(),
66 self.pki_environment.clone(),
67 self.mls.clone(),
68 #[cfg(feature = "proteus")]
69 self.proteus.clone(),
70 )
71 .await
72 }
73}
74
75#[cfg_attr(target_os = "unknown", async_trait::async_trait(?Send))]
76#[cfg_attr(not(target_os = "unknown"), async_trait::async_trait)]
77impl HasSessionAndCrypto for TransactionContext {
78 async fn session(&self) -> crate::mls::Result<Session<Database>> {
79 self.session()
80 .await
81 .map_err(RecursiveError::transaction("getting mls client"))
82 .map_err(Into::into)
83 }
84
85 async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
86 self.mls_provider()
87 .await
88 .map_err(RecursiveError::transaction("getting mls provider"))
89 .map_err(Into::into)
90 }
91}
92
93impl TransactionContext {
94 async fn new(
95 keystore: Database,
96 pki_environment: Arc<RwLock<Option<PkiEnvironment>>>,
97 mls_session: Arc<RwLock<Option<Session<Database>>>>,
98 #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
99 ) -> Result<Self> {
100 keystore
101 .new_transaction()
102 .await
103 .map_err(MlsError::wrap("creating new transaction"))?;
104 let mls_groups = Arc::new(RwLock::new(Default::default()));
105 Ok(Self {
106 inner: Arc::new(
107 TransactionContextInner::Valid {
108 database: keystore,
109 pki_environment,
110 mls_session: mls_session.clone(),
111 mls_groups,
112 pending_epoch_changes: Default::default(),
113 #[cfg(feature = "proteus")]
114 proteus_central,
115 }
116 .into(),
117 ),
118 })
119 }
120
121 pub(crate) async fn session(&self) -> Result<Session<Database>> {
122 match &*self.inner.read().await {
123 TransactionContextInner::Valid { mls_session, .. } => mls_session.read().await.as_ref().cloned().ok_or(
124 RecursiveError::mls_client("Getting mls session from transaction context")(
125 mls::session::Error::MlsNotInitialized,
126 )
127 .into(),
128 ),
129 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
130 }
131 }
132
133 #[cfg(test)]
134 pub(crate) async fn set_session_if_exists(&self, new_session: Session<Database>) {
135 match &*self.inner.read().await {
136 TransactionContextInner::Valid { mls_session, .. } => {
137 let mut guard = mls_session.write().await;
138
139 if guard.as_ref().is_some() {
140 *guard = Some(new_session)
141 }
142 }
143 TransactionContextInner::Invalid => {}
144 }
145 }
146
147 pub(crate) async fn mls_transport(&self) -> Result<Arc<dyn MlsTransport + 'static>> {
148 match &*self.inner.read().await {
149 TransactionContextInner::Valid { mls_session, .. } => {
150 mls_session.read().await.as_ref().map(|s| s.transport.clone()).ok_or(
151 RecursiveError::mls_client("Getting mls session from transaction context")(
152 mls::session::Error::MlsNotInitialized,
153 )
154 .into(),
155 )
156 }
157
158 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
159 }
160 }
161
162 pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
164 match &*self.inner.read().await {
165 TransactionContextInner::Valid { mls_session, .. } => mls_session
166 .read()
167 .await
168 .as_ref()
169 .map(|s| s.crypto_provider.clone())
170 .ok_or(
171 RecursiveError::mls_client("Getting mls session from transaction context")(
172 mls::session::Error::MlsNotInitialized,
173 )
174 .into(),
175 ),
176 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
177 }
178 }
179
180 pub(crate) async fn database(&self) -> Result<Database> {
181 match &*self.inner.read().await {
182 TransactionContextInner::Valid { database, .. } => Ok(database.clone()),
183 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
184 }
185 }
186
187 pub(crate) async fn pki_environment(&self) -> Result<PkiEnvironment> {
188 match &*self.inner.read().await {
189 TransactionContextInner::Valid { pki_environment, .. } => {
190 pki_environment.read().await.as_ref().map(Clone::clone).ok_or(
191 RecursiveError::transaction("Getting PKI environment from transaction context")(
192 e2e_identity::Error::PkiEnvironmentUnset,
193 )
194 .into(),
195 )
196 }
197 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
198 }
199 }
200
201 pub(crate) async fn pki_environment_option(&self) -> Result<Option<PkiEnvironment>> {
202 match &*self.inner.read().await {
203 TransactionContextInner::Valid { pki_environment, .. } => Ok(pki_environment.read().await.clone()),
204
205 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
206 }
207 }
208
209 pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
210 match &*self.inner.read().await {
211 TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
212 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
213 }
214 }
215
216 pub(crate) async fn queue_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) -> Result<()> {
217 match &*self.inner.read().await {
218 TransactionContextInner::Valid {
219 pending_epoch_changes, ..
220 } => {
221 pending_epoch_changes.lock().await.push((conversation_id, epoch));
222 Ok(())
223 }
224 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
225 }
226 }
227
228 #[cfg(feature = "proteus")]
229 pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
230 match &*self.inner.read().await {
231 TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
232 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
233 }
234 }
235
236 pub async fn finish(&self) -> Result<()> {
240 let mut guard = self.inner.write().await;
241 let TransactionContextInner::Valid {
242 database,
243 pending_epoch_changes,
244 mls_session,
245 ..
246 } = &*guard
247 else {
248 return Err(Error::InvalidTransactionContext);
249 };
250
251 let commit_result = database
252 .commit_transaction()
253 .await
254 .map_err(KeystoreError::wrap("commiting transaction"))
255 .map_err(Into::into);
256
257 if let Some(session) = mls_session.read_arc().await.clone()
258 && commit_result.is_ok()
259 {
260 let mut epoch_changes = pending_epoch_changes.lock().await;
263 let epoch_changes = epoch_changes.drain(..);
264 for (conversation_id, epoch) in epoch_changes {
265 session.notify_epoch_changed(conversation_id, epoch).await;
266 }
267 }
268
269 *guard = TransactionContextInner::Invalid;
270 commit_result
271 }
272
273 pub async fn abort(&self) -> Result<()> {
277 let mut guard = self.inner.write().await;
278
279 let TransactionContextInner::Valid { database: keystore, .. } = &*guard else {
280 return Err(Error::InvalidTransactionContext);
281 };
282
283 let result = keystore
284 .rollback_transaction()
285 .await
286 .map_err(KeystoreError::wrap("rolling back transaction"))
287 .map_err(Into::into);
288
289 *guard = TransactionContextInner::Invalid;
290 result
291 }
292
293 pub async fn mls_init(&self, session_id: ClientId, transport: Arc<dyn MlsTransport>) -> Result<()> {
295 let database = self.database().await?;
296
297 let pki_env_provider = self
298 .pki_environment_option()
299 .await?
300 .map(|pki_env| pki_env.mls_pki_env_provider())
301 .unwrap_or_default();
302
303 let crypto_provider = MlsCryptoProvider::new_with_pki_env(database.clone(), pki_env_provider);
304 let session = Session::new(session_id.clone(), crypto_provider, database, transport);
305 self.set_mls_session(session).await?;
306
307 Ok(())
308 }
309
310 pub(crate) async fn set_mls_session(&self, session: Session<Database>) -> Result<()> {
312 match &*self.inner.read().await {
313 TransactionContextInner::Valid { mls_session, .. } => {
314 let mut guard = mls_session.write().await;
315 *guard = Some(session);
316 Ok(())
317 }
318 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
319 }
320 }
321
322 pub async fn client_id(&self) -> Result<ClientId> {
324 let session = self.session().await?;
325 Ok(session.id())
326 }
327
328 pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
330 use openmls_traits::random::OpenMlsRand as _;
331 self.mls_provider()
332 .await?
333 .rand()
334 .random_vec(len)
335 .map_err(MlsError::wrap("generating random vector"))
336 .map_err(Into::into)
337 }
338
339 pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
343 self.database()
344 .await?
345 .save(ConsumerData::from(data))
346 .await
347 .map_err(KeystoreError::wrap("saving consumer data"))?;
348 Ok(())
349 }
350
351 pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
354 match self.database().await?.get_unique::<ConsumerData>().await {
355 Ok(maybe_data) => Ok(maybe_data.map(Into::into)),
356 Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
357 Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
358 }
359 }
360
361 pub async fn find_credentials(&self, find_filters: CredentialFindFilters<'_>) -> Result<Vec<CredentialRef>> {
365 self.session()
366 .await?
367 .find_credentials(find_filters)
368 .await
369 .map_err(RecursiveError::mls_client("finding credentials by filter"))
370 .map_err(Into::into)
371 }
372
373 pub async fn get_credentials(&self) -> Result<Vec<CredentialRef>> {
377 self.session()
378 .await?
379 .get_credentials()
380 .await
381 .map_err(RecursiveError::mls_client("getting all credentials"))
382 .map_err(Into::into)
383 }
384}