core_crypto/
group_store.rs

1// Wire
2// Copyright (C) 2022 Wire Swiss GmbH
3
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with this program. If not, see http://www.gnu.org/licenses/.
16
17use std::sync::Arc;
18
19use crate::{KeystoreError, ProteusError, RecursiveError, Result, prelude::MlsConversation};
20use core_crypto_keystore::connection::FetchFromDatabase;
21#[cfg(test)]
22use core_crypto_keystore::entities::EntityFindParams;
23
24#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
25#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
26pub(crate) trait GroupStoreEntity: std::fmt::Debug {
27    type RawStoreValue: core_crypto_keystore::entities::Entity;
28    type IdentityType;
29
30    #[cfg(test)]
31    fn id(&self) -> &[u8];
32
33    async fn fetch_from_id(
34        id: &[u8],
35        identity: Option<Self::IdentityType>,
36        keystore: &impl FetchFromDatabase,
37    ) -> Result<Option<Self>>
38    where
39        Self: Sized;
40
41    #[cfg(test)]
42    async fn fetch_all(keystore: &impl FetchFromDatabase) -> Result<Vec<Self>>
43    where
44        Self: Sized;
45}
46
47#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
48#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
49impl GroupStoreEntity for MlsConversation {
50    type RawStoreValue = core_crypto_keystore::entities::PersistedMlsGroup;
51    type IdentityType = ();
52
53    #[cfg(test)]
54    fn id(&self) -> &[u8] {
55        self.id().as_slice()
56    }
57
58    async fn fetch_from_id(
59        id: &[u8],
60        _: Option<Self::IdentityType>,
61        keystore: &impl FetchFromDatabase,
62    ) -> crate::Result<Option<Self>> {
63        let result = keystore
64            .find::<Self::RawStoreValue>(id)
65            .await
66            .map_err(KeystoreError::wrap("finding mls conversation from keystore by id"))?;
67        let Some(store_value) = result else {
68            return Ok(None);
69        };
70
71        let conversation = Self::from_serialized_state(store_value.state.clone(), store_value.parent_id.clone())
72            .map_err(RecursiveError::mls_conversation("deserializing mls conversation"))?;
73        // If the conversation is not active, pretend it doesn't exist
74        Ok(conversation.group.is_active().then_some(conversation))
75    }
76
77    #[cfg(test)]
78    async fn fetch_all(keystore: &impl FetchFromDatabase) -> Result<Vec<Self>> {
79        let all_conversations = keystore
80            .find_all::<Self::RawStoreValue>(EntityFindParams::default())
81            .await
82            .map_err(KeystoreError::wrap("finding all mls conversations"))?;
83        Ok(all_conversations
84            .iter()
85            .filter_map(|c| {
86                let conversation = Self::from_serialized_state(c.state.clone(), c.parent_id.clone()).unwrap();
87                conversation.group.is_active().then_some(conversation)
88            })
89            .collect::<Vec<_>>())
90    }
91}
92
93#[cfg(feature = "proteus")]
94#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
95#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
96impl GroupStoreEntity for crate::proteus::ProteusConversationSession {
97    type RawStoreValue = core_crypto_keystore::entities::ProteusSession;
98    type IdentityType = Arc<proteus_wasm::keys::IdentityKeyPair>;
99
100    #[cfg(test)]
101    fn id(&self) -> &[u8] {
102        unreachable!()
103    }
104
105    async fn fetch_from_id(
106        id: &[u8],
107        identity: Option<Self::IdentityType>,
108        keystore: &impl FetchFromDatabase,
109    ) -> crate::Result<Option<Self>> {
110        let result = keystore
111            .find::<Self::RawStoreValue>(id)
112            .await
113            .map_err(KeystoreError::wrap("finding raw group store entity by id"))?;
114        let Some(store_value) = result else {
115            return Ok(None);
116        };
117
118        let Some(identity) = identity else {
119            return Err(crate::Error::ProteusNotInitialized);
120        };
121
122        let session = proteus_wasm::session::Session::deserialise(identity, &store_value.session)
123            .map_err(ProteusError::wrap("deserializing session"))?;
124
125        Ok(Some(Self {
126            identifier: store_value.id.clone(),
127            session,
128        }))
129    }
130
131    #[cfg(test)]
132    async fn fetch_all(_keystore: &impl FetchFromDatabase) -> Result<Vec<Self>>
133    where
134        Self: Sized,
135    {
136        unreachable!()
137    }
138}
139
140pub(crate) type GroupStoreValue<V> = Arc<async_lock::RwLock<V>>;
141
142pub(crate) type LruMap<V> = schnellru::LruMap<Vec<u8>, GroupStoreValue<V>, HybridMemoryLimiter>;
143
144/// LRU-cache based group/session store
145/// Uses a hybrid memory limiter based on both amount of elements and total memory usage
146/// As with all LRU caches, eviction is based on oldest elements
147pub(crate) struct GroupStore<V: GroupStoreEntity>(LruMap<V>);
148
149impl<V: GroupStoreEntity> std::fmt::Debug for GroupStore<V> {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        f.debug_struct("GroupStore")
152            .field("length", &self.0.len())
153            .field("memory_usage", &self.0.memory_usage())
154            .field(
155                "entries",
156                &self
157                    .0
158                    .iter()
159                    .map(|(k, v)| format!("{k:?}={v:?}"))
160                    .collect::<Vec<String>>()
161                    .join("\n"),
162            )
163            .finish()
164    }
165}
166
167impl<V: GroupStoreEntity> Default for GroupStore<V> {
168    fn default() -> Self {
169        Self(schnellru::LruMap::default())
170    }
171}
172
173#[cfg(test)]
174impl<V: GroupStoreEntity> std::ops::Deref for GroupStore<V> {
175    type Target = LruMap<V>;
176
177    fn deref(&self) -> &Self::Target {
178        &self.0
179    }
180}
181
182#[cfg(test)]
183impl<V: GroupStoreEntity> std::ops::DerefMut for GroupStore<V> {
184    fn deref_mut(&mut self) -> &mut Self::Target {
185        &mut self.0
186    }
187}
188
189impl<V: GroupStoreEntity> GroupStore<V> {
190    #[allow(dead_code)]
191    pub(crate) fn new_with_limit(len: u32) -> Self {
192        let limiter = HybridMemoryLimiter::new(Some(len), None);
193        let store = schnellru::LruMap::new(limiter);
194        Self(store)
195    }
196
197    #[allow(dead_code)]
198    pub(crate) fn new(count: Option<u32>, memory: Option<usize>) -> Self {
199        let limiter = HybridMemoryLimiter::new(count, memory);
200        let store = schnellru::LruMap::new(limiter);
201        Self(store)
202    }
203
204    #[allow(dead_code)]
205    pub(crate) fn contains_key(&self, k: &[u8]) -> bool {
206        self.0.peek(k).is_some()
207    }
208
209    pub(crate) async fn get_fetch(
210        &mut self,
211        k: &[u8],
212        keystore: &impl FetchFromDatabase,
213        identity: Option<V::IdentityType>,
214    ) -> crate::Result<Option<GroupStoreValue<V>>> {
215        // Optimistic cache lookup
216        if let Some(value) = self.0.get(k) {
217            return Ok(Some(value.clone()));
218        }
219
220        // Not in store, fetch the thing in the keystore
221        let inserted_value = V::fetch_from_id(k, identity, keystore).await?.map(|value| {
222            let value_to_insert = Arc::new(async_lock::RwLock::new(value));
223            self.insert_prepped(k.to_vec(), value_to_insert.clone());
224            value_to_insert
225        });
226        Ok(inserted_value)
227    }
228
229    /// Returns the value from the keystore.
230    /// WARNING: the returned value is not attached to the keystore and mutations on it will be
231    /// lost when the object is dropped
232    pub(crate) async fn fetch_from_keystore(
233        k: &[u8],
234        keystore: &impl FetchFromDatabase,
235        identity: Option<V::IdentityType>,
236    ) -> crate::Result<Option<V>> {
237        V::fetch_from_id(k, identity, keystore).await
238    }
239
240    #[cfg(test)]
241    pub(crate) async fn get_fetch_all(&mut self, keystore: &impl FetchFromDatabase) -> Result<Vec<GroupStoreValue<V>>> {
242        let all = V::fetch_all(keystore)
243            .await?
244            .into_iter()
245            .map(|g| {
246                let id = g.id().to_vec();
247                let to_insert = Arc::new(async_lock::RwLock::new(g));
248                self.insert_prepped(id, to_insert.clone());
249                to_insert
250            })
251            .collect::<Vec<_>>();
252        Ok(all)
253    }
254
255    fn insert_prepped(&mut self, k: Vec<u8>, prepped_entity: GroupStoreValue<V>) {
256        self.0.insert(k, prepped_entity);
257    }
258
259    pub(crate) fn insert(&mut self, k: Vec<u8>, entity: V) {
260        let value_to_insert = Arc::new(async_lock::RwLock::new(entity));
261        self.insert_prepped(k, value_to_insert)
262    }
263
264    pub(crate) fn try_insert(&mut self, k: Vec<u8>, entity: V) -> Result<(), V> {
265        let value_to_insert = Arc::new(async_lock::RwLock::new(entity));
266
267        if self.0.insert(k, value_to_insert.clone()) {
268            Ok(())
269        } else {
270            // This is safe because we just built the value
271            Err(Arc::into_inner(value_to_insert).unwrap().into_inner())
272        }
273    }
274
275    pub(crate) fn remove(&mut self, k: &[u8]) -> Option<GroupStoreValue<V>> {
276        self.0.remove(k)
277    }
278
279    pub(crate) fn get(&mut self, k: &[u8]) -> Option<&mut GroupStoreValue<V>> {
280        self.0.get(k)
281    }
282}
283
284pub(crate) struct HybridMemoryLimiter {
285    mem: schnellru::ByMemoryUsage,
286    len: schnellru::ByLength,
287}
288
289pub(crate) const MEMORY_LIMIT: usize = 100_000_000;
290pub(crate) const ITEM_LIMIT: u32 = 100;
291
292impl HybridMemoryLimiter {
293    // in the wasm case, we ignore the suggested memory limit
294    #[cfg_attr(target_family = "wasm", expect(unused_variables))]
295    pub(crate) fn new(count: Option<u32>, memory: Option<usize>) -> Self {
296        #[cfg(target_family = "wasm")]
297        let memory_limit = MEMORY_LIMIT;
298
299        #[cfg(not(target_family = "wasm"))]
300        let memory_limit = memory
301            .or_else(|| {
302                let system = sysinfo::System::new_with_specifics(
303                    sysinfo::RefreshKind::nothing().with_memory(sysinfo::MemoryRefreshKind::nothing().with_ram()),
304                );
305
306                let available_sys_memory = system.available_memory();
307                (available_sys_memory > 0).then_some(available_sys_memory as usize)
308            })
309            .unwrap_or(MEMORY_LIMIT);
310
311        let mem = schnellru::ByMemoryUsage::new(memory_limit);
312        let len = schnellru::ByLength::new(count.unwrap_or(ITEM_LIMIT));
313
314        Self { mem, len }
315    }
316}
317
318impl Default for HybridMemoryLimiter {
319    fn default() -> Self {
320        Self::new(None, None)
321    }
322}
323
324impl<K, V> schnellru::Limiter<K, V> for HybridMemoryLimiter {
325    type KeyToInsert<'a> = K;
326    type LinkType = u32;
327
328    fn is_over_the_limit(&self, length: usize) -> bool {
329        <schnellru::ByLength as schnellru::Limiter<K, V>>::is_over_the_limit(&self.len, length)
330    }
331
332    fn on_insert(&mut self, length: usize, key: Self::KeyToInsert<'_>, value: V) -> Option<(K, V)> {
333        <schnellru::ByLength as schnellru::Limiter<K, V>>::on_insert(&mut self.len, length, key, value)
334    }
335
336    // Both underlying limiters have dummy implementations here
337    fn on_replace(
338        &mut self,
339        _length: usize,
340        _old_key: &mut K,
341        _new_key: Self::KeyToInsert<'_>,
342        _old_value: &mut V,
343        _new_value: &mut V,
344    ) -> bool {
345        true
346    }
347    fn on_removed(&mut self, _key: &mut K, _value: &mut V) {}
348    fn on_cleared(&mut self) {}
349
350    fn on_grow(&mut self, new_memory_usage: usize) -> bool {
351        <schnellru::ByMemoryUsage as schnellru::Limiter<K, V>>::on_grow(&mut self.mem, new_memory_usage)
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use core_crypto_keystore::dummy_entity::{DummyStoreValue, DummyValue};
358    use wasm_bindgen_test::*;
359
360    use super::*;
361
362    wasm_bindgen_test_configure!(run_in_browser);
363
364    #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
365    #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
366    impl GroupStoreEntity for DummyValue {
367        type RawStoreValue = DummyStoreValue;
368
369        type IdentityType = ();
370
371        fn id(&self) -> &[u8] {
372            unreachable!()
373        }
374
375        async fn fetch_from_id(
376            id: &[u8],
377            _identity: Option<Self::IdentityType>,
378            _keystore: &impl FetchFromDatabase,
379        ) -> crate::Result<Option<Self>> {
380            // it's not worth adding a variant to the Error type here to handle test dummy values
381            let id = std::str::from_utf8(id).expect("dummy value ids are strings");
382            Ok(Some(id.into()))
383        }
384
385        #[cfg(test)]
386        async fn fetch_all(_keystore: &impl FetchFromDatabase) -> Result<Vec<Self>> {
387            unreachable!()
388        }
389    }
390
391    type TestGroupStore = GroupStore<DummyValue>;
392
393    #[async_std::test]
394    #[wasm_bindgen_test]
395    async fn group_store_init() {
396        let store = TestGroupStore::new_with_limit(1);
397        assert_eq!(store.len(), 0);
398        let store = TestGroupStore::new_with_limit(0);
399        assert_eq!(store.len(), 0);
400        let store = TestGroupStore::new(Some(0), Some(0));
401        assert_eq!(store.len(), 0);
402        let store = TestGroupStore::new(Some(0), Some(1));
403        assert_eq!(store.len(), 0);
404        let store = TestGroupStore::new(Some(1), Some(0));
405        assert_eq!(store.len(), 0);
406        let store = TestGroupStore::new(Some(1), Some(1));
407        assert_eq!(store.len(), 0);
408    }
409
410    #[async_std::test]
411    #[wasm_bindgen_test]
412    async fn group_store_common_ops() {
413        let mut store = TestGroupStore::new(Some(u32::MAX), Some(usize::MAX));
414        for i in 1..=3 {
415            let i_str = i.to_string();
416            assert!(
417                store
418                    .try_insert(i_str.as_bytes().to_vec(), i_str.as_str().into())
419                    .is_ok()
420            );
421            assert_eq!(store.len(), i);
422        }
423        for i in 4..=6 {
424            let i_str = i.to_string();
425            store.insert(i_str.as_bytes().to_vec(), i_str.as_str().into());
426            assert_eq!(store.len(), i);
427        }
428
429        for i in 1..=6 {
430            assert!(store.contains_key(i.to_string().as_bytes()));
431        }
432    }
433
434    #[async_std::test]
435    #[wasm_bindgen_test]
436    async fn group_store_operations_len_limiter() {
437        let mut store = TestGroupStore::new_with_limit(2);
438        assert!(store.try_insert(b"1".to_vec(), "1".into()).is_ok());
439        assert_eq!(store.len(), 1);
440        assert!(store.try_insert(b"2".to_vec(), "2".into()).is_ok());
441        assert_eq!(store.len(), 2);
442        assert!(store.try_insert(b"3".to_vec(), "3".into()).is_ok());
443        assert_eq!(store.len(), 2);
444        assert!(!store.contains_key(b"1"));
445        assert!(store.contains_key(b"2"));
446        assert!(store.contains_key(b"3"));
447        store.insert(b"4".to_vec(), "4".into());
448        assert_eq!(store.len(), 2);
449    }
450
451    #[async_std::test]
452    #[wasm_bindgen_test]
453    async fn group_store_operations_mem_limiter() {
454        use schnellru::{LruMap, UnlimitedCompact};
455        let mut lru: LruMap<Vec<u8>, DummyValue, UnlimitedCompact> =
456            LruMap::<Vec<u8>, DummyValue, UnlimitedCompact>::new(UnlimitedCompact);
457        assert_eq!(lru.guaranteed_capacity(), 0);
458        assert_eq!(lru.memory_usage(), 0);
459        lru.insert(1usize.to_le_bytes().to_vec(), "10".into());
460        let memory_usage_step_1 = lru.memory_usage();
461        lru.insert(2usize.to_le_bytes().to_vec(), "20".into());
462        lru.insert(3usize.to_le_bytes().to_vec(), "30".into());
463        lru.insert(4usize.to_le_bytes().to_vec(), "40".into());
464        let memory_usage_step_2 = lru.memory_usage();
465        assert_ne!(memory_usage_step_1, memory_usage_step_2);
466
467        let mut store = TestGroupStore::new(None, Some(memory_usage_step_2));
468        assert_eq!(store.guaranteed_capacity(), 0);
469        assert_eq!(store.memory_usage(), 0);
470        store.try_insert(1usize.to_le_bytes().to_vec(), "10".into()).unwrap();
471        assert_eq!(store.guaranteed_capacity(), 3);
472        assert!(store.memory_usage() <= memory_usage_step_1);
473        store.try_insert(2usize.to_le_bytes().to_vec(), "20".into()).unwrap();
474        store.try_insert(3usize.to_le_bytes().to_vec(), "30".into()).unwrap();
475        for i in 1..=3usize {
476            assert_eq!(
477                *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
478                DummyValue::from(format!("{}", i * 10).as_str())
479            );
480        }
481        assert_eq!(store.guaranteed_capacity(), 3);
482        assert!(store.memory_usage() <= memory_usage_step_1);
483        assert!(store.try_insert(4usize.to_le_bytes().to_vec(), "40".into()).is_ok());
484        for i in (1usize..=4).rev() {
485            assert_eq!(
486                *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
487                DummyValue::from(format!("{}", i * 10).as_str())
488            );
489        }
490        assert_eq!(store.guaranteed_capacity(), 7);
491        assert!(store.memory_usage() <= memory_usage_step_2);
492        store.try_insert(5usize.to_le_bytes().to_vec(), "50".into()).unwrap();
493        store.try_insert(6usize.to_le_bytes().to_vec(), "60".into()).unwrap();
494        store.try_insert(7usize.to_le_bytes().to_vec(), "70".into()).unwrap();
495        for i in (5usize..=7).rev() {
496            store.get(i.to_le_bytes().as_ref()).unwrap();
497        }
498
499        store.insert(8usize.to_le_bytes().to_vec(), "80".into());
500        for i in [8usize, 7, 6, 5].iter() {
501            assert_eq!(
502                *(store
503                    .get(i.to_le_bytes().as_ref())
504                    .unwrap_or_else(|| panic!("couldn't find index {i}"))
505                    .read()
506                    .await),
507                DummyValue::from(format!("{}", i * 10).as_str())
508            );
509        }
510
511        assert_eq!(store.guaranteed_capacity(), 7);
512        assert!(store.memory_usage() <= memory_usage_step_2);
513        store.assert_check_internal_state();
514    }
515}