core_crypto_keystore/transaction/
mod.rs1use std::{
2 borrow::Cow,
3 collections::{HashMap, HashSet, hash_map::Entry},
4 sync::Arc,
5};
6
7use async_lock::{RwLock, SemaphoreGuardArc};
8use itertools::Itertools;
9
10use crate::{
11 CryptoKeystoreError, CryptoKeystoreResult,
12 connection::{Database, KeystoreDatabaseConnection},
13 entities::{MlsPendingMessage, MlsPendingMessagePrimaryKey, PersistedMlsGroupExt},
14 traits::{BorrowPrimaryKey, Entity, EntityBase as _, EntityDatabaseMutation, EntityDeleteBorrowed, KeyType},
15 transaction::dynamic_dispatch::EntityId,
16};
17
18pub(crate) mod dynamic_dispatch;
19
20type InMemoryTable = HashMap<EntityId, dynamic_dispatch::Entity>;
22type InMemoryCollection = Arc<RwLock<HashMap<&'static str, InMemoryTable>>>;
24
25#[derive(Debug, Clone)]
28pub(crate) struct KeystoreTransaction {
29 cache: InMemoryCollection,
30 deleted: Arc<RwLock<HashSet<EntityId>>>,
31 _semaphore_guard: Arc<SemaphoreGuardArc>,
32}
33
34impl KeystoreTransaction {
35 pub(crate) async fn new(semaphore_guard: SemaphoreGuardArc) -> CryptoKeystoreResult<Self> {
39 Ok(Self {
40 cache: Default::default(),
41 deleted: Arc::new(Default::default()),
42 _semaphore_guard: Arc::new(semaphore_guard),
43 })
44 }
45
46 pub(crate) async fn save<'a, E>(&self, mut entity: E) -> CryptoKeystoreResult<E::AutoGeneratedFields>
55 where
56 E: Entity<ConnectionType = KeystoreDatabaseConnection> + EntityDatabaseMutation<'a> + Send + Sync,
57 {
58 let auto_generated_fields = entity.pre_save().await?;
59
60 let entity_id =
61 EntityId::from_entity(&entity).ok_or(CryptoKeystoreError::UnknownCollectionName(E::COLLECTION_NAME))?;
62 {
63 let mut cache_guard = self.cache.write().await;
65 let table = cache_guard.entry(E::COLLECTION_NAME).or_default();
66 table.insert(entity_id.clone(), entity.to_transaction_entity());
67 }
68 {
69 let mut cache_guard = self.deleted.write().await;
72 cache_guard.remove(&entity_id);
73 }
74
75 Ok(auto_generated_fields)
76 }
77
78 async fn remove_by_entity_id<'a, E>(&self, entity_id: EntityId) -> CryptoKeystoreResult<()>
79 where
80 E: Entity + EntityDatabaseMutation<'a>,
81 {
82 let mut cache_guard = self.cache.write().await;
85 if let Entry::Occupied(mut table) = cache_guard.entry(E::COLLECTION_NAME)
86 && let Entry::Occupied(cached_record) = table.get_mut().entry(entity_id.clone())
87 {
88 cached_record.remove_entry();
89 };
90
91 let mut deleted_set = self.deleted.write().await;
93 deleted_set.insert(entity_id);
94 Ok(())
95 }
96
97 pub(crate) async fn remove<'a, E>(&self, id: &E::PrimaryKey) -> CryptoKeystoreResult<()>
104 where
105 E: Entity + EntityDatabaseMutation<'a>,
106 {
107 let entity_id = EntityId::from_primary_key::<E>(id)
108 .ok_or(CryptoKeystoreError::UnknownCollectionName(E::COLLECTION_NAME))?;
109 self.remove_by_entity_id::<E>(entity_id).await
110 }
111
112 pub(crate) async fn remove_borrowed<'a, E>(&self, id: &E::BorrowedPrimaryKey) -> CryptoKeystoreResult<()>
117 where
118 E: EntityDeleteBorrowed<'a> + BorrowPrimaryKey,
119 {
120 let entity_id = EntityId::from_borrowed_primary_key::<E>(id)
121 .ok_or(CryptoKeystoreError::UnknownCollectionName(E::COLLECTION_NAME))?;
122 self.remove_by_entity_id::<E>(entity_id).await
123 }
124
125 pub(crate) async fn child_groups<E>(
126 &self,
127 entity: E,
128 persisted_records: impl IntoIterator<Item = E>,
129 ) -> CryptoKeystoreResult<Vec<E>>
130 where
131 E: Clone + Entity + BorrowPrimaryKey + PersistedMlsGroupExt + Send + Sync,
132 for<'pk> &'pk <E as BorrowPrimaryKey>::BorrowedPrimaryKey: KeyType,
133 {
134 let cached_records = self.find_all_in_cache::<E>().await;
136 let cached_records = cached_records
137 .iter()
138 .filter(|maybe_child| {
139 maybe_child
140 .parent_id()
141 .map(|parent_id| parent_id == entity.borrow_primary_key().bytes().as_ref())
142 .unwrap_or_default()
143 })
144 .map(Arc::as_ref)
145 .map(Cow::Borrowed);
146
147 let persisted_records = persisted_records.into_iter().map(Cow::Owned);
148
149 Ok(self.merge_records(cached_records, persisted_records).await)
150 }
151
152 pub(crate) async fn remove_pending_messages_by_conversation_id(&self, conversation_id: impl AsRef<[u8]> + Send) {
153 let conversation_id = conversation_id.as_ref();
154
155 let mut cache_guard = self.cache.write().await;
156 if let Entry::Occupied(mut table) = cache_guard.entry(MlsPendingMessage::COLLECTION_NAME) {
157 table.get_mut().retain(|_key, entity| {
158 let pending_message = entity
159 .downcast::<MlsPendingMessage>()
160 .expect("table for MlsPendingMessage contains only that type");
161 pending_message.foreign_id != conversation_id
162 });
163 }
164 drop(cache_guard);
165
166 let mut deleted_set = self.deleted.write().await;
167 deleted_set.insert(
168 EntityId::from_primary_key::<MlsPendingMessage>(&MlsPendingMessagePrimaryKey::from_conversation_id(
169 conversation_id,
170 ))
171 .expect("mls pending messages are proper entities which can be parsed"),
172 );
173 }
174
175 pub(crate) async fn find_pending_messages_by_conversation_id(
176 &self,
177 conversation_id: &[u8],
178 persisted_records: impl IntoIterator<Item = MlsPendingMessage>,
179 ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
180 let persisted_records = persisted_records.into_iter().map(Cow::Owned);
181
182 let cached_records = self.find_all_in_cache::<MlsPendingMessage>().await;
183 let cached_records = cached_records
184 .iter()
185 .filter(|pending_message| pending_message.foreign_id == conversation_id)
186 .map(Arc::as_ref)
187 .map(Cow::Borrowed);
188
189 let merged_records = self.merge_records(cached_records, persisted_records).await;
190 Ok(merged_records)
191 }
192
193 async fn find_in_cache<E>(&self, entity_id: &EntityId) -> Option<Arc<E>>
194 where
195 E: Entity + Send + Sync,
196 {
197 let cache_guard = self.cache.read().await;
198 cache_guard
199 .get(E::COLLECTION_NAME)
200 .and_then(|table| table.get(entity_id).and_then(|entity| entity.downcast()))
201 }
202
203 async fn get_by_entity_id<E>(&self, entity_id: &EntityId) -> Option<Option<Arc<E>>>
208 where
209 E: Entity + Send + Sync,
210 {
211 let deleted_list = self.deleted.read().await;
214 if deleted_list.contains(entity_id) {
215 return Some(None);
216 }
217
218 self.find_in_cache::<E>(entity_id).await.map(Some)
219 }
220
221 pub(crate) async fn get<E>(&self, id: &E::PrimaryKey) -> Option<Option<Arc<E>>>
226 where
227 E: Entity + Send + Sync,
228 {
229 let entity_id = EntityId::from_primary_key::<E>(id)?;
230 self.get_by_entity_id(&entity_id).await
231 }
232
233 pub(crate) async fn get_borrowed<E>(&self, id: &E::BorrowedPrimaryKey) -> Option<Option<Arc<E>>>
238 where
239 E: Entity + BorrowPrimaryKey + Send + Sync,
240 {
241 let entity_id = EntityId::from_borrowed_primary_key::<E>(id)?;
242 self.get_by_entity_id(&entity_id).await
243 }
244
245 async fn find_all_in_cache<E>(&self) -> Vec<Arc<E>>
246 where
247 E: Entity + Send + Sync,
248 {
249 let cache_guard = self.cache.read().await;
250 cache_guard
251 .get(E::COLLECTION_NAME)
252 .map(|table| {
253 table
254 .values()
255 .map(|record: &dynamic_dispatch::Entity| {
256 record
257 .downcast::<E>()
258 .expect("all entries in this table are of this type")
259 .clone()
260 })
261 .collect::<Vec<_>>()
262 })
263 .unwrap_or_default()
264 }
265
266 pub(crate) async fn find_all<E>(&self, persisted_records: Vec<E>) -> CryptoKeystoreResult<Vec<E>>
267 where
268 E: Clone + Entity + Send + Sync,
269 {
270 let cached_records = self.find_all_in_cache().await;
271 let merged_records = self
272 .merge_records(
273 cached_records.iter().map(Arc::as_ref).map(Cow::Borrowed),
274 persisted_records.into_iter().map(Cow::Owned),
275 )
276 .await;
277 Ok(merged_records)
278 }
279
280 async fn merge_records<'a, E>(
288 &self,
289 records_a: impl IntoIterator<Item = Cow<'a, E>>,
290 records_b: impl IntoIterator<Item = Cow<'a, E>>,
291 ) -> Vec<E>
292 where
293 E: Clone + Entity,
294 {
295 let deleted_records = self.deleted.read().await;
296
297 records_a
298 .into_iter()
299 .chain(records_b)
300 .unique_by(|e| e.primary_key().bytes().into_owned())
301 .filter_map(|record| {
302 let id = EntityId::from_entity(record.as_ref())?;
303 (!deleted_records.contains(&id)).then_some(record.into_owned())
304 })
305 .collect()
306 }
307
308 pub(crate) async fn commit(&self, db: &Database) -> Result<(), CryptoKeystoreError> {
311 let conn = db.conn().await?;
312 let mut conn = conn.conn().await;
313
314 let cache = self.cache.read().await;
315 let deleted_ids = self.deleted.read().await;
316
317 let table_names_with_deletion = deleted_ids.iter().map(|entity_id| entity_id.collection_name());
318 let table_names_with_save = cache
319 .values()
320 .flat_map(|table| table.keys())
321 .map(|entity_id| entity_id.collection_name());
322 let mut tables = table_names_with_deletion
323 .chain(table_names_with_save)
324 .collect::<Vec<_>>();
325
326 if tables.is_empty() {
327 log::debug!("Empty transaction was committed.");
328 return Ok(());
329 }
330
331 tables.sort_unstable();
332 tables.dedup();
333
334 #[cfg(target_family = "wasm")]
336 let tx = conn.new_transaction(&tables).await?;
337 #[cfg(not(target_family = "wasm"))]
338 let tx = conn.transaction()?.into();
339
340 for entity in cache.values().flat_map(|table| table.values()) {
341 entity.execute_save(&tx).await?;
342 }
343
344 for deleted_id in deleted_ids.iter() {
345 deleted_id.execute_delete(&tx).await?;
346 }
347
348 tx.commit_tx().await?;
350
351 Ok(())
352 }
353}