core_crypto_keystore/transaction/
mod.rs

1pub mod dynamic_dispatch;
2
3use crate::entities::mls::*;
4#[cfg(feature = "proteus-keystore")]
5use crate::entities::proteus::*;
6use crate::entities::{ConsumerData, EntityBase, EntityFindParams, EntityTransactionExt, UniqueEntity};
7use crate::transaction::dynamic_dispatch::EntityId;
8use crate::{
9    connection::{Connection, DatabaseConnection, FetchFromDatabase, KeystoreDatabaseConnection},
10    CryptoKeystoreError, CryptoKeystoreResult,
11};
12use async_lock::{RwLock, SemaphoreGuardArc};
13use itertools::Itertools;
14use std::{ops::DerefMut, sync::Arc};
15
16/// This represents a transaction, where all operations will be done in memory and committed at the
17/// end
18#[derive(Debug, Clone)]
19pub(crate) struct KeystoreTransaction {
20    /// In-memory cache
21    cache: Connection,
22    deleted: Arc<RwLock<Vec<EntityId>>>,
23    deleted_credentials: Arc<RwLock<Vec<Vec<u8>>>>,
24    _semaphore_guard: Arc<SemaphoreGuardArc>,
25}
26
27impl KeystoreTransaction {
28    pub(crate) async fn new(semaphore_guard: SemaphoreGuardArc) -> CryptoKeystoreResult<Self> {
29        Ok(Self {
30            // We're not using a proper key because we're not using the DB for security (memory is unencrypted).
31            // We're using it for its API.
32            cache: Connection::open_in_memory_with_key("core_crypto_transaction_cache", "").await?,
33            deleted: Arc::new(Default::default()),
34            deleted_credentials: Arc::new(Default::default()),
35            _semaphore_guard: Arc::new(semaphore_guard),
36        })
37    }
38
39    pub(crate) async fn save_mut<
40        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt + Sync,
41    >(
42        &self,
43        mut entity: E,
44    ) -> CryptoKeystoreResult<E> {
45        entity.pre_save().await?;
46        let mut conn = self.cache.borrow_conn().await?;
47        #[cfg(target_family = "wasm")]
48        let transaction = conn.new_transaction(&[E::COLLECTION_NAME]).await?;
49        #[cfg(not(target_family = "wasm"))]
50        let transaction = conn.new_transaction().await?;
51        entity.save(&transaction).await?;
52        transaction.commit_tx().await?;
53        Ok(entity)
54    }
55
56    pub(crate) async fn remove<
57        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt,
58        S: AsRef<[u8]>,
59    >(
60        &self,
61        id: S,
62    ) -> CryptoKeystoreResult<()> {
63        let mut conn = self.cache.borrow_conn().await?;
64        #[cfg(target_family = "wasm")]
65        let transaction = conn.new_transaction(&[E::COLLECTION_NAME]).await?;
66        #[cfg(not(target_family = "wasm"))]
67        let transaction = conn.new_transaction().await?;
68        E::delete(&transaction, id.as_ref().into()).await?;
69        transaction.commit_tx().await?;
70        let mut deleted_list = self.deleted.write().await;
71        deleted_list.push(EntityId::from_collection_name(E::COLLECTION_NAME, id.as_ref())?);
72        Ok(())
73    }
74
75    pub(crate) async fn child_groups<
76        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + PersistedMlsGroupExt + Sync,
77    >(
78        &self,
79        entity: E,
80        persisted_records: Vec<E>,
81    ) -> CryptoKeystoreResult<Vec<E>> {
82        let mut conn = self.cache.borrow_conn().await?;
83        let cached_records = entity.child_groups(conn.deref_mut()).await?;
84        Ok(self
85            .merge_records(cached_records, persisted_records, EntityFindParams::default())
86            .await)
87    }
88
89    pub(crate) async fn cred_delete_by_credential(&self, cred: Vec<u8>) -> CryptoKeystoreResult<()> {
90        let mut conn = self.cache.borrow_conn().await?;
91        #[cfg(target_family = "wasm")]
92        let transaction = conn.new_transaction(&[MlsCredential::COLLECTION_NAME]).await?;
93        #[cfg(not(target_family = "wasm"))]
94        let transaction = conn.new_transaction().await?;
95        MlsCredential::delete_by_credential(&transaction, cred.clone()).await?;
96        transaction.commit_tx().await?;
97        let mut deleted_list = self.deleted_credentials.write().await;
98        deleted_list.push(cred);
99        Ok(())
100    }
101
102    /// The result of this function will have different contents for different scenarios:
103    /// * `Some(Some(E))` - the transaction cache contains the record
104    /// * `Some(None)` - the deletion of the record has been cached
105    /// * `None` - there is no information about the record in the cache
106    pub(crate) async fn find<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<Option<E>>>
107    where
108        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
109    {
110        let cache_result = self.cache.find(id).await?;
111        if let Some(cache_result) = cache_result {
112            Ok(Some(Some(cache_result)))
113        } else {
114            let deleted_list = self.deleted.read().await;
115            if deleted_list.contains(&EntityId::from_collection_name(E::COLLECTION_NAME, id)?) {
116                Ok(Some(None))
117            } else {
118                Ok(None)
119            }
120        }
121    }
122
123    pub(crate) async fn find_unique<U: UniqueEntity<ConnectionType = KeystoreDatabaseConnection>>(
124        &self,
125    ) -> CryptoKeystoreResult<Option<U>> {
126        let cache_result = self.cache.find_unique().await;
127        if let Ok(cache_result) = cache_result {
128            Ok(Some(cache_result))
129        } else {
130            // The deleted list doesn't have to be checked because unique entities don't implement
131            // deletion, just replace. So we can directly return None.
132            Ok(None)
133        }
134    }
135
136    pub(crate) async fn find_all<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
137        &self,
138        persisted_records: Vec<E>,
139        params: EntityFindParams,
140    ) -> CryptoKeystoreResult<Vec<E>> {
141        let cached_records: Vec<E> = self.cache.find_all(params.clone()).await?;
142        Ok(self.merge_records(cached_records, persisted_records, params).await)
143    }
144
145    pub(crate) async fn find_many<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
146        &self,
147        persisted_records: Vec<E>,
148        ids: &[Vec<u8>],
149    ) -> CryptoKeystoreResult<Vec<E>> {
150        let cached_records: Vec<E> = self.cache.find_many(ids).await?;
151        Ok(self
152            .merge_records(cached_records, persisted_records, EntityFindParams::default())
153            .await)
154    }
155
156    /// Build a single list of unique records from two potentially overlapping lists.
157    /// In case of overlap, records in `records_a` are prioritized.
158    /// Identity from the perspective of this function is determined by the output of [crate::entities::Entity::merge_key].
159    ///
160    /// Further, the output list of records is built with respect to the provided [EntityFindParams]
161    /// and the deleted records cached in this [Self] instance.
162    async fn merge_records<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
163        &self,
164        records_a: Vec<E>,
165        records_b: Vec<E>,
166        params: EntityFindParams,
167    ) -> Vec<E> {
168        let merged = records_a.into_iter().chain(records_b).unique_by(|e| e.merge_key());
169
170        // We are consuming the iterator here to keep types of the `if` and `else` block consistent.
171        // The alternative to giving up laziness here would be to use a dynamically
172        // typed iterator Box<dyn Iterator<Item = E>> assigned to `merged`. The below approach
173        // trades stack allocation instead of heap allocation for laziness.
174        //
175        // Also, we have to do this before filtering by deleted records since filter map does not
176        // return an iterator that is double ended.
177        let merged: Vec<E> = if params.reverse {
178            merged.rev().collect()
179        } else {
180            merged.collect()
181        };
182
183        if merged.is_empty() {
184            return merged;
185        }
186
187        let deleted_records = self.deleted.read().await;
188        let deleted_credentials = self.deleted_credentials.read().await;
189        let merged = if deleted_records.is_empty() && deleted_credentials.is_empty() {
190            merged
191        } else {
192            merged
193                .into_iter()
194                .filter(|record| {
195                    !Self::record_is_in_deleted_list(record, &deleted_records)
196                        && !Self::credential_is_in_deleted_list(record, &deleted_credentials)
197                })
198                .collect()
199        };
200
201        merged
202            .into_iter()
203            .skip(params.offset.unwrap_or(0) as usize)
204            .take(params.limit.unwrap_or(u32::MAX) as usize)
205            .collect()
206    }
207
208    fn record_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
209        record: &E,
210        deleted_records: &[EntityId],
211    ) -> bool {
212        let id = EntityId::from_collection_name(E::COLLECTION_NAME, record.id_raw());
213        let Ok(id) = id else { return false };
214        deleted_records.contains(&id)
215    }
216    fn credential_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
217        maybe_credential: &E,
218        deleted_credentials: &[Vec<u8>],
219    ) -> bool {
220        let Some(credential) = maybe_credential.downcast::<MlsCredential>() else {
221            return false;
222        };
223        deleted_credentials.contains(&credential.credential)
224    }
225}
226
227/// Persist all records cached in `$keystore_transaction` (first argument),
228/// using a transaction on `$db` (second argument).
229/// Use the provided types to read from the cache and write to the `$db`.
230///
231/// # Examples
232/// ```rust,ignore
233/// let transaction = KeystoreTransaction::new();
234/// let db = Connection::new();
235///
236/// // Commit records of all provided types
237/// commit_transaction!(
238///     transaction, db,
239///     [
240///         (identifier_01, MlsCredential),
241///         (identifier_02, MlsSignatureKeyPair),
242///     ],
243/// );
244///
245/// // Commit records of provided types in the first list. Commit records of types in the second
246/// // list only if the "proteus-keystore" cargo feature is enabled.
247/// commit_transaction!(
248///     transaction, db,
249///     [
250///         (identifier_01, MlsCredential),
251///         (identifier_02, MlsSignatureKeyPair),
252///     ],
253///     proteus_types: [
254///         (identifier_03, ProteusPrekey),
255///         (identifier_04, ProteusIdentity),
256///         (identifier_05, ProteusSession)
257///     ]
258/// );
259///```
260macro_rules! commit_transaction {
261    ($keystore_transaction:expr, $db:expr, [ $( ($records:ident, $entity:ty) ),*], proteus_types: [ $( ($conditional_records:ident, $conditional_entity:ty) ),*]) => {
262        #[cfg(feature = "proteus-keystore")]
263        commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*], [ $( ($conditional_records, $conditional_entity) ),*]);
264
265        #[cfg(not(feature = "proteus-keystore"))]
266        commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*]);
267    };
268     ($keystore_transaction:expr, $db:expr, $([ $( ($records:ident, $entity:ty) ),*]),*) => {
269            let cached_collections = ( $( $(
270            $keystore_transaction.cache.find_all::<$entity>(Default::default()).await?,
271                )* )* );
272
273             let ( $( $( $records, )* )* ) = cached_collections;
274
275            let mut conn = $db.borrow_conn().await?;
276            let deleted_ids = $keystore_transaction.deleted.read().await;
277
278            let mut tables = Vec::new();
279            $( $(
280                if !$records.is_empty() {
281                    tables.push(<$entity>::COLLECTION_NAME);
282                }
283            )* )*
284
285            for deleted_id in deleted_ids.iter() {
286                tables.push(deleted_id.collection_name());
287            }
288
289            if tables.is_empty() {
290                log::warn!("Empty transaction was committed, this could be an indication of a programming error");
291                return Ok(());
292            }
293
294            #[cfg(target_family = "wasm")]
295            let tx = conn.new_transaction(&tables).await?;
296            #[cfg(not(target_family = "wasm"))]
297            let tx = conn.new_transaction().await?;
298
299             $( $(
300                if !$records.is_empty() {
301                    for record in $records {
302                        dynamic_dispatch::execute_save(&tx, &record.to_transaction_entity()).await?;
303                    }
304                }
305             )* )*
306
307
308        for deleted_id in deleted_ids.iter() {
309            dynamic_dispatch::execute_delete(&tx, deleted_id).await?
310        }
311
312        for deleted_credential in $keystore_transaction.deleted_credentials.read().await.iter() {
313            MlsCredential::delete_by_credential(&tx, deleted_credential.to_owned()).await?;
314        }
315
316         tx.commit_tx().await?;
317     };
318}
319
320impl KeystoreTransaction {
321    /// Persists all the operations in the database. It will effectively open a transaction
322    /// internally, perform all the buffered operations and commit.
323    pub(crate) async fn commit(&self, db: &Connection) -> Result<(), CryptoKeystoreError> {
324        commit_transaction!(
325            self, db,
326            [
327                (identifier_01, MlsCredential),
328                (identifier_02, MlsSignatureKeyPair),
329                (identifier_03, MlsHpkePrivateKey),
330                (identifier_04, MlsEncryptionKeyPair),
331                (identifier_05, MlsEpochEncryptionKeyPair),
332                (identifier_06, MlsPskBundle),
333                (identifier_07, MlsKeyPackage),
334                (identifier_08, PersistedMlsGroup),
335                (identifier_09, PersistedMlsPendingGroup),
336                (identifier_10, MlsPendingMessage),
337                (identifier_11, E2eiEnrollment),
338                (identifier_12, E2eiRefreshToken),
339                (identifier_13, E2eiAcmeCA),
340                (identifier_14, E2eiIntermediateCert),
341                (identifier_15, E2eiCrl),
342                (identifier_16, ConsumerData)
343            ],
344            proteus_types: [
345                (identifier_17, ProteusPrekey),
346                (identifier_18, ProteusIdentity),
347                (identifier_19, ProteusSession)
348            ]
349        );
350
351        Ok(())
352    }
353}