use log::{debug, trace};
use openmls::prelude::LeafNodeIndex;
use super::MlsConversation;
use crate::MlsError;
const DELAY_RAMP_UP_MULTIPLIER: f32 = 120.0;
const DELAY_RAMP_UP_SUB: u64 = 106;
const DELAY_POS_LINEAR_INCR: u64 = 15;
const DELAY_POS_LINEAR_RANGE: std::ops::RangeInclusive<u64> = 1..=3;
impl MlsConversation {
pub fn compute_next_commit_delay(&self) -> Option<u64> {
use openmls::messages::proposals::Proposal;
if self.group.pending_proposals().count() > 0 {
let removed_index = self
.group
.pending_proposals()
.filter_map(|proposal| {
if let Proposal::Remove(remove_proposal) = proposal.proposal() {
Some(remove_proposal.removed())
} else {
None
}
})
.collect::<Vec<LeafNodeIndex>>();
let self_index = self.group.own_leaf_index();
debug!(removed_index:? = removed_index, self_index:? = self_index; "Indexes");
let is_self_removed = removed_index.iter().any(|&i| i == self_index);
if is_self_removed {
debug!("Self removed from group, no delay needed");
return None;
}
let epoch = self.group.epoch().as_u64();
let mut own_index = self.group.own_leaf_index().u32() as u64;
let left_tree_diff = self
.group
.members()
.take(own_index as usize)
.try_fold(0u32, |mut acc, kp| {
if removed_index.contains(&kp.index) {
acc += 1;
}
Result::<_, MlsError>::Ok(acc)
})
.map_err(MlsError::from)
.unwrap_or_default();
let nb_members = (self.group.members().count() as u64).saturating_sub(removed_index.len() as u64);
own_index = own_index.saturating_sub(left_tree_diff as u64);
Some(Self::calculate_delay(own_index, epoch, nb_members))
} else {
trace!("No pending proposals, no delay needed");
None
}
}
fn calculate_delay(self_index: u64, epoch: u64, nb_members: u64) -> u64 {
let position = if nb_members > 0 {
((epoch % nb_members) + (self_index % nb_members)) % nb_members + 1
} else {
1
};
if DELAY_POS_LINEAR_RANGE.contains(&position) {
position.saturating_sub(1) * DELAY_POS_LINEAR_INCR
} else {
(((position as f32).ln() * DELAY_RAMP_UP_MULTIPLIER) as u64).saturating_sub(DELAY_RAMP_UP_SUB)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{prelude::MlsConversationCreationMessage, test_utils::*};
use tls_codec::Serialize as _;
use wasm_bindgen_test::*;
wasm_bindgen_test_configure!(run_in_browser);
#[test]
#[wasm_bindgen_test]
fn calculate_delay_single() {
let (self_index, epoch, nb_members) = (0, 0, 1);
let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
assert_eq!(delay, 0);
}
#[test]
#[wasm_bindgen_test]
fn calculate_delay_max() {
let (self_index, epoch, nb_members) = (u64::MAX, u64::MAX, u64::MAX);
let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
assert_eq!(delay, 0);
}
#[test]
#[wasm_bindgen_test]
fn calculate_delay_min() {
let (self_index, epoch, nb_members) = (u64::MIN, u64::MIN, u64::MAX);
let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
assert_eq!(delay, 0);
}
#[test]
#[wasm_bindgen_test]
fn calculate_delay_zero_members() {
let (self_index, epoch, nb_members) = (0, 0, u64::MIN);
let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
assert_eq!(delay, 0);
}
#[test]
#[wasm_bindgen_test]
fn calculate_delay_min_max() {
let (self_index, epoch, nb_members) = (u64::MIN, u64::MAX, u64::MAX);
let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
assert_eq!(delay, 0);
}
#[test]
#[wasm_bindgen_test]
fn calculate_delay_n() {
let epoch = 1;
let nb_members = 10;
let indexes_delays = [
(0, 15),
(1, 30),
(2, 60),
(3, 87),
(4, 109),
(5, 127),
(6, 143),
(7, 157),
(8, 170),
(9, 0),
(10, 15),
];
for (self_index, expected_delay) in indexes_delays {
let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
assert_eq!(delay, expected_delay);
}
}
#[apply(all_cred_cipher)]
#[wasm_bindgen_test]
async fn calculate_delay_creator_removed(case: TestCase) {
run_test_with_client_ids(
case.clone(),
["alice", "bob", "charlie"],
move |[mut alice_central, mut bob_central, mut charlie_central]| {
Box::pin(async move {
let id = conversation_id();
alice_central
.mls_central
.new_conversation(&id, case.credential_type, case.cfg.clone())
.await
.unwrap();
let bob = bob_central.mls_central.rand_key_package(&case).await;
let MlsConversationCreationMessage {
welcome: bob_welcome, ..
} = alice_central
.mls_central
.add_members_to_conversation(&id, vec![bob])
.await
.unwrap();
assert_eq!(
alice_central
.mls_central
.get_conversation_unchecked(&id)
.await
.members()
.len(),
1
);
alice_central.mls_central.commit_accepted(&id).await.unwrap();
assert_eq!(
alice_central
.mls_central
.get_conversation_unchecked(&id)
.await
.members()
.len(),
2
);
bob_central
.mls_central
.process_welcome_message(bob_welcome.clone().into(), case.custom_cfg())
.await
.unwrap();
let charlie = charlie_central.mls_central.rand_key_package(&case).await;
let MlsConversationCreationMessage {
welcome: charlie_welcome,
commit,
..
} = alice_central
.mls_central
.add_members_to_conversation(&id, vec![charlie])
.await
.unwrap();
assert_eq!(
alice_central
.mls_central
.get_conversation_unchecked(&id)
.await
.members()
.len(),
2
);
alice_central.mls_central.commit_accepted(&id).await.unwrap();
assert_eq!(
alice_central
.mls_central
.get_conversation_unchecked(&id)
.await
.members()
.len(),
3
);
let _ = bob_central
.mls_central
.decrypt_message(&id, &commit.tls_serialize_detached().unwrap())
.await
.unwrap();
charlie_central
.mls_central
.process_welcome_message(charlie_welcome.into(), case.custom_cfg())
.await
.unwrap();
assert_eq!(
bob_central.mls_central.get_conversation_unchecked(&id).await.id(),
alice_central.mls_central.get_conversation_unchecked(&id).await.id()
);
assert_eq!(
charlie_central.mls_central.get_conversation_unchecked(&id).await.id(),
alice_central.mls_central.get_conversation_unchecked(&id).await.id()
);
let proposal_bundle = alice_central
.mls_central
.new_remove_proposal(&id, alice_central.mls_central.get_client_id())
.await
.unwrap();
let bob_hypothetical_position = 0;
let charlie_hypothetical_position = 1;
let bob_decrypted_message = bob_central
.mls_central
.decrypt_message(&id, &proposal_bundle.proposal.tls_serialize_detached().unwrap())
.await
.unwrap();
assert_eq!(
bob_decrypted_message.delay,
Some(DELAY_POS_LINEAR_INCR * bob_hypothetical_position)
);
let charlie_decrypted_message = charlie_central
.mls_central
.decrypt_message(&id, &proposal_bundle.proposal.tls_serialize_detached().unwrap())
.await
.unwrap();
assert_eq!(
charlie_decrypted_message.delay,
Some(DELAY_POS_LINEAR_INCR * charlie_hypothetical_position)
);
})
},
)
.await;
}
}