core_crypto_keystore/
mls.rs

1use openmls::prelude::Ciphersuite;
2use openmls_basic_credential::SignatureKeyPair;
3use openmls_traits::key_store::{MlsEntity, MlsEntityId};
4
5use crate::{
6    CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind,
7    connection::FetchFromDatabase,
8    entities::{
9        EntityFindParams, PersistedMlsGroup, PersistedMlsPendingGroup, StoredCredential, StoredE2eiEnrollment,
10        StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle,
11    },
12};
13
14/// An interface for the specialized queries in the KeyStore
15#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
16#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
17pub trait CryptoKeystoreMls: Sized {
18    /// Fetches Keypackages
19    ///
20    /// # Arguments
21    /// * `count` - amount of entries to be returned
22    ///
23    /// # Errors
24    /// Any common error that can happen during a database connection. IoError being a common error
25    /// for example.
26    async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>>;
27
28    /// Checks if the given MLS group id exists in the keystore
29    /// Note: in case of any error, this will return false
30    ///
31    /// # Arguments
32    /// * `group_id` - group/conversation id
33    async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool;
34
35    /// Persists a `MlsGroup`
36    ///
37    /// # Arguments
38    /// * `group_id` - group/conversation id
39    /// * `state` - the group state
40    ///
41    /// # Errors
42    /// Any common error that can happen during a database connection. IoError being a common error
43    /// for example.
44    async fn mls_group_persist(
45        &self,
46        group_id: impl AsRef<[u8]> + Send,
47        state: &[u8],
48        parent_group_id: Option<&[u8]>,
49    ) -> CryptoKeystoreResult<()>;
50
51    /// Loads `MlsGroups` from the database. It will be returned as a `HashMap` where the key is
52    /// the group/conversation id and the value the group state
53    ///
54    /// # Errors
55    /// Any common error that can happen during a database connection. IoError being a common error
56    /// for example.
57    async fn mls_groups_restore(
58        &self,
59    ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>>;
60
61    /// Deletes `MlsGroups` from the database.
62    /// # Errors
63    /// Any common error that can happen during a database connection. IoError being a common error
64    /// for example.
65    async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()>;
66
67    /// Saves a `MlsGroup` in a temporary table (typically used in scenarios where the group cannot
68    /// be committed until the backend acknowledges it, like external commits)
69    ///
70    /// # Arguments
71    /// * `group_id` - group/conversation id
72    /// * `mls_group` - the group/conversation state
73    /// * `custom_configuration` - local group configuration
74    ///
75    /// # Errors
76    /// Any common error that can happen during a database connection. IoError being a common error
77    /// for example.
78    async fn mls_pending_groups_save(
79        &self,
80        group_id: impl AsRef<[u8]> + Send,
81        mls_group: &[u8],
82        custom_configuration: &[u8],
83        parent_group_id: Option<&[u8]>,
84    ) -> CryptoKeystoreResult<()>;
85
86    /// Loads a temporary `MlsGroup` and its configuration from the database
87    ///
88    /// # Arguments
89    /// * `id` - group/conversation id
90    ///
91    /// # Errors
92    /// Any common error that can happen during a database connection. IoError being a common error
93    /// for example.
94    async fn mls_pending_groups_load(
95        &self,
96        group_id: impl AsRef<[u8]> + Send,
97    ) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)>;
98
99    /// Deletes a temporary `MlsGroup` from the database
100    ///
101    /// # Arguments
102    /// * `id` - group/conversation id
103    ///
104    /// # Errors
105    /// Any common error that can happen during a database connection. IoError being a common error
106    /// for example.
107    async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()>;
108
109    /// Persists an enrollment instance
110    ///
111    /// # Arguments
112    /// * `id` - hash of the enrollment and unique identifier
113    /// * `content` - serialized enrollment
114    async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()>;
115
116    /// Fetches and delete the enrollment instance
117    ///
118    /// # Arguments
119    /// * `id` - hash of the enrollment and unique identifier
120    async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>>;
121}
122
123#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
124#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
125impl CryptoKeystoreMls for crate::Database {
126    async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>> {
127        let reverse = !cfg!(target_family = "wasm");
128        let keypackages = self
129            .find_all::<StoredKeypackage>(EntityFindParams {
130                limit: Some(count),
131                offset: None,
132                reverse,
133            })
134            .await?;
135
136        Ok(keypackages
137            .into_iter()
138            .filter_map(|kpb| postcard::from_bytes(&kpb.keypackage).ok())
139            .collect())
140    }
141
142    async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool {
143        matches!(self.find::<PersistedMlsGroup>(group_id).await, Ok(Some(_)))
144    }
145
146    async fn mls_group_persist(
147        &self,
148        group_id: impl AsRef<[u8]> + Send,
149        state: &[u8],
150        parent_group_id: Option<&[u8]>,
151    ) -> CryptoKeystoreResult<()> {
152        self.save(PersistedMlsGroup {
153            id: group_id.as_ref().to_owned(),
154            state: state.into(),
155            parent_id: parent_group_id.map(Into::into),
156        })
157        .await?;
158
159        Ok(())
160    }
161
162    async fn mls_groups_restore(
163        &self,
164    ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>> {
165        let groups = self.find_all::<PersistedMlsGroup>(EntityFindParams::default()).await?;
166        Ok(groups
167            .into_iter()
168            .map(|group: PersistedMlsGroup| (group.id.clone(), (group.parent_id.clone(), group.state.clone())))
169            .collect())
170    }
171
172    async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> {
173        self.remove::<PersistedMlsGroup, _>(group_id).await?;
174
175        Ok(())
176    }
177
178    async fn mls_pending_groups_save(
179        &self,
180        group_id: impl AsRef<[u8]> + Send,
181        mls_group: &[u8],
182        custom_configuration: &[u8],
183        parent_group_id: Option<&[u8]>,
184    ) -> CryptoKeystoreResult<()> {
185        self.save(PersistedMlsPendingGroup {
186            id: group_id.as_ref().to_owned(),
187            state: mls_group.into(),
188            custom_configuration: custom_configuration.into(),
189            parent_id: parent_group_id.map(Into::into),
190        })
191        .await?;
192        Ok(())
193    }
194
195    async fn mls_pending_groups_load(
196        &self,
197        group_id: impl AsRef<[u8]> + Send,
198    ) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)> {
199        self.find(group_id)
200            .await?
201            .map(|r: PersistedMlsPendingGroup| (r.state.clone(), r.custom_configuration.clone()))
202            .ok_or(CryptoKeystoreError::MissingKeyInStore(
203                MissingKeyErrorKind::MlsPendingGroup,
204            ))
205    }
206
207    async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> {
208        self.remove::<PersistedMlsPendingGroup, _>(group_id).await
209    }
210
211    async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()> {
212        self.save(StoredE2eiEnrollment {
213            id: id.into(),
214            content: content.into(),
215        })
216        .await?;
217        Ok(())
218    }
219
220    async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>> {
221        // someone who has time could try to optimize this but honestly it's really on the cold path
222        let enrollment = self
223            .find::<StoredE2eiEnrollment>(id)
224            .await?
225            .ok_or(CryptoKeystoreError::MissingKeyInStore(
226                MissingKeyErrorKind::StoredE2eiEnrollment,
227            ))?;
228        self.remove::<StoredE2eiEnrollment, _>(id).await?;
229        Ok(enrollment.content.clone())
230    }
231}
232
233#[inline(always)]
234pub fn deser<T: MlsEntity>(bytes: &[u8]) -> Result<T, CryptoKeystoreError> {
235    Ok(postcard::from_bytes(bytes)?)
236}
237
238#[inline(always)]
239pub fn ser<T: MlsEntity>(value: &T) -> Result<Vec<u8>, CryptoKeystoreError> {
240    Ok(postcard::to_stdvec(value)?)
241}
242
243#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
244#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
245impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Database {
246    type Error = CryptoKeystoreError;
247
248    async fn store<V: MlsEntity + Sync>(&self, k: &[u8], v: &V) -> Result<(), Self::Error>
249    where
250        Self: Sized,
251    {
252        if k.is_empty() {
253            return Err(CryptoKeystoreError::MlsKeyStoreError(
254                "The provided key is empty".into(),
255            ));
256        }
257
258        let data = ser(v)?;
259
260        match V::ID {
261            MlsEntityId::GroupState => {
262                return Err(CryptoKeystoreError::IncorrectApiUsage(
263                    "Groups must not be saved using OpenMLS's APIs. You should use the keystore's provided methods",
264                ));
265            }
266            MlsEntityId::SignatureKeyPair => {
267                return Err(CryptoKeystoreError::IncorrectApiUsage(
268                    "Signature keys must not be saved using OpenMLS's APIs. Save a credential via the keystore API
269                    instead.",
270                ));
271            }
272            MlsEntityId::KeyPackage => {
273                let kp = StoredKeypackage {
274                    keypackage_ref: k.into(),
275                    keypackage: data,
276                };
277                self.save(kp).await?;
278            }
279            MlsEntityId::HpkePrivateKey => {
280                let kp = StoredHpkePrivateKey { pk: k.into(), sk: data };
281                self.save(kp).await?;
282            }
283            MlsEntityId::PskBundle => {
284                let kp = StoredPskBundle {
285                    psk_id: k.into(),
286                    psk: data,
287                };
288                self.save(kp).await?;
289            }
290            MlsEntityId::EncryptionKeyPair => {
291                let kp = StoredEncryptionKeyPair { pk: k.into(), sk: data };
292                self.save(kp).await?;
293            }
294            MlsEntityId::EpochEncryptionKeyPair => {
295                let kp = StoredEpochEncryptionKeypair {
296                    id: k.into(),
297                    keypairs: data,
298                };
299                self.save(kp).await?;
300            }
301        }
302
303        Ok(())
304    }
305
306    async fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V>
307    where
308        Self: Sized,
309    {
310        if k.is_empty() {
311            return None;
312        }
313
314        match V::ID {
315            MlsEntityId::GroupState => {
316                let group: PersistedMlsGroup = self.find(k).await.ok().flatten()?;
317                deser(&group.state).ok()
318            }
319            MlsEntityId::SignatureKeyPair => {
320                let stored_credential = self.find::<StoredCredential>(k).await.ok().flatten()?;
321                let ciphersuite = Ciphersuite::try_from(stored_credential.ciphersuite).ok()?;
322                let signature_scheme = ciphersuite.signature_algorithm();
323
324                let mls_keypair = SignatureKeyPair::from_raw(
325                    signature_scheme,
326                    stored_credential.private_key.to_vec(),
327                    stored_credential.public_key.to_vec(),
328                );
329
330                // In a well designed interface, something like this should not be necessary. However, we don't have
331                // a well-designed interface.
332                let mls_keypair_serialized = ser(&mls_keypair).ok()?;
333                deser(&mls_keypair_serialized).ok()
334            }
335            MlsEntityId::KeyPackage => {
336                let kp: StoredKeypackage = self.find(k).await.ok().flatten()?;
337                deser(&kp.keypackage).ok()
338            }
339            MlsEntityId::HpkePrivateKey => {
340                let hpke_pk: StoredHpkePrivateKey = self.find(k).await.ok().flatten()?;
341                deser(&hpke_pk.sk).ok()
342            }
343            MlsEntityId::PskBundle => {
344                let psk_bundle: StoredPskBundle = self.find(k).await.ok().flatten()?;
345                deser(&psk_bundle.psk).ok()
346            }
347            MlsEntityId::EncryptionKeyPair => {
348                let kp: StoredEncryptionKeyPair = self.find(k).await.ok().flatten()?;
349                deser(&kp.sk).ok()
350            }
351            MlsEntityId::EpochEncryptionKeyPair => {
352                let kp: StoredEpochEncryptionKeypair = self.find(k).await.ok().flatten()?;
353                deser(&kp.keypairs).ok()
354            }
355        }
356    }
357
358    async fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
359        match V::ID {
360            MlsEntityId::GroupState => self.remove::<PersistedMlsGroup, _>(k).await?,
361            MlsEntityId::SignatureKeyPair => unimplemented!(
362                "Deleting a signature key pair should not be done through this API, any keypair should be deleted via
363                deleting a credential."
364            ),
365            MlsEntityId::HpkePrivateKey => self.remove::<StoredHpkePrivateKey, _>(k).await?,
366            MlsEntityId::KeyPackage => self.remove::<StoredKeypackage, _>(k).await?,
367            MlsEntityId::PskBundle => self.remove::<StoredPskBundle, _>(k).await?,
368            MlsEntityId::EncryptionKeyPair => self.remove::<StoredEncryptionKeyPair, _>(k).await?,
369            MlsEntityId::EpochEncryptionKeyPair => self.remove::<StoredEpochEncryptionKeypair, _>(k).await?,
370        }
371
372        Ok(())
373    }
374}