1use crate::connection::FetchFromDatabase;
2use crate::entities::MlsEpochEncryptionKeyPair;
3use crate::{
4 CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind,
5 entities::{
6 E2eiEnrollment, EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage, MlsPskBundle,
7 MlsSignatureKeyPair, PersistedMlsGroup, PersistedMlsPendingGroup,
8 },
9};
10use openmls_basic_credential::SignatureKeyPair;
11use openmls_traits::key_store::{MlsEntity, MlsEntityId};
12
13#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
15#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
16pub trait CryptoKeystoreMls: Sized {
17 async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>>;
26
27 async fn mls_group_exists(&self, group_id: &[u8]) -> bool;
33
34 async fn mls_group_persist(
44 &self,
45 group_id: &[u8],
46 state: &[u8],
47 parent_group_id: Option<&[u8]>,
48 ) -> CryptoKeystoreResult<()>;
49
50 async fn mls_groups_restore(
57 &self,
58 ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>>;
59
60 async fn mls_group_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()>;
65
66 async fn mls_pending_groups_save(
78 &self,
79 group_id: &[u8],
80 mls_group: &[u8],
81 custom_configuration: &[u8],
82 parent_group_id: Option<&[u8]>,
83 ) -> CryptoKeystoreResult<()>;
84
85 async fn mls_pending_groups_load(&self, group_id: &[u8]) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)>;
94
95 async fn mls_pending_groups_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()>;
104
105 async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()>;
111
112 async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>>;
117}
118
119#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
120#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
121impl CryptoKeystoreMls for crate::Connection {
122 async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>> {
123 cfg_if::cfg_if! {
124 if #[cfg(not(target_family = "wasm"))] {
125 let reverse = true;
126 } else {
127 let reverse = false;
128 }
129 }
130 let keypackages = self
131 .find_all::<MlsKeyPackage>(EntityFindParams {
132 limit: Some(count),
133 offset: None,
134 reverse,
135 })
136 .await?;
137
138 Ok(keypackages
139 .into_iter()
140 .filter_map(|kpb| postcard::from_bytes(&kpb.keypackage).ok())
141 .collect())
142 }
143
144 async fn mls_group_exists(&self, group_id: &[u8]) -> bool {
145 matches!(self.find::<PersistedMlsGroup>(group_id).await, Ok(Some(_)))
146 }
147
148 async fn mls_group_persist(
149 &self,
150 group_id: &[u8],
151 state: &[u8],
152 parent_group_id: Option<&[u8]>,
153 ) -> CryptoKeystoreResult<()> {
154 self.save(PersistedMlsGroup {
155 id: group_id.into(),
156 state: state.into(),
157 parent_id: parent_group_id.map(Into::into),
158 })
159 .await?;
160
161 Ok(())
162 }
163
164 async fn mls_groups_restore(
165 &self,
166 ) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>> {
167 let groups = self.find_all::<PersistedMlsGroup>(EntityFindParams::default()).await?;
168 Ok(groups
169 .into_iter()
170 .map(|group: PersistedMlsGroup| (group.id.clone(), (group.parent_id.clone(), group.state.clone())))
171 .collect())
172 }
173
174 async fn mls_group_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()> {
175 self.remove::<PersistedMlsGroup, _>(group_id).await?;
176
177 Ok(())
178 }
179
180 async fn mls_pending_groups_save(
181 &self,
182 group_id: &[u8],
183 mls_group: &[u8],
184 custom_configuration: &[u8],
185 parent_group_id: Option<&[u8]>,
186 ) -> CryptoKeystoreResult<()> {
187 self.save(PersistedMlsPendingGroup {
188 id: group_id.into(),
189 state: mls_group.into(),
190 custom_configuration: custom_configuration.into(),
191 parent_id: parent_group_id.map(Into::into),
192 })
193 .await?;
194 Ok(())
195 }
196
197 async fn mls_pending_groups_load(&self, group_id: &[u8]) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)> {
198 self.find(group_id)
199 .await?
200 .map(|r: PersistedMlsPendingGroup| (r.state.clone(), r.custom_configuration.clone()))
201 .ok_or(CryptoKeystoreError::MissingKeyInStore(
202 MissingKeyErrorKind::MlsPendingGroup,
203 ))
204 }
205
206 async fn mls_pending_groups_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()> {
207 self.remove::<PersistedMlsPendingGroup, _>(group_id).await
208 }
209
210 async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()> {
211 self.save(E2eiEnrollment {
212 id: id.into(),
213 content: content.into(),
214 })
215 .await?;
216 Ok(())
217 }
218
219 async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>> {
220 let enrollment = self
222 .find::<E2eiEnrollment>(id)
223 .await?
224 .ok_or(CryptoKeystoreError::MissingKeyInStore(
225 MissingKeyErrorKind::E2eiEnrollment,
226 ))?;
227 self.remove::<E2eiEnrollment, _>(id).await?;
228 Ok(enrollment.content.clone())
229 }
230}
231
232#[inline(always)]
233pub fn deser<T: MlsEntity>(bytes: &[u8]) -> Result<T, CryptoKeystoreError> {
234 Ok(postcard::from_bytes(bytes)?)
235}
236
237#[inline(always)]
238pub fn ser<T: MlsEntity>(value: &T) -> Result<Vec<u8>, CryptoKeystoreError> {
239 Ok(postcard::to_stdvec(value)?)
240}
241
242#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
243#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
244impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Connection {
245 type Error = CryptoKeystoreError;
246
247 async fn store<V: MlsEntity + Sync>(&self, k: &[u8], v: &V) -> Result<(), Self::Error>
248 where
249 Self: Sized,
250 {
251 if k.is_empty() {
252 return Err(CryptoKeystoreError::MlsKeyStoreError(
253 "The provided key is empty".into(),
254 ));
255 }
256
257 let data = ser(v)?;
258
259 match V::ID {
260 MlsEntityId::GroupState => {
261 return Err(CryptoKeystoreError::IncorrectApiUsage(
262 "Groups must not be saved using OpenMLS's APIs. You should use the keystore's provided methods",
263 ));
264 }
265 MlsEntityId::SignatureKeyPair => {
266 let concrete_signature_keypair: &SignatureKeyPair = v
267 .downcast()
268 .expect("There's an implementation issue in OpenMLS. This shouln't be happening.");
269
270 let credential_id = vec![];
272 let kp = MlsSignatureKeyPair::new(
273 concrete_signature_keypair.signature_scheme(),
274 k.into(),
275 data,
276 credential_id,
277 );
278 self.save(kp).await?;
279 }
280 MlsEntityId::KeyPackage => {
281 let kp = MlsKeyPackage {
282 keypackage_ref: k.into(),
283 keypackage: data,
284 };
285 self.save(kp).await?;
286 }
287 MlsEntityId::HpkePrivateKey => {
288 let kp = MlsHpkePrivateKey { pk: k.into(), sk: data };
289 self.save(kp).await?;
290 }
291 MlsEntityId::PskBundle => {
292 let kp = MlsPskBundle {
293 psk_id: k.into(),
294 psk: data,
295 };
296 self.save(kp).await?;
297 }
298 MlsEntityId::EncryptionKeyPair => {
299 let kp = MlsEncryptionKeyPair { pk: k.into(), sk: data };
300 self.save(kp).await?;
301 }
302 MlsEntityId::EpochEncryptionKeyPair => {
303 let kp = MlsEpochEncryptionKeyPair {
304 id: k.into(),
305 keypairs: data,
306 };
307 self.save(kp).await?;
308 }
309 }
310
311 Ok(())
312 }
313
314 async fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V>
315 where
316 Self: Sized,
317 {
318 if k.is_empty() {
319 return None;
320 }
321
322 match V::ID {
323 MlsEntityId::GroupState => {
324 let group: PersistedMlsGroup = self.find(k).await.ok().flatten()?;
325 deser(&group.state).ok()
326 }
327 MlsEntityId::SignatureKeyPair => {
328 let sig: MlsSignatureKeyPair = self.find(k).await.ok().flatten()?;
329 deser(&sig.keypair).ok()
330 }
331 MlsEntityId::KeyPackage => {
332 let kp: MlsKeyPackage = self.find(k).await.ok().flatten()?;
333 deser(&kp.keypackage).ok()
334 }
335 MlsEntityId::HpkePrivateKey => {
336 let hpke_pk: MlsHpkePrivateKey = self.find(k).await.ok().flatten()?;
337 deser(&hpke_pk.sk).ok()
338 }
339 MlsEntityId::PskBundle => {
340 let psk_bundle: MlsPskBundle = self.find(k).await.ok().flatten()?;
341 deser(&psk_bundle.psk).ok()
342 }
343 MlsEntityId::EncryptionKeyPair => {
344 let kp: MlsEncryptionKeyPair = self.find(k).await.ok().flatten()?;
345 deser(&kp.sk).ok()
346 }
347 MlsEntityId::EpochEncryptionKeyPair => {
348 let kp: MlsEpochEncryptionKeyPair = self.find(k).await.ok().flatten()?;
349 deser(&kp.keypairs).ok()
350 }
351 }
352 }
353
354 async fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
355 match V::ID {
356 MlsEntityId::GroupState => self.remove::<PersistedMlsGroup, _>(k).await?,
357 MlsEntityId::SignatureKeyPair => self.remove::<MlsSignatureKeyPair, _>(k).await?,
358 MlsEntityId::HpkePrivateKey => self.remove::<MlsHpkePrivateKey, _>(k).await?,
359 MlsEntityId::KeyPackage => self.remove::<MlsKeyPackage, _>(k).await?,
360 MlsEntityId::PskBundle => self.remove::<MlsPskBundle, _>(k).await?,
361 MlsEntityId::EncryptionKeyPair => self.remove::<MlsEncryptionKeyPair, _>(k).await?,
362 MlsEntityId::EpochEncryptionKeyPair => self.remove::<MlsEpochEncryptionKeyPair, _>(k).await?,
363 }
364
365 Ok(())
366 }
367}