core_crypto/mls/session/
epoch_observer.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::prelude::ConversationId;
6
7use super::{Error, Result, Session};
8
9/// An `EpochObserver` is notified whenever a conversation's epoch changes.
10#[cfg_attr(target_family = "wasm", async_trait(?Send))]
11#[cfg_attr(not(target_family = "wasm"), async_trait)]
12pub trait EpochObserver: Send + Sync {
13    /// This function will be called every time a conversation's epoch changes.
14    ///
15    /// The `epoch` parameter is the new epoch.
16    ///
17    /// <div class="warning">
18    /// This function must not block! Foreign implementors of this inteface can
19    /// spawn a task indirecting the notification, or (unblocking) send the notification
20    /// on some kind of channel, or anything else, as long as the operation completes
21    /// quickly.
22    /// </div>
23    async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
24}
25
26impl Session {
27    /// Add an epoch observer to this session.
28    /// (see [EpochObserver]).
29    ///
30    /// This function should be called 0 or 1 times in a session's lifetime. If called
31    /// when an epoch observer already exists, this will return an error.
32    pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
33        let mut observer_guard = self.epoch_observer.write().await;
34        if observer_guard.is_some() {
35            return Err(Error::EpochObserverAlreadyExists);
36        }
37        observer_guard.replace(epoch_observer);
38        Ok(())
39    }
40
41    /// Notify the observer that the epoch has changed, if one is present.
42    pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
43        if let Some(observer) = self.epoch_observer.read().await.as_ref() {
44            observer.epoch_changed(conversation_id, epoch).await;
45        }
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use rstest::rstest;
52    use rstest_reuse::apply;
53    use wasm_bindgen_test::*;
54
55    use crate::test_utils::{TestContext, TestEpochObserver, all_cred_cipher, conversation_id};
56
57    #[apply(all_cred_cipher)]
58    #[wasm_bindgen_test]
59    pub async fn observe_local_epoch_change(case: TestContext) {
60        let [session_context] = case.sessions().await;
61        Box::pin(async move {
62            let id = conversation_id();
63            session_context
64                .transaction
65                .new_conversation(&id, case.credential_type, case.cfg.clone())
66                .await
67                .unwrap();
68
69            let observer = TestEpochObserver::new();
70            session_context
71                .session()
72                .await
73                .register_epoch_observer(observer.clone())
74                .await
75                .unwrap();
76
77            // trigger an epoch
78            session_context
79                .transaction
80                .conversation(&id)
81                .await
82                .unwrap()
83                .update_key_material()
84                .await
85                .unwrap();
86
87            // ensure we have observed the epoch change
88            let observed_epochs = observer.observed_epochs().await;
89            assert_eq!(
90                observed_epochs.len(),
91                1,
92                "we triggered exactly one epoch change and so should observe one epoch change"
93            );
94            assert_eq!(
95                observed_epochs[0].0, id,
96                "conversation id of observed epoch change must match"
97            );
98        })
99        .await
100    }
101
102    #[apply(all_cred_cipher)]
103    #[wasm_bindgen_test]
104    pub async fn observe_remote_epoch_change(case: TestContext) {
105        let [alice, bob] = case.sessions().await;
106        Box::pin(async move {
107            let id = conversation_id();
108            alice
109                .transaction
110                .new_conversation(&id, case.credential_type, case.cfg.clone())
111                .await
112                .unwrap();
113
114            alice.invite_all(&case, &id, [&bob]).await.unwrap();
115
116            //  bob has the observer
117            let observer = TestEpochObserver::new();
118            bob.session()
119                .await
120                .register_epoch_observer(observer.clone())
121                .await
122                .unwrap();
123
124            // alice triggers an epoch
125            alice
126                .transaction
127                .conversation(&id)
128                .await
129                .unwrap()
130                .update_key_material()
131                .await
132                .unwrap();
133
134            // communicate that to bob
135            let commit = alice.mls_transport.latest_commit().await;
136            bob.transaction
137                .conversation(&id)
138                .await
139                .unwrap()
140                .decrypt_message(commit.to_bytes().unwrap())
141                .await
142                .unwrap();
143
144            // ensure we have observed the epoch change
145            let observed_epochs = observer.observed_epochs().await;
146            assert_eq!(
147                observed_epochs.len(),
148                1,
149                "we triggered exactly one epoch change and so should observe one epoch change"
150            );
151            assert_eq!(
152                observed_epochs[0].0, id,
153                "conversation id of observed epoch change must match"
154            );
155        })
156        .await
157    }
158}