core_crypto/
group_store.rs

1// Wire
2// Copyright (C) 2022 Wire Swiss GmbH
3
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with this program. If not, see http://www.gnu.org/licenses/.
16
17use crate::prelude::{CryptoResult, MlsConversation};
18use core_crypto_keystore::{connection::FetchFromDatabase, entities::EntityFindParams};
19
20#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
21#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
22pub(crate) trait GroupStoreEntity: std::fmt::Debug {
23    type RawStoreValue: core_crypto_keystore::entities::Entity;
24    type IdentityType;
25
26    fn id(&self) -> &[u8];
27
28    async fn fetch_from_id(
29        id: &[u8],
30        identity: Option<Self::IdentityType>,
31        keystore: &impl FetchFromDatabase,
32    ) -> CryptoResult<Option<Self>>
33    where
34        Self: Sized;
35
36    async fn fetch_all(keystore: &impl FetchFromDatabase) -> CryptoResult<Vec<Self>>
37    where
38        Self: Sized;
39}
40
41#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
42#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
43impl GroupStoreEntity for MlsConversation {
44    type RawStoreValue = core_crypto_keystore::entities::PersistedMlsGroup;
45    type IdentityType = ();
46
47    fn id(&self) -> &[u8] {
48        self.id().as_slice()
49    }
50
51    async fn fetch_from_id(
52        id: &[u8],
53        _: Option<Self::IdentityType>,
54        keystore: &impl FetchFromDatabase,
55    ) -> crate::CryptoResult<Option<Self>> {
56        let result = keystore.find::<Self::RawStoreValue>(id).await?;
57        let Some(store_value) = result else {
58            return Ok(None);
59        };
60
61        let conversation = Self::from_serialized_state(store_value.state.clone(), store_value.parent_id.clone())?;
62        // If the conversation is not active, pretend it doesn't exist
63        Ok(if conversation.group.is_active() {
64            Some(conversation)
65        } else {
66            None
67        })
68    }
69
70    async fn fetch_all(keystore: &impl FetchFromDatabase) -> CryptoResult<Vec<Self>> {
71        let all_conversations = keystore
72            .find_all::<Self::RawStoreValue>(EntityFindParams::default())
73            .await?;
74        Ok(all_conversations
75            .iter()
76            .filter_map(|c| {
77                let conversation = Self::from_serialized_state(c.state.clone(), c.parent_id.clone()).unwrap();
78                conversation.group.is_active().then_some(conversation)
79            })
80            .collect::<Vec<_>>())
81    }
82}
83
84#[cfg(feature = "proteus")]
85#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
86#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
87impl GroupStoreEntity for crate::proteus::ProteusConversationSession {
88    type RawStoreValue = core_crypto_keystore::entities::ProteusSession;
89    type IdentityType = std::sync::Arc<proteus_wasm::keys::IdentityKeyPair>;
90
91    fn id(&self) -> &[u8] {
92        unreachable!()
93    }
94
95    async fn fetch_from_id(
96        id: &[u8],
97        identity: Option<Self::IdentityType>,
98        keystore: &impl FetchFromDatabase,
99    ) -> crate::CryptoResult<Option<Self>> {
100        let result = keystore.find::<Self::RawStoreValue>(id).await?;
101        let Some(store_value) = result else {
102            return Ok(None);
103        };
104
105        let Some(identity) = identity else {
106            return Err(crate::CryptoError::ProteusNotInitialized);
107        };
108
109        let session = proteus_wasm::session::Session::deserialise(identity, &store_value.session)
110            .map_err(crate::ProteusError::from)?;
111
112        Ok(Some(Self {
113            identifier: store_value.id.clone(),
114            session,
115        }))
116    }
117
118    async fn fetch_all(_keystore: &impl FetchFromDatabase) -> CryptoResult<Vec<Self>>
119    where
120        Self: Sized,
121    {
122        unreachable!()
123    }
124}
125
126pub(crate) type GroupStoreValue<V> = std::sync::Arc<async_lock::RwLock<V>>;
127
128pub(crate) type LruMap<V> = schnellru::LruMap<Vec<u8>, GroupStoreValue<V>, HybridMemoryLimiter>;
129
130/// LRU-cache based group/session store
131/// Uses a hybrid memory limiter based on both amount of elements and total memory usage
132/// As with all LRU caches, eviction is based on oldest elements
133pub(crate) struct GroupStore<V: GroupStoreEntity>(LruMap<V>);
134
135impl<V: GroupStoreEntity> std::fmt::Debug for GroupStore<V> {
136    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137        f.debug_struct("GroupStore")
138            .field("length", &self.0.len())
139            .field("memory_usage", &self.0.memory_usage())
140            .field(
141                "entries",
142                &self
143                    .0
144                    .iter()
145                    .map(|(k, v)| format!("{k:?}={v:?}"))
146                    .collect::<Vec<String>>()
147                    .join("\n"),
148            )
149            .finish()
150    }
151}
152
153impl<V: GroupStoreEntity> Default for GroupStore<V> {
154    fn default() -> Self {
155        Self(schnellru::LruMap::default())
156    }
157}
158
159#[cfg(test)]
160impl<V: GroupStoreEntity> std::ops::Deref for GroupStore<V> {
161    type Target = LruMap<V>;
162
163    fn deref(&self) -> &Self::Target {
164        &self.0
165    }
166}
167
168#[cfg(test)]
169impl<V: GroupStoreEntity> std::ops::DerefMut for GroupStore<V> {
170    fn deref_mut(&mut self) -> &mut Self::Target {
171        &mut self.0
172    }
173}
174
175impl<V: GroupStoreEntity> GroupStore<V> {
176    #[allow(dead_code)]
177    pub(crate) fn new_with_limit(len: u32) -> Self {
178        let limiter = HybridMemoryLimiter::new(Some(len), None);
179        let store = schnellru::LruMap::new(limiter);
180        Self(store)
181    }
182
183    #[allow(dead_code)]
184    pub(crate) fn new(count: Option<u32>, memory: Option<usize>) -> Self {
185        let limiter = HybridMemoryLimiter::new(count, memory);
186        let store = schnellru::LruMap::new(limiter);
187        Self(store)
188    }
189
190    #[allow(dead_code)]
191    pub(crate) fn contains_key(&self, k: &[u8]) -> bool {
192        self.0.peek(k).is_some()
193    }
194
195    pub(crate) async fn get_fetch(
196        &mut self,
197        k: &[u8],
198        keystore: &impl FetchFromDatabase,
199        identity: Option<V::IdentityType>,
200    ) -> crate::CryptoResult<Option<GroupStoreValue<V>>> {
201        // Optimistic cache lookup
202        if let Some(value) = self.0.get(k) {
203            return Ok(Some(value.clone()));
204        }
205
206        // Not in store, fetch the thing in the keystore
207        let mut value = V::fetch_from_id(k, identity, keystore).await?;
208        if let Some(value) = value.take() {
209            let value_to_insert = std::sync::Arc::new(async_lock::RwLock::new(value));
210            self.insert_prepped(k.to_vec(), value_to_insert.clone());
211
212            Ok(Some(value_to_insert))
213        } else {
214            Ok(None)
215        }
216    }
217
218    /// Returns the value from the keystore.
219    /// WARNING: the returned value is not attached to the keystore and mutations on it will be
220    /// lost when the object is dropped
221    pub(crate) async fn fetch_from_keystore(
222        k: &[u8],
223        keystore: &impl FetchFromDatabase,
224        identity: Option<V::IdentityType>,
225    ) -> crate::CryptoResult<Option<V>> {
226        V::fetch_from_id(k, identity, keystore).await
227    }
228
229    pub(crate) async fn get_fetch_all(
230        &mut self,
231        keystore: &impl FetchFromDatabase,
232    ) -> CryptoResult<Vec<GroupStoreValue<V>>> {
233        let all = V::fetch_all(keystore)
234            .await?
235            .into_iter()
236            .map(|g| {
237                let id = g.id().to_vec();
238                let to_insert = std::sync::Arc::new(async_lock::RwLock::new(g));
239                self.insert_prepped(id, to_insert.clone());
240                to_insert
241            })
242            .collect::<Vec<_>>();
243        Ok(all)
244    }
245
246    fn insert_prepped(&mut self, k: Vec<u8>, prepped_entity: GroupStoreValue<V>) {
247        self.0.insert(k, prepped_entity);
248    }
249
250    pub(crate) fn insert(&mut self, k: Vec<u8>, entity: V) {
251        let value_to_insert = std::sync::Arc::new(async_lock::RwLock::new(entity));
252        self.insert_prepped(k, value_to_insert)
253    }
254
255    pub(crate) fn try_insert(&mut self, k: Vec<u8>, entity: V) -> Result<(), V> {
256        let value_to_insert = std::sync::Arc::new(async_lock::RwLock::new(entity));
257
258        if self.0.insert(k, value_to_insert.clone()) {
259            Ok(())
260        } else {
261            // This is safe because we just built the value
262            Err(std::sync::Arc::into_inner(value_to_insert).unwrap().into_inner())
263        }
264    }
265
266    pub(crate) fn remove(&mut self, k: &[u8]) -> Option<GroupStoreValue<V>> {
267        self.0.remove(k)
268    }
269
270    pub(crate) fn get(&mut self, k: &[u8]) -> Option<&mut GroupStoreValue<V>> {
271        self.0.get(k)
272    }
273}
274
275pub(crate) struct HybridMemoryLimiter {
276    mem: schnellru::ByMemoryUsage,
277    len: schnellru::ByLength,
278}
279
280pub(crate) const MEMORY_LIMIT: usize = 100_000_000;
281pub(crate) const ITEM_LIMIT: u32 = 100;
282
283impl HybridMemoryLimiter {
284    pub(crate) fn new(count: Option<u32>, memory: Option<usize>) -> Self {
285        // false positive. We want to fetch system metrics lazily
286        #[allow(clippy::unnecessary_lazy_evaluations)]
287        let maybe_memory_limit = memory.or_else(|| {
288            cfg_if::cfg_if! {
289                if #[cfg(target_family = "wasm")] {
290                    None
291                } else {
292                    let system = sysinfo::System::new_with_specifics(sysinfo::RefreshKind::new().with_memory(sysinfo::MemoryRefreshKind::new().with_ram()));
293                    let available_sys_memory = system.available_memory();
294                    if available_sys_memory > 0 {
295                        Some(available_sys_memory as usize)
296                    } else {
297                        None
298                    }
299                }
300            }
301        });
302
303        let mem = schnellru::ByMemoryUsage::new(maybe_memory_limit.unwrap_or(MEMORY_LIMIT));
304        let len = schnellru::ByLength::new(count.unwrap_or(ITEM_LIMIT));
305
306        Self { mem, len }
307    }
308}
309
310impl Default for HybridMemoryLimiter {
311    fn default() -> Self {
312        Self::new(None, None)
313    }
314}
315
316impl<K, V> schnellru::Limiter<K, V> for HybridMemoryLimiter {
317    type KeyToInsert<'a> = K;
318    type LinkType = u32;
319
320    fn is_over_the_limit(&self, length: usize) -> bool {
321        <schnellru::ByLength as schnellru::Limiter<K, V>>::is_over_the_limit(&self.len, length)
322    }
323
324    fn on_insert(&mut self, length: usize, key: Self::KeyToInsert<'_>, value: V) -> Option<(K, V)> {
325        <schnellru::ByLength as schnellru::Limiter<K, V>>::on_insert(&mut self.len, length, key, value)
326    }
327
328    // Both underlying limiters have dummy implementations here
329    fn on_replace(
330        &mut self,
331        _length: usize,
332        _old_key: &mut K,
333        _new_key: Self::KeyToInsert<'_>,
334        _old_value: &mut V,
335        _new_value: &mut V,
336    ) -> bool {
337        true
338    }
339    fn on_removed(&mut self, _key: &mut K, _value: &mut V) {}
340    fn on_cleared(&mut self) {}
341
342    fn on_grow(&mut self, new_memory_usage: usize) -> bool {
343        <schnellru::ByMemoryUsage as schnellru::Limiter<K, V>>::on_grow(&mut self.mem, new_memory_usage)
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use core_crypto_keystore::dummy_entity::{DummyStoreValue, DummyValue};
350    use wasm_bindgen_test::*;
351
352    use super::*;
353
354    wasm_bindgen_test_configure!(run_in_browser);
355
356    #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
357    #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
358    impl GroupStoreEntity for DummyValue {
359        type RawStoreValue = DummyStoreValue;
360
361        type IdentityType = ();
362
363        fn id(&self) -> &[u8] {
364            unreachable!()
365        }
366
367        async fn fetch_from_id(
368            id: &[u8],
369            _identity: Option<Self::IdentityType>,
370            _keystore: &impl FetchFromDatabase,
371        ) -> crate::CryptoResult<Option<Self>> {
372            let id = std::str::from_utf8(id)?;
373            Ok(Some(id.into()))
374        }
375
376        async fn fetch_all(_keystore: &impl FetchFromDatabase) -> CryptoResult<Vec<Self>> {
377            unreachable!()
378        }
379    }
380
381    type TestGroupStore = GroupStore<DummyValue>;
382
383    #[async_std::test]
384    #[wasm_bindgen_test]
385    async fn group_store_init() {
386        let store = TestGroupStore::new_with_limit(1);
387        assert_eq!(store.len(), 0);
388        let store = TestGroupStore::new_with_limit(0);
389        assert_eq!(store.len(), 0);
390        let store = TestGroupStore::new(Some(0), Some(0));
391        assert_eq!(store.len(), 0);
392        let store = TestGroupStore::new(Some(0), Some(1));
393        assert_eq!(store.len(), 0);
394        let store = TestGroupStore::new(Some(1), Some(0));
395        assert_eq!(store.len(), 0);
396        let store = TestGroupStore::new(Some(1), Some(1));
397        assert_eq!(store.len(), 0);
398    }
399
400    #[async_std::test]
401    #[wasm_bindgen_test]
402    async fn group_store_common_ops() {
403        let mut store = TestGroupStore::new(Some(u32::MAX), Some(usize::MAX));
404        for i in 1..=3 {
405            let i_str = i.to_string();
406            assert!(store
407                .try_insert(i_str.as_bytes().to_vec(), i_str.as_str().into())
408                .is_ok());
409            assert_eq!(store.len(), i);
410        }
411        for i in 4..=6 {
412            let i_str = i.to_string();
413            store.insert(i_str.as_bytes().to_vec(), i_str.as_str().into());
414            assert_eq!(store.len(), i);
415        }
416
417        for i in 1..=6 {
418            assert!(store.contains_key(i.to_string().as_bytes()));
419        }
420    }
421
422    #[async_std::test]
423    #[wasm_bindgen_test]
424    async fn group_store_operations_len_limiter() {
425        let mut store = TestGroupStore::new_with_limit(2);
426        assert!(store.try_insert(b"1".to_vec(), "1".into()).is_ok());
427        assert_eq!(store.len(), 1);
428        assert!(store.try_insert(b"2".to_vec(), "2".into()).is_ok());
429        assert_eq!(store.len(), 2);
430        assert!(store.try_insert(b"3".to_vec(), "3".into()).is_ok());
431        assert_eq!(store.len(), 2);
432        assert!(!store.contains_key(b"1"));
433        assert!(store.contains_key(b"2"));
434        assert!(store.contains_key(b"3"));
435        store.insert(b"4".to_vec(), "4".into());
436        assert_eq!(store.len(), 2);
437    }
438
439    #[async_std::test]
440    #[wasm_bindgen_test]
441    async fn group_store_operations_mem_limiter() {
442        use schnellru::{LruMap, UnlimitedCompact};
443        let mut lru: LruMap<Vec<u8>, DummyValue, UnlimitedCompact> =
444            LruMap::<Vec<u8>, DummyValue, UnlimitedCompact>::new(UnlimitedCompact);
445        assert_eq!(lru.guaranteed_capacity(), 0);
446        assert_eq!(lru.memory_usage(), 0);
447        lru.insert(1usize.to_le_bytes().to_vec(), "10".into());
448        let memory_usage_step_1 = lru.memory_usage();
449        lru.insert(2usize.to_le_bytes().to_vec(), "20".into());
450        lru.insert(3usize.to_le_bytes().to_vec(), "30".into());
451        lru.insert(4usize.to_le_bytes().to_vec(), "40".into());
452        let memory_usage_step_2 = lru.memory_usage();
453        assert_ne!(memory_usage_step_1, memory_usage_step_2);
454
455        let mut store = TestGroupStore::new(None, Some(memory_usage_step_2));
456        assert_eq!(store.guaranteed_capacity(), 0);
457        assert_eq!(store.memory_usage(), 0);
458        store.try_insert(1usize.to_le_bytes().to_vec(), "10".into()).unwrap();
459        assert_eq!(store.guaranteed_capacity(), 3);
460        assert!(store.memory_usage() <= memory_usage_step_1);
461        store.try_insert(2usize.to_le_bytes().to_vec(), "20".into()).unwrap();
462        store.try_insert(3usize.to_le_bytes().to_vec(), "30".into()).unwrap();
463        for i in 1..=3usize {
464            assert_eq!(
465                *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
466                DummyValue::from(format!("{}", i * 10).as_str())
467            );
468        }
469        assert_eq!(store.guaranteed_capacity(), 3);
470        assert!(store.memory_usage() <= memory_usage_step_1);
471        assert!(store.try_insert(4usize.to_le_bytes().to_vec(), "40".into()).is_ok());
472        for i in (1usize..=4).rev() {
473            assert_eq!(
474                *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
475                DummyValue::from(format!("{}", i * 10).as_str())
476            );
477        }
478        assert_eq!(store.guaranteed_capacity(), 7);
479        assert!(store.memory_usage() <= memory_usage_step_2);
480        store.try_insert(5usize.to_le_bytes().to_vec(), "50".into()).unwrap();
481        store.try_insert(6usize.to_le_bytes().to_vec(), "60".into()).unwrap();
482        store.try_insert(7usize.to_le_bytes().to_vec(), "70".into()).unwrap();
483        for i in (5usize..=7).rev() {
484            store.get(i.to_le_bytes().as_ref()).unwrap();
485        }
486
487        store.insert(8usize.to_le_bytes().to_vec(), "80".into());
488        for i in [8usize, 7, 6, 5].iter() {
489            assert_eq!(
490                *(store
491                    .get(i.to_le_bytes().as_ref())
492                    .unwrap_or_else(|| panic!("couldn't find index {i}"))
493                    .read()
494                    .await),
495                DummyValue::from(format!("{}", i * 10).as_str())
496            );
497        }
498
499        assert_eq!(store.guaranteed_capacity(), 7);
500        assert!(store.memory_usage() <= memory_usage_step_2);
501        store.assert_check_internal_state();
502    }
503}