1use std::sync::Arc;
2
3use crate::{KeystoreError, RecursiveError, Result, prelude::MlsConversation};
4use core_crypto_keystore::connection::FetchFromDatabase;
5#[cfg(test)]
6use core_crypto_keystore::entities::EntityFindParams;
7
8#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
9#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
10pub(crate) trait GroupStoreEntity: std::fmt::Debug {
11 type RawStoreValue: core_crypto_keystore::entities::Entity;
12 type IdentityType;
13
14 #[cfg(test)]
15 fn id(&self) -> &[u8];
16
17 async fn fetch_from_id(
18 id: &[u8],
19 identity: Option<Self::IdentityType>,
20 keystore: &impl FetchFromDatabase,
21 ) -> Result<Option<Self>>
22 where
23 Self: Sized;
24
25 #[cfg(test)]
26 async fn fetch_all(keystore: &impl FetchFromDatabase) -> Result<Vec<Self>>
27 where
28 Self: Sized;
29}
30
31#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
32#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
33impl GroupStoreEntity for MlsConversation {
34 type RawStoreValue = core_crypto_keystore::entities::PersistedMlsGroup;
35 type IdentityType = ();
36
37 #[cfg(test)]
38 fn id(&self) -> &[u8] {
39 self.id().as_slice()
40 }
41
42 async fn fetch_from_id(
43 id: &[u8],
44 _: Option<Self::IdentityType>,
45 keystore: &impl FetchFromDatabase,
46 ) -> crate::Result<Option<Self>> {
47 let result = keystore
48 .find::<Self::RawStoreValue>(id)
49 .await
50 .map_err(KeystoreError::wrap("finding mls conversation from keystore by id"))?;
51 let Some(store_value) = result else {
52 return Ok(None);
53 };
54
55 let conversation = Self::from_serialized_state(store_value.state.clone(), store_value.parent_id.clone())
56 .map_err(RecursiveError::mls_conversation("deserializing mls conversation"))?;
57 Ok(conversation.group.is_active().then_some(conversation))
59 }
60
61 #[cfg(test)]
62 async fn fetch_all(keystore: &impl FetchFromDatabase) -> Result<Vec<Self>> {
63 let all_conversations = keystore
64 .find_all::<Self::RawStoreValue>(EntityFindParams::default())
65 .await
66 .map_err(KeystoreError::wrap("finding all mls conversations"))?;
67 Ok(all_conversations
68 .iter()
69 .filter_map(|c| {
70 let conversation = Self::from_serialized_state(c.state.clone(), c.parent_id.clone()).unwrap();
71 conversation.group.is_active().then_some(conversation)
72 })
73 .collect::<Vec<_>>())
74 }
75}
76
77pub(crate) type GroupStoreValue<V> = Arc<async_lock::RwLock<V>>;
78
79pub(crate) type LruMap<V> = schnellru::LruMap<Vec<u8>, GroupStoreValue<V>, HybridMemoryLimiter>;
80
81pub(crate) struct GroupStore<V: GroupStoreEntity>(LruMap<V>);
85
86impl<V: GroupStoreEntity> std::fmt::Debug for GroupStore<V> {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 f.debug_struct("GroupStore")
89 .field("length", &self.0.len())
90 .field("memory_usage", &self.0.memory_usage())
91 .field(
92 "entries",
93 &self
94 .0
95 .iter()
96 .map(|(k, v)| format!("{k:?}={v:?}"))
97 .collect::<Vec<String>>()
98 .join("\n"),
99 )
100 .finish()
101 }
102}
103
104impl<V: GroupStoreEntity> Default for GroupStore<V> {
105 fn default() -> Self {
106 Self(schnellru::LruMap::default())
107 }
108}
109
110#[cfg(test)]
111impl<V: GroupStoreEntity> std::ops::Deref for GroupStore<V> {
112 type Target = LruMap<V>;
113
114 fn deref(&self) -> &Self::Target {
115 &self.0
116 }
117}
118
119#[cfg(test)]
120impl<V: GroupStoreEntity> std::ops::DerefMut for GroupStore<V> {
121 fn deref_mut(&mut self) -> &mut Self::Target {
122 &mut self.0
123 }
124}
125
126impl<V: GroupStoreEntity> GroupStore<V> {
127 pub(crate) fn new_with_limit(len: u32) -> Self {
128 let limiter = HybridMemoryLimiter::new(Some(len), None);
129 let store = schnellru::LruMap::new(limiter);
130 Self(store)
131 }
132
133 #[cfg(test)]
134 fn new(count: Option<u32>, memory: Option<usize>) -> Self {
135 let limiter = HybridMemoryLimiter::new(count, memory);
136 let store = schnellru::LruMap::new(limiter);
137 Self(store)
138 }
139
140 #[cfg(test)]
141 fn contains_key(&self, k: &[u8]) -> bool {
142 self.0.peek(k).is_some()
143 }
144
145 pub(crate) async fn get_fetch(
146 &mut self,
147 k: &[u8],
148 keystore: &impl FetchFromDatabase,
149 identity: Option<V::IdentityType>,
150 ) -> crate::Result<Option<GroupStoreValue<V>>> {
151 if let Some(value) = self.0.get(k) {
153 return Ok(Some(value.clone()));
154 }
155
156 let inserted_value = V::fetch_from_id(k, identity, keystore).await?.map(|value| {
158 let value_to_insert = Arc::new(async_lock::RwLock::new(value));
159 self.insert_prepped(k.to_vec(), value_to_insert.clone());
160 value_to_insert
161 });
162 Ok(inserted_value)
163 }
164
165 pub(crate) async fn fetch_from_keystore(
169 k: &[u8],
170 keystore: &impl FetchFromDatabase,
171 identity: Option<V::IdentityType>,
172 ) -> crate::Result<Option<V>> {
173 V::fetch_from_id(k, identity, keystore).await
174 }
175
176 #[cfg(test)]
177 pub(crate) async fn get_fetch_all(&mut self, keystore: &impl FetchFromDatabase) -> Result<Vec<GroupStoreValue<V>>> {
178 let all = V::fetch_all(keystore)
179 .await?
180 .into_iter()
181 .map(|g| {
182 let id = g.id().to_vec();
183 let to_insert = Arc::new(async_lock::RwLock::new(g));
184 self.insert_prepped(id, to_insert.clone());
185 to_insert
186 })
187 .collect::<Vec<_>>();
188 Ok(all)
189 }
190
191 fn insert_prepped(&mut self, k: Vec<u8>, prepped_entity: GroupStoreValue<V>) {
192 self.0.insert(k, prepped_entity);
193 }
194
195 pub(crate) fn insert(&mut self, k: Vec<u8>, entity: V) {
196 let value_to_insert = Arc::new(async_lock::RwLock::new(entity));
197 self.insert_prepped(k, value_to_insert)
198 }
199
200 pub(crate) fn try_insert(&mut self, k: Vec<u8>, entity: V) -> Result<(), V> {
201 let value_to_insert = Arc::new(async_lock::RwLock::new(entity));
202
203 if self.0.insert(k, value_to_insert.clone()) {
204 Ok(())
205 } else {
206 Err(Arc::into_inner(value_to_insert).unwrap().into_inner())
208 }
209 }
210
211 pub(crate) fn remove(&mut self, k: &[u8]) -> Option<GroupStoreValue<V>> {
212 self.0.remove(k)
213 }
214
215 pub(crate) fn get(&mut self, k: &[u8]) -> Option<&mut GroupStoreValue<V>> {
216 self.0.get(k)
217 }
218}
219
220pub(crate) struct HybridMemoryLimiter {
221 mem: schnellru::ByMemoryUsage,
222 len: schnellru::ByLength,
223}
224
225pub(crate) const MEMORY_LIMIT: usize = 100_000_000;
226pub(crate) const ITEM_LIMIT: u32 = 100;
227
228impl HybridMemoryLimiter {
229 pub(crate) fn new(count: Option<u32>, memory: Option<usize>) -> Self {
230 let memory_limit = memory.unwrap_or(MEMORY_LIMIT);
231 let mem = schnellru::ByMemoryUsage::new(memory_limit);
232 let len = schnellru::ByLength::new(count.unwrap_or(ITEM_LIMIT));
233
234 Self { mem, len }
235 }
236}
237
238impl Default for HybridMemoryLimiter {
239 fn default() -> Self {
240 Self::new(None, None)
241 }
242}
243
244impl<K, V> schnellru::Limiter<K, V> for HybridMemoryLimiter {
245 type KeyToInsert<'a> = K;
246 type LinkType = u32;
247
248 fn is_over_the_limit(&self, length: usize) -> bool {
249 <schnellru::ByLength as schnellru::Limiter<K, V>>::is_over_the_limit(&self.len, length)
250 }
251
252 fn on_insert(&mut self, length: usize, key: Self::KeyToInsert<'_>, value: V) -> Option<(K, V)> {
253 <schnellru::ByLength as schnellru::Limiter<K, V>>::on_insert(&mut self.len, length, key, value)
254 }
255
256 fn on_replace(
258 &mut self,
259 _length: usize,
260 _old_key: &mut K,
261 _new_key: Self::KeyToInsert<'_>,
262 _old_value: &mut V,
263 _new_value: &mut V,
264 ) -> bool {
265 true
266 }
267 fn on_removed(&mut self, _key: &mut K, _value: &mut V) {}
268 fn on_cleared(&mut self) {}
269
270 fn on_grow(&mut self, new_memory_usage: usize) -> bool {
271 <schnellru::ByMemoryUsage as schnellru::Limiter<K, V>>::on_grow(&mut self.mem, new_memory_usage)
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use core_crypto_keystore::dummy_entity::{DummyStoreValue, DummyValue};
278 use wasm_bindgen_test::*;
279
280 use super::*;
281
282 wasm_bindgen_test_configure!(run_in_browser);
283
284 #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
285 #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
286 impl GroupStoreEntity for DummyValue {
287 type RawStoreValue = DummyStoreValue;
288
289 type IdentityType = ();
290
291 fn id(&self) -> &[u8] {
292 unreachable!()
293 }
294
295 async fn fetch_from_id(
296 id: &[u8],
297 _identity: Option<Self::IdentityType>,
298 _keystore: &impl FetchFromDatabase,
299 ) -> crate::Result<Option<Self>> {
300 let id = std::str::from_utf8(id).expect("dummy value ids are strings");
302 Ok(Some(id.into()))
303 }
304
305 #[cfg(test)]
306 async fn fetch_all(_keystore: &impl FetchFromDatabase) -> Result<Vec<Self>> {
307 unreachable!()
308 }
309 }
310
311 type TestGroupStore = GroupStore<DummyValue>;
312
313 #[async_std::test]
314 #[wasm_bindgen_test]
315 async fn group_store_init() {
316 let store = TestGroupStore::new_with_limit(1);
317 assert_eq!(store.len(), 0);
318 let store = TestGroupStore::new_with_limit(0);
319 assert_eq!(store.len(), 0);
320 let store = TestGroupStore::new(Some(0), Some(0));
321 assert_eq!(store.len(), 0);
322 let store = TestGroupStore::new(Some(0), Some(1));
323 assert_eq!(store.len(), 0);
324 let store = TestGroupStore::new(Some(1), Some(0));
325 assert_eq!(store.len(), 0);
326 let store = TestGroupStore::new(Some(1), Some(1));
327 assert_eq!(store.len(), 0);
328 }
329
330 #[async_std::test]
331 #[wasm_bindgen_test]
332 async fn group_store_common_ops() {
333 let mut store = TestGroupStore::new(Some(u32::MAX), Some(usize::MAX));
334 for i in 1..=3 {
335 let i_str = i.to_string();
336 assert!(
337 store
338 .try_insert(i_str.as_bytes().to_vec(), i_str.as_str().into())
339 .is_ok()
340 );
341 assert_eq!(store.len(), i);
342 }
343 for i in 4..=6 {
344 let i_str = i.to_string();
345 store.insert(i_str.as_bytes().to_vec(), i_str.as_str().into());
346 assert_eq!(store.len(), i);
347 }
348
349 for i in 1..=6 {
350 assert!(store.contains_key(i.to_string().as_bytes()));
351 }
352 }
353
354 #[async_std::test]
355 #[wasm_bindgen_test]
356 async fn group_store_operations_len_limiter() {
357 let mut store = TestGroupStore::new_with_limit(2);
358 assert!(store.try_insert(b"1".to_vec(), "1".into()).is_ok());
359 assert_eq!(store.len(), 1);
360 assert!(store.try_insert(b"2".to_vec(), "2".into()).is_ok());
361 assert_eq!(store.len(), 2);
362 assert!(store.try_insert(b"3".to_vec(), "3".into()).is_ok());
363 assert_eq!(store.len(), 2);
364 assert!(!store.contains_key(b"1"));
365 assert!(store.contains_key(b"2"));
366 assert!(store.contains_key(b"3"));
367 store.insert(b"4".to_vec(), "4".into());
368 assert_eq!(store.len(), 2);
369 }
370
371 #[async_std::test]
372 #[wasm_bindgen_test]
373 async fn group_store_operations_mem_limiter() {
374 use schnellru::{LruMap, UnlimitedCompact};
375 let mut lru: LruMap<Vec<u8>, DummyValue, UnlimitedCompact> =
376 LruMap::<Vec<u8>, DummyValue, UnlimitedCompact>::new(UnlimitedCompact);
377 assert_eq!(lru.guaranteed_capacity(), 0);
378 assert_eq!(lru.memory_usage(), 0);
379 lru.insert(1usize.to_le_bytes().to_vec(), "10".into());
380 let memory_usage_step_1 = lru.memory_usage();
381 lru.insert(2usize.to_le_bytes().to_vec(), "20".into());
382 lru.insert(3usize.to_le_bytes().to_vec(), "30".into());
383 lru.insert(4usize.to_le_bytes().to_vec(), "40".into());
384 let memory_usage_step_2 = lru.memory_usage();
385 assert_ne!(memory_usage_step_1, memory_usage_step_2);
386
387 let mut store = TestGroupStore::new(None, Some(memory_usage_step_2));
388 assert_eq!(store.guaranteed_capacity(), 0);
389 assert_eq!(store.memory_usage(), 0);
390 store.try_insert(1usize.to_le_bytes().to_vec(), "10".into()).unwrap();
391 assert_eq!(store.guaranteed_capacity(), 3);
392 assert!(store.memory_usage() <= memory_usage_step_1);
393 store.try_insert(2usize.to_le_bytes().to_vec(), "20".into()).unwrap();
394 store.try_insert(3usize.to_le_bytes().to_vec(), "30".into()).unwrap();
395 for i in 1..=3usize {
396 assert_eq!(
397 *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
398 DummyValue::from(format!("{}", i * 10).as_str())
399 );
400 }
401 assert_eq!(store.guaranteed_capacity(), 3);
402 assert!(store.memory_usage() <= memory_usage_step_1);
403 assert!(store.try_insert(4usize.to_le_bytes().to_vec(), "40".into()).is_ok());
404 for i in (1usize..=4).rev() {
405 assert_eq!(
406 *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
407 DummyValue::from(format!("{}", i * 10).as_str())
408 );
409 }
410 assert_eq!(store.guaranteed_capacity(), 7);
411 assert!(store.memory_usage() <= memory_usage_step_2);
412 store.try_insert(5usize.to_le_bytes().to_vec(), "50".into()).unwrap();
413 store.try_insert(6usize.to_le_bytes().to_vec(), "60".into()).unwrap();
414 store.try_insert(7usize.to_le_bytes().to_vec(), "70".into()).unwrap();
415 for i in (5usize..=7).rev() {
416 store.get(i.to_le_bytes().as_ref()).unwrap();
417 }
418
419 store.insert(8usize.to_le_bytes().to_vec(), "80".into());
420 for i in [8usize, 7, 6, 5].iter() {
421 assert_eq!(
422 *(store
423 .get(i.to_le_bytes().as_ref())
424 .unwrap_or_else(|| panic!("couldn't find index {i}"))
425 .read()
426 .await),
427 DummyValue::from(format!("{}", i * 10).as_str())
428 );
429 }
430
431 assert_eq!(store.guaranteed_capacity(), 7);
432 assert!(store.memory_usage() <= memory_usage_step_2);
433 store.assert_check_internal_state();
434 }
435}