core_crypto_keystore/transaction/
mod.rs

1pub mod dynamic_dispatch;
2
3use crate::connection::ConnectionType;
4use crate::entities::mls::*;
5#[cfg(feature = "proteus-keystore")]
6use crate::entities::proteus::*;
7use crate::entities::{ConsumerData, EntityBase, EntityFindParams, EntityTransactionExt, UniqueEntity};
8use crate::transaction::dynamic_dispatch::EntityId;
9use crate::{
10    CryptoKeystoreError, CryptoKeystoreResult,
11    connection::{Connection, DatabaseKey, FetchFromDatabase, KeystoreDatabaseConnection},
12};
13use async_lock::{RwLock, SemaphoreGuardArc};
14use itertools::Itertools;
15use std::{ops::DerefMut, sync::Arc};
16
17/// This represents a transaction, where all operations will be done in memory and committed at the
18/// end
19#[derive(Debug, Clone)]
20pub(crate) struct KeystoreTransaction {
21    /// In-memory cache
22    cache: Connection,
23    deleted: Arc<RwLock<Vec<EntityId>>>,
24    deleted_credentials: Arc<RwLock<Vec<Vec<u8>>>>,
25    _semaphore_guard: Arc<SemaphoreGuardArc>,
26}
27
28impl KeystoreTransaction {
29    pub(crate) async fn new(semaphore_guard: SemaphoreGuardArc) -> CryptoKeystoreResult<Self> {
30        // We don't really care about the key and we're not going to store it anywhere.
31        let key = DatabaseKey::from([0u8; 32]);
32        Ok(Self {
33            cache: Connection::open(ConnectionType::InMemory, &key).await?,
34            deleted: Arc::new(Default::default()),
35            deleted_credentials: Arc::new(Default::default()),
36            _semaphore_guard: Arc::new(semaphore_guard),
37        })
38    }
39
40    pub(crate) async fn save_mut<
41        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt + Sync,
42    >(
43        &self,
44        mut entity: E,
45    ) -> CryptoKeystoreResult<E> {
46        entity.pre_save().await?;
47        let conn = self.cache.borrow_conn().await?;
48        let mut conn = conn.conn().await;
49        #[cfg(target_family = "wasm")]
50        let transaction = conn.new_transaction(&[E::COLLECTION_NAME]).await?;
51        #[cfg(not(target_family = "wasm"))]
52        let transaction = conn.transaction()?.into();
53        entity.save(&transaction).await?;
54        transaction.commit_tx().await?;
55        Ok(entity)
56    }
57
58    pub(crate) async fn remove<
59        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt,
60        S: AsRef<[u8]>,
61    >(
62        &self,
63        id: S,
64    ) -> CryptoKeystoreResult<()> {
65        let conn = self.cache.borrow_conn().await?;
66        let mut conn = conn.conn().await;
67        #[cfg(target_family = "wasm")]
68        let transaction = conn.new_transaction(&[E::COLLECTION_NAME]).await?;
69        #[cfg(not(target_family = "wasm"))]
70        let transaction = conn.transaction()?.into();
71        E::delete(&transaction, id.as_ref().into()).await?;
72        transaction.commit_tx().await?;
73        let mut deleted_list = self.deleted.write().await;
74        deleted_list.push(EntityId::from_collection_name(E::COLLECTION_NAME, id.as_ref())?);
75        Ok(())
76    }
77
78    pub(crate) async fn child_groups<
79        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + PersistedMlsGroupExt + Sync,
80    >(
81        &self,
82        entity: E,
83        persisted_records: Vec<E>,
84    ) -> CryptoKeystoreResult<Vec<E>> {
85        let mut conn = self.cache.borrow_conn().await?;
86        let cached_records = entity.child_groups(conn.deref_mut()).await?;
87        Ok(self
88            .merge_records(cached_records, persisted_records, EntityFindParams::default())
89            .await)
90    }
91
92    pub(crate) async fn cred_delete_by_credential(&self, cred: Vec<u8>) -> CryptoKeystoreResult<()> {
93        let conn = self.cache.borrow_conn().await?;
94        let mut conn = conn.conn().await;
95        #[cfg(target_family = "wasm")]
96        let transaction = conn.new_transaction(&[MlsCredential::COLLECTION_NAME]).await?;
97        #[cfg(not(target_family = "wasm"))]
98        let transaction = conn.transaction()?.into();
99        MlsCredential::delete_by_credential(&transaction, cred.clone()).await?;
100        transaction.commit_tx().await?;
101        let mut deleted_list = self.deleted_credentials.write().await;
102        deleted_list.push(cred);
103        Ok(())
104    }
105
106    /// The result of this function will have different contents for different scenarios:
107    /// * `Some(Some(E))` - the transaction cache contains the record
108    /// * `Some(None)` - the deletion of the record has been cached
109    /// * `None` - there is no information about the record in the cache
110    pub(crate) async fn find<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<Option<E>>>
111    where
112        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
113    {
114        let cache_result = self.cache.find(id).await?;
115        match cache_result {
116            Some(cache_result) => Ok(Some(Some(cache_result))),
117            _ => {
118                let deleted_list = self.deleted.read().await;
119                if deleted_list.contains(&EntityId::from_collection_name(E::COLLECTION_NAME, id)?) {
120                    Ok(Some(None))
121                } else {
122                    Ok(None)
123                }
124            }
125        }
126    }
127
128    pub(crate) async fn find_unique<U: UniqueEntity<ConnectionType = KeystoreDatabaseConnection>>(
129        &self,
130    ) -> CryptoKeystoreResult<Option<U>> {
131        let cache_result = self.cache.find_unique().await;
132        match cache_result {
133            Ok(cache_result) => Ok(Some(cache_result)),
134            _ => {
135                // The deleted list doesn't have to be checked because unique entities don't implement
136                // deletion, just replace. So we can directly return None.
137                Ok(None)
138            }
139        }
140    }
141
142    pub(crate) async fn find_all<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
143        &self,
144        persisted_records: Vec<E>,
145        params: EntityFindParams,
146    ) -> CryptoKeystoreResult<Vec<E>> {
147        let cached_records: Vec<E> = self.cache.find_all(params.clone()).await?;
148        Ok(self.merge_records(cached_records, persisted_records, params).await)
149    }
150
151    pub(crate) async fn find_many<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
152        &self,
153        persisted_records: Vec<E>,
154        ids: &[Vec<u8>],
155    ) -> CryptoKeystoreResult<Vec<E>> {
156        let cached_records: Vec<E> = self.cache.find_many(ids).await?;
157        Ok(self
158            .merge_records(cached_records, persisted_records, EntityFindParams::default())
159            .await)
160    }
161
162    /// Build a single list of unique records from two potentially overlapping lists.
163    /// In case of overlap, records in `records_a` are prioritized.
164    /// Identity from the perspective of this function is determined by the output of [crate::entities::Entity::merge_key].
165    ///
166    /// Further, the output list of records is built with respect to the provided [EntityFindParams]
167    /// and the deleted records cached in this [Self] instance.
168    async fn merge_records<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
169        &self,
170        records_a: Vec<E>,
171        records_b: Vec<E>,
172        params: EntityFindParams,
173    ) -> Vec<E> {
174        let merged = records_a.into_iter().chain(records_b).unique_by(|e| e.merge_key());
175
176        // We are consuming the iterator here to keep types of the `if` and `else` block consistent.
177        // The alternative to giving up laziness here would be to use a dynamically
178        // typed iterator Box<dyn Iterator<Item = E>> assigned to `merged`. The below approach
179        // trades stack allocation instead of heap allocation for laziness.
180        //
181        // Also, we have to do this before filtering by deleted records since filter map does not
182        // return an iterator that is double ended.
183        let merged: Vec<E> = if params.reverse {
184            merged.rev().collect()
185        } else {
186            merged.collect()
187        };
188
189        if merged.is_empty() {
190            return merged;
191        }
192
193        let deleted_records = self.deleted.read().await;
194        let deleted_credentials = self.deleted_credentials.read().await;
195        let merged = if deleted_records.is_empty() && deleted_credentials.is_empty() {
196            merged
197        } else {
198            merged
199                .into_iter()
200                .filter(|record| {
201                    !Self::record_is_in_deleted_list(record, &deleted_records)
202                        && !Self::credential_is_in_deleted_list(record, &deleted_credentials)
203                })
204                .collect()
205        };
206
207        merged
208            .into_iter()
209            .skip(params.offset.unwrap_or(0) as usize)
210            .take(params.limit.unwrap_or(u32::MAX) as usize)
211            .collect()
212    }
213
214    fn record_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
215        record: &E,
216        deleted_records: &[EntityId],
217    ) -> bool {
218        let id = EntityId::from_collection_name(E::COLLECTION_NAME, record.id_raw());
219        let Ok(id) = id else { return false };
220        deleted_records.contains(&id)
221    }
222    fn credential_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
223        maybe_credential: &E,
224        deleted_credentials: &[Vec<u8>],
225    ) -> bool {
226        let Some(credential) = maybe_credential.downcast::<MlsCredential>() else {
227            return false;
228        };
229        deleted_credentials.contains(&credential.credential)
230    }
231}
232
233/// Persist all records cached in `$keystore_transaction` (first argument),
234/// using a transaction on `$db` (second argument).
235/// Use the provided types to read from the cache and write to the `$db`.
236///
237/// # Examples
238/// ```rust,ignore
239/// let transaction = KeystoreTransaction::new();
240/// let db = Connection::new();
241///
242/// // Commit records of all provided types
243/// commit_transaction!(
244///     transaction, db,
245///     [
246///         (identifier_01, MlsCredential),
247///         (identifier_02, MlsSignatureKeyPair),
248///     ],
249/// );
250///
251/// // Commit records of provided types in the first list. Commit records of types in the second
252/// // list only if the "proteus-keystore" cargo feature is enabled.
253/// commit_transaction!(
254///     transaction, db,
255///     [
256///         (identifier_01, MlsCredential),
257///         (identifier_02, MlsSignatureKeyPair),
258///     ],
259///     proteus_types: [
260///         (identifier_03, ProteusPrekey),
261///         (identifier_04, ProteusIdentity),
262///         (identifier_05, ProteusSession)
263///     ]
264/// );
265///```
266macro_rules! commit_transaction {
267    ($keystore_transaction:expr_2021, $db:expr_2021, [ $( ($records:ident, $entity:ty) ),*], proteus_types: [ $( ($conditional_records:ident, $conditional_entity:ty) ),*]) => {
268        #[cfg(feature = "proteus-keystore")]
269        commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*], [ $( ($conditional_records, $conditional_entity) ),*]);
270
271        #[cfg(not(feature = "proteus-keystore"))]
272        commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*]);
273    };
274     ($keystore_transaction:expr_2021, $db:expr_2021, $([ $( ($records:ident, $entity:ty) ),*]),*) => {
275            let cached_collections = ( $( $(
276            $keystore_transaction.cache.find_all::<$entity>(Default::default()).await?,
277                )* )* );
278
279             let ( $( $( $records, )* )* ) = cached_collections;
280
281            let conn = $db.borrow_conn().await?;
282            let mut conn = conn.conn().await;
283            let deleted_ids = $keystore_transaction.deleted.read().await;
284
285            let mut tables = Vec::new();
286            $( $(
287                if !$records.is_empty() {
288                    tables.push(<$entity>::COLLECTION_NAME);
289                }
290            )* )*
291
292            for deleted_id in deleted_ids.iter() {
293                tables.push(deleted_id.collection_name());
294            }
295
296            if tables.is_empty() {
297                log::debug!("Empty transaction was committed.");
298                return Ok(());
299            }
300
301            #[cfg(target_family = "wasm")]
302            let tx = conn.new_transaction(&tables).await?;
303            #[cfg(not(target_family = "wasm"))]
304            let tx = conn.transaction()?.into();
305
306             $( $(
307                if !$records.is_empty() {
308                    for record in $records {
309                        dynamic_dispatch::execute_save(&tx, &record.to_transaction_entity()).await?;
310                    }
311                }
312             )* )*
313
314
315        for deleted_id in deleted_ids.iter() {
316            dynamic_dispatch::execute_delete(&tx, deleted_id).await?
317        }
318
319        for deleted_credential in $keystore_transaction.deleted_credentials.read().await.iter() {
320            MlsCredential::delete_by_credential(&tx, deleted_credential.to_owned()).await?;
321        }
322
323         tx.commit_tx().await?;
324     };
325}
326
327impl KeystoreTransaction {
328    /// Persists all the operations in the database. It will effectively open a transaction
329    /// internally, perform all the buffered operations and commit.
330    pub(crate) async fn commit(&self, db: &Connection) -> Result<(), CryptoKeystoreError> {
331        commit_transaction!(
332            self, db,
333            [
334                (identifier_01, MlsCredential),
335                (identifier_02, MlsSignatureKeyPair),
336                (identifier_03, MlsHpkePrivateKey),
337                (identifier_04, MlsEncryptionKeyPair),
338                (identifier_05, MlsEpochEncryptionKeyPair),
339                (identifier_06, MlsPskBundle),
340                (identifier_07, MlsKeyPackage),
341                (identifier_08, PersistedMlsGroup),
342                (identifier_09, PersistedMlsPendingGroup),
343                (identifier_10, MlsPendingMessage),
344                (identifier_11, E2eiEnrollment),
345                // (identifier_12, E2eiRefreshToken),
346                (identifier_13, E2eiAcmeCA),
347                (identifier_14, E2eiIntermediateCert),
348                (identifier_15, E2eiCrl),
349                (identifier_16, ConsumerData)
350            ],
351            proteus_types: [
352                (identifier_17, ProteusPrekey),
353                (identifier_18, ProteusIdentity),
354                (identifier_19, ProteusSession)
355            ]
356        );
357
358        Ok(())
359    }
360}