core_crypto/mls/conversation/
commit_delay.rs1use log::{debug, trace};
2use openmls::prelude::LeafNodeIndex;
3
4use super::MlsConversation;
5use crate::MlsError;
6
7const DELAY_RAMP_UP_MULTIPLIER: f32 = 120.0;
9const DELAY_RAMP_UP_SUB: u64 = 106;
10const DELAY_POS_LINEAR_INCR: u64 = 15;
11const DELAY_POS_LINEAR_RANGE: std::ops::RangeInclusive<u64> = 1..=3;
12
13impl MlsConversation {
14 pub fn compute_next_commit_delay(&self) -> Option<u64> {
20 use openmls::messages::proposals::Proposal;
21
22 if self.group.pending_proposals().count() > 0 {
23 let removed_index = self
24 .group
25 .pending_proposals()
26 .filter_map(|proposal| {
27 if let Proposal::Remove(remove_proposal) = proposal.proposal() {
28 Some(remove_proposal.removed())
29 } else {
30 None
31 }
32 })
33 .collect::<Vec<LeafNodeIndex>>();
34
35 let self_index = self.group.own_leaf_index();
36 debug!(removed_index:? = removed_index, self_index:? = self_index; "Indexes");
37 let is_self_removed = removed_index.iter().any(|&i| i == self_index);
39
40 if is_self_removed {
42 debug!("Self removed from group, no delay needed");
43 return None;
44 }
45
46 let epoch = self.group.epoch().as_u64();
47 let mut own_index = self.group.own_leaf_index().u32() as u64;
48
49 let left_tree_diff = self
51 .group
52 .members()
53 .take(own_index as usize)
54 .try_fold(0u32, |mut acc, kp| {
55 if removed_index.contains(&kp.index) {
56 acc += 1;
57 }
58
59 Result::<_, MlsError>::Ok(acc)
60 })
61 .unwrap_or_default();
62
63 let nb_members = (self.group.members().count() as u64).saturating_sub(removed_index.len() as u64);
65 own_index = own_index.saturating_sub(left_tree_diff as u64);
67
68 Some(Self::calculate_delay(own_index, epoch, nb_members))
69 } else {
70 trace!("No pending proposals, no delay needed");
71 None
72 }
73 }
74
75 fn calculate_delay(self_index: u64, epoch: u64, nb_members: u64) -> u64 {
76 let position = if nb_members > 0 {
77 ((epoch % nb_members) + (self_index % nb_members)) % nb_members + 1
78 } else {
79 1
80 };
81
82 if DELAY_POS_LINEAR_RANGE.contains(&position) {
83 position.saturating_sub(1) * DELAY_POS_LINEAR_INCR
84 } else {
85 (((position as f32).ln() * DELAY_RAMP_UP_MULTIPLIER) as u64).saturating_sub(DELAY_RAMP_UP_SUB)
86 }
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use crate::{prelude::MlsConversationCreationMessage, test_utils::*};
94 use tls_codec::Serialize as _;
95 use wasm_bindgen_test::*;
96
97 wasm_bindgen_test_configure!(run_in_browser);
98
99 #[test]
100 #[wasm_bindgen_test]
101 fn calculate_delay_single() {
102 let (self_index, epoch, nb_members) = (0, 0, 1);
103 let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
104 assert_eq!(delay, 0);
105 }
106
107 #[test]
108 #[wasm_bindgen_test]
109 fn calculate_delay_max() {
110 let (self_index, epoch, nb_members) = (u64::MAX, u64::MAX, u64::MAX);
111 let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
112 assert_eq!(delay, 0);
113 }
114
115 #[test]
116 #[wasm_bindgen_test]
117 fn calculate_delay_min() {
118 let (self_index, epoch, nb_members) = (u64::MIN, u64::MIN, u64::MAX);
119 let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
120 assert_eq!(delay, 0);
121 }
122
123 #[test]
124 #[wasm_bindgen_test]
125 fn calculate_delay_zero_members() {
126 let (self_index, epoch, nb_members) = (0, 0, u64::MIN);
127 let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
128 assert_eq!(delay, 0);
129 }
130
131 #[test]
132 #[wasm_bindgen_test]
133 fn calculate_delay_min_max() {
134 let (self_index, epoch, nb_members) = (u64::MIN, u64::MAX, u64::MAX);
135 let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
136 assert_eq!(delay, 0);
137 }
138
139 #[test]
140 #[wasm_bindgen_test]
141 fn calculate_delay_n() {
142 let epoch = 1;
143 let nb_members = 10;
144
145 let indexes_delays = [
146 (0, 15),
147 (1, 30),
148 (2, 60),
149 (3, 87),
150 (4, 109),
151 (5, 127),
152 (6, 143),
153 (7, 157),
154 (8, 170),
155 (9, 0),
156 (10, 15),
158 ];
159
160 for (self_index, expected_delay) in indexes_delays {
161 let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
162 assert_eq!(delay, expected_delay);
163 }
164 }
165
166 #[apply(all_cred_cipher)]
167 #[wasm_bindgen_test]
168 async fn calculate_delay_creator_removed(case: TestCase) {
169 run_test_with_client_ids(
170 case.clone(),
171 ["alice", "bob", "charlie"],
172 move |[alice_central, bob_central, charlie_central]| {
173 Box::pin(async move {
174 let id = conversation_id();
175
176 alice_central
177 .context
178 .new_conversation(&id, case.credential_type, case.cfg.clone())
179 .await
180 .unwrap();
181
182 let bob = bob_central.rand_key_package(&case).await;
183 let MlsConversationCreationMessage {
184 welcome: bob_welcome, ..
185 } = alice_central
186 .context
187 .add_members_to_conversation(&id, vec![bob])
188 .await
189 .unwrap();
190 assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 1);
191 alice_central.context.commit_accepted(&id).await.unwrap();
192 assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
193
194 bob_central
195 .context
196 .process_welcome_message(bob_welcome.clone().into(), case.custom_cfg())
197 .await
198 .unwrap();
199
200 let charlie = charlie_central.rand_key_package(&case).await;
201 let MlsConversationCreationMessage {
202 welcome: charlie_welcome,
203 commit,
204 ..
205 } = alice_central
206 .context
207 .add_members_to_conversation(&id, vec![charlie])
208 .await
209 .unwrap();
210 assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
211 alice_central.context.commit_accepted(&id).await.unwrap();
212 assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 3);
213
214 let _ = bob_central
215 .context
216 .decrypt_message(&id, &commit.tls_serialize_detached().unwrap())
217 .await
218 .unwrap();
219
220 charlie_central
221 .context
222 .process_welcome_message(charlie_welcome.into(), case.custom_cfg())
223 .await
224 .unwrap();
225
226 assert_eq!(
227 bob_central.get_conversation_unchecked(&id).await.id(),
228 alice_central.get_conversation_unchecked(&id).await.id()
229 );
230 assert_eq!(
231 charlie_central.get_conversation_unchecked(&id).await.id(),
232 alice_central.get_conversation_unchecked(&id).await.id()
233 );
234
235 let proposal_bundle = alice_central
236 .context
237 .new_remove_proposal(&id, alice_central.get_client_id().await)
238 .await
239 .unwrap();
240
241 let bob_hypothetical_position = 0;
242 let charlie_hypothetical_position = 1;
243
244 let bob_decrypted_message = bob_central
245 .context
246 .decrypt_message(&id, &proposal_bundle.proposal.tls_serialize_detached().unwrap())
247 .await
248 .unwrap();
249
250 assert_eq!(
251 bob_decrypted_message.delay,
252 Some(DELAY_POS_LINEAR_INCR * bob_hypothetical_position)
253 );
254
255 let charlie_decrypted_message = charlie_central
256 .context
257 .decrypt_message(&id, &proposal_bundle.proposal.tls_serialize_detached().unwrap())
258 .await
259 .unwrap();
260
261 assert_eq!(
262 charlie_decrypted_message.delay,
263 Some(DELAY_POS_LINEAR_INCR * charlie_hypothetical_position)
264 );
265 })
266 },
267 )
268 .await;
269 }
270}