core_crypto_keystore/connection/
mod.rs1use std::{borrow::Borrow, fmt, ops::Deref};
2
3use async_trait::async_trait;
4use sha2::{Digest as _, Sha256};
5use zeroize::{Zeroize, ZeroizeOnDrop};
6
7pub mod platform {
8 cfg_if::cfg_if! {
9 if #[cfg(target_family = "wasm")] {
10 mod wasm;
11 pub use self::wasm::WasmConnection as KeystoreDatabaseConnection;
12 pub use wasm::storage;
13 pub use self::wasm::storage::WasmStorageTransaction as TransactionWrapper;
14 } else {
15 mod generic;
16 pub use self::generic::SqlCipherConnection as KeystoreDatabaseConnection;
17 pub use self::generic::TransactionWrapper;
18 #[cfg(test)]
19 pub(crate) use generic::MigrationTarget;
20
21
22 }
23 }
24}
25
26use std::{ops::DerefMut, sync::Arc};
27
28use async_lock::{Mutex, MutexGuard, Semaphore};
29
30pub use self::platform::*;
31use crate::{
32 CryptoKeystoreError, CryptoKeystoreResult,
33 entities::{MlsPendingMessage, PersistedMlsGroupExt},
34 traits::{
35 BorrowPrimaryKey, Entity, EntityDatabaseMutation, EntityDeleteBorrowed, EntityGetBorrowed, FetchFromDatabase,
36 KeyType,
37 },
38 transaction::KeystoreTransaction,
39};
40
41pub const MAX_BLOB_LEN: usize = 1_000_000_000;
50
51#[cfg(not(target_family = "wasm"))]
52pub trait DatabaseConnectionRequirements: Sized + Send {}
54#[cfg(target_family = "wasm")]
55pub trait DatabaseConnectionRequirements: Sized {}
58
59#[derive(Clone, Zeroize, ZeroizeOnDrop, derive_more::From, PartialEq, Eq)]
61pub struct DatabaseKey([u8; Self::LEN]);
62
63impl DatabaseKey {
64 pub const LEN: usize = 32;
65
66 pub fn generate() -> DatabaseKey {
67 DatabaseKey(rand::random::<[u8; Self::LEN]>())
68 }
69}
70
71impl fmt::Debug for DatabaseKey {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
73 f.write_str("DatabaseKey(hash=")?;
74 for x in Sha256::digest(self).as_slice().iter().take(10) {
75 fmt::LowerHex::fmt(x, f)?
76 }
77 f.write_str("...)")
78 }
79}
80
81impl AsRef<[u8]> for DatabaseKey {
82 fn as_ref(&self) -> &[u8] {
83 &self.0
84 }
85}
86
87impl Deref for DatabaseKey {
88 type Target = [u8];
89
90 fn deref(&self) -> &Self::Target {
91 &self.0
92 }
93}
94
95impl TryFrom<&[u8]> for DatabaseKey {
96 type Error = CryptoKeystoreError;
97
98 fn try_from(buf: &[u8]) -> Result<Self, Self::Error> {
99 if buf.len() != Self::LEN {
100 Err(CryptoKeystoreError::InvalidDbKeySize {
101 expected: Self::LEN,
102 actual: buf.len(),
103 })
104 } else {
105 Ok(Self(buf.try_into().unwrap()))
106 }
107 }
108}
109
110#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
111#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
112pub trait DatabaseConnection<'a>: DatabaseConnectionRequirements {
113 type Connection: 'a;
114
115 async fn open(name: &str, key: &DatabaseKey) -> CryptoKeystoreResult<Self>;
116
117 async fn open_in_memory(key: &DatabaseKey) -> CryptoKeystoreResult<Self>;
118
119 async fn update_key(&mut self, new_key: &DatabaseKey) -> CryptoKeystoreResult<()>;
120
121 async fn wipe(self) -> CryptoKeystoreResult<()>;
123
124 fn check_buffer_size(size: usize) -> CryptoKeystoreResult<()> {
125 #[cfg(not(target_family = "wasm"))]
126 if size > i32::MAX as usize {
127 return Err(CryptoKeystoreError::BlobTooBig);
128 }
129
130 if size >= MAX_BLOB_LEN {
131 return Err(CryptoKeystoreError::BlobTooBig);
132 }
133
134 Ok(())
135 }
136}
137
138#[derive(Debug, Clone)]
139pub struct Database {
140 pub(crate) conn: Arc<Mutex<Option<KeystoreDatabaseConnection>>>,
141 pub(crate) transaction: Arc<Mutex<Option<KeystoreTransaction>>>,
142 transaction_semaphore: Arc<Semaphore>,
143}
144
145const ALLOWED_CONCURRENT_TRANSACTIONS_COUNT: usize = 1;
146
147#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
150#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
151pub trait OldFetchFromDatabase: Send + Sync {
152 async fn find<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
153 &self,
154 id: impl AsRef<[u8]> + Send,
155 ) -> CryptoKeystoreResult<Option<E>>;
156
157 async fn find_unique<U: crate::entities::UniqueEntity<ConnectionType = KeystoreDatabaseConnection>>(
158 &self,
159 ) -> CryptoKeystoreResult<U>;
160
161 async fn find_all<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
162 &self,
163 params: crate::entities::EntityFindParams,
164 ) -> CryptoKeystoreResult<Vec<E>>;
165
166 async fn find_many<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(
167 &self,
168 ids: &[Vec<u8>],
169 ) -> CryptoKeystoreResult<Vec<E>>;
170 async fn count<E: Entity<ConnectionType = KeystoreDatabaseConnection>>(&self) -> CryptoKeystoreResult<usize>;
171}
172
173unsafe impl Send for Database {}
175unsafe impl Sync for Database {}
177
178#[derive(Debug, Clone)]
180pub enum ConnectionType<'a> {
181 Persistent(&'a str),
183 InMemory,
185}
186
187pub struct ConnectionGuard<'a> {
192 guard: MutexGuard<'a, Option<KeystoreDatabaseConnection>>,
193}
194
195impl<'a> TryFrom<MutexGuard<'a, Option<KeystoreDatabaseConnection>>> for ConnectionGuard<'a> {
196 type Error = CryptoKeystoreError;
197
198 fn try_from(guard: MutexGuard<'a, Option<KeystoreDatabaseConnection>>) -> Result<Self, Self::Error> {
199 guard
200 .is_some()
201 .then_some(Self { guard })
202 .ok_or(CryptoKeystoreError::Closed)
203 }
204}
205
206impl Deref for ConnectionGuard<'_> {
207 type Target = KeystoreDatabaseConnection;
208
209 fn deref(&self) -> &Self::Target {
210 self.guard
211 .as_ref()
212 .expect("we have exclusive access and already checked that the connection exists")
213 }
214}
215
216impl DerefMut for ConnectionGuard<'_> {
217 fn deref_mut(&mut self) -> &mut Self::Target {
218 self.guard
219 .as_mut()
220 .expect("we have exclusive access and already checked that the connection exists")
221 }
222}
223
224impl Database {
226 pub async fn open(location: ConnectionType<'_>, key: &DatabaseKey) -> CryptoKeystoreResult<Self> {
227 let conn = match location {
228 ConnectionType::Persistent(name) => KeystoreDatabaseConnection::open(name, key).await?,
229 ConnectionType::InMemory => KeystoreDatabaseConnection::open_in_memory(key).await?,
230 };
231 let conn = Mutex::new(Some(conn));
232 #[allow(clippy::arc_with_non_send_sync)] let conn = Arc::new(conn);
234 Ok(Self {
235 conn,
236 transaction: Default::default(),
237 transaction_semaphore: Arc::new(Semaphore::new(ALLOWED_CONCURRENT_TRANSACTIONS_COUNT)),
238 })
239 }
240
241 #[cfg(all(test, not(target_family = "wasm")))]
242 pub(crate) async fn open_at_schema_version(
243 name: &str,
244 key: &DatabaseKey,
245 version: MigrationTarget,
246 ) -> CryptoKeystoreResult<Self> {
247 let conn = KeystoreDatabaseConnection::init_with_key_at_schema_version(name, key, version)?;
248 let conn = Mutex::new(Some(conn));
249 let conn = Arc::new(conn);
250 Ok(Self {
251 conn,
252 transaction: Default::default(),
253 transaction_semaphore: Arc::new(Semaphore::new(ALLOWED_CONCURRENT_TRANSACTIONS_COUNT)),
254 })
255 }
256
257 pub async fn conn(&self) -> CryptoKeystoreResult<ConnectionGuard<'_>> {
259 self.conn.lock().await.try_into()
260 }
261
262 async fn take(&self) -> CryptoKeystoreResult<KeystoreDatabaseConnection> {
265 let _semaphore = self.transaction_semaphore.acquire_arc().await;
266
267 let mut guard = self.conn.lock().await;
268 guard.take().ok_or(CryptoKeystoreError::Closed)
269 }
270
271 pub async fn close(&self) -> CryptoKeystoreResult<()> {
273 #[cfg(not(target_family = "wasm"))]
274 self.take().await?;
275
276 #[cfg(target_family = "wasm")]
277 {
278 let conn = self.take().await?;
279 conn.close().await?;
280 }
281 Ok(())
282 }
283
284 pub async fn wipe(&self) -> CryptoKeystoreResult<()> {
286 self.take().await?.wipe().await
287 }
288
289 pub async fn migrate_db_key_type_to_bytes(
290 name: &str,
291 old_key: &str,
292 new_key: &DatabaseKey,
293 ) -> CryptoKeystoreResult<()> {
294 KeystoreDatabaseConnection::migrate_db_key_type_to_bytes(name, old_key, new_key).await
295 }
296
297 pub async fn update_key(&mut self, new_key: &DatabaseKey) -> CryptoKeystoreResult<()> {
298 self.conn().await?.update_key(new_key).await
299 }
300
301 pub async fn new_transaction(&self) -> CryptoKeystoreResult<()> {
303 let semaphore = self.transaction_semaphore.acquire_arc().await;
304 let mut transaction_guard = self.transaction.lock().await;
305 *transaction_guard = Some(KeystoreTransaction::new(semaphore).await?);
306 Ok(())
307 }
308
309 pub async fn commit_transaction(&self) -> CryptoKeystoreResult<()> {
310 let mut transaction_guard = self.transaction.lock().await;
311 let Some(transaction) = transaction_guard.as_ref() else {
312 return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
313 };
314 transaction.commit(self).await?;
315 *transaction_guard = None;
316 Ok(())
317 }
318
319 pub async fn rollback_transaction(&self) -> CryptoKeystoreResult<()> {
320 let mut transaction_guard = self.transaction.lock().await;
321 if transaction_guard.is_none() {
322 return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
323 };
324 *transaction_guard = None;
325 Ok(())
326 }
327
328 pub async fn child_groups<'a, E>(&self, entity: E) -> CryptoKeystoreResult<Vec<E>>
329 where
330 E: Clone + Entity + EntityDatabaseMutation<'a> + BorrowPrimaryKey + PersistedMlsGroupExt + Send + Sync,
331 for<'pk> &'pk <E as BorrowPrimaryKey>::BorrowedPrimaryKey: KeyType,
332 {
333 let mut conn = self.conn().await?;
334 let persisted_records = entity.child_groups(conn.deref_mut()).await?;
335
336 let transaction_guard = self.transaction.lock().await;
337 let Some(transaction) = transaction_guard.as_ref() else {
338 return Ok(persisted_records);
339 };
340 transaction.child_groups(entity, persisted_records).await
341 }
342
343 pub async fn save<'a, E>(&self, entity: E) -> CryptoKeystoreResult<E::AutoGeneratedFields>
344 where
345 E: Entity + EntityDatabaseMutation<'a> + Send + Sync,
346 {
347 let transaction_guard = self.transaction.lock().await;
348 let Some(transaction) = transaction_guard.as_ref() else {
349 return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
350 };
351 transaction.save(entity).await
352 }
353
354 pub async fn remove<'a, E>(&self, id: &E::PrimaryKey) -> CryptoKeystoreResult<()>
355 where
356 E: Entity + EntityDatabaseMutation<'a>,
357 {
358 let transaction_guard = self.transaction.lock().await;
359 let Some(transaction) = transaction_guard.as_ref() else {
360 return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
361 };
362 transaction.remove::<E>(id).await
363 }
364
365 pub async fn remove_borrowed<'a, E>(&self, id: &E::BorrowedPrimaryKey) -> CryptoKeystoreResult<()>
366 where
367 E: Entity + EntityDatabaseMutation<'a> + BorrowPrimaryKey + EntityDeleteBorrowed<'a>,
368 {
369 let transaction_guard = self.transaction.lock().await;
370 let Some(transaction) = transaction_guard.as_ref() else {
371 return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
372 };
373 transaction.remove_borrowed::<E>(id).await
374 }
375
376 pub async fn find_pending_messages_by_conversation_id(
377 &self,
378 conversation_id: &[u8],
379 ) -> CryptoKeystoreResult<Vec<MlsPendingMessage>> {
380 let mut conn = self.conn().await?;
381 let persisted_records =
382 MlsPendingMessage::find_all_by_conversation_id(&mut conn, conversation_id, Default::default()).await?;
383
384 let transaction_guard = self.transaction.lock().await;
385 let Some(transaction) = transaction_guard.as_ref() else {
386 return Ok(persisted_records);
387 };
388 transaction
389 .find_pending_messages_by_conversation_id(conversation_id, persisted_records)
390 .await
391 }
392
393 pub async fn remove_pending_messages_by_conversation_id(
394 &self,
395 conversation_id: impl AsRef<[u8]> + Send,
396 ) -> CryptoKeystoreResult<()> {
397 let transaction_guard = self.transaction.lock().await;
398 let Some(transaction) = transaction_guard.as_ref() else {
399 return Err(CryptoKeystoreError::MutatingOperationWithoutTransaction);
400 };
401 transaction
402 .remove_pending_messages_by_conversation_id(conversation_id)
403 .await;
404 Ok(())
405 }
406}
407
408#[cfg_attr(target_family = "wasm", async_trait(?Send))]
409#[cfg_attr(not(target_family = "wasm"), async_trait)]
410impl FetchFromDatabase for Database {
411 async fn get<E>(&self, id: &E::PrimaryKey) -> CryptoKeystoreResult<Option<E>>
412 where
413 E: Entity<ConnectionType = KeystoreDatabaseConnection> + Clone + Send + Sync,
414 {
415 if let Some(transaction) = self.transaction.lock().await.as_ref()
417 && let Some(cached_record) = transaction.get(id).await
419 {
420 return Ok(cached_record.map(Arc::unwrap_or_clone));
421 }
422
423 let mut conn = self.conn().await?;
425 E::get(&mut conn, id).await
426 }
427
428 async fn get_borrowed<E>(&self, id: &<E as BorrowPrimaryKey>::BorrowedPrimaryKey) -> CryptoKeystoreResult<Option<E>>
429 where
430 E: EntityGetBorrowed<ConnectionType = KeystoreDatabaseConnection> + Clone + Send + Sync,
431 E::PrimaryKey: Borrow<E::BorrowedPrimaryKey>,
432 for<'a> &'a E::BorrowedPrimaryKey: KeyType,
433 {
434 if let Some(transaction) = self.transaction.lock().await.as_ref()
436 && let Some(cached_record) = transaction.get_borrowed(id).await
438 {
439 return Ok(cached_record.map(Arc::unwrap_or_clone));
440 }
441
442 let mut conn = self.conn().await?;
444 E::get_borrowed(&mut conn, id).await
445 }
446
447 async fn count<E>(&self) -> CryptoKeystoreResult<u32>
448 where
449 E: Entity<ConnectionType = KeystoreDatabaseConnection> + Clone + Send + Sync,
450 {
451 if self.transaction.lock().await.is_some() {
452 let count = self.load_all::<E>().await?.len();
455 Ok(count as _)
456 } else {
457 let mut conn = self.conn().await?;
458 E::count(&mut conn).await
459 }
460 }
461
462 async fn load_all<E>(&self) -> CryptoKeystoreResult<Vec<E>>
463 where
464 E: Entity<ConnectionType = KeystoreDatabaseConnection> + Clone + Send + Sync,
465 {
466 let mut conn = self.conn().await?;
467 let persisted_records = E::load_all(&mut conn).await?;
468
469 let transaction_guard = self.transaction.lock().await;
470 let Some(transaction) = transaction_guard.as_ref() else {
471 return Ok(persisted_records);
472 };
473 transaction.find_all(persisted_records).await
474 }
475}