core_crypto/transaction_context/
mod.rs1#[cfg(feature = "proteus")]
5use crate::proteus::ProteusCentral;
6use crate::{
7 CoreCrypto, KeystoreError, MlsError, MlsTransport, RecursiveError,
8 group_store::GroupStore,
9 prelude::{ClientId, INITIAL_KEYING_MATERIAL_COUNT, MlsConversation, MlsCredentialType, Session},
10};
11use crate::{
12 mls::HasSessionAndCrypto,
13 prelude::{ClientIdentifier, MlsCiphersuite},
14};
15#[cfg(feature = "proteus")]
16use async_lock::Mutex;
17use async_lock::{RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
18use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
19pub use error::{Error, Result};
20use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
21use openmls_traits::OpenMlsCryptoProvider as _;
22use std::{ops::Deref, sync::Arc};
23pub mod conversation;
24pub mod e2e_identity;
25mod error;
26pub mod key_package;
27#[cfg(feature = "proteus")]
28pub mod proteus;
29#[cfg(test)]
30pub mod test_utils;
31
32#[derive(Debug, Clone)]
39pub struct TransactionContext {
40 inner: Arc<RwLock<TransactionContextInner>>,
41}
42
43#[derive(Debug, Clone)]
47enum TransactionContextInner {
48 Valid {
49 provider: MlsCryptoProvider,
50 transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
51 mls_client: Session,
52 mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
53 #[cfg(feature = "proteus")]
54 proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
55 },
56 Invalid,
57}
58
59impl CoreCrypto {
60 pub async fn new_transaction(&self) -> Result<TransactionContext> {
64 TransactionContext::new(
65 &self.mls,
66 #[cfg(feature = "proteus")]
67 self.proteus.clone(),
68 )
69 .await
70 }
71}
72
73#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
74#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
75impl HasSessionAndCrypto for TransactionContext {
76 async fn session(&self) -> crate::mls::Result<Session> {
77 self.session()
78 .await
79 .map_err(RecursiveError::transaction("getting mls client"))
80 .map_err(Into::into)
81 }
82
83 async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
84 self.mls_provider()
85 .await
86 .map_err(RecursiveError::transaction("getting mls provider"))
87 .map_err(Into::into)
88 }
89}
90
91impl TransactionContext {
92 async fn new(
93 client: &Session,
94 #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
95 ) -> Result<Self> {
96 client
97 .crypto_provider
98 .new_transaction()
99 .await
100 .map_err(MlsError::wrap("creating new transaction"))?;
101 let mls_groups = Arc::new(RwLock::new(Default::default()));
102 let callbacks = client.transport.clone();
103 let mls_client = client.clone();
104 Ok(Self {
105 inner: Arc::new(
106 TransactionContextInner::Valid {
107 mls_client,
108 transport: callbacks,
109 provider: client.crypto_provider.clone(),
110 mls_groups,
111 #[cfg(feature = "proteus")]
112 proteus_central,
113 }
114 .into(),
115 ),
116 })
117 }
118
119 pub(crate) async fn session(&self) -> Result<Session> {
120 match self.inner.read().await.deref() {
121 TransactionContextInner::Valid { mls_client, .. } => Ok(mls_client.clone()),
122 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
123 }
124 }
125
126 pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
127 match self.inner.read().await.deref() {
128 TransactionContextInner::Valid {
129 transport: callbacks, ..
130 } => Ok(callbacks.read_arc().await),
131 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
132 }
133 }
134
135 #[cfg(test)]
136 pub(crate) async fn set_transport_callbacks(
137 &self,
138 callbacks: Option<Arc<dyn MlsTransport + 'static>>,
139 ) -> Result<()> {
140 match self.inner.read().await.deref() {
141 TransactionContextInner::Valid { transport: cbs, .. } => {
142 *cbs.write_arc().await = callbacks;
143 Ok(())
144 }
145 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
146 }
147 }
148
149 pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
151 match self.inner.read().await.deref() {
152 TransactionContextInner::Valid { provider, .. } => Ok(provider.clone()),
153 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
154 }
155 }
156
157 pub(crate) async fn keystore(&self) -> Result<CryptoKeystore> {
158 match self.inner.read().await.deref() {
159 TransactionContextInner::Valid { provider, .. } => Ok(provider.keystore()),
160 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
161 }
162 }
163
164 pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
165 match self.inner.read().await.deref() {
166 TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
167 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
168 }
169 }
170
171 #[cfg(feature = "proteus")]
172 pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
173 match self.inner.read().await.deref() {
174 TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
175 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
176 }
177 }
178
179 pub async fn finish(&self) -> Result<()> {
183 let mut guard = self.inner.write().await;
184 let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
185 return Err(Error::InvalidTransactionContext);
186 };
187
188 let commit_result = provider
189 .keystore()
190 .commit_transaction()
191 .await
192 .map_err(KeystoreError::wrap("commiting transaction"))
193 .map_err(Into::into);
194
195 *guard = TransactionContextInner::Invalid;
196 commit_result
197 }
198
199 pub async fn abort(&self) -> Result<()> {
203 let mut guard = self.inner.write().await;
204
205 let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
206 return Err(Error::InvalidTransactionContext);
207 };
208
209 let result = provider
210 .keystore()
211 .rollback_transaction()
212 .await
213 .map_err(KeystoreError::wrap("rolling back transaction"))
214 .map_err(Into::into);
215
216 *guard = TransactionContextInner::Invalid;
217 result
218 }
219
220 pub async fn mls_init(
224 &self,
225 identifier: ClientIdentifier,
226 ciphersuites: Vec<MlsCiphersuite>,
227 nb_init_key_packages: Option<usize>,
228 ) -> Result<()> {
229 let nb_key_package = nb_init_key_packages.unwrap_or(INITIAL_KEYING_MATERIAL_COUNT);
230 let mls_client = self.session().await?;
231 mls_client
232 .init(identifier, &ciphersuites, &self.mls_provider().await?, nb_key_package)
233 .await
234 .map_err(RecursiveError::mls_client("initializing mls client"))?;
235
236 if mls_client.is_e2ei_capable().await {
237 let client_id = mls_client
238 .id()
239 .await
240 .map_err(RecursiveError::mls_client("getting client id"))?;
241 log::trace!(client_id:% = client_id; "Initializing PKI environment");
242 self.init_pki_env().await?;
243 }
244
245 Ok(())
246 }
247
248 pub async fn client_public_key(
250 &self,
251 ciphersuite: MlsCiphersuite,
252 credential_type: MlsCredentialType,
253 ) -> Result<Vec<u8>> {
254 let cb = self
255 .session()
256 .await?
257 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
258 .await
259 .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
260 Ok(cb.signature_key.to_public_vec())
261 }
262
263 pub async fn client_id(&self) -> Result<ClientId> {
265 self.session()
266 .await?
267 .id()
268 .await
269 .map_err(RecursiveError::mls_client("getting client id"))
270 .map_err(Into::into)
271 }
272
273 pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
275 use openmls_traits::random::OpenMlsRand as _;
276 self.mls_provider()
277 .await?
278 .rand()
279 .random_vec(len)
280 .map_err(MlsError::wrap("generating random vector"))
281 .map_err(Into::into)
282 }
283
284 pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
288 self.keystore()
289 .await?
290 .save(ConsumerData::from(data))
291 .await
292 .map_err(KeystoreError::wrap("saving consumer data"))?;
293 Ok(())
294 }
295
296 pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
299 match self.keystore().await?.find_unique::<ConsumerData>().await {
300 Ok(data) => Ok(Some(data.into())),
301 Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
302 Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
303 }
304 }
305}