core_crypto_keystore/
mls.rs

1use openmls::prelude::Ciphersuite;
2use openmls_basic_credential::SignatureKeyPair;
3use openmls_traits::key_store::{MlsEntity, MlsEntityId};
4
5use crate::{
6    CryptoKeystoreError, CryptoKeystoreResult, Sha256Hash,
7    entities::{
8        PersistedMlsGroup, PersistedMlsPendingGroup, StoredCredential, StoredE2eiEnrollment, StoredEncryptionKeyPair,
9        StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle,
10    },
11    traits::FetchFromDatabase,
12};
13
14/// An interface for the specialized queries in the KeyStore
15#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
16#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
17pub trait CryptoKeystoreMls: Sized {
18    /// Fetches Keypackages
19    ///
20    /// # Arguments
21    /// * `count` - amount of entries to be returned
22    ///
23    /// # Errors
24    /// Any common error that can happen during a database connection. IoError being a common error
25    /// for example.
26    async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>>;
27
28    /// Checks if the given MLS group id exists in the keystore
29    /// Note: in case of any error, this will return false
30    ///
31    /// # Arguments
32    /// * `group_id` - group/conversation id
33    async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool;
34
35    /// Persists a `MlsGroup`
36    ///
37    /// # Arguments
38    /// * `group_id` - group/conversation id
39    /// * `state` - the group state
40    ///
41    /// # Errors
42    /// Any common error that can happen during a database connection. IoError being a common error
43    /// for example.
44    async fn mls_group_persist(
45        &self,
46        group_id: impl AsRef<[u8]> + Send,
47        state: &[u8],
48        parent_group_id: Option<&[u8]>,
49    ) -> CryptoKeystoreResult<()>;
50
51    /// Loads `MlsGroups` from the database. It will be returned as a `HashMap` where the key is
52    /// the group/conversation id and the value the group state
53    ///
54    /// # Errors
55    /// Any common error that can happen during a database connection. IoError being a common error
56    /// for example.
57    async fn mls_groups_restore(
58        &self,
59    ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>>;
60
61    /// Deletes `MlsGroups` from the database.
62    /// # Errors
63    /// Any common error that can happen during a database connection. IoError being a common error
64    /// for example.
65    async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()>;
66
67    /// Saves a `MlsGroup` in a temporary table (typically used in scenarios where the group cannot
68    /// be committed until the backend acknowledges it, like external commits)
69    ///
70    /// # Arguments
71    /// * `group_id` - group/conversation id
72    /// * `mls_group` - the group/conversation state
73    /// * `custom_configuration` - local group configuration
74    ///
75    /// # Errors
76    /// Any common error that can happen during a database connection. IoError being a common error
77    /// for example.
78    async fn mls_pending_groups_save(
79        &self,
80        group_id: impl AsRef<[u8]> + Send,
81        mls_group: &[u8],
82        custom_configuration: &[u8],
83        parent_group_id: Option<&[u8]>,
84    ) -> CryptoKeystoreResult<()>;
85
86    /// Loads a temporary `MlsGroup` and its configuration from the database
87    ///
88    /// # Arguments
89    /// * `id` - group/conversation id
90    ///
91    /// # Errors
92    /// Any common error that can happen during a database connection. IoError being a common error
93    /// for example.
94    async fn mls_pending_groups_load(
95        &self,
96        group_id: impl AsRef<[u8]> + Send,
97    ) -> CryptoKeystoreResult<Option<(Vec<u8>, Vec<u8>)>>;
98
99    /// Deletes a temporary `MlsGroup` from the database
100    ///
101    /// # Arguments
102    /// * `id` - group/conversation id
103    ///
104    /// # Errors
105    /// Any common error that can happen during a database connection. IoError being a common error
106    /// for example.
107    async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()>;
108
109    /// Persists an enrollment instance
110    ///
111    /// # Arguments
112    /// * `id` - hash of the enrollment and unique identifier
113    /// * `content` - serialized enrollment
114    async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()>;
115
116    /// Fetches and delete the enrollment instance
117    ///
118    /// # Arguments
119    /// * `id` - hash of the enrollment and unique identifier
120    async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Option<Vec<u8>>>;
121}
122
123#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
124#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
125impl CryptoKeystoreMls for crate::Database {
126    async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>> {
127        let keypackages = self.load_all::<StoredKeypackage>().await?;
128        Ok(keypackages
129            .into_iter()
130            .filter_map(|kpb| postcard::from_bytes(&kpb.keypackage).ok())
131            .take(count as _)
132            .collect())
133    }
134
135    async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool {
136        matches!(
137            self.get_borrowed::<PersistedMlsGroup>(group_id.as_ref()).await,
138            Ok(Some(_))
139        )
140    }
141
142    async fn mls_group_persist(
143        &self,
144        group_id: impl AsRef<[u8]> + Send,
145        state: &[u8],
146        parent_group_id: Option<&[u8]>,
147    ) -> CryptoKeystoreResult<()> {
148        self.save(PersistedMlsGroup {
149            id: group_id.as_ref().to_owned(),
150            state: state.into(),
151            parent_id: parent_group_id.map(Into::into),
152        })
153        .await?;
154
155        Ok(())
156    }
157
158    async fn mls_groups_restore(
159        &self,
160    ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>> {
161        let groups = self.load_all::<PersistedMlsGroup>().await?;
162        Ok(groups
163            .into_iter()
164            .map(|mut group: PersistedMlsGroup| {
165                let id = std::mem::take(&mut group.id);
166                let parent_id = std::mem::take(&mut group.parent_id);
167                let state = std::mem::take(&mut group.state);
168                (id, (parent_id, state))
169            })
170            .collect())
171    }
172
173    async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> {
174        self.remove_borrowed::<PersistedMlsGroup>(group_id.as_ref()).await?;
175        Ok(())
176    }
177
178    async fn mls_pending_groups_save(
179        &self,
180        group_id: impl AsRef<[u8]> + Send,
181        mls_group: &[u8],
182        custom_configuration: &[u8],
183        parent_group_id: Option<&[u8]>,
184    ) -> CryptoKeystoreResult<()> {
185        self.save(PersistedMlsPendingGroup {
186            id: group_id.as_ref().to_owned(),
187            state: mls_group.into(),
188            custom_configuration: custom_configuration.into(),
189            parent_id: parent_group_id.map(Into::into),
190        })
191        .await?;
192        Ok(())
193    }
194
195    async fn mls_pending_groups_load(
196        &self,
197        group_id: impl AsRef<[u8]> + Send,
198    ) -> CryptoKeystoreResult<Option<(Vec<u8>, Vec<u8>)>> {
199        self.get_borrowed::<PersistedMlsPendingGroup>(group_id.as_ref())
200            .await
201            .map(|optional| {
202                optional.map(|pending_group| (pending_group.state.clone(), pending_group.custom_configuration.clone()))
203            })
204    }
205
206    async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> {
207        self.remove_borrowed::<PersistedMlsPendingGroup>(group_id.as_ref())
208            .await
209    }
210
211    async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()> {
212        self.save(StoredE2eiEnrollment {
213            id: id.into(),
214            content: content.into(),
215        })
216        .await?;
217        Ok(())
218    }
219
220    async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Option<Vec<u8>>> {
221        // someone who has time could try to optimize this but honestly it's really on the cold path
222        let Some(mut enrollment) = self.get_borrowed::<StoredE2eiEnrollment>(id).await? else {
223            return Ok(None);
224        };
225        self.remove_borrowed::<StoredE2eiEnrollment>(id).await?;
226        Ok(Some(std::mem::take(&mut enrollment.content)))
227    }
228}
229
230#[inline(always)]
231pub fn deser<T: MlsEntity>(bytes: &[u8]) -> Result<T, CryptoKeystoreError> {
232    Ok(postcard::from_bytes(bytes)?)
233}
234
235#[inline(always)]
236pub fn ser<T: MlsEntity>(value: &T) -> Result<Vec<u8>, CryptoKeystoreError> {
237    Ok(postcard::to_stdvec(value)?)
238}
239
240#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
241#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
242impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Database {
243    type Error = CryptoKeystoreError;
244
245    async fn store<V: MlsEntity + Sync>(&self, k: &[u8], v: &V) -> Result<(), Self::Error>
246    where
247        Self: Sized,
248    {
249        if k.is_empty() {
250            return Err(CryptoKeystoreError::MlsKeyStoreError(
251                "The provided key is empty".into(),
252            ));
253        }
254
255        let data = ser(v)?;
256
257        match V::ID {
258            MlsEntityId::GroupState => {
259                return Err(CryptoKeystoreError::IncorrectApiUsage(
260                    "Groups must not be saved using OpenMLS's APIs. You should use the keystore's provided methods",
261                ));
262            }
263            MlsEntityId::SignatureKeyPair => {
264                return Err(CryptoKeystoreError::IncorrectApiUsage(
265                    "Signature keys must not be saved using OpenMLS's APIs. Save a credential via the keystore API
266                    instead.",
267                ));
268            }
269            MlsEntityId::KeyPackage => {
270                let kp = StoredKeypackage {
271                    keypackage_ref: k.into(),
272                    keypackage: data,
273                };
274                self.save(kp).await?;
275            }
276            MlsEntityId::HpkePrivateKey => {
277                let kp = StoredHpkePrivateKey { pk: k.into(), sk: data };
278                self.save(kp).await?;
279            }
280            MlsEntityId::PskBundle => {
281                let kp = StoredPskBundle {
282                    psk_id: k.into(),
283                    psk: data,
284                };
285                self.save(kp).await?;
286            }
287            MlsEntityId::EncryptionKeyPair => {
288                let kp = StoredEncryptionKeyPair { pk: k.into(), sk: data };
289                self.save(kp).await?;
290            }
291            MlsEntityId::EpochEncryptionKeyPair => {
292                let kp = StoredEpochEncryptionKeypair {
293                    id: k.into(),
294                    keypairs: data,
295                };
296                self.save(kp).await?;
297            }
298        }
299
300        Ok(())
301    }
302
303    async fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V>
304    where
305        Self: Sized,
306    {
307        if k.is_empty() {
308            return None;
309        }
310
311        match V::ID {
312            MlsEntityId::GroupState => {
313                let group: PersistedMlsGroup = self.get_borrowed(k).await.ok().flatten()?;
314                deser(&group.state).ok()
315            }
316            MlsEntityId::SignatureKeyPair => {
317                let hash = Sha256Hash::from_existing_hash(k).ok()?;
318                let stored_credential = self.get::<StoredCredential>(&hash).await.ok().flatten()?;
319                let ciphersuite = Ciphersuite::try_from(stored_credential.ciphersuite).ok()?;
320                let signature_scheme = ciphersuite.signature_algorithm();
321
322                let mls_keypair = SignatureKeyPair::from_raw(
323                    signature_scheme,
324                    stored_credential.private_key.to_vec(),
325                    stored_credential.public_key.to_vec(),
326                );
327
328                // In a well designed interface, something like this should not be necessary. However, we don't have
329                // a well-designed interface.
330                let mls_keypair_serialized = ser(&mls_keypair).ok()?;
331                deser(&mls_keypair_serialized).ok()
332            }
333            MlsEntityId::KeyPackage => {
334                let kp: StoredKeypackage = self.get_borrowed(k).await.ok().flatten()?;
335                deser(&kp.keypackage).ok()
336            }
337            MlsEntityId::HpkePrivateKey => {
338                let hpke_pk: StoredHpkePrivateKey = self.get_borrowed(k).await.ok().flatten()?;
339                deser(&hpke_pk.sk).ok()
340            }
341            MlsEntityId::PskBundle => {
342                let psk_bundle: StoredPskBundle = self.get_borrowed(k).await.ok().flatten()?;
343                deser(&psk_bundle.psk).ok()
344            }
345            MlsEntityId::EncryptionKeyPair => {
346                let kp: StoredEncryptionKeyPair = self.get_borrowed(k).await.ok().flatten()?;
347                deser(&kp.sk).ok()
348            }
349            MlsEntityId::EpochEncryptionKeyPair => {
350                let kp: StoredEpochEncryptionKeypair = self.get_borrowed(k).await.ok().flatten()?;
351                deser(&kp.keypairs).ok()
352            }
353        }
354    }
355
356    async fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
357        match V::ID {
358            MlsEntityId::GroupState => self.remove_borrowed::<PersistedMlsGroup>(k).await?,
359            MlsEntityId::SignatureKeyPair => unimplemented!(
360                "Deleting a signature key pair should not be done through this API, any keypair should be deleted via
361                deleting a credential."
362            ),
363            MlsEntityId::HpkePrivateKey => self.remove_borrowed::<StoredHpkePrivateKey>(k).await?,
364            MlsEntityId::KeyPackage => self.remove_borrowed::<StoredKeypackage>(k).await?,
365            MlsEntityId::PskBundle => self.remove_borrowed::<StoredPskBundle>(k).await?,
366            MlsEntityId::EncryptionKeyPair => self.remove_borrowed::<StoredEncryptionKeyPair>(k).await?,
367            MlsEntityId::EpochEncryptionKeyPair => self.remove_borrowed::<StoredEpochEncryptionKeypair>(k).await?,
368        }
369
370        Ok(())
371    }
372}