core_crypto_keystore/transaction/
mod.rs

1use std::{
2    borrow::Cow,
3    collections::{HashMap, HashSet, hash_map::Entry},
4    sync::Arc,
5};
6
7use async_lock::{RwLock, SemaphoreGuardArc};
8use itertools::Itertools;
9
10use crate::{
11    CryptoKeystoreError, CryptoKeystoreResult,
12    connection::{Database, KeystoreDatabaseConnection},
13    entities::{MlsPendingMessage, MlsPendingMessagePrimaryKey, PersistedMlsGroupExt},
14    traits::{BorrowPrimaryKey, Entity, EntityBase as _, EntityDatabaseMutation, EntityDeleteBorrowed, KeyType},
15    transaction::dynamic_dispatch::EntityId,
16};
17
18pub(crate) mod dynamic_dispatch;
19
20/// table: primary key -> entity reference
21type InMemoryTable = HashMap<EntityId, dynamic_dispatch::Entity>;
22/// collection: collection name -> table
23type InMemoryCollection = Arc<RwLock<HashMap<&'static str, InMemoryTable>>>;
24
25/// This represents a transaction, where all operations will be done in memory and committed at the
26/// end
27#[derive(Debug, Clone)]
28pub(crate) struct KeystoreTransaction {
29    cache: InMemoryCollection,
30    deleted: Arc<RwLock<HashSet<EntityId>>>,
31    _semaphore_guard: Arc<SemaphoreGuardArc>,
32}
33
34impl KeystoreTransaction {
35    /// Instantiate a new transaction.
36    ///
37    /// Requires a semaphore guard to ensure that only one exists at a time.
38    pub(crate) async fn new(semaphore_guard: SemaphoreGuardArc) -> CryptoKeystoreResult<Self> {
39        Ok(Self {
40            cache: Default::default(),
41            deleted: Arc::new(Default::default()),
42            _semaphore_guard: Arc::new(semaphore_guard),
43        })
44    }
45
46    /// Save an entity into this transaction.
47    ///
48    /// This is a multi-step process:
49    ///
50    /// - Adjust the entity by calling its [`pre_save()`][Entity::pre_save] method.
51    /// - Store the entity in an internal map.
52    ///   - Remove the entity from the set of deleted entities, if it was there.
53    /// - On [`Self::commit`], actually persist the entity into the supplied database.
54    pub(crate) async fn save<'a, E>(&self, mut entity: E) -> CryptoKeystoreResult<E::AutoGeneratedFields>
55    where
56        E: Entity<ConnectionType = KeystoreDatabaseConnection> + EntityDatabaseMutation<'a> + Send + Sync,
57    {
58        let auto_generated_fields = entity.pre_save().await?;
59
60        let entity_id =
61            EntityId::from_entity(&entity).ok_or(CryptoKeystoreError::UnknownCollectionName(E::COLLECTION_NAME))?;
62        {
63            // start by adding the entity
64            let mut cache_guard = self.cache.write().await;
65            let table = cache_guard.entry(E::COLLECTION_NAME).or_default();
66            table.insert(entity_id.clone(), entity.to_transaction_entity());
67        }
68        {
69            // at this point remove the entity from the set of deleted entities to ensure that
70            // this new data gets propagated
71            let mut cache_guard = self.deleted.write().await;
72            cache_guard.remove(&entity_id);
73        }
74
75        Ok(auto_generated_fields)
76    }
77
78    async fn remove_by_entity_id<'a, E>(&self, entity_id: EntityId) -> CryptoKeystoreResult<()>
79    where
80        E: Entity + EntityDatabaseMutation<'a>,
81    {
82        // rm this entity from the set of added/modified items
83        // it might never touch the real db at all
84        let mut cache_guard = self.cache.write().await;
85        if let Entry::Occupied(mut table) = cache_guard.entry(E::COLLECTION_NAME)
86            && let Entry::Occupied(cached_record) = table.get_mut().entry(entity_id.clone())
87        {
88            cached_record.remove_entry();
89        };
90
91        // add this entity to the set of items which should be deleted from the persisted db
92        let mut deleted_set = self.deleted.write().await;
93        deleted_set.insert(entity_id);
94        Ok(())
95    }
96
97    /// Remove an entity by its primary key.
98    ///
99    /// Where the primary key has a distinct borrowed form, consider [`Self::remove_borrowed`].
100    ///
101    /// Note that this doesn't return whether or not anything was actually removed because
102    /// that won't happen until the transaction is committed.
103    pub(crate) async fn remove<'a, E>(&self, id: &E::PrimaryKey) -> CryptoKeystoreResult<()>
104    where
105        E: Entity + EntityDatabaseMutation<'a>,
106    {
107        let entity_id = EntityId::from_primary_key::<E>(id)
108            .ok_or(CryptoKeystoreError::UnknownCollectionName(E::COLLECTION_NAME))?;
109        self.remove_by_entity_id::<E>(entity_id).await
110    }
111
112    /// Remove an entity by the borrowed form of its primary key.
113    ///
114    /// Note that this doesn't return whether or not anything was actually removed because
115    /// that won't happen until the transaction is committed.
116    pub(crate) async fn remove_borrowed<'a, E>(&self, id: &E::BorrowedPrimaryKey) -> CryptoKeystoreResult<()>
117    where
118        E: EntityDeleteBorrowed<'a> + BorrowPrimaryKey,
119    {
120        let entity_id = EntityId::from_borrowed_primary_key::<E>(id)
121            .ok_or(CryptoKeystoreError::UnknownCollectionName(E::COLLECTION_NAME))?;
122        self.remove_by_entity_id::<E>(entity_id).await
123    }
124
125    pub(crate) async fn child_groups<E>(
126        &self,
127        entity: E,
128        persisted_records: impl IntoIterator<Item = E>,
129    ) -> CryptoKeystoreResult<Vec<E>>
130    where
131        E: Clone + Entity + BorrowPrimaryKey + PersistedMlsGroupExt + Send + Sync,
132        for<'pk> &'pk <E as BorrowPrimaryKey>::BorrowedPrimaryKey: KeyType,
133    {
134        // First get all raw groups from the cache, then filter by their parent id
135        let cached_records = self.find_all_in_cache::<E>().await;
136        let cached_records = cached_records
137            .iter()
138            .filter(|maybe_child| {
139                maybe_child
140                    .parent_id()
141                    .map(|parent_id| parent_id == entity.borrow_primary_key().bytes().as_ref())
142                    .unwrap_or_default()
143            })
144            .map(Arc::as_ref)
145            .map(Cow::Borrowed);
146
147        let persisted_records = persisted_records.into_iter().map(Cow::Owned);
148
149        Ok(self.merge_records(cached_records, persisted_records).await)
150    }
151
152    pub(crate) async fn remove_pending_messages_by_conversation_id(&self, conversation_id: impl AsRef<[u8]> + Send) {
153        let conversation_id = conversation_id.as_ref();
154
155        let mut cache_guard = self.cache.write().await;
156        if let Entry::Occupied(mut table) = cache_guard.entry(MlsPendingMessage::COLLECTION_NAME) {
157            table.get_mut().retain(|_key, entity| {
158                let pending_message = entity
159                    .downcast::<MlsPendingMessage>()
160                    .expect("table for MlsPendingMessage contains only that type");
161                pending_message.foreign_id != conversation_id
162            });
163        }
164        drop(cache_guard);
165
166        let mut deleted_set = self.deleted.write().await;
167        deleted_set.insert(
168            EntityId::from_primary_key::<MlsPendingMessage>(&MlsPendingMessagePrimaryKey::from_conversation_id(
169                conversation_id,
170            ))
171            .expect("mls pending messages are proper entities which can be parsed"),
172        );
173    }
174
175    pub(crate) async fn find_pending_messages_by_conversation_id(
176        &self,
177        conversation_id: &[u8],
178        persisted_records: impl IntoIterator<Item = MlsPendingMessage>,
179    ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
180        let persisted_records = persisted_records.into_iter().map(Cow::Owned);
181
182        let cached_records = self.find_all_in_cache::<MlsPendingMessage>().await;
183        let cached_records = cached_records
184            .iter()
185            .filter(|pending_message| pending_message.foreign_id == conversation_id)
186            .map(Arc::as_ref)
187            .map(Cow::Borrowed);
188
189        let merged_records = self.merge_records(cached_records, persisted_records).await;
190        Ok(merged_records)
191    }
192
193    async fn find_in_cache<E>(&self, entity_id: &EntityId) -> Option<Arc<E>>
194    where
195        E: Entity + Send + Sync,
196    {
197        let cache_guard = self.cache.read().await;
198        cache_guard
199            .get(E::COLLECTION_NAME)
200            .and_then(|table| table.get(entity_id).and_then(|entity| entity.downcast()))
201    }
202
203    /// The result of this function will have different contents for different scenarios:
204    /// * `Some(Some(E))` - the transaction cache contains the record
205    /// * `Some(None)` - the deletion of the record has been cached
206    /// * `None` - there is no information about the record in the cache
207    async fn get_by_entity_id<E>(&self, entity_id: &EntityId) -> Option<Option<Arc<E>>>
208    where
209        E: Entity + Send + Sync,
210    {
211        // when applying our transaction to the real database, we delete after inserting,
212        // so here we have to check for deletion before we check for existing values
213        let deleted_list = self.deleted.read().await;
214        if deleted_list.contains(entity_id) {
215            return Some(None);
216        }
217
218        self.find_in_cache::<E>(entity_id).await.map(Some)
219    }
220
221    /// The result of this function will have different contents for different scenarios:
222    /// * `Some(Some(E))` - the transaction cache contains the record
223    /// * `Some(None)` - the deletion of the record has been cached
224    /// * `None` - there is no information about the record in the cache
225    pub(crate) async fn get<E>(&self, id: &E::PrimaryKey) -> Option<Option<Arc<E>>>
226    where
227        E: Entity + Send + Sync,
228    {
229        let entity_id = EntityId::from_primary_key::<E>(id)?;
230        self.get_by_entity_id(&entity_id).await
231    }
232
233    /// The result of this function will have different contents for different scenarios:
234    /// * `Some(Some(E))` - the transaction cache contains the record
235    /// * `Some(None)` - the deletion of the record has been cached
236    /// * `None` - there is no information about the record in the cache
237    pub(crate) async fn get_borrowed<E>(&self, id: &E::BorrowedPrimaryKey) -> Option<Option<Arc<E>>>
238    where
239        E: Entity + BorrowPrimaryKey + Send + Sync,
240    {
241        let entity_id = EntityId::from_borrowed_primary_key::<E>(id)?;
242        self.get_by_entity_id(&entity_id).await
243    }
244
245    async fn find_all_in_cache<E>(&self) -> Vec<Arc<E>>
246    where
247        E: Entity + Send + Sync,
248    {
249        let cache_guard = self.cache.read().await;
250        cache_guard
251            .get(E::COLLECTION_NAME)
252            .map(|table| {
253                table
254                    .values()
255                    .map(|record: &dynamic_dispatch::Entity| {
256                        record
257                            .downcast::<E>()
258                            .expect("all entries in this table are of this type")
259                            .clone()
260                    })
261                    .collect::<Vec<_>>()
262            })
263            .unwrap_or_default()
264    }
265
266    pub(crate) async fn find_all<E>(&self, persisted_records: Vec<E>) -> CryptoKeystoreResult<Vec<E>>
267    where
268        E: Clone + Entity + Send + Sync,
269    {
270        let cached_records = self.find_all_in_cache().await;
271        let merged_records = self
272            .merge_records(
273                cached_records.iter().map(Arc::as_ref).map(Cow::Borrowed),
274                persisted_records.into_iter().map(Cow::Owned),
275            )
276            .await;
277        Ok(merged_records)
278    }
279
280    /// Build a single list of unique records from two potentially overlapping lists.
281    /// In case of overlap, records in `records_a` are prioritized.
282    /// Identity from the perspective of this function is determined by the output of
283    /// [Entity::merge_key].
284    ///
285    /// Further, the output list of records is built with respect to the provided [EntityFindParams]
286    /// and the deleted records cached in this [Self] instance.
287    async fn merge_records<'a, E>(
288        &self,
289        records_a: impl IntoIterator<Item = Cow<'a, E>>,
290        records_b: impl IntoIterator<Item = Cow<'a, E>>,
291    ) -> Vec<E>
292    where
293        E: Clone + Entity,
294    {
295        let deleted_records = self.deleted.read().await;
296
297        records_a
298            .into_iter()
299            .chain(records_b)
300            .unique_by(|e| e.primary_key().bytes().into_owned())
301            .filter_map(|record| {
302                let id = EntityId::from_entity(record.as_ref())?;
303                (!deleted_records.contains(&id)).then_some(record.into_owned())
304            })
305            .collect()
306    }
307
308    /// Persists all the operations in the database. It will effectively open a transaction
309    /// internally, perform all the buffered operations and commit.
310    pub(crate) async fn commit(&self, db: &Database) -> Result<(), CryptoKeystoreError> {
311        let conn = db.conn().await?;
312        let mut conn = conn.conn().await;
313
314        let cache = self.cache.read().await;
315        let deleted_ids = self.deleted.read().await;
316
317        let table_names_with_deletion = deleted_ids.iter().map(|entity_id| entity_id.collection_name());
318        let table_names_with_save = cache
319            .values()
320            .flat_map(|table| table.keys())
321            .map(|entity_id| entity_id.collection_name());
322        let mut tables = table_names_with_deletion
323            .chain(table_names_with_save)
324            .collect::<Vec<_>>();
325
326        if tables.is_empty() {
327            log::debug!("Empty transaction was committed.");
328            return Ok(());
329        }
330
331        tables.sort_unstable();
332        tables.dedup();
333
334        // open a database transaction
335        #[cfg(target_family = "wasm")]
336        let tx = conn.new_transaction(&tables).await?;
337        #[cfg(not(target_family = "wasm"))]
338        let tx = conn.transaction()?.into();
339
340        for entity in cache.values().flat_map(|table| table.values()) {
341            entity.execute_save(&tx).await?;
342        }
343
344        for deleted_id in deleted_ids.iter() {
345            deleted_id.execute_delete(&tx).await?;
346        }
347
348        // and commit everything
349        tx.commit_tx().await?;
350
351        Ok(())
352    }
353}