core_crypto_keystore/transaction/
mod.rs1use std::{
2 borrow::Cow,
3 collections::{HashMap, HashSet, hash_map::Entry},
4 sync::Arc,
5};
6
7use async_lock::{RwLock, SemaphoreGuardArc};
8use itertools::Itertools;
9
10use crate::{
11 CryptoKeystoreError, CryptoKeystoreResult,
12 connection::{Database, KeystoreDatabaseConnection},
13 entities::{MlsPendingMessage, PersistedMlsGroup},
14 traits::{
15 BorrowPrimaryKey, Entity, EntityBase as _, EntityDatabaseMutation, EntityDeleteBorrowed, KeyType,
16 SearchableEntity,
17 },
18 transaction::dynamic_dispatch::EntityId,
19};
20
21pub(crate) mod dynamic_dispatch;
22
23type InMemoryTable = HashMap<EntityId, dynamic_dispatch::Entity>;
25type InMemoryCollection = Arc<RwLock<HashMap<&'static str, InMemoryTable>>>;
27
28#[derive(Debug, Clone)]
31pub(crate) struct KeystoreTransaction {
32 cache: InMemoryCollection,
33 deleted: Arc<RwLock<HashSet<EntityId>>>,
34 _semaphore_guard: Arc<SemaphoreGuardArc>,
35}
36
37impl KeystoreTransaction {
38 pub(crate) async fn new(semaphore_guard: SemaphoreGuardArc) -> CryptoKeystoreResult<Self> {
42 Ok(Self {
43 cache: Default::default(),
44 deleted: Arc::new(Default::default()),
45 _semaphore_guard: Arc::new(semaphore_guard),
46 })
47 }
48
49 pub(crate) async fn save<'a, E>(&self, mut entity: E) -> CryptoKeystoreResult<E::AutoGeneratedFields>
58 where
59 E: Entity<ConnectionType = KeystoreDatabaseConnection> + EntityDatabaseMutation<'a> + Send + Sync,
60 {
61 let auto_generated_fields = entity.pre_save().await?;
62
63 let entity_id =
64 EntityId::from_entity(&entity).ok_or(CryptoKeystoreError::UnknownCollectionName(E::COLLECTION_NAME))?;
65 {
66 let mut cache_guard = self.cache.write().await;
68 let table = cache_guard.entry(E::COLLECTION_NAME).or_default();
69 table.insert(entity_id.clone(), entity.to_transaction_entity());
70 }
71 {
72 let mut cache_guard = self.deleted.write().await;
75 cache_guard.remove(&entity_id);
76 }
77
78 Ok(auto_generated_fields)
79 }
80
81 async fn remove_by_entity_id<'a, E>(&self, entity_id: EntityId) -> CryptoKeystoreResult<()>
82 where
83 E: Entity + EntityDatabaseMutation<'a>,
84 {
85 let mut cache_guard = self.cache.write().await;
88 if let Entry::Occupied(mut table) = cache_guard.entry(E::COLLECTION_NAME)
89 && let Entry::Occupied(cached_record) = table.get_mut().entry(entity_id.clone())
90 {
91 cached_record.remove_entry();
92 };
93
94 let mut deleted_set = self.deleted.write().await;
96 deleted_set.insert(entity_id);
97 Ok(())
98 }
99
100 pub(crate) async fn remove<'a, E>(&self, id: &E::PrimaryKey) -> CryptoKeystoreResult<()>
107 where
108 E: Entity + EntityDatabaseMutation<'a>,
109 {
110 let entity_id = EntityId::from_primary_key::<E>(id)
111 .ok_or(CryptoKeystoreError::UnknownCollectionName(E::COLLECTION_NAME))?;
112 self.remove_by_entity_id::<E>(entity_id).await
113 }
114
115 pub(crate) async fn remove_borrowed<'a, E>(&self, id: &E::BorrowedPrimaryKey) -> CryptoKeystoreResult<()>
120 where
121 E: EntityDeleteBorrowed<'a> + BorrowPrimaryKey,
122 {
123 let entity_id = EntityId::from_borrowed_primary_key::<E>(id)
124 .ok_or(CryptoKeystoreError::UnknownCollectionName(E::COLLECTION_NAME))?;
125 self.remove_by_entity_id::<E>(entity_id).await
126 }
127
128 pub(crate) async fn child_groups(
129 &self,
130 entity: PersistedMlsGroup,
131 persisted_records: impl IntoIterator<Item = PersistedMlsGroup>,
132 ) -> CryptoKeystoreResult<Vec<PersistedMlsGroup>> {
133 let cached_records = self.find_all_in_cache::<PersistedMlsGroup>().await;
135 let cached_records = cached_records
136 .iter()
137 .filter(|maybe_child| {
138 maybe_child
139 .parent_id
140 .as_deref()
141 .map(|parent_id| parent_id == entity.borrow_primary_key().bytes().as_ref())
142 .unwrap_or_default()
143 })
144 .map(Arc::as_ref)
145 .map(Cow::Borrowed);
146
147 let persisted_records = persisted_records.into_iter().map(Cow::Owned);
148
149 Ok(self.merge_records(cached_records, persisted_records).await)
150 }
151
152 pub(crate) async fn remove_pending_messages_by_conversation_id(&self, conversation_id: impl AsRef<[u8]> + Send) {
153 let conversation_id = conversation_id.as_ref();
154
155 let mut cache_guard = self.cache.write().await;
156 if let Entry::Occupied(mut table) = cache_guard.entry(MlsPendingMessage::COLLECTION_NAME) {
157 table.get_mut().retain(|_key, entity| {
158 let pending_message = entity
159 .downcast::<MlsPendingMessage>()
160 .expect("table for MlsPendingMessage contains only that type");
161 pending_message.foreign_id != conversation_id
162 });
163 }
164 drop(cache_guard);
165
166 let mut deleted_set = self.deleted.write().await;
167 deleted_set.insert(
168 EntityId::from_key::<MlsPendingMessage>(conversation_id.into())
169 .expect("mls pending messages are proper entities which can be parsed"),
170 );
171 }
172
173 pub(crate) async fn find_pending_messages_by_conversation_id(
174 &self,
175 conversation_id: &[u8],
176 persisted_records: impl IntoIterator<Item = MlsPendingMessage>,
177 ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
178 let persisted_records = persisted_records.into_iter().map(Cow::Owned);
179
180 let cached_records = self.find_all_in_cache::<MlsPendingMessage>().await;
181 let cached_records = cached_records
182 .iter()
183 .filter(|pending_message| pending_message.foreign_id == conversation_id)
184 .map(Arc::as_ref)
185 .map(Cow::Borrowed);
186
187 let merged_records = self.merge_records(cached_records, persisted_records).await;
188 Ok(merged_records)
189 }
190
191 async fn find_in_cache<E>(&self, entity_id: &EntityId) -> Option<Arc<E>>
192 where
193 E: Entity + Send + Sync,
194 {
195 let cache_guard = self.cache.read().await;
196 cache_guard
197 .get(E::COLLECTION_NAME)
198 .and_then(|table| table.get(entity_id).and_then(|entity| entity.downcast()))
199 }
200
201 async fn get_by_entity_id<E>(&self, entity_id: &EntityId) -> Option<Option<Arc<E>>>
206 where
207 E: Entity + Send + Sync,
208 {
209 let deleted_list = self.deleted.read().await;
212 if deleted_list.contains(entity_id) {
213 return Some(None);
214 }
215
216 self.find_in_cache::<E>(entity_id).await.map(Some)
217 }
218
219 pub(crate) async fn get<E>(&self, id: &E::PrimaryKey) -> Option<Option<Arc<E>>>
224 where
225 E: Entity + Send + Sync,
226 {
227 let entity_id = EntityId::from_primary_key::<E>(id)?;
228 self.get_by_entity_id(&entity_id).await
229 }
230
231 pub(crate) async fn get_borrowed<E>(&self, id: &E::BorrowedPrimaryKey) -> Option<Option<Arc<E>>>
236 where
237 E: Entity + BorrowPrimaryKey + Send + Sync,
238 {
239 let entity_id = EntityId::from_borrowed_primary_key::<E>(id)?;
240 self.get_by_entity_id(&entity_id).await
241 }
242
243 async fn find_all_in_cache<E>(&self) -> Vec<Arc<E>>
244 where
245 E: Entity + Send + Sync,
246 {
247 let cache_guard = self.cache.read().await;
248 cache_guard
249 .get(E::COLLECTION_NAME)
250 .map(|table| {
251 table
252 .values()
253 .map(|record: &dynamic_dispatch::Entity| {
254 record
255 .downcast::<E>()
256 .expect("all entries in this table are of this type")
257 .clone()
258 })
259 .collect::<Vec<_>>()
260 })
261 .unwrap_or_default()
262 }
263
264 async fn search_in_cache<E, SearchKey>(&self, search_key: &SearchKey) -> Vec<Arc<E>>
265 where
266 E: Entity + SearchableEntity<SearchKey> + Send + Sync,
267 SearchKey: KeyType,
268 {
269 let cache_guard = self.cache.read().await;
270 cache_guard
271 .get(E::COLLECTION_NAME)
272 .map(|table| {
273 table
274 .values()
275 .filter_map(|record: &dynamic_dispatch::Entity| {
276 let entity = record
277 .downcast::<E>()
278 .expect("all entries in this table are of this type")
279 .clone();
280 entity.matches(search_key).then_some(entity)
281 })
282 .collect()
283 })
284 .unwrap_or_default()
285 }
286
287 pub(crate) async fn find_all<E>(&self, persisted_records: Vec<E>) -> CryptoKeystoreResult<Vec<E>>
288 where
289 E: Clone + Entity + Send + Sync,
290 {
291 let cached_records = self.find_all_in_cache().await;
292 let merged_records = self
293 .merge_records(
294 cached_records.iter().map(Arc::as_ref).map(Cow::Borrowed),
295 persisted_records.into_iter().map(Cow::Owned),
296 )
297 .await;
298 Ok(merged_records)
299 }
300
301 pub(crate) async fn search<E, SearchKey>(
302 &self,
303 persisted_records: Vec<E>,
304 search_key: &SearchKey,
305 ) -> CryptoKeystoreResult<Vec<E>>
306 where
307 E: Clone + Entity + SearchableEntity<SearchKey> + Send + Sync,
308 SearchKey: KeyType,
309 {
310 let cached_records = self.search_in_cache(search_key).await;
311 let merged_records = self
312 .merge_records(
313 cached_records.iter().map(Arc::as_ref).map(Cow::Borrowed),
314 persisted_records.into_iter().map(Cow::Owned),
315 )
316 .await;
317 Ok(merged_records)
318 }
319
320 async fn merge_records<'a, E>(
328 &self,
329 records_a: impl IntoIterator<Item = Cow<'a, E>>,
330 records_b: impl IntoIterator<Item = Cow<'a, E>>,
331 ) -> Vec<E>
332 where
333 E: Clone + Entity,
334 {
335 let deleted_records = self.deleted.read().await;
336
337 records_a
338 .into_iter()
339 .chain(records_b)
340 .unique_by(|e| e.primary_key().bytes().into_owned())
341 .filter_map(|record| {
342 let id = EntityId::from_entity(record.as_ref())?;
343 (!deleted_records.contains(&id)).then_some(record.into_owned())
344 })
345 .collect()
346 }
347
348 pub(crate) async fn commit(&self, db: &Database) -> Result<(), CryptoKeystoreError> {
351 let conn = db.conn().await?;
352 let mut conn = conn.conn().await;
353
354 let cache = self.cache.read().await;
355 let deleted_ids = self.deleted.read().await;
356
357 let table_names_with_deletion = deleted_ids.iter().map(|entity_id| entity_id.collection_name());
358 let table_names_with_save = cache
359 .values()
360 .flat_map(|table| table.keys())
361 .map(|entity_id| entity_id.collection_name());
362 let mut tables = table_names_with_deletion
363 .chain(table_names_with_save)
364 .collect::<Vec<_>>();
365
366 if tables.is_empty() {
367 log::debug!("Empty transaction was committed.");
368 return Ok(());
369 }
370
371 tables.sort_unstable();
372 tables.dedup();
373
374 #[cfg(target_family = "wasm")]
376 let tx = conn.new_transaction(&tables).await?;
377 #[cfg(not(target_family = "wasm"))]
378 let tx = conn.transaction()?.into();
379
380 for entity in cache.values().flat_map(|table| table.values()) {
381 entity.execute_save(&tx).await?;
382 }
383
384 for deleted_id in deleted_ids.iter() {
385 deleted_id.execute_delete(&tx).await?;
386 }
387
388 tx.commit_tx().await?;
390
391 Ok(())
392 }
393}