core_crypto_keystore/transaction/
mod.rs1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3use std::sync::Arc;
4
5use async_lock::{RwLock, SemaphoreGuardArc};
6use itertools::Itertools;
7use zeroize::Zeroizing;
8
9use crate::connection::KeystoreDatabaseConnection;
10use crate::entities::mls::*;
11#[cfg(feature = "proteus-keystore")]
12use crate::entities::proteus::*;
13use crate::entities::{ConsumerData, EntityBase, EntityFindParams, EntityTransactionExt, UniqueEntity};
14use crate::transaction::dynamic_dispatch::EntityId;
15use crate::{CryptoKeystoreError, CryptoKeystoreResult, connection::Connection};
16
17pub mod dynamic_dispatch;
18
19#[derive(Debug, Default, derive_more::Deref, derive_more::DerefMut)]
20struct InMemoryTable(HashMap<Vec<u8>, Zeroizing<Vec<u8>>>);
21
22type InMemoryCache = Arc<RwLock<HashMap<String, InMemoryTable>>>;
23
24#[derive(Debug, Clone)]
27pub(crate) struct KeystoreTransaction {
28 cache: InMemoryCache,
29 deleted: Arc<RwLock<Vec<EntityId>>>,
30 deleted_credentials: Arc<RwLock<Vec<Vec<u8>>>>,
31 _semaphore_guard: Arc<SemaphoreGuardArc>,
32}
33
34impl KeystoreTransaction {
35 pub(crate) async fn new(semaphore_guard: SemaphoreGuardArc) -> CryptoKeystoreResult<Self> {
36 Ok(Self {
37 cache: Default::default(),
38 deleted: Arc::new(Default::default()),
39 deleted_credentials: Arc::new(Default::default()),
40 _semaphore_guard: Arc::new(semaphore_guard),
41 })
42 }
43
44 pub(crate) async fn save_mut<
45 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt + Sync,
46 >(
47 &self,
48 mut entity: E,
49 ) -> CryptoKeystoreResult<E> {
50 entity.pre_save().await?;
51 let mut cache_guard = self.cache.write().await;
52 let table = cache_guard.entry(E::COLLECTION_NAME.to_string()).or_default();
53 let serialized = postcard::to_stdvec(&entity)?;
54 table.insert(entity.merge_key(), Zeroizing::new(serialized));
58 Ok(entity)
59 }
60
61 pub(crate) async fn remove<
62 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt,
63 S: AsRef<[u8]>,
64 >(
65 &self,
66 id: S,
67 ) -> CryptoKeystoreResult<()> {
68 let mut cache_guard = self.cache.write().await;
69 if let Entry::Occupied(mut table) = cache_guard.entry(E::COLLECTION_NAME.to_string())
70 && let Entry::Occupied(cached_record) = table.get_mut().entry(id.as_ref().to_vec())
71 {
72 cached_record.remove_entry();
73 };
74
75 let mut deleted_list = self.deleted.write().await;
76 deleted_list.push(EntityId::from_collection_name(E::COLLECTION_NAME, id.as_ref())?);
77 Ok(())
78 }
79
80 pub(crate) async fn child_groups<E>(&self, entity: E, persisted_records: Vec<E>) -> CryptoKeystoreResult<Vec<E>>
81 where
82 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + PersistedMlsGroupExt + Sync,
83 {
84 let cached_records = self
87 .find_all_in_cache()
88 .await?
89 .into_iter()
90 .filter(|maybe_child: &E| {
91 maybe_child
92 .parent_id()
93 .map(|parent_id| parent_id == entity.id_raw())
94 .unwrap_or_default()
95 })
96 .collect();
97
98 Ok(self
99 .merge_records(cached_records, persisted_records, EntityFindParams::default())
100 .await)
101 }
102
103 pub(crate) async fn cred_delete_by_credential(&self, cred: Vec<u8>) -> CryptoKeystoreResult<()> {
104 let mut cache_guard = self.cache.write().await;
105 if let Entry::Occupied(mut table) = cache_guard.entry(MlsCredential::COLLECTION_NAME.to_string()) {
106 table.get_mut().retain(|_, value| **value != cred);
107 }
108
109 let mut deleted_list = self.deleted_credentials.write().await;
110 deleted_list.push(cred);
111 Ok(())
112 }
113
114 pub(crate) async fn remove_pending_messages_by_conversation_id(
115 &self,
116 conversation_id: &[u8],
117 ) -> CryptoKeystoreResult<()> {
118 let mut result = Ok(());
120
121 let mut cache_guard = self.cache.write().await;
122 if let Entry::Occupied(mut table) = cache_guard.entry(MlsPendingMessage::COLLECTION_NAME.to_string()) {
123 table.get_mut().retain(|_key, record_bytes| {
124 postcard::from_bytes::<MlsPendingMessage>(record_bytes)
125 .map(|pending_message| pending_message.foreign_id != conversation_id)
126 .inspect_err(|err| result = Err(err.clone()))
127 .unwrap_or(false)
128 });
129 }
130
131 let mut deleted_list = self.deleted.write().await;
132 deleted_list.push(EntityId::from_collection_name(
133 MlsPendingMessage::COLLECTION_NAME,
134 conversation_id,
135 )?);
136 result.map_err(Into::into)
137 }
138
139 pub(crate) async fn find_pending_messages_by_conversation_id(
140 &self,
141 conversation_id: &[u8],
142 persisted_records: Vec<MlsPendingMessage>,
143 ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
144 let cached_records = self
145 .find_all_in_cache::<MlsPendingMessage>()
146 .await?
147 .into_iter()
148 .filter(|pending_message| pending_message.foreign_id == conversation_id)
149 .collect();
150 let merged_records = self
151 .merge_records(cached_records, persisted_records, Default::default())
152 .await;
153 Ok(merged_records)
154 }
155
156 async fn find_in_cache<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<E>>
157 where
158 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
159 {
160 let cache_guard = self.cache.read().await;
161 cache_guard
162 .get(E::COLLECTION_NAME)
163 .and_then(|table| {
164 table
165 .get(id)
166 .map(|record| -> CryptoKeystoreResult<_> { postcard::from_bytes::<E>(record).map_err(Into::into) })
167 })
168 .transpose()
169 }
170
171 pub(crate) async fn find<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<Option<E>>>
176 where
177 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
178 {
179 let maybe_cached_record = self.find_in_cache(id).await?;
180 if let Some(cached_record) = maybe_cached_record {
181 return Ok(Some(Some(cached_record)));
182 }
183
184 let deleted_list = self.deleted.read().await;
185 if deleted_list.contains(&EntityId::from_collection_name(E::COLLECTION_NAME, id)?) {
186 return Ok(Some(None));
187 }
188
189 Ok(None)
190 }
191
192 pub(crate) async fn find_unique<U: UniqueEntity<ConnectionType = KeystoreDatabaseConnection>>(
193 &self,
194 ) -> CryptoKeystoreResult<Option<U>> {
195 #[cfg(target_family = "wasm")]
196 let id = &U::ID;
197 #[cfg(not(target_family = "wasm"))]
198 let id = &[U::ID as u8];
199 let maybe_cached_record = self.find_in_cache::<U>(id).await?;
200 match maybe_cached_record {
201 Some(cached_record) => Ok(Some(cached_record)),
202 _ => {
203 Ok(None)
206 }
207 }
208 }
209
210 async fn find_all_in_cache<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
211 &self,
212 ) -> CryptoKeystoreResult<Vec<E>> {
213 let cache_guard = self.cache.read().await;
214 let cached_records = cache_guard
215 .get(E::COLLECTION_NAME)
216 .map(|table| {
217 table
218 .values()
219 .map(|record| postcard::from_bytes::<E>(record).map_err(Into::into))
220 .collect::<CryptoKeystoreResult<Vec<_>>>()
221 })
222 .transpose()?
223 .unwrap_or_default();
224 Ok(cached_records)
225 }
226
227 pub(crate) async fn find_all<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
228 &self,
229 persisted_records: Vec<E>,
230 params: EntityFindParams,
231 ) -> CryptoKeystoreResult<Vec<E>> {
232 let cached_records = self.find_all_in_cache().await?;
233 let merged_records = self.merge_records(cached_records, persisted_records, params).await;
234 Ok(merged_records)
235 }
236
237 pub(crate) async fn find_many<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
238 &self,
239 persisted_records: Vec<E>,
240 ids: &[Vec<u8>],
241 ) -> CryptoKeystoreResult<Vec<E>> {
242 let records = self
243 .find_all(persisted_records, EntityFindParams::default())
244 .await?
245 .into_iter()
246 .filter(|record| ids.contains(&record.id_raw().to_vec()))
247 .collect();
248 Ok(records)
249 }
250
251 async fn merge_records<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
258 &self,
259 records_a: Vec<E>,
260 records_b: Vec<E>,
261 params: EntityFindParams,
262 ) -> Vec<E> {
263 let mut merged = records_a.into_iter().chain(records_b).unique_by(|e| e.merge_key());
264
265 let deleted_records = self.deleted.read().await;
266 let deleted_credentials = self.deleted_credentials.read().await;
267
268 let merged: &mut dyn Iterator<Item = E> = if params.reverse { &mut merged.rev() } else { &mut merged };
269
270 merged
271 .filter(|record| {
272 !Self::record_is_in_deleted_list(record, &deleted_records)
273 && !Self::credential_is_in_deleted_list(record, &deleted_credentials)
274 })
275 .skip(params.offset.unwrap_or(0) as usize)
276 .take(params.limit.unwrap_or(u32::MAX) as usize)
277 .collect()
278 }
279
280 fn record_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
281 record: &E,
282 deleted_records: &[EntityId],
283 ) -> bool {
284 let id = EntityId::from_collection_name(E::COLLECTION_NAME, record.id_raw());
285 let Ok(id) = id else { return false };
286 deleted_records.contains(&id)
287 }
288
289 fn credential_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
290 maybe_credential: &E,
291 deleted_credentials: &[Vec<u8>],
292 ) -> bool {
293 let Some(credential) = maybe_credential.downcast::<MlsCredential>() else {
294 return false;
295 };
296 deleted_credentials.contains(&credential.credential)
297 }
298}
299
300macro_rules! commit_transaction {
334 ($keystore_transaction:expr_2021, $db:expr_2021, [ $( ($records:ident, $entity:ty) ),*], proteus_types: [ $( ($conditional_records:ident, $conditional_entity:ty) ),*]) => {
335 #[cfg(feature = "proteus-keystore")]
336 commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*], [ $( ($conditional_records, $conditional_entity) ),*]);
337
338 #[cfg(not(feature = "proteus-keystore"))]
339 commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*]);
340 };
341 ($keystore_transaction:expr_2021, $db:expr_2021, $([ $( ($records:ident, $entity:ty) ),*]),*) => {
342 let cached_collections = ( $( $(
343 $keystore_transaction.find_all_in_cache::<$entity>().await?,
344 )* )* );
345
346 let ( $( $( $records, )* )* ) = cached_collections;
347
348 let conn = $db.borrow_conn().await?;
349 let mut conn = conn.conn().await;
350 let deleted_ids = $keystore_transaction.deleted.read().await;
351
352 let mut tables = Vec::new();
353 $( $(
354 if !$records.is_empty() {
355 tables.push(<$entity>::COLLECTION_NAME);
356 }
357 )* )*
358
359 for deleted_id in deleted_ids.iter() {
360 tables.push(deleted_id.collection_name());
361 }
362
363 if tables.is_empty() {
364 log::debug!("Empty transaction was committed.");
365 return Ok(());
366 }
367
368 #[cfg(target_family = "wasm")]
369 let tx = conn.new_transaction(&tables).await?;
370 #[cfg(not(target_family = "wasm"))]
371 let tx = conn.transaction()?.into();
372
373 $( $(
374 if !$records.is_empty() {
375 for record in $records {
376 dynamic_dispatch::execute_save(&tx, &record.to_transaction_entity()).await?;
377 }
378 }
379 )* )*
380
381
382 for deleted_id in deleted_ids.iter() {
383 dynamic_dispatch::execute_delete(&tx, deleted_id).await?
384 }
385
386 for deleted_credential in $keystore_transaction.deleted_credentials.read().await.iter() {
387 MlsCredential::delete_by_credential(&tx, deleted_credential.to_owned()).await?;
388 }
389
390 tx.commit_tx().await?;
391 };
392}
393
394impl KeystoreTransaction {
395 pub(crate) async fn commit(&self, db: &Connection) -> Result<(), CryptoKeystoreError> {
398 commit_transaction!(
399 self, db,
400 [
401 (identifier_01, MlsCredential),
402 (identifier_02, MlsSignatureKeyPair),
403 (identifier_03, MlsHpkePrivateKey),
404 (identifier_04, MlsEncryptionKeyPair),
405 (identifier_05, MlsEpochEncryptionKeyPair),
406 (identifier_06, MlsPskBundle),
407 (identifier_07, MlsKeyPackage),
408 (identifier_08, PersistedMlsGroup),
409 (identifier_09, PersistedMlsPendingGroup),
410 (identifier_10, MlsPendingMessage),
411 (identifier_11, E2eiEnrollment),
412 (identifier_13, E2eiAcmeCA),
414 (identifier_14, E2eiIntermediateCert),
415 (identifier_15, E2eiCrl),
416 (identifier_16, ConsumerData)
417 ],
418 proteus_types: [
419 (identifier_17, ProteusPrekey),
420 (identifier_18, ProteusIdentity),
421 (identifier_19, ProteusSession)
422 ]
423 );
424
425 Ok(())
426 }
427}