core_crypto_keystore/transaction/
mod.rs

1pub mod dynamic_dispatch;
2
3use crate::entities::mls::*;
4#[cfg(feature = "proteus-keystore")]
5use crate::entities::proteus::*;
6use crate::entities::{ConsumerData, EntityBase, EntityFindParams, EntityTransactionExt, UniqueEntity};
7use crate::transaction::dynamic_dispatch::EntityId;
8use crate::{
9    CryptoKeystoreError, CryptoKeystoreResult,
10    connection::{Connection, DatabaseKey, FetchFromDatabase, KeystoreDatabaseConnection},
11};
12use async_lock::{RwLock, SemaphoreGuardArc};
13use itertools::Itertools;
14use std::{ops::DerefMut, sync::Arc};
15
16/// This represents a transaction, where all operations will be done in memory and committed at the
17/// end
18#[derive(Debug, Clone)]
19pub(crate) struct KeystoreTransaction {
20    /// In-memory cache
21    cache: Connection,
22    deleted: Arc<RwLock<Vec<EntityId>>>,
23    deleted_credentials: Arc<RwLock<Vec<Vec<u8>>>>,
24    _semaphore_guard: Arc<SemaphoreGuardArc>,
25}
26
27impl KeystoreTransaction {
28    pub(crate) async fn new(semaphore_guard: SemaphoreGuardArc) -> CryptoKeystoreResult<Self> {
29        // We don't really care about the key and we're not going to store it anywhere.
30        let key = DatabaseKey::from([0u8; 32]);
31        Ok(Self {
32            cache: Connection::open_in_memory_with_key("core_crypto_transaction_cache", &key).await?,
33            deleted: Arc::new(Default::default()),
34            deleted_credentials: Arc::new(Default::default()),
35            _semaphore_guard: Arc::new(semaphore_guard),
36        })
37    }
38
39    pub(crate) async fn save_mut<
40        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt + Sync,
41    >(
42        &self,
43        mut entity: E,
44    ) -> CryptoKeystoreResult<E> {
45        entity.pre_save().await?;
46        let conn = self.cache.borrow_conn().await?;
47        let mut conn = conn.conn().await;
48        #[cfg(target_family = "wasm")]
49        let transaction = conn.new_transaction(&[E::COLLECTION_NAME]).await?;
50        #[cfg(not(target_family = "wasm"))]
51        let transaction = conn.transaction()?.into();
52        entity.save(&transaction).await?;
53        transaction.commit_tx().await?;
54        Ok(entity)
55    }
56
57    pub(crate) async fn remove<
58        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt,
59        S: AsRef<[u8]>,
60    >(
61        &self,
62        id: S,
63    ) -> CryptoKeystoreResult<()> {
64        let conn = self.cache.borrow_conn().await?;
65        let mut conn = conn.conn().await;
66        #[cfg(target_family = "wasm")]
67        let transaction = conn.new_transaction(&[E::COLLECTION_NAME]).await?;
68        #[cfg(not(target_family = "wasm"))]
69        let transaction = conn.transaction()?.into();
70        E::delete(&transaction, id.as_ref().into()).await?;
71        transaction.commit_tx().await?;
72        let mut deleted_list = self.deleted.write().await;
73        deleted_list.push(EntityId::from_collection_name(E::COLLECTION_NAME, id.as_ref())?);
74        Ok(())
75    }
76
77    pub(crate) async fn child_groups<
78        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + PersistedMlsGroupExt + Sync,
79    >(
80        &self,
81        entity: E,
82        persisted_records: Vec<E>,
83    ) -> CryptoKeystoreResult<Vec<E>> {
84        let mut conn = self.cache.borrow_conn().await?;
85        let cached_records = entity.child_groups(conn.deref_mut()).await?;
86        Ok(self
87            .merge_records(cached_records, persisted_records, EntityFindParams::default())
88            .await)
89    }
90
91    pub(crate) async fn cred_delete_by_credential(&self, cred: Vec<u8>) -> CryptoKeystoreResult<()> {
92        let conn = self.cache.borrow_conn().await?;
93        let mut conn = conn.conn().await;
94        #[cfg(target_family = "wasm")]
95        let transaction = conn.new_transaction(&[MlsCredential::COLLECTION_NAME]).await?;
96        #[cfg(not(target_family = "wasm"))]
97        let transaction = conn.transaction()?.into();
98        MlsCredential::delete_by_credential(&transaction, cred.clone()).await?;
99        transaction.commit_tx().await?;
100        let mut deleted_list = self.deleted_credentials.write().await;
101        deleted_list.push(cred);
102        Ok(())
103    }
104
105    /// The result of this function will have different contents for different scenarios:
106    /// * `Some(Some(E))` - the transaction cache contains the record
107    /// * `Some(None)` - the deletion of the record has been cached
108    /// * `None` - there is no information about the record in the cache
109    pub(crate) async fn find<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<Option<E>>>
110    where
111        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
112    {
113        let cache_result = self.cache.find(id).await?;
114        match cache_result {
115            Some(cache_result) => Ok(Some(Some(cache_result))),
116            _ => {
117                let deleted_list = self.deleted.read().await;
118                if deleted_list.contains(&EntityId::from_collection_name(E::COLLECTION_NAME, id)?) {
119                    Ok(Some(None))
120                } else {
121                    Ok(None)
122                }
123            }
124        }
125    }
126
127    pub(crate) async fn find_unique<U: UniqueEntity<ConnectionType = KeystoreDatabaseConnection>>(
128        &self,
129    ) -> CryptoKeystoreResult<Option<U>> {
130        let cache_result = self.cache.find_unique().await;
131        match cache_result {
132            Ok(cache_result) => Ok(Some(cache_result)),
133            _ => {
134                // The deleted list doesn't have to be checked because unique entities don't implement
135                // deletion, just replace. So we can directly return None.
136                Ok(None)
137            }
138        }
139    }
140
141    pub(crate) async fn find_all<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
142        &self,
143        persisted_records: Vec<E>,
144        params: EntityFindParams,
145    ) -> CryptoKeystoreResult<Vec<E>> {
146        let cached_records: Vec<E> = self.cache.find_all(params.clone()).await?;
147        Ok(self.merge_records(cached_records, persisted_records, params).await)
148    }
149
150    pub(crate) async fn find_many<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
151        &self,
152        persisted_records: Vec<E>,
153        ids: &[Vec<u8>],
154    ) -> CryptoKeystoreResult<Vec<E>> {
155        let cached_records: Vec<E> = self.cache.find_many(ids).await?;
156        Ok(self
157            .merge_records(cached_records, persisted_records, EntityFindParams::default())
158            .await)
159    }
160
161    /// Build a single list of unique records from two potentially overlapping lists.
162    /// In case of overlap, records in `records_a` are prioritized.
163    /// Identity from the perspective of this function is determined by the output of [crate::entities::Entity::merge_key].
164    ///
165    /// Further, the output list of records is built with respect to the provided [EntityFindParams]
166    /// and the deleted records cached in this [Self] instance.
167    async fn merge_records<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
168        &self,
169        records_a: Vec<E>,
170        records_b: Vec<E>,
171        params: EntityFindParams,
172    ) -> Vec<E> {
173        let merged = records_a.into_iter().chain(records_b).unique_by(|e| e.merge_key());
174
175        // We are consuming the iterator here to keep types of the `if` and `else` block consistent.
176        // The alternative to giving up laziness here would be to use a dynamically
177        // typed iterator Box<dyn Iterator<Item = E>> assigned to `merged`. The below approach
178        // trades stack allocation instead of heap allocation for laziness.
179        //
180        // Also, we have to do this before filtering by deleted records since filter map does not
181        // return an iterator that is double ended.
182        let merged: Vec<E> = if params.reverse {
183            merged.rev().collect()
184        } else {
185            merged.collect()
186        };
187
188        if merged.is_empty() {
189            return merged;
190        }
191
192        let deleted_records = self.deleted.read().await;
193        let deleted_credentials = self.deleted_credentials.read().await;
194        let merged = if deleted_records.is_empty() && deleted_credentials.is_empty() {
195            merged
196        } else {
197            merged
198                .into_iter()
199                .filter(|record| {
200                    !Self::record_is_in_deleted_list(record, &deleted_records)
201                        && !Self::credential_is_in_deleted_list(record, &deleted_credentials)
202                })
203                .collect()
204        };
205
206        merged
207            .into_iter()
208            .skip(params.offset.unwrap_or(0) as usize)
209            .take(params.limit.unwrap_or(u32::MAX) as usize)
210            .collect()
211    }
212
213    fn record_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
214        record: &E,
215        deleted_records: &[EntityId],
216    ) -> bool {
217        let id = EntityId::from_collection_name(E::COLLECTION_NAME, record.id_raw());
218        let Ok(id) = id else { return false };
219        deleted_records.contains(&id)
220    }
221    fn credential_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
222        maybe_credential: &E,
223        deleted_credentials: &[Vec<u8>],
224    ) -> bool {
225        let Some(credential) = maybe_credential.downcast::<MlsCredential>() else {
226            return false;
227        };
228        deleted_credentials.contains(&credential.credential)
229    }
230}
231
232/// Persist all records cached in `$keystore_transaction` (first argument),
233/// using a transaction on `$db` (second argument).
234/// Use the provided types to read from the cache and write to the `$db`.
235///
236/// # Examples
237/// ```rust,ignore
238/// let transaction = KeystoreTransaction::new();
239/// let db = Connection::new();
240///
241/// // Commit records of all provided types
242/// commit_transaction!(
243///     transaction, db,
244///     [
245///         (identifier_01, MlsCredential),
246///         (identifier_02, MlsSignatureKeyPair),
247///     ],
248/// );
249///
250/// // Commit records of provided types in the first list. Commit records of types in the second
251/// // list only if the "proteus-keystore" cargo feature is enabled.
252/// commit_transaction!(
253///     transaction, db,
254///     [
255///         (identifier_01, MlsCredential),
256///         (identifier_02, MlsSignatureKeyPair),
257///     ],
258///     proteus_types: [
259///         (identifier_03, ProteusPrekey),
260///         (identifier_04, ProteusIdentity),
261///         (identifier_05, ProteusSession)
262///     ]
263/// );
264///```
265macro_rules! commit_transaction {
266    ($keystore_transaction:expr_2021, $db:expr_2021, [ $( ($records:ident, $entity:ty) ),*], proteus_types: [ $( ($conditional_records:ident, $conditional_entity:ty) ),*]) => {
267        #[cfg(feature = "proteus-keystore")]
268        commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*], [ $( ($conditional_records, $conditional_entity) ),*]);
269
270        #[cfg(not(feature = "proteus-keystore"))]
271        commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*]);
272    };
273     ($keystore_transaction:expr_2021, $db:expr_2021, $([ $( ($records:ident, $entity:ty) ),*]),*) => {
274            let cached_collections = ( $( $(
275            $keystore_transaction.cache.find_all::<$entity>(Default::default()).await?,
276                )* )* );
277
278             let ( $( $( $records, )* )* ) = cached_collections;
279
280            let conn = $db.borrow_conn().await?;
281            let mut conn = conn.conn().await;
282            let deleted_ids = $keystore_transaction.deleted.read().await;
283
284            let mut tables = Vec::new();
285            $( $(
286                if !$records.is_empty() {
287                    tables.push(<$entity>::COLLECTION_NAME);
288                }
289            )* )*
290
291            for deleted_id in deleted_ids.iter() {
292                tables.push(deleted_id.collection_name());
293            }
294
295            if tables.is_empty() {
296                log::debug!("Empty transaction was committed.");
297                return Ok(());
298            }
299
300            #[cfg(target_family = "wasm")]
301            let tx = conn.new_transaction(&tables).await?;
302            #[cfg(not(target_family = "wasm"))]
303            let tx = conn.transaction()?.into();
304
305             $( $(
306                if !$records.is_empty() {
307                    for record in $records {
308                        dynamic_dispatch::execute_save(&tx, &record.to_transaction_entity()).await?;
309                    }
310                }
311             )* )*
312
313
314        for deleted_id in deleted_ids.iter() {
315            dynamic_dispatch::execute_delete(&tx, deleted_id).await?
316        }
317
318        for deleted_credential in $keystore_transaction.deleted_credentials.read().await.iter() {
319            MlsCredential::delete_by_credential(&tx, deleted_credential.to_owned()).await?;
320        }
321
322         tx.commit_tx().await?;
323     };
324}
325
326impl KeystoreTransaction {
327    /// Persists all the operations in the database. It will effectively open a transaction
328    /// internally, perform all the buffered operations and commit.
329    pub(crate) async fn commit(&self, db: &Connection) -> Result<(), CryptoKeystoreError> {
330        commit_transaction!(
331            self, db,
332            [
333                (identifier_01, MlsCredential),
334                (identifier_02, MlsSignatureKeyPair),
335                (identifier_03, MlsHpkePrivateKey),
336                (identifier_04, MlsEncryptionKeyPair),
337                (identifier_05, MlsEpochEncryptionKeyPair),
338                (identifier_06, MlsPskBundle),
339                (identifier_07, MlsKeyPackage),
340                (identifier_08, PersistedMlsGroup),
341                (identifier_09, PersistedMlsPendingGroup),
342                (identifier_10, MlsPendingMessage),
343                (identifier_11, E2eiEnrollment),
344                (identifier_12, E2eiRefreshToken),
345                (identifier_13, E2eiAcmeCA),
346                (identifier_14, E2eiIntermediateCert),
347                (identifier_15, E2eiCrl),
348                (identifier_16, ConsumerData)
349            ],
350            proteus_types: [
351                (identifier_17, ProteusPrekey),
352                (identifier_18, ProteusIdentity),
353                (identifier_19, ProteusSession)
354            ]
355        );
356
357        Ok(())
358    }
359}