core_crypto_keystore/
mls.rs

1use openmls::prelude::Ciphersuite;
2use openmls_basic_credential::SignatureKeyPair;
3use openmls_traits::key_store::{MlsEntity, MlsEntityId};
4
5use crate::{
6    CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind, Sha256Hash,
7    entities::{
8        PersistedMlsGroup, PersistedMlsPendingGroup, StoredCredential, StoredE2eiEnrollment, StoredEncryptionKeyPair,
9        StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle,
10    },
11    traits::FetchFromDatabase,
12};
13
14/// An interface for the specialized queries in the KeyStore
15#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
16#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
17pub trait CryptoKeystoreMls: Sized {
18    /// Fetches Keypackages
19    ///
20    /// # Arguments
21    /// * `count` - amount of entries to be returned
22    ///
23    /// # Errors
24    /// Any common error that can happen during a database connection. IoError being a common error
25    /// for example.
26    async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>>;
27
28    /// Checks if the given MLS group id exists in the keystore
29    /// Note: in case of any error, this will return false
30    ///
31    /// # Arguments
32    /// * `group_id` - group/conversation id
33    async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool;
34
35    /// Persists a `MlsGroup`
36    ///
37    /// # Arguments
38    /// * `group_id` - group/conversation id
39    /// * `state` - the group state
40    ///
41    /// # Errors
42    /// Any common error that can happen during a database connection. IoError being a common error
43    /// for example.
44    async fn mls_group_persist(
45        &self,
46        group_id: impl AsRef<[u8]> + Send,
47        state: &[u8],
48        parent_group_id: Option<&[u8]>,
49    ) -> CryptoKeystoreResult<()>;
50
51    /// Loads `MlsGroups` from the database. It will be returned as a `HashMap` where the key is
52    /// the group/conversation id and the value the group state
53    ///
54    /// # Errors
55    /// Any common error that can happen during a database connection. IoError being a common error
56    /// for example.
57    async fn mls_groups_restore(
58        &self,
59    ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>>;
60
61    /// Deletes `MlsGroups` from the database.
62    /// # Errors
63    /// Any common error that can happen during a database connection. IoError being a common error
64    /// for example.
65    async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()>;
66
67    /// Saves a `MlsGroup` in a temporary table (typically used in scenarios where the group cannot
68    /// be committed until the backend acknowledges it, like external commits)
69    ///
70    /// # Arguments
71    /// * `group_id` - group/conversation id
72    /// * `mls_group` - the group/conversation state
73    /// * `custom_configuration` - local group configuration
74    ///
75    /// # Errors
76    /// Any common error that can happen during a database connection. IoError being a common error
77    /// for example.
78    async fn mls_pending_groups_save(
79        &self,
80        group_id: impl AsRef<[u8]> + Send,
81        mls_group: &[u8],
82        custom_configuration: &[u8],
83        parent_group_id: Option<&[u8]>,
84    ) -> CryptoKeystoreResult<()>;
85
86    /// Loads a temporary `MlsGroup` and its configuration from the database
87    ///
88    /// # Arguments
89    /// * `id` - group/conversation id
90    ///
91    /// # Errors
92    /// Any common error that can happen during a database connection. IoError being a common error
93    /// for example.
94    async fn mls_pending_groups_load(
95        &self,
96        group_id: impl AsRef<[u8]> + Send,
97    ) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)>;
98
99    /// Deletes a temporary `MlsGroup` from the database
100    ///
101    /// # Arguments
102    /// * `id` - group/conversation id
103    ///
104    /// # Errors
105    /// Any common error that can happen during a database connection. IoError being a common error
106    /// for example.
107    async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()>;
108
109    /// Persists an enrollment instance
110    ///
111    /// # Arguments
112    /// * `id` - hash of the enrollment and unique identifier
113    /// * `content` - serialized enrollment
114    async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()>;
115
116    /// Fetches and delete the enrollment instance
117    ///
118    /// # Arguments
119    /// * `id` - hash of the enrollment and unique identifier
120    async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>>;
121}
122
123#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
124#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
125impl CryptoKeystoreMls for crate::Database {
126    async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>> {
127        let keypackages = self.load_all::<StoredKeypackage>().await?;
128        Ok(keypackages
129            .into_iter()
130            .filter_map(|kpb| postcard::from_bytes(&kpb.keypackage).ok())
131            .take(count as _)
132            .collect())
133    }
134
135    async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool {
136        matches!(
137            self.get_borrowed::<PersistedMlsGroup>(group_id.as_ref()).await,
138            Ok(Some(_))
139        )
140    }
141
142    async fn mls_group_persist(
143        &self,
144        group_id: impl AsRef<[u8]> + Send,
145        state: &[u8],
146        parent_group_id: Option<&[u8]>,
147    ) -> CryptoKeystoreResult<()> {
148        self.save(PersistedMlsGroup {
149            id: group_id.as_ref().to_owned(),
150            state: state.into(),
151            parent_id: parent_group_id.map(Into::into),
152        })
153        .await?;
154
155        Ok(())
156    }
157
158    async fn mls_groups_restore(
159        &self,
160    ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>> {
161        let groups = self.load_all::<PersistedMlsGroup>().await?;
162        Ok(groups
163            .into_iter()
164            .map(|mut group: PersistedMlsGroup| {
165                let id = std::mem::take(&mut group.id);
166                let parent_id = std::mem::take(&mut group.parent_id);
167                let state = std::mem::take(&mut group.state);
168                (id, (parent_id, state))
169            })
170            .collect())
171    }
172
173    async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> {
174        self.remove_borrowed::<PersistedMlsGroup>(group_id.as_ref()).await?;
175        Ok(())
176    }
177
178    async fn mls_pending_groups_save(
179        &self,
180        group_id: impl AsRef<[u8]> + Send,
181        mls_group: &[u8],
182        custom_configuration: &[u8],
183        parent_group_id: Option<&[u8]>,
184    ) -> CryptoKeystoreResult<()> {
185        self.save(PersistedMlsPendingGroup {
186            id: group_id.as_ref().to_owned(),
187            state: mls_group.into(),
188            custom_configuration: custom_configuration.into(),
189            parent_id: parent_group_id.map(Into::into),
190        })
191        .await?;
192        Ok(())
193    }
194
195    async fn mls_pending_groups_load(
196        &self,
197        group_id: impl AsRef<[u8]> + Send,
198    ) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)> {
199        self.get_borrowed(group_id.as_ref())
200            .await?
201            .map(|r: PersistedMlsPendingGroup| (r.state.clone(), r.custom_configuration.clone()))
202            .ok_or(CryptoKeystoreError::MissingKeyInStore(
203                MissingKeyErrorKind::MlsPendingGroup,
204            ))
205    }
206
207    async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> {
208        self.remove_borrowed::<PersistedMlsPendingGroup>(group_id.as_ref())
209            .await
210    }
211
212    async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()> {
213        self.save(StoredE2eiEnrollment {
214            id: id.into(),
215            content: content.into(),
216        })
217        .await?;
218        Ok(())
219    }
220
221    async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>> {
222        // someone who has time could try to optimize this but honestly it's really on the cold path
223        let enrollment =
224            self.get_borrowed::<StoredE2eiEnrollment>(id)
225                .await?
226                .ok_or(CryptoKeystoreError::MissingKeyInStore(
227                    MissingKeyErrorKind::StoredE2eiEnrollment,
228                ))?;
229        self.remove_borrowed::<StoredE2eiEnrollment>(id).await?;
230        Ok(enrollment.content.clone())
231    }
232}
233
234#[inline(always)]
235pub fn deser<T: MlsEntity>(bytes: &[u8]) -> Result<T, CryptoKeystoreError> {
236    Ok(postcard::from_bytes(bytes)?)
237}
238
239#[inline(always)]
240pub fn ser<T: MlsEntity>(value: &T) -> Result<Vec<u8>, CryptoKeystoreError> {
241    Ok(postcard::to_stdvec(value)?)
242}
243
244#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
245#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
246impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Database {
247    type Error = CryptoKeystoreError;
248
249    async fn store<V: MlsEntity + Sync>(&self, k: &[u8], v: &V) -> Result<(), Self::Error>
250    where
251        Self: Sized,
252    {
253        if k.is_empty() {
254            return Err(CryptoKeystoreError::MlsKeyStoreError(
255                "The provided key is empty".into(),
256            ));
257        }
258
259        let data = ser(v)?;
260
261        match V::ID {
262            MlsEntityId::GroupState => {
263                return Err(CryptoKeystoreError::IncorrectApiUsage(
264                    "Groups must not be saved using OpenMLS's APIs. You should use the keystore's provided methods",
265                ));
266            }
267            MlsEntityId::SignatureKeyPair => {
268                return Err(CryptoKeystoreError::IncorrectApiUsage(
269                    "Signature keys must not be saved using OpenMLS's APIs. Save a credential via the keystore API
270                    instead.",
271                ));
272            }
273            MlsEntityId::KeyPackage => {
274                let kp = StoredKeypackage {
275                    keypackage_ref: k.into(),
276                    keypackage: data,
277                };
278                self.save(kp).await?;
279            }
280            MlsEntityId::HpkePrivateKey => {
281                let kp = StoredHpkePrivateKey { pk: k.into(), sk: data };
282                self.save(kp).await?;
283            }
284            MlsEntityId::PskBundle => {
285                let kp = StoredPskBundle {
286                    psk_id: k.into(),
287                    psk: data,
288                };
289                self.save(kp).await?;
290            }
291            MlsEntityId::EncryptionKeyPair => {
292                let kp = StoredEncryptionKeyPair { pk: k.into(), sk: data };
293                self.save(kp).await?;
294            }
295            MlsEntityId::EpochEncryptionKeyPair => {
296                let kp = StoredEpochEncryptionKeypair {
297                    id: k.into(),
298                    keypairs: data,
299                };
300                self.save(kp).await?;
301            }
302        }
303
304        Ok(())
305    }
306
307    async fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V>
308    where
309        Self: Sized,
310    {
311        if k.is_empty() {
312            return None;
313        }
314
315        match V::ID {
316            MlsEntityId::GroupState => {
317                let group: PersistedMlsGroup = self.get_borrowed(k).await.ok().flatten()?;
318                deser(&group.state).ok()
319            }
320            MlsEntityId::SignatureKeyPair => {
321                let hash = Sha256Hash::from_existing_hash(k).ok()?;
322                let stored_credential = self.get::<StoredCredential>(&hash).await.ok().flatten()?;
323                let ciphersuite = Ciphersuite::try_from(stored_credential.ciphersuite).ok()?;
324                let signature_scheme = ciphersuite.signature_algorithm();
325
326                let mls_keypair = SignatureKeyPair::from_raw(
327                    signature_scheme,
328                    stored_credential.private_key.to_vec(),
329                    stored_credential.public_key.to_vec(),
330                );
331
332                // In a well designed interface, something like this should not be necessary. However, we don't have
333                // a well-designed interface.
334                let mls_keypair_serialized = ser(&mls_keypair).ok()?;
335                deser(&mls_keypair_serialized).ok()
336            }
337            MlsEntityId::KeyPackage => {
338                let kp: StoredKeypackage = self.get_borrowed(k).await.ok().flatten()?;
339                deser(&kp.keypackage).ok()
340            }
341            MlsEntityId::HpkePrivateKey => {
342                let hpke_pk: StoredHpkePrivateKey = self.get_borrowed(k).await.ok().flatten()?;
343                deser(&hpke_pk.sk).ok()
344            }
345            MlsEntityId::PskBundle => {
346                let psk_bundle: StoredPskBundle = self.get_borrowed(k).await.ok().flatten()?;
347                deser(&psk_bundle.psk).ok()
348            }
349            MlsEntityId::EncryptionKeyPair => {
350                let kp: StoredEncryptionKeyPair = self.get_borrowed(k).await.ok().flatten()?;
351                deser(&kp.sk).ok()
352            }
353            MlsEntityId::EpochEncryptionKeyPair => {
354                let kp: StoredEpochEncryptionKeypair = self.get_borrowed(k).await.ok().flatten()?;
355                deser(&kp.keypairs).ok()
356            }
357        }
358    }
359
360    async fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
361        match V::ID {
362            MlsEntityId::GroupState => self.remove_borrowed::<PersistedMlsGroup>(k).await?,
363            MlsEntityId::SignatureKeyPair => unimplemented!(
364                "Deleting a signature key pair should not be done through this API, any keypair should be deleted via
365                deleting a credential."
366            ),
367            MlsEntityId::HpkePrivateKey => self.remove_borrowed::<StoredHpkePrivateKey>(k).await?,
368            MlsEntityId::KeyPackage => self.remove_borrowed::<StoredKeypackage>(k).await?,
369            MlsEntityId::PskBundle => self.remove_borrowed::<StoredPskBundle>(k).await?,
370            MlsEntityId::EncryptionKeyPair => self.remove_borrowed::<StoredEncryptionKeyPair>(k).await?,
371            MlsEntityId::EpochEncryptionKeyPair => self.remove_borrowed::<StoredEpochEncryptionKeypair>(k).await?,
372        }
373
374        Ok(())
375    }
376}