1use crate::connection::FetchFromDatabase;
18use crate::entities::MlsEpochEncryptionKeyPair;
19use crate::{
20 CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind,
21 entities::{
22 E2eiEnrollment, EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage, MlsPskBundle,
23 MlsSignatureKeyPair, PersistedMlsGroup, PersistedMlsPendingGroup,
24 },
25};
26use openmls_basic_credential::SignatureKeyPair;
27use openmls_traits::key_store::{MlsEntity, MlsEntityId};
28
29#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
31#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
32pub trait CryptoKeystoreMls: Sized {
33 async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>>;
42
43 async fn mls_group_exists(&self, group_id: &[u8]) -> bool;
49
50 async fn mls_group_persist(
60 &self,
61 group_id: &[u8],
62 state: &[u8],
63 parent_group_id: Option<&[u8]>,
64 ) -> CryptoKeystoreResult<()>;
65
66 async fn mls_groups_restore(
73 &self,
74 ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>>;
75
76 async fn mls_group_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()>;
81
82 async fn mls_pending_groups_save(
94 &self,
95 group_id: &[u8],
96 mls_group: &[u8],
97 custom_configuration: &[u8],
98 parent_group_id: Option<&[u8]>,
99 ) -> CryptoKeystoreResult<()>;
100
101 async fn mls_pending_groups_load(&self, group_id: &[u8]) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)>;
110
111 async fn mls_pending_groups_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()>;
120
121 async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()>;
127
128 async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>>;
133}
134
135#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
136#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
137impl CryptoKeystoreMls for crate::Connection {
138 async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>> {
139 cfg_if::cfg_if! {
140 if #[cfg(not(target_family = "wasm"))] {
141 let reverse = true;
142 } else {
143 let reverse = false;
144 }
145 }
146 let keypackages = self
147 .find_all::<MlsKeyPackage>(EntityFindParams {
148 limit: Some(count),
149 offset: None,
150 reverse,
151 })
152 .await?;
153
154 Ok(keypackages
155 .into_iter()
156 .filter_map(|kpb| postcard::from_bytes(&kpb.keypackage).ok())
157 .collect())
158 }
159
160 async fn mls_group_exists(&self, group_id: &[u8]) -> bool {
161 matches!(self.find::<PersistedMlsGroup>(group_id).await, Ok(Some(_)))
162 }
163
164 async fn mls_group_persist(
165 &self,
166 group_id: &[u8],
167 state: &[u8],
168 parent_group_id: Option<&[u8]>,
169 ) -> CryptoKeystoreResult<()> {
170 self.save(PersistedMlsGroup {
171 id: group_id.into(),
172 state: state.into(),
173 parent_id: parent_group_id.map(Into::into),
174 })
175 .await?;
176
177 Ok(())
178 }
179
180 async fn mls_groups_restore(
181 &self,
182 ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>> {
183 let groups = self.find_all::<PersistedMlsGroup>(EntityFindParams::default()).await?;
184 Ok(groups
185 .into_iter()
186 .map(|group: PersistedMlsGroup| (group.id.clone(), (group.parent_id.clone(), group.state.clone())))
187 .collect())
188 }
189
190 async fn mls_group_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()> {
191 self.remove::<PersistedMlsGroup, _>(group_id).await?;
192
193 Ok(())
194 }
195
196 async fn mls_pending_groups_save(
197 &self,
198 group_id: &[u8],
199 mls_group: &[u8],
200 custom_configuration: &[u8],
201 parent_group_id: Option<&[u8]>,
202 ) -> CryptoKeystoreResult<()> {
203 self.save(PersistedMlsPendingGroup {
204 id: group_id.into(),
205 state: mls_group.into(),
206 custom_configuration: custom_configuration.into(),
207 parent_id: parent_group_id.map(Into::into),
208 })
209 .await?;
210 Ok(())
211 }
212
213 async fn mls_pending_groups_load(&self, group_id: &[u8]) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)> {
214 self.find(group_id)
215 .await?
216 .map(|r: PersistedMlsPendingGroup| (r.state.clone(), r.custom_configuration.clone()))
217 .ok_or(CryptoKeystoreError::MissingKeyInStore(
218 MissingKeyErrorKind::MlsPendingGroup,
219 ))
220 }
221
222 async fn mls_pending_groups_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()> {
223 self.remove::<PersistedMlsPendingGroup, _>(group_id).await
224 }
225
226 async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()> {
227 self.save(E2eiEnrollment {
228 id: id.into(),
229 content: content.into(),
230 })
231 .await?;
232 Ok(())
233 }
234
235 async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>> {
236 let enrollment = self
238 .find::<E2eiEnrollment>(id)
239 .await?
240 .ok_or(CryptoKeystoreError::MissingKeyInStore(
241 MissingKeyErrorKind::E2eiEnrollment,
242 ))?;
243 self.remove::<E2eiEnrollment, _>(id).await?;
244 Ok(enrollment.content.clone())
245 }
246}
247
248#[inline(always)]
249pub fn deser<T: MlsEntity>(bytes: &[u8]) -> Result<T, CryptoKeystoreError> {
250 Ok(postcard::from_bytes(bytes)?)
251}
252
253#[inline(always)]
254pub fn ser<T: MlsEntity>(value: &T) -> Result<Vec<u8>, CryptoKeystoreError> {
255 Ok(postcard::to_stdvec(value)?)
256}
257
258#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
259#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
260impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Connection {
261 type Error = CryptoKeystoreError;
262
263 async fn store<V: MlsEntity + Sync>(&self, k: &[u8], v: &V) -> Result<(), Self::Error>
264 where
265 Self: Sized,
266 {
267 if k.is_empty() {
268 return Err(CryptoKeystoreError::MlsKeyStoreError(
269 "The provided key is empty".into(),
270 ));
271 }
272
273 let data = ser(v)?;
274
275 match V::ID {
276 MlsEntityId::GroupState => {
277 return Err(CryptoKeystoreError::IncorrectApiUsage(
278 "Groups must not be saved using OpenMLS's APIs. You should use the keystore's provided methods",
279 ));
280 }
281 MlsEntityId::SignatureKeyPair => {
282 let concrete_signature_keypair: &SignatureKeyPair = v
283 .downcast()
284 .expect("There's an implementation issue in OpenMLS. This shouln't be happening.");
285
286 let credential_id = vec![];
288 let kp = MlsSignatureKeyPair::new(
289 concrete_signature_keypair.signature_scheme(),
290 k.into(),
291 data,
292 credential_id,
293 );
294 self.save(kp).await?;
295 }
296 MlsEntityId::KeyPackage => {
297 let kp = MlsKeyPackage {
298 keypackage_ref: k.into(),
299 keypackage: data,
300 };
301 self.save(kp).await?;
302 }
303 MlsEntityId::HpkePrivateKey => {
304 let kp = MlsHpkePrivateKey { pk: k.into(), sk: data };
305 self.save(kp).await?;
306 }
307 MlsEntityId::PskBundle => {
308 let kp = MlsPskBundle {
309 psk_id: k.into(),
310 psk: data,
311 };
312 self.save(kp).await?;
313 }
314 MlsEntityId::EncryptionKeyPair => {
315 let kp = MlsEncryptionKeyPair { pk: k.into(), sk: data };
316 self.save(kp).await?;
317 }
318 MlsEntityId::EpochEncryptionKeyPair => {
319 let kp = MlsEpochEncryptionKeyPair {
320 id: k.into(),
321 keypairs: data,
322 };
323 self.save(kp).await?;
324 }
325 }
326
327 Ok(())
328 }
329
330 async fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V>
331 where
332 Self: Sized,
333 {
334 if k.is_empty() {
335 return None;
336 }
337
338 match V::ID {
339 MlsEntityId::GroupState => {
340 let group: PersistedMlsGroup = self.find(k).await.ok().flatten()?;
341 deser(&group.state).ok()
342 }
343 MlsEntityId::SignatureKeyPair => {
344 let sig: MlsSignatureKeyPair = self.find(k).await.ok().flatten()?;
345 deser(&sig.keypair).ok()
346 }
347 MlsEntityId::KeyPackage => {
348 let kp: MlsKeyPackage = self.find(k).await.ok().flatten()?;
349 deser(&kp.keypackage).ok()
350 }
351 MlsEntityId::HpkePrivateKey => {
352 let hpke_pk: MlsHpkePrivateKey = self.find(k).await.ok().flatten()?;
353 deser(&hpke_pk.sk).ok()
354 }
355 MlsEntityId::PskBundle => {
356 let psk_bundle: MlsPskBundle = self.find(k).await.ok().flatten()?;
357 deser(&psk_bundle.psk).ok()
358 }
359 MlsEntityId::EncryptionKeyPair => {
360 let kp: MlsEncryptionKeyPair = self.find(k).await.ok().flatten()?;
361 deser(&kp.sk).ok()
362 }
363 MlsEntityId::EpochEncryptionKeyPair => {
364 let kp: MlsEpochEncryptionKeyPair = self.find(k).await.ok().flatten()?;
365 deser(&kp.keypairs).ok()
366 }
367 }
368 }
369
370 async fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
371 match V::ID {
372 MlsEntityId::GroupState => self.remove::<PersistedMlsGroup, _>(k).await?,
373 MlsEntityId::SignatureKeyPair => self.remove::<MlsSignatureKeyPair, _>(k).await?,
374 MlsEntityId::HpkePrivateKey => self.remove::<MlsHpkePrivateKey, _>(k).await?,
375 MlsEntityId::KeyPackage => self.remove::<MlsKeyPackage, _>(k).await?,
376 MlsEntityId::PskBundle => self.remove::<MlsPskBundle, _>(k).await?,
377 MlsEntityId::EncryptionKeyPair => self.remove::<MlsEncryptionKeyPair, _>(k).await?,
378 MlsEntityId::EpochEncryptionKeyPair => self.remove::<MlsEpochEncryptionKeyPair, _>(k).await?,
379 }
380
381 Ok(())
382 }
383}