core_crypto/transaction_context/
mod.rs1use std::sync::Arc;
5
6#[cfg(feature = "proteus")]
7use async_lock::Mutex;
8use async_lock::{RwLock, RwLockWriteGuardArc};
9use core_crypto_keystore::{CryptoKeystoreError, entities::ConsumerData, traits::FetchFromDatabase as _};
10pub use error::{Error, Result};
11use mls_crypto_provider::{Database, MlsCryptoProvider};
12use openmls_traits::OpenMlsCryptoProvider as _;
13
14#[cfg(feature = "proteus")]
15use crate::proteus::ProteusCentral;
16use crate::{
17 Ciphersuite, ClientId, ClientIdentifier, CoreCrypto, Credential, CredentialFindFilters, CredentialRef,
18 CredentialType, KeystoreError, MlsConversation, MlsError, MlsTransport, RecursiveError, Session,
19 group_store::GroupStore,
20 mls::{
21 self, HasSessionAndCrypto,
22 session::{Error as SessionError, identities::Identities},
23 },
24};
25pub mod conversation;
26pub mod e2e_identity;
27mod error;
28pub mod key_package;
29#[cfg(feature = "proteus")]
30pub mod proteus;
31#[cfg(test)]
32pub mod test_utils;
33
34#[derive(Debug, Clone)]
41pub struct TransactionContext {
42 inner: Arc<RwLock<TransactionContextInner>>,
43}
44
45#[derive(Debug, Clone)]
49enum TransactionContextInner {
50 Valid {
51 keystore: Database,
52 mls_session: Arc<RwLock<Option<Session>>>,
53 mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
54 #[cfg(feature = "proteus")]
55 proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
56 },
57 Invalid,
58}
59
60impl CoreCrypto {
61 pub async fn new_transaction(&self) -> Result<TransactionContext> {
65 TransactionContext::new(
66 self.database.clone(),
67 self.mls.clone(),
68 #[cfg(feature = "proteus")]
69 self.proteus.clone(),
70 )
71 .await
72 }
73}
74
75#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
76#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
77impl HasSessionAndCrypto for TransactionContext {
78 async fn session(&self) -> crate::mls::Result<Session> {
79 self.session()
80 .await
81 .map_err(RecursiveError::transaction("getting mls client"))
82 .map_err(Into::into)
83 }
84
85 async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
86 self.mls_provider()
87 .await
88 .map_err(RecursiveError::transaction("getting mls provider"))
89 .map_err(Into::into)
90 }
91}
92
93impl TransactionContext {
94 async fn new(
95 keystore: Database,
96 mls_session: Arc<RwLock<Option<Session>>>,
97 #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
98 ) -> Result<Self> {
99 keystore
100 .new_transaction()
101 .await
102 .map_err(MlsError::wrap("creating new transaction"))?;
103 let mls_groups = Arc::new(RwLock::new(Default::default()));
104 Ok(Self {
105 inner: Arc::new(
106 TransactionContextInner::Valid {
107 keystore,
108 mls_session: mls_session.clone(),
109 mls_groups,
110 #[cfg(feature = "proteus")]
111 proteus_central,
112 }
113 .into(),
114 ),
115 })
116 }
117
118 pub(crate) async fn session(&self) -> Result<Session> {
119 match &*self.inner.read().await {
120 TransactionContextInner::Valid { mls_session, .. } => {
121 if let Some(session) = mls_session.read().await.as_ref() {
122 return Ok(session.clone());
123 }
124 Err(mls::session::Error::MlsNotInitialized)
125 .map_err(RecursiveError::mls_client(
126 "Getting mls session from transaction context",
127 ))
128 .map_err(Into::into)
129 }
130 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
131 }
132 }
133
134 #[cfg(test)]
135 pub(crate) async fn set_session_if_exists(&self, new_session: Session) {
136 match &*self.inner.read().await {
137 TransactionContextInner::Valid { mls_session, .. } => {
138 let mut guard = mls_session.write().await;
139
140 if guard.as_ref().is_some() {
141 *guard = Some(new_session)
142 }
143 }
144 TransactionContextInner::Invalid => {}
145 }
146 }
147
148 pub(crate) async fn mls_transport(&self) -> Result<Arc<dyn MlsTransport + 'static>> {
149 match &*self.inner.read().await {
150 TransactionContextInner::Valid { mls_session, .. } => {
151 if let Some(session) = mls_session.read().await.as_ref() {
152 let transport = session.transport.clone();
153 return Ok(transport);
154 }
155 Err(mls::session::Error::MlsNotInitialized)
156 .map_err(RecursiveError::mls_client(
157 "Getting mls session from transaction context",
158 ))
159 .map_err(Into::into)
160 }
161
162 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
163 }
164 }
165
166 pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
168 match &*self.inner.read().await {
169 TransactionContextInner::Valid { mls_session, .. } => {
170 if let Some(session) = mls_session.read().await.as_ref() {
171 return Ok(session.crypto_provider.clone());
172 }
173 Err(mls::session::Error::MlsNotInitialized)
174 .map_err(RecursiveError::mls_client(
175 "Getting mls session from transaction context",
176 ))
177 .map_err(Into::into)
178 }
179 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
180 }
181 }
182
183 pub(crate) async fn keystore(&self) -> Result<Database> {
184 match &*self.inner.read().await {
185 TransactionContextInner::Valid { keystore, .. } => Ok(keystore.clone()),
186 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
187 }
188 }
189
190 pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
191 match &*self.inner.read().await {
192 TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
193 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
194 }
195 }
196
197 #[cfg(feature = "proteus")]
198 pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
199 match &*self.inner.read().await {
200 TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
201 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
202 }
203 }
204
205 pub async fn finish(&self) -> Result<()> {
209 let mut guard = self.inner.write().await;
210 let TransactionContextInner::Valid { keystore, .. } = &*guard else {
211 return Err(Error::InvalidTransactionContext);
212 };
213
214 let commit_result = keystore
215 .commit_transaction()
216 .await
217 .map_err(KeystoreError::wrap("commiting transaction"))
218 .map_err(Into::into);
219
220 *guard = TransactionContextInner::Invalid;
221 commit_result
222 }
223
224 pub async fn abort(&self) -> Result<()> {
228 let mut guard = self.inner.write().await;
229
230 let TransactionContextInner::Valid { keystore, .. } = &*guard else {
231 return Err(Error::InvalidTransactionContext);
232 };
233
234 let result = keystore
235 .rollback_transaction()
236 .await
237 .map_err(KeystoreError::wrap("rolling back transaction"))
238 .map_err(Into::into);
239
240 *guard = TransactionContextInner::Invalid;
241 result
242 }
243
244 async fn init(&self, identifier: ClientIdentifier, ciphersuites: &[Ciphersuite]) -> Result<(ClientId, Identities)> {
248 let database = self.keystore().await?;
249 let client_id = identifier
250 .get_id()
251 .map_err(RecursiveError::mls_client("getting client id"))?
252 .into_owned();
253
254 let signature_schemes = &ciphersuites
255 .iter()
256 .map(|ciphersuite| ciphersuite.signature_algorithm())
257 .collect::<Vec<_>>();
258
259 let mut credential_refs = CredentialRef::find(
268 &database,
269 CredentialFindFilters::builder().client_id(&client_id).build(),
270 )
271 .await
272 .map_err(RecursiveError::mls_credential_ref(
273 "loading matching credential refs while initializing a client",
274 ))?;
275 credential_refs.retain(|credential_ref| signature_schemes.contains(&credential_ref.signature_scheme()));
276
277 let mut identities = Identities::new(credential_refs.len());
278 let credentials_cache =
279 CredentialRef::load_stored_credentials(&database)
280 .await
281 .map_err(RecursiveError::mls_credential_ref(
282 "loading credential ref cache while initializing session",
283 ))?;
284
285 for credential_ref in credential_refs {
286 if let Some(credential) =
287 credential_ref
288 .load_from_cache(&credentials_cache)
289 .map_err(RecursiveError::mls_credential_ref(
290 "loading credential list in session init",
291 ))?
292 {
293 match identities.push_credential(credential).await {
294 Err(SessionError::CredentialConflict) => {
295 }
298 Ok(_) => {}
299 Err(err) => {
300 return Err(RecursiveError::MlsClient {
301 context: "adding credential to identities in init",
302 source: Box::new(err),
303 }
304 .into());
305 }
306 }
307 }
308 }
309
310 Ok((client_id, identities))
311 }
312
313 pub async fn mls_init(
315 &self,
316 identifier: ClientIdentifier,
317 ciphersuites: &[Ciphersuite],
318 transport: Arc<dyn MlsTransport>,
319 ) -> Result<()> {
320 let database = self.keystore().await?;
321 let (client_id, identities) = self.init(identifier, ciphersuites).await?;
322
323 let mls_backend = MlsCryptoProvider::new(database);
324 let session = Session::new(client_id.clone(), identities, mls_backend, transport);
325
326 if session.is_e2ei_capable().await {
327 log::trace!(client_id:% = client_id; "Initializing PKI environment");
328 self.init_pki_env().await?;
329 }
330
331 self.set_mls_session(session).await?;
332
333 Ok(())
334 }
335
336 pub(crate) async fn set_mls_session(&self, session: Session) -> Result<()> {
338 match &*self.inner.read().await {
339 TransactionContextInner::Valid { mls_session, .. } => {
340 let mut guard = mls_session.write().await;
341 *guard = Some(session);
342 Ok(())
343 }
344 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
345 }
346 }
347
348 pub async fn client_public_key(
350 &self,
351 ciphersuite: Ciphersuite,
352 credential_type: CredentialType,
353 ) -> Result<Vec<u8>> {
354 let cb = self
355 .session()
356 .await?
357 .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type)
358 .await
359 .map_err(RecursiveError::mls_client("finding most recent credential"))?;
360 Ok(cb.signature_key_pair.to_public_vec())
361 }
362
363 pub async fn client_id(&self) -> Result<ClientId> {
365 let session = self.session().await?;
366 Ok(session.id())
367 }
368
369 pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
371 use openmls_traits::random::OpenMlsRand as _;
372 self.mls_provider()
373 .await?
374 .rand()
375 .random_vec(len)
376 .map_err(MlsError::wrap("generating random vector"))
377 .map_err(Into::into)
378 }
379
380 pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
384 self.keystore()
385 .await?
386 .save(ConsumerData::from(data))
387 .await
388 .map_err(KeystoreError::wrap("saving consumer data"))?;
389 Ok(())
390 }
391
392 pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
395 match self.keystore().await?.get_unique::<ConsumerData>().await {
396 Ok(maybe_data) => Ok(maybe_data.map(Into::into)),
397 Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
398 Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
399 }
400 }
401
402 pub async fn add_credential(&self, credential: Credential) -> Result<CredentialRef> {
406 self.session()
407 .await?
408 .add_credential(credential)
409 .await
410 .map_err(RecursiveError::mls_client("adding credential to session"))
411 .map_err(Into::into)
412 }
413
414 pub async fn remove_credential(&self, credential_ref: &CredentialRef) -> Result<()> {
420 self.session()
421 .await?
422 .remove_credential(credential_ref)
423 .await
424 .map_err(RecursiveError::mls_client("removing credential from session"))
425 .map_err(Into::into)
426 }
427
428 pub async fn find_credentials(&self, find_filters: CredentialFindFilters<'_>) -> Result<Vec<CredentialRef>> {
432 self.session()
433 .await?
434 .find_credentials(find_filters)
435 .await
436 .map_err(RecursiveError::mls_client("finding credentials by filter"))
437 .map_err(Into::into)
438 }
439
440 pub async fn get_credentials(&self) -> Result<Vec<CredentialRef>> {
444 self.session()
445 .await?
446 .get_credentials()
447 .await
448 .map_err(RecursiveError::mls_client("getting all credentials"))
449 .map_err(Into::into)
450 }
451}