core_crypto_keystore/connection/
mod.rs

1use std::{fmt, ops::Deref};
2
3use sha2::{Digest as _, Sha256};
4use zeroize::{Zeroize, ZeroizeOnDrop};
5
6pub mod platform {
7    cfg_if::cfg_if! {
8        if #[cfg(target_family = "wasm")] {
9            mod wasm;
10            pub use self::wasm::WasmConnection as KeystoreDatabaseConnection;
11            pub use wasm::storage;
12            pub use self::wasm::storage::WasmStorageTransaction as TransactionWrapper;
13        } else {
14            mod generic;
15            pub use self::generic::SqlCipherConnection as KeystoreDatabaseConnection;
16            pub use self::generic::TransactionWrapper;
17        }
18    }
19}
20
21use std::{ops::DerefMut, sync::Arc};
22
23use async_lock::{Mutex, MutexGuard, Semaphore};
24
25pub use self::platform::*;
26use crate::{
27    CryptoKeystoreError, CryptoKeystoreResult,
28    entities::{Entity, EntityFindParams, EntityTransactionExt, MlsPendingMessage, StringEntityId, UniqueEntity},
29    transaction::KeystoreTransaction,
30};
31
32/// Limit on the length of a blob to be stored in the database.
33///
34/// This limit applies to both SQLCipher-backed stores and WASM.
35/// This limit is conservative on purpose when targeting WASM, as the lower bound that exists is Safari with a limit of 1GB per origin.
36///
37/// See: [SQLite limits](https://www.sqlite.org/limits.html)
38/// See: [IndexedDB limits](https://stackoverflow.com/a/63019999/1934177)
39pub const MAX_BLOB_LEN: usize = 1_000_000_000;
40
41#[cfg(not(target_family = "wasm"))]
42// ? Because of UniFFI async requirements, we need our keystore to be Send as well now
43pub trait DatabaseConnectionRequirements: Sized + Send {}
44#[cfg(target_family = "wasm")]
45// ? On the other hand, things cannot be Send on WASM because of platform restrictions (all things are copied across the FFI)
46pub trait DatabaseConnectionRequirements: Sized {}
47
48/// The key used to encrypt the database.
49#[derive(Clone, Zeroize, ZeroizeOnDrop, derive_more::From, PartialEq, Eq)]
50pub struct DatabaseKey([u8; Self::LEN]);
51
52impl DatabaseKey {
53    pub const LEN: usize = 32;
54
55    pub fn generate() -> DatabaseKey {
56        DatabaseKey(rand::random::<[u8; Self::LEN]>())
57    }
58}
59
60impl fmt::Debug for DatabaseKey {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
62        f.write_str("DatabaseKey(hash=")?;
63        for x in Sha256::digest(self).as_slice().iter().take(10) {
64            fmt::LowerHex::fmt(x, f)?
65        }
66        f.write_str("...)")
67    }
68}
69
70impl AsRef<[u8]> for DatabaseKey {
71    fn as_ref(&self) -> &[u8] {
72        &self.0
73    }
74}
75
76impl Deref for DatabaseKey {
77    type Target = [u8];
78
79    fn deref(&self) -> &Self::Target {
80        &self.0
81    }
82}
83
84impl TryFrom<&[u8]> for DatabaseKey {
85    type Error = CryptoKeystoreError;
86
87    fn try_from(buf: &[u8]) -> Result<Self, Self::Error> {
88        if buf.len() != Self::LEN {
89            Err(CryptoKeystoreError::InvalidDbKeySize {
90                expected: Self::LEN,
91                actual: buf.len(),
92            })
93        } else {
94            Ok(Self(buf.try_into().unwrap()))
95        }
96    }
97}
98
99#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
100#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
101pub trait DatabaseConnection<'a>: DatabaseConnectionRequirements {
102    type Connection: 'a;
103
104    async fn open(name: &str, key: &DatabaseKey) -> CryptoKeystoreResult<Self>;
105
106    async fn open_in_memory(key: &DatabaseKey) -> CryptoKeystoreResult<Self>;
107
108    async fn update_key(&mut self, new_key: &DatabaseKey) -> CryptoKeystoreResult<()>;
109
110    async fn close(self) -> CryptoKeystoreResult<()>;
111
112    /// Default implementation of wipe
113    async fn wipe(self) -> CryptoKeystoreResult<()> {
114        self.close().await
115    }
116
117    fn check_buffer_size(size: usize) -> CryptoKeystoreResult<()> {
118        #[cfg(not(target_family = "wasm"))]
119        if size > i32::MAX as usize {
120            return Err(CryptoKeystoreError::BlobTooBig);
121        }
122
123        if size >= MAX_BLOB_LEN {
124            return Err(CryptoKeystoreError::BlobTooBig);
125        }
126
127        Ok(())
128    }
129}
130
131#[derive(Debug, Clone)]
132pub struct Database {
133    pub(crate) conn: Arc<Mutex<KeystoreDatabaseConnection>>,
134    pub(crate) transaction: Arc<Mutex<Option<KeystoreTransaction>>>,
135    transaction_semaphore: Arc<Semaphore>,
136}
137
138const ALLOWED_CONCURRENT_TRANSACTIONS_COUNT: usize = 1;
139
140/// Interface to fetch from the database either from the connection directly or through a
141/// transaaction
142#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
143#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
144pub trait FetchFromDatabase: Send + Sync {
145    async fn find<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
146        &self,
147        id: impl AsRef<[u8]> + Send,
148    ) -> CryptoKeystoreResult<Option<E>>;
149
150    async fn find_unique<U: UniqueEntity<ConnectionType = KeystoreDatabaseConnection>>(
151        &self,
152    ) -> CryptoKeystoreResult<U>;
153
154    async fn find_all<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
155        &self,
156        params: EntityFindParams,
157    ) -> CryptoKeystoreResult<Vec<E>>;
158
159    async fn find_many<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
160        &self,
161        ids: &[Vec<u8>],
162    ) -> CryptoKeystoreResult<Vec<E>>;
163    async fn count<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(&self) -> CryptoKeystoreResult<usize>;
164}
165
166// SAFETY: this has mutexes and atomics protecting underlying data so this is safe to share between threads
167unsafe impl Send for Database {}
168// SAFETY: this has mutexes and atomics protecting underlying data so this is safe to share between threads
169unsafe impl Sync for Database {}
170
171/// Where to open a connection
172#[derive(Debug, Clone)]
173pub enum ConnectionType<'a> {
174    /// This connection is persistent at the provided path
175    Persistent(&'a str),
176    /// This connection is transient and lives in memory
177    InMemory,
178}
179
180impl Database {
181    pub async fn open(location: ConnectionType<'_>, key: &DatabaseKey) -> CryptoKeystoreResult<Self> {
182        let conn = match location {
183            ConnectionType::Persistent(name) => KeystoreDatabaseConnection::open(name, key).await?.into(),
184            ConnectionType::InMemory => KeystoreDatabaseConnection::open_in_memory(key).await?.into(),
185        };
186        #[allow(clippy::arc_with_non_send_sync)] // see https://github.com/rustwasm/wasm-bindgen/pull/955
187        let conn = Arc::new(conn);
188        Ok(Self {
189            conn,
190            transaction: Default::default(),
191            transaction_semaphore: Arc::new(Semaphore::new(ALLOWED_CONCURRENT_TRANSACTIONS_COUNT)),
192        })
193    }
194
195    pub async fn borrow_conn(&self) -> CryptoKeystoreResult<MutexGuard<'_, KeystoreDatabaseConnection>> {
196        Ok(self.conn.lock().await)
197    }
198
199    pub async fn migrate_db_key_type_to_bytes(
200        name: &str,
201        old_key: &str,
202        new_key: &DatabaseKey,
203    ) -> CryptoKeystoreResult<()> {
204        KeystoreDatabaseConnection::migrate_db_key_type_to_bytes(name, old_key, new_key).await
205    }
206
207    pub async fn update_key(&mut self, new_key: &DatabaseKey) -> CryptoKeystoreResult<()> {
208        self.conn.lock().await.update_key(new_key).await
209    }
210
211    pub async fn wipe(self) -> CryptoKeystoreResult<()> {
212        if self.transaction.lock().await.is_some() {
213            return Err(CryptoKeystoreError::TransactionInProgress {
214                attempted_operation: "wipe()".to_string(),
215            });
216        }
217        let conn: KeystoreDatabaseConnection = Arc::into_inner(self.conn).unwrap().into_inner();
218        conn.wipe().await?;
219        Ok(())
220    }
221
222    /// Wait for any running transaction to finish, then close the database connection.
223    pub async fn close(self) -> CryptoKeystoreResult<()> {
224        // Wait for any running transaction to finish
225        let _semaphore = self.transaction_semaphore.acquire_arc().await;
226        // Ensure that there's only one reference to the connection
227        let Some(conn) = Arc::into_inner(self.conn) else {
228            return Err(CryptoKeystoreError::CannotClose);
229        };
230        let conn = conn.into_inner();
231        conn.close().await?;
232        Ok(())
233    }
234
235    /// Waits for the current transaction to be committed or rolled back, then starts a new one.
236    pub async fn new_transaction(&self) -> CryptoKeystoreResult<()> {
237        let semaphore = self.transaction_semaphore.acquire_arc().await;
238        let mut transaction_guard = self.transaction.lock().await;
239        *transaction_guard = Some(KeystoreTransaction::new(semaphore).await?);
240        Ok(())
241    }
242
243    pub async fn commit_transaction(&self) -> CryptoKeystoreResult<()> {
244        let mut transaction_guard = self.transaction.lock().await;
245        let Some(transaction) = transaction_guard.as_ref() else {
246            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
247        };
248        transaction.commit(self).await?;
249        *transaction_guard = None;
250        Ok(())
251    }
252
253    pub async fn rollback_transaction(&self) -> CryptoKeystoreResult<()> {
254        let mut transaction_guard = self.transaction.lock().await;
255        if transaction_guard.is_none() {
256            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
257        };
258        *transaction_guard = None;
259        Ok(())
260    }
261
262    pub async fn child_groups<
263        E: Entity<ConnectionType = KeystoreDatabaseConnection> + crate::entities::PersistedMlsGroupExt + Sync,
264    >(
265        &self,
266        entity: E,
267    ) -> CryptoKeystoreResult<Vec<E>> {
268        let mut conn = self.conn.lock().await;
269        let persisted_records = entity.child_groups(conn.deref_mut()).await?;
270
271        let transaction_guard = self.transaction.lock().await;
272        let Some(transaction) = transaction_guard.as_ref() else {
273            return Ok(persisted_records);
274        };
275        transaction.child_groups(entity, persisted_records).await
276    }
277
278    pub async fn save<E: Entity<ConnectionType = KeystoreDatabaseConnection> + Sync + EntityTransactionExt>(
279        &self,
280        entity: E,
281    ) -> CryptoKeystoreResult<E> {
282        let transaction_guard = self.transaction.lock().await;
283        let Some(transaction) = transaction_guard.as_ref() else {
284            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
285        };
286        transaction.save_mut(entity).await
287    }
288
289    pub async fn remove<
290        E: Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt,
291        S: AsRef<[u8]>,
292    >(
293        &self,
294        id: S,
295    ) -> CryptoKeystoreResult<()> {
296        let transaction_guard = self.transaction.lock().await;
297        let Some(transaction) = transaction_guard.as_ref() else {
298            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
299        };
300        transaction.remove::<E, S>(id).await
301    }
302
303    pub async fn find_pending_messages_by_conversation_id(
304        &self,
305        conversation_id: &[u8],
306    ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
307        let mut conn = self.conn.lock().await;
308        let persisted_records =
309            MlsPendingMessage::find_all_by_conversation_id(&mut conn, conversation_id, Default::default()).await?;
310
311        let transaction_guard = self.transaction.lock().await;
312        let Some(transaction) = transaction_guard.as_ref() else {
313            return Ok(persisted_records);
314        };
315        transaction
316            .find_pending_messages_by_conversation_id(conversation_id, persisted_records)
317            .await
318    }
319
320    pub async fn remove_pending_messages_by_conversation_id(
321        &self,
322        conversation_id: impl AsRef<[u8]> + Send,
323    ) -> CryptoKeystoreResult<()> {
324        let transaction_guard = self.transaction.lock().await;
325        let Some(transaction) = transaction_guard.as_ref() else {
326            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
327        };
328        transaction
329            .remove_pending_messages_by_conversation_id(conversation_id)
330            .await
331    }
332
333    pub async fn cred_delete_by_credential(&self, cred: Vec<u8>) -> CryptoKeystoreResult<()> {
334        let transaction_guard = self.transaction.lock().await;
335        let Some(transaction) = transaction_guard.as_ref() else {
336            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
337        };
338        transaction.cred_delete_by_credential(cred).await
339    }
340}
341
342#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
343#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
344impl FetchFromDatabase for Database {
345    async fn find<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
346        &self,
347        id: impl AsRef<[u8]> + Send,
348    ) -> CryptoKeystoreResult<Option<E>> {
349        // If a transaction is in progress...
350        if let Some(transaction) = self.transaction.lock().await.as_ref()
351            //... and it has information about this entity, ...
352            && let Some(cached_record) = transaction.find::<E>(id.as_ref()).await?
353        {
354            // ... return that result
355            return Ok(cached_record);
356        }
357
358        // Otherwise get it from the database
359        let mut conn = self.conn.lock().await;
360        E::find_one(&mut conn, &id.as_ref().into()).await
361    }
362
363    async fn find_unique<U: UniqueEntity>(&self) -> CryptoKeystoreResult<U> {
364        // If a transaction is in progress...
365        if let Some(transaction) = self.transaction.lock().await.as_ref()
366            //... and it has information about this entity, ...
367            && let Some(cached_record) = transaction.find_unique::<U>().await?
368        {
369            // ... return that result
370            return Ok(cached_record);
371        }
372        // Otherwise get it from the database
373        let mut conn = self.conn.lock().await;
374        U::find_unique(&mut conn).await
375    }
376
377    async fn find_all<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
378        &self,
379        params: EntityFindParams,
380    ) -> CryptoKeystoreResult<Vec<E>> {
381        let mut conn = self.conn.lock().await;
382        let persisted_records = E::find_all(&mut conn, params.clone()).await?;
383
384        let transaction_guard = self.transaction.lock().await;
385        let Some(transaction) = transaction_guard.as_ref() else {
386            return Ok(persisted_records);
387        };
388        transaction.find_all(persisted_records, params).await
389    }
390
391    async fn find_many<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
392        &self,
393        ids: &[Vec<u8>],
394    ) -> CryptoKeystoreResult<Vec<E>> {
395        let entity_ids: Vec<StringEntityId> = ids.iter().map(|id| id.as_slice().into()).collect();
396        let mut conn = self.conn.lock().await;
397        let persisted_records = E::find_many(&mut conn, &entity_ids).await?;
398
399        let transaction_guard = self.transaction.lock().await;
400        let Some(transaction) = transaction_guard.as_ref() else {
401            return Ok(persisted_records);
402        };
403        transaction.find_many(persisted_records, ids).await
404    }
405
406    async fn count<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(&self) -> CryptoKeystoreResult<usize> {
407        if self.transaction.lock().await.is_some() {
408            // Unfortunately, we have to do this because of possible record id overlap
409            // between cache and db.
410            return Ok(self.find_all::<E>(Default::default()).await?.len());
411        };
412        let mut conn = self.conn.lock().await;
413        E::count(&mut conn).await
414    }
415}