core_crypto/
group_store.rs

1use std::sync::Arc;
2
3use crate::{KeystoreError, RecursiveError, Result, prelude::MlsConversation};
4use core_crypto_keystore::connection::FetchFromDatabase;
5#[cfg(test)]
6use core_crypto_keystore::entities::EntityFindParams;
7
8#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
9#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
10pub(crate) trait GroupStoreEntity: std::fmt::Debug {
11    type RawStoreValue: core_crypto_keystore::entities::Entity;
12    type IdentityType;
13
14    #[cfg(test)]
15    fn id(&self) -> &[u8];
16
17    async fn fetch_from_id(
18        id: &[u8],
19        identity: Option<Self::IdentityType>,
20        keystore: &impl FetchFromDatabase,
21    ) -> Result<Option<Self>>
22    where
23        Self: Sized;
24
25    #[cfg(test)]
26    async fn fetch_all(keystore: &impl FetchFromDatabase) -> Result<Vec<Self>>
27    where
28        Self: Sized;
29}
30
31#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
32#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
33impl GroupStoreEntity for MlsConversation {
34    type RawStoreValue = core_crypto_keystore::entities::PersistedMlsGroup;
35    type IdentityType = ();
36
37    #[cfg(test)]
38    fn id(&self) -> &[u8] {
39        self.id().as_slice()
40    }
41
42    async fn fetch_from_id(
43        id: &[u8],
44        _: Option<Self::IdentityType>,
45        keystore: &impl FetchFromDatabase,
46    ) -> crate::Result<Option<Self>> {
47        let result = keystore
48            .find::<Self::RawStoreValue>(id)
49            .await
50            .map_err(KeystoreError::wrap("finding mls conversation from keystore by id"))?;
51        let Some(store_value) = result else {
52            return Ok(None);
53        };
54
55        let conversation = Self::from_serialized_state(store_value.state.clone(), store_value.parent_id.clone())
56            .map_err(RecursiveError::mls_conversation("deserializing mls conversation"))?;
57        // If the conversation is not active, pretend it doesn't exist
58        Ok(conversation.group.is_active().then_some(conversation))
59    }
60
61    #[cfg(test)]
62    async fn fetch_all(keystore: &impl FetchFromDatabase) -> Result<Vec<Self>> {
63        let all_conversations = keystore
64            .find_all::<Self::RawStoreValue>(EntityFindParams::default())
65            .await
66            .map_err(KeystoreError::wrap("finding all mls conversations"))?;
67        Ok(all_conversations
68            .iter()
69            .filter_map(|c| {
70                let conversation = Self::from_serialized_state(c.state.clone(), c.parent_id.clone()).unwrap();
71                conversation.group.is_active().then_some(conversation)
72            })
73            .collect::<Vec<_>>())
74    }
75}
76
77pub(crate) type GroupStoreValue<V> = Arc<async_lock::RwLock<V>>;
78
79pub(crate) type LruMap<V> = schnellru::LruMap<Vec<u8>, GroupStoreValue<V>, HybridMemoryLimiter>;
80
81/// LRU-cache based group/session store
82/// Uses a hybrid memory limiter based on both amount of elements and total memory usage
83/// As with all LRU caches, eviction is based on oldest elements
84pub(crate) struct GroupStore<V: GroupStoreEntity>(LruMap<V>);
85
86impl<V: GroupStoreEntity> std::fmt::Debug for GroupStore<V> {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct("GroupStore")
89            .field("length", &self.0.len())
90            .field("memory_usage", &self.0.memory_usage())
91            .field(
92                "entries",
93                &self
94                    .0
95                    .iter()
96                    .map(|(k, v)| format!("{k:?}={v:?}"))
97                    .collect::<Vec<String>>()
98                    .join("\n"),
99            )
100            .finish()
101    }
102}
103
104impl<V: GroupStoreEntity> Default for GroupStore<V> {
105    fn default() -> Self {
106        Self(schnellru::LruMap::default())
107    }
108}
109
110#[cfg(test)]
111impl<V: GroupStoreEntity> std::ops::Deref for GroupStore<V> {
112    type Target = LruMap<V>;
113
114    fn deref(&self) -> &Self::Target {
115        &self.0
116    }
117}
118
119#[cfg(test)]
120impl<V: GroupStoreEntity> std::ops::DerefMut for GroupStore<V> {
121    fn deref_mut(&mut self) -> &mut Self::Target {
122        &mut self.0
123    }
124}
125
126impl<V: GroupStoreEntity> GroupStore<V> {
127    pub(crate) fn new_with_limit(len: u32) -> Self {
128        let limiter = HybridMemoryLimiter::new(Some(len), None);
129        let store = schnellru::LruMap::new(limiter);
130        Self(store)
131    }
132
133    #[cfg(test)]
134    fn new(count: Option<u32>, memory: Option<usize>) -> Self {
135        let limiter = HybridMemoryLimiter::new(count, memory);
136        let store = schnellru::LruMap::new(limiter);
137        Self(store)
138    }
139
140    #[cfg(test)]
141    fn contains_key(&self, k: &[u8]) -> bool {
142        self.0.peek(k).is_some()
143    }
144
145    pub(crate) async fn get_fetch(
146        &mut self,
147        k: &[u8],
148        keystore: &impl FetchFromDatabase,
149        identity: Option<V::IdentityType>,
150    ) -> crate::Result<Option<GroupStoreValue<V>>> {
151        // Optimistic cache lookup
152        if let Some(value) = self.0.get(k) {
153            return Ok(Some(value.clone()));
154        }
155
156        // Not in store, fetch the thing in the keystore
157        let inserted_value = V::fetch_from_id(k, identity, keystore).await?.map(|value| {
158            let value_to_insert = Arc::new(async_lock::RwLock::new(value));
159            self.insert_prepped(k.to_vec(), value_to_insert.clone());
160            value_to_insert
161        });
162        Ok(inserted_value)
163    }
164
165    /// Returns the value from the keystore.
166    /// WARNING: the returned value is not attached to the keystore and mutations on it will be
167    /// lost when the object is dropped
168    pub(crate) async fn fetch_from_keystore(
169        k: &[u8],
170        keystore: &impl FetchFromDatabase,
171        identity: Option<V::IdentityType>,
172    ) -> crate::Result<Option<V>> {
173        V::fetch_from_id(k, identity, keystore).await
174    }
175
176    #[cfg(test)]
177    pub(crate) async fn get_fetch_all(&mut self, keystore: &impl FetchFromDatabase) -> Result<Vec<GroupStoreValue<V>>> {
178        let all = V::fetch_all(keystore)
179            .await?
180            .into_iter()
181            .map(|g| {
182                let id = g.id().to_vec();
183                let to_insert = Arc::new(async_lock::RwLock::new(g));
184                self.insert_prepped(id, to_insert.clone());
185                to_insert
186            })
187            .collect::<Vec<_>>();
188        Ok(all)
189    }
190
191    fn insert_prepped(&mut self, k: Vec<u8>, prepped_entity: GroupStoreValue<V>) {
192        self.0.insert(k, prepped_entity);
193    }
194
195    pub(crate) fn insert(&mut self, k: Vec<u8>, entity: V) {
196        let value_to_insert = Arc::new(async_lock::RwLock::new(entity));
197        self.insert_prepped(k, value_to_insert)
198    }
199
200    pub(crate) fn try_insert(&mut self, k: Vec<u8>, entity: V) -> Result<(), V> {
201        let value_to_insert = Arc::new(async_lock::RwLock::new(entity));
202
203        if self.0.insert(k, value_to_insert.clone()) {
204            Ok(())
205        } else {
206            // This is safe because we just built the value
207            Err(Arc::into_inner(value_to_insert).unwrap().into_inner())
208        }
209    }
210
211    pub(crate) fn remove(&mut self, k: &[u8]) -> Option<GroupStoreValue<V>> {
212        self.0.remove(k)
213    }
214
215    pub(crate) fn get(&mut self, k: &[u8]) -> Option<&mut GroupStoreValue<V>> {
216        self.0.get(k)
217    }
218}
219
220pub(crate) struct HybridMemoryLimiter {
221    mem: schnellru::ByMemoryUsage,
222    len: schnellru::ByLength,
223}
224
225pub(crate) const MEMORY_LIMIT: usize = 100_000_000;
226pub(crate) const ITEM_LIMIT: u32 = 100;
227
228impl HybridMemoryLimiter {
229    pub(crate) fn new(count: Option<u32>, memory: Option<usize>) -> Self {
230        let memory_limit = memory.unwrap_or(MEMORY_LIMIT);
231        let mem = schnellru::ByMemoryUsage::new(memory_limit);
232        let len = schnellru::ByLength::new(count.unwrap_or(ITEM_LIMIT));
233
234        Self { mem, len }
235    }
236}
237
238impl Default for HybridMemoryLimiter {
239    fn default() -> Self {
240        Self::new(None, None)
241    }
242}
243
244impl<K, V> schnellru::Limiter<K, V> for HybridMemoryLimiter {
245    type KeyToInsert<'a> = K;
246    type LinkType = u32;
247
248    fn is_over_the_limit(&self, length: usize) -> bool {
249        <schnellru::ByLength as schnellru::Limiter<K, V>>::is_over_the_limit(&self.len, length)
250    }
251
252    fn on_insert(&mut self, length: usize, key: Self::KeyToInsert<'_>, value: V) -> Option<(K, V)> {
253        <schnellru::ByLength as schnellru::Limiter<K, V>>::on_insert(&mut self.len, length, key, value)
254    }
255
256    // Both underlying limiters have dummy implementations here
257    fn on_replace(
258        &mut self,
259        _length: usize,
260        _old_key: &mut K,
261        _new_key: Self::KeyToInsert<'_>,
262        _old_value: &mut V,
263        _new_value: &mut V,
264    ) -> bool {
265        true
266    }
267    fn on_removed(&mut self, _key: &mut K, _value: &mut V) {}
268    fn on_cleared(&mut self) {}
269
270    fn on_grow(&mut self, new_memory_usage: usize) -> bool {
271        <schnellru::ByMemoryUsage as schnellru::Limiter<K, V>>::on_grow(&mut self.mem, new_memory_usage)
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use core_crypto_keystore::dummy_entity::{DummyStoreValue, DummyValue};
278    use wasm_bindgen_test::*;
279
280    use super::*;
281
282    wasm_bindgen_test_configure!(run_in_browser);
283
284    #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
285    #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
286    impl GroupStoreEntity for DummyValue {
287        type RawStoreValue = DummyStoreValue;
288
289        type IdentityType = ();
290
291        fn id(&self) -> &[u8] {
292            unreachable!()
293        }
294
295        async fn fetch_from_id(
296            id: &[u8],
297            _identity: Option<Self::IdentityType>,
298            _keystore: &impl FetchFromDatabase,
299        ) -> crate::Result<Option<Self>> {
300            // it's not worth adding a variant to the Error type here to handle test dummy values
301            let id = std::str::from_utf8(id).expect("dummy value ids are strings");
302            Ok(Some(id.into()))
303        }
304
305        #[cfg(test)]
306        async fn fetch_all(_keystore: &impl FetchFromDatabase) -> Result<Vec<Self>> {
307            unreachable!()
308        }
309    }
310
311    type TestGroupStore = GroupStore<DummyValue>;
312
313    #[async_std::test]
314    #[wasm_bindgen_test]
315    async fn group_store_init() {
316        let store = TestGroupStore::new_with_limit(1);
317        assert_eq!(store.len(), 0);
318        let store = TestGroupStore::new_with_limit(0);
319        assert_eq!(store.len(), 0);
320        let store = TestGroupStore::new(Some(0), Some(0));
321        assert_eq!(store.len(), 0);
322        let store = TestGroupStore::new(Some(0), Some(1));
323        assert_eq!(store.len(), 0);
324        let store = TestGroupStore::new(Some(1), Some(0));
325        assert_eq!(store.len(), 0);
326        let store = TestGroupStore::new(Some(1), Some(1));
327        assert_eq!(store.len(), 0);
328    }
329
330    #[async_std::test]
331    #[wasm_bindgen_test]
332    async fn group_store_common_ops() {
333        let mut store = TestGroupStore::new(Some(u32::MAX), Some(usize::MAX));
334        for i in 1..=3 {
335            let i_str = i.to_string();
336            assert!(
337                store
338                    .try_insert(i_str.as_bytes().to_vec(), i_str.as_str().into())
339                    .is_ok()
340            );
341            assert_eq!(store.len(), i);
342        }
343        for i in 4..=6 {
344            let i_str = i.to_string();
345            store.insert(i_str.as_bytes().to_vec(), i_str.as_str().into());
346            assert_eq!(store.len(), i);
347        }
348
349        for i in 1..=6 {
350            assert!(store.contains_key(i.to_string().as_bytes()));
351        }
352    }
353
354    #[async_std::test]
355    #[wasm_bindgen_test]
356    async fn group_store_operations_len_limiter() {
357        let mut store = TestGroupStore::new_with_limit(2);
358        assert!(store.try_insert(b"1".to_vec(), "1".into()).is_ok());
359        assert_eq!(store.len(), 1);
360        assert!(store.try_insert(b"2".to_vec(), "2".into()).is_ok());
361        assert_eq!(store.len(), 2);
362        assert!(store.try_insert(b"3".to_vec(), "3".into()).is_ok());
363        assert_eq!(store.len(), 2);
364        assert!(!store.contains_key(b"1"));
365        assert!(store.contains_key(b"2"));
366        assert!(store.contains_key(b"3"));
367        store.insert(b"4".to_vec(), "4".into());
368        assert_eq!(store.len(), 2);
369    }
370
371    #[async_std::test]
372    #[wasm_bindgen_test]
373    async fn group_store_operations_mem_limiter() {
374        use schnellru::{LruMap, UnlimitedCompact};
375        let mut lru: LruMap<Vec<u8>, DummyValue, UnlimitedCompact> =
376            LruMap::<Vec<u8>, DummyValue, UnlimitedCompact>::new(UnlimitedCompact);
377        assert_eq!(lru.guaranteed_capacity(), 0);
378        assert_eq!(lru.memory_usage(), 0);
379        lru.insert(1usize.to_le_bytes().to_vec(), "10".into());
380        let memory_usage_step_1 = lru.memory_usage();
381        lru.insert(2usize.to_le_bytes().to_vec(), "20".into());
382        lru.insert(3usize.to_le_bytes().to_vec(), "30".into());
383        lru.insert(4usize.to_le_bytes().to_vec(), "40".into());
384        let memory_usage_step_2 = lru.memory_usage();
385        assert_ne!(memory_usage_step_1, memory_usage_step_2);
386
387        let mut store = TestGroupStore::new(None, Some(memory_usage_step_2));
388        assert_eq!(store.guaranteed_capacity(), 0);
389        assert_eq!(store.memory_usage(), 0);
390        store.try_insert(1usize.to_le_bytes().to_vec(), "10".into()).unwrap();
391        assert_eq!(store.guaranteed_capacity(), 3);
392        assert!(store.memory_usage() <= memory_usage_step_1);
393        store.try_insert(2usize.to_le_bytes().to_vec(), "20".into()).unwrap();
394        store.try_insert(3usize.to_le_bytes().to_vec(), "30".into()).unwrap();
395        for i in 1..=3usize {
396            assert_eq!(
397                *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
398                DummyValue::from(format!("{}", i * 10).as_str())
399            );
400        }
401        assert_eq!(store.guaranteed_capacity(), 3);
402        assert!(store.memory_usage() <= memory_usage_step_1);
403        assert!(store.try_insert(4usize.to_le_bytes().to_vec(), "40".into()).is_ok());
404        for i in (1usize..=4).rev() {
405            assert_eq!(
406                *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
407                DummyValue::from(format!("{}", i * 10).as_str())
408            );
409        }
410        assert_eq!(store.guaranteed_capacity(), 7);
411        assert!(store.memory_usage() <= memory_usage_step_2);
412        store.try_insert(5usize.to_le_bytes().to_vec(), "50".into()).unwrap();
413        store.try_insert(6usize.to_le_bytes().to_vec(), "60".into()).unwrap();
414        store.try_insert(7usize.to_le_bytes().to_vec(), "70".into()).unwrap();
415        for i in (5usize..=7).rev() {
416            store.get(i.to_le_bytes().as_ref()).unwrap();
417        }
418
419        store.insert(8usize.to_le_bytes().to_vec(), "80".into());
420        for i in [8usize, 7, 6, 5].iter() {
421            assert_eq!(
422                *(store
423                    .get(i.to_le_bytes().as_ref())
424                    .unwrap_or_else(|| panic!("couldn't find index {i}"))
425                    .read()
426                    .await),
427                DummyValue::from(format!("{}", i * 10).as_str())
428            );
429        }
430
431        assert_eq!(store.guaranteed_capacity(), 7);
432        assert!(store.memory_usage() <= memory_usage_step_2);
433        store.assert_check_internal_state();
434    }
435}