core_crypto/transaction_context/
mod.rs1#[cfg(feature = "proteus")]
5use crate::proteus::ProteusCentral;
6use crate::{
7 CoreCrypto, KeystoreError, MlsError, MlsTransport, RecursiveError,
8 group_store::GroupStore,
9 prelude::{ClientId, INITIAL_KEYING_MATERIAL_COUNT, MlsConversation, MlsCredentialType, Session},
10};
11use crate::{
12 mls::HasSessionAndCrypto,
13 prelude::{ClientIdentifier, MlsCiphersuite},
14};
15use async_lock::{Mutex, RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
16use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
17pub use error::{Error, Result};
18use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
19use openmls_traits::OpenMlsCryptoProvider as _;
20use std::{ops::Deref, sync::Arc};
21pub mod conversation;
22pub mod e2e_identity;
23mod error;
24pub mod key_package;
25#[cfg(feature = "proteus")]
26pub mod proteus;
27#[cfg(test)]
28pub mod test_utils;
29
30#[derive(Debug, Clone)]
37pub struct TransactionContext {
38 inner: Arc<RwLock<TransactionContextInner>>,
39}
40
41#[derive(Debug, Clone)]
45enum TransactionContextInner {
46 Valid {
47 provider: MlsCryptoProvider,
48 transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
49 mls_client: Session,
50 mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
51 #[cfg(feature = "proteus")]
52 proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
53 },
54 Invalid,
55}
56
57impl CoreCrypto {
58 pub async fn new_transaction(&self) -> Result<TransactionContext> {
62 TransactionContext::new(
63 &self.mls,
64 #[cfg(feature = "proteus")]
65 self.proteus.clone(),
66 )
67 .await
68 }
69}
70
71#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
72#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
73impl HasSessionAndCrypto for TransactionContext {
74 async fn session(&self) -> crate::mls::Result<Session> {
75 self.session()
76 .await
77 .map_err(RecursiveError::transaction("getting mls client"))
78 .map_err(Into::into)
79 }
80
81 async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
82 self.mls_provider()
83 .await
84 .map_err(RecursiveError::transaction("getting mls provider"))
85 .map_err(Into::into)
86 }
87}
88
89impl TransactionContext {
90 async fn new(
91 client: &Session,
92 #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
93 ) -> Result<Self> {
94 client
95 .crypto_provider
96 .new_transaction()
97 .await
98 .map_err(MlsError::wrap("creating new transaction"))?;
99 let mls_groups = Arc::new(RwLock::new(Default::default()));
100 let callbacks = client.transport.clone();
101 let mls_client = client.clone();
102 Ok(Self {
103 inner: Arc::new(
104 TransactionContextInner::Valid {
105 mls_client,
106 transport: callbacks,
107 provider: client.crypto_provider.clone(),
108 mls_groups,
109 #[cfg(feature = "proteus")]
110 proteus_central,
111 }
112 .into(),
113 ),
114 })
115 }
116
117 pub(crate) async fn session(&self) -> Result<Session> {
118 match self.inner.read().await.deref() {
119 TransactionContextInner::Valid { mls_client, .. } => Ok(mls_client.clone()),
120 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
121 }
122 }
123
124 pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
125 match self.inner.read().await.deref() {
126 TransactionContextInner::Valid {
127 transport: callbacks, ..
128 } => Ok(callbacks.read_arc().await),
129 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
130 }
131 }
132
133 #[cfg(test)]
134 pub(crate) async fn set_transport_callbacks(
135 &self,
136 callbacks: Option<Arc<dyn MlsTransport + 'static>>,
137 ) -> Result<()> {
138 match self.inner.read().await.deref() {
139 TransactionContextInner::Valid { transport: cbs, .. } => {
140 *cbs.write_arc().await = callbacks;
141 Ok(())
142 }
143 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
144 }
145 }
146
147 pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
149 match self.inner.read().await.deref() {
150 TransactionContextInner::Valid { provider, .. } => Ok(provider.clone()),
151 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
152 }
153 }
154
155 pub(crate) async fn keystore(&self) -> Result<CryptoKeystore> {
156 match self.inner.read().await.deref() {
157 TransactionContextInner::Valid { provider, .. } => Ok(provider.keystore()),
158 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
159 }
160 }
161
162 pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
163 match self.inner.read().await.deref() {
164 TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
165 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
166 }
167 }
168
169 #[cfg(feature = "proteus")]
170 pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
171 match self.inner.read().await.deref() {
172 TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
173 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
174 }
175 }
176
177 pub async fn finish(&self) -> Result<()> {
181 let mut guard = self.inner.write().await;
182 let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
183 return Err(Error::InvalidTransactionContext);
184 };
185
186 let commit_result = provider
187 .keystore()
188 .commit_transaction()
189 .await
190 .map_err(KeystoreError::wrap("commiting transaction"))
191 .map_err(Into::into);
192
193 *guard = TransactionContextInner::Invalid;
194 commit_result
195 }
196
197 pub async fn abort(&self) -> Result<()> {
201 let mut guard = self.inner.write().await;
202
203 let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
204 return Err(Error::InvalidTransactionContext);
205 };
206
207 let result = provider
208 .keystore()
209 .rollback_transaction()
210 .await
211 .map_err(KeystoreError::wrap("rolling back transaction"))
212 .map_err(Into::into);
213
214 *guard = TransactionContextInner::Invalid;
215 result
216 }
217
218 pub async fn mls_init(
222 &self,
223 identifier: ClientIdentifier,
224 ciphersuites: Vec<MlsCiphersuite>,
225 nb_init_key_packages: Option<usize>,
226 ) -> Result<()> {
227 let nb_key_package = nb_init_key_packages.unwrap_or(INITIAL_KEYING_MATERIAL_COUNT);
228 let mls_client = self.session().await?;
229 mls_client
230 .init(identifier, &ciphersuites, &self.mls_provider().await?, nb_key_package)
231 .await
232 .map_err(RecursiveError::mls_client("initializing mls client"))?;
233
234 if mls_client.is_e2ei_capable().await {
235 let client_id = mls_client
236 .id()
237 .await
238 .map_err(RecursiveError::mls_client("getting client id"))?;
239 log::trace!(client_id:% = client_id; "Initializing PKI environment");
240 self.init_pki_env().await?;
241 }
242
243 Ok(())
244 }
245
246 #[cfg_attr(test, crate::dispotent)]
251 pub async fn mls_generate_keypairs(&self, ciphersuites: Vec<MlsCiphersuite>) -> Result<Vec<ClientId>> {
252 self.session()
253 .await?
254 .generate_raw_keypairs(&ciphersuites, &self.mls_provider().await?)
255 .await
256 .map_err(RecursiveError::mls_client("generating raw keypairs"))
257 .map_err(Into::into)
258 }
259
260 #[cfg_attr(test, crate::dispotent)]
264 pub async fn mls_init_with_client_id(
265 &self,
266 client_id: ClientId,
267 tmp_client_ids: Vec<ClientId>,
268 ciphersuites: Vec<MlsCiphersuite>,
269 ) -> Result<()> {
270 self.session()
271 .await?
272 .init_with_external_client_id(client_id, tmp_client_ids, &ciphersuites, &self.mls_provider().await?)
273 .await
274 .map_err(RecursiveError::mls_client(
275 "initializing mls client with external client id",
276 ))
277 .map_err(Into::into)
278 }
279
280 pub async fn client_public_key(
282 &self,
283 ciphersuite: MlsCiphersuite,
284 credential_type: MlsCredentialType,
285 ) -> Result<Vec<u8>> {
286 let cb = self
287 .session()
288 .await?
289 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
290 .await
291 .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
292 Ok(cb.signature_key.to_public_vec())
293 }
294
295 pub async fn client_id(&self) -> Result<ClientId> {
297 self.session()
298 .await?
299 .id()
300 .await
301 .map_err(RecursiveError::mls_client("getting client id"))
302 .map_err(Into::into)
303 }
304
305 pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
307 use openmls_traits::random::OpenMlsRand as _;
308 self.mls_provider()
309 .await?
310 .rand()
311 .random_vec(len)
312 .map_err(MlsError::wrap("generating random vector"))
313 .map_err(Into::into)
314 }
315
316 pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
320 self.keystore()
321 .await?
322 .save(ConsumerData::from(data))
323 .await
324 .map_err(KeystoreError::wrap("saving consumer data"))?;
325 Ok(())
326 }
327
328 pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
331 match self.keystore().await?.find_unique::<ConsumerData>().await {
332 Ok(data) => Ok(Some(data.into())),
333 Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
334 Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
335 }
336 }
337}