core_crypto/transaction_context/
mod.rs1use crate::mls::HasSessionAndCrypto;
5#[cfg(feature = "proteus")]
6use crate::proteus::ProteusCentral;
7use crate::{
8 CoreCrypto, KeystoreError, MlsError, MlsTransport, RecursiveError,
9 group_store::GroupStore,
10 prelude::{MlsConversation, Session},
11};
12use async_lock::{Mutex, RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
13use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
14pub use error::{Error, Result};
15use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
16use std::{ops::Deref, sync::Arc};
17pub mod conversation;
18pub mod e2e_identity;
19mod error;
20#[cfg(test)]
21pub mod test_utils;
22
23#[derive(Debug, Clone)]
30pub struct TransactionContext {
31 inner: Arc<RwLock<TransactionContextInner>>,
32}
33
34#[derive(Debug, Clone)]
38enum TransactionContextInner {
39 Valid {
40 provider: MlsCryptoProvider,
41 transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
42 mls_client: Session,
43 mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
44 #[cfg(feature = "proteus")]
45 proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
46 },
47 Invalid,
48}
49
50impl CoreCrypto {
51 pub async fn new_transaction(&self) -> Result<TransactionContext> {
55 TransactionContext::new(
56 &self.mls,
57 #[cfg(feature = "proteus")]
58 self.proteus.clone(),
59 )
60 .await
61 }
62}
63
64#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
65#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
66impl HasSessionAndCrypto for TransactionContext {
67 async fn session(&self) -> crate::mls::Result<Session> {
68 self.session()
69 .await
70 .map_err(RecursiveError::transaction("getting mls client"))
71 .map_err(Into::into)
72 }
73
74 async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
75 self.mls_provider()
76 .await
77 .map_err(RecursiveError::transaction("getting mls provider"))
78 .map_err(Into::into)
79 }
80}
81
82impl TransactionContext {
83 async fn new(
84 client: &Session,
85 #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
86 ) -> Result<Self> {
87 client
88 .crypto_provider
89 .new_transaction()
90 .await
91 .map_err(MlsError::wrap("creating new transaction"))?;
92 let mls_groups = Arc::new(RwLock::new(Default::default()));
93 let callbacks = client.transport.clone();
94 let mls_client = client.clone();
95 Ok(Self {
96 inner: Arc::new(
97 TransactionContextInner::Valid {
98 mls_client,
99 transport: callbacks,
100 provider: client.crypto_provider.clone(),
101 mls_groups,
102 #[cfg(feature = "proteus")]
103 proteus_central,
104 }
105 .into(),
106 ),
107 })
108 }
109
110 pub(crate) async fn session(&self) -> Result<Session> {
111 match self.inner.read().await.deref() {
112 TransactionContextInner::Valid { mls_client, .. } => Ok(mls_client.clone()),
113 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
114 }
115 }
116
117 pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
118 match self.inner.read().await.deref() {
119 TransactionContextInner::Valid {
120 transport: callbacks, ..
121 } => Ok(callbacks.read_arc().await),
122 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
123 }
124 }
125
126 #[cfg(test)]
127 pub(crate) async fn set_transport_callbacks(
128 &self,
129 callbacks: Option<Arc<dyn MlsTransport + 'static>>,
130 ) -> Result<()> {
131 match self.inner.read().await.deref() {
132 TransactionContextInner::Valid { transport: cbs, .. } => {
133 *cbs.write_arc().await = callbacks;
134 Ok(())
135 }
136 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
137 }
138 }
139
140 pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
142 match self.inner.read().await.deref() {
143 TransactionContextInner::Valid { provider, .. } => Ok(provider.clone()),
144 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
145 }
146 }
147
148 pub(crate) async fn keystore(&self) -> Result<CryptoKeystore> {
149 match self.inner.read().await.deref() {
150 TransactionContextInner::Valid { provider, .. } => Ok(provider.keystore()),
151 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
152 }
153 }
154
155 pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
156 match self.inner.read().await.deref() {
157 TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
158 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
159 }
160 }
161
162 #[cfg(feature = "proteus")]
163 pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
164 match self.inner.read().await.deref() {
165 TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
166 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
167 }
168 }
169
170 pub async fn finish(&self) -> Result<()> {
174 let mut guard = self.inner.write().await;
175 let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
176 return Err(Error::InvalidTransactionContext);
177 };
178
179 let commit_result = provider
180 .keystore()
181 .commit_transaction()
182 .await
183 .map_err(KeystoreError::wrap("commiting transaction"))
184 .map_err(Into::into);
185
186 *guard = TransactionContextInner::Invalid;
187 commit_result
188 }
189
190 pub async fn abort(&self) -> Result<()> {
194 let mut guard = self.inner.write().await;
195
196 let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
197 return Err(Error::InvalidTransactionContext);
198 };
199
200 let result = provider
201 .keystore()
202 .rollback_transaction()
203 .await
204 .map_err(KeystoreError::wrap("rolling back transaction"))
205 .map_err(Into::into);
206
207 *guard = TransactionContextInner::Invalid;
208 result
209 }
210
211 pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
215 self.keystore()
216 .await?
217 .save(ConsumerData::from(data))
218 .await
219 .map_err(KeystoreError::wrap("saving consumer data"))?;
220 Ok(())
221 }
222
223 pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
226 match self.keystore().await?.find_unique::<ConsumerData>().await {
227 Ok(data) => Ok(Some(data.into())),
228 Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
229 Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
230 }
231 }
232}