core_crypto_keystore/transaction/
mod.rs

1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3use std::sync::Arc;
4
5use async_lock::{RwLock, SemaphoreGuardArc};
6use itertools::Itertools;
7use zeroize::Zeroizing;
8
9use crate::connection::KeystoreDatabaseConnection;
10use crate::entities::mls::*;
11#[cfg(feature = "proteus-keystore")]
12use crate::entities::proteus::*;
13use crate::entities::{ConsumerData, EntityBase, EntityFindParams, EntityTransactionExt, UniqueEntity};
14use crate::transaction::dynamic_dispatch::EntityId;
15use crate::{CryptoKeystoreError, CryptoKeystoreResult, connection::Connection};
16
17pub mod dynamic_dispatch;
18
19#[derive(Debug, Default, derive_more::Deref, derive_more::DerefMut)]
20struct InMemoryTable(HashMap<Vec<u8>, Zeroizing<Vec<u8>>>);
21
22type InMemoryCache = Arc<RwLock<HashMap<String, InMemoryTable>>>;
23
24/// This represents a transaction, where all operations will be done in memory and committed at the
25/// end
26#[derive(Debug, Clone)]
27pub(crate) struct KeystoreTransaction {
28    cache: InMemoryCache,
29    deleted: Arc<RwLock<Vec<EntityId>>>,
30    deleted_credentials: Arc<RwLock<Vec<Vec<u8>>>>,
31    _semaphore_guard: Arc<SemaphoreGuardArc>,
32}
33
34impl KeystoreTransaction {
35    pub(crate) async fn new(semaphore_guard: SemaphoreGuardArc) -> CryptoKeystoreResult<Self> {
36        Ok(Self {
37            cache: Default::default(),
38            deleted: Arc::new(Default::default()),
39            deleted_credentials: Arc::new(Default::default()),
40            _semaphore_guard: Arc::new(semaphore_guard),
41        })
42    }
43
44    pub(crate) async fn save_mut<
45        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt + Sync,
46    >(
47        &self,
48        mut entity: E,
49    ) -> CryptoKeystoreResult<E> {
50        entity.pre_save().await?;
51        let mut cache_guard = self.cache.write().await;
52        let table = cache_guard.entry(E::COLLECTION_NAME.to_string()).or_default();
53        let serialized = postcard::to_stdvec(&entity)?;
54        // Use merge_key() because `id_raw()` is not always unique for records.
55        // For `MlsCredential`, `id_raw()` is the `CLientId`.
56        // For `MlsPendingMessage` it's the id of the group it belongs to.
57        table.insert(entity.merge_key(), Zeroizing::new(serialized));
58        Ok(entity)
59    }
60
61    pub(crate) async fn remove<
62        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt,
63        S: AsRef<[u8]>,
64    >(
65        &self,
66        id: S,
67    ) -> CryptoKeystoreResult<()> {
68        let mut cache_guard = self.cache.write().await;
69        if let Entry::Occupied(mut table) = cache_guard.entry(E::COLLECTION_NAME.to_string())
70            && let Entry::Occupied(cached_record) = table.get_mut().entry(id.as_ref().to_vec())
71        {
72            cached_record.remove_entry();
73        };
74
75        let mut deleted_list = self.deleted.write().await;
76        deleted_list.push(EntityId::from_collection_name(E::COLLECTION_NAME, id.as_ref())?);
77        Ok(())
78    }
79
80    pub(crate) async fn child_groups<E>(&self, entity: E, persisted_records: Vec<E>) -> CryptoKeystoreResult<Vec<E>>
81    where
82        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + PersistedMlsGroupExt + Sync,
83    {
84        // First get all raw groups from the cache, then deserialize them to enable filtering by there parent id
85        // matching `entity.id_raw()`.
86        let cached_records = self
87            .find_all_in_cache()
88            .await?
89            .into_iter()
90            .filter(|maybe_child: &E| {
91                maybe_child
92                    .parent_id()
93                    .map(|parent_id| parent_id == entity.id_raw())
94                    .unwrap_or_default()
95            })
96            .collect();
97
98        Ok(self
99            .merge_records(cached_records, persisted_records, EntityFindParams::default())
100            .await)
101    }
102
103    pub(crate) async fn cred_delete_by_credential(&self, cred: Vec<u8>) -> CryptoKeystoreResult<()> {
104        let mut cache_guard = self.cache.write().await;
105        if let Entry::Occupied(mut table) = cache_guard.entry(MlsCredential::COLLECTION_NAME.to_string()) {
106            table.get_mut().retain(|_, value| **value != cred);
107        }
108
109        let mut deleted_list = self.deleted_credentials.write().await;
110        deleted_list.push(cred);
111        Ok(())
112    }
113
114    pub(crate) async fn remove_pending_messages_by_conversation_id(
115        &self,
116        conversation_id: &[u8],
117    ) -> CryptoKeystoreResult<()> {
118        // We cannot return an error from `retain()`, so we've got to do this dance with a mutable result.
119        let mut result = Ok(());
120
121        let mut cache_guard = self.cache.write().await;
122        if let Entry::Occupied(mut table) = cache_guard.entry(MlsPendingMessage::COLLECTION_NAME.to_string()) {
123            table.get_mut().retain(|_key, record_bytes| {
124                postcard::from_bytes::<MlsPendingMessage>(record_bytes)
125                    .map(|pending_message| pending_message.foreign_id != conversation_id)
126                    .inspect_err(|err| result = Err(err.clone()))
127                    .unwrap_or(false)
128            });
129        }
130
131        let mut deleted_list = self.deleted.write().await;
132        deleted_list.push(EntityId::from_collection_name(
133            MlsPendingMessage::COLLECTION_NAME,
134            conversation_id,
135        )?);
136        result.map_err(Into::into)
137    }
138
139    pub(crate) async fn find_pending_messages_by_conversation_id(
140        &self,
141        conversation_id: &[u8],
142        persisted_records: Vec<MlsPendingMessage>,
143    ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
144        let cached_records = self
145            .find_all_in_cache::<MlsPendingMessage>()
146            .await?
147            .into_iter()
148            .filter(|pending_message| pending_message.foreign_id == conversation_id)
149            .collect();
150        let merged_records = self
151            .merge_records(cached_records, persisted_records, Default::default())
152            .await;
153        Ok(merged_records)
154    }
155
156    async fn find_in_cache<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<E>>
157    where
158        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
159    {
160        let cache_guard = self.cache.read().await;
161        cache_guard
162            .get(E::COLLECTION_NAME)
163            .and_then(|table| {
164                table
165                    .get(id)
166                    .map(|record| -> CryptoKeystoreResult<_> { postcard::from_bytes::<E>(record).map_err(Into::into) })
167            })
168            .transpose()
169    }
170
171    /// The result of this function will have different contents for different scenarios:
172    /// * `Some(Some(E))` - the transaction cache contains the record
173    /// * `Some(None)` - the deletion of the record has been cached
174    /// * `None` - there is no information about the record in the cache
175    pub(crate) async fn find<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<Option<E>>>
176    where
177        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
178    {
179        let maybe_cached_record = self.find_in_cache(id).await?;
180        if let Some(cached_record) = maybe_cached_record {
181            return Ok(Some(Some(cached_record)));
182        }
183
184        let deleted_list = self.deleted.read().await;
185        if deleted_list.contains(&EntityId::from_collection_name(E::COLLECTION_NAME, id)?) {
186            return Ok(Some(None));
187        }
188
189        Ok(None)
190    }
191
192    pub(crate) async fn find_unique<U: UniqueEntity<ConnectionType = KeystoreDatabaseConnection>>(
193        &self,
194    ) -> CryptoKeystoreResult<Option<U>> {
195        #[cfg(target_family = "wasm")]
196        let id = &U::ID;
197        #[cfg(not(target_family = "wasm"))]
198        let id = &[U::ID as u8];
199        let maybe_cached_record = self.find_in_cache::<U>(id).await?;
200        match maybe_cached_record {
201            Some(cached_record) => Ok(Some(cached_record)),
202            _ => {
203                // The deleted list doesn't have to be checked because unique entities don't implement
204                // deletion, just replace. So we can directly return None.
205                Ok(None)
206            }
207        }
208    }
209
210    async fn find_all_in_cache<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
211        &self,
212    ) -> CryptoKeystoreResult<Vec<E>> {
213        let cache_guard = self.cache.read().await;
214        let cached_records = cache_guard
215            .get(E::COLLECTION_NAME)
216            .map(|table| {
217                table
218                    .values()
219                    .map(|record| postcard::from_bytes::<E>(record).map_err(Into::into))
220                    .collect::<CryptoKeystoreResult<Vec<_>>>()
221            })
222            .transpose()?
223            .unwrap_or_default();
224        Ok(cached_records)
225    }
226
227    pub(crate) async fn find_all<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
228        &self,
229        persisted_records: Vec<E>,
230        params: EntityFindParams,
231    ) -> CryptoKeystoreResult<Vec<E>> {
232        let cached_records = self.find_all_in_cache().await?;
233        let merged_records = self.merge_records(cached_records, persisted_records, params).await;
234        Ok(merged_records)
235    }
236
237    pub(crate) async fn find_many<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
238        &self,
239        persisted_records: Vec<E>,
240        ids: &[Vec<u8>],
241    ) -> CryptoKeystoreResult<Vec<E>> {
242        let records = self
243            .find_all(persisted_records, EntityFindParams::default())
244            .await?
245            .into_iter()
246            .filter(|record| ids.contains(&record.id_raw().to_vec()))
247            .collect();
248        Ok(records)
249    }
250
251    /// Build a single list of unique records from two potentially overlapping lists.
252    /// In case of overlap, records in `records_a` are prioritized.
253    /// Identity from the perspective of this function is determined by the output of [crate::entities::Entity::merge_key].
254    ///
255    /// Further, the output list of records is built with respect to the provided [EntityFindParams]
256    /// and the deleted records cached in this [Self] instance.
257    async fn merge_records<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
258        &self,
259        records_a: Vec<E>,
260        records_b: Vec<E>,
261        params: EntityFindParams,
262    ) -> Vec<E> {
263        let mut merged = records_a.into_iter().chain(records_b).unique_by(|e| e.merge_key());
264
265        let deleted_records = self.deleted.read().await;
266        let deleted_credentials = self.deleted_credentials.read().await;
267
268        let merged: &mut dyn Iterator<Item = E> = if params.reverse { &mut merged.rev() } else { &mut merged };
269
270        merged
271            .filter(|record| {
272                !Self::record_is_in_deleted_list(record, &deleted_records)
273                    && !Self::credential_is_in_deleted_list(record, &deleted_credentials)
274            })
275            .skip(params.offset.unwrap_or(0) as usize)
276            .take(params.limit.unwrap_or(u32::MAX) as usize)
277            .collect()
278    }
279
280    fn record_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
281        record: &E,
282        deleted_records: &[EntityId],
283    ) -> bool {
284        let id = EntityId::from_collection_name(E::COLLECTION_NAME, record.id_raw());
285        let Ok(id) = id else { return false };
286        deleted_records.contains(&id)
287    }
288
289    fn credential_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
290        maybe_credential: &E,
291        deleted_credentials: &[Vec<u8>],
292    ) -> bool {
293        let Some(credential) = maybe_credential.downcast::<MlsCredential>() else {
294            return false;
295        };
296        deleted_credentials.contains(&credential.credential)
297    }
298}
299
300/// Persist all records cached in `$keystore_transaction` (first argument),
301/// using a transaction on `$db` (second argument).
302/// Use the provided types to read from the cache and write to the `$db`.
303///
304/// # Examples
305/// ```rust,ignore
306/// let transaction = KeystoreTransaction::new();
307/// let db = Connection::new();
308///
309/// // Commit records of all provided types
310/// commit_transaction!(
311///     transaction, db,
312///     [
313///         (identifier_01, MlsCredential),
314///         (identifier_02, MlsSignatureKeyPair),
315///     ],
316/// );
317///
318/// // Commit records of provided types in the first list. Commit records of types in the second
319/// // list only if the "proteus-keystore" cargo feature is enabled.
320/// commit_transaction!(
321///     transaction, db,
322///     [
323///         (identifier_01, MlsCredential),
324///         (identifier_02, MlsSignatureKeyPair),
325///     ],
326///     proteus_types: [
327///         (identifier_03, ProteusPrekey),
328///         (identifier_04, ProteusIdentity),
329///         (identifier_05, ProteusSession)
330///     ]
331/// );
332///```
333macro_rules! commit_transaction {
334    ($keystore_transaction:expr_2021, $db:expr_2021, [ $( ($records:ident, $entity:ty) ),*], proteus_types: [ $( ($conditional_records:ident, $conditional_entity:ty) ),*]) => {
335        #[cfg(feature = "proteus-keystore")]
336        commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*], [ $( ($conditional_records, $conditional_entity) ),*]);
337
338        #[cfg(not(feature = "proteus-keystore"))]
339        commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*]);
340    };
341     ($keystore_transaction:expr_2021, $db:expr_2021, $([ $( ($records:ident, $entity:ty) ),*]),*) => {
342            let cached_collections = ( $( $(
343            $keystore_transaction.find_all_in_cache::<$entity>().await?,
344                )* )* );
345
346             let ( $( $( $records, )* )* ) = cached_collections;
347
348            let conn = $db.borrow_conn().await?;
349            let mut conn = conn.conn().await;
350            let deleted_ids = $keystore_transaction.deleted.read().await;
351
352            let mut tables = Vec::new();
353            $( $(
354                if !$records.is_empty() {
355                    tables.push(<$entity>::COLLECTION_NAME);
356                }
357            )* )*
358
359            for deleted_id in deleted_ids.iter() {
360                tables.push(deleted_id.collection_name());
361            }
362
363            if tables.is_empty() {
364                log::debug!("Empty transaction was committed.");
365                return Ok(());
366            }
367
368            #[cfg(target_family = "wasm")]
369            let tx = conn.new_transaction(&tables).await?;
370            #[cfg(not(target_family = "wasm"))]
371            let tx = conn.transaction()?.into();
372
373             $( $(
374                if !$records.is_empty() {
375                    for record in $records {
376                        dynamic_dispatch::execute_save(&tx, &record.to_transaction_entity()).await?;
377                    }
378                }
379             )* )*
380
381
382        for deleted_id in deleted_ids.iter() {
383            dynamic_dispatch::execute_delete(&tx, deleted_id).await?
384        }
385
386        for deleted_credential in $keystore_transaction.deleted_credentials.read().await.iter() {
387            MlsCredential::delete_by_credential(&tx, deleted_credential.to_owned()).await?;
388        }
389
390         tx.commit_tx().await?;
391     };
392}
393
394impl KeystoreTransaction {
395    /// Persists all the operations in the database. It will effectively open a transaction
396    /// internally, perform all the buffered operations and commit.
397    pub(crate) async fn commit(&self, db: &Connection) -> Result<(), CryptoKeystoreError> {
398        commit_transaction!(
399            self, db,
400            [
401                (identifier_01, MlsCredential),
402                (identifier_02, MlsSignatureKeyPair),
403                (identifier_03, MlsHpkePrivateKey),
404                (identifier_04, MlsEncryptionKeyPair),
405                (identifier_05, MlsEpochEncryptionKeyPair),
406                (identifier_06, MlsPskBundle),
407                (identifier_07, MlsKeyPackage),
408                (identifier_08, PersistedMlsGroup),
409                (identifier_09, PersistedMlsPendingGroup),
410                (identifier_10, MlsPendingMessage),
411                (identifier_11, E2eiEnrollment),
412                // (identifier_12, E2eiRefreshToken),
413                (identifier_13, E2eiAcmeCA),
414                (identifier_14, E2eiIntermediateCert),
415                (identifier_15, E2eiCrl),
416                (identifier_16, ConsumerData)
417            ],
418            proteus_types: [
419                (identifier_17, ProteusPrekey),
420                (identifier_18, ProteusIdentity),
421                (identifier_19, ProteusSession)
422            ]
423        );
424
425        Ok(())
426    }
427}