core_crypto_keystore/
mls.rs

1use crate::connection::FetchFromDatabase;
2use crate::entities::MlsEpochEncryptionKeyPair;
3use crate::{
4    CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind,
5    entities::{
6        E2eiEnrollment, EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage, MlsPskBundle,
7        MlsSignatureKeyPair, PersistedMlsGroup, PersistedMlsPendingGroup,
8    },
9};
10use openmls_basic_credential::SignatureKeyPair;
11use openmls_traits::key_store::{MlsEntity, MlsEntityId};
12
13/// An interface for the specialized queries in the KeyStore
14#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
15#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
16pub trait CryptoKeystoreMls: Sized {
17    /// Fetches Keypackages
18    ///
19    /// # Arguments
20    /// * `count` - amount of entries to be returned
21    ///
22    /// # Errors
23    /// Any common error that can happen during a database connection. IoError being a common error
24    /// for example.
25    async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>>;
26
27    /// Checks if the given MLS group id exists in the keystore
28    /// Note: in case of any error, this will return false
29    ///
30    /// # Arguments
31    /// * `group_id` - group/conversation id
32    async fn mls_group_exists(&self, group_id: &[u8]) -> bool;
33
34    /// Persists a `MlsGroup`
35    ///
36    /// # Arguments
37    /// * `group_id` - group/conversation id
38    /// * `state` - the group state
39    ///
40    /// # Errors
41    /// Any common error that can happen during a database connection. IoError being a common error
42    /// for example.
43    async fn mls_group_persist(
44        &self,
45        group_id: &[u8],
46        state: &[u8],
47        parent_group_id: Option<&[u8]>,
48    ) -> CryptoKeystoreResult<()>;
49
50    /// Loads `MlsGroups` from the database. It will be returned as a `HashMap` where the key is
51    /// the group/conversation id and the value the group state
52    ///
53    /// # Errors
54    /// Any common error that can happen during a database connection. IoError being a common error
55    /// for example.
56    async fn mls_groups_restore(
57        &self,
58    ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>>;
59
60    /// Deletes `MlsGroups` from the database.
61    /// # Errors
62    /// Any common error that can happen during a database connection. IoError being a common error
63    /// for example.
64    async fn mls_group_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()>;
65
66    /// Saves a `MlsGroup` in a temporary table (typically used in scenarios where the group cannot
67    /// be committed until the backend acknowledges it, like external commits)
68    ///
69    /// # Arguments
70    /// * `group_id` - group/conversation id
71    /// * `mls_group` - the group/conversation state
72    /// * `custom_configuration` - local group configuration
73    ///
74    /// # Errors
75    /// Any common error that can happen during a database connection. IoError being a common error
76    /// for example.
77    async fn mls_pending_groups_save(
78        &self,
79        group_id: &[u8],
80        mls_group: &[u8],
81        custom_configuration: &[u8],
82        parent_group_id: Option<&[u8]>,
83    ) -> CryptoKeystoreResult<()>;
84
85    /// Loads a temporary `MlsGroup` and its configuration from the database
86    ///
87    /// # Arguments
88    /// * `id` - group/conversation id
89    ///
90    /// # Errors
91    /// Any common error that can happen during a database connection. IoError being a common error
92    /// for example.
93    async fn mls_pending_groups_load(&self, group_id: &[u8]) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)>;
94
95    /// Deletes a temporary `MlsGroup` from the database
96    ///
97    /// # Arguments
98    /// * `id` - group/conversation id
99    ///
100    /// # Errors
101    /// Any common error that can happen during a database connection. IoError being a common error
102    /// for example.
103    async fn mls_pending_groups_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()>;
104
105    /// Persists an enrollment instance
106    ///
107    /// # Arguments
108    /// * `id` - hash of the enrollment and unique identifier
109    /// * `content` - serialized enrollment
110    async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()>;
111
112    /// Fetches and delete the enrollment instance
113    ///
114    /// # Arguments
115    /// * `id` - hash of the enrollment and unique identifier
116    async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>>;
117}
118
119#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
120#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
121impl CryptoKeystoreMls for crate::Connection {
122    async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>> {
123        cfg_if::cfg_if! {
124            if #[cfg(not(target_family = "wasm"))] {
125                let reverse = true;
126            } else {
127                let reverse = false;
128            }
129        }
130        let keypackages = self
131            .find_all::<MlsKeyPackage>(EntityFindParams {
132                limit: Some(count),
133                offset: None,
134                reverse,
135            })
136            .await?;
137
138        Ok(keypackages
139            .into_iter()
140            .filter_map(|kpb| postcard::from_bytes(&kpb.keypackage).ok())
141            .collect())
142    }
143
144    async fn mls_group_exists(&self, group_id: &[u8]) -> bool {
145        matches!(self.find::<PersistedMlsGroup>(group_id).await, Ok(Some(_)))
146    }
147
148    async fn mls_group_persist(
149        &self,
150        group_id: &[u8],
151        state: &[u8],
152        parent_group_id: Option<&[u8]>,
153    ) -> CryptoKeystoreResult<()> {
154        self.save(PersistedMlsGroup {
155            id: group_id.into(),
156            state: state.into(),
157            parent_id: parent_group_id.map(Into::into),
158        })
159        .await?;
160
161        Ok(())
162    }
163
164    async fn mls_groups_restore(
165        &self,
166    ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>> {
167        let groups = self.find_all::<PersistedMlsGroup>(EntityFindParams::default()).await?;
168        Ok(groups
169            .into_iter()
170            .map(|group: PersistedMlsGroup| (group.id.clone(), (group.parent_id.clone(), group.state.clone())))
171            .collect())
172    }
173
174    async fn mls_group_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()> {
175        self.remove::<PersistedMlsGroup, _>(group_id).await?;
176
177        Ok(())
178    }
179
180    async fn mls_pending_groups_save(
181        &self,
182        group_id: &[u8],
183        mls_group: &[u8],
184        custom_configuration: &[u8],
185        parent_group_id: Option<&[u8]>,
186    ) -> CryptoKeystoreResult<()> {
187        self.save(PersistedMlsPendingGroup {
188            id: group_id.into(),
189            state: mls_group.into(),
190            custom_configuration: custom_configuration.into(),
191            parent_id: parent_group_id.map(Into::into),
192        })
193        .await?;
194        Ok(())
195    }
196
197    async fn mls_pending_groups_load(&self, group_id: &[u8]) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)> {
198        self.find(group_id)
199            .await?
200            .map(|r: PersistedMlsPendingGroup| (r.state.clone(), r.custom_configuration.clone()))
201            .ok_or(CryptoKeystoreError::MissingKeyInStore(
202                MissingKeyErrorKind::MlsPendingGroup,
203            ))
204    }
205
206    async fn mls_pending_groups_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()> {
207        self.remove::<PersistedMlsPendingGroup, _>(group_id).await
208    }
209
210    async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()> {
211        self.save(E2eiEnrollment {
212            id: id.into(),
213            content: content.into(),
214        })
215        .await?;
216        Ok(())
217    }
218
219    async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>> {
220        // someone who has time could try to optimize this but honestly it's really on the cold path
221        let enrollment = self
222            .find::<E2eiEnrollment>(id)
223            .await?
224            .ok_or(CryptoKeystoreError::MissingKeyInStore(
225                MissingKeyErrorKind::E2eiEnrollment,
226            ))?;
227        self.remove::<E2eiEnrollment, _>(id).await?;
228        Ok(enrollment.content.clone())
229    }
230}
231
232#[inline(always)]
233pub fn deser<T: MlsEntity>(bytes: &[u8]) -> Result<T, CryptoKeystoreError> {
234    Ok(postcard::from_bytes(bytes)?)
235}
236
237#[inline(always)]
238pub fn ser<T: MlsEntity>(value: &T) -> Result<Vec<u8>, CryptoKeystoreError> {
239    Ok(postcard::to_stdvec(value)?)
240}
241
242#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
243#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
244impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Connection {
245    type Error = CryptoKeystoreError;
246
247    async fn store<V: MlsEntity + Sync>(&self, k: &[u8], v: &V) -> Result<(), Self::Error>
248    where
249        Self: Sized,
250    {
251        if k.is_empty() {
252            return Err(CryptoKeystoreError::MlsKeyStoreError(
253                "The provided key is empty".into(),
254            ));
255        }
256
257        let data = ser(v)?;
258
259        match V::ID {
260            MlsEntityId::GroupState => {
261                return Err(CryptoKeystoreError::IncorrectApiUsage(
262                    "Groups must not be saved using OpenMLS's APIs. You should use the keystore's provided methods",
263                ));
264            }
265            MlsEntityId::SignatureKeyPair => {
266                let concrete_signature_keypair: &SignatureKeyPair = v
267                    .downcast()
268                    .expect("There's an implementation issue in OpenMLS. This shouln't be happening.");
269
270                // Having an empty credential id seems tolerable, since the SignatureKeyPair type is retrieved from the key store via its public key.
271                let credential_id = vec![];
272                let kp = MlsSignatureKeyPair::new(
273                    concrete_signature_keypair.signature_scheme(),
274                    k.into(),
275                    data,
276                    credential_id,
277                );
278                self.save(kp).await?;
279            }
280            MlsEntityId::KeyPackage => {
281                let kp = MlsKeyPackage {
282                    keypackage_ref: k.into(),
283                    keypackage: data,
284                };
285                self.save(kp).await?;
286            }
287            MlsEntityId::HpkePrivateKey => {
288                let kp = MlsHpkePrivateKey { pk: k.into(), sk: data };
289                self.save(kp).await?;
290            }
291            MlsEntityId::PskBundle => {
292                let kp = MlsPskBundle {
293                    psk_id: k.into(),
294                    psk: data,
295                };
296                self.save(kp).await?;
297            }
298            MlsEntityId::EncryptionKeyPair => {
299                let kp = MlsEncryptionKeyPair { pk: k.into(), sk: data };
300                self.save(kp).await?;
301            }
302            MlsEntityId::EpochEncryptionKeyPair => {
303                let kp = MlsEpochEncryptionKeyPair {
304                    id: k.into(),
305                    keypairs: data,
306                };
307                self.save(kp).await?;
308            }
309        }
310
311        Ok(())
312    }
313
314    async fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V>
315    where
316        Self: Sized,
317    {
318        if k.is_empty() {
319            return None;
320        }
321
322        match V::ID {
323            MlsEntityId::GroupState => {
324                let group: PersistedMlsGroup = self.find(k).await.ok().flatten()?;
325                deser(&group.state).ok()
326            }
327            MlsEntityId::SignatureKeyPair => {
328                let sig: MlsSignatureKeyPair = self.find(k).await.ok().flatten()?;
329                deser(&sig.keypair).ok()
330            }
331            MlsEntityId::KeyPackage => {
332                let kp: MlsKeyPackage = self.find(k).await.ok().flatten()?;
333                deser(&kp.keypackage).ok()
334            }
335            MlsEntityId::HpkePrivateKey => {
336                let hpke_pk: MlsHpkePrivateKey = self.find(k).await.ok().flatten()?;
337                deser(&hpke_pk.sk).ok()
338            }
339            MlsEntityId::PskBundle => {
340                let psk_bundle: MlsPskBundle = self.find(k).await.ok().flatten()?;
341                deser(&psk_bundle.psk).ok()
342            }
343            MlsEntityId::EncryptionKeyPair => {
344                let kp: MlsEncryptionKeyPair = self.find(k).await.ok().flatten()?;
345                deser(&kp.sk).ok()
346            }
347            MlsEntityId::EpochEncryptionKeyPair => {
348                let kp: MlsEpochEncryptionKeyPair = self.find(k).await.ok().flatten()?;
349                deser(&kp.keypairs).ok()
350            }
351        }
352    }
353
354    async fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
355        match V::ID {
356            MlsEntityId::GroupState => self.remove::<PersistedMlsGroup, _>(k).await?,
357            MlsEntityId::SignatureKeyPair => self.remove::<MlsSignatureKeyPair, _>(k).await?,
358            MlsEntityId::HpkePrivateKey => self.remove::<MlsHpkePrivateKey, _>(k).await?,
359            MlsEntityId::KeyPackage => self.remove::<MlsKeyPackage, _>(k).await?,
360            MlsEntityId::PskBundle => self.remove::<MlsPskBundle, _>(k).await?,
361            MlsEntityId::EncryptionKeyPair => self.remove::<MlsEncryptionKeyPair, _>(k).await?,
362            MlsEntityId::EpochEncryptionKeyPair => self.remove::<MlsEpochEncryptionKeyPair, _>(k).await?,
363        }
364
365        Ok(())
366    }
367}