core_crypto_keystore/transaction/
mod.rs1use std::{
2 collections::{HashMap, hash_map::Entry},
3 sync::Arc,
4};
5
6use async_lock::{RwLock, SemaphoreGuardArc};
7use itertools::Itertools;
8use zeroize::Zeroizing;
9
10#[cfg(feature = "proteus-keystore")]
11use crate::entities::proteus::*;
12use crate::{
13 CryptoKeystoreError, CryptoKeystoreResult,
14 connection::{Database, KeystoreDatabaseConnection},
15 entities::{ConsumerData, EntityBase, EntityFindParams, EntityTransactionExt, UniqueEntity, mls::*},
16 transaction::dynamic_dispatch::EntityId,
17};
18
19pub mod dynamic_dispatch;
20
21#[derive(Debug, Default, derive_more::Deref, derive_more::DerefMut)]
22struct InMemoryTable(HashMap<Vec<u8>, Zeroizing<Vec<u8>>>);
23
24type InMemoryCache = Arc<RwLock<HashMap<String, InMemoryTable>>>;
25
26#[derive(Debug, Clone)]
29pub(crate) struct KeystoreTransaction {
30 cache: InMemoryCache,
31 deleted: Arc<RwLock<Vec<EntityId>>>,
32 deleted_credentials: Arc<RwLock<Vec<Vec<u8>>>>,
33 _semaphore_guard: Arc<SemaphoreGuardArc>,
34}
35
36impl KeystoreTransaction {
37 pub(crate) async fn new(semaphore_guard: SemaphoreGuardArc) -> CryptoKeystoreResult<Self> {
38 Ok(Self {
39 cache: Default::default(),
40 deleted: Arc::new(Default::default()),
41 deleted_credentials: Arc::new(Default::default()),
42 _semaphore_guard: Arc::new(semaphore_guard),
43 })
44 }
45
46 pub(crate) async fn save_mut<
47 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt + Sync,
48 >(
49 &self,
50 mut entity: E,
51 ) -> CryptoKeystoreResult<E> {
52 entity.pre_save().await?;
53 let mut cache_guard = self.cache.write().await;
54 let table = cache_guard.entry(E::COLLECTION_NAME.to_string()).or_default();
55 let serialized = postcard::to_stdvec(&entity)?;
56 table.insert(entity.merge_key(), Zeroizing::new(serialized));
60 Ok(entity)
61 }
62
63 pub(crate) async fn remove<
64 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt,
65 S: AsRef<[u8]>,
66 >(
67 &self,
68 id: S,
69 ) -> CryptoKeystoreResult<()> {
70 let mut cache_guard = self.cache.write().await;
71 if let Entry::Occupied(mut table) = cache_guard.entry(E::COLLECTION_NAME.to_string())
72 && let Entry::Occupied(cached_record) = table.get_mut().entry(id.as_ref().to_vec())
73 {
74 cached_record.remove_entry();
75 };
76
77 let mut deleted_list = self.deleted.write().await;
78 deleted_list.push(EntityId::from_collection_name(E::COLLECTION_NAME, id.as_ref())?);
79 Ok(())
80 }
81
82 pub(crate) async fn child_groups<E>(&self, entity: E, persisted_records: Vec<E>) -> CryptoKeystoreResult<Vec<E>>
83 where
84 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + PersistedMlsGroupExt + Sync,
85 {
86 let cached_records = self
89 .find_all_in_cache()
90 .await?
91 .into_iter()
92 .filter(|maybe_child: &E| {
93 maybe_child
94 .parent_id()
95 .map(|parent_id| parent_id == entity.id_raw())
96 .unwrap_or_default()
97 })
98 .collect();
99
100 Ok(self
101 .merge_records(cached_records, persisted_records, EntityFindParams::default())
102 .await)
103 }
104
105 pub(crate) async fn cred_delete_by_credential(&self, cred: Vec<u8>) -> CryptoKeystoreResult<()> {
106 let mut cache_guard = self.cache.write().await;
107 if let Entry::Occupied(mut table) = cache_guard.entry(MlsCredential::COLLECTION_NAME.to_string()) {
108 table.get_mut().retain(|_, value| **value != cred);
109 }
110
111 let mut deleted_list = self.deleted_credentials.write().await;
112 deleted_list.push(cred);
113 Ok(())
114 }
115
116 pub(crate) async fn remove_pending_messages_by_conversation_id(
117 &self,
118 conversation_id: impl AsRef<[u8]> + Send,
119 ) -> CryptoKeystoreResult<()> {
120 let mut result = Ok(());
122
123 let mut cache_guard = self.cache.write().await;
124 if let Entry::Occupied(mut table) = cache_guard.entry(MlsPendingMessage::COLLECTION_NAME.to_string()) {
125 table.get_mut().retain(|_key, record_bytes| {
126 postcard::from_bytes::<MlsPendingMessage>(record_bytes)
127 .map(|pending_message| pending_message.foreign_id != conversation_id.as_ref())
128 .inspect_err(|err| result = Err(err.clone()))
129 .unwrap_or(false)
130 });
131 }
132
133 let mut deleted_list = self.deleted.write().await;
134 deleted_list.push(EntityId::from_collection_name(
135 MlsPendingMessage::COLLECTION_NAME,
136 conversation_id.as_ref(),
137 )?);
138 result.map_err(Into::into)
139 }
140
141 pub(crate) async fn find_pending_messages_by_conversation_id(
142 &self,
143 conversation_id: &[u8],
144 persisted_records: Vec<MlsPendingMessage>,
145 ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
146 let cached_records = self
147 .find_all_in_cache::<MlsPendingMessage>()
148 .await?
149 .into_iter()
150 .filter(|pending_message| pending_message.foreign_id == conversation_id)
151 .collect();
152 let merged_records = self
153 .merge_records(cached_records, persisted_records, Default::default())
154 .await;
155 Ok(merged_records)
156 }
157
158 async fn find_in_cache<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<E>>
159 where
160 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
161 {
162 let cache_guard = self.cache.read().await;
163 cache_guard
164 .get(E::COLLECTION_NAME)
165 .and_then(|table| {
166 table
167 .get(id)
168 .map(|record| -> CryptoKeystoreResult<_> { postcard::from_bytes::<E>(record).map_err(Into::into) })
169 })
170 .transpose()
171 }
172
173 pub(crate) async fn find<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<Option<E>>>
178 where
179 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
180 {
181 let maybe_cached_record = self.find_in_cache(id).await?;
182 if let Some(cached_record) = maybe_cached_record {
183 return Ok(Some(Some(cached_record)));
184 }
185
186 let deleted_list = self.deleted.read().await;
187 if deleted_list.contains(&EntityId::from_collection_name(E::COLLECTION_NAME, id)?) {
188 return Ok(Some(None));
189 }
190
191 Ok(None)
192 }
193
194 pub(crate) async fn find_unique<U: UniqueEntity<ConnectionType = KeystoreDatabaseConnection>>(
195 &self,
196 ) -> CryptoKeystoreResult<Option<U>> {
197 #[cfg(target_family = "wasm")]
198 let id = &U::ID;
199 #[cfg(not(target_family = "wasm"))]
200 let id = &[U::ID as u8];
201 let maybe_cached_record = self.find_in_cache::<U>(id).await?;
202 match maybe_cached_record {
203 Some(cached_record) => Ok(Some(cached_record)),
204 _ => {
205 Ok(None)
208 }
209 }
210 }
211
212 async fn find_all_in_cache<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
213 &self,
214 ) -> CryptoKeystoreResult<Vec<E>> {
215 let cache_guard = self.cache.read().await;
216 let cached_records = cache_guard
217 .get(E::COLLECTION_NAME)
218 .map(|table| {
219 table
220 .values()
221 .map(|record| postcard::from_bytes::<E>(record).map_err(Into::into))
222 .collect::<CryptoKeystoreResult<Vec<_>>>()
223 })
224 .transpose()?
225 .unwrap_or_default();
226 Ok(cached_records)
227 }
228
229 pub(crate) async fn find_all<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
230 &self,
231 persisted_records: Vec<E>,
232 params: EntityFindParams,
233 ) -> CryptoKeystoreResult<Vec<E>> {
234 let cached_records = self.find_all_in_cache().await?;
235 let merged_records = self.merge_records(cached_records, persisted_records, params).await;
236 Ok(merged_records)
237 }
238
239 pub(crate) async fn find_many<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
240 &self,
241 persisted_records: Vec<E>,
242 ids: &[Vec<u8>],
243 ) -> CryptoKeystoreResult<Vec<E>> {
244 let records = self
245 .find_all(persisted_records, EntityFindParams::default())
246 .await?
247 .into_iter()
248 .filter(|record| ids.contains(&record.id_raw().to_vec()))
249 .collect();
250 Ok(records)
251 }
252
253 async fn merge_records<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
260 &self,
261 records_a: Vec<E>,
262 records_b: Vec<E>,
263 params: EntityFindParams,
264 ) -> Vec<E> {
265 let mut merged = records_a.into_iter().chain(records_b).unique_by(|e| e.merge_key());
266
267 let deleted_records = self.deleted.read().await;
268 let deleted_credentials = self.deleted_credentials.read().await;
269
270 let merged: &mut dyn Iterator<Item = E> = if params.reverse { &mut merged.rev() } else { &mut merged };
271
272 merged
273 .filter(|record| {
274 !Self::record_is_in_deleted_list(record, &deleted_records)
275 && !Self::credential_is_in_deleted_list(record, &deleted_credentials)
276 })
277 .skip(params.offset.unwrap_or(0) as usize)
278 .take(params.limit.unwrap_or(u32::MAX) as usize)
279 .collect()
280 }
281
282 fn record_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
283 record: &E,
284 deleted_records: &[EntityId],
285 ) -> bool {
286 let id = EntityId::from_collection_name(E::COLLECTION_NAME, record.id_raw());
287 let Ok(id) = id else { return false };
288 deleted_records.contains(&id)
289 }
290
291 fn credential_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
292 maybe_credential: &E,
293 deleted_credentials: &[Vec<u8>],
294 ) -> bool {
295 let Some(credential) = maybe_credential.downcast::<MlsCredential>() else {
296 return false;
297 };
298 deleted_credentials.contains(&credential.credential)
299 }
300}
301
302macro_rules! commit_transaction {
336 ($keystore_transaction:expr_2021, $db:expr_2021, [ $( ($records:ident, $entity:ty) ),*], proteus_types: [ $( ($conditional_records:ident, $conditional_entity:ty) ),*]) => {
337 #[cfg(feature = "proteus-keystore")]
338 commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*], [ $( ($conditional_records, $conditional_entity) ),*]);
339
340 #[cfg(not(feature = "proteus-keystore"))]
341 commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*]);
342 };
343 ($keystore_transaction:expr_2021, $db:expr_2021, $([ $( ($records:ident, $entity:ty) ),*]),*) => {
344 let cached_collections = ( $( $(
345 $keystore_transaction.find_all_in_cache::<$entity>().await?,
346 )* )* );
347
348 let ( $( $( $records, )* )* ) = cached_collections;
349
350 let conn = $db.borrow_conn().await?;
351 let mut conn = conn.conn().await;
352 let deleted_ids = $keystore_transaction.deleted.read().await;
353
354 let mut tables = Vec::new();
355 $( $(
356 if !$records.is_empty() {
357 tables.push(<$entity>::COLLECTION_NAME);
358 }
359 )* )*
360
361 for deleted_id in deleted_ids.iter() {
362 tables.push(deleted_id.collection_name());
363 }
364
365 if tables.is_empty() {
366 log::debug!("Empty transaction was committed.");
367 return Ok(());
368 }
369
370 #[cfg(target_family = "wasm")]
371 let tx = conn.new_transaction(&tables).await?;
372 #[cfg(not(target_family = "wasm"))]
373 let tx = conn.transaction()?.into();
374
375 $( $(
376 if !$records.is_empty() {
377 for record in $records {
378 dynamic_dispatch::execute_save(&tx, &record.to_transaction_entity()).await?;
379 }
380 }
381 )* )*
382
383
384 for deleted_id in deleted_ids.iter() {
385 dynamic_dispatch::execute_delete(&tx, deleted_id).await?
386 }
387
388 for deleted_credential in $keystore_transaction.deleted_credentials.read().await.iter() {
389 MlsCredential::delete_by_credential(&tx, deleted_credential.to_owned()).await?;
390 }
391
392 tx.commit_tx().await?;
393 };
394}
395
396impl KeystoreTransaction {
397 pub(crate) async fn commit(&self, db: &Database) -> Result<(), CryptoKeystoreError> {
400 commit_transaction!(
401 self, db,
402 [
403 (identifier_01, MlsCredential),
404 (identifier_02, MlsSignatureKeyPair),
405 (identifier_03, MlsHpkePrivateKey),
406 (identifier_04, MlsEncryptionKeyPair),
407 (identifier_05, MlsEpochEncryptionKeyPair),
408 (identifier_06, MlsPskBundle),
409 (identifier_07, MlsKeyPackage),
410 (identifier_08, PersistedMlsGroup),
411 (identifier_09, PersistedMlsPendingGroup),
412 (identifier_10, MlsPendingMessage),
413 (identifier_11, E2eiEnrollment),
414 (identifier_13, E2eiAcmeCA),
416 (identifier_14, E2eiIntermediateCert),
417 (identifier_15, E2eiCrl),
418 (identifier_16, ConsumerData)
419 ],
420 proteus_types: [
421 (identifier_17, ProteusPrekey),
422 (identifier_18, ProteusIdentity),
423 (identifier_19, ProteusSession)
424 ]
425 );
426
427 Ok(())
428 }
429}