core_crypto/
group_store.rs

1use std::sync::Arc;
2
3use crate::{KeystoreError, RecursiveError, Result, prelude::MlsConversation};
4use core_crypto_keystore::connection::FetchFromDatabase;
5
6#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
7#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
8pub(crate) trait GroupStoreEntity: std::fmt::Debug {
9    type RawStoreValue: core_crypto_keystore::entities::Entity;
10    type IdentityType;
11
12    async fn fetch_from_id(
13        id: &[u8],
14        identity: Option<Self::IdentityType>,
15        keystore: &impl FetchFromDatabase,
16    ) -> Result<Option<Self>>
17    where
18        Self: Sized;
19}
20
21#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
22#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
23impl GroupStoreEntity for MlsConversation {
24    type RawStoreValue = core_crypto_keystore::entities::PersistedMlsGroup;
25    type IdentityType = ();
26
27    async fn fetch_from_id(
28        id: &[u8],
29        _: Option<Self::IdentityType>,
30        keystore: &impl FetchFromDatabase,
31    ) -> crate::Result<Option<Self>> {
32        let result = keystore
33            .find::<Self::RawStoreValue>(id)
34            .await
35            .map_err(KeystoreError::wrap("finding mls conversation from keystore by id"))?;
36        let Some(store_value) = result else {
37            return Ok(None);
38        };
39
40        let conversation = Self::from_serialized_state(store_value.state.clone(), store_value.parent_id.clone())
41            .map_err(RecursiveError::mls_conversation("deserializing mls conversation"))?;
42        // If the conversation is not active, pretend it doesn't exist
43        Ok(conversation.group.is_active().then_some(conversation))
44    }
45}
46
47pub(crate) type GroupStoreValue<V> = Arc<async_lock::RwLock<V>>;
48
49pub(crate) type LruMap<V> = schnellru::LruMap<Vec<u8>, GroupStoreValue<V>, HybridMemoryLimiter>;
50
51/// LRU-cache based group/session store
52/// Uses a hybrid memory limiter based on both amount of elements and total memory usage
53/// As with all LRU caches, eviction is based on oldest elements
54pub(crate) struct GroupStore<V: GroupStoreEntity>(LruMap<V>);
55
56impl<V: GroupStoreEntity> std::fmt::Debug for GroupStore<V> {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("GroupStore")
59            .field("length", &self.0.len())
60            .field("memory_usage", &self.0.memory_usage())
61            .field(
62                "entries",
63                &self
64                    .0
65                    .iter()
66                    .map(|(k, v)| format!("{k:?}={v:?}"))
67                    .collect::<Vec<String>>()
68                    .join("\n"),
69            )
70            .finish()
71    }
72}
73
74impl<V: GroupStoreEntity> Default for GroupStore<V> {
75    fn default() -> Self {
76        Self(schnellru::LruMap::default())
77    }
78}
79
80#[cfg(test)]
81impl<V: GroupStoreEntity> std::ops::Deref for GroupStore<V> {
82    type Target = LruMap<V>;
83
84    fn deref(&self) -> &Self::Target {
85        &self.0
86    }
87}
88
89#[cfg(test)]
90impl<V: GroupStoreEntity> std::ops::DerefMut for GroupStore<V> {
91    fn deref_mut(&mut self) -> &mut Self::Target {
92        &mut self.0
93    }
94}
95
96impl<V: GroupStoreEntity> GroupStore<V> {
97    pub(crate) fn new_with_limit(len: u32) -> Self {
98        let limiter = HybridMemoryLimiter::new(Some(len), None);
99        let store = schnellru::LruMap::new(limiter);
100        Self(store)
101    }
102
103    #[cfg(test)]
104    fn new(count: Option<u32>, memory: Option<usize>) -> Self {
105        let limiter = HybridMemoryLimiter::new(count, memory);
106        let store = schnellru::LruMap::new(limiter);
107        Self(store)
108    }
109
110    #[cfg(test)]
111    fn contains_key(&self, k: &[u8]) -> bool {
112        self.0.peek(k).is_some()
113    }
114
115    pub(crate) async fn get_fetch(
116        &mut self,
117        k: &[u8],
118        keystore: &impl FetchFromDatabase,
119        identity: Option<V::IdentityType>,
120    ) -> crate::Result<Option<GroupStoreValue<V>>> {
121        // Optimistic cache lookup
122        if let Some(value) = self.0.get(k) {
123            return Ok(Some(value.clone()));
124        }
125
126        // Not in store, fetch the thing in the keystore
127        let inserted_value = V::fetch_from_id(k, identity, keystore).await?.map(|value| {
128            let value_to_insert = Arc::new(async_lock::RwLock::new(value));
129            self.insert_prepped(k.to_vec(), value_to_insert.clone());
130            value_to_insert
131        });
132        Ok(inserted_value)
133    }
134
135    /// Returns the value from the keystore.
136    /// WARNING: the returned value is not attached to the keystore and mutations on it will be
137    /// lost when the object is dropped
138    pub(crate) async fn fetch_from_keystore(
139        k: &[u8],
140        keystore: &impl FetchFromDatabase,
141        identity: Option<V::IdentityType>,
142    ) -> crate::Result<Option<V>> {
143        V::fetch_from_id(k, identity, keystore).await
144    }
145
146    fn insert_prepped(&mut self, k: Vec<u8>, prepped_entity: GroupStoreValue<V>) {
147        self.0.insert(k, prepped_entity);
148    }
149
150    pub(crate) fn insert(&mut self, k: Vec<u8>, entity: V) {
151        let value_to_insert = Arc::new(async_lock::RwLock::new(entity));
152        self.insert_prepped(k, value_to_insert)
153    }
154
155    pub(crate) fn try_insert(&mut self, k: Vec<u8>, entity: V) -> Result<(), V> {
156        let value_to_insert = Arc::new(async_lock::RwLock::new(entity));
157
158        if self.0.insert(k, value_to_insert.clone()) {
159            Ok(())
160        } else {
161            // This is safe because we just built the value
162            Err(Arc::into_inner(value_to_insert).unwrap().into_inner())
163        }
164    }
165
166    pub(crate) fn remove(&mut self, k: &[u8]) -> Option<GroupStoreValue<V>> {
167        self.0.remove(k)
168    }
169
170    pub(crate) fn get(&mut self, k: &[u8]) -> Option<&mut GroupStoreValue<V>> {
171        self.0.get(k)
172    }
173}
174
175pub(crate) struct HybridMemoryLimiter {
176    mem: schnellru::ByMemoryUsage,
177    len: schnellru::ByLength,
178}
179
180pub(crate) const MEMORY_LIMIT: usize = 100_000_000;
181pub(crate) const ITEM_LIMIT: u32 = 100;
182
183impl HybridMemoryLimiter {
184    pub(crate) fn new(count: Option<u32>, memory: Option<usize>) -> Self {
185        let memory_limit = memory.unwrap_or(MEMORY_LIMIT);
186        let mem = schnellru::ByMemoryUsage::new(memory_limit);
187        let len = schnellru::ByLength::new(count.unwrap_or(ITEM_LIMIT));
188
189        Self { mem, len }
190    }
191}
192
193impl Default for HybridMemoryLimiter {
194    fn default() -> Self {
195        Self::new(None, None)
196    }
197}
198
199impl<K, V> schnellru::Limiter<K, V> for HybridMemoryLimiter {
200    type KeyToInsert<'a> = K;
201    type LinkType = u32;
202
203    fn is_over_the_limit(&self, length: usize) -> bool {
204        <schnellru::ByLength as schnellru::Limiter<K, V>>::is_over_the_limit(&self.len, length)
205    }
206
207    fn on_insert(&mut self, length: usize, key: Self::KeyToInsert<'_>, value: V) -> Option<(K, V)> {
208        <schnellru::ByLength as schnellru::Limiter<K, V>>::on_insert(&mut self.len, length, key, value)
209    }
210
211    // Both underlying limiters have dummy implementations here
212    fn on_replace(
213        &mut self,
214        _length: usize,
215        _old_key: &mut K,
216        _new_key: Self::KeyToInsert<'_>,
217        _old_value: &mut V,
218        _new_value: &mut V,
219    ) -> bool {
220        true
221    }
222    fn on_removed(&mut self, _key: &mut K, _value: &mut V) {}
223    fn on_cleared(&mut self) {}
224
225    fn on_grow(&mut self, new_memory_usage: usize) -> bool {
226        <schnellru::ByMemoryUsage as schnellru::Limiter<K, V>>::on_grow(&mut self.mem, new_memory_usage)
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use core_crypto_keystore::dummy_entity::{DummyStoreValue, DummyValue};
233    use wasm_bindgen_test::*;
234
235    use super::*;
236
237    wasm_bindgen_test_configure!(run_in_browser);
238
239    #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
240    #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
241    impl GroupStoreEntity for DummyValue {
242        type RawStoreValue = DummyStoreValue;
243
244        type IdentityType = ();
245
246        async fn fetch_from_id(
247            id: &[u8],
248            _identity: Option<Self::IdentityType>,
249            _keystore: &impl FetchFromDatabase,
250        ) -> crate::Result<Option<Self>> {
251            // it's not worth adding a variant to the Error type here to handle test dummy values
252            let id = std::str::from_utf8(id).expect("dummy value ids are strings");
253            Ok(Some(id.into()))
254        }
255    }
256
257    type TestGroupStore = GroupStore<DummyValue>;
258
259    #[async_std::test]
260    #[wasm_bindgen_test]
261    async fn group_store_init() {
262        let store = TestGroupStore::new_with_limit(1);
263        assert_eq!(store.len(), 0);
264        let store = TestGroupStore::new_with_limit(0);
265        assert_eq!(store.len(), 0);
266        let store = TestGroupStore::new(Some(0), Some(0));
267        assert_eq!(store.len(), 0);
268        let store = TestGroupStore::new(Some(0), Some(1));
269        assert_eq!(store.len(), 0);
270        let store = TestGroupStore::new(Some(1), Some(0));
271        assert_eq!(store.len(), 0);
272        let store = TestGroupStore::new(Some(1), Some(1));
273        assert_eq!(store.len(), 0);
274    }
275
276    #[async_std::test]
277    #[wasm_bindgen_test]
278    async fn group_store_common_ops() {
279        let mut store = TestGroupStore::new(Some(u32::MAX), Some(usize::MAX));
280        for i in 1..=3 {
281            let i_str = i.to_string();
282            assert!(
283                store
284                    .try_insert(i_str.as_bytes().to_vec(), i_str.as_str().into())
285                    .is_ok()
286            );
287            assert_eq!(store.len(), i);
288        }
289        for i in 4..=6 {
290            let i_str = i.to_string();
291            store.insert(i_str.as_bytes().to_vec(), i_str.as_str().into());
292            assert_eq!(store.len(), i);
293        }
294
295        for i in 1..=6 {
296            assert!(store.contains_key(i.to_string().as_bytes()));
297        }
298    }
299
300    #[async_std::test]
301    #[wasm_bindgen_test]
302    async fn group_store_operations_len_limiter() {
303        let mut store = TestGroupStore::new_with_limit(2);
304        assert!(store.try_insert(b"1".to_vec(), "1".into()).is_ok());
305        assert_eq!(store.len(), 1);
306        assert!(store.try_insert(b"2".to_vec(), "2".into()).is_ok());
307        assert_eq!(store.len(), 2);
308        assert!(store.try_insert(b"3".to_vec(), "3".into()).is_ok());
309        assert_eq!(store.len(), 2);
310        assert!(!store.contains_key(b"1"));
311        assert!(store.contains_key(b"2"));
312        assert!(store.contains_key(b"3"));
313        store.insert(b"4".to_vec(), "4".into());
314        assert_eq!(store.len(), 2);
315    }
316
317    #[async_std::test]
318    #[wasm_bindgen_test]
319    async fn group_store_operations_mem_limiter() {
320        use schnellru::{LruMap, UnlimitedCompact};
321        let mut lru: LruMap<Vec<u8>, DummyValue, UnlimitedCompact> =
322            LruMap::<Vec<u8>, DummyValue, UnlimitedCompact>::new(UnlimitedCompact);
323        assert_eq!(lru.guaranteed_capacity(), 0);
324        assert_eq!(lru.memory_usage(), 0);
325        lru.insert(1usize.to_le_bytes().to_vec(), "10".into());
326        let memory_usage_step_1 = lru.memory_usage();
327        lru.insert(2usize.to_le_bytes().to_vec(), "20".into());
328        lru.insert(3usize.to_le_bytes().to_vec(), "30".into());
329        lru.insert(4usize.to_le_bytes().to_vec(), "40".into());
330        let memory_usage_step_2 = lru.memory_usage();
331        assert_ne!(memory_usage_step_1, memory_usage_step_2);
332
333        let mut store = TestGroupStore::new(None, Some(memory_usage_step_2));
334        assert_eq!(store.guaranteed_capacity(), 0);
335        assert_eq!(store.memory_usage(), 0);
336        store.try_insert(1usize.to_le_bytes().to_vec(), "10".into()).unwrap();
337        assert_eq!(store.guaranteed_capacity(), 3);
338        assert!(store.memory_usage() <= memory_usage_step_1);
339        store.try_insert(2usize.to_le_bytes().to_vec(), "20".into()).unwrap();
340        store.try_insert(3usize.to_le_bytes().to_vec(), "30".into()).unwrap();
341        for i in 1..=3usize {
342            assert_eq!(
343                *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
344                DummyValue::from(format!("{}", i * 10).as_str())
345            );
346        }
347        assert_eq!(store.guaranteed_capacity(), 3);
348        assert!(store.memory_usage() <= memory_usage_step_1);
349        assert!(store.try_insert(4usize.to_le_bytes().to_vec(), "40".into()).is_ok());
350        for i in (1usize..=4).rev() {
351            assert_eq!(
352                *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
353                DummyValue::from(format!("{}", i * 10).as_str())
354            );
355        }
356        assert_eq!(store.guaranteed_capacity(), 7);
357        assert!(store.memory_usage() <= memory_usage_step_2);
358        store.try_insert(5usize.to_le_bytes().to_vec(), "50".into()).unwrap();
359        store.try_insert(6usize.to_le_bytes().to_vec(), "60".into()).unwrap();
360        store.try_insert(7usize.to_le_bytes().to_vec(), "70".into()).unwrap();
361        for i in (5usize..=7).rev() {
362            store.get(i.to_le_bytes().as_ref()).unwrap();
363        }
364
365        store.insert(8usize.to_le_bytes().to_vec(), "80".into());
366        for i in [8usize, 7, 6, 5].iter() {
367            assert_eq!(
368                *(store
369                    .get(i.to_le_bytes().as_ref())
370                    .unwrap_or_else(|| panic!("couldn't find index {i}"))
371                    .read()
372                    .await),
373                DummyValue::from(format!("{}", i * 10).as_str())
374            );
375        }
376
377        assert_eq!(store.guaranteed_capacity(), 7);
378        assert!(store.memory_usage() <= memory_usage_step_2);
379        store.assert_check_internal_state();
380    }
381}