core_crypto_keystore/
mls.rs

1use openmls_basic_credential::SignatureKeyPair;
2use openmls_traits::key_store::{MlsEntity, MlsEntityId};
3
4use crate::{
5    CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind,
6    connection::FetchFromDatabase,
7    entities::{
8        EntityFindParams, PersistedMlsGroup, PersistedMlsPendingGroup, StoredCredential, StoredE2eiEnrollment,
9        StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle,
10    },
11};
12
13/// An interface for the specialized queries in the KeyStore
14#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
15#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
16pub trait CryptoKeystoreMls: Sized {
17    /// Fetches Keypackages
18    ///
19    /// # Arguments
20    /// * `count` - amount of entries to be returned
21    ///
22    /// # Errors
23    /// Any common error that can happen during a database connection. IoError being a common error
24    /// for example.
25    async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>>;
26
27    /// Checks if the given MLS group id exists in the keystore
28    /// Note: in case of any error, this will return false
29    ///
30    /// # Arguments
31    /// * `group_id` - group/conversation id
32    async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool;
33
34    /// Persists a `MlsGroup`
35    ///
36    /// # Arguments
37    /// * `group_id` - group/conversation id
38    /// * `state` - the group state
39    ///
40    /// # Errors
41    /// Any common error that can happen during a database connection. IoError being a common error
42    /// for example.
43    async fn mls_group_persist(
44        &self,
45        group_id: impl AsRef<[u8]> + Send,
46        state: &[u8],
47        parent_group_id: Option<&[u8]>,
48    ) -> CryptoKeystoreResult<()>;
49
50    /// Loads `MlsGroups` from the database. It will be returned as a `HashMap` where the key is
51    /// the group/conversation id and the value the group state
52    ///
53    /// # Errors
54    /// Any common error that can happen during a database connection. IoError being a common error
55    /// for example.
56    async fn mls_groups_restore(
57        &self,
58    ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>>;
59
60    /// Deletes `MlsGroups` from the database.
61    /// # Errors
62    /// Any common error that can happen during a database connection. IoError being a common error
63    /// for example.
64    async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()>;
65
66    /// Saves a `MlsGroup` in a temporary table (typically used in scenarios where the group cannot
67    /// be committed until the backend acknowledges it, like external commits)
68    ///
69    /// # Arguments
70    /// * `group_id` - group/conversation id
71    /// * `mls_group` - the group/conversation state
72    /// * `custom_configuration` - local group configuration
73    ///
74    /// # Errors
75    /// Any common error that can happen during a database connection. IoError being a common error
76    /// for example.
77    async fn mls_pending_groups_save(
78        &self,
79        group_id: impl AsRef<[u8]> + Send,
80        mls_group: &[u8],
81        custom_configuration: &[u8],
82        parent_group_id: Option<&[u8]>,
83    ) -> CryptoKeystoreResult<()>;
84
85    /// Loads a temporary `MlsGroup` and its configuration from the database
86    ///
87    /// # Arguments
88    /// * `id` - group/conversation id
89    ///
90    /// # Errors
91    /// Any common error that can happen during a database connection. IoError being a common error
92    /// for example.
93    async fn mls_pending_groups_load(
94        &self,
95        group_id: impl AsRef<[u8]> + Send,
96    ) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)>;
97
98    /// Deletes a temporary `MlsGroup` from the database
99    ///
100    /// # Arguments
101    /// * `id` - group/conversation id
102    ///
103    /// # Errors
104    /// Any common error that can happen during a database connection. IoError being a common error
105    /// for example.
106    async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()>;
107
108    /// Persists an enrollment instance
109    ///
110    /// # Arguments
111    /// * `id` - hash of the enrollment and unique identifier
112    /// * `content` - serialized enrollment
113    async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()>;
114
115    /// Fetches and delete the enrollment instance
116    ///
117    /// # Arguments
118    /// * `id` - hash of the enrollment and unique identifier
119    async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>>;
120}
121
122#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
123#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
124impl CryptoKeystoreMls for crate::Database {
125    async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>> {
126        let reverse = !cfg!(target_family = "wasm");
127        let keypackages = self
128            .find_all::<StoredKeypackage>(EntityFindParams {
129                limit: Some(count),
130                offset: None,
131                reverse,
132            })
133            .await?;
134
135        Ok(keypackages
136            .into_iter()
137            .filter_map(|kpb| postcard::from_bytes(&kpb.keypackage).ok())
138            .collect())
139    }
140
141    async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool {
142        matches!(self.find::<PersistedMlsGroup>(group_id).await, Ok(Some(_)))
143    }
144
145    async fn mls_group_persist(
146        &self,
147        group_id: impl AsRef<[u8]> + Send,
148        state: &[u8],
149        parent_group_id: Option<&[u8]>,
150    ) -> CryptoKeystoreResult<()> {
151        self.save(PersistedMlsGroup {
152            id: group_id.as_ref().to_owned(),
153            state: state.into(),
154            parent_id: parent_group_id.map(Into::into),
155        })
156        .await?;
157
158        Ok(())
159    }
160
161    async fn mls_groups_restore(
162        &self,
163    ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>> {
164        let groups = self.find_all::<PersistedMlsGroup>(EntityFindParams::default()).await?;
165        Ok(groups
166            .into_iter()
167            .map(|group: PersistedMlsGroup| (group.id.clone(), (group.parent_id.clone(), group.state.clone())))
168            .collect())
169    }
170
171    async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> {
172        self.remove::<PersistedMlsGroup, _>(group_id).await?;
173
174        Ok(())
175    }
176
177    async fn mls_pending_groups_save(
178        &self,
179        group_id: impl AsRef<[u8]> + Send,
180        mls_group: &[u8],
181        custom_configuration: &[u8],
182        parent_group_id: Option<&[u8]>,
183    ) -> CryptoKeystoreResult<()> {
184        self.save(PersistedMlsPendingGroup {
185            id: group_id.as_ref().to_owned(),
186            state: mls_group.into(),
187            custom_configuration: custom_configuration.into(),
188            parent_id: parent_group_id.map(Into::into),
189        })
190        .await?;
191        Ok(())
192    }
193
194    async fn mls_pending_groups_load(
195        &self,
196        group_id: impl AsRef<[u8]> + Send,
197    ) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)> {
198        self.find(group_id)
199            .await?
200            .map(|r: PersistedMlsPendingGroup| (r.state.clone(), r.custom_configuration.clone()))
201            .ok_or(CryptoKeystoreError::MissingKeyInStore(
202                MissingKeyErrorKind::MlsPendingGroup,
203            ))
204    }
205
206    async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> {
207        self.remove::<PersistedMlsPendingGroup, _>(group_id).await
208    }
209
210    async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()> {
211        self.save(StoredE2eiEnrollment {
212            id: id.into(),
213            content: content.into(),
214        })
215        .await?;
216        Ok(())
217    }
218
219    async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>> {
220        // someone who has time could try to optimize this but honestly it's really on the cold path
221        let enrollment = self
222            .find::<StoredE2eiEnrollment>(id)
223            .await?
224            .ok_or(CryptoKeystoreError::MissingKeyInStore(
225                MissingKeyErrorKind::StoredE2eiEnrollment,
226            ))?;
227        self.remove::<StoredE2eiEnrollment, _>(id).await?;
228        Ok(enrollment.content.clone())
229    }
230}
231
232#[inline(always)]
233pub fn deser<T: MlsEntity>(bytes: &[u8]) -> Result<T, CryptoKeystoreError> {
234    Ok(postcard::from_bytes(bytes)?)
235}
236
237#[inline(always)]
238pub fn ser<T: MlsEntity>(value: &T) -> Result<Vec<u8>, CryptoKeystoreError> {
239    Ok(postcard::to_stdvec(value)?)
240}
241
242#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
243#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
244impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Database {
245    type Error = CryptoKeystoreError;
246
247    async fn store<V: MlsEntity + Sync>(&self, k: &[u8], v: &V) -> Result<(), Self::Error>
248    where
249        Self: Sized,
250    {
251        if k.is_empty() {
252            return Err(CryptoKeystoreError::MlsKeyStoreError(
253                "The provided key is empty".into(),
254            ));
255        }
256
257        let data = ser(v)?;
258
259        match V::ID {
260            MlsEntityId::GroupState => {
261                return Err(CryptoKeystoreError::IncorrectApiUsage(
262                    "Groups must not be saved using OpenMLS's APIs. You should use the keystore's provided methods",
263                ));
264            }
265            MlsEntityId::SignatureKeyPair => {
266                return Err(CryptoKeystoreError::IncorrectApiUsage(
267                    "Signature keys must not be saved using OpenMLS's APIs. Save a credential via the keystore API
268                    instead.",
269                ));
270            }
271            MlsEntityId::KeyPackage => {
272                let kp = StoredKeypackage {
273                    keypackage_ref: k.into(),
274                    keypackage: data,
275                };
276                self.save(kp).await?;
277            }
278            MlsEntityId::HpkePrivateKey => {
279                let kp = StoredHpkePrivateKey { pk: k.into(), sk: data };
280                self.save(kp).await?;
281            }
282            MlsEntityId::PskBundle => {
283                let kp = StoredPskBundle {
284                    psk_id: k.into(),
285                    psk: data,
286                };
287                self.save(kp).await?;
288            }
289            MlsEntityId::EncryptionKeyPair => {
290                let kp = StoredEncryptionKeyPair { pk: k.into(), sk: data };
291                self.save(kp).await?;
292            }
293            MlsEntityId::EpochEncryptionKeyPair => {
294                let kp = StoredEpochEncryptionKeypair {
295                    id: k.into(),
296                    keypairs: data,
297                };
298                self.save(kp).await?;
299            }
300        }
301
302        Ok(())
303    }
304
305    async fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V>
306    where
307        Self: Sized,
308    {
309        if k.is_empty() {
310            return None;
311        }
312
313        match V::ID {
314            MlsEntityId::GroupState => {
315                let group: PersistedMlsGroup = self.find(k).await.ok().flatten()?;
316                deser(&group.state).ok()
317            }
318            MlsEntityId::SignatureKeyPair => {
319                let stored_credential: StoredCredential = self.find(k).await.ok().flatten()?;
320                let signature_scheme = stored_credential.signature_scheme.try_into().ok()?;
321                let mls_keypair = SignatureKeyPair::from_raw(
322                    signature_scheme,
323                    stored_credential.secret_key.to_vec(),
324                    stored_credential.public_key.to_vec(),
325                );
326
327                // In a well designed interface, something like this should not be necessary. However, we don't have
328                // a well-designed interface.
329                let mls_keypair_serialized = ser(&mls_keypair).ok()?;
330                deser(&mls_keypair_serialized).ok()
331            }
332            MlsEntityId::KeyPackage => {
333                let kp: StoredKeypackage = self.find(k).await.ok().flatten()?;
334                deser(&kp.keypackage).ok()
335            }
336            MlsEntityId::HpkePrivateKey => {
337                let hpke_pk: StoredHpkePrivateKey = self.find(k).await.ok().flatten()?;
338                deser(&hpke_pk.sk).ok()
339            }
340            MlsEntityId::PskBundle => {
341                let psk_bundle: StoredPskBundle = self.find(k).await.ok().flatten()?;
342                deser(&psk_bundle.psk).ok()
343            }
344            MlsEntityId::EncryptionKeyPair => {
345                let kp: StoredEncryptionKeyPair = self.find(k).await.ok().flatten()?;
346                deser(&kp.sk).ok()
347            }
348            MlsEntityId::EpochEncryptionKeyPair => {
349                let kp: StoredEpochEncryptionKeypair = self.find(k).await.ok().flatten()?;
350                deser(&kp.keypairs).ok()
351            }
352        }
353    }
354
355    async fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
356        match V::ID {
357            MlsEntityId::GroupState => self.remove::<PersistedMlsGroup, _>(k).await?,
358            MlsEntityId::SignatureKeyPair => unimplemented!(
359                "Deleting a signature key pair should not be done through this API, any keypair should be deleted via
360                deleting a credential."
361            ),
362            MlsEntityId::HpkePrivateKey => self.remove::<StoredHpkePrivateKey, _>(k).await?,
363            MlsEntityId::KeyPackage => self.remove::<StoredKeypackage, _>(k).await?,
364            MlsEntityId::PskBundle => self.remove::<StoredPskBundle, _>(k).await?,
365            MlsEntityId::EncryptionKeyPair => self.remove::<StoredEncryptionKeyPair, _>(k).await?,
366            MlsEntityId::EpochEncryptionKeyPair => self.remove::<StoredEpochEncryptionKeypair, _>(k).await?,
367        }
368
369        Ok(())
370    }
371}