core_crypto/transaction_context/
mod.rs1use std::{ops::Deref, sync::Arc};
5
6#[cfg(feature = "proteus")]
7use async_lock::Mutex;
8use async_lock::{RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
9use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
10pub use error::{Error, Result};
11use mls_crypto_provider::{Database, MlsCryptoProvider};
12use openmls_traits::OpenMlsCryptoProvider as _;
13
14#[cfg(feature = "proteus")]
15use crate::proteus::ProteusCentral;
16use crate::{
17 ClientId, ClientIdentifier, CoreCrypto, KeystoreError, MlsCiphersuite, MlsConversation, MlsCredentialType,
18 MlsError, MlsTransport, RecursiveError, Session, group_store::GroupStore, mls::HasSessionAndCrypto,
19};
20pub mod conversation;
21pub mod e2e_identity;
22mod error;
23pub mod key_package;
24#[cfg(feature = "proteus")]
25pub mod proteus;
26#[cfg(test)]
27pub mod test_utils;
28
29#[derive(Debug, Clone)]
36pub struct TransactionContext {
37 inner: Arc<RwLock<TransactionContextInner>>,
38}
39
40#[derive(Debug, Clone)]
44enum TransactionContextInner {
45 Valid {
46 provider: MlsCryptoProvider,
47 transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
48 mls_client: Session,
49 mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
50 #[cfg(feature = "proteus")]
51 proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
52 },
53 Invalid,
54}
55
56impl CoreCrypto {
57 pub async fn new_transaction(&self) -> Result<TransactionContext> {
61 TransactionContext::new(
62 &self.mls,
63 #[cfg(feature = "proteus")]
64 self.proteus.clone(),
65 )
66 .await
67 }
68}
69
70#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
71#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
72impl HasSessionAndCrypto for TransactionContext {
73 async fn session(&self) -> crate::mls::Result<Session> {
74 self.session()
75 .await
76 .map_err(RecursiveError::transaction("getting mls client"))
77 .map_err(Into::into)
78 }
79
80 async fn crypto_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
81 self.mls_provider()
82 .await
83 .map_err(RecursiveError::transaction("getting mls provider"))
84 .map_err(Into::into)
85 }
86}
87
88impl TransactionContext {
89 async fn new(
90 client: &Session,
91 #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
92 ) -> Result<Self> {
93 client
94 .crypto_provider
95 .new_transaction()
96 .await
97 .map_err(MlsError::wrap("creating new transaction"))?;
98 let mls_groups = Arc::new(RwLock::new(Default::default()));
99 let callbacks = client.transport.clone();
100 let mls_client = client.clone();
101 Ok(Self {
102 inner: Arc::new(
103 TransactionContextInner::Valid {
104 mls_client,
105 transport: callbacks,
106 provider: client.crypto_provider.clone(),
107 mls_groups,
108 #[cfg(feature = "proteus")]
109 proteus_central,
110 }
111 .into(),
112 ),
113 })
114 }
115
116 pub(crate) async fn session(&self) -> Result<Session> {
117 match self.inner.read().await.deref() {
118 TransactionContextInner::Valid { mls_client, .. } => Ok(mls_client.clone()),
119 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
120 }
121 }
122
123 pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
124 match self.inner.read().await.deref() {
125 TransactionContextInner::Valid {
126 transport: callbacks, ..
127 } => Ok(callbacks.read_arc().await),
128 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
129 }
130 }
131
132 #[cfg(test)]
133 pub(crate) async fn set_transport_callbacks(
134 &self,
135 callbacks: Option<Arc<dyn MlsTransport + 'static>>,
136 ) -> Result<()> {
137 match self.inner.read().await.deref() {
138 TransactionContextInner::Valid { transport: cbs, .. } => {
139 *cbs.write_arc().await = callbacks;
140 Ok(())
141 }
142 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
143 }
144 }
145
146 pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
148 match self.inner.read().await.deref() {
149 TransactionContextInner::Valid { provider, .. } => Ok(provider.clone()),
150 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
151 }
152 }
153
154 pub(crate) async fn keystore(&self) -> Result<Database> {
155 match self.inner.read().await.deref() {
156 TransactionContextInner::Valid { provider, .. } => Ok(provider.keystore()),
157 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
158 }
159 }
160
161 pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
162 match self.inner.read().await.deref() {
163 TransactionContextInner::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
164 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
165 }
166 }
167
168 #[cfg(feature = "proteus")]
169 pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
170 match self.inner.read().await.deref() {
171 TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
172 TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
173 }
174 }
175
176 pub async fn finish(&self) -> Result<()> {
180 let mut guard = self.inner.write().await;
181 let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
182 return Err(Error::InvalidTransactionContext);
183 };
184
185 let commit_result = provider
186 .keystore()
187 .commit_transaction()
188 .await
189 .map_err(KeystoreError::wrap("commiting transaction"))
190 .map_err(Into::into);
191
192 *guard = TransactionContextInner::Invalid;
193 commit_result
194 }
195
196 pub async fn abort(&self) -> Result<()> {
200 let mut guard = self.inner.write().await;
201
202 let TransactionContextInner::Valid { provider, .. } = guard.deref() else {
203 return Err(Error::InvalidTransactionContext);
204 };
205
206 let result = provider
207 .keystore()
208 .rollback_transaction()
209 .await
210 .map_err(KeystoreError::wrap("rolling back transaction"))
211 .map_err(Into::into);
212
213 *guard = TransactionContextInner::Invalid;
214 result
215 }
216
217 pub async fn mls_init(&self, identifier: ClientIdentifier, ciphersuites: &[MlsCiphersuite]) -> Result<()> {
219 let mls_client = self.session().await?;
220 mls_client
221 .init(identifier, ciphersuites, &self.mls_provider().await?)
222 .await
223 .map_err(RecursiveError::mls_client("initializing mls client"))?;
224
225 if mls_client.is_e2ei_capable().await {
226 let client_id = mls_client
227 .id()
228 .await
229 .map_err(RecursiveError::mls_client("getting client id"))?;
230 log::trace!(client_id:% = client_id; "Initializing PKI environment");
231 self.init_pki_env().await?;
232 }
233
234 Ok(())
235 }
236
237 pub async fn client_public_key(
239 &self,
240 ciphersuite: MlsCiphersuite,
241 credential_type: MlsCredentialType,
242 ) -> Result<Vec<u8>> {
243 let cb = self
244 .session()
245 .await?
246 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
247 .await
248 .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
249 Ok(cb.signature_key.to_public_vec())
250 }
251
252 pub async fn client_id(&self) -> Result<ClientId> {
254 self.session()
255 .await?
256 .id()
257 .await
258 .map_err(RecursiveError::mls_client("getting client id"))
259 .map_err(Into::into)
260 }
261
262 pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
264 use openmls_traits::random::OpenMlsRand as _;
265 self.mls_provider()
266 .await?
267 .rand()
268 .random_vec(len)
269 .map_err(MlsError::wrap("generating random vector"))
270 .map_err(Into::into)
271 }
272
273 pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
277 self.keystore()
278 .await?
279 .save(ConsumerData::from(data))
280 .await
281 .map_err(KeystoreError::wrap("saving consumer data"))?;
282 Ok(())
283 }
284
285 pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
288 match self.keystore().await?.find_unique::<ConsumerData>().await {
289 Ok(data) => Ok(Some(data.into())),
290 Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
291 Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
292 }
293 }
294}