core_crypto/transaction_context/
mod.rs1use std::sync::Arc;
5
6#[cfg(feature = "proteus")]
7use async_lock::Mutex;
8use async_lock::{RwLock, RwLockWriteGuardArc};
9use core_crypto_keystore::{CryptoKeystoreError, entities::ConsumerData, traits::FetchFromDatabase as _};
10pub use error::{Error, Result};
11use openmls_traits::OpenMlsCryptoProvider as _;
12use wire_e2e_identity::pki_env::PkiEnvironment;
13
14#[cfg(feature = "proteus")]
15use crate::proteus::ProteusCentral;
16use crate::{
17 ClientId, CoreCrypto, CredentialFindFilters, CredentialRef, KeystoreError, MlsConversation, MlsError, MlsTransport,
18 RecursiveError, Session,
19 group_store::GroupStore,
20 mls::{self, HasSessionAndCrypto},
21 mls_provider::{Database, MlsCryptoProvider},
22};
23pub mod conversation;
24mod credential;
25pub mod e2e_identity;
26mod error;
27pub mod key_package;
28#[cfg(feature = "proteus")]
29pub mod proteus;
30#[cfg(test)]
31pub mod test_utils;
32
33#[derive(Debug, Clone)]
40pub struct TransactionContext {
41 inner: Arc<RwLock<TransactionContextInner>>,
42}
43
44#[derive(Debug, Clone)]
48enum TransactionContextInner {
49 Valid {
50 pki_environment: Arc<RwLock<Option<PkiEnvironment>>>,
51 database: Database,
52 mls_session: Arc<RwLock<Option<Session<Database>>>>,
53 mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
54 #[cfg(feature = "proteus")]
55 proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
56 },
57 Invalid,
58}
59
60impl CoreCrypto {
61 pub async fn new_transaction(&self) -> Result<TransactionContext> {
65 TransactionContext::new(
66 self.database.clone(),
67 self.pki_environment.clone(),
68 self.mls.clone(),
69 #[cfg(feature = "proteus")]
70 self.proteus.clone(),
71 )
72 .await
73 }
74}
75
76#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
77#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
78impl HasSessionAndCrypto for TransactionContext {
79 async fn session(&self) -> crate::mls::Result<Session<Database>> {
80 self.session()
81 .await
82 .map_err(RecursiveError::transaction("getting mls client"))
83 .map_err(Into::into)
84 }
85
86 async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
87 self.mls_provider()
88 .await
89 .map_err(RecursiveError::transaction("getting mls provider"))
90 .map_err(Into::into)
91 }
92}
93
94impl TransactionContext {
95 async fn new(
96 keystore: Database,
97 pki_environment: Arc<RwLock<Option<PkiEnvironment>>>,
98 mls_session: Arc<RwLock<Option<Session<Database>>>>,
99 #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
100 ) -> Result<Self> {
101 keystore
102 .new_transaction()
103 .await
104 .map_err(MlsError::wrap("creating new transaction"))?;
105 let mls_groups = Arc::new(RwLock::new(Default::default()));
106 Ok(Self {
107 inner: Arc::new(
108 TransactionContextInner::Valid {
109 database: keystore,
110 pki_environment,
111 mls_session: mls_session.clone(),
112 mls_groups,
113 #[cfg(feature = "proteus")]
114 proteus_central,
115 }
116 .into(),
117 ),
118 })
119 }
120
121 pub(crate) async fn session(&self) -> Result<Session<Database>> {
122 match &*self.inner.read().await {
123 TransactionContextInner::Valid { mls_session, .. } => mls_session.read().await.as_ref().cloned().ok_or(
124 RecursiveError::mls_client("Getting mls session from transaction context")(
125 mls::session::Error::MlsNotInitialized,
126 )
127 .into(),
128 ),
129 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
130 }
131 }
132
133 #[cfg(test)]
134 pub(crate) async fn set_session_if_exists(&self, new_session: Session<Database>) {
135 match &*self.inner.read().await {
136 TransactionContextInner::Valid { mls_session, .. } => {
137 let mut guard = mls_session.write().await;
138
139 if guard.as_ref().is_some() {
140 *guard = Some(new_session)
141 }
142 }
143 TransactionContextInner::Invalid => {}
144 }
145 }
146
147 pub(crate) async fn mls_transport(&self) -> Result<Arc<dyn MlsTransport + 'static>> {
148 match &*self.inner.read().await {
149 TransactionContextInner::Valid { mls_session, .. } => {
150 mls_session.read().await.as_ref().map(|s| s.transport.clone()).ok_or(
151 RecursiveError::mls_client("Getting mls session from transaction context")(
152 mls::session::Error::MlsNotInitialized,
153 )
154 .into(),
155 )
156 }
157
158 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
159 }
160 }
161
162 pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
164 match &*self.inner.read().await {
165 TransactionContextInner::Valid { mls_session, .. } => mls_session
166 .read()
167 .await
168 .as_ref()
169 .map(|s| s.crypto_provider.clone())
170 .ok_or(
171 RecursiveError::mls_client("Getting mls session from transaction context")(
172 mls::session::Error::MlsNotInitialized,
173 )
174 .into(),
175 ),
176 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
177 }
178 }
179
180 pub(crate) async fn database(&self) -> Result<Database> {
181 match &*self.inner.read().await {
182 TransactionContextInner::Valid { database, .. } => Ok(database.clone()),
183 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
184 }
185 }
186
187 pub(crate) async fn pki_environment(&self) -> Result<PkiEnvironment> {
188 match &*self.inner.read().await {
189 TransactionContextInner::Valid { pki_environment, .. } => {
190 pki_environment.read().await.as_ref().map(Clone::clone).ok_or(
191 RecursiveError::transaction("Getting PKI environment from transaction context")(
192 e2e_identity::Error::PkiEnvironmentUnset,
193 )
194 .into(),
195 )
196 }
197 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
198 }
199 }
200
201 pub(crate) async fn pki_environment_option(&self) -> Result<Option<PkiEnvironment>> {
202 match &*self.inner.read().await {
203 TransactionContextInner::Valid { pki_environment, .. } => Ok(pki_environment.read().await.clone()),
204
205 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
206 }
207 }
208
209 pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
210 match &*self.inner.read().await {
211 TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
212 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
213 }
214 }
215
216 #[cfg(feature = "proteus")]
217 pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
218 match &*self.inner.read().await {
219 TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
220 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
221 }
222 }
223
224 pub async fn finish(&self) -> Result<()> {
228 let mut guard = self.inner.write().await;
229 let TransactionContextInner::Valid { database: keystore, .. } = &*guard else {
230 return Err(Error::InvalidTransactionContext);
231 };
232
233 let commit_result = keystore
234 .commit_transaction()
235 .await
236 .map_err(KeystoreError::wrap("commiting transaction"))
237 .map_err(Into::into);
238
239 *guard = TransactionContextInner::Invalid;
240 commit_result
241 }
242
243 pub async fn abort(&self) -> Result<()> {
247 let mut guard = self.inner.write().await;
248
249 let TransactionContextInner::Valid { database: keystore, .. } = &*guard else {
250 return Err(Error::InvalidTransactionContext);
251 };
252
253 let result = keystore
254 .rollback_transaction()
255 .await
256 .map_err(KeystoreError::wrap("rolling back transaction"))
257 .map_err(Into::into);
258
259 *guard = TransactionContextInner::Invalid;
260 result
261 }
262
263 pub async fn mls_init(&self, session_id: ClientId, transport: Arc<dyn MlsTransport>) -> Result<()> {
265 let database = self.database().await?;
266
267 let pki_env_provider = self
268 .pki_environment_option()
269 .await?
270 .map(|pki_env| pki_env.mls_pki_env_provider())
271 .unwrap_or_default();
272
273 let crypto_provider = MlsCryptoProvider::new_with_pki_env(database, pki_env_provider);
274 let database = self.database().await?;
275 let session = Session::new(session_id.clone(), crypto_provider, database, transport);
276 self.set_mls_session(session).await?;
277
278 Ok(())
279 }
280
281 pub(crate) async fn set_mls_session(&self, session: Session<Database>) -> Result<()> {
283 match &*self.inner.read().await {
284 TransactionContextInner::Valid { mls_session, .. } => {
285 let mut guard = mls_session.write().await;
286 *guard = Some(session);
287 Ok(())
288 }
289 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
290 }
291 }
292
293 pub async fn client_id(&self) -> Result<ClientId> {
295 let session = self.session().await?;
296 Ok(session.id())
297 }
298
299 pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
301 use openmls_traits::random::OpenMlsRand as _;
302 self.mls_provider()
303 .await?
304 .rand()
305 .random_vec(len)
306 .map_err(MlsError::wrap("generating random vector"))
307 .map_err(Into::into)
308 }
309
310 pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
314 self.database()
315 .await?
316 .save(ConsumerData::from(data))
317 .await
318 .map_err(KeystoreError::wrap("saving consumer data"))?;
319 Ok(())
320 }
321
322 pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
325 match self.database().await?.get_unique::<ConsumerData>().await {
326 Ok(maybe_data) => Ok(maybe_data.map(Into::into)),
327 Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
328 Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
329 }
330 }
331
332 pub async fn find_credentials(&self, find_filters: CredentialFindFilters<'_>) -> Result<Vec<CredentialRef>> {
336 self.session()
337 .await?
338 .find_credentials(find_filters)
339 .await
340 .map_err(RecursiveError::mls_client("finding credentials by filter"))
341 .map_err(Into::into)
342 }
343
344 pub async fn get_credentials(&self) -> Result<Vec<CredentialRef>> {
348 self.session()
349 .await?
350 .get_credentials()
351 .await
352 .map_err(RecursiveError::mls_client("getting all credentials"))
353 .map_err(Into::into)
354 }
355}