1use openmls_basic_credential::SignatureKeyPair;
2use openmls_traits::key_store::{MlsEntity, MlsEntityId};
3
4use crate::{
5 CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind,
6 connection::FetchFromDatabase,
7 entities::{
8 E2eiEnrollment, EntityFindParams, MlsEncryptionKeyPair, MlsEpochEncryptionKeyPair, MlsHpkePrivateKey,
9 MlsKeyPackage, MlsPskBundle, MlsSignatureKeyPair, PersistedMlsGroup, PersistedMlsPendingGroup,
10 },
11};
12
13#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
15#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
16pub trait CryptoKeystoreMls: Sized {
17 async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>>;
26
27 async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool;
33
34 async fn mls_group_persist(
44 &self,
45 group_id: impl AsRef<[u8]> + Send,
46 state: &[u8],
47 parent_group_id: Option<&[u8]>,
48 ) -> CryptoKeystoreResult<()>;
49
50 async fn mls_groups_restore(
57 &self,
58 ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>>;
59
60 async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()>;
65
66 async fn mls_pending_groups_save(
78 &self,
79 group_id: impl AsRef<[u8]> + Send,
80 mls_group: &[u8],
81 custom_configuration: &[u8],
82 parent_group_id: Option<&[u8]>,
83 ) -> CryptoKeystoreResult<()>;
84
85 async fn mls_pending_groups_load(
94 &self,
95 group_id: impl AsRef<[u8]> + Send,
96 ) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)>;
97
98 async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()>;
107
108 async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()>;
114
115 async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>>;
120}
121
122#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
123#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
124impl CryptoKeystoreMls for crate::Database {
125 async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>> {
126 cfg_if::cfg_if! {
127 if #[cfg(not(target_family = "wasm"))] {
128 let reverse = true;
129 } else {
130 let reverse = false;
131 }
132 }
133 let keypackages = self
134 .find_all::<MlsKeyPackage>(EntityFindParams {
135 limit: Some(count),
136 offset: None,
137 reverse,
138 })
139 .await?;
140
141 Ok(keypackages
142 .into_iter()
143 .filter_map(|kpb| postcard::from_bytes(&kpb.keypackage).ok())
144 .collect())
145 }
146
147 async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool {
148 matches!(self.find::<PersistedMlsGroup>(group_id).await, Ok(Some(_)))
149 }
150
151 async fn mls_group_persist(
152 &self,
153 group_id: impl AsRef<[u8]> + Send,
154 state: &[u8],
155 parent_group_id: Option<&[u8]>,
156 ) -> CryptoKeystoreResult<()> {
157 self.save(PersistedMlsGroup {
158 id: group_id.as_ref().to_owned(),
159 state: state.into(),
160 parent_id: parent_group_id.map(Into::into),
161 })
162 .await?;
163
164 Ok(())
165 }
166
167 async fn mls_groups_restore(
168 &self,
169 ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>> {
170 let groups = self.find_all::<PersistedMlsGroup>(EntityFindParams::default()).await?;
171 Ok(groups
172 .into_iter()
173 .map(|group: PersistedMlsGroup| (group.id.clone(), (group.parent_id.clone(), group.state.clone())))
174 .collect())
175 }
176
177 async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> {
178 self.remove::<PersistedMlsGroup, _>(group_id).await?;
179
180 Ok(())
181 }
182
183 async fn mls_pending_groups_save(
184 &self,
185 group_id: impl AsRef<[u8]> + Send,
186 mls_group: &[u8],
187 custom_configuration: &[u8],
188 parent_group_id: Option<&[u8]>,
189 ) -> CryptoKeystoreResult<()> {
190 self.save(PersistedMlsPendingGroup {
191 id: group_id.as_ref().to_owned(),
192 state: mls_group.into(),
193 custom_configuration: custom_configuration.into(),
194 parent_id: parent_group_id.map(Into::into),
195 })
196 .await?;
197 Ok(())
198 }
199
200 async fn mls_pending_groups_load(
201 &self,
202 group_id: impl AsRef<[u8]> + Send,
203 ) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)> {
204 self.find(group_id)
205 .await?
206 .map(|r: PersistedMlsPendingGroup| (r.state.clone(), r.custom_configuration.clone()))
207 .ok_or(CryptoKeystoreError::MissingKeyInStore(
208 MissingKeyErrorKind::MlsPendingGroup,
209 ))
210 }
211
212 async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> {
213 self.remove::<PersistedMlsPendingGroup, _>(group_id).await
214 }
215
216 async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()> {
217 self.save(E2eiEnrollment {
218 id: id.into(),
219 content: content.into(),
220 })
221 .await?;
222 Ok(())
223 }
224
225 async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>> {
226 let enrollment = self
228 .find::<E2eiEnrollment>(id)
229 .await?
230 .ok_or(CryptoKeystoreError::MissingKeyInStore(
231 MissingKeyErrorKind::E2eiEnrollment,
232 ))?;
233 self.remove::<E2eiEnrollment, _>(id).await?;
234 Ok(enrollment.content.clone())
235 }
236}
237
238#[inline(always)]
239pub fn deser<T: MlsEntity>(bytes: &[u8]) -> Result<T, CryptoKeystoreError> {
240 Ok(postcard::from_bytes(bytes)?)
241}
242
243#[inline(always)]
244pub fn ser<T: MlsEntity>(value: &T) -> Result<Vec<u8>, CryptoKeystoreError> {
245 Ok(postcard::to_stdvec(value)?)
246}
247
248#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
249#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
250impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Database {
251 type Error = CryptoKeystoreError;
252
253 async fn store<V: MlsEntity + Sync>(&self, k: &[u8], v: &V) -> Result<(), Self::Error>
254 where
255 Self: Sized,
256 {
257 if k.is_empty() {
258 return Err(CryptoKeystoreError::MlsKeyStoreError(
259 "The provided key is empty".into(),
260 ));
261 }
262
263 let data = ser(v)?;
264
265 match V::ID {
266 MlsEntityId::GroupState => {
267 return Err(CryptoKeystoreError::IncorrectApiUsage(
268 "Groups must not be saved using OpenMLS's APIs. You should use the keystore's provided methods",
269 ));
270 }
271 MlsEntityId::SignatureKeyPair => {
272 let concrete_signature_keypair: &SignatureKeyPair = v
273 .downcast()
274 .expect("There's an implementation issue in OpenMLS. This shouln't be happening.");
275
276 let credential_id = vec![];
278 let kp = MlsSignatureKeyPair::new(
279 concrete_signature_keypair.signature_scheme(),
280 k.into(),
281 data,
282 credential_id,
283 );
284 self.save(kp).await?;
285 }
286 MlsEntityId::KeyPackage => {
287 let kp = MlsKeyPackage {
288 keypackage_ref: k.into(),
289 keypackage: data,
290 };
291 self.save(kp).await?;
292 }
293 MlsEntityId::HpkePrivateKey => {
294 let kp = MlsHpkePrivateKey { pk: k.into(), sk: data };
295 self.save(kp).await?;
296 }
297 MlsEntityId::PskBundle => {
298 let kp = MlsPskBundle {
299 psk_id: k.into(),
300 psk: data,
301 };
302 self.save(kp).await?;
303 }
304 MlsEntityId::EncryptionKeyPair => {
305 let kp = MlsEncryptionKeyPair { pk: k.into(), sk: data };
306 self.save(kp).await?;
307 }
308 MlsEntityId::EpochEncryptionKeyPair => {
309 let kp = MlsEpochEncryptionKeyPair {
310 id: k.into(),
311 keypairs: data,
312 };
313 self.save(kp).await?;
314 }
315 }
316
317 Ok(())
318 }
319
320 async fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V>
321 where
322 Self: Sized,
323 {
324 if k.is_empty() {
325 return None;
326 }
327
328 match V::ID {
329 MlsEntityId::GroupState => {
330 let group: PersistedMlsGroup = self.find(k).await.ok().flatten()?;
331 deser(&group.state).ok()
332 }
333 MlsEntityId::SignatureKeyPair => {
334 let sig: MlsSignatureKeyPair = self.find(k).await.ok().flatten()?;
335 deser(&sig.keypair).ok()
336 }
337 MlsEntityId::KeyPackage => {
338 let kp: MlsKeyPackage = self.find(k).await.ok().flatten()?;
339 deser(&kp.keypackage).ok()
340 }
341 MlsEntityId::HpkePrivateKey => {
342 let hpke_pk: MlsHpkePrivateKey = self.find(k).await.ok().flatten()?;
343 deser(&hpke_pk.sk).ok()
344 }
345 MlsEntityId::PskBundle => {
346 let psk_bundle: MlsPskBundle = self.find(k).await.ok().flatten()?;
347 deser(&psk_bundle.psk).ok()
348 }
349 MlsEntityId::EncryptionKeyPair => {
350 let kp: MlsEncryptionKeyPair = self.find(k).await.ok().flatten()?;
351 deser(&kp.sk).ok()
352 }
353 MlsEntityId::EpochEncryptionKeyPair => {
354 let kp: MlsEpochEncryptionKeyPair = self.find(k).await.ok().flatten()?;
355 deser(&kp.keypairs).ok()
356 }
357 }
358 }
359
360 async fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
361 match V::ID {
362 MlsEntityId::GroupState => self.remove::<PersistedMlsGroup, _>(k).await?,
363 MlsEntityId::SignatureKeyPair => self.remove::<MlsSignatureKeyPair, _>(k).await?,
364 MlsEntityId::HpkePrivateKey => self.remove::<MlsHpkePrivateKey, _>(k).await?,
365 MlsEntityId::KeyPackage => self.remove::<MlsKeyPackage, _>(k).await?,
366 MlsEntityId::PskBundle => self.remove::<MlsPskBundle, _>(k).await?,
367 MlsEntityId::EncryptionKeyPair => self.remove::<MlsEncryptionKeyPair, _>(k).await?,
368 MlsEntityId::EpochEncryptionKeyPair => self.remove::<MlsEpochEncryptionKeyPair, _>(k).await?,
369 }
370
371 Ok(())
372 }
373}