1use std::sync::Arc;
2
3use crate::{KeystoreError, RecursiveError, Result, prelude::MlsConversation};
4use core_crypto_keystore::connection::FetchFromDatabase;
5
6#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
7#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
8pub(crate) trait GroupStoreEntity: std::fmt::Debug {
9 type RawStoreValue: core_crypto_keystore::entities::Entity;
10 type IdentityType;
11
12 async fn fetch_from_id(
13 id: &[u8],
14 identity: Option<Self::IdentityType>,
15 keystore: &impl FetchFromDatabase,
16 ) -> Result<Option<Self>>
17 where
18 Self: Sized;
19}
20
21#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
22#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
23impl GroupStoreEntity for MlsConversation {
24 type RawStoreValue = core_crypto_keystore::entities::PersistedMlsGroup;
25 type IdentityType = ();
26
27 async fn fetch_from_id(
28 id: &[u8],
29 _: Option<Self::IdentityType>,
30 keystore: &impl FetchFromDatabase,
31 ) -> crate::Result<Option<Self>> {
32 let result = keystore
33 .find::<Self::RawStoreValue>(id)
34 .await
35 .map_err(KeystoreError::wrap("finding mls conversation from keystore by id"))?;
36 let Some(store_value) = result else {
37 return Ok(None);
38 };
39
40 let conversation = Self::from_serialized_state(store_value.state.clone(), store_value.parent_id.clone())
41 .map_err(RecursiveError::mls_conversation("deserializing mls conversation"))?;
42 Ok(conversation.group.is_active().then_some(conversation))
44 }
45}
46
47pub(crate) type GroupStoreValue<V> = Arc<async_lock::RwLock<V>>;
48
49pub(crate) type LruMap<V> = schnellru::LruMap<Vec<u8>, GroupStoreValue<V>, HybridMemoryLimiter>;
50
51pub(crate) struct GroupStore<V: GroupStoreEntity>(LruMap<V>);
55
56impl<V: GroupStoreEntity> std::fmt::Debug for GroupStore<V> {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("GroupStore")
59 .field("length", &self.0.len())
60 .field("memory_usage", &self.0.memory_usage())
61 .field(
62 "entries",
63 &self
64 .0
65 .iter()
66 .map(|(k, v)| format!("{k:?}={v:?}"))
67 .collect::<Vec<String>>()
68 .join("\n"),
69 )
70 .finish()
71 }
72}
73
74impl<V: GroupStoreEntity> Default for GroupStore<V> {
75 fn default() -> Self {
76 Self(schnellru::LruMap::default())
77 }
78}
79
80#[cfg(test)]
81impl<V: GroupStoreEntity> std::ops::Deref for GroupStore<V> {
82 type Target = LruMap<V>;
83
84 fn deref(&self) -> &Self::Target {
85 &self.0
86 }
87}
88
89#[cfg(test)]
90impl<V: GroupStoreEntity> std::ops::DerefMut for GroupStore<V> {
91 fn deref_mut(&mut self) -> &mut Self::Target {
92 &mut self.0
93 }
94}
95
96impl<V: GroupStoreEntity> GroupStore<V> {
97 pub(crate) fn new_with_limit(len: u32) -> Self {
98 let limiter = HybridMemoryLimiter::new(Some(len), None);
99 let store = schnellru::LruMap::new(limiter);
100 Self(store)
101 }
102
103 #[cfg(test)]
104 fn new(count: Option<u32>, memory: Option<usize>) -> Self {
105 let limiter = HybridMemoryLimiter::new(count, memory);
106 let store = schnellru::LruMap::new(limiter);
107 Self(store)
108 }
109
110 #[cfg(test)]
111 fn contains_key(&self, k: &[u8]) -> bool {
112 self.0.peek(k).is_some()
113 }
114
115 pub(crate) async fn get_fetch(
116 &mut self,
117 k: &[u8],
118 keystore: &impl FetchFromDatabase,
119 identity: Option<V::IdentityType>,
120 ) -> crate::Result<Option<GroupStoreValue<V>>> {
121 if let Some(value) = self.0.get(k) {
123 return Ok(Some(value.clone()));
124 }
125
126 let inserted_value = V::fetch_from_id(k, identity, keystore).await?.map(|value| {
128 let value_to_insert = Arc::new(async_lock::RwLock::new(value));
129 self.insert_prepped(k.to_vec(), value_to_insert.clone());
130 value_to_insert
131 });
132 Ok(inserted_value)
133 }
134
135 pub(crate) async fn fetch_from_keystore(
139 k: &[u8],
140 keystore: &impl FetchFromDatabase,
141 identity: Option<V::IdentityType>,
142 ) -> crate::Result<Option<V>> {
143 V::fetch_from_id(k, identity, keystore).await
144 }
145
146 fn insert_prepped(&mut self, k: Vec<u8>, prepped_entity: GroupStoreValue<V>) {
147 self.0.insert(k, prepped_entity);
148 }
149
150 pub(crate) fn insert(&mut self, k: Vec<u8>, entity: V) {
151 let value_to_insert = Arc::new(async_lock::RwLock::new(entity));
152 self.insert_prepped(k, value_to_insert)
153 }
154
155 pub(crate) fn try_insert(&mut self, k: Vec<u8>, entity: V) -> Result<(), V> {
156 let value_to_insert = Arc::new(async_lock::RwLock::new(entity));
157
158 if self.0.insert(k, value_to_insert.clone()) {
159 Ok(())
160 } else {
161 Err(Arc::into_inner(value_to_insert).unwrap().into_inner())
163 }
164 }
165
166 pub(crate) fn remove(&mut self, k: &[u8]) -> Option<GroupStoreValue<V>> {
167 self.0.remove(k)
168 }
169
170 pub(crate) fn get(&mut self, k: &[u8]) -> Option<&mut GroupStoreValue<V>> {
171 self.0.get(k)
172 }
173}
174
175pub(crate) struct HybridMemoryLimiter {
176 mem: schnellru::ByMemoryUsage,
177 len: schnellru::ByLength,
178}
179
180pub(crate) const MEMORY_LIMIT: usize = 100_000_000;
181pub(crate) const ITEM_LIMIT: u32 = 100;
182
183impl HybridMemoryLimiter {
184 pub(crate) fn new(count: Option<u32>, memory: Option<usize>) -> Self {
185 let memory_limit = memory.unwrap_or(MEMORY_LIMIT);
186 let mem = schnellru::ByMemoryUsage::new(memory_limit);
187 let len = schnellru::ByLength::new(count.unwrap_or(ITEM_LIMIT));
188
189 Self { mem, len }
190 }
191}
192
193impl Default for HybridMemoryLimiter {
194 fn default() -> Self {
195 Self::new(None, None)
196 }
197}
198
199impl<K, V> schnellru::Limiter<K, V> for HybridMemoryLimiter {
200 type KeyToInsert<'a> = K;
201 type LinkType = u32;
202
203 fn is_over_the_limit(&self, length: usize) -> bool {
204 <schnellru::ByLength as schnellru::Limiter<K, V>>::is_over_the_limit(&self.len, length)
205 }
206
207 fn on_insert(&mut self, length: usize, key: Self::KeyToInsert<'_>, value: V) -> Option<(K, V)> {
208 <schnellru::ByLength as schnellru::Limiter<K, V>>::on_insert(&mut self.len, length, key, value)
209 }
210
211 fn on_replace(
213 &mut self,
214 _length: usize,
215 _old_key: &mut K,
216 _new_key: Self::KeyToInsert<'_>,
217 _old_value: &mut V,
218 _new_value: &mut V,
219 ) -> bool {
220 true
221 }
222 fn on_removed(&mut self, _key: &mut K, _value: &mut V) {}
223 fn on_cleared(&mut self) {}
224
225 fn on_grow(&mut self, new_memory_usage: usize) -> bool {
226 <schnellru::ByMemoryUsage as schnellru::Limiter<K, V>>::on_grow(&mut self.mem, new_memory_usage)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use core_crypto_keystore::dummy_entity::{DummyStoreValue, DummyValue};
233 use wasm_bindgen_test::*;
234
235 use super::*;
236
237 wasm_bindgen_test_configure!(run_in_browser);
238
239 #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
240 #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
241 impl GroupStoreEntity for DummyValue {
242 type RawStoreValue = DummyStoreValue;
243
244 type IdentityType = ();
245
246 async fn fetch_from_id(
247 id: &[u8],
248 _identity: Option<Self::IdentityType>,
249 _keystore: &impl FetchFromDatabase,
250 ) -> crate::Result<Option<Self>> {
251 let id = std::str::from_utf8(id).expect("dummy value ids are strings");
253 Ok(Some(id.into()))
254 }
255 }
256
257 type TestGroupStore = GroupStore<DummyValue>;
258
259 #[async_std::test]
260 #[wasm_bindgen_test]
261 async fn group_store_init() {
262 let store = TestGroupStore::new_with_limit(1);
263 assert_eq!(store.len(), 0);
264 let store = TestGroupStore::new_with_limit(0);
265 assert_eq!(store.len(), 0);
266 let store = TestGroupStore::new(Some(0), Some(0));
267 assert_eq!(store.len(), 0);
268 let store = TestGroupStore::new(Some(0), Some(1));
269 assert_eq!(store.len(), 0);
270 let store = TestGroupStore::new(Some(1), Some(0));
271 assert_eq!(store.len(), 0);
272 let store = TestGroupStore::new(Some(1), Some(1));
273 assert_eq!(store.len(), 0);
274 }
275
276 #[async_std::test]
277 #[wasm_bindgen_test]
278 async fn group_store_common_ops() {
279 let mut store = TestGroupStore::new(Some(u32::MAX), Some(usize::MAX));
280 for i in 1..=3 {
281 let i_str = i.to_string();
282 assert!(
283 store
284 .try_insert(i_str.as_bytes().to_vec(), i_str.as_str().into())
285 .is_ok()
286 );
287 assert_eq!(store.len(), i);
288 }
289 for i in 4..=6 {
290 let i_str = i.to_string();
291 store.insert(i_str.as_bytes().to_vec(), i_str.as_str().into());
292 assert_eq!(store.len(), i);
293 }
294
295 for i in 1..=6 {
296 assert!(store.contains_key(i.to_string().as_bytes()));
297 }
298 }
299
300 #[async_std::test]
301 #[wasm_bindgen_test]
302 async fn group_store_operations_len_limiter() {
303 let mut store = TestGroupStore::new_with_limit(2);
304 assert!(store.try_insert(b"1".to_vec(), "1".into()).is_ok());
305 assert_eq!(store.len(), 1);
306 assert!(store.try_insert(b"2".to_vec(), "2".into()).is_ok());
307 assert_eq!(store.len(), 2);
308 assert!(store.try_insert(b"3".to_vec(), "3".into()).is_ok());
309 assert_eq!(store.len(), 2);
310 assert!(!store.contains_key(b"1"));
311 assert!(store.contains_key(b"2"));
312 assert!(store.contains_key(b"3"));
313 store.insert(b"4".to_vec(), "4".into());
314 assert_eq!(store.len(), 2);
315 }
316
317 #[async_std::test]
318 #[wasm_bindgen_test]
319 async fn group_store_operations_mem_limiter() {
320 use schnellru::{LruMap, UnlimitedCompact};
321 let mut lru: LruMap<Vec<u8>, DummyValue, UnlimitedCompact> =
322 LruMap::<Vec<u8>, DummyValue, UnlimitedCompact>::new(UnlimitedCompact);
323 assert_eq!(lru.guaranteed_capacity(), 0);
324 assert_eq!(lru.memory_usage(), 0);
325 lru.insert(1usize.to_le_bytes().to_vec(), "10".into());
326 let memory_usage_step_1 = lru.memory_usage();
327 lru.insert(2usize.to_le_bytes().to_vec(), "20".into());
328 lru.insert(3usize.to_le_bytes().to_vec(), "30".into());
329 lru.insert(4usize.to_le_bytes().to_vec(), "40".into());
330 let memory_usage_step_2 = lru.memory_usage();
331 assert_ne!(memory_usage_step_1, memory_usage_step_2);
332
333 let mut store = TestGroupStore::new(None, Some(memory_usage_step_2));
334 assert_eq!(store.guaranteed_capacity(), 0);
335 assert_eq!(store.memory_usage(), 0);
336 store.try_insert(1usize.to_le_bytes().to_vec(), "10".into()).unwrap();
337 assert_eq!(store.guaranteed_capacity(), 3);
338 assert!(store.memory_usage() <= memory_usage_step_1);
339 store.try_insert(2usize.to_le_bytes().to_vec(), "20".into()).unwrap();
340 store.try_insert(3usize.to_le_bytes().to_vec(), "30".into()).unwrap();
341 for i in 1..=3usize {
342 assert_eq!(
343 *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
344 DummyValue::from(format!("{}", i * 10).as_str())
345 );
346 }
347 assert_eq!(store.guaranteed_capacity(), 3);
348 assert!(store.memory_usage() <= memory_usage_step_1);
349 assert!(store.try_insert(4usize.to_le_bytes().to_vec(), "40".into()).is_ok());
350 for i in (1usize..=4).rev() {
351 assert_eq!(
352 *(store.get(i.to_le_bytes().as_ref()).unwrap().read().await),
353 DummyValue::from(format!("{}", i * 10).as_str())
354 );
355 }
356 assert_eq!(store.guaranteed_capacity(), 7);
357 assert!(store.memory_usage() <= memory_usage_step_2);
358 store.try_insert(5usize.to_le_bytes().to_vec(), "50".into()).unwrap();
359 store.try_insert(6usize.to_le_bytes().to_vec(), "60".into()).unwrap();
360 store.try_insert(7usize.to_le_bytes().to_vec(), "70".into()).unwrap();
361 for i in (5usize..=7).rev() {
362 store.get(i.to_le_bytes().as_ref()).unwrap();
363 }
364
365 store.insert(8usize.to_le_bytes().to_vec(), "80".into());
366 for i in [8usize, 7, 6, 5].iter() {
367 assert_eq!(
368 *(store
369 .get(i.to_le_bytes().as_ref())
370 .unwrap_or_else(|| panic!("couldn't find index {i}"))
371 .read()
372 .await),
373 DummyValue::from(format!("{}", i * 10).as_str())
374 );
375 }
376
377 assert_eq!(store.guaranteed_capacity(), 7);
378 assert!(store.memory_usage() <= memory_usage_step_2);
379 store.assert_check_internal_state();
380 }
381}