1use crate::prelude::{CryptoResult, MlsConversation};
18use core_crypto_keystore::{connection::FetchFromDatabase, entities::EntityFindParams};
19
20#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
21#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
22pub(crate) trait GroupStoreEntity: std::fmt::Debug {
23 type RawStoreValue: core_crypto_keystore::entities::Entity;
24 type IdentityType;
25
26 fn id(&self) -> &[u8];
27
28 async fn fetch_from_id(
29 id: &[u8],
30 identity: Option<Self::IdentityType>,
31 keystore: &impl FetchFromDatabase,
32 ) -> CryptoResult<Option<Self>>
33 where
34 Self: Sized;
35
36 async fn fetch_all(keystore: &impl FetchFromDatabase) -> CryptoResult<Vec<Self>>
37 where
38 Self: Sized;
39}
40
41#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
42#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
43impl GroupStoreEntity for MlsConversation {
44 type RawStoreValue = core_crypto_keystore::entities::PersistedMlsGroup;
45 type IdentityType = ();
46
47 fn id(&self) -> &[u8] {
48 self.id().as_slice()
49 }
50
51 async fn fetch_from_id(
52 id: &[u8],
53 _: Option<Self::IdentityType>,
54 keystore: &impl FetchFromDatabase,
55 ) -> crate::CryptoResult<Option<Self>> {
56 let result = keystore.find::<Self::RawStoreValue>(id).await?;
57 let Some(store_value) = result else {
58 return Ok(None);
59 };
60
61 let conversation = Self::from_serialized_state(store_value.state.clone(), store_value.parent_id.clone())?;
62 Ok(if conversation.group.is_active() {
64 Some(conversation)
65 } else {
66 None
67 })
68 }
69
70 async fn fetch_all(keystore: &impl FetchFromDatabase) -> CryptoResult<Vec<Self>> {
71 let all_conversations = keystore
72 .find_all::<Self::RawStoreValue>(EntityFindParams::default())
73 .await?;
74 Ok(all_conversations
75 .iter()
76 .filter_map(|c| {
77 let conversation = Self::from_serialized_state(c.state.clone(), c.parent_id.clone()).unwrap();
78 conversation.group.is_active().then_some(conversation)
79 })
80 .collect::<Vec<_>>())
81 }
82}
83
84#[cfg(feature = "proteus")]
85#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
86#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
87impl GroupStoreEntity for crate::proteus::ProteusConversationSession {
88 type RawStoreValue = core_crypto_keystore::entities::ProteusSession;
89 type IdentityType = std::sync::Arc<proteus_wasm::keys::IdentityKeyPair>;
90
91 fn id(&self) -> &[u8] {
92 unreachable!()
93 }
94
95 async fn fetch_from_id(
96 id: &[u8],
97 identity: Option<Self::IdentityType>,
98 keystore: &impl FetchFromDatabase,
99 ) -> crate::CryptoResult<Option<Self>> {
100 let result = keystore.find::<Self::RawStoreValue>(id).await?;
101 let Some(store_value) = result else {
102 return Ok(None);
103 };
104
105 let Some(identity) = identity else {
106 return Err(crate::CryptoError::ProteusNotInitialized);
107 };
108
109 let session = proteus_wasm::session::Session::deserialise(identity, &store_value.session)
110 .map_err(crate::ProteusError::from)?;
111
112 Ok(Some(Self {
113 identifier: store_value.id.clone(),
114 session,
115 }))
116 }
117
118 async fn fetch_all(_keystore: &impl FetchFromDatabase) -> CryptoResult<Vec<Self>>
119 where
120 Self: Sized,
121 {
122 unreachable!()
123 }
124}
125
126pub(crate) type GroupStoreValue<V> = std::sync::Arc<async_lock::RwLock<V>>;
127
128pub(crate) type LruMap<V> = schnellru::LruMap<Vec<u8>, GroupStoreValue<V>, HybridMemoryLimiter>;
129
130pub(crate) struct GroupStore<V: GroupStoreEntity>(LruMap<V>);
134
135impl<V: GroupStoreEntity> std::fmt::Debug for GroupStore<V> {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 f.debug_struct("GroupStore")
138 .field("length", &self.0.len())
139 .field("memory_usage", &self.0.memory_usage())
140 .field(
141 "entries",
142 &self
143 .0
144 .iter()
145 .map(|(k, v)| format!("{k:?}={v:?}"))
146 .collect::<Vec<String>>()
147 .join("\n"),
148 )
149 .finish()
150 }
151}
152
153impl<V: GroupStoreEntity> Default for GroupStore<V> {
154 fn default() -> Self {
155 Self(schnellru::LruMap::default())
156 }
157}
158
159#[cfg(test)]
160impl<V: GroupStoreEntity> std::ops::Deref for GroupStore<V> {
161 type Target = LruMap<V>;
162
163 fn deref(&self) -> &Self::Target {
164 &self.0
165 }
166}
167
168#[cfg(test)]
169impl<V: GroupStoreEntity> std::ops::DerefMut for GroupStore<V> {
170 fn deref_mut(&mut self) -> &mut Self::Target {
171 &mut self.0
172 }
173}
174
175impl<V: GroupStoreEntity> GroupStore<V> {
176 #[allow(dead_code)]
177 pub(crate) fn new_with_limit(len: u32) -> Self {
178 let limiter = HybridMemoryLimiter::new(Some(len), None);
179 let store = schnellru::LruMap::new(limiter);
180 Self(store)
181 }
182
183 #[allow(dead_code)]
184 pub(crate) fn new(count: Option<u32>, memory: Option<usize>) -> Self {
185 let limiter = HybridMemoryLimiter::new(count, memory);
186 let store = schnellru::LruMap::new(limiter);
187 Self(store)
188 }
189
190 #[allow(dead_code)]
191 pub(crate) fn contains_key(&self, k: &[u8]) -> bool {
192 self.0.peek(k).is_some()
193 }
194
195 pub(crate) async fn get_fetch(
196 &mut self,
197 k: &[u8],
198 keystore: &impl FetchFromDatabase,
199 identity: Option<V::IdentityType>,
200 ) -> crate::CryptoResult<Option<GroupStoreValue<V>>> {
201 if let Some(value) = self.0.get(k) {
203 return Ok(Some(value.clone()));
204 }
205
206 let mut value = V::fetch_from_id(k, identity, keystore).await?;
208 if let Some(value) = value.take() {
209 let value_to_insert = std::sync::Arc::new(async_lock::RwLock::new(value));
210 self.insert_prepped(k.to_vec(), value_to_insert.clone());
211
212 Ok(Some(value_to_insert))
213 } else {
214 Ok(None)
215 }
216 }
217
218 pub(crate) async fn fetch_from_keystore(
222 k: &[u8],
223 keystore: &impl FetchFromDatabase,
224 identity: Option<V::IdentityType>,
225 ) -> crate::CryptoResult<Option<V>> {
226 V::fetch_from_id(k, identity, keystore).await
227 }
228
229 pub(crate) async fn get_fetch_all(
230 &mut self,
231 keystore: &impl FetchFromDatabase,
232 ) -> CryptoResult<Vec<GroupStoreValue<V>>> {
233 let all = V::fetch_all(keystore)
234 .await?
235 .into_iter()
236 .map(|g| {
237 let id = g.id().to_vec();
238 let to_insert = std::sync::Arc::new(async_lock::RwLock::new(g));
239 self.insert_prepped(id, to_insert.clone());
240 to_insert
241 })
242 .collect::<Vec<_>>();
243 Ok(all)
244 }
245
246 fn insert_prepped(&mut self, k: Vec<u8>, prepped_entity: GroupStoreValue<V>) {
247 self.0.insert(k, prepped_entity);
248 }
249
250 pub(crate) fn insert(&mut self, k: Vec<u8>, entity: V) {
251 let value_to_insert = std::sync::Arc::new(async_lock::RwLock::new(entity));
252 self.insert_prepped(k, value_to_insert)
253 }
254
255 pub(crate) fn try_insert(&mut self, k: Vec<u8>, entity: V) -> Result<(), V> {
256 let value_to_insert = std::sync::Arc::new(async_lock::RwLock::new(entity));
257
258 if self.0.insert(k, value_to_insert.clone()) {
259 Ok(())
260 } else {
261 Err(std::sync::Arc::into_inner(value_to_insert).unwrap().into_inner())
263 }
264 }
265
266 pub(crate) fn remove(&mut self, k: &[u8]) -> Option<GroupStoreValue<V>> {
267 self.0.remove(k)
268 }
269
270 pub(crate) fn get(&mut self, k: &[u8]) -> Option<&mut GroupStoreValue<V>> {
271 self.0.get(k)
272 }
273}
274
275pub(crate) struct HybridMemoryLimiter {
276 mem: schnellru::ByMemoryUsage,
277 len: schnellru::ByLength,
278}
279
280pub(crate) const MEMORY_LIMIT: usize = 100_000_000;
281pub(crate) const ITEM_LIMIT: u32 = 100;
282
283impl HybridMemoryLimiter {
284 pub(crate) fn new(count: Option<u32>, memory: Option<usize>) -> Self {
285 #[allow(clippy::unnecessary_lazy_evaluations)]
287 let maybe_memory_limit = memory.or_else(|| {
288 cfg_if::cfg_if! {
289 if #[cfg(target_family = "wasm")] {
290 None
291 } else {
292 let system = sysinfo::System::new_with_specifics(sysinfo::RefreshKind::new().with_memory(sysinfo::MemoryRefreshKind::new().with_ram()));
293 let available_sys_memory = system.available_memory();
294 if available_sys_memory > 0 {
295 Some(available_sys_memory as usize)
296 } else {
297 None
298 }
299 }
300 }
301 });
302
303 let mem = schnellru::ByMemoryUsage::new(maybe_memory_limit.unwrap_or(MEMORY_LIMIT));
304 let len = schnellru::ByLength::new(count.unwrap_or(ITEM_LIMIT));
305
306 Self { mem, len }
307 }
308}
309
310impl Default for HybridMemoryLimiter {
311 fn default() -> Self {
312 Self::new(None, None)
313 }
314}
315
316impl<K, V> schnellru::Limiter<K, V> for HybridMemoryLimiter {
317 type KeyToInsert<'a> = K;
318 type LinkType = u32;
319
320 fn is_over_the_limit(&self, length: usize) -> bool {
321 <schnellru::ByLength as schnellru::Limiter<K, V>>::is_over_the_limit(&self.len, length)
322 }
323
324 fn on_insert(&mut self, length: usize, key: Self::KeyToInsert<'_>, value: V) -> Option<(K, V)> {
325 <schnellru::ByLength as schnellru::Limiter<K, V>>::on_insert(&mut self.len, length, key, value)
326 }
327
328 fn on_replace(
330 &mut self,
331 _length: usize,
332 _old_key: &mut K,
333 _new_key: Self::KeyToInsert<'_>,
334 _old_value: &mut V,
335 _new_value: &mut V,
336 ) -> bool {
337 true
338 }
339 fn on_removed(&mut self, _key: &mut K, _value: &mut V) {}
340 fn on_cleared(&mut self) {}
341
342 fn on_grow(&mut self, new_memory_usage: usize) -> bool {
343 <schnellru::ByMemoryUsage as schnellru::Limiter<K, V>>::on_grow(&mut self.mem, new_memory_usage)
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use core_crypto_keystore::dummy_entity::{DummyStoreValue, DummyValue};
350 use wasm_bindgen_test::*;
351
352 use super::*;
353
354 wasm_bindgen_test_configure!(run_in_browser);
355
356 #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
357 #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
358 impl GroupStoreEntity for DummyValue {
359 type RawStoreValue = DummyStoreValue;
360
361 type IdentityType = ();
362
363 fn id(&self) -> &[u8] {
364 unreachable!()
365 }
366
367 async fn fetch_from_id(
368 id: &[u8],
369 _identity: Option<Self::IdentityType>,
370 _keystore: &impl FetchFromDatabase,
371 ) -> crate::CryptoResult<Option<Self>> {
372 let id = std::str::from_utf8(id)?;
373 Ok(Some(id.into()))
374 }
375
376 async fn fetch_all(_keystore: &impl FetchFromDatabase) -> CryptoResult<Vec<Self>> {
377 unreachable!()
378 }
379 }
380
381 type TestGroupStore = GroupStore<DummyValue>;
382
383 #[async_std::test]
384 #[wasm_bindgen_test]
385 async fn group_store_init() {
386 let store = TestGroupStore::new_with_limit(1);
387 assert_eq!(store.len(), 0);
388 let store = TestGroupStore::new_with_limit(0);
389 assert_eq!(store.len(), 0);
390 let store = TestGroupStore::new(Some(0), Some(0));
391 assert_eq!(store.len(), 0);
392 let store = TestGroupStore::new(Some(0), Some(1));
393 assert_eq!(store.len(), 0);
394 let store = TestGroupStore::new(Some(1), Some(0));
395 assert_eq!(store.len(), 0);
396 let store = TestGroupStore::new(Some(1), Some(1));
397 assert_eq!(store.len(), 0);
398 }
399
400 #[async_std::test]
401 #[wasm_bindgen_test]
402 async fn group_store_common_ops() {
403 let mut store = TestGroupStore::new(Some(u32::MAX), Some(usize::MAX));
404 for i in 1..=3 {
405 let i_str = i.to_string();
406 assert!(store
407 .try_insert(i_str.as_bytes().to_vec(), i_str.as_str().into())
408 .is_ok());
409 assert_eq!(store.len(), i);
410 }
411 for i in 4..=6 {
412 let i_str = i.to_string();
413 store.insert(i_str.as_bytes().to_vec(), i_str.as_str().into());
414 assert_eq!(store.len(), i);
415 }
416
417 for i in 1..=6 {
418 assert!(store.contains_key(i.to_string().as_bytes()));
419 }
420 }
421
422 #[async_std::test]
423 #[wasm_bindgen_test]
424 async fn group_store_operations_len_limiter() {
425 let mut store = TestGroupStore::new_with_limit(2);
426 assert!(store.try_insert(b"1".to_vec(), "1".into()).is_ok());
427 assert_eq!(store.len(), 1);
428 assert!(store.try_insert(b"2".to_vec(), "2".into()).is_ok());
429 assert_eq!(store.len(), 2);
430 assert!(store.try_insert(b"3".to_vec(), "3".into()).is_ok());
431 assert_eq!(store.len(), 2);
432 assert!(!store.contains_key(b"1"));
433 assert!(store.contains_key(b"2"));
434 assert!(store.contains_key(b"3"));
435 store.insert(b"4".to_vec(), "4".into());
436 assert_eq!(store.len(), 2);
437 }
438
439 #[async_std::test]
440 #[wasm_bindgen_test]
441 async fn group_store_operations_mem_limiter() {
442 use schnellru::{LruMap, UnlimitedCompact};
443 let mut lru: LruMap<Vec<u8>, DummyValue, UnlimitedCompact> =
444 LruMap::<Vec<u8>, DummyValue, UnlimitedCompact>::new(UnlimitedCompact);
445 assert_eq!(lru.guaranteed_capacity(), 0);
446 assert_eq!(lru.memory_usage(), 0);
447 lru.insert(1usize.to_le_bytes().to_vec(), "10".into());
448 let memory_usage_step_1 = lru.memory_usage();
449 lru.insert(2usize.to_le_bytes().to_vec(), "20".into());
450 lru.insert(3usize.to_le_bytes().to_vec(), "30".into());
451 lru.insert(4usize.to_le_bytes().to_vec(), "40".into());
452 let memory_usage_step_2 = lru.memory_usage();
453 assert_ne!(memory_usage_step_1, memory_usage_step_2);
454
455 let mut store = TestGroupStore::new(None, Some(memory_usage_step_2));
456 assert_eq!(store.guaranteed_capacity(), 0);
457 assert_eq!(store.memory_usage(), 0);
458 store.try_insert(1usize.to_le_bytes().to_vec(), "10".into()).unwrap();
459 assert_eq!(store.guaranteed_capacity(), 3);
460 assert!(store.memory_usage() <= memory_usage_step_1);
461 store.try_insert(2usize.to_le_bytes().to_vec(), "20".into()).unwrap();
462 store.try_insert(3usize.to_le_bytes().to_vec(), "30".into()).unwrap();
463 for i in 1..=3usize {
464 assert_eq!(
465 *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
466 DummyValue::from(format!("{}", i * 10).as_str())
467 );
468 }
469 assert_eq!(store.guaranteed_capacity(), 3);
470 assert!(store.memory_usage() <= memory_usage_step_1);
471 assert!(store.try_insert(4usize.to_le_bytes().to_vec(), "40".into()).is_ok());
472 for i in (1usize..=4).rev() {
473 assert_eq!(
474 *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
475 DummyValue::from(format!("{}", i * 10).as_str())
476 );
477 }
478 assert_eq!(store.guaranteed_capacity(), 7);
479 assert!(store.memory_usage() <= memory_usage_step_2);
480 store.try_insert(5usize.to_le_bytes().to_vec(), "50".into()).unwrap();
481 store.try_insert(6usize.to_le_bytes().to_vec(), "60".into()).unwrap();
482 store.try_insert(7usize.to_le_bytes().to_vec(), "70".into()).unwrap();
483 for i in (5usize..=7).rev() {
484 store.get(i.to_le_bytes().as_ref()).unwrap();
485 }
486
487 store.insert(8usize.to_le_bytes().to_vec(), "80".into());
488 for i in [8usize, 7, 6, 5].iter() {
489 assert_eq!(
490 *(store
491 .get(i.to_le_bytes().as_ref())
492 .unwrap_or_else(|| panic!("couldn't find index {i}"))
493 .read()
494 .await),
495 DummyValue::from(format!("{}", i * 10).as_str())
496 );
497 }
498
499 assert_eq!(store.guaranteed_capacity(), 7);
500 assert!(store.memory_usage() <= memory_usage_step_2);
501 store.assert_check_internal_state();
502 }
503}