core_crypto_keystore/connection/
mod.rs

1use std::fmt;
2use std::ops::Deref;
3
4use sha2::{Digest as _, Sha256};
5use zeroize::{Zeroize, ZeroizeOnDrop};
6
7pub mod platform {
8    cfg_if::cfg_if! {
9        if #[cfg(target_family = "wasm")] {
10            mod wasm;
11            pub use self::wasm::WasmConnection as KeystoreDatabaseConnection;
12            pub use wasm::storage;
13            pub use self::wasm::storage::WasmStorageTransaction as TransactionWrapper;
14        } else {
15            mod generic;
16            pub use self::generic::SqlCipherConnection as KeystoreDatabaseConnection;
17            pub use self::generic::TransactionWrapper;
18        }
19    }
20}
21
22pub use self::platform::*;
23use crate::entities::{Entity, EntityFindParams, MlsPendingMessage, StringEntityId};
24use std::ops::DerefMut;
25
26use crate::entities::{EntityTransactionExt, UniqueEntity};
27use crate::transaction::KeystoreTransaction;
28use crate::{CryptoKeystoreError, CryptoKeystoreResult};
29use async_lock::{Mutex, MutexGuard, Semaphore};
30use std::sync::Arc;
31
32/// Limit on the length of a blob to be stored in the database.
33///
34/// This limit applies to both SQLCipher-backed stores and WASM.
35/// This limit is conservative on purpose when targeting WASM, as the lower bound that exists is Safari with a limit of 1GB per origin.
36///
37/// See: [SQLite limits](https://www.sqlite.org/limits.html)
38/// See: [IndexedDB limits](https://stackoverflow.com/a/63019999/1934177)
39pub const MAX_BLOB_LEN: usize = 1_000_000_000;
40
41#[cfg(not(target_family = "wasm"))]
42// ? Because of UniFFI async requirements, we need our keystore to be Send as well now
43pub trait DatabaseConnectionRequirements: Sized + Send {}
44#[cfg(target_family = "wasm")]
45// ? On the other hand, things cannot be Send on WASM because of platform restrictions (all things are copied across the FFI)
46pub trait DatabaseConnectionRequirements: Sized {}
47
48/// The key used to encrypt the database.
49#[derive(Clone, Zeroize, ZeroizeOnDrop, derive_more::From, PartialEq, Eq)]
50pub struct DatabaseKey([u8; Self::LEN]);
51
52impl DatabaseKey {
53    pub const LEN: usize = 32;
54
55    pub fn generate() -> DatabaseKey {
56        DatabaseKey(rand::random::<[u8; Self::LEN]>())
57    }
58}
59
60impl fmt::Debug for DatabaseKey {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
62        f.write_str("DatabaseKey(hash=")?;
63        for x in Sha256::digest(self).as_slice().iter().take(10) {
64            fmt::LowerHex::fmt(x, f)?
65        }
66        f.write_str("...)")
67    }
68}
69
70impl AsRef<[u8]> for DatabaseKey {
71    fn as_ref(&self) -> &[u8] {
72        &self.0
73    }
74}
75
76impl Deref for DatabaseKey {
77    type Target = [u8];
78
79    fn deref(&self) -> &Self::Target {
80        &self.0
81    }
82}
83
84impl TryFrom<&[u8]> for DatabaseKey {
85    type Error = CryptoKeystoreError;
86
87    fn try_from(buf: &[u8]) -> Result<Self, Self::Error> {
88        if buf.len() != Self::LEN {
89            Err(CryptoKeystoreError::InvalidDbKeySize {
90                expected: Self::LEN,
91                actual: buf.len(),
92            })
93        } else {
94            Ok(Self(buf.try_into().unwrap()))
95        }
96    }
97}
98
99#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
100#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
101pub trait DatabaseConnection<'a>: DatabaseConnectionRequirements {
102    type Connection: 'a;
103
104    async fn open(name: &str, key: &DatabaseKey) -> CryptoKeystoreResult<Self>;
105
106    async fn open_in_memory(key: &DatabaseKey) -> CryptoKeystoreResult<Self>;
107
108    async fn update_key(&mut self, new_key: &DatabaseKey) -> CryptoKeystoreResult<()>;
109
110    async fn close(self) -> CryptoKeystoreResult<()>;
111
112    /// Default implementation of wipe
113    async fn wipe(self) -> CryptoKeystoreResult<()> {
114        self.close().await
115    }
116
117    fn check_buffer_size(size: usize) -> CryptoKeystoreResult<()> {
118        #[cfg(not(target_family = "wasm"))]
119        if size > i32::MAX as usize {
120            return Err(CryptoKeystoreError::BlobTooBig);
121        }
122
123        if size >= MAX_BLOB_LEN {
124            return Err(CryptoKeystoreError::BlobTooBig);
125        }
126
127        Ok(())
128    }
129}
130
131#[derive(Debug, Clone)]
132pub struct Connection {
133    pub(crate) conn: Arc<Mutex<KeystoreDatabaseConnection>>,
134    pub(crate) transaction: Arc<Mutex<Option<KeystoreTransaction>>>,
135    transaction_semaphore: Arc<Semaphore>,
136}
137
138const ALLOWED_CONCURRENT_TRANSACTIONS_COUNT: usize = 1;
139
140/// Interface to fetch from the database either from the connection directly or through a
141/// transaaction
142#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
143#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
144pub trait FetchFromDatabase: Send + Sync {
145    async fn find<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
146        &self,
147        id: &[u8],
148    ) -> CryptoKeystoreResult<Option<E>>;
149
150    async fn find_unique<U: UniqueEntity<ConnectionType = KeystoreDatabaseConnection>>(
151        &self,
152    ) -> CryptoKeystoreResult<U>;
153
154    async fn find_all<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
155        &self,
156        params: EntityFindParams,
157    ) -> CryptoKeystoreResult<Vec<E>>;
158
159    async fn find_many<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
160        &self,
161        ids: &[Vec<u8>],
162    ) -> CryptoKeystoreResult<Vec<E>>;
163    async fn count<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(&self) -> CryptoKeystoreResult<usize>;
164}
165
166// SAFETY: this has mutexes and atomics protecting underlying data so this is safe to share between threads
167unsafe impl Send for Connection {}
168// SAFETY: this has mutexes and atomics protecting underlying data so this is safe to share between threads
169unsafe impl Sync for Connection {}
170
171/// Where to open a connection
172#[derive(Debug, Clone)]
173pub enum ConnectionType<'a> {
174    /// This connection is persistent at the provided path
175    Persistent(&'a str),
176    /// This connection is transient and lives in memory
177    InMemory,
178}
179
180impl Connection {
181    pub async fn open(location: ConnectionType<'_>, key: &DatabaseKey) -> CryptoKeystoreResult<Self> {
182        let conn = match location {
183            ConnectionType::Persistent(name) => KeystoreDatabaseConnection::open(name, key).await?.into(),
184            ConnectionType::InMemory => KeystoreDatabaseConnection::open_in_memory(key).await?.into(),
185        };
186        #[allow(clippy::arc_with_non_send_sync)] // see https://github.com/rustwasm/wasm-bindgen/pull/955
187        let conn = Arc::new(conn);
188        Ok(Self {
189            conn,
190            transaction: Default::default(),
191            transaction_semaphore: Arc::new(Semaphore::new(ALLOWED_CONCURRENT_TRANSACTIONS_COUNT)),
192        })
193    }
194
195    pub async fn borrow_conn(&self) -> CryptoKeystoreResult<MutexGuard<'_, KeystoreDatabaseConnection>> {
196        Ok(self.conn.lock().await)
197    }
198
199    pub async fn migrate_db_key_type_to_bytes(
200        name: &str,
201        old_key: &str,
202        new_key: &DatabaseKey,
203    ) -> CryptoKeystoreResult<()> {
204        KeystoreDatabaseConnection::migrate_db_key_type_to_bytes(name, old_key, new_key).await
205    }
206
207    pub async fn update_key(&mut self, new_key: &DatabaseKey) -> CryptoKeystoreResult<()> {
208        self.conn.lock().await.update_key(new_key).await
209    }
210
211    pub async fn wipe(self) -> CryptoKeystoreResult<()> {
212        if self.transaction.lock().await.is_some() {
213            return Err(CryptoKeystoreError::TransactionInProgress {
214                attempted_operation: "wipe()".to_string(),
215            });
216        }
217        let conn: KeystoreDatabaseConnection = Arc::into_inner(self.conn).unwrap().into_inner();
218        conn.wipe().await?;
219        Ok(())
220    }
221
222    /// Wait for any running transaction to finish, then close the database connection.
223    pub async fn close(self) -> CryptoKeystoreResult<()> {
224        // Wait for any running transaction to finish
225        let _semaphore = self.transaction_semaphore.acquire_arc().await;
226        // Ensure that there's only one reference to the connection
227        let Some(conn) = Arc::into_inner(self.conn) else {
228            return Err(CryptoKeystoreError::CannotClose);
229        };
230        let conn = conn.into_inner();
231        conn.close().await?;
232        Ok(())
233    }
234
235    /// Waits for the current transaction to be committed or rolled back, then starts a new one.
236    pub async fn new_transaction(&self) -> CryptoKeystoreResult<()> {
237        let semaphore = self.transaction_semaphore.acquire_arc().await;
238        let mut transaction_guard = self.transaction.lock().await;
239        *transaction_guard = Some(KeystoreTransaction::new(semaphore).await?);
240        Ok(())
241    }
242
243    pub async fn commit_transaction(&self) -> CryptoKeystoreResult<()> {
244        let mut transaction_guard = self.transaction.lock().await;
245        let Some(transaction) = transaction_guard.as_ref() else {
246            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
247        };
248        transaction.commit(self).await?;
249        *transaction_guard = None;
250        Ok(())
251    }
252
253    pub async fn rollback_transaction(&self) -> CryptoKeystoreResult<()> {
254        let mut transaction_guard = self.transaction.lock().await;
255        if transaction_guard.is_none() {
256            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
257        };
258        *transaction_guard = None;
259        Ok(())
260    }
261
262    pub async fn child_groups<
263        E: Entity<ConnectionType = KeystoreDatabaseConnection> + crate::entities::PersistedMlsGroupExt + Sync,
264    >(
265        &self,
266        entity: E,
267    ) -> CryptoKeystoreResult<Vec<E>> {
268        let mut conn = self.conn.lock().await;
269        let persisted_records = entity.child_groups(conn.deref_mut()).await?;
270
271        let transaction_guard = self.transaction.lock().await;
272        let Some(transaction) = transaction_guard.as_ref() else {
273            return Ok(persisted_records);
274        };
275        transaction.child_groups(entity, persisted_records).await
276    }
277
278    pub async fn save<E: Entity<ConnectionType = KeystoreDatabaseConnection> + Sync + EntityTransactionExt>(
279        &self,
280        entity: E,
281    ) -> CryptoKeystoreResult<E> {
282        let transaction_guard = self.transaction.lock().await;
283        let Some(transaction) = transaction_guard.as_ref() else {
284            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
285        };
286        transaction.save_mut(entity).await
287    }
288
289    pub async fn remove<
290        E: Entity<ConnectionType = KeystoreDatabaseConnection> + EntityTransactionExt,
291        S: AsRef<[u8]>,
292    >(
293        &self,
294        id: S,
295    ) -> CryptoKeystoreResult<()> {
296        let transaction_guard = self.transaction.lock().await;
297        let Some(transaction) = transaction_guard.as_ref() else {
298            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
299        };
300        transaction.remove::<E, S>(id).await
301    }
302
303    pub async fn find_pending_messages_by_conversation_id(
304        &self,
305        conversation_id: &[u8],
306    ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
307        let mut conn = self.conn.lock().await;
308        let persisted_records =
309            MlsPendingMessage::find_all_by_conversation_id(&mut conn, conversation_id, Default::default()).await?;
310
311        let transaction_guard = self.transaction.lock().await;
312        let Some(transaction) = transaction_guard.as_ref() else {
313            return Ok(persisted_records);
314        };
315        transaction
316            .find_pending_messages_by_conversation_id(conversation_id, persisted_records)
317            .await
318    }
319
320    pub async fn remove_pending_messages_by_conversation_id(&self, conversation_id: &[u8]) -> CryptoKeystoreResult<()> {
321        let transaction_guard = self.transaction.lock().await;
322        let Some(transaction) = transaction_guard.as_ref() else {
323            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
324        };
325        transaction
326            .remove_pending_messages_by_conversation_id(conversation_id)
327            .await
328    }
329
330    pub async fn cred_delete_by_credential(&self, cred: Vec<u8>) -> CryptoKeystoreResult<()> {
331        let transaction_guard = self.transaction.lock().await;
332        let Some(transaction) = transaction_guard.as_ref() else {
333            return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
334        };
335        transaction.cred_delete_by_credential(cred).await
336    }
337}
338
339#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
340#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
341impl FetchFromDatabase for Connection {
342    async fn find<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
343        &self,
344        id: &[u8],
345    ) -> CryptoKeystoreResult<Option<E>> {
346        // If a transaction is in progress...
347        if let Some(transaction) = self.transaction.lock().await.as_ref()
348            //... and it has information about this entity, ...
349            && let Some(cached_record) = transaction.find::<E>(id).await?
350        {
351            // ... return that result
352            return Ok(cached_record);
353        }
354
355        // Otherwise get it from the database
356        let mut conn = self.conn.lock().await;
357        E::find_one(&mut conn, &id.into()).await
358    }
359
360    async fn find_unique<U: UniqueEntity>(&self) -> CryptoKeystoreResult<U> {
361        // If a transaction is in progress...
362        if let Some(transaction) = self.transaction.lock().await.as_ref()
363            //... and it has information about this entity, ...
364            && let Some(cached_record) = transaction.find_unique::<U>().await?
365        {
366            // ... return that result
367            return Ok(cached_record);
368        }
369        // Otherwise get it from the database
370        let mut conn = self.conn.lock().await;
371        U::find_unique(&mut conn).await
372    }
373
374    async fn find_all<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
375        &self,
376        params: EntityFindParams,
377    ) -> CryptoKeystoreResult<Vec<E>> {
378        let mut conn = self.conn.lock().await;
379        let persisted_records = E::find_all(&mut conn, params.clone()).await?;
380
381        let transaction_guard = self.transaction.lock().await;
382        let Some(transaction) = transaction_guard.as_ref() else {
383            return Ok(persisted_records);
384        };
385        transaction.find_all(persisted_records, params).await
386    }
387
388    async fn find_many<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
389        &self,
390        ids: &[Vec<u8>],
391    ) -> CryptoKeystoreResult<Vec<E>> {
392        let entity_ids: Vec<StringEntityId> = ids.iter().map(|id| id.as_slice().into()).collect();
393        let mut conn = self.conn.lock().await;
394        let persisted_records = E::find_many(&mut conn, &entity_ids).await?;
395
396        let transaction_guard = self.transaction.lock().await;
397        let Some(transaction) = transaction_guard.as_ref() else {
398            return Ok(persisted_records);
399        };
400        transaction.find_many(persisted_records, ids).await
401    }
402
403    async fn count<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(&self) -> CryptoKeystoreResult<usize> {
404        if self.transaction.lock().await.is_some() {
405            // Unfortunately, we have to do this because of possible record id overlap
406            // between cache and db.
407            return Ok(self.find_all::<E>(Default::default()).await?.len());
408        };
409        let mut conn = self.conn.lock().await;
410        E::count(&mut conn).await
411    }
412}