use openmls::prelude::{hash_ref::ProposalRef, KeyPackage};
use mls_crypto_provider::MlsCryptoProvider;
use crate::{
mls::{ClientId, ConversationId, MlsConversation},
prelude::{Client, CryptoError, CryptoResult, MlsProposalBundle},
};
use crate::context::CentralContext;
#[derive(Debug, Clone, Eq, PartialEq, derive_more::From, derive_more::Deref, derive_more::Display)]
pub struct MlsProposalRef(ProposalRef);
impl From<Vec<u8>> for MlsProposalRef {
fn from(value: Vec<u8>) -> Self {
Self(ProposalRef::from_slice(value.as_slice()))
}
}
impl MlsProposalRef {
pub fn into_inner(self) -> ProposalRef {
self.0
}
pub(crate) fn to_bytes(&self) -> Vec<u8> {
self.0.as_slice().to_vec()
}
}
#[cfg(test)]
impl From<MlsProposalRef> for Vec<u8> {
fn from(prop_ref: MlsProposalRef) -> Self {
prop_ref.0.as_slice().to_vec()
}
}
#[allow(clippy::large_enum_variant)]
pub enum MlsProposal {
Add(KeyPackage),
Update,
Remove(ClientId),
}
impl MlsProposal {
async fn create(
self,
client: &Client,
backend: &MlsCryptoProvider,
mut conversation: impl std::ops::DerefMut<Target = MlsConversation>,
) -> CryptoResult<MlsProposalBundle> {
let proposal = match self {
MlsProposal::Add(key_package) => {
(*conversation)
.propose_add_member(client, backend, key_package.into())
.await
}
MlsProposal::Update => (*conversation).propose_self_update(client, backend).await,
MlsProposal::Remove(client_id) => {
let index = conversation
.group
.members()
.find(|kp| kp.credential.identity() == client_id.as_slice())
.ok_or(CryptoError::ClientNotFound(client_id))
.map(|kp| kp.index)?;
(*conversation).propose_remove_member(client, backend, index).await
}
}?;
Ok(proposal)
}
}
impl CentralContext {
#[cfg_attr(test, crate::idempotent)]
pub async fn new_add_proposal(
&self,
id: &ConversationId,
key_package: KeyPackage,
) -> CryptoResult<MlsProposalBundle> {
self.new_proposal(id, MlsProposal::Add(key_package)).await
}
#[cfg_attr(test, crate::idempotent)]
pub async fn new_remove_proposal(
&self,
id: &ConversationId,
client_id: ClientId,
) -> CryptoResult<MlsProposalBundle> {
self.new_proposal(id, MlsProposal::Remove(client_id)).await
}
#[cfg_attr(test, crate::dispotent)]
pub async fn new_update_proposal(&self, id: &ConversationId) -> CryptoResult<MlsProposalBundle> {
self.new_proposal(id, MlsProposal::Update).await
}
async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> CryptoResult<MlsProposalBundle> {
let conversation = self.get_conversation(id).await?;
let client = &self.mls_client().await?;
proposal
.create(client, &self.mls_provider().await?, conversation.write().await)
.await
}
}
#[cfg(test)]
mod tests {
use wasm_bindgen_test::*;
use crate::{prelude::MlsCommitBundle, prelude::*, test_utils::*};
wasm_bindgen_test_configure!(run_in_browser);
mod add {
use super::*;
#[apply(all_cred_cipher)]
#[wasm_bindgen_test]
pub async fn should_add_member(case: TestCase) {
run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice_central, bob_central]| {
Box::pin(async move {
let id = conversation_id();
alice_central
.context
.new_conversation(&id, case.credential_type, case.cfg.clone())
.await
.unwrap();
let bob_kp = bob_central.get_one_key_package(&case).await;
alice_central.context.new_add_proposal(&id, bob_kp).await.unwrap();
let MlsCommitBundle { welcome, .. } = alice_central
.context
.commit_pending_proposals(&id)
.await
.unwrap()
.unwrap();
alice_central.context.commit_accepted(&id).await.unwrap();
assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
let new_id = bob_central
.context
.process_welcome_message(welcome.unwrap().into(), case.custom_cfg())
.await
.unwrap()
.id;
assert_eq!(id, new_id);
assert!(bob_central.try_talk_to(&id, &alice_central).await.is_ok());
})
})
.await
}
}
mod update {
use super::*;
use itertools::Itertools;
#[apply(all_cred_cipher)]
#[wasm_bindgen_test]
pub async fn should_update_hpke_key(case: TestCase) {
run_test_with_central(case.clone(), move |[central]| {
Box::pin(async move {
let id = conversation_id();
central
.context
.new_conversation(&id, case.credential_type, case.cfg.clone())
.await
.unwrap();
let before = central
.get_conversation_unchecked(&id)
.await
.encryption_keys()
.find_or_first(|_| true)
.unwrap();
central.context.new_update_proposal(&id).await.unwrap();
central.context.commit_pending_proposals(&id).await.unwrap();
central.context.commit_accepted(&id).await.unwrap();
let after = central
.get_conversation_unchecked(&id)
.await
.encryption_keys()
.find_or_first(|_| true)
.unwrap();
assert_ne!(before, after)
})
})
.await
}
}
mod remove {
use super::*;
#[apply(all_cred_cipher)]
#[wasm_bindgen_test]
pub async fn should_remove_member(case: TestCase) {
run_test_with_client_ids(case.clone(), ["alice", "bob"], |[alice_central, bob_central]| {
Box::pin(async move {
let id = conversation_id();
alice_central
.context
.new_conversation(&id, case.credential_type, case.cfg.clone())
.await
.unwrap();
alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
assert_eq!(bob_central.get_conversation_unchecked(&id).await.members().len(), 2);
let remove_proposal = alice_central
.context
.new_remove_proposal(&id, bob_central.get_client_id().await)
.await
.unwrap();
bob_central
.context
.decrypt_message(&id, remove_proposal.proposal.to_bytes().unwrap())
.await
.unwrap();
let MlsCommitBundle { commit, .. } = alice_central
.context
.commit_pending_proposals(&id)
.await
.unwrap()
.unwrap();
alice_central.context.commit_accepted(&id).await.unwrap();
assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 1);
bob_central
.context
.decrypt_message(&id, commit.to_bytes().unwrap())
.await
.unwrap();
assert!(matches!(
bob_central.context.get_conversation(&id).await.unwrap_err(),
CryptoError::ConversationNotFound(conv_id) if conv_id == id
));
})
})
.await
}
#[apply(all_cred_cipher)]
#[wasm_bindgen_test]
pub async fn should_fail_when_unknown_client(case: TestCase) {
run_test_with_client_ids(case.clone(), ["alice"], move |[alice_central]| {
Box::pin(async move {
let id = conversation_id();
alice_central
.context
.new_conversation(&id, case.credential_type, case.cfg.clone())
.await
.unwrap();
let remove_proposal = alice_central
.context
.new_remove_proposal(&id, b"unknown"[..].into())
.await;
assert!(matches!(
remove_proposal.unwrap_err(),
CryptoError::ClientNotFound(client_id) if client_id == b"unknown"[..].into()
));
})
})
.await
}
}
}