core_crypto_keystore/transaction/
mod.rs

1use std::{
2    collections::{HashMap, hash_map::Entry},
3    sync::Arc,
4};
5
6use async_lock::{RwLock, SemaphoreGuardArc};
7use itertools::Itertools;
8use zeroize::Zeroizing;
9
10#[cfg(feature = "proteus-keystore")]
11use crate::entities::proteus::*;
12use crate::{
13    CryptoKeystoreError, CryptoKeystoreResult,
14    connection::{Database, KeystoreDatabaseConnection},
15    entities::{ConsumerData, EntityBase, EntityFindParams, EntityTransactionExt, UniqueEntity, mls::*},
16    transaction::dynamic_dispatch::EntityId,
17};
18
19pub mod dynamic_dispatch;
20
21#[derive(Debug, Default, derive_more::Deref, derive_more::DerefMut)]
22struct InMemoryTable(HashMap<Vec<u8>, Zeroizing<Vec<u8>>>);
23
24type InMemoryCache = Arc<RwLock<HashMap<String, InMemoryTable>>>;
25
26/// This represents a transaction, where all operations will be done in memory and committed at the
27/// end
28#[derive(Debug, Clone)]
29pub(crate) struct KeystoreTransaction {
30    cache: InMemoryCache,
31    deleted: Arc<RwLock<Vec<EntityId>>>,
32    _semaphore_guard: Arc<SemaphoreGuardArc>,
33}
34
35impl KeystoreTransaction {
36    pub(crate) async fn new(semaphore_guard: SemaphoreGuardArc) -> CryptoKeystoreResult<Self> {
37        Ok(Self {
38            cache: Default::default(),
39            deleted: Arc::new(Default::default()),
40            _semaphore_guard: Arc::new(semaphore_guard),
41        })
42    }
43
44    pub(crate) async fn save_mut<
45        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt + Sync,
46    >(
47        &self,
48        mut entity: E,
49    ) -> CryptoKeystoreResult<E> {
50        entity.pre_save().await?;
51        let mut cache_guard = self.cache.write().await;
52        let table = cache_guard.entry(E::COLLECTION_NAME.to_string()).or_default();
53        let serialized = postcard::to_stdvec(&entity)?;
54        // Use merge_key() because `id_raw()` is not always unique for records.
55        // For `MlsPendingMessage` it's the id of the group it belongs to.
56        table.insert(entity.merge_key(), Zeroizing::new(serialized));
57        Ok(entity)
58    }
59
60    pub(crate) async fn remove<
61        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt,
62        S: AsRef<[u8]>,
63    >(
64        &self,
65        id: S,
66    ) -> CryptoKeystoreResult<()> {
67        let mut cache_guard = self.cache.write().await;
68        if let Entry::Occupied(mut table) = cache_guard.entry(E::COLLECTION_NAME.to_string())
69            && let Entry::Occupied(cached_record) = table.get_mut().entry(id.as_ref().to_vec())
70        {
71            cached_record.remove_entry();
72        };
73
74        let mut deleted_list = self.deleted.write().await;
75        deleted_list.push(EntityId::from_collection_name(E::COLLECTION_NAME, id.as_ref())?);
76        Ok(())
77    }
78
79    pub(crate) async fn child_groups<E>(&self, entity: E, persisted_records: Vec<E>) -> CryptoKeystoreResult<Vec<E>>
80    where
81        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + PersistedMlsGroupExt + Sync,
82    {
83        // First get all raw groups from the cache, then deserialize them to enable filtering by there parent id
84        // matching `entity.id_raw()`.
85        let cached_records = self
86            .find_all_in_cache()
87            .await?
88            .into_iter()
89            .filter(|maybe_child: &E| {
90                maybe_child
91                    .parent_id()
92                    .map(|parent_id| parent_id == entity.id_raw())
93                    .unwrap_or_default()
94            })
95            .collect();
96
97        Ok(self
98            .merge_records(cached_records, persisted_records, EntityFindParams::default())
99            .await)
100    }
101
102    pub(crate) async fn remove_pending_messages_by_conversation_id(
103        &self,
104        conversation_id: impl AsRef<[u8]> + Send,
105    ) -> CryptoKeystoreResult<()> {
106        // We cannot return an error from `retain()`, so we've got to do this dance with a mutable result.
107        let mut result = Ok(());
108
109        let mut cache_guard = self.cache.write().await;
110        if let Entry::Occupied(mut table) = cache_guard.entry(MlsPendingMessage::COLLECTION_NAME.to_string()) {
111            table.get_mut().retain(|_key, record_bytes| {
112                postcard::from_bytes::<MlsPendingMessage>(record_bytes)
113                    .map(|pending_message| pending_message.foreign_id != conversation_id.as_ref())
114                    .inspect_err(|err| result = Err(err.clone()))
115                    .unwrap_or(false)
116            });
117        }
118
119        let mut deleted_list = self.deleted.write().await;
120        deleted_list.push(EntityId::from_collection_name(
121            MlsPendingMessage::COLLECTION_NAME,
122            conversation_id.as_ref(),
123        )?);
124        result.map_err(Into::into)
125    }
126
127    pub(crate) async fn find_pending_messages_by_conversation_id(
128        &self,
129        conversation_id: &[u8],
130        persisted_records: Vec<MlsPendingMessage>,
131    ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
132        let cached_records = self
133            .find_all_in_cache::<MlsPendingMessage>()
134            .await?
135            .into_iter()
136            .filter(|pending_message| pending_message.foreign_id == conversation_id)
137            .collect();
138        let merged_records = self
139            .merge_records(cached_records, persisted_records, Default::default())
140            .await;
141        Ok(merged_records)
142    }
143
144    async fn find_in_cache<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<E>>
145    where
146        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
147    {
148        let cache_guard = self.cache.read().await;
149        cache_guard
150            .get(E::COLLECTION_NAME)
151            .and_then(|table| {
152                table
153                    .get(id)
154                    .map(|record| -> CryptoKeystoreResult<_> { postcard::from_bytes::<E>(record).map_err(Into::into) })
155            })
156            .transpose()
157    }
158
159    /// The result of this function will have different contents for different scenarios:
160    /// * `Some(Some(E))` - the transaction cache contains the record
161    /// * `Some(None)` - the deletion of the record has been cached
162    /// * `None` - there is no information about the record in the cache
163    pub(crate) async fn find<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<Option<E>>>
164    where
165        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
166    {
167        let maybe_cached_record = self.find_in_cache(id).await?;
168        if let Some(cached_record) = maybe_cached_record {
169            return Ok(Some(Some(cached_record)));
170        }
171
172        let deleted_list = self.deleted.read().await;
173        if deleted_list.contains(&EntityId::from_collection_name(E::COLLECTION_NAME, id)?) {
174            return Ok(Some(None));
175        }
176
177        Ok(None)
178    }
179
180    pub(crate) async fn find_unique<U: UniqueEntity<ConnectionType = KeystoreDatabaseConnection>>(
181        &self,
182    ) -> CryptoKeystoreResult<Option<U>> {
183        #[cfg(target_family = "wasm")]
184        let id = &U::ID;
185        #[cfg(not(target_family = "wasm"))]
186        let id = &[U::ID as u8];
187        let maybe_cached_record = self.find_in_cache::<U>(id).await?;
188        match maybe_cached_record {
189            Some(cached_record) => Ok(Some(cached_record)),
190            _ => {
191                // The deleted list doesn't have to be checked because unique entities don't implement
192                // deletion, just replace. So we can directly return None.
193                Ok(None)
194            }
195        }
196    }
197
198    async fn find_all_in_cache<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
199        &self,
200    ) -> CryptoKeystoreResult<Vec<E>> {
201        let cache_guard = self.cache.read().await;
202        let cached_records = cache_guard
203            .get(E::COLLECTION_NAME)
204            .map(|table| {
205                table
206                    .values()
207                    .map(|record| postcard::from_bytes::<E>(record).map_err(Into::into))
208                    .collect::<CryptoKeystoreResult<Vec<_>>>()
209            })
210            .transpose()?
211            .unwrap_or_default();
212        Ok(cached_records)
213    }
214
215    pub(crate) async fn find_all<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
216        &self,
217        persisted_records: Vec<E>,
218        params: EntityFindParams,
219    ) -> CryptoKeystoreResult<Vec<E>> {
220        let cached_records = self.find_all_in_cache().await?;
221        let merged_records = self.merge_records(cached_records, persisted_records, params).await;
222        Ok(merged_records)
223    }
224
225    pub(crate) async fn find_many<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
226        &self,
227        persisted_records: Vec<E>,
228        ids: &[Vec<u8>],
229    ) -> CryptoKeystoreResult<Vec<E>> {
230        let records = self
231            .find_all(persisted_records, EntityFindParams::default())
232            .await?
233            .into_iter()
234            .filter(|record| ids.contains(&record.id_raw().to_vec()))
235            .collect();
236        Ok(records)
237    }
238
239    /// Build a single list of unique records from two potentially overlapping lists.
240    /// In case of overlap, records in `records_a` are prioritized.
241    /// Identity from the perspective of this function is determined by the output of
242    /// [crate::entities::Entity::merge_key].
243    ///
244    /// Further, the output list of records is built with respect to the provided [EntityFindParams]
245    /// and the deleted records cached in this [Self] instance.
246    async fn merge_records<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
247        &self,
248        records_a: Vec<E>,
249        records_b: Vec<E>,
250        params: EntityFindParams,
251    ) -> Vec<E> {
252        let mut merged = records_a.into_iter().chain(records_b).unique_by(|e| e.merge_key());
253
254        let deleted_records = self.deleted.read().await;
255
256        let merged: &mut dyn Iterator<Item = E> = if params.reverse { &mut merged.rev() } else { &mut merged };
257
258        merged
259            .filter(|record| !Self::record_is_in_deleted_list(record, &deleted_records))
260            .skip(params.offset.unwrap_or(0) as usize)
261            .take(params.limit.unwrap_or(u32::MAX) as usize)
262            .collect()
263    }
264
265    fn record_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
266        record: &E,
267        deleted_records: &[EntityId],
268    ) -> bool {
269        let id = EntityId::from_collection_name(E::COLLECTION_NAME, record.id_raw());
270        let Ok(id) = id else { return false };
271        deleted_records.contains(&id)
272    }
273}
274
275/// Persist all records cached in `$keystore_transaction` (first argument),
276/// using a transaction on `$db` (second argument).
277/// Use the provided types to read from the cache and write to the `$db`.
278///
279/// # Examples
280/// ```rust,ignore
281/// let transaction = KeystoreTransaction::new();
282/// let db = Connection::new();
283///
284/// // Commit records of all provided types
285/// commit_transaction!(
286///     transaction, db,
287///     [
288///         (identifier_01, StoredCredential),
289///         (identifier_02, StoredSignatureKeypair),
290///     ],
291/// );
292///
293/// // Commit records of provided types in the first list. Commit records of types in the second
294/// // list only if the "proteus-keystore" cargo feature is enabled.
295/// commit_transaction!(
296///     transaction, db,
297///     [
298///         (identifier_01, StoredCredential),
299///         (identifier_02, StoredSignatureKeypair),
300///     ],
301///     proteus_types: [
302///         (identifier_03, ProteusPrekey),
303///         (identifier_04, ProteusIdentity),
304///         (identifier_05, ProteusSession)
305///     ]
306/// );
307/// ```
308macro_rules! commit_transaction {
309    ($keystore_transaction:expr_2021, $db:expr_2021, [ $( ($records:ident, $entity:ty) ),*], proteus_types: [ $( ($conditional_records:ident, $conditional_entity:ty) ),*]) => {
310        #[cfg(feature = "proteus-keystore")]
311        commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*], [ $( ($conditional_records, $conditional_entity) ),*]);
312
313        #[cfg(not(feature = "proteus-keystore"))]
314        commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*]);
315    };
316     ($keystore_transaction:expr_2021, $db:expr_2021, $([ $( ($records:ident, $entity:ty) ),*]),*) => {
317            let cached_collections = ( $( $(
318            $keystore_transaction.find_all_in_cache::<$entity>().await?,
319                )* )* );
320
321             let ( $( $( $records, )* )* ) = cached_collections;
322
323            let conn = $db.conn().await?;
324            let mut conn = conn.conn().await;
325            let deleted_ids = $keystore_transaction.deleted.read().await;
326
327            let mut tables = Vec::new();
328            $( $(
329                if !$records.is_empty() {
330                    tables.push(<$entity>::COLLECTION_NAME);
331                }
332            )* )*
333
334            for deleted_id in deleted_ids.iter() {
335                tables.push(deleted_id.collection_name());
336            }
337
338            if tables.is_empty() {
339                log::debug!("Empty transaction was committed.");
340                return Ok(());
341            }
342
343            #[cfg(target_family = "wasm")]
344            let tx = conn.new_transaction(&tables).await?;
345            #[cfg(not(target_family = "wasm"))]
346            let tx = conn.transaction()?.into();
347
348             $( $(
349                if !$records.is_empty() {
350                    for record in $records {
351                        dynamic_dispatch::execute_save(&tx, &record.to_transaction_entity()).await?;
352                    }
353                }
354             )* )*
355
356
357        for deleted_id in deleted_ids.iter() {
358            dynamic_dispatch::execute_delete(&tx, deleted_id).await?
359        }
360
361         tx.commit_tx().await?;
362     };
363}
364
365impl KeystoreTransaction {
366    /// Persists all the operations in the database. It will effectively open a transaction
367    /// internally, perform all the buffered operations and commit.
368    pub(crate) async fn commit(&self, db: &Database) -> Result<(), CryptoKeystoreError> {
369        commit_transaction!(
370            self, db,
371            [
372                (identifier_01, StoredCredential),
373                // (identifier_02, StoredSignatureKeypair),
374                (identifier_03, StoredHpkePrivateKey),
375                (identifier_04, StoredEncryptionKeyPair),
376                (identifier_05, StoredEpochEncryptionKeypair),
377                (identifier_06, StoredPskBundle),
378                (identifier_07, StoredKeypackage),
379                (identifier_08, PersistedMlsGroup),
380                (identifier_09, PersistedMlsPendingGroup),
381                (identifier_10, MlsPendingMessage),
382                (identifier_11, StoredE2eiEnrollment),
383                // (identifier_12, E2eiRefreshToken),
384                (identifier_13, E2eiAcmeCA),
385                (identifier_14, E2eiIntermediateCert),
386                (identifier_15, E2eiCrl),
387                (identifier_16, ConsumerData)
388            ],
389            proteus_types: [
390                (identifier_17, ProteusPrekey),
391                (identifier_18, ProteusIdentity),
392                (identifier_19, ProteusSession)
393            ]
394        );
395
396        Ok(())
397    }
398}