core_crypto_keystore/transaction/
mod.rs

1use std::{
2    borrow::Cow,
3    collections::{HashMap, HashSet, hash_map::Entry},
4    sync::Arc,
5};
6
7use async_lock::{RwLock, SemaphoreGuardArc};
8use itertools::Itertools;
9
10use crate::{
11    CryptoKeystoreError, CryptoKeystoreResult,
12    connection::{Database, KeystoreDatabaseConnection},
13    entities::{MlsPendingMessage, PersistedMlsGroup},
14    traits::{
15        BorrowPrimaryKey, Entity, EntityBase as _, EntityDatabaseMutation, EntityDeleteBorrowed, KeyType,
16        SearchableEntity,
17    },
18    transaction::dynamic_dispatch::EntityId,
19};
20
21pub(crate) mod dynamic_dispatch;
22
23/// table: primary key -> entity reference
24type InMemoryTable = HashMap<EntityId, dynamic_dispatch::Entity>;
25/// collection: collection name -> table
26type InMemoryCollection = Arc<RwLock<HashMap<&'static str, InMemoryTable>>>;
27
28/// This represents a transaction, where all operations will be done in memory and committed at the
29/// end
30#[derive(Debug, Clone)]
31pub(crate) struct KeystoreTransaction {
32    cache: InMemoryCollection,
33    deleted: Arc<RwLock<HashSet<EntityId>>>,
34    _semaphore_guard: Arc<SemaphoreGuardArc>,
35}
36
37impl KeystoreTransaction {
38    /// Instantiate a new transaction.
39    ///
40    /// Requires a semaphore guard to ensure that only one exists at a time.
41    pub(crate) async fn new(semaphore_guard: SemaphoreGuardArc) -> CryptoKeystoreResult<Self> {
42        Ok(Self {
43            cache: Default::default(),
44            deleted: Arc::new(Default::default()),
45            _semaphore_guard: Arc::new(semaphore_guard),
46        })
47    }
48
49    /// Save an entity into this transaction.
50    ///
51    /// This is a multi-step process:
52    ///
53    /// - Adjust the entity by calling its [`pre_save()`][Entity::pre_save] method.
54    /// - Store the entity in an internal map.
55    ///   - Remove the entity from the set of deleted entities, if it was there.
56    /// - On [`Self::commit`], actually persist the entity into the supplied database.
57    pub(crate) async fn save<'a, E>(&self, mut entity: E) -> CryptoKeystoreResult<E::AutoGeneratedFields>
58    where
59        E: Entity<ConnectionType = KeystoreDatabaseConnection> + EntityDatabaseMutation<'a> + Send + Sync,
60    {
61        let auto_generated_fields = entity.pre_save().await?;
62
63        let entity_id =
64            EntityId::from_entity(&entity).ok_or(CryptoKeystoreError::UnknownCollectionName(E::COLLECTION_NAME))?;
65        {
66            // start by adding the entity
67            let mut cache_guard = self.cache.write().await;
68            let table = cache_guard.entry(E::COLLECTION_NAME).or_default();
69            table.insert(entity_id.clone(), entity.to_transaction_entity());
70        }
71        {
72            // at this point remove the entity from the set of deleted entities to ensure that
73            // this new data gets propagated
74            let mut cache_guard = self.deleted.write().await;
75            cache_guard.remove(&entity_id);
76        }
77
78        Ok(auto_generated_fields)
79    }
80
81    async fn remove_by_entity_id<'a, E>(&self, entity_id: EntityId) -> CryptoKeystoreResult<()>
82    where
83        E: Entity + EntityDatabaseMutation<'a>,
84    {
85        // rm this entity from the set of added/modified items
86        // it might never touch the real db at all
87        let mut cache_guard = self.cache.write().await;
88        if let Entry::Occupied(mut table) = cache_guard.entry(E::COLLECTION_NAME)
89            && let Entry::Occupied(cached_record) = table.get_mut().entry(entity_id.clone())
90        {
91            cached_record.remove_entry();
92        };
93
94        // add this entity to the set of items which should be deleted from the persisted db
95        let mut deleted_set = self.deleted.write().await;
96        deleted_set.insert(entity_id);
97        Ok(())
98    }
99
100    /// Remove an entity by its primary key.
101    ///
102    /// Where the primary key has a distinct borrowed form, consider [`Self::remove_borrowed`].
103    ///
104    /// Note that this doesn't return whether or not anything was actually removed because
105    /// that won't happen until the transaction is committed.
106    pub(crate) async fn remove<'a, E>(&self, id: &E::PrimaryKey) -> CryptoKeystoreResult<()>
107    where
108        E: Entity + EntityDatabaseMutation<'a>,
109    {
110        let entity_id = EntityId::from_primary_key::<E>(id)
111            .ok_or(CryptoKeystoreError::UnknownCollectionName(E::COLLECTION_NAME))?;
112        self.remove_by_entity_id::<E>(entity_id).await
113    }
114
115    /// Remove an entity by the borrowed form of its primary key.
116    ///
117    /// Note that this doesn't return whether or not anything was actually removed because
118    /// that won't happen until the transaction is committed.
119    pub(crate) async fn remove_borrowed<'a, E>(&self, id: &E::BorrowedPrimaryKey) -> CryptoKeystoreResult<()>
120    where
121        E: EntityDeleteBorrowed<'a> + BorrowPrimaryKey,
122    {
123        let entity_id = EntityId::from_borrowed_primary_key::<E>(id)
124            .ok_or(CryptoKeystoreError::UnknownCollectionName(E::COLLECTION_NAME))?;
125        self.remove_by_entity_id::<E>(entity_id).await
126    }
127
128    pub(crate) async fn child_groups(
129        &self,
130        entity: PersistedMlsGroup,
131        persisted_records: impl IntoIterator<Item = PersistedMlsGroup>,
132    ) -> CryptoKeystoreResult<Vec<PersistedMlsGroup>> {
133        // First get all raw groups from the cache, then filter by their parent id
134        let cached_records = self.find_all_in_cache::<PersistedMlsGroup>().await;
135        let cached_records = cached_records
136            .iter()
137            .filter(|maybe_child| {
138                maybe_child
139                    .parent_id
140                    .as_deref()
141                    .map(|parent_id| parent_id == entity.borrow_primary_key().bytes().as_ref())
142                    .unwrap_or_default()
143            })
144            .map(Arc::as_ref)
145            .map(Cow::Borrowed);
146
147        let persisted_records = persisted_records.into_iter().map(Cow::Owned);
148
149        Ok(self.merge_records(cached_records, persisted_records).await)
150    }
151
152    pub(crate) async fn remove_pending_messages_by_conversation_id(&self, conversation_id: impl AsRef<[u8]> + Send) {
153        let conversation_id = conversation_id.as_ref();
154
155        let mut cache_guard = self.cache.write().await;
156        if let Entry::Occupied(mut table) = cache_guard.entry(MlsPendingMessage::COLLECTION_NAME) {
157            table.get_mut().retain(|_key, entity| {
158                let pending_message = entity
159                    .downcast::<MlsPendingMessage>()
160                    .expect("table for MlsPendingMessage contains only that type");
161                pending_message.foreign_id != conversation_id
162            });
163        }
164        drop(cache_guard);
165
166        let mut deleted_set = self.deleted.write().await;
167        deleted_set.insert(
168            EntityId::from_key::<MlsPendingMessage>(conversation_id.into())
169                .expect("mls pending messages are proper entities which can be parsed"),
170        );
171    }
172
173    pub(crate) async fn find_pending_messages_by_conversation_id(
174        &self,
175        conversation_id: &[u8],
176        persisted_records: impl IntoIterator<Item = MlsPendingMessage>,
177    ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
178        let persisted_records = persisted_records.into_iter().map(Cow::Owned);
179
180        let cached_records = self.find_all_in_cache::<MlsPendingMessage>().await;
181        let cached_records = cached_records
182            .iter()
183            .filter(|pending_message| pending_message.foreign_id == conversation_id)
184            .map(Arc::as_ref)
185            .map(Cow::Borrowed);
186
187        let merged_records = self.merge_records(cached_records, persisted_records).await;
188        Ok(merged_records)
189    }
190
191    async fn find_in_cache<E>(&self, entity_id: &EntityId) -> Option<Arc<E>>
192    where
193        E: Entity + Send + Sync,
194    {
195        let cache_guard = self.cache.read().await;
196        cache_guard
197            .get(E::COLLECTION_NAME)
198            .and_then(|table| table.get(entity_id).and_then(|entity| entity.downcast()))
199    }
200
201    /// The result of this function will have different contents for different scenarios:
202    /// * `Some(Some(E))` - the transaction cache contains the record
203    /// * `Some(None)` - the deletion of the record has been cached
204    /// * `None` - there is no information about the record in the cache
205    async fn get_by_entity_id<E>(&self, entity_id: &EntityId) -> Option<Option<Arc<E>>>
206    where
207        E: Entity + Send + Sync,
208    {
209        // when applying our transaction to the real database, we delete after inserting,
210        // so here we have to check for deletion before we check for existing values
211        let deleted_list = self.deleted.read().await;
212        if deleted_list.contains(entity_id) {
213            return Some(None);
214        }
215
216        self.find_in_cache::<E>(entity_id).await.map(Some)
217    }
218
219    /// The result of this function will have different contents for different scenarios:
220    /// * `Some(Some(E))` - the transaction cache contains the record
221    /// * `Some(None)` - the deletion of the record has been cached
222    /// * `None` - there is no information about the record in the cache
223    pub(crate) async fn get<E>(&self, id: &E::PrimaryKey) -> Option<Option<Arc<E>>>
224    where
225        E: Entity + Send + Sync,
226    {
227        let entity_id = EntityId::from_primary_key::<E>(id)?;
228        self.get_by_entity_id(&entity_id).await
229    }
230
231    /// The result of this function will have different contents for different scenarios:
232    /// * `Some(Some(E))` - the transaction cache contains the record
233    /// * `Some(None)` - the deletion of the record has been cached
234    /// * `None` - there is no information about the record in the cache
235    pub(crate) async fn get_borrowed<E>(&self, id: &E::BorrowedPrimaryKey) -> Option<Option<Arc<E>>>
236    where
237        E: Entity + BorrowPrimaryKey + Send + Sync,
238    {
239        let entity_id = EntityId::from_borrowed_primary_key::<E>(id)?;
240        self.get_by_entity_id(&entity_id).await
241    }
242
243    async fn find_all_in_cache<E>(&self) -> Vec<Arc<E>>
244    where
245        E: Entity + Send + Sync,
246    {
247        let cache_guard = self.cache.read().await;
248        cache_guard
249            .get(E::COLLECTION_NAME)
250            .map(|table| {
251                table
252                    .values()
253                    .map(|record: &dynamic_dispatch::Entity| {
254                        record
255                            .downcast::<E>()
256                            .expect("all entries in this table are of this type")
257                            .clone()
258                    })
259                    .collect::<Vec<_>>()
260            })
261            .unwrap_or_default()
262    }
263
264    async fn search_in_cache<E, SearchKey>(&self, search_key: &SearchKey) -> Vec<Arc<E>>
265    where
266        E: Entity + SearchableEntity<SearchKey> + Send + Sync,
267        SearchKey: KeyType,
268    {
269        let cache_guard = self.cache.read().await;
270        cache_guard
271            .get(E::COLLECTION_NAME)
272            .map(|table| {
273                table
274                    .values()
275                    .filter_map(|record: &dynamic_dispatch::Entity| {
276                        let entity = record
277                            .downcast::<E>()
278                            .expect("all entries in this table are of this type")
279                            .clone();
280                        entity.matches(search_key).then_some(entity)
281                    })
282                    .collect()
283            })
284            .unwrap_or_default()
285    }
286
287    pub(crate) async fn find_all<E>(&self, persisted_records: Vec<E>) -> CryptoKeystoreResult<Vec<E>>
288    where
289        E: Clone + Entity + Send + Sync,
290    {
291        let cached_records = self.find_all_in_cache().await;
292        let merged_records = self
293            .merge_records(
294                cached_records.iter().map(Arc::as_ref).map(Cow::Borrowed),
295                persisted_records.into_iter().map(Cow::Owned),
296            )
297            .await;
298        Ok(merged_records)
299    }
300
301    pub(crate) async fn search<E, SearchKey>(
302        &self,
303        persisted_records: Vec<E>,
304        search_key: &SearchKey,
305    ) -> CryptoKeystoreResult<Vec<E>>
306    where
307        E: Clone + Entity + SearchableEntity<SearchKey> + Send + Sync,
308        SearchKey: KeyType,
309    {
310        let cached_records = self.search_in_cache(search_key).await;
311        let merged_records = self
312            .merge_records(
313                cached_records.iter().map(Arc::as_ref).map(Cow::Borrowed),
314                persisted_records.into_iter().map(Cow::Owned),
315            )
316            .await;
317        Ok(merged_records)
318    }
319
320    /// Build a single list of unique records from two potentially overlapping lists.
321    /// In case of overlap, records in `records_a` are prioritized.
322    /// Identity from the perspective of this function is determined by the output of
323    /// [Entity::merge_key].
324    ///
325    /// Further, the output list of records is built with respect to the provided [EntityFindParams]
326    /// and the deleted records cached in this [Self] instance.
327    async fn merge_records<'a, E>(
328        &self,
329        records_a: impl IntoIterator<Item = Cow<'a, E>>,
330        records_b: impl IntoIterator<Item = Cow<'a, E>>,
331    ) -> Vec<E>
332    where
333        E: Clone + Entity,
334    {
335        let deleted_records = self.deleted.read().await;
336
337        records_a
338            .into_iter()
339            .chain(records_b)
340            .unique_by(|e| e.primary_key().bytes().into_owned())
341            .filter_map(|record| {
342                let id = EntityId::from_entity(record.as_ref())?;
343                (!deleted_records.contains(&id)).then_some(record.into_owned())
344            })
345            .collect()
346    }
347
348    /// Persists all the operations in the database. It will effectively open a transaction
349    /// internally, perform all the buffered operations and commit.
350    pub(crate) async fn commit(&self, db: &Database) -> Result<(), CryptoKeystoreError> {
351        let conn = db.conn().await?;
352        let mut conn = conn.conn().await;
353
354        let cache = self.cache.read().await;
355        let deleted_ids = self.deleted.read().await;
356
357        let table_names_with_deletion = deleted_ids.iter().map(|entity_id| entity_id.collection_name());
358        let table_names_with_save = cache
359            .values()
360            .flat_map(|table| table.keys())
361            .map(|entity_id| entity_id.collection_name());
362        let mut tables = table_names_with_deletion
363            .chain(table_names_with_save)
364            .collect::<Vec<_>>();
365
366        if tables.is_empty() {
367            log::debug!("Empty transaction was committed.");
368            return Ok(());
369        }
370
371        tables.sort_unstable();
372        tables.dedup();
373
374        // open a database transaction
375        #[cfg(target_family = "wasm")]
376        let tx = conn.new_transaction(&tables).await?;
377        #[cfg(not(target_family = "wasm"))]
378        let tx = conn.transaction()?.into();
379
380        for entity in cache.values().flat_map(|table| table.values()) {
381            entity.execute_save(&tx).await?;
382        }
383
384        for deleted_id in deleted_ids.iter() {
385            deleted_id.execute_delete(&tx).await?;
386        }
387
388        // and commit everything
389        tx.commit_tx().await?;
390
391        Ok(())
392    }
393}