1use openmls::prelude::Ciphersuite;
2use openmls_basic_credential::SignatureKeyPair;
3use openmls_traits::key_store::{MlsEntity, MlsEntityId};
4
5use crate::{
6 CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind,
7 connection::FetchFromDatabase,
8 entities::{
9 EntityFindParams, PersistedMlsGroup, PersistedMlsPendingGroup, StoredCredential, StoredE2eiEnrollment,
10 StoredEncryptionKeyPair, StoredEpochEncryptionKeypair, StoredHpkePrivateKey, StoredKeypackage, StoredPskBundle,
11 },
12};
13
14#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
16#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
17pub trait CryptoKeystoreMls: Sized {
18 async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>>;
27
28 async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool;
34
35 async fn mls_group_persist(
45 &self,
46 group_id: impl AsRef<[u8]> + Send,
47 state: &[u8],
48 parent_group_id: Option<&[u8]>,
49 ) -> CryptoKeystoreResult<()>;
50
51 async fn mls_groups_restore(
58 &self,
59 ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>>;
60
61 async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()>;
66
67 async fn mls_pending_groups_save(
79 &self,
80 group_id: impl AsRef<[u8]> + Send,
81 mls_group: &[u8],
82 custom_configuration: &[u8],
83 parent_group_id: Option<&[u8]>,
84 ) -> CryptoKeystoreResult<()>;
85
86 async fn mls_pending_groups_load(
95 &self,
96 group_id: impl AsRef<[u8]> + Send,
97 ) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)>;
98
99 async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()>;
108
109 async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()>;
115
116 async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>>;
121}
122
123#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
124#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
125impl CryptoKeystoreMls for crate::Database {
126 async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>> {
127 let reverse = !cfg!(target_family = "wasm");
128 let keypackages = self
129 .find_all::<StoredKeypackage>(EntityFindParams {
130 limit: Some(count),
131 offset: None,
132 reverse,
133 })
134 .await?;
135
136 Ok(keypackages
137 .into_iter()
138 .filter_map(|kpb| postcard::from_bytes(&kpb.keypackage).ok())
139 .collect())
140 }
141
142 async fn mls_group_exists(&self, group_id: impl AsRef<[u8]> + Send) -> bool {
143 matches!(self.find::<PersistedMlsGroup>(group_id).await, Ok(Some(_)))
144 }
145
146 async fn mls_group_persist(
147 &self,
148 group_id: impl AsRef<[u8]> + Send,
149 state: &[u8],
150 parent_group_id: Option<&[u8]>,
151 ) -> CryptoKeystoreResult<()> {
152 self.save(PersistedMlsGroup {
153 id: group_id.as_ref().to_owned(),
154 state: state.into(),
155 parent_id: parent_group_id.map(Into::into),
156 })
157 .await?;
158
159 Ok(())
160 }
161
162 async fn mls_groups_restore(
163 &self,
164 ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>> {
165 let groups = self.find_all::<PersistedMlsGroup>(EntityFindParams::default()).await?;
166 Ok(groups
167 .into_iter()
168 .map(|group: PersistedMlsGroup| (group.id.clone(), (group.parent_id.clone(), group.state.clone())))
169 .collect())
170 }
171
172 async fn mls_group_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> {
173 self.remove::<PersistedMlsGroup, _>(group_id).await?;
174
175 Ok(())
176 }
177
178 async fn mls_pending_groups_save(
179 &self,
180 group_id: impl AsRef<[u8]> + Send,
181 mls_group: &[u8],
182 custom_configuration: &[u8],
183 parent_group_id: Option<&[u8]>,
184 ) -> CryptoKeystoreResult<()> {
185 self.save(PersistedMlsPendingGroup {
186 id: group_id.as_ref().to_owned(),
187 state: mls_group.into(),
188 custom_configuration: custom_configuration.into(),
189 parent_id: parent_group_id.map(Into::into),
190 })
191 .await?;
192 Ok(())
193 }
194
195 async fn mls_pending_groups_load(
196 &self,
197 group_id: impl AsRef<[u8]> + Send,
198 ) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)> {
199 self.find(group_id)
200 .await?
201 .map(|r: PersistedMlsPendingGroup| (r.state.clone(), r.custom_configuration.clone()))
202 .ok_or(CryptoKeystoreError::MissingKeyInStore(
203 MissingKeyErrorKind::MlsPendingGroup,
204 ))
205 }
206
207 async fn mls_pending_groups_delete(&self, group_id: impl AsRef<[u8]> + Send) -> CryptoKeystoreResult<()> {
208 self.remove::<PersistedMlsPendingGroup, _>(group_id).await
209 }
210
211 async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()> {
212 self.save(StoredE2eiEnrollment {
213 id: id.into(),
214 content: content.into(),
215 })
216 .await?;
217 Ok(())
218 }
219
220 async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>> {
221 let enrollment = self
223 .find::<StoredE2eiEnrollment>(id)
224 .await?
225 .ok_or(CryptoKeystoreError::MissingKeyInStore(
226 MissingKeyErrorKind::StoredE2eiEnrollment,
227 ))?;
228 self.remove::<StoredE2eiEnrollment, _>(id).await?;
229 Ok(enrollment.content.clone())
230 }
231}
232
233#[inline(always)]
234pub fn deser<T: MlsEntity>(bytes: &[u8]) -> Result<T, CryptoKeystoreError> {
235 Ok(postcard::from_bytes(bytes)?)
236}
237
238#[inline(always)]
239pub fn ser<T: MlsEntity>(value: &T) -> Result<Vec<u8>, CryptoKeystoreError> {
240 Ok(postcard::to_stdvec(value)?)
241}
242
243#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
244#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
245impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Database {
246 type Error = CryptoKeystoreError;
247
248 async fn store<V: MlsEntity + Sync>(&self, k: &[u8], v: &V) -> Result<(), Self::Error>
249 where
250 Self: Sized,
251 {
252 if k.is_empty() {
253 return Err(CryptoKeystoreError::MlsKeyStoreError(
254 "The provided key is empty".into(),
255 ));
256 }
257
258 let data = ser(v)?;
259
260 match V::ID {
261 MlsEntityId::GroupState => {
262 return Err(CryptoKeystoreError::IncorrectApiUsage(
263 "Groups must not be saved using OpenMLS's APIs. You should use the keystore's provided methods",
264 ));
265 }
266 MlsEntityId::SignatureKeyPair => {
267 return Err(CryptoKeystoreError::IncorrectApiUsage(
268 "Signature keys must not be saved using OpenMLS's APIs. Save a credential via the keystore API
269 instead.",
270 ));
271 }
272 MlsEntityId::KeyPackage => {
273 let kp = StoredKeypackage {
274 keypackage_ref: k.into(),
275 keypackage: data,
276 };
277 self.save(kp).await?;
278 }
279 MlsEntityId::HpkePrivateKey => {
280 let kp = StoredHpkePrivateKey { pk: k.into(), sk: data };
281 self.save(kp).await?;
282 }
283 MlsEntityId::PskBundle => {
284 let kp = StoredPskBundle {
285 psk_id: k.into(),
286 psk: data,
287 };
288 self.save(kp).await?;
289 }
290 MlsEntityId::EncryptionKeyPair => {
291 let kp = StoredEncryptionKeyPair { pk: k.into(), sk: data };
292 self.save(kp).await?;
293 }
294 MlsEntityId::EpochEncryptionKeyPair => {
295 let kp = StoredEpochEncryptionKeypair {
296 id: k.into(),
297 keypairs: data,
298 };
299 self.save(kp).await?;
300 }
301 }
302
303 Ok(())
304 }
305
306 async fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V>
307 where
308 Self: Sized,
309 {
310 if k.is_empty() {
311 return None;
312 }
313
314 match V::ID {
315 MlsEntityId::GroupState => {
316 let group: PersistedMlsGroup = self.find(k).await.ok().flatten()?;
317 deser(&group.state).ok()
318 }
319 MlsEntityId::SignatureKeyPair => {
320 let stored_credential = self.find::<StoredCredential>(k).await.ok().flatten()?;
321 let ciphersuite = Ciphersuite::try_from(stored_credential.ciphersuite).ok()?;
322 let signature_scheme = ciphersuite.signature_algorithm();
323
324 let mls_keypair = SignatureKeyPair::from_raw(
325 signature_scheme,
326 stored_credential.private_key.to_vec(),
327 stored_credential.public_key.to_vec(),
328 );
329
330 let mls_keypair_serialized = ser(&mls_keypair).ok()?;
333 deser(&mls_keypair_serialized).ok()
334 }
335 MlsEntityId::KeyPackage => {
336 let kp: StoredKeypackage = self.find(k).await.ok().flatten()?;
337 deser(&kp.keypackage).ok()
338 }
339 MlsEntityId::HpkePrivateKey => {
340 let hpke_pk: StoredHpkePrivateKey = self.find(k).await.ok().flatten()?;
341 deser(&hpke_pk.sk).ok()
342 }
343 MlsEntityId::PskBundle => {
344 let psk_bundle: StoredPskBundle = self.find(k).await.ok().flatten()?;
345 deser(&psk_bundle.psk).ok()
346 }
347 MlsEntityId::EncryptionKeyPair => {
348 let kp: StoredEncryptionKeyPair = self.find(k).await.ok().flatten()?;
349 deser(&kp.sk).ok()
350 }
351 MlsEntityId::EpochEncryptionKeyPair => {
352 let kp: StoredEpochEncryptionKeypair = self.find(k).await.ok().flatten()?;
353 deser(&kp.keypairs).ok()
354 }
355 }
356 }
357
358 async fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
359 match V::ID {
360 MlsEntityId::GroupState => self.remove::<PersistedMlsGroup, _>(k).await?,
361 MlsEntityId::SignatureKeyPair => unimplemented!(
362 "Deleting a signature key pair should not be done through this API, any keypair should be deleted via
363 deleting a credential."
364 ),
365 MlsEntityId::HpkePrivateKey => self.remove::<StoredHpkePrivateKey, _>(k).await?,
366 MlsEntityId::KeyPackage => self.remove::<StoredKeypackage, _>(k).await?,
367 MlsEntityId::PskBundle => self.remove::<StoredPskBundle, _>(k).await?,
368 MlsEntityId::EncryptionKeyPair => self.remove::<StoredEncryptionKeyPair, _>(k).await?,
369 MlsEntityId::EpochEncryptionKeyPair => self.remove::<StoredEpochEncryptionKeypair, _>(k).await?,
370 }
371
372 Ok(())
373 }
374}