core_crypto_keystore/
mls.rs

1// Wire
2// Copyright (C) 2022 Wire Swiss GmbH
3
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with this program. If not, see http://www.gnu.org/licenses/.
16
17use crate::connection::FetchFromDatabase;
18use crate::entities::MlsEpochEncryptionKeyPair;
19use crate::{
20    entities::{
21        E2eiEnrollment, EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage, MlsPskBundle,
22        MlsSignatureKeyPair, PersistedMlsGroup, PersistedMlsPendingGroup,
23    },
24    CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind,
25};
26use openmls_basic_credential::SignatureKeyPair;
27use openmls_traits::key_store::{MlsEntity, MlsEntityId};
28
29/// An interface for the specialized queries in the KeyStore
30#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
31#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
32pub trait CryptoKeystoreMls: Sized {
33    /// Fetches Keypackages
34    ///
35    /// # Arguments
36    /// * `count` - amount of entries to be returned
37    ///
38    /// # Errors
39    /// Any common error that can happen during a database connection. IoError being a common error
40    /// for example.
41    async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>>;
42
43    /// Checks if the given MLS group id exists in the keystore
44    /// Note: in case of any error, this will return false
45    ///
46    /// # Arguments
47    /// * `group_id` - group/conversation id
48    async fn mls_group_exists(&self, group_id: &[u8]) -> bool;
49
50    /// Persists a `MlsGroup`
51    ///
52    /// # Arguments
53    /// * `group_id` - group/conversation id
54    /// * `state` - the group state
55    ///
56    /// # Errors
57    /// Any common error that can happen during a database connection. IoError being a common error
58    /// for example.
59    async fn mls_group_persist(
60        &self,
61        group_id: &[u8],
62        state: &[u8],
63        parent_group_id: Option<&[u8]>,
64    ) -> CryptoKeystoreResult<()>;
65
66    /// Loads `MlsGroups` from the database. It will be returned as a `HashMap` where the key is
67    /// the group/conversation id and the value the group state
68    ///
69    /// # Errors
70    /// Any common error that can happen during a database connection. IoError being a common error
71    /// for example.
72    async fn mls_groups_restore(
73        &self,
74    ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>>;
75
76    /// Deletes `MlsGroups` from the database.
77    /// # Errors
78    /// Any common error that can happen during a database connection. IoError being a common error
79    /// for example.
80    async fn mls_group_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()>;
81
82    /// Saves a `MlsGroup` in a temporary table (typically used in scenarios where the group cannot
83    /// be committed until the backend acknowledges it, like external commits)
84    ///
85    /// # Arguments
86    /// * `group_id` - group/conversation id
87    /// * `mls_group` - the group/conversation state
88    /// * `custom_configuration` - local group configuration
89    ///
90    /// # Errors
91    /// Any common error that can happen during a database connection. IoError being a common error
92    /// for example.
93    async fn mls_pending_groups_save(
94        &self,
95        group_id: &[u8],
96        mls_group: &[u8],
97        custom_configuration: &[u8],
98        parent_group_id: Option<&[u8]>,
99    ) -> CryptoKeystoreResult<()>;
100
101    /// Loads a temporary `MlsGroup` and its configuration from the database
102    ///
103    /// # Arguments
104    /// * `id` - group/conversation id
105    ///
106    /// # Errors
107    /// Any common error that can happen during a database connection. IoError being a common error
108    /// for example.
109    async fn mls_pending_groups_load(&self, group_id: &[u8]) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)>;
110
111    /// Deletes a temporary `MlsGroup` from the database
112    ///
113    /// # Arguments
114    /// * `id` - group/conversation id
115    ///
116    /// # Errors
117    /// Any common error that can happen during a database connection. IoError being a common error
118    /// for example.
119    async fn mls_pending_groups_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()>;
120
121    /// Persists an enrollment instance
122    ///
123    /// # Arguments
124    /// * `id` - hash of the enrollment and unique identifier
125    /// * `content` - serialized enrollment
126    async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()>;
127
128    /// Fetches and delete the enrollment instance
129    ///
130    /// # Arguments
131    /// * `id` - hash of the enrollment and unique identifier
132    async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>>;
133}
134
135#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
136#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
137impl CryptoKeystoreMls for crate::Connection {
138    async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>> {
139        cfg_if::cfg_if! {
140            if #[cfg(not(target_family = "wasm"))] {
141                let reverse = true;
142            } else {
143                let reverse = false;
144            }
145        }
146        let keypackages = self
147            .find_all::<MlsKeyPackage>(EntityFindParams {
148                limit: Some(count),
149                offset: None,
150                reverse,
151            })
152            .await?;
153
154        Ok(keypackages
155            .into_iter()
156            .filter_map(|kpb| postcard::from_bytes(&kpb.keypackage).ok())
157            .collect())
158    }
159
160    async fn mls_group_exists(&self, group_id: &[u8]) -> bool {
161        matches!(self.find::<PersistedMlsGroup>(group_id).await, Ok(Some(_)))
162    }
163
164    async fn mls_group_persist(
165        &self,
166        group_id: &[u8],
167        state: &[u8],
168        parent_group_id: Option<&[u8]>,
169    ) -> CryptoKeystoreResult<()> {
170        self.save(PersistedMlsGroup {
171            id: group_id.into(),
172            state: state.into(),
173            parent_id: parent_group_id.map(Into::into),
174        })
175        .await?;
176
177        Ok(())
178    }
179
180    async fn mls_groups_restore(
181        &self,
182    ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>> {
183        let groups = self.find_all::<PersistedMlsGroup>(EntityFindParams::default()).await?;
184        Ok(groups
185            .into_iter()
186            .map(|group: PersistedMlsGroup| (group.id.clone(), (group.parent_id.clone(), group.state.clone())))
187            .collect())
188    }
189
190    async fn mls_group_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()> {
191        self.remove::<PersistedMlsGroup, _>(group_id).await?;
192
193        Ok(())
194    }
195
196    async fn mls_pending_groups_save(
197        &self,
198        group_id: &[u8],
199        mls_group: &[u8],
200        custom_configuration: &[u8],
201        parent_group_id: Option<&[u8]>,
202    ) -> CryptoKeystoreResult<()> {
203        self.save(PersistedMlsPendingGroup {
204            id: group_id.into(),
205            state: mls_group.into(),
206            custom_configuration: custom_configuration.into(),
207            parent_id: parent_group_id.map(Into::into),
208        })
209        .await?;
210        Ok(())
211    }
212
213    async fn mls_pending_groups_load(&self, group_id: &[u8]) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)> {
214        self.find(group_id)
215            .await?
216            .map(|r: PersistedMlsPendingGroup| (r.state.clone(), r.custom_configuration.clone()))
217            .ok_or(CryptoKeystoreError::MissingKeyInStore(
218                MissingKeyErrorKind::MlsPendingGroup,
219            ))
220    }
221
222    async fn mls_pending_groups_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()> {
223        self.remove::<PersistedMlsPendingGroup, _>(group_id).await
224    }
225
226    async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()> {
227        self.save(E2eiEnrollment {
228            id: id.into(),
229            content: content.into(),
230        })
231        .await?;
232        Ok(())
233    }
234
235    async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>> {
236        // someone who has time could try to optimize this but honestly it's really on the cold path
237        let enrollment = self
238            .find::<E2eiEnrollment>(id)
239            .await?
240            .ok_or(CryptoKeystoreError::MissingKeyInStore(
241                MissingKeyErrorKind::E2eiEnrollment,
242            ))?;
243        self.remove::<E2eiEnrollment, _>(id).await?;
244        Ok(enrollment.content.clone())
245    }
246}
247
248#[inline(always)]
249pub fn deser<T: MlsEntity>(bytes: &[u8]) -> Result<T, CryptoKeystoreError> {
250    Ok(postcard::from_bytes(bytes)?)
251}
252
253#[inline(always)]
254pub fn ser<T: MlsEntity>(value: &T) -> Result<Vec<u8>, CryptoKeystoreError> {
255    Ok(postcard::to_stdvec(value)?)
256}
257
258#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
259#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
260impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Connection {
261    type Error = CryptoKeystoreError;
262
263    async fn store<V: MlsEntity + Sync>(&self, k: &[u8], v: &V) -> Result<(), Self::Error>
264    where
265        Self: Sized,
266    {
267        if k.is_empty() {
268            return Err(CryptoKeystoreError::MlsKeyStoreError(
269                "The provided key is empty".into(),
270            ));
271        }
272
273        let data = ser(v)?;
274
275        match V::ID {
276            MlsEntityId::GroupState => {
277                return Err(CryptoKeystoreError::IncorrectApiUsage(
278                    "Groups must not be saved using OpenMLS's APIs. You should use the keystore's provided methods",
279                ));
280            }
281            MlsEntityId::SignatureKeyPair => {
282                let concrete_signature_keypair: &SignatureKeyPair = v
283                    .downcast()
284                    .expect("There's an implementation issue in OpenMLS. This shouln't be happening.");
285
286                // Having an empty credential id seems tolerable, since the SignatureKeyPair type is retrieved from the key store via its public key.
287                let credential_id = vec![];
288                let kp = MlsSignatureKeyPair::new(
289                    concrete_signature_keypair.signature_scheme(),
290                    k.into(),
291                    data,
292                    credential_id,
293                );
294                self.save(kp).await?;
295            }
296            MlsEntityId::KeyPackage => {
297                let kp = MlsKeyPackage {
298                    keypackage_ref: k.into(),
299                    keypackage: data,
300                };
301                self.save(kp).await?;
302            }
303            MlsEntityId::HpkePrivateKey => {
304                let kp = MlsHpkePrivateKey { pk: k.into(), sk: data };
305                self.save(kp).await?;
306            }
307            MlsEntityId::PskBundle => {
308                let kp = MlsPskBundle {
309                    psk_id: k.into(),
310                    psk: data,
311                };
312                self.save(kp).await?;
313            }
314            MlsEntityId::EncryptionKeyPair => {
315                let kp = MlsEncryptionKeyPair { pk: k.into(), sk: data };
316                self.save(kp).await?;
317            }
318            MlsEntityId::EpochEncryptionKeyPair => {
319                let kp = MlsEpochEncryptionKeyPair {
320                    id: k.into(),
321                    keypairs: data,
322                };
323                self.save(kp).await?;
324            }
325        }
326
327        Ok(())
328    }
329
330    async fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V>
331    where
332        Self: Sized,
333    {
334        if k.is_empty() {
335            return None;
336        }
337
338        match V::ID {
339            MlsEntityId::GroupState => {
340                let group: PersistedMlsGroup = self.find(k).await.ok().flatten()?;
341                deser(&group.state).ok()
342            }
343            MlsEntityId::SignatureKeyPair => {
344                let sig: MlsSignatureKeyPair = self.find(k).await.ok().flatten()?;
345                deser(&sig.keypair).ok()
346            }
347            MlsEntityId::KeyPackage => {
348                let kp: MlsKeyPackage = self.find(k).await.ok().flatten()?;
349                deser(&kp.keypackage).ok()
350            }
351            MlsEntityId::HpkePrivateKey => {
352                let hpke_pk: MlsHpkePrivateKey = self.find(k).await.ok().flatten()?;
353                deser(&hpke_pk.sk).ok()
354            }
355            MlsEntityId::PskBundle => {
356                let psk_bundle: MlsPskBundle = self.find(k).await.ok().flatten()?;
357                deser(&psk_bundle.psk).ok()
358            }
359            MlsEntityId::EncryptionKeyPair => {
360                let kp: MlsEncryptionKeyPair = self.find(k).await.ok().flatten()?;
361                deser(&kp.sk).ok()
362            }
363            MlsEntityId::EpochEncryptionKeyPair => {
364                let kp: MlsEpochEncryptionKeyPair = self.find(k).await.ok().flatten()?;
365                deser(&kp.keypairs).ok()
366            }
367        }
368    }
369
370    async fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
371        match V::ID {
372            MlsEntityId::GroupState => self.remove::<PersistedMlsGroup, _>(k).await?,
373            MlsEntityId::SignatureKeyPair => self.remove::<MlsSignatureKeyPair, _>(k).await?,
374            MlsEntityId::HpkePrivateKey => self.remove::<MlsHpkePrivateKey, _>(k).await?,
375            MlsEntityId::KeyPackage => self.remove::<MlsKeyPackage, _>(k).await?,
376            MlsEntityId::PskBundle => self.remove::<MlsPskBundle, _>(k).await?,
377            MlsEntityId::EncryptionKeyPair => self.remove::<MlsEncryptionKeyPair, _>(k).await?,
378            MlsEntityId::EpochEncryptionKeyPair => self.remove::<MlsEpochEncryptionKeyPair, _>(k).await?,
379        }
380
381        Ok(())
382    }
383}