core_crypto_keystore/
mls.rs

1use openmls_basic_credential::SignatureKeyPair;
2use openmls_traits::key_store::{MlsEntity, MlsEntityId};
3
4use crate::{
5    CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind,
6    connection::FetchFromDatabase,
7    entities::{
8        E2eiEnrollment, EntityFindParams, MlsEncryptionKeyPair, MlsEpochEncryptionKeyPair, MlsHpkePrivateKey,
9        MlsKeyPackage, MlsPskBundle, MlsSignatureKeyPair, PersistedMlsGroup, PersistedMlsPendingGroup,
10    },
11};
12
13/// An interface for the specialized queries in the KeyStore
14#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
15#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
16pub trait CryptoKeystoreMls: Sized {
17    /// Fetches Keypackages
18    ///
19    /// # Arguments
20    /// * `count` - amount of entries to be returned
21    ///
22    /// # Errors
23    /// Any common error that can happen during a database connection. IoError being a common error
24    /// for example.
25    async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>>;
26
27    /// Checks if the given MLS group id exists in the keystore
28    /// Note: in case of any error, this will return false
29    ///
30    /// # Arguments
31    /// * `group_id` - group/conversation id
32    async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool;
33
34    /// Persists a `MlsGroup`
35    ///
36    /// # Arguments
37    /// * `group_id` - group/conversation id
38    /// * `state` - the group state
39    ///
40    /// # Errors
41    /// Any common error that can happen during a database connection. IoError being a common error
42    /// for example.
43    async fn mls_group_persist(
44        &self,
45        group_id: impl AsRef<[u8]> + Send,
46        state: &[u8],
47        parent_group_id: Option<&[u8]>,
48    ) -> CryptoKeystoreResult<()>;
49
50    /// Loads `MlsGroups` from the database. It will be returned as a `HashMap` where the key is
51    /// the group/conversation id and the value the group state
52    ///
53    /// # Errors
54    /// Any common error that can happen during a database connection. IoError being a common error
55    /// for example.
56    async fn mls_groups_restore(
57        &self,
58    ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>>;
59
60    /// Deletes `MlsGroups` from the database.
61    /// # Errors
62    /// Any common error that can happen during a database connection. IoError being a common error
63    /// for example.
64    async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()>;
65
66    /// Saves a `MlsGroup` in a temporary table (typically used in scenarios where the group cannot
67    /// be committed until the backend acknowledges it, like external commits)
68    ///
69    /// # Arguments
70    /// * `group_id` - group/conversation id
71    /// * `mls_group` - the group/conversation state
72    /// * `custom_configuration` - local group configuration
73    ///
74    /// # Errors
75    /// Any common error that can happen during a database connection. IoError being a common error
76    /// for example.
77    async fn mls_pending_groups_save(
78        &self,
79        group_id: impl AsRef<[u8]> + Send,
80        mls_group: &[u8],
81        custom_configuration: &[u8],
82        parent_group_id: Option<&[u8]>,
83    ) -> CryptoKeystoreResult<()>;
84
85    /// Loads a temporary `MlsGroup` and its configuration from the database
86    ///
87    /// # Arguments
88    /// * `id` - group/conversation id
89    ///
90    /// # Errors
91    /// Any common error that can happen during a database connection. IoError being a common error
92    /// for example.
93    async fn mls_pending_groups_load(
94        &self,
95        group_id: impl AsRef<[u8]> + Send,
96    ) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)>;
97
98    /// Deletes a temporary `MlsGroup` from the database
99    ///
100    /// # Arguments
101    /// * `id` - group/conversation id
102    ///
103    /// # Errors
104    /// Any common error that can happen during a database connection. IoError being a common error
105    /// for example.
106    async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()>;
107
108    /// Persists an enrollment instance
109    ///
110    /// # Arguments
111    /// * `id` - hash of the enrollment and unique identifier
112    /// * `content` - serialized enrollment
113    async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()>;
114
115    /// Fetches and delete the enrollment instance
116    ///
117    /// # Arguments
118    /// * `id` - hash of the enrollment and unique identifier
119    async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>>;
120}
121
122#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
123#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
124impl CryptoKeystoreMls for crate::Database {
125    async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>> {
126        cfg_if::cfg_if! {
127            if #[cfg(not(target_family = "wasm"))] {
128                let reverse = true;
129            } else {
130                let reverse = false;
131            }
132        }
133        let keypackages = self
134            .find_all::<MlsKeyPackage>(EntityFindParams {
135                limit: Some(count),
136                offset: None,
137                reverse,
138            })
139            .await?;
140
141        Ok(keypackages
142            .into_iter()
143            .filter_map(|kpb| postcard::from_bytes(&kpb.keypackage).ok())
144            .collect())
145    }
146
147    async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool {
148        matches!(self.find::<PersistedMlsGroup>(group_id).await, Ok(Some(_)))
149    }
150
151    async fn mls_group_persist(
152        &self,
153        group_id: impl AsRef<[u8]> + Send,
154        state: &[u8],
155        parent_group_id: Option<&[u8]>,
156    ) -> CryptoKeystoreResult<()> {
157        self.save(PersistedMlsGroup {
158            id: group_id.as_ref().to_owned(),
159            state: state.into(),
160            parent_id: parent_group_id.map(Into::into),
161        })
162        .await?;
163
164        Ok(())
165    }
166
167    async fn mls_groups_restore(
168        &self,
169    ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>> {
170        let groups = self.find_all::<PersistedMlsGroup>(EntityFindParams::default()).await?;
171        Ok(groups
172            .into_iter()
173            .map(|group: PersistedMlsGroup| (group.id.clone(), (group.parent_id.clone(), group.state.clone())))
174            .collect())
175    }
176
177    async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> {
178        self.remove::<PersistedMlsGroup, _>(group_id).await?;
179
180        Ok(())
181    }
182
183    async fn mls_pending_groups_save(
184        &self,
185        group_id: impl AsRef<[u8]> + Send,
186        mls_group: &[u8],
187        custom_configuration: &[u8],
188        parent_group_id: Option<&[u8]>,
189    ) -> CryptoKeystoreResult<()> {
190        self.save(PersistedMlsPendingGroup {
191            id: group_id.as_ref().to_owned(),
192            state: mls_group.into(),
193            custom_configuration: custom_configuration.into(),
194            parent_id: parent_group_id.map(Into::into),
195        })
196        .await?;
197        Ok(())
198    }
199
200    async fn mls_pending_groups_load(
201        &self,
202        group_id: impl AsRef<[u8]> + Send,
203    ) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)> {
204        self.find(group_id)
205            .await?
206            .map(|r: PersistedMlsPendingGroup| (r.state.clone(), r.custom_configuration.clone()))
207            .ok_or(CryptoKeystoreError::MissingKeyInStore(
208                MissingKeyErrorKind::MlsPendingGroup,
209            ))
210    }
211
212    async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> {
213        self.remove::<PersistedMlsPendingGroup, _>(group_id).await
214    }
215
216    async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()> {
217        self.save(E2eiEnrollment {
218            id: id.into(),
219            content: content.into(),
220        })
221        .await?;
222        Ok(())
223    }
224
225    async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>> {
226        // someone who has time could try to optimize this but honestly it's really on the cold path
227        let enrollment = self
228            .find::<E2eiEnrollment>(id)
229            .await?
230            .ok_or(CryptoKeystoreError::MissingKeyInStore(
231                MissingKeyErrorKind::E2eiEnrollment,
232            ))?;
233        self.remove::<E2eiEnrollment, _>(id).await?;
234        Ok(enrollment.content.clone())
235    }
236}
237
238#[inline(always)]
239pub fn deser<T: MlsEntity>(bytes: &[u8]) -> Result<T, CryptoKeystoreError> {
240    Ok(postcard::from_bytes(bytes)?)
241}
242
243#[inline(always)]
244pub fn ser<T: MlsEntity>(value: &T) -> Result<Vec<u8>, CryptoKeystoreError> {
245    Ok(postcard::to_stdvec(value)?)
246}
247
248#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
249#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
250impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Database {
251    type Error = CryptoKeystoreError;
252
253    async fn store<V: MlsEntity + Sync>(&self, k: &[u8], v: &V) -> Result<(), Self::Error>
254    where
255        Self: Sized,
256    {
257        if k.is_empty() {
258            return Err(CryptoKeystoreError::MlsKeyStoreError(
259                "The provided key is empty".into(),
260            ));
261        }
262
263        let data = ser(v)?;
264
265        match V::ID {
266            MlsEntityId::GroupState => {
267                return Err(CryptoKeystoreError::IncorrectApiUsage(
268                    "Groups must not be saved using OpenMLS's APIs. You should use the keystore's provided methods",
269                ));
270            }
271            MlsEntityId::SignatureKeyPair => {
272                let concrete_signature_keypair: &SignatureKeyPair = v
273                    .downcast()
274                    .expect("There's an implementation issue in OpenMLS. This shouln't be happening.");
275
276                // Having an empty credential id seems tolerable, since the SignatureKeyPair type is retrieved from the key store via its public key.
277                let credential_id = vec![];
278                let kp = MlsSignatureKeyPair::new(
279                    concrete_signature_keypair.signature_scheme(),
280                    k.into(),
281                    data,
282                    credential_id,
283                );
284                self.save(kp).await?;
285            }
286            MlsEntityId::KeyPackage => {
287                let kp = MlsKeyPackage {
288                    keypackage_ref: k.into(),
289                    keypackage: data,
290                };
291                self.save(kp).await?;
292            }
293            MlsEntityId::HpkePrivateKey => {
294                let kp = MlsHpkePrivateKey { pk: k.into(), sk: data };
295                self.save(kp).await?;
296            }
297            MlsEntityId::PskBundle => {
298                let kp = MlsPskBundle {
299                    psk_id: k.into(),
300                    psk: data,
301                };
302                self.save(kp).await?;
303            }
304            MlsEntityId::EncryptionKeyPair => {
305                let kp = MlsEncryptionKeyPair { pk: k.into(), sk: data };
306                self.save(kp).await?;
307            }
308            MlsEntityId::EpochEncryptionKeyPair => {
309                let kp = MlsEpochEncryptionKeyPair {
310                    id: k.into(),
311                    keypairs: data,
312                };
313                self.save(kp).await?;
314            }
315        }
316
317        Ok(())
318    }
319
320    async fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V>
321    where
322        Self: Sized,
323    {
324        if k.is_empty() {
325            return None;
326        }
327
328        match V::ID {
329            MlsEntityId::GroupState => {
330                let group: PersistedMlsGroup = self.find(k).await.ok().flatten()?;
331                deser(&group.state).ok()
332            }
333            MlsEntityId::SignatureKeyPair => {
334                let sig: MlsSignatureKeyPair = self.find(k).await.ok().flatten()?;
335                deser(&sig.keypair).ok()
336            }
337            MlsEntityId::KeyPackage => {
338                let kp: MlsKeyPackage = self.find(k).await.ok().flatten()?;
339                deser(&kp.keypackage).ok()
340            }
341            MlsEntityId::HpkePrivateKey => {
342                let hpke_pk: MlsHpkePrivateKey = self.find(k).await.ok().flatten()?;
343                deser(&hpke_pk.sk).ok()
344            }
345            MlsEntityId::PskBundle => {
346                let psk_bundle: MlsPskBundle = self.find(k).await.ok().flatten()?;
347                deser(&psk_bundle.psk).ok()
348            }
349            MlsEntityId::EncryptionKeyPair => {
350                let kp: MlsEncryptionKeyPair = self.find(k).await.ok().flatten()?;
351                deser(&kp.sk).ok()
352            }
353            MlsEntityId::EpochEncryptionKeyPair => {
354                let kp: MlsEpochEncryptionKeyPair = self.find(k).await.ok().flatten()?;
355                deser(&kp.keypairs).ok()
356            }
357        }
358    }
359
360    async fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
361        match V::ID {
362            MlsEntityId::GroupState => self.remove::<PersistedMlsGroup, _>(k).await?,
363            MlsEntityId::SignatureKeyPair => self.remove::<MlsSignatureKeyPair, _>(k).await?,
364            MlsEntityId::HpkePrivateKey => self.remove::<MlsHpkePrivateKey, _>(k).await?,
365            MlsEntityId::KeyPackage => self.remove::<MlsKeyPackage, _>(k).await?,
366            MlsEntityId::PskBundle => self.remove::<MlsPskBundle, _>(k).await?,
367            MlsEntityId::EncryptionKeyPair => self.remove::<MlsEncryptionKeyPair, _>(k).await?,
368            MlsEntityId::EpochEncryptionKeyPair => self.remove::<MlsEpochEncryptionKeyPair, _>(k).await?,
369        }
370
371        Ok(())
372    }
373}