core_crypto/mls/conversation/
commit_delay.rs

1use log::{debug, trace};
2use openmls::prelude::LeafNodeIndex;
3
4use super::MlsConversation;
5use crate::MlsError;
6
7/// These constants intend to ramp up the delay and flatten the curve for later positions
8const DELAY_RAMP_UP_MULTIPLIER: f32 = 120.0;
9const DELAY_RAMP_UP_SUB: u64 = 106;
10const DELAY_POS_LINEAR_INCR: u64 = 15;
11const DELAY_POS_LINEAR_RANGE: std::ops::RangeInclusive<u64> = 1..=3;
12
13impl MlsConversation {
14    /// Helps consumer by providing a deterministic delay in seconds for him to commit its pending proposal.
15    /// It depends on the index of the client in the ratchet tree
16    /// * `self_index` - ratchet tree index of self client
17    /// * `epoch` - current group epoch
18    /// * `nb_members` - number of clients in the group
19    pub fn compute_next_commit_delay(&self) -> Option<u64> {
20        use openmls::messages::proposals::Proposal;
21
22        if self.group.pending_proposals().count() > 0 {
23            let removed_index = self
24                .group
25                .pending_proposals()
26                .filter_map(|proposal| {
27                    if let Proposal::Remove(remove_proposal) = proposal.proposal() {
28                        Some(remove_proposal.removed())
29                    } else {
30                        None
31                    }
32                })
33                .collect::<Vec<LeafNodeIndex>>();
34
35            let self_index = self.group.own_leaf_index();
36            debug!(removed_index:? = removed_index, self_index:? = self_index; "Indexes");
37            // Find a remove proposal that concerns us
38            let is_self_removed = removed_index.iter().any(|&i| i == self_index);
39
40            // If our own client has been removed, don't commit
41            if is_self_removed {
42                debug!("Self removed from group, no delay needed");
43                return None;
44            }
45
46            let epoch = self.group.epoch().as_u64();
47            let mut own_index = self.group.own_leaf_index().u32() as u64;
48
49            // Look for members that were removed at the left of our tree in order to shift our own leaf index (post-commit tree visualization)
50            let left_tree_diff = self
51                .group
52                .members()
53                .take(own_index as usize)
54                .try_fold(0u32, |mut acc, kp| {
55                    if removed_index.contains(&kp.index) {
56                        acc += 1;
57                    }
58
59                    Result::<_, MlsError>::Ok(acc)
60                })
61                .unwrap_or_default();
62
63            // Post-commit visualization of the number of members after remove proposals
64            let nb_members = (self.group.members().count() as u64).saturating_sub(removed_index.len() as u64);
65            // This shifts our own leaf index to the left (tree-wise) from as many as there was removed members that have a smaller leaf index than us (older members)
66            own_index = own_index.saturating_sub(left_tree_diff as u64);
67
68            Some(Self::calculate_delay(own_index, epoch, nb_members))
69        } else {
70            trace!("No pending proposals, no delay needed");
71            None
72        }
73    }
74
75    fn calculate_delay(self_index: u64, epoch: u64, nb_members: u64) -> u64 {
76        let position = if nb_members > 0 {
77            ((epoch % nb_members) + (self_index % nb_members)) % nb_members + 1
78        } else {
79            1
80        };
81
82        if DELAY_POS_LINEAR_RANGE.contains(&position) {
83            position.saturating_sub(1) * DELAY_POS_LINEAR_INCR
84        } else {
85            (((position as f32).ln() * DELAY_RAMP_UP_MULTIPLIER) as u64).saturating_sub(DELAY_RAMP_UP_SUB)
86        }
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use crate::{prelude::MlsConversationCreationMessage, test_utils::*};
94    use tls_codec::Serialize as _;
95    use wasm_bindgen_test::*;
96
97    wasm_bindgen_test_configure!(run_in_browser);
98
99    #[test]
100    #[wasm_bindgen_test]
101    fn calculate_delay_single() {
102        let (self_index, epoch, nb_members) = (0, 0, 1);
103        let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
104        assert_eq!(delay, 0);
105    }
106
107    #[test]
108    #[wasm_bindgen_test]
109    fn calculate_delay_max() {
110        let (self_index, epoch, nb_members) = (u64::MAX, u64::MAX, u64::MAX);
111        let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
112        assert_eq!(delay, 0);
113    }
114
115    #[test]
116    #[wasm_bindgen_test]
117    fn calculate_delay_min() {
118        let (self_index, epoch, nb_members) = (u64::MIN, u64::MIN, u64::MAX);
119        let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
120        assert_eq!(delay, 0);
121    }
122
123    #[test]
124    #[wasm_bindgen_test]
125    fn calculate_delay_zero_members() {
126        let (self_index, epoch, nb_members) = (0, 0, u64::MIN);
127        let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
128        assert_eq!(delay, 0);
129    }
130
131    #[test]
132    #[wasm_bindgen_test]
133    fn calculate_delay_min_max() {
134        let (self_index, epoch, nb_members) = (u64::MIN, u64::MAX, u64::MAX);
135        let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
136        assert_eq!(delay, 0);
137    }
138
139    #[test]
140    #[wasm_bindgen_test]
141    fn calculate_delay_n() {
142        let epoch = 1;
143        let nb_members = 10;
144
145        let indexes_delays = [
146            (0, 15),
147            (1, 30),
148            (2, 60),
149            (3, 87),
150            (4, 109),
151            (5, 127),
152            (6, 143),
153            (7, 157),
154            (8, 170),
155            (9, 0),
156            // wrong but it shouldn't cause problems
157            (10, 15),
158        ];
159
160        for (self_index, expected_delay) in indexes_delays {
161            let delay = MlsConversation::calculate_delay(self_index, epoch, nb_members);
162            assert_eq!(delay, expected_delay);
163        }
164    }
165
166    #[apply(all_cred_cipher)]
167    #[wasm_bindgen_test]
168    async fn calculate_delay_creator_removed(case: TestCase) {
169        run_test_with_client_ids(
170            case.clone(),
171            ["alice", "bob", "charlie"],
172            move |[alice_central, bob_central, charlie_central]| {
173                Box::pin(async move {
174                    let id = conversation_id();
175
176                    alice_central
177                        .context
178                        .new_conversation(&id, case.credential_type, case.cfg.clone())
179                        .await
180                        .unwrap();
181
182                    let bob = bob_central.rand_key_package(&case).await;
183                    let MlsConversationCreationMessage {
184                        welcome: bob_welcome, ..
185                    } = alice_central
186                        .context
187                        .add_members_to_conversation(&id, vec![bob])
188                        .await
189                        .unwrap();
190                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 1);
191                    alice_central.context.commit_accepted(&id).await.unwrap();
192                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
193
194                    bob_central
195                        .context
196                        .process_welcome_message(bob_welcome.clone().into(), case.custom_cfg())
197                        .await
198                        .unwrap();
199
200                    let charlie = charlie_central.rand_key_package(&case).await;
201                    let MlsConversationCreationMessage {
202                        welcome: charlie_welcome,
203                        commit,
204                        ..
205                    } = alice_central
206                        .context
207                        .add_members_to_conversation(&id, vec![charlie])
208                        .await
209                        .unwrap();
210                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
211                    alice_central.context.commit_accepted(&id).await.unwrap();
212                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 3);
213
214                    let _ = bob_central
215                        .context
216                        .decrypt_message(&id, &commit.tls_serialize_detached().unwrap())
217                        .await
218                        .unwrap();
219
220                    charlie_central
221                        .context
222                        .process_welcome_message(charlie_welcome.into(), case.custom_cfg())
223                        .await
224                        .unwrap();
225
226                    assert_eq!(
227                        bob_central.get_conversation_unchecked(&id).await.id(),
228                        alice_central.get_conversation_unchecked(&id).await.id()
229                    );
230                    assert_eq!(
231                        charlie_central.get_conversation_unchecked(&id).await.id(),
232                        alice_central.get_conversation_unchecked(&id).await.id()
233                    );
234
235                    let proposal_bundle = alice_central
236                        .context
237                        .new_remove_proposal(&id, alice_central.get_client_id().await)
238                        .await
239                        .unwrap();
240
241                    let bob_hypothetical_position = 0;
242                    let charlie_hypothetical_position = 1;
243
244                    let bob_decrypted_message = bob_central
245                        .context
246                        .decrypt_message(&id, &proposal_bundle.proposal.tls_serialize_detached().unwrap())
247                        .await
248                        .unwrap();
249
250                    assert_eq!(
251                        bob_decrypted_message.delay,
252                        Some(DELAY_POS_LINEAR_INCR * bob_hypothetical_position)
253                    );
254
255                    let charlie_decrypted_message = charlie_central
256                        .context
257                        .decrypt_message(&id, &proposal_bundle.proposal.tls_serialize_detached().unwrap())
258                        .await
259                        .unwrap();
260
261                    assert_eq!(
262                        charlie_decrypted_message.delay,
263                        Some(DELAY_POS_LINEAR_INCR * charlie_hypothetical_position)
264                    );
265                })
266            },
267        )
268        .await;
269    }
270}