core_crypto/transaction_context/
mod.rs1#[cfg(feature = "proteus")]
5use crate::proteus::ProteusCentral;
6use crate::{
7 CoreCrypto, KeystoreError, MlsError, MlsTransport, RecursiveError,
8 group_store::GroupStore,
9 prelude::{ClientId, ConversationId, INITIAL_KEYING_MATERIAL_COUNT, MlsConversation, MlsCredentialType, Session},
10};
11use crate::{
12 mls::HasSessionAndCrypto,
13 prelude::{ClientIdentifier, MlsCiphersuite},
14};
15use async_lock::{Mutex, RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
16use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
17pub use error::{Error, Result};
18use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
19use openmls_traits::OpenMlsCryptoProvider as _;
20use std::{ops::Deref, sync::Arc};
21pub mod conversation;
22pub mod e2e_identity;
23mod error;
24pub mod key_package;
25#[cfg(feature = "proteus")]
26pub mod proteus;
27#[cfg(test)]
28pub mod test_utils;
29
30#[derive(Debug, Clone)]
37pub struct TransactionContext {
38 inner: Arc<RwLock<TransactionContextInner>>,
39}
40
41#[derive(Debug, Clone)]
45enum TransactionContextInner {
46 Valid {
47 provider: MlsCryptoProvider,
48 transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
49 mls_client: Session,
50 mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
51 pending_epoch_changes: Arc<Mutex<Vec<(ConversationId, u64)>>>,
52 #[cfg(feature = "proteus")]
53 proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
54 },
55 Invalid,
56}
57
58impl CoreCrypto {
59 pub async fn new_transaction(&self) -> Result<TransactionContext> {
63 TransactionContext::new(
64 &self.mls,
65 #[cfg(feature = "proteus")]
66 self.proteus.clone(),
67 )
68 .await
69 }
70}
71
72#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
73#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
74impl HasSessionAndCrypto for TransactionContext {
75 async fn session(&self) -> crate::mls::Result<Session> {
76 self.session()
77 .await
78 .map_err(RecursiveError::transaction("getting mls client"))
79 .map_err(Into::into)
80 }
81
82 async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
83 self.mls_provider()
84 .await
85 .map_err(RecursiveError::transaction("getting mls provider"))
86 .map_err(Into::into)
87 }
88}
89
90impl TransactionContext {
91 async fn new(
92 client: &Session,
93 #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
94 ) -> Result<Self> {
95 client
96 .crypto_provider
97 .new_transaction()
98 .await
99 .map_err(MlsError::wrap("creating new transaction"))?;
100 let mls_groups = Arc::new(RwLock::new(Default::default()));
101 let callbacks = client.transport.clone();
102 let mls_client = client.clone();
103 Ok(Self {
104 inner: Arc::new(
105 TransactionContextInner::Valid {
106 mls_client,
107 transport: callbacks,
108 provider: client.crypto_provider.clone(),
109 mls_groups,
110 pending_epoch_changes: Default::default(),
111 #[cfg(feature = "proteus")]
112 proteus_central,
113 }
114 .into(),
115 ),
116 })
117 }
118
119 pub(crate) async fn session(&self) -> Result<Session> {
120 match self.inner.read().await.deref() {
121 TransactionContextInner::Valid { mls_client, .. } => Ok(mls_client.clone()),
122 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
123 }
124 }
125
126 pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
127 match self.inner.read().await.deref() {
128 TransactionContextInner::Valid {
129 transport: callbacks, ..
130 } => Ok(callbacks.read_arc().await),
131 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
132 }
133 }
134
135 #[cfg(test)]
136 pub(crate) async fn set_transport_callbacks(
137 &self,
138 callbacks: Option<Arc<dyn MlsTransport + 'static>>,
139 ) -> Result<()> {
140 match self.inner.read().await.deref() {
141 TransactionContextInner::Valid { transport: cbs, .. } => {
142 *cbs.write_arc().await = callbacks;
143 Ok(())
144 }
145 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
146 }
147 }
148
149 pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
151 match self.inner.read().await.deref() {
152 TransactionContextInner::Valid { provider, .. } => Ok(provider.clone()),
153 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
154 }
155 }
156
157 pub(crate) async fn keystore(&self) -> Result<CryptoKeystore> {
158 match self.inner.read().await.deref() {
159 TransactionContextInner::Valid { provider, .. } => Ok(provider.keystore()),
160 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
161 }
162 }
163
164 pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
165 match self.inner.read().await.deref() {
166 TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
167 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
168 }
169 }
170
171 pub(crate) async fn queue_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) -> Result<()> {
172 match self.inner.read().await.deref() {
173 TransactionContextInner::Valid {
174 pending_epoch_changes, ..
175 } => {
176 pending_epoch_changes.lock().await.push((conversation_id, epoch));
177 Ok(())
178 }
179 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
180 }
181 }
182
183 #[cfg(feature = "proteus")]
184 pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
185 match self.inner.read().await.deref() {
186 TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
187 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
188 }
189 }
190
191 pub async fn finish(&self) -> Result<()> {
195 let mut guard = self.inner.write().await;
196 let TransactionContextInner::Valid {
197 provider,
198 mls_client,
199 pending_epoch_changes,
200 ..
201 } = guard.deref()
202 else {
203 return Err(Error::InvalidTransactionContext);
204 };
205
206 let commit_result = provider
207 .keystore()
208 .commit_transaction()
209 .await
210 .map_err(KeystoreError::wrap("commiting transaction"))
211 .map_err(Into::into);
212
213 if commit_result.is_ok() {
214 let mut epoch_changes = pending_epoch_changes.lock().await;
217 let epoch_changes = epoch_changes.drain(..);
218 for (conversation_id, epoch) in epoch_changes {
219 mls_client.notify_epoch_changed(conversation_id, epoch).await;
220 }
221 }
222
223 *guard = TransactionContextInner::Invalid;
224 commit_result
225 }
226
227 pub async fn abort(&self) -> Result<()> {
231 let mut guard = self.inner.write().await;
232
233 let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
234 return Err(Error::InvalidTransactionContext);
235 };
236
237 let result = provider
238 .keystore()
239 .rollback_transaction()
240 .await
241 .map_err(KeystoreError::wrap("rolling back transaction"))
242 .map_err(Into::into);
243
244 *guard = TransactionContextInner::Invalid;
245 result
246 }
247
248 pub async fn mls_init(
252 &self,
253 identifier: ClientIdentifier,
254 ciphersuites: Vec<MlsCiphersuite>,
255 nb_init_key_packages: Option<usize>,
256 ) -> Result<()> {
257 let nb_key_package = nb_init_key_packages.unwrap_or(INITIAL_KEYING_MATERIAL_COUNT);
258 let mls_client = self.session().await?;
259 mls_client
260 .init(identifier, &ciphersuites, &self.mls_provider().await?, nb_key_package)
261 .await
262 .map_err(RecursiveError::mls_client("initializing mls client"))?;
263
264 if mls_client.is_e2ei_capable().await {
265 let client_id = mls_client
266 .id()
267 .await
268 .map_err(RecursiveError::mls_client("getting client id"))?;
269 log::trace!(client_id:% = client_id; "Initializing PKI environment");
270 self.init_pki_env().await?;
271 }
272
273 Ok(())
274 }
275
276 pub async fn client_public_key(
278 &self,
279 ciphersuite: MlsCiphersuite,
280 credential_type: MlsCredentialType,
281 ) -> Result<Vec<u8>> {
282 let cb = self
283 .session()
284 .await?
285 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
286 .await
287 .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
288 Ok(cb.signature_key.to_public_vec())
289 }
290
291 pub async fn client_id(&self) -> Result<ClientId> {
293 self.session()
294 .await?
295 .id()
296 .await
297 .map_err(RecursiveError::mls_client("getting client id"))
298 .map_err(Into::into)
299 }
300
301 pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
303 use openmls_traits::random::OpenMlsRand as _;
304 self.mls_provider()
305 .await?
306 .rand()
307 .random_vec(len)
308 .map_err(MlsError::wrap("generating random vector"))
309 .map_err(Into::into)
310 }
311
312 pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
316 self.keystore()
317 .await?
318 .save(ConsumerData::from(data))
319 .await
320 .map_err(KeystoreError::wrap("saving consumer data"))?;
321 Ok(())
322 }
323
324 pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
327 match self.keystore().await?.find_unique::<ConsumerData>().await {
328 Ok(data) => Ok(Some(data.into())),
329 Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
330 Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
331 }
332 }
333}