core_crypto/transaction_context/
mod.rs1#[cfg(feature = "proteus")]
5use crate::proteus::ProteusCentral;
6use crate::{
7 CoreCrypto, KeystoreError, MlsError, MlsTransport, RecursiveError,
8 group_store::GroupStore,
9 prelude::{ClientId, INITIAL_KEYING_MATERIAL_COUNT, MlsConversation, MlsCredentialType, Session},
10};
11use crate::{
12 mls::HasSessionAndCrypto,
13 prelude::{ClientIdentifier, MlsCiphersuite},
14};
15use async_lock::{Mutex, RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
16use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
17pub use error::{Error, Result};
18use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
19use openmls_traits::OpenMlsCryptoProvider as _;
20use std::{ops::Deref, sync::Arc};
21pub mod conversation;
22pub mod e2e_identity;
23mod error;
24pub mod key_package;
25#[cfg(feature = "proteus")]
26pub mod proteus;
27#[cfg(test)]
28pub mod test_utils;
29
30#[derive(Debug, Clone)]
37pub struct TransactionContext {
38 inner: Arc<RwLock<TransactionContextInner>>,
39}
40
41#[derive(Debug, Clone)]
45enum TransactionContextInner {
46 Valid {
47 provider: MlsCryptoProvider,
48 transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
49 mls_client: Session,
50 mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
51 #[cfg(feature = "proteus")]
52 proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
53 },
54 Invalid,
55}
56
57impl CoreCrypto {
58 pub async fn new_transaction(&self) -> Result<TransactionContext> {
62 TransactionContext::new(
63 &self.mls,
64 #[cfg(feature = "proteus")]
65 self.proteus.clone(),
66 )
67 .await
68 }
69}
70
71#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
72#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
73impl HasSessionAndCrypto for TransactionContext {
74 async fn session(&self) -> crate::mls::Result<Session> {
75 self.session()
76 .await
77 .map_err(RecursiveError::transaction("getting mls client"))
78 .map_err(Into::into)
79 }
80
81 async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
82 self.mls_provider()
83 .await
84 .map_err(RecursiveError::transaction("getting mls provider"))
85 .map_err(Into::into)
86 }
87}
88
89impl TransactionContext {
90 async fn new(
91 client: &Session,
92 #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
93 ) -> Result<Self> {
94 client
95 .crypto_provider
96 .new_transaction()
97 .await
98 .map_err(MlsError::wrap("creating new transaction"))?;
99 let mls_groups = Arc::new(RwLock::new(Default::default()));
100 let callbacks = client.transport.clone();
101 let mls_client = client.clone();
102 Ok(Self {
103 inner: Arc::new(
104 TransactionContextInner::Valid {
105 mls_client,
106 transport: callbacks,
107 provider: client.crypto_provider.clone(),
108 mls_groups,
109 #[cfg(feature = "proteus")]
110 proteus_central,
111 }
112 .into(),
113 ),
114 })
115 }
116
117 pub(crate) async fn session(&self) -> Result<Session> {
118 match self.inner.read().await.deref() {
119 TransactionContextInner::Valid { mls_client, .. } => Ok(mls_client.clone()),
120 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
121 }
122 }
123
124 pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
125 match self.inner.read().await.deref() {
126 TransactionContextInner::Valid {
127 transport: callbacks, ..
128 } => Ok(callbacks.read_arc().await),
129 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
130 }
131 }
132
133 #[cfg(test)]
134 pub(crate) async fn set_transport_callbacks(
135 &self,
136 callbacks: Option<Arc<dyn MlsTransport + 'static>>,
137 ) -> Result<()> {
138 match self.inner.read().await.deref() {
139 TransactionContextInner::Valid { transport: cbs, .. } => {
140 *cbs.write_arc().await = callbacks;
141 Ok(())
142 }
143 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
144 }
145 }
146
147 pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
149 match self.inner.read().await.deref() {
150 TransactionContextInner::Valid { provider, .. } => Ok(provider.clone()),
151 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
152 }
153 }
154
155 pub(crate) async fn keystore(&self) -> Result<CryptoKeystore> {
156 match self.inner.read().await.deref() {
157 TransactionContextInner::Valid { provider, .. } => Ok(provider.keystore()),
158 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
159 }
160 }
161
162 pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
163 match self.inner.read().await.deref() {
164 TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
165 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
166 }
167 }
168
169 #[cfg(feature = "proteus")]
170 pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
171 match self.inner.read().await.deref() {
172 TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
173 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
174 }
175 }
176
177 pub async fn finish(&self) -> Result<()> {
181 let mut guard = self.inner.write().await;
182 let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
183 return Err(Error::InvalidTransactionContext);
184 };
185
186 let commit_result = provider
187 .keystore()
188 .commit_transaction()
189 .await
190 .map_err(KeystoreError::wrap("commiting transaction"))
191 .map_err(Into::into);
192
193 *guard = TransactionContextInner::Invalid;
194 commit_result
195 }
196
197 pub async fn abort(&self) -> Result<()> {
201 let mut guard = self.inner.write().await;
202
203 let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
204 return Err(Error::InvalidTransactionContext);
205 };
206
207 let result = provider
208 .keystore()
209 .rollback_transaction()
210 .await
211 .map_err(KeystoreError::wrap("rolling back transaction"))
212 .map_err(Into::into);
213
214 *guard = TransactionContextInner::Invalid;
215 result
216 }
217
218 pub async fn mls_init(
222 &self,
223 identifier: ClientIdentifier,
224 ciphersuites: Vec<MlsCiphersuite>,
225 nb_init_key_packages: Option<usize>,
226 ) -> Result<()> {
227 let nb_key_package = nb_init_key_packages.unwrap_or(INITIAL_KEYING_MATERIAL_COUNT);
228 let mls_client = self.session().await?;
229 mls_client
230 .init(identifier, &ciphersuites, &self.mls_provider().await?, nb_key_package)
231 .await
232 .map_err(RecursiveError::mls_client("initializing mls client"))?;
233
234 if mls_client.is_e2ei_capable().await {
235 let client_id = mls_client
236 .id()
237 .await
238 .map_err(RecursiveError::mls_client("getting client id"))?;
239 log::trace!(client_id:% = client_id; "Initializing PKI environment");
240 self.init_pki_env().await?;
241 }
242
243 Ok(())
244 }
245
246 pub async fn client_public_key(
248 &self,
249 ciphersuite: MlsCiphersuite,
250 credential_type: MlsCredentialType,
251 ) -> Result<Vec<u8>> {
252 let cb = self
253 .session()
254 .await?
255 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
256 .await
257 .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
258 Ok(cb.signature_key.to_public_vec())
259 }
260
261 pub async fn client_id(&self) -> Result<ClientId> {
263 self.session()
264 .await?
265 .id()
266 .await
267 .map_err(RecursiveError::mls_client("getting client id"))
268 .map_err(Into::into)
269 }
270
271 pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
273 use openmls_traits::random::OpenMlsRand as _;
274 self.mls_provider()
275 .await?
276 .rand()
277 .random_vec(len)
278 .map_err(MlsError::wrap("generating random vector"))
279 .map_err(Into::into)
280 }
281
282 pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
286 self.keystore()
287 .await?
288 .save(ConsumerData::from(data))
289 .await
290 .map_err(KeystoreError::wrap("saving consumer data"))?;
291 Ok(())
292 }
293
294 pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
297 match self.keystore().await?.find_unique::<ConsumerData>().await {
298 Ok(data) => Ok(Some(data.into())),
299 Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
300 Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
301 }
302 }
303}