core_crypto/transaction_context/
mod.rs1use std::sync::Arc;
5
6#[cfg(feature = "proteus")]
7use async_lock::Mutex;
8use async_lock::{RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
9use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
10pub use error::{Error, Result};
11use mls_crypto_provider::{Database, MlsCryptoProvider};
12use openmls_traits::OpenMlsCryptoProvider as _;
13
14#[cfg(feature = "proteus")]
15use crate::proteus::ProteusCentral;
16use crate::{
17 Ciphersuite, ClientId, ClientIdentifier, CoreCrypto, Credential, CredentialFindFilters, CredentialRef,
18 CredentialType, KeystoreError, MlsConversation, MlsError, MlsTransport, RecursiveError, Session,
19 group_store::GroupStore, mls::HasSessionAndCrypto,
20};
21pub mod conversation;
22pub mod e2e_identity;
23mod error;
24pub mod key_package;
25#[cfg(feature = "proteus")]
26pub mod proteus;
27#[cfg(test)]
28pub mod test_utils;
29
30#[derive(Debug, Clone)]
37pub struct TransactionContext {
38 inner: Arc<RwLock<TransactionContextInner>>,
39}
40
41#[derive(Debug, Clone)]
45enum TransactionContextInner {
46 Valid {
47 provider: MlsCryptoProvider,
48 transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
49 mls_client: Session,
50 mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
51 #[cfg(feature = "proteus")]
52 proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
53 },
54 Invalid,
55}
56
57impl CoreCrypto {
58 pub async fn new_transaction(&self) -> Result<TransactionContext> {
62 TransactionContext::new(
63 &self.mls,
64 #[cfg(feature = "proteus")]
65 self.proteus.clone(),
66 )
67 .await
68 }
69}
70
71#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
72#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
73impl HasSessionAndCrypto for TransactionContext {
74 async fn session(&self) -> crate::mls::Result<Session> {
75 self.session()
76 .await
77 .map_err(RecursiveError::transaction("getting mls client"))
78 .map_err(Into::into)
79 }
80
81 async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
82 self.mls_provider()
83 .await
84 .map_err(RecursiveError::transaction("getting mls provider"))
85 .map_err(Into::into)
86 }
87}
88
89impl TransactionContext {
90 async fn new(
91 client: &Session,
92 #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
93 ) -> Result<Self> {
94 client
95 .crypto_provider
96 .new_transaction()
97 .await
98 .map_err(MlsError::wrap("creating new transaction"))?;
99 let mls_groups = Arc::new(RwLock::new(Default::default()));
100 let callbacks = client.transport.clone();
101 let mls_client = client.clone();
102 Ok(Self {
103 inner: Arc::new(
104 TransactionContextInner::Valid {
105 mls_client,
106 transport: callbacks,
107 provider: client.crypto_provider.clone(),
108 mls_groups,
109 #[cfg(feature = "proteus")]
110 proteus_central,
111 }
112 .into(),
113 ),
114 })
115 }
116
117 pub(crate) async fn session(&self) -> Result<Session> {
118 match &*self.inner.read().await {
119 TransactionContextInner::Valid { mls_client, .. } => Ok(mls_client.clone()),
120 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
121 }
122 }
123
124 pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
125 match &*self.inner.read().await {
126 TransactionContextInner::Valid {
127 transport: callbacks, ..
128 } => Ok(callbacks.read_arc().await),
129 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
130 }
131 }
132
133 #[cfg(test)]
134 pub(crate) async fn set_transport_callbacks(
135 &self,
136 callbacks: Option<Arc<dyn MlsTransport + 'static>>,
137 ) -> Result<()> {
138 match &*self.inner.read().await {
139 TransactionContextInner::Valid { transport: cbs, .. } => {
140 *cbs.write_arc().await = callbacks;
141 Ok(())
142 }
143 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
144 }
145 }
146
147 pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
149 match &*self.inner.read().await {
150 TransactionContextInner::Valid { provider, .. } => Ok(provider.clone()),
151 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
152 }
153 }
154
155 pub(crate) async fn keystore(&self) -> Result<Database> {
156 match &*self.inner.read().await {
157 TransactionContextInner::Valid { provider, .. } => Ok(provider.keystore()),
158 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
159 }
160 }
161
162 pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
163 match &*self.inner.read().await {
164 TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
165 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
166 }
167 }
168
169 #[cfg(feature = "proteus")]
170 pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
171 match &*self.inner.read().await {
172 TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
173 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
174 }
175 }
176
177 pub async fn finish(&self) -> Result<()> {
181 let mut guard = self.inner.write().await;
182 let TransactionContextInner::Valid { provider, .. } = &*guard else {
183 return Err(Error::InvalidTransactionContext);
184 };
185
186 let commit_result = provider
187 .keystore()
188 .commit_transaction()
189 .await
190 .map_err(KeystoreError::wrap("commiting transaction"))
191 .map_err(Into::into);
192
193 *guard = TransactionContextInner::Invalid;
194 commit_result
195 }
196
197 pub async fn abort(&self) -> Result<()> {
201 let mut guard = self.inner.write().await;
202
203 let TransactionContextInner::Valid { provider, .. } = &*guard else {
204 return Err(Error::InvalidTransactionContext);
205 };
206
207 let result = provider
208 .keystore()
209 .rollback_transaction()
210 .await
211 .map_err(KeystoreError::wrap("rolling back transaction"))
212 .map_err(Into::into);
213
214 *guard = TransactionContextInner::Invalid;
215 result
216 }
217
218 pub async fn mls_init(&self, identifier: ClientIdentifier, ciphersuites: &[Ciphersuite]) -> Result<()> {
220 let mls_client = self.session().await?;
221 mls_client
222 .init(
223 identifier,
224 &ciphersuites
225 .iter()
226 .map(|ciphersuite| ciphersuite.signature_algorithm())
227 .collect::<Vec<_>>(),
228 )
229 .await
230 .map_err(RecursiveError::mls_client("initializing mls client"))?;
231
232 if mls_client.is_e2ei_capable().await {
233 let client_id = mls_client
234 .id()
235 .await
236 .map_err(RecursiveError::mls_client("getting client id"))?;
237 log::trace!(client_id:% = client_id; "Initializing PKI environment");
238 self.init_pki_env().await?;
239 }
240
241 Ok(())
242 }
243
244 pub async fn client_public_key(
246 &self,
247 ciphersuite: Ciphersuite,
248 credential_type: CredentialType,
249 ) -> Result<Vec<u8>> {
250 let cb = self
251 .session()
252 .await?
253 .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type)
254 .await
255 .map_err(RecursiveError::mls_client("finding most recent credential"))?;
256 Ok(cb.signature_key_pair.to_public_vec())
257 }
258
259 pub async fn client_id(&self) -> Result<ClientId> {
261 self.session()
262 .await?
263 .id()
264 .await
265 .map_err(RecursiveError::mls_client("getting client id"))
266 .map_err(Into::into)
267 }
268
269 pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
271 use openmls_traits::random::OpenMlsRand as _;
272 self.mls_provider()
273 .await?
274 .rand()
275 .random_vec(len)
276 .map_err(MlsError::wrap("generating random vector"))
277 .map_err(Into::into)
278 }
279
280 pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
284 self.keystore()
285 .await?
286 .save(ConsumerData::from(data))
287 .await
288 .map_err(KeystoreError::wrap("saving consumer data"))?;
289 Ok(())
290 }
291
292 pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
295 match self.keystore().await?.find_unique::<ConsumerData>().await {
296 Ok(data) => Ok(Some(data.into())),
297 Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
298 Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
299 }
300 }
301
302 pub async fn add_credential(&self, credential: Credential) -> Result<CredentialRef> {
306 self.session()
307 .await?
308 .add_credential(credential)
309 .await
310 .map_err(RecursiveError::mls_client("adding credential to session"))
311 .map_err(Into::into)
312 }
313
314 pub async fn remove_credential(&self, credential_ref: &CredentialRef) -> Result<()> {
318 self.session()
319 .await?
320 .remove_credential(credential_ref)
321 .await
322 .map_err(RecursiveError::mls_client("removing credential from session"))
323 .map_err(Into::into)
324 }
325
326 pub async fn find_credentials(&self, find_filters: CredentialFindFilters<'_>) -> Result<Vec<CredentialRef>> {
330 self.session()
331 .await?
332 .find_credentials(find_filters)
333 .await
334 .map_err(RecursiveError::mls_client("finding credentials by filter"))
335 .map_err(Into::into)
336 }
337
338 pub async fn get_credentials(&self) -> Result<Vec<CredentialRef>> {
342 self.session()
343 .await?
344 .get_credentials()
345 .await
346 .map_err(RecursiveError::mls_client("getting all credentials"))
347 .map_err(Into::into)
348 }
349}