core_crypto_keystore/transaction/
mod.rs

1use std::{
2    collections::{HashMap, hash_map::Entry},
3    sync::Arc,
4};
5
6use async_lock::{RwLock, SemaphoreGuardArc};
7use itertools::Itertools;
8use zeroize::Zeroizing;
9
10#[cfg(feature = "proteus-keystore")]
11use crate::entities::proteus::*;
12use crate::{
13    CryptoKeystoreError, CryptoKeystoreResult,
14    connection::{Database, KeystoreDatabaseConnection},
15    entities::{ConsumerData, EntityBase, EntityFindParams, EntityTransactionExt, UniqueEntity, mls::*},
16    transaction::dynamic_dispatch::EntityId,
17};
18
19pub mod dynamic_dispatch;
20
21#[derive(Debug, Default, derive_more::Deref, derive_more::DerefMut)]
22struct InMemoryTable(HashMap<Vec<u8>, Zeroizing<Vec<u8>>>);
23
24type InMemoryCache = Arc<RwLock<HashMap<String, InMemoryTable>>>;
25
26/// This represents a transaction, where all operations will be done in memory and committed at the
27/// end
28#[derive(Debug, Clone)]
29pub(crate) struct KeystoreTransaction {
30    cache: InMemoryCache,
31    deleted: Arc<RwLock<Vec<EntityId>>>,
32    deleted_credentials: Arc<RwLock<Vec<Vec<u8>>>>,
33    _semaphore_guard: Arc<SemaphoreGuardArc>,
34}
35
36impl KeystoreTransaction {
37    pub(crate) async fn new(semaphore_guard: SemaphoreGuardArc) -> CryptoKeystoreResult<Self> {
38        Ok(Self {
39            cache: Default::default(),
40            deleted: Arc::new(Default::default()),
41            deleted_credentials: Arc::new(Default::default()),
42            _semaphore_guard: Arc::new(semaphore_guard),
43        })
44    }
45
46    pub(crate) async fn save_mut<
47        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt + Sync,
48    >(
49        &self,
50        mut entity: E,
51    ) -> CryptoKeystoreResult<E> {
52        entity.pre_save().await?;
53        let mut cache_guard = self.cache.write().await;
54        let table = cache_guard.entry(E::COLLECTION_NAME.to_string()).or_default();
55        let serialized = postcard::to_stdvec(&entity)?;
56        // Use merge_key() because `id_raw()` is not always unique for records.
57        // For `MlsCredential`, `id_raw()` is the `CLientId`.
58        // For `MlsPendingMessage` it's the id of the group it belongs to.
59        table.insert(entity.merge_key(), Zeroizing::new(serialized));
60        Ok(entity)
61    }
62
63    pub(crate) async fn remove<
64        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt,
65        S: AsRef<[u8]>,
66    >(
67        &self,
68        id: S,
69    ) -> CryptoKeystoreResult<()> {
70        let mut cache_guard = self.cache.write().await;
71        if let Entry::Occupied(mut table) = cache_guard.entry(E::COLLECTION_NAME.to_string())
72            && let Entry::Occupied(cached_record) = table.get_mut().entry(id.as_ref().to_vec())
73        {
74            cached_record.remove_entry();
75        };
76
77        let mut deleted_list = self.deleted.write().await;
78        deleted_list.push(EntityId::from_collection_name(E::COLLECTION_NAME, id.as_ref())?);
79        Ok(())
80    }
81
82    pub(crate) async fn child_groups<E>(&self, entity: E, persisted_records: Vec<E>) -> CryptoKeystoreResult<Vec<E>>
83    where
84        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection> + PersistedMlsGroupExt + Sync,
85    {
86        // First get all raw groups from the cache, then deserialize them to enable filtering by there parent id
87        // matching `entity.id_raw()`.
88        let cached_records = self
89            .find_all_in_cache()
90            .await?
91            .into_iter()
92            .filter(|maybe_child: &E| {
93                maybe_child
94                    .parent_id()
95                    .map(|parent_id| parent_id == entity.id_raw())
96                    .unwrap_or_default()
97            })
98            .collect();
99
100        Ok(self
101            .merge_records(cached_records, persisted_records, EntityFindParams::default())
102            .await)
103    }
104
105    pub(crate) async fn cred_delete_by_credential(&self, cred: Vec<u8>) -> CryptoKeystoreResult<()> {
106        let mut cache_guard = self.cache.write().await;
107        if let Entry::Occupied(mut table) = cache_guard.entry(MlsCredential::COLLECTION_NAME.to_string()) {
108            table.get_mut().retain(|_, value| **value != cred);
109        }
110
111        let mut deleted_list = self.deleted_credentials.write().await;
112        deleted_list.push(cred);
113        Ok(())
114    }
115
116    pub(crate) async fn remove_pending_messages_by_conversation_id(
117        &self,
118        conversation_id: impl AsRef<[u8]> + Send,
119    ) -> CryptoKeystoreResult<()> {
120        // We cannot return an error from `retain()`, so we've got to do this dance with a mutable result.
121        let mut result = Ok(());
122
123        let mut cache_guard = self.cache.write().await;
124        if let Entry::Occupied(mut table) = cache_guard.entry(MlsPendingMessage::COLLECTION_NAME.to_string()) {
125            table.get_mut().retain(|_key, record_bytes| {
126                postcard::from_bytes::<MlsPendingMessage>(record_bytes)
127                    .map(|pending_message| pending_message.foreign_id != conversation_id.as_ref())
128                    .inspect_err(|err| result = Err(err.clone()))
129                    .unwrap_or(false)
130            });
131        }
132
133        let mut deleted_list = self.deleted.write().await;
134        deleted_list.push(EntityId::from_collection_name(
135            MlsPendingMessage::COLLECTION_NAME,
136            conversation_id.as_ref(),
137        )?);
138        result.map_err(Into::into)
139    }
140
141    pub(crate) async fn find_pending_messages_by_conversation_id(
142        &self,
143        conversation_id: &[u8],
144        persisted_records: Vec<MlsPendingMessage>,
145    ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
146        let cached_records = self
147            .find_all_in_cache::<MlsPendingMessage>()
148            .await?
149            .into_iter()
150            .filter(|pending_message| pending_message.foreign_id == conversation_id)
151            .collect();
152        let merged_records = self
153            .merge_records(cached_records, persisted_records, Default::default())
154            .await;
155        Ok(merged_records)
156    }
157
158    async fn find_in_cache<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<E>>
159    where
160        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
161    {
162        let cache_guard = self.cache.read().await;
163        cache_guard
164            .get(E::COLLECTION_NAME)
165            .and_then(|table| {
166                table
167                    .get(id)
168                    .map(|record| -> CryptoKeystoreResult<_> { postcard::from_bytes::<E>(record).map_err(Into::into) })
169            })
170            .transpose()
171    }
172
173    /// The result of this function will have different contents for different scenarios:
174    /// * `Some(Some(E))` - the transaction cache contains the record
175    /// * `Some(None)` - the deletion of the record has been cached
176    /// * `None` - there is no information about the record in the cache
177    pub(crate) async fn find<E>(&self, id: &[u8]) -> CryptoKeystoreResult<Option<Option<E>>>
178    where
179        E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>,
180    {
181        let maybe_cached_record = self.find_in_cache(id).await?;
182        if let Some(cached_record) = maybe_cached_record {
183            return Ok(Some(Some(cached_record)));
184        }
185
186        let deleted_list = self.deleted.read().await;
187        if deleted_list.contains(&EntityId::from_collection_name(E::COLLECTION_NAME, id)?) {
188            return Ok(Some(None));
189        }
190
191        Ok(None)
192    }
193
194    pub(crate) async fn find_unique<U: UniqueEntity<ConnectionType = KeystoreDatabaseConnection>>(
195        &self,
196    ) -> CryptoKeystoreResult<Option<U>> {
197        #[cfg(target_family = "wasm")]
198        let id = &U::ID;
199        #[cfg(not(target_family = "wasm"))]
200        let id = &[U::ID as u8];
201        let maybe_cached_record = self.find_in_cache::<U>(id).await?;
202        match maybe_cached_record {
203            Some(cached_record) => Ok(Some(cached_record)),
204            _ => {
205                // The deleted list doesn't have to be checked because unique entities don't implement
206                // deletion, just replace. So we can directly return None.
207                Ok(None)
208            }
209        }
210    }
211
212    async fn find_all_in_cache<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
213        &self,
214    ) -> CryptoKeystoreResult<Vec<E>> {
215        let cache_guard = self.cache.read().await;
216        let cached_records = cache_guard
217            .get(E::COLLECTION_NAME)
218            .map(|table| {
219                table
220                    .values()
221                    .map(|record| postcard::from_bytes::<E>(record).map_err(Into::into))
222                    .collect::<CryptoKeystoreResult<Vec<_>>>()
223            })
224            .transpose()?
225            .unwrap_or_default();
226        Ok(cached_records)
227    }
228
229    pub(crate) async fn find_all<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
230        &self,
231        persisted_records: Vec<E>,
232        params: EntityFindParams,
233    ) -> CryptoKeystoreResult<Vec<E>> {
234        let cached_records = self.find_all_in_cache().await?;
235        let merged_records = self.merge_records(cached_records, persisted_records, params).await;
236        Ok(merged_records)
237    }
238
239    pub(crate) async fn find_many<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
240        &self,
241        persisted_records: Vec<E>,
242        ids: &[Vec<u8>],
243    ) -> CryptoKeystoreResult<Vec<E>> {
244        let records = self
245            .find_all(persisted_records, EntityFindParams::default())
246            .await?
247            .into_iter()
248            .filter(|record| ids.contains(&record.id_raw().to_vec()))
249            .collect();
250        Ok(records)
251    }
252
253    /// Build a single list of unique records from two potentially overlapping lists.
254    /// In case of overlap, records in `records_a` are prioritized.
255    /// Identity from the perspective of this function is determined by the output of [crate::entities::Entity::merge_key].
256    ///
257    /// Further, the output list of records is built with respect to the provided [EntityFindParams]
258    /// and the deleted records cached in this [Self] instance.
259    async fn merge_records<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
260        &self,
261        records_a: Vec<E>,
262        records_b: Vec<E>,
263        params: EntityFindParams,
264    ) -> Vec<E> {
265        let mut merged = records_a.into_iter().chain(records_b).unique_by(|e| e.merge_key());
266
267        let deleted_records = self.deleted.read().await;
268        let deleted_credentials = self.deleted_credentials.read().await;
269
270        let merged: &mut dyn Iterator<Item = E> = if params.reverse { &mut merged.rev() } else { &mut merged };
271
272        merged
273            .filter(|record| {
274                !Self::record_is_in_deleted_list(record, &deleted_records)
275                    && !Self::credential_is_in_deleted_list(record, &deleted_credentials)
276            })
277            .skip(params.offset.unwrap_or(0) as usize)
278            .take(params.limit.unwrap_or(u32::MAX) as usize)
279            .collect()
280    }
281
282    fn record_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
283        record: &E,
284        deleted_records: &[EntityId],
285    ) -> bool {
286        let id = EntityId::from_collection_name(E::COLLECTION_NAME, record.id_raw());
287        let Ok(id) = id else { return false };
288        deleted_records.contains(&id)
289    }
290
291    fn credential_is_in_deleted_list<E: crate::entities::Entity<ConnectionType = KeystoreDatabaseConnection>>(
292        maybe_credential: &E,
293        deleted_credentials: &[Vec<u8>],
294    ) -> bool {
295        let Some(credential) = maybe_credential.downcast::<MlsCredential>() else {
296            return false;
297        };
298        deleted_credentials.contains(&credential.credential)
299    }
300}
301
302/// Persist all records cached in `$keystore_transaction` (first argument),
303/// using a transaction on `$db` (second argument).
304/// Use the provided types to read from the cache and write to the `$db`.
305///
306/// # Examples
307/// ```rust,ignore
308/// let transaction = KeystoreTransaction::new();
309/// let db = Connection::new();
310///
311/// // Commit records of all provided types
312/// commit_transaction!(
313///     transaction, db,
314///     [
315///         (identifier_01, MlsCredential),
316///         (identifier_02, MlsSignatureKeyPair),
317///     ],
318/// );
319///
320/// // Commit records of provided types in the first list. Commit records of types in the second
321/// // list only if the "proteus-keystore" cargo feature is enabled.
322/// commit_transaction!(
323///     transaction, db,
324///     [
325///         (identifier_01, MlsCredential),
326///         (identifier_02, MlsSignatureKeyPair),
327///     ],
328///     proteus_types: [
329///         (identifier_03, ProteusPrekey),
330///         (identifier_04, ProteusIdentity),
331///         (identifier_05, ProteusSession)
332///     ]
333/// );
334///```
335macro_rules! commit_transaction {
336    ($keystore_transaction:expr_2021, $db:expr_2021, [ $( ($records:ident, $entity:ty) ),*], proteus_types: [ $( ($conditional_records:ident, $conditional_entity:ty) ),*]) => {
337        #[cfg(feature = "proteus-keystore")]
338        commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*], [ $( ($conditional_records, $conditional_entity) ),*]);
339
340        #[cfg(not(feature = "proteus-keystore"))]
341        commit_transaction!($keystore_transaction, $db, [ $( ($records, $entity) ),*]);
342    };
343     ($keystore_transaction:expr_2021, $db:expr_2021, $([ $( ($records:ident, $entity:ty) ),*]),*) => {
344            let cached_collections = ( $( $(
345            $keystore_transaction.find_all_in_cache::<$entity>().await?,
346                )* )* );
347
348             let ( $( $( $records, )* )* ) = cached_collections;
349
350            let conn = $db.borrow_conn().await?;
351            let mut conn = conn.conn().await;
352            let deleted_ids = $keystore_transaction.deleted.read().await;
353
354            let mut tables = Vec::new();
355            $( $(
356                if !$records.is_empty() {
357                    tables.push(<$entity>::COLLECTION_NAME);
358                }
359            )* )*
360
361            for deleted_id in deleted_ids.iter() {
362                tables.push(deleted_id.collection_name());
363            }
364
365            if tables.is_empty() {
366                log::debug!("Empty transaction was committed.");
367                return Ok(());
368            }
369
370            #[cfg(target_family = "wasm")]
371            let tx = conn.new_transaction(&tables).await?;
372            #[cfg(not(target_family = "wasm"))]
373            let tx = conn.transaction()?.into();
374
375             $( $(
376                if !$records.is_empty() {
377                    for record in $records {
378                        dynamic_dispatch::execute_save(&tx, &record.to_transaction_entity()).await?;
379                    }
380                }
381             )* )*
382
383
384        for deleted_id in deleted_ids.iter() {
385            dynamic_dispatch::execute_delete(&tx, deleted_id).await?
386        }
387
388        for deleted_credential in $keystore_transaction.deleted_credentials.read().await.iter() {
389            MlsCredential::delete_by_credential(&tx, deleted_credential.to_owned()).await?;
390        }
391
392         tx.commit_tx().await?;
393     };
394}
395
396impl KeystoreTransaction {
397    /// Persists all the operations in the database. It will effectively open a transaction
398    /// internally, perform all the buffered operations and commit.
399    pub(crate) async fn commit(&self, db: &Database) -> Result<(), CryptoKeystoreError> {
400        commit_transaction!(
401            self, db,
402            [
403                (identifier_01, MlsCredential),
404                (identifier_02, MlsSignatureKeyPair),
405                (identifier_03, MlsHpkePrivateKey),
406                (identifier_04, MlsEncryptionKeyPair),
407                (identifier_05, MlsEpochEncryptionKeyPair),
408                (identifier_06, MlsPskBundle),
409                (identifier_07, MlsKeyPackage),
410                (identifier_08, PersistedMlsGroup),
411                (identifier_09, PersistedMlsPendingGroup),
412                (identifier_10, MlsPendingMessage),
413                (identifier_11, E2eiEnrollment),
414                // (identifier_12, E2eiRefreshToken),
415                (identifier_13, E2eiAcmeCA),
416                (identifier_14, E2eiIntermediateCert),
417                (identifier_15, E2eiCrl),
418                (identifier_16, ConsumerData)
419            ],
420            proteus_types: [
421                (identifier_17, ProteusPrekey),
422                (identifier_18, ProteusIdentity),
423                (identifier_19, ProteusSession)
424            ]
425        );
426
427        Ok(())
428    }
429}