core_crypto_keystore/connection/
mod.rs

1use std::{borrow::Borrow, fmt, ops::Deref};
2
3use async_trait::async_trait;
4use sha2::{Digest as _, Sha256};
5use zeroize::{Zeroize, ZeroizeOnDrop};
6
7pub mod platform {
8    cfg_if::cfg_if! {
9        if #[cfg(target_family = "wasm")] {
10            mod wasm;
11            pub use self::wasm::WasmConnection as KeystoreDatabaseConnection;
12            pub use wasm::storage;
13            pub use self::wasm::storage::WasmStorageTransaction as TransactionWrapper;
14        } else {
15            mod generic;
16            pub use self::generic::SqlCipherConnection as KeystoreDatabaseConnection;
17            pub use self::generic::TransactionWrapper;
18            #[cfg(test)]
19            pub(crate) use generic::MigrationTarget;
20
21
22        }
23    }
24}
25
26use std::{ops::DerefMut, sync::Arc};
27
28use async_lock::{Mutex, MutexGuard, Semaphore};
29
30pub use self::platform::*;
31use crate::{
32    CryptoKeystoreError, CryptoKeystoreResult,
33    entities::{MlsPendingMessage, PersistedMlsGroupExt},
34    traits::{
35        BorrowPrimaryKey, Entity, EntityDatabaseMutation, EntityDeleteBorrowed, EntityGetBorrowed, FetchFromDatabase,
36        KeyType,
37    },
38    transaction::KeystoreTransaction,
39};
40
41/// Limit on the length of a blob to be stored in the database.
42///
43/// This limit applies to both SQLCipher-backed stores and WASM.
44/// This limit is conservative on purpose when targeting WASM, as the lower bound that exists is Safari with a limit of
45/// 1GB per origin.
46///
47/// See: [SQLite limits](https://www.sqlite.org/limits.html)
48/// See: [IndexedDB limits](https://stackoverflow.com/a/63019999/1934177)
49pub const MAX_BLOB_LEN: usize = 1_000_000_000;
50
51#[cfg(not(target_family = "wasm"))]
52// ? Because of UniFFI async requirements, we need our keystore to be Send as well now
53pub trait DatabaseConnectionRequirements: Sized + Send {}
54#[cfg(target_family = "wasm")]
55// ? On the other hand, things cannot be Send on WASM because of platform restrictions (all things are copied across the
56// FFI)
57pub trait DatabaseConnectionRequirements: Sized {}
58
59/// The key used to encrypt the database.
60#[derive(Clone, Zeroize, ZeroizeOnDrop, derive_more::From, PartialEq, Eq)]
61pub struct DatabaseKey([u8; Self::LEN]);
62
63impl DatabaseKey {
64    pub const LEN: usize = 32;
65
66    pub fn generate() -> DatabaseKey {
67        DatabaseKey(rand::random::<[u8; Self::LEN]>())
68    }
69}
70
71impl fmt::Debug for DatabaseKey {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
73        f.write_str("DatabaseKey(hash=")?;
74        for x in Sha256::digest(self).as_slice().iter().take(10) {
75            fmt::LowerHex::fmt(x, f)?
76        }
77        f.write_str("...)")
78    }
79}
80
81impl AsRef<[u8]> for DatabaseKey {
82    fn as_ref(&self) -> &[u8] {
83        &self.0
84    }
85}
86
87impl Deref for DatabaseKey {
88    type Target = [u8];
89
90    fn deref(&self) -> &Self::Target {
91        &self.0
92    }
93}
94
95impl TryFrom<&[u8]> for DatabaseKey {
96    type Error = CryptoKeystoreError;
97
98    fn try_from(buf: &[u8]) -> Result<Self, Self::Error> {
99        if buf.len() != Self::LEN {
100            Err(CryptoKeystoreError::InvalidDbKeySize {
101                expected: Self::LEN,
102                actual: buf.len(),
103            })
104        } else {
105            Ok(Self(buf.try_into().unwrap()))
106        }
107    }
108}
109
110#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
111#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
112pub trait DatabaseConnection<'a>: DatabaseConnectionRequirements {
113    type Connection: 'a;
114
115    async fn open(name: &str, key: &DatabaseKey) -> CryptoKeystoreResult<Self>;
116
117    async fn open_in_memory(key: &DatabaseKey) -> CryptoKeystoreResult<Self>;
118
119    async fn update_key(&mut self, new_key: &DatabaseKey) -> CryptoKeystoreResult<()>;
120
121    /// Clear all data from the database and close it.
122    async fn wipe(self) -> CryptoKeystoreResult<()>;
123
124    fn check_buffer_size(size: usize) -> CryptoKeystoreResult<()> {
125        #[cfg(not(target_family = "wasm"))]
126        if size > i32::MAX as usize {
127            return Err(CryptoKeystoreError::BlobTooBig);
128        }
129
130        if size >= MAX_BLOB_LEN {
131            return Err(CryptoKeystoreError::BlobTooBig);
132        }
133
134        Ok(())
135    }
136}
137
138#[derive(Debug, Clone)]
139pub struct Database {
140    pub(crate) conn: Arc<Mutex<Option<KeystoreDatabaseConnection>>>,
141    pub(crate) transaction: Arc<Mutex<Option<KeystoreTransaction>>>,
142    transaction_semaphore: Arc<Semaphore>,
143}
144
145const ALLOWED_CONCURRENT_TRANSACTIONS_COUNT: usize = 1;
146
147/// Interface to fetch from the database either from the connection directly or through a
148/// transaaction
149#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
150#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
151pub trait OldFetchFromDatabase: Send + Sync {
152    async fn find<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
153        &self,
154        id: impl AsRef<[u8]> + Send,
155    ) -> CryptoKeystoreResult<Option<E>>;
156
157    async fn find_unique<U: crate::entities::UniqueEntity<ConnectionType = KeystoreDatabaseConnection>>(
158        &self,
159    ) -> CryptoKeystoreResult<U>;
160
161    async fn find_all<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
162        &self,
163        params: crate::entities::EntityFindParams,
164    ) -> CryptoKeystoreResult<Vec<E>>;
165
166    async fn find_many<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
167        &self,
168        ids: &[Vec<u8>],
169    ) -> CryptoKeystoreResult<Vec<E>>;
170    async fn count<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(&self) -> CryptoKeystoreResult<usize>;
171}
172
173// SAFETY: this has mutexes and atomics protecting underlying data so this is safe to share between threads
174unsafe impl Send for Database {}
175// SAFETY: this has mutexes and atomics protecting underlying data so this is safe to share between threads
176unsafe impl Sync for Database {}
177
178/// Where to open a connection
179#[derive(Debug, Clone)]
180pub enum ConnectionType<'a> {
181    /// This connection is persistent at the provided path
182    Persistent(&'a str),
183    /// This connection is transient and lives in memory
184    InMemory,
185}
186
187/// Exclusive access to the database connection
188///
189/// Note that this is only ever constructed when we already hold exclusive access,
190/// and the connection has already been tested to ensure that it is non-empty.
191pub struct ConnectionGuard<'a> {
192    guard: MutexGuard<'a, Option<KeystoreDatabaseConnection>>,
193}
194
195impl<'a> TryFrom<MutexGuard<'a, Option<KeystoreDatabaseConnection>>> for ConnectionGuard<'a> {
196    type Error = CryptoKeystoreError;
197
198    fn try_from(guard: MutexGuard<'a, Option<KeystoreDatabaseConnection>>) -> Result<Self, Self::Error> {
199        guard
200            .is_some()
201            .then_some(Self { guard })
202            .ok_or(CryptoKeystoreError::Closed)
203    }
204}
205
206impl Deref for ConnectionGuard<'_> {
207    type Target = KeystoreDatabaseConnection;
208
209    fn deref(&self) -> &Self::Target {
210        self.guard
211            .as_ref()
212            .expect("we have exclusive access and already checked that the connection exists")
213    }
214}
215
216impl DerefMut for ConnectionGuard<'_> {
217    fn deref_mut(&mut self) -> &mut Self::Target {
218        self.guard
219            .as_mut()
220            .expect("we have exclusive access and already checked that the connection exists")
221    }
222}
223
224// Only the functions in this impl block directly mess with `self.conn`
225impl Database {
226    pub async fn open(location: ConnectionType<'_>, key: &DatabaseKey) -> CryptoKeystoreResult<Self> {
227        let conn = match location {
228            ConnectionType::Persistent(name) => KeystoreDatabaseConnection::open(name, key).await?,
229            ConnectionType::InMemory => KeystoreDatabaseConnection::open_in_memory(key).await?,
230        };
231        let conn = Mutex::new(Some(conn));
232        #[allow(clippy::arc_with_non_send_sync)] // see https://github.com/rustwasm/wasm-bindgen/pull/955
233        let conn = Arc::new(conn);
234        Ok(Self {
235            conn,
236            transaction: Default::default(),
237            transaction_semaphore: Arc::new(Semaphore::new(ALLOWED_CONCURRENT_TRANSACTIONS_COUNT)),
238        })
239    }
240
241    #[cfg(all(test, not(target_family = "wasm")))]
242    pub(crate) async fn open_at_schema_version(
243        name: &str,
244        key: &DatabaseKey,
245        version: MigrationTarget,
246    ) -> CryptoKeystoreResult<Self> {
247        let conn = KeystoreDatabaseConnection::init_with_key_at_schema_version(name, key, version)?;
248        let conn = Mutex::new(Some(conn));
249        let conn = Arc::new(conn);
250        Ok(Self {
251            conn,
252            transaction: Default::default(),
253            transaction_semaphore: Arc::new(Semaphore::new(ALLOWED_CONCURRENT_TRANSACTIONS_COUNT)),
254        })
255    }
256
257    /// Get direct exclusive access to the connection.
258    pub async fn conn(&self) -> CryptoKeystoreResult<ConnectionGuard<'_>> {
259        self.conn.lock().await.try_into()
260    }
261
262    /// Wait for any running transaction to finish, then take the connection out of this database,
263    /// preventing this database from being used again.
264    async fn take(&self) -> CryptoKeystoreResult<KeystoreDatabaseConnection> {
265        let _semaphore = self.transaction_semaphore.acquire_arc().await;
266
267        let mut guard = self.conn.lock().await;
268        guard.take().ok_or(CryptoKeystoreError::Closed)
269    }
270
271    // Close this database connection
272    pub async fn close(&self) -> CryptoKeystoreResult<()> {
273        #[cfg(not(target_family = "wasm"))]
274        self.take().await?;
275
276        #[cfg(target_family = "wasm")]
277        {
278            let conn = self.take().await?;
279            conn.close().await?;
280        }
281        Ok(())
282    }
283
284    /// Close this database and delete its contents.
285    pub async fn wipe(&self) -> CryptoKeystoreResult<()> {
286        self.take().await?.wipe().await
287    }
288
289    pub async fn migrate_db_key_type_to_bytes(
290        name: &str,
291        old_key: &str,
292        new_key: &DatabaseKey,
293    ) -> CryptoKeystoreResult<()> {
294        KeystoreDatabaseConnection::migrate_db_key_type_to_bytes(name, old_key, new_key).await
295    }
296
297    pub async fn update_key(&mut self, new_key: &DatabaseKey) -> CryptoKeystoreResult<()> {
298        self.conn().await?.update_key(new_key).await
299    }
300
301    /// Waits for the current transaction to be committed or rolled back, then starts a new one.
302    pub async fn new_transaction(&self) -> CryptoKeystoreResult<()> {
303        let semaphore = self.transaction_semaphore.acquire_arc().await;
304        let mut transaction_guard = self.transaction.lock().await;
305        *transaction_guard = Some(KeystoreTransaction::new(semaphore).await?);
306        Ok(())
307    }
308
309    pub async fn commit_transaction(&self) -> CryptoKeystoreResult<()> {
310        let mut transaction_guard = self.transaction.lock().await;
311        let Some(transaction) = transaction_guard.as_ref() else {
312            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
313        };
314        transaction.commit(self).await?;
315        *transaction_guard = None;
316        Ok(())
317    }
318
319    pub async fn rollback_transaction(&self) -> CryptoKeystoreResult<()> {
320        let mut transaction_guard = self.transaction.lock().await;
321        if transaction_guard.is_none() {
322            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
323        };
324        *transaction_guard = None;
325        Ok(())
326    }
327
328    pub async fn child_groups<'a, E>(&self, entity: E) -> CryptoKeystoreResult<Vec<E>>
329    where
330        E: Clone + Entity + EntityDatabaseMutation<'a> + BorrowPrimaryKey + PersistedMlsGroupExt + Send + Sync,
331        for<'pk> &'pk <E as BorrowPrimaryKey>::BorrowedPrimaryKey: KeyType,
332    {
333        let mut conn = self.conn().await?;
334        let persisted_records = entity.child_groups(conn.deref_mut()).await?;
335
336        let transaction_guard = self.transaction.lock().await;
337        let Some(transaction) = transaction_guard.as_ref() else {
338            return Ok(persisted_records);
339        };
340        transaction.child_groups(entity, persisted_records).await
341    }
342
343    pub async fn save<'a, E>(&self, entity: E) -> CryptoKeystoreResult<E::AutoGeneratedFields>
344    where
345        E: Entity + EntityDatabaseMutation<'a> + Send + Sync,
346    {
347        let transaction_guard = self.transaction.lock().await;
348        let Some(transaction) = transaction_guard.as_ref() else {
349            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
350        };
351        transaction.save(entity).await
352    }
353
354    pub async fn remove<'a, E>(&self, id: &E::PrimaryKey) -> CryptoKeystoreResult<()>
355    where
356        E: Entity + EntityDatabaseMutation<'a>,
357    {
358        let transaction_guard = self.transaction.lock().await;
359        let Some(transaction) = transaction_guard.as_ref() else {
360            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
361        };
362        transaction.remove::<E>(id).await
363    }
364
365    pub async fn remove_borrowed<'a, E>(&self, id: &E::BorrowedPrimaryKey) -> CryptoKeystoreResult<()>
366    where
367        E: Entity + EntityDatabaseMutation<'a> + BorrowPrimaryKey + EntityDeleteBorrowed<'a>,
368    {
369        let transaction_guard = self.transaction.lock().await;
370        let Some(transaction) = transaction_guard.as_ref() else {
371            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
372        };
373        transaction.remove_borrowed::<E>(id).await
374    }
375
376    pub async fn find_pending_messages_by_conversation_id(
377        &self,
378        conversation_id: &[u8],
379    ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
380        let mut conn = self.conn().await?;
381        let persisted_records =
382            MlsPendingMessage::find_all_by_conversation_id(&mut conn, conversation_id, Default::default()).await?;
383
384        let transaction_guard = self.transaction.lock().await;
385        let Some(transaction) = transaction_guard.as_ref() else {
386            return Ok(persisted_records);
387        };
388        transaction
389            .find_pending_messages_by_conversation_id(conversation_id, persisted_records)
390            .await
391    }
392
393    pub async fn remove_pending_messages_by_conversation_id(
394        &self,
395        conversation_id: impl AsRef<[u8]> + Send,
396    ) -> CryptoKeystoreResult<()> {
397        let transaction_guard = self.transaction.lock().await;
398        let Some(transaction) = transaction_guard.as_ref() else {
399            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
400        };
401        transaction
402            .remove_pending_messages_by_conversation_id(conversation_id)
403            .await;
404        Ok(())
405    }
406}
407
408#[cfg_attr(target_family = "wasm", async_trait(?Send))]
409#[cfg_attr(not(target_family = "wasm"), async_trait)]
410impl FetchFromDatabase for Database {
411    async fn get<E>(&self, id: &E::PrimaryKey) -> CryptoKeystoreResult<Option<E>>
412    where
413        E: Entity<ConnectionType = KeystoreDatabaseConnection> + Clone + Send + Sync,
414    {
415        // If a transaction is in progress...
416        if let Some(transaction) = self.transaction.lock().await.as_ref()
417            //... and it has information about this entity, ...
418            && let Some(cached_record) = transaction.get(id).await
419        {
420            return Ok(cached_record.map(Arc::unwrap_or_clone));
421        }
422
423        // Otherwise get it from the database
424        let mut conn = self.conn().await?;
425        E::get(&mut conn, id).await
426    }
427
428    async fn get_borrowed<E>(&self, id: &<E as BorrowPrimaryKey>::BorrowedPrimaryKey) -> CryptoKeystoreResult<Option<E>>
429    where
430        E: EntityGetBorrowed<ConnectionType = KeystoreDatabaseConnection> + Clone + Send + Sync,
431        E::PrimaryKey: Borrow<E::BorrowedPrimaryKey>,
432        for<'a> &'a E::BorrowedPrimaryKey: KeyType,
433    {
434        // If a transaction is in progress...
435        if let Some(transaction) = self.transaction.lock().await.as_ref()
436            //... and it has information about this entity, ...
437            && let Some(cached_record) = transaction.get_borrowed(id).await
438        {
439            return Ok(cached_record.map(Arc::unwrap_or_clone));
440        }
441
442        // Otherwise get it from the database
443        let mut conn = self.conn().await?;
444        E::get_borrowed(&mut conn, id).await
445    }
446
447    async fn count<E>(&self) -> CryptoKeystoreResult<u32>
448    where
449        E: Entity<ConnectionType = KeystoreDatabaseConnection> + Clone + Send + Sync,
450    {
451        if self.transaction.lock().await.is_some() {
452            // Unfortunately, we have to do this because of possible record id overlap
453            // between cache and db.
454            let count = self.load_all::<E>().await?.len();
455            Ok(count as _)
456        } else {
457            let mut conn = self.conn().await?;
458            E::count(&mut conn).await
459        }
460    }
461
462    async fn load_all<E>(&self) -> CryptoKeystoreResult<Vec<E>>
463    where
464        E: Entity<ConnectionType = KeystoreDatabaseConnection> + Clone + Send + Sync,
465    {
466        let mut conn = self.conn().await?;
467        let persisted_records = E::load_all(&mut conn).await?;
468
469        let transaction_guard = self.transaction.lock().await;
470        let Some(transaction) = transaction_guard.as_ref() else {
471            return Ok(persisted_records);
472        };
473        transaction.find_all(persisted_records).await
474    }
475}