1use std::sync::Arc;
18
19use crate::{KeystoreError, ProteusError, RecursiveError, Result, prelude::MlsConversation};
20use core_crypto_keystore::connection::FetchFromDatabase;
21#[cfg(test)]
22use core_crypto_keystore::entities::EntityFindParams;
23
24#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
25#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
26pub(crate) trait GroupStoreEntity: std::fmt::Debug {
27 type RawStoreValue: core_crypto_keystore::entities::Entity;
28 type IdentityType;
29
30 #[cfg(test)]
31 fn id(&self) -> &[u8];
32
33 async fn fetch_from_id(
34 id: &[u8],
35 identity: Option<Self::IdentityType>,
36 keystore: &impl FetchFromDatabase,
37 ) -> Result<Option<Self>>
38 where
39 Self: Sized;
40
41 #[cfg(test)]
42 async fn fetch_all(keystore: &impl FetchFromDatabase) -> Result<Vec<Self>>
43 where
44 Self: Sized;
45}
46
47#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
48#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
49impl GroupStoreEntity for MlsConversation {
50 type RawStoreValue = core_crypto_keystore::entities::PersistedMlsGroup;
51 type IdentityType = ();
52
53 #[cfg(test)]
54 fn id(&self) -> &[u8] {
55 self.id().as_slice()
56 }
57
58 async fn fetch_from_id(
59 id: &[u8],
60 _: Option<Self::IdentityType>,
61 keystore: &impl FetchFromDatabase,
62 ) -> crate::Result<Option<Self>> {
63 let result = keystore
64 .find::<Self::RawStoreValue>(id)
65 .await
66 .map_err(KeystoreError::wrap("finding mls conversation from keystore by id"))?;
67 let Some(store_value) = result else {
68 return Ok(None);
69 };
70
71 let conversation = Self::from_serialized_state(store_value.state.clone(), store_value.parent_id.clone())
72 .map_err(RecursiveError::mls_conversation("deserializing mls conversation"))?;
73 Ok(conversation.group.is_active().then_some(conversation))
75 }
76
77 #[cfg(test)]
78 async fn fetch_all(keystore: &impl FetchFromDatabase) -> Result<Vec<Self>> {
79 let all_conversations = keystore
80 .find_all::<Self::RawStoreValue>(EntityFindParams::default())
81 .await
82 .map_err(KeystoreError::wrap("finding all mls conversations"))?;
83 Ok(all_conversations
84 .iter()
85 .filter_map(|c| {
86 let conversation = Self::from_serialized_state(c.state.clone(), c.parent_id.clone()).unwrap();
87 conversation.group.is_active().then_some(conversation)
88 })
89 .collect::<Vec<_>>())
90 }
91}
92
93#[cfg(feature = "proteus")]
94#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
95#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
96impl GroupStoreEntity for crate::proteus::ProteusConversationSession {
97 type RawStoreValue = core_crypto_keystore::entities::ProteusSession;
98 type IdentityType = Arc<proteus_wasm::keys::IdentityKeyPair>;
99
100 #[cfg(test)]
101 fn id(&self) -> &[u8] {
102 unreachable!()
103 }
104
105 async fn fetch_from_id(
106 id: &[u8],
107 identity: Option<Self::IdentityType>,
108 keystore: &impl FetchFromDatabase,
109 ) -> crate::Result<Option<Self>> {
110 let result = keystore
111 .find::<Self::RawStoreValue>(id)
112 .await
113 .map_err(KeystoreError::wrap("finding raw group store entity by id"))?;
114 let Some(store_value) = result else {
115 return Ok(None);
116 };
117
118 let Some(identity) = identity else {
119 return Err(crate::Error::ProteusNotInitialized);
120 };
121
122 let session = proteus_wasm::session::Session::deserialise(identity, &store_value.session)
123 .map_err(ProteusError::wrap("deserializing session"))?;
124
125 Ok(Some(Self {
126 identifier: store_value.id.clone(),
127 session,
128 }))
129 }
130
131 #[cfg(test)]
132 async fn fetch_all(_keystore: &impl FetchFromDatabase) -> Result<Vec<Self>>
133 where
134 Self: Sized,
135 {
136 unreachable!()
137 }
138}
139
140pub(crate) type GroupStoreValue<V> = Arc<async_lock::RwLock<V>>;
141
142pub(crate) type LruMap<V> = schnellru::LruMap<Vec<u8>, GroupStoreValue<V>, HybridMemoryLimiter>;
143
144pub(crate) struct GroupStore<V: GroupStoreEntity>(LruMap<V>);
148
149impl<V: GroupStoreEntity> std::fmt::Debug for GroupStore<V> {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 f.debug_struct("GroupStore")
152 .field("length", &self.0.len())
153 .field("memory_usage", &self.0.memory_usage())
154 .field(
155 "entries",
156 &self
157 .0
158 .iter()
159 .map(|(k, v)| format!("{k:?}={v:?}"))
160 .collect::<Vec<String>>()
161 .join("\n"),
162 )
163 .finish()
164 }
165}
166
167impl<V: GroupStoreEntity> Default for GroupStore<V> {
168 fn default() -> Self {
169 Self(schnellru::LruMap::default())
170 }
171}
172
173#[cfg(test)]
174impl<V: GroupStoreEntity> std::ops::Deref for GroupStore<V> {
175 type Target = LruMap<V>;
176
177 fn deref(&self) -> &Self::Target {
178 &self.0
179 }
180}
181
182#[cfg(test)]
183impl<V: GroupStoreEntity> std::ops::DerefMut for GroupStore<V> {
184 fn deref_mut(&mut self) -> &mut Self::Target {
185 &mut self.0
186 }
187}
188
189impl<V: GroupStoreEntity> GroupStore<V> {
190 #[allow(dead_code)]
191 pub(crate) fn new_with_limit(len: u32) -> Self {
192 let limiter = HybridMemoryLimiter::new(Some(len), None);
193 let store = schnellru::LruMap::new(limiter);
194 Self(store)
195 }
196
197 #[allow(dead_code)]
198 pub(crate) fn new(count: Option<u32>, memory: Option<usize>) -> Self {
199 let limiter = HybridMemoryLimiter::new(count, memory);
200 let store = schnellru::LruMap::new(limiter);
201 Self(store)
202 }
203
204 #[allow(dead_code)]
205 pub(crate) fn contains_key(&self, k: &[u8]) -> bool {
206 self.0.peek(k).is_some()
207 }
208
209 pub(crate) async fn get_fetch(
210 &mut self,
211 k: &[u8],
212 keystore: &impl FetchFromDatabase,
213 identity: Option<V::IdentityType>,
214 ) -> crate::Result<Option<GroupStoreValue<V>>> {
215 if let Some(value) = self.0.get(k) {
217 return Ok(Some(value.clone()));
218 }
219
220 let inserted_value = V::fetch_from_id(k, identity, keystore).await?.map(|value| {
222 let value_to_insert = Arc::new(async_lock::RwLock::new(value));
223 self.insert_prepped(k.to_vec(), value_to_insert.clone());
224 value_to_insert
225 });
226 Ok(inserted_value)
227 }
228
229 pub(crate) async fn fetch_from_keystore(
233 k: &[u8],
234 keystore: &impl FetchFromDatabase,
235 identity: Option<V::IdentityType>,
236 ) -> crate::Result<Option<V>> {
237 V::fetch_from_id(k, identity, keystore).await
238 }
239
240 #[cfg(test)]
241 pub(crate) async fn get_fetch_all(&mut self, keystore: &impl FetchFromDatabase) -> Result<Vec<GroupStoreValue<V>>> {
242 let all = V::fetch_all(keystore)
243 .await?
244 .into_iter()
245 .map(|g| {
246 let id = g.id().to_vec();
247 let to_insert = Arc::new(async_lock::RwLock::new(g));
248 self.insert_prepped(id, to_insert.clone());
249 to_insert
250 })
251 .collect::<Vec<_>>();
252 Ok(all)
253 }
254
255 fn insert_prepped(&mut self, k: Vec<u8>, prepped_entity: GroupStoreValue<V>) {
256 self.0.insert(k, prepped_entity);
257 }
258
259 pub(crate) fn insert(&mut self, k: Vec<u8>, entity: V) {
260 let value_to_insert = Arc::new(async_lock::RwLock::new(entity));
261 self.insert_prepped(k, value_to_insert)
262 }
263
264 pub(crate) fn try_insert(&mut self, k: Vec<u8>, entity: V) -> Result<(), V> {
265 let value_to_insert = Arc::new(async_lock::RwLock::new(entity));
266
267 if self.0.insert(k, value_to_insert.clone()) {
268 Ok(())
269 } else {
270 Err(Arc::into_inner(value_to_insert).unwrap().into_inner())
272 }
273 }
274
275 pub(crate) fn remove(&mut self, k: &[u8]) -> Option<GroupStoreValue<V>> {
276 self.0.remove(k)
277 }
278
279 pub(crate) fn get(&mut self, k: &[u8]) -> Option<&mut GroupStoreValue<V>> {
280 self.0.get(k)
281 }
282}
283
284pub(crate) struct HybridMemoryLimiter {
285 mem: schnellru::ByMemoryUsage,
286 len: schnellru::ByLength,
287}
288
289pub(crate) const MEMORY_LIMIT: usize = 100_000_000;
290pub(crate) const ITEM_LIMIT: u32 = 100;
291
292impl HybridMemoryLimiter {
293 #[cfg_attr(target_family = "wasm", expect(unused_variables))]
295 pub(crate) fn new(count: Option<u32>, memory: Option<usize>) -> Self {
296 #[cfg(target_family = "wasm")]
297 let memory_limit = MEMORY_LIMIT;
298
299 #[cfg(not(target_family = "wasm"))]
300 let memory_limit = memory
301 .or_else(|| {
302 let system = sysinfo::System::new_with_specifics(
303 sysinfo::RefreshKind::nothing().with_memory(sysinfo::MemoryRefreshKind::nothing().with_ram()),
304 );
305
306 let available_sys_memory = system.available_memory();
307 (available_sys_memory > 0).then_some(available_sys_memory as usize)
308 })
309 .unwrap_or(MEMORY_LIMIT);
310
311 let mem = schnellru::ByMemoryUsage::new(memory_limit);
312 let len = schnellru::ByLength::new(count.unwrap_or(ITEM_LIMIT));
313
314 Self { mem, len }
315 }
316}
317
318impl Default for HybridMemoryLimiter {
319 fn default() -> Self {
320 Self::new(None, None)
321 }
322}
323
324impl<K, V> schnellru::Limiter<K, V> for HybridMemoryLimiter {
325 type KeyToInsert<'a> = K;
326 type LinkType = u32;
327
328 fn is_over_the_limit(&self, length: usize) -> bool {
329 <schnellru::ByLength as schnellru::Limiter<K, V>>::is_over_the_limit(&self.len, length)
330 }
331
332 fn on_insert(&mut self, length: usize, key: Self::KeyToInsert<'_>, value: V) -> Option<(K, V)> {
333 <schnellru::ByLength as schnellru::Limiter<K, V>>::on_insert(&mut self.len, length, key, value)
334 }
335
336 fn on_replace(
338 &mut self,
339 _length: usize,
340 _old_key: &mut K,
341 _new_key: Self::KeyToInsert<'_>,
342 _old_value: &mut V,
343 _new_value: &mut V,
344 ) -> bool {
345 true
346 }
347 fn on_removed(&mut self, _key: &mut K, _value: &mut V) {}
348 fn on_cleared(&mut self) {}
349
350 fn on_grow(&mut self, new_memory_usage: usize) -> bool {
351 <schnellru::ByMemoryUsage as schnellru::Limiter<K, V>>::on_grow(&mut self.mem, new_memory_usage)
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use core_crypto_keystore::dummy_entity::{DummyStoreValue, DummyValue};
358 use wasm_bindgen_test::*;
359
360 use super::*;
361
362 wasm_bindgen_test_configure!(run_in_browser);
363
364 #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
365 #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
366 impl GroupStoreEntity for DummyValue {
367 type RawStoreValue = DummyStoreValue;
368
369 type IdentityType = ();
370
371 fn id(&self) -> &[u8] {
372 unreachable!()
373 }
374
375 async fn fetch_from_id(
376 id: &[u8],
377 _identity: Option<Self::IdentityType>,
378 _keystore: &impl FetchFromDatabase,
379 ) -> crate::Result<Option<Self>> {
380 let id = std::str::from_utf8(id).expect("dummy value ids are strings");
382 Ok(Some(id.into()))
383 }
384
385 #[cfg(test)]
386 async fn fetch_all(_keystore: &impl FetchFromDatabase) -> Result<Vec<Self>> {
387 unreachable!()
388 }
389 }
390
391 type TestGroupStore = GroupStore<DummyValue>;
392
393 #[async_std::test]
394 #[wasm_bindgen_test]
395 async fn group_store_init() {
396 let store = TestGroupStore::new_with_limit(1);
397 assert_eq!(store.len(), 0);
398 let store = TestGroupStore::new_with_limit(0);
399 assert_eq!(store.len(), 0);
400 let store = TestGroupStore::new(Some(0), Some(0));
401 assert_eq!(store.len(), 0);
402 let store = TestGroupStore::new(Some(0), Some(1));
403 assert_eq!(store.len(), 0);
404 let store = TestGroupStore::new(Some(1), Some(0));
405 assert_eq!(store.len(), 0);
406 let store = TestGroupStore::new(Some(1), Some(1));
407 assert_eq!(store.len(), 0);
408 }
409
410 #[async_std::test]
411 #[wasm_bindgen_test]
412 async fn group_store_common_ops() {
413 let mut store = TestGroupStore::new(Some(u32::MAX), Some(usize::MAX));
414 for i in 1..=3 {
415 let i_str = i.to_string();
416 assert!(
417 store
418 .try_insert(i_str.as_bytes().to_vec(), i_str.as_str().into())
419 .is_ok()
420 );
421 assert_eq!(store.len(), i);
422 }
423 for i in 4..=6 {
424 let i_str = i.to_string();
425 store.insert(i_str.as_bytes().to_vec(), i_str.as_str().into());
426 assert_eq!(store.len(), i);
427 }
428
429 for i in 1..=6 {
430 assert!(store.contains_key(i.to_string().as_bytes()));
431 }
432 }
433
434 #[async_std::test]
435 #[wasm_bindgen_test]
436 async fn group_store_operations_len_limiter() {
437 let mut store = TestGroupStore::new_with_limit(2);
438 assert!(store.try_insert(b"1".to_vec(), "1".into()).is_ok());
439 assert_eq!(store.len(), 1);
440 assert!(store.try_insert(b"2".to_vec(), "2".into()).is_ok());
441 assert_eq!(store.len(), 2);
442 assert!(store.try_insert(b"3".to_vec(), "3".into()).is_ok());
443 assert_eq!(store.len(), 2);
444 assert!(!store.contains_key(b"1"));
445 assert!(store.contains_key(b"2"));
446 assert!(store.contains_key(b"3"));
447 store.insert(b"4".to_vec(), "4".into());
448 assert_eq!(store.len(), 2);
449 }
450
451 #[async_std::test]
452 #[wasm_bindgen_test]
453 async fn group_store_operations_mem_limiter() {
454 use schnellru::{LruMap, UnlimitedCompact};
455 let mut lru: LruMap<Vec<u8>, DummyValue, UnlimitedCompact> =
456 LruMap::<Vec<u8>, DummyValue, UnlimitedCompact>::new(UnlimitedCompact);
457 assert_eq!(lru.guaranteed_capacity(), 0);
458 assert_eq!(lru.memory_usage(), 0);
459 lru.insert(1usize.to_le_bytes().to_vec(), "10".into());
460 let memory_usage_step_1 = lru.memory_usage();
461 lru.insert(2usize.to_le_bytes().to_vec(), "20".into());
462 lru.insert(3usize.to_le_bytes().to_vec(), "30".into());
463 lru.insert(4usize.to_le_bytes().to_vec(), "40".into());
464 let memory_usage_step_2 = lru.memory_usage();
465 assert_ne!(memory_usage_step_1, memory_usage_step_2);
466
467 let mut store = TestGroupStore::new(None, Some(memory_usage_step_2));
468 assert_eq!(store.guaranteed_capacity(), 0);
469 assert_eq!(store.memory_usage(), 0);
470 store.try_insert(1usize.to_le_bytes().to_vec(), "10".into()).unwrap();
471 assert_eq!(store.guaranteed_capacity(), 3);
472 assert!(store.memory_usage() <= memory_usage_step_1);
473 store.try_insert(2usize.to_le_bytes().to_vec(), "20".into()).unwrap();
474 store.try_insert(3usize.to_le_bytes().to_vec(), "30".into()).unwrap();
475 for i in 1..=3usize {
476 assert_eq!(
477 *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
478 DummyValue::from(format!("{}", i * 10).as_str())
479 );
480 }
481 assert_eq!(store.guaranteed_capacity(), 3);
482 assert!(store.memory_usage() <= memory_usage_step_1);
483 assert!(store.try_insert(4usize.to_le_bytes().to_vec(), "40".into()).is_ok());
484 for i in (1usize..=4).rev() {
485 assert_eq!(
486 *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
487 DummyValue::from(format!("{}", i * 10).as_str())
488 );
489 }
490 assert_eq!(store.guaranteed_capacity(), 7);
491 assert!(store.memory_usage() <= memory_usage_step_2);
492 store.try_insert(5usize.to_le_bytes().to_vec(), "50".into()).unwrap();
493 store.try_insert(6usize.to_le_bytes().to_vec(), "60".into()).unwrap();
494 store.try_insert(7usize.to_le_bytes().to_vec(), "70".into()).unwrap();
495 for i in (5usize..=7).rev() {
496 store.get(i.to_le_bytes().as_ref()).unwrap();
497 }
498
499 store.insert(8usize.to_le_bytes().to_vec(), "80".into());
500 for i in [8usize, 7, 6, 5].iter() {
501 assert_eq!(
502 *(store
503 .get(i.to_le_bytes().as_ref())
504 .unwrap_or_else(|| panic!("couldn't find index {i}"))
505 .read()
506 .await),
507 DummyValue::from(format!("{}", i * 10).as_str())
508 );
509 }
510
511 assert_eq!(store.guaranteed_capacity(), 7);
512 assert!(store.memory_usage() <= memory_usage_step_2);
513 store.assert_check_internal_state();
514 }
515}