core_crypto_keystore/transaction/
mod.rs1use std::{
2 collections::{HashMap, hash_map::Entry},
3 sync::Arc,
4};
5
6use async_lock::{RwLock, SemaphoreGuardArc};
7use itertools::Itertools;
8use zeroize::Zeroizing;
9
10#[cfg(feature = "proteus-keystore")]
11use crate::entities::proteus::*;
12use crate::{
13 CryptoKeystoreError, CryptoKeystoreResult,
14 connection::{Database, KeystoreDatabaseConnection},
15 entities::{ConsumerData, EntityBase, EntityFindParams, EntityTransactionExt, UniqueEntity, mls::*},
16 transaction::dynamic_dispatch::EntityId,
17};
18
19pub mod dynamic_dispatch;
20
21#[derive(Debug, Default, derive_more::Deref, derive_more::DerefMut)]
22struct InMemoryTable(HashMap<Vec<u8>, Zeroizing<Vec<u8>>>);
23
24type InMemoryCache = Arc<RwLock<HashMap<String, InMemoryTable>>>;
25
26#[derive(Debug, Clone)]
29pub(crate) struct KeystoreTransaction {
30 cache: InMemoryCache,
31 deleted: Arc<RwLock<Vec<EntityId>>>,
32 _semaphore_guard: Arc<SemaphoreGuardArc>,
33}
34
35impl KeystoreTransaction {
36 pub(crate) async fn new(semaphore_guard: SemaphoreGuardArc) -> CryptoKeystoreResult<Self> {
37 Ok(Self {
38 cache: Default::default(),
39 deleted: Arc::new(Default::default()),
40 _semaphore_guard: Arc::new(semaphore_guard),
41 })
42 }
43
44 pub(crate) async fn save_mut<
45 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt + Sync,
46 >(
47 &self,
48 mut entity: E,
49 ) -> CryptoKeystoreResult<E> {
50 entity.pre_save().await?;
51 let mut cache_guard = self.cache.write().await;
52 let table = cache_guard.entry(E::COLLECTION_NAME.to_string()).or_default();
53 let serialized = postcard::to_stdvec(&entity)?;
54 table.insert(entity.merge_key(), Zeroizing::new(serialized));
57 Ok(entity)
58 }
59
60 pub(crate) async fn remove<
61 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt,
62 S: AsRef<[u8]>,
63 >(
64 &self,
65 id: S,
66 ) -> CryptoKeystoreResult<()> {
67 let mut cache_guard = self.cache.write().await;
68 if let Entry::Occupied(mut table) = cache_guard.entry(E::COLLECTION_NAME.to_string())
69 && let Entry::Occupied(cached_record) = table.get_mut().entry(id.as_ref().to_vec())
70 {
71 cached_record.remove_entry();
72 };
73
74 let mut deleted_list = self.deleted.write().await;
75 deleted_list.push(EntityId::from_collection_name(E::COLLECTION_NAME, id.as_ref())?);
76 Ok(())
77 }
78
79 pub(crate) async fn child_groups<E>(&self, entity: E, persisted_records: Vec<E>) -> CryptoKeystoreResult<Vec<E>>
80 where
81 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + PersistedMlsGroupExt + Sync,
82 {
83 let cached_records = self
86 .find_all_in_cache()
87 .await?
88 .into_iter()
89 .filter(|maybe_child: &E| {
90 maybe_child
91 .parent_id()
92 .map(|parent_id| parent_id == entity.id_raw())
93 .unwrap_or_default()
94 })
95 .collect();
96
97 Ok(self
98 .merge_records(cached_records, persisted_records, EntityFindParams::default())
99 .await)
100 }
101
102 pub(crate) async fn remove_pending_messages_by_conversation_id(
103 &self,
104 conversation_id: impl AsRef<[u8]> + Send,
105 ) -> CryptoKeystoreResult<()> {
106 let mut result = Ok(());
108
109 let mut cache_guard = self.cache.write().await;
110 if let Entry::Occupied(mut table) = cache_guard.entry(MlsPendingMessage::COLLECTION_NAME.to_string()) {
111 table.get_mut().retain(|_key, record_bytes| {
112 postcard::from_bytes::<MlsPendingMessage>(record_bytes)
113 .map(|pending_message| pending_message.foreign_id != conversation_id.as_ref())
114 .inspect_err(|err| result = Err(err.clone()))
115 .unwrap_or(false)
116 });
117 }
118
119 let mut deleted_list = self.deleted.write().await;
120 deleted_list.push(EntityId::from_collection_name(
121 MlsPendingMessage::COLLECTION_NAME,
122 conversation_id.as_ref(),
123 )?);
124 result.map_err(Into::into)
125 }
126
127 pub(crate) async fn find_pending_messages_by_conversation_id(
128 &self,
129 conversation_id: &[u8],
130 persisted_records: Vec<MlsPendingMessage>,
131 ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
132 let cached_records = self
133 .find_all_in_cache::<MlsPendingMessage>()
134 .await?
135 .into_iter()
136 .filter(|pending_message| pending_message.foreign_id == conversation_id)
137 .collect();
138 let merged_records = self
139 .merge_records(cached_records, persisted_records, Default::default())
140 .await;
141 Ok(merged_records)
142 }
143
144 async fn find_in_cache<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<E>>
145 where
146 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
147 {
148 let cache_guard = self.cache.read().await;
149 cache_guard
150 .get(E::COLLECTION_NAME)
151 .and_then(|table| {
152 table
153 .get(id)
154 .map(|record| -> CryptoKeystoreResult<_> { postcard::from_bytes::<E>(record).map_err(Into::into) })
155 })
156 .transpose()
157 }
158
159 pub(crate) async fn find<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<Option<E>>>
164 where
165 E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
166 {
167 let maybe_cached_record = self.find_in_cache(id).await?;
168 if let Some(cached_record) = maybe_cached_record {
169 return Ok(Some(Some(cached_record)));
170 }
171
172 let deleted_list = self.deleted.read().await;
173 if deleted_list.contains(&EntityId::from_collection_name(E::COLLECTION_NAME, id)?) {
174 return Ok(Some(None));
175 }
176
177 Ok(None)
178 }
179
180 pub(crate) async fn find_unique<U: UniqueEntity<ConnectionType = KeystoreDatabaseConnection>>(
181 &self,
182 ) -> CryptoKeystoreResult<Option<U>> {
183 #[cfg(target_family = "wasm")]
184 let id = &U::ID;
185 #[cfg(not(target_family = "wasm"))]
186 let id = &[U::ID as u8];
187 let maybe_cached_record = self.find_in_cache::<U>(id).await?;
188 match maybe_cached_record {
189 Some(cached_record) => Ok(Some(cached_record)),
190 _ => {
191 Ok(None)
194 }
195 }
196 }
197
198 async fn find_all_in_cache<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
199 &self,
200 ) -> CryptoKeystoreResult<Vec<E>> {
201 let cache_guard = self.cache.read().await;
202 let cached_records = cache_guard
203 .get(E::COLLECTION_NAME)
204 .map(|table| {
205 table
206 .values()
207 .map(|record| postcard::from_bytes::<E>(record).map_err(Into::into))
208 .collect::<CryptoKeystoreResult<Vec<_>>>()
209 })
210 .transpose()?
211 .unwrap_or_default();
212 Ok(cached_records)
213 }
214
215 pub(crate) async fn find_all<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
216 &self,
217 persisted_records: Vec<E>,
218 params: EntityFindParams,
219 ) -> CryptoKeystoreResult<Vec<E>> {
220 let cached_records = self.find_all_in_cache().await?;
221 let merged_records = self.merge_records(cached_records, persisted_records, params).await;
222 Ok(merged_records)
223 }
224
225 pub(crate) async fn find_many<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
226 &self,
227 persisted_records: Vec<E>,
228 ids: &[Vec<u8>],
229 ) -> CryptoKeystoreResult<Vec<E>> {
230 let records = self
231 .find_all(persisted_records, EntityFindParams::default())
232 .await?
233 .into_iter()
234 .filter(|record| ids.contains(&record.id_raw().to_vec()))
235 .collect();
236 Ok(records)
237 }
238
239 async fn merge_records<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
247 &self,
248 records_a: Vec<E>,
249 records_b: Vec<E>,
250 params: EntityFindParams,
251 ) -> Vec<E> {
252 let mut merged = records_a.into_iter().chain(records_b).unique_by(|e| e.merge_key());
253
254 let deleted_records = self.deleted.read().await;
255
256 let merged: &mut dyn Iterator<Item = E> = if params.reverse { &mut merged.rev() } else { &mut merged };
257
258 merged
259 .filter(|record| !Self::record_is_in_deleted_list(record, &deleted_records))
260 .skip(params.offset.unwrap_or(0) as usize)
261 .take(params.limit.unwrap_or(u32::MAX) as usize)
262 .collect()
263 }
264
265 fn record_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
266 record: &E,
267 deleted_records: &[EntityId],
268 ) -> bool {
269 let id = EntityId::from_collection_name(E::COLLECTION_NAME, record.id_raw());
270 let Ok(id) = id else { return false };
271 deleted_records.contains(&id)
272 }
273}
274
275macro_rules! commit_transaction {
309 ($keystore_transaction:expr_2021, $db:expr_2021, [ $( ($records:ident, $entity:ty) ),*], proteus_types: [ $( ($conditional_records:ident, $conditional_entity:ty) ),*]) => {
310 #[cfg(feature = "proteus-keystore")]
311 commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*], [ $( ($conditional_records, $conditional_entity) ),*]);
312
313 #[cfg(not(feature = "proteus-keystore"))]
314 commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*]);
315 };
316 ($keystore_transaction:expr_2021, $db:expr_2021, $([ $( ($records:ident, $entity:ty) ),*]),*) => {
317 let cached_collections = ( $( $(
318 $keystore_transaction.find_all_in_cache::<$entity>().await?,
319 )* )* );
320
321 let ( $( $( $records, )* )* ) = cached_collections;
322
323 let conn = $db.conn().await?;
324 let mut conn = conn.conn().await;
325 let deleted_ids = $keystore_transaction.deleted.read().await;
326
327 let mut tables = Vec::new();
328 $( $(
329 if !$records.is_empty() {
330 tables.push(<$entity>::COLLECTION_NAME);
331 }
332 )* )*
333
334 for deleted_id in deleted_ids.iter() {
335 tables.push(deleted_id.collection_name());
336 }
337
338 if tables.is_empty() {
339 log::debug!("Empty transaction was committed.");
340 return Ok(());
341 }
342
343 #[cfg(target_family = "wasm")]
344 let tx = conn.new_transaction(&tables).await?;
345 #[cfg(not(target_family = "wasm"))]
346 let tx = conn.transaction()?.into();
347
348 $( $(
349 if !$records.is_empty() {
350 for record in $records {
351 dynamic_dispatch::execute_save(&tx, &record.to_transaction_entity()).await?;
352 }
353 }
354 )* )*
355
356
357 for deleted_id in deleted_ids.iter() {
358 dynamic_dispatch::execute_delete(&tx, deleted_id).await?
359 }
360
361 tx.commit_tx().await?;
362 };
363}
364
365impl KeystoreTransaction {
366 pub(crate) async fn commit(&self, db: &Database) -> Result<(), CryptoKeystoreError> {
369 commit_transaction!(
370 self, db,
371 [
372 (identifier_01, StoredCredential),
373 (identifier_03, StoredHpkePrivateKey),
375 (identifier_04, StoredEncryptionKeyPair),
376 (identifier_05, StoredEpochEncryptionKeypair),
377 (identifier_06, StoredPskBundle),
378 (identifier_07, StoredKeypackage),
379 (identifier_08, PersistedMlsGroup),
380 (identifier_09, PersistedMlsPendingGroup),
381 (identifier_10, MlsPendingMessage),
382 (identifier_11, StoredE2eiEnrollment),
383 (identifier_13, E2eiAcmeCA),
385 (identifier_14, E2eiIntermediateCert),
386 (identifier_15, E2eiCrl),
387 (identifier_16, ConsumerData)
388 ],
389 proteus_types: [
390 (identifier_17, ProteusPrekey),
391 (identifier_18, ProteusIdentity),
392 (identifier_19, ProteusSession)
393 ]
394 );
395
396 Ok(())
397 }
398}