use crate::connection::FetchFromDatabase;
use crate::entities::MlsEpochEncryptionKeyPair;
use crate::{
entities::{
E2eiEnrollment, EntityFindParams, MlsEncryptionKeyPair, MlsHpkePrivateKey, MlsKeyPackage, MlsPskBundle,
MlsSignatureKeyPair, PersistedMlsGroup, PersistedMlsPendingGroup,
},
CryptoKeystoreError, CryptoKeystoreResult, MissingKeyErrorKind,
};
use openmls_basic_credential::SignatureKeyPair;
use openmls_traits::key_store::{MlsEntity, MlsEntityId};
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
pub trait CryptoKeystoreMls: Sized {
async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>>;
async fn mls_group_exists(&self, group_id: &[u8]) -> bool;
async fn mls_group_persist(
&self,
group_id: &[u8],
state: &[u8],
parent_group_id: Option<&[u8]>,
) -> CryptoKeystoreResult<()>;
async fn mls_groups_restore(
&self,
) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>>;
async fn mls_group_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()>;
async fn mls_pending_groups_save(
&self,
group_id: &[u8],
mls_group: &[u8],
custom_configuration: &[u8],
parent_group_id: Option<&[u8]>,
) -> CryptoKeystoreResult<()>;
async fn mls_pending_groups_load(&self, group_id: &[u8]) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)>;
async fn mls_pending_groups_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()>;
async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()>;
async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>>;
}
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
impl CryptoKeystoreMls for crate::Connection {
async fn mls_fetch_keypackages<V: MlsEntity>(&self, count: u32) -> CryptoKeystoreResult<Vec<V>> {
cfg_if::cfg_if! {
if #[cfg(not(target_family = "wasm"))] {
let reverse = true;
} else {
let reverse = false;
}
}
let keypackages = self
.find_all::<MlsKeyPackage>(EntityFindParams {
limit: Some(count),
offset: None,
reverse,
})
.await?;
Ok(keypackages
.into_iter()
.filter_map(|kpb| postcard::from_bytes(&kpb.keypackage).ok())
.collect())
}
async fn mls_group_exists(&self, group_id: &[u8]) -> bool {
matches!(self.find::<PersistedMlsGroup>(group_id).await, Ok(Some(_)))
}
async fn mls_group_persist(
&self,
group_id: &[u8],
state: &[u8],
parent_group_id: Option<&[u8]>,
) -> CryptoKeystoreResult<()> {
self.save(PersistedMlsGroup {
id: group_id.into(),
state: state.into(),
parent_id: parent_group_id.map(Into::into),
})
.await?;
Ok(())
}
async fn mls_groups_restore(
&self,
) -> CryptoKeystoreResult<std::collections::HashMap<Vec<u8>, (Option<Vec<u8>>, Vec<u8>)>> {
let groups = self.find_all::<PersistedMlsGroup>(EntityFindParams::default()).await?;
Ok(groups
.into_iter()
.map(|group: PersistedMlsGroup| (group.id.clone(), (group.parent_id.clone(), group.state.clone())))
.collect())
}
async fn mls_group_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()> {
self.remove::<PersistedMlsGroup, _>(group_id).await?;
Ok(())
}
async fn mls_pending_groups_save(
&self,
group_id: &[u8],
mls_group: &[u8],
custom_configuration: &[u8],
parent_group_id: Option<&[u8]>,
) -> CryptoKeystoreResult<()> {
self.save(PersistedMlsPendingGroup {
id: group_id.into(),
state: mls_group.into(),
custom_configuration: custom_configuration.into(),
parent_id: parent_group_id.map(Into::into),
})
.await?;
Ok(())
}
async fn mls_pending_groups_load(&self, group_id: &[u8]) -> CryptoKeystoreResult<(Vec<u8>, Vec<u8>)> {
self.find(group_id)
.await?
.map(|r: PersistedMlsPendingGroup| (r.state.clone(), r.custom_configuration.clone()))
.ok_or(CryptoKeystoreError::MissingKeyInStore(
MissingKeyErrorKind::MlsPendingGroup,
))
}
async fn mls_pending_groups_delete(&self, group_id: &[u8]) -> CryptoKeystoreResult<()> {
self.remove::<PersistedMlsPendingGroup, _>(group_id).await
}
async fn save_e2ei_enrollment(&self, id: &[u8], content: &[u8]) -> CryptoKeystoreResult<()> {
self.save(E2eiEnrollment {
id: id.into(),
content: content.into(),
})
.await?;
Ok(())
}
async fn pop_e2ei_enrollment(&self, id: &[u8]) -> CryptoKeystoreResult<Vec<u8>> {
let enrollment = self
.find::<E2eiEnrollment>(id)
.await?
.ok_or(CryptoKeystoreError::MissingKeyInStore(
MissingKeyErrorKind::E2eiEnrollment,
))?;
self.remove::<E2eiEnrollment, _>(id).await?;
Ok(enrollment.content.clone())
}
}
#[inline(always)]
pub fn deser<T: MlsEntity>(bytes: &[u8]) -> Result<T, CryptoKeystoreError> {
Ok(postcard::from_bytes(bytes)?)
}
#[inline(always)]
pub fn ser<T: MlsEntity>(value: &T) -> Result<Vec<u8>, CryptoKeystoreError> {
Ok(postcard::to_stdvec(value)?)
}
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
impl openmls_traits::key_store::OpenMlsKeyStore for crate::connection::Connection {
type Error = CryptoKeystoreError;
async fn store<V: MlsEntity + Sync>(&self, k: &[u8], v: &V) -> Result<(), Self::Error>
where
Self: Sized,
{
if k.is_empty() {
return Err(CryptoKeystoreError::MlsKeyStoreError(
"The provided key is empty".into(),
));
}
let data = ser(v)?;
match V::ID {
MlsEntityId::GroupState => {
return Err(CryptoKeystoreError::IncorrectApiUsage(
"Groups must not be saved using OpenMLS's APIs. You should use the keystore's provided methods",
));
}
MlsEntityId::SignatureKeyPair => {
let concrete_signature_keypair: &SignatureKeyPair = v
.downcast()
.expect("There's an implementation issue in OpenMLS. This shouln't be happening.");
let credential_id = vec![];
let kp = MlsSignatureKeyPair::new(
concrete_signature_keypair.signature_scheme(),
k.into(),
data,
credential_id,
);
self.save(kp).await?;
}
MlsEntityId::KeyPackage => {
let kp = MlsKeyPackage {
keypackage_ref: k.into(),
keypackage: data,
};
self.save(kp).await?;
}
MlsEntityId::HpkePrivateKey => {
let kp = MlsHpkePrivateKey { pk: k.into(), sk: data };
self.save(kp).await?;
}
MlsEntityId::PskBundle => {
let kp = MlsPskBundle {
psk_id: k.into(),
psk: data,
};
self.save(kp).await?;
}
MlsEntityId::EncryptionKeyPair => {
let kp = MlsEncryptionKeyPair { pk: k.into(), sk: data };
self.save(kp).await?;
}
MlsEntityId::EpochEncryptionKeyPair => {
let kp = MlsEpochEncryptionKeyPair {
id: k.into(),
keypairs: data,
};
self.save(kp).await?;
}
}
Ok(())
}
async fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V>
where
Self: Sized,
{
if k.is_empty() {
return None;
}
match V::ID {
MlsEntityId::GroupState => {
let group: PersistedMlsGroup = self.find(k).await.ok().flatten()?;
deser(&group.state).ok()
}
MlsEntityId::SignatureKeyPair => {
let sig: MlsSignatureKeyPair = self.find(k).await.ok().flatten()?;
deser(&sig.keypair).ok()
}
MlsEntityId::KeyPackage => {
let kp: MlsKeyPackage = self.find(k).await.ok().flatten()?;
deser(&kp.keypackage).ok()
}
MlsEntityId::HpkePrivateKey => {
let hpke_pk: MlsHpkePrivateKey = self.find(k).await.ok().flatten()?;
deser(&hpke_pk.sk).ok()
}
MlsEntityId::PskBundle => {
let psk_bundle: MlsPskBundle = self.find(k).await.ok().flatten()?;
deser(&psk_bundle.psk).ok()
}
MlsEntityId::EncryptionKeyPair => {
let kp: MlsEncryptionKeyPair = self.find(k).await.ok().flatten()?;
deser(&kp.sk).ok()
}
MlsEntityId::EpochEncryptionKeyPair => {
let kp: MlsEpochEncryptionKeyPair = self.find(k).await.ok().flatten()?;
deser(&kp.keypairs).ok()
}
}
}
async fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
match V::ID {
MlsEntityId::GroupState => self.remove::<PersistedMlsGroup, _>(k).await?,
MlsEntityId::SignatureKeyPair => self.remove::<MlsSignatureKeyPair, _>(k).await?,
MlsEntityId::HpkePrivateKey => self.remove::<MlsHpkePrivateKey, _>(k).await?,
MlsEntityId::KeyPackage => self.remove::<MlsKeyPackage, _>(k).await?,
MlsEntityId::PskBundle => self.remove::<MlsPskBundle, _>(k).await?,
MlsEntityId::EncryptionKeyPair => self.remove::<MlsEncryptionKeyPair, _>(k).await?,
MlsEntityId::EpochEncryptionKeyPair => self.remove::<MlsEpochEncryptionKeyPair, _>(k).await?,
}
Ok(())
}
}