core_crypto/mls/session/
epoch_observer.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::prelude::ConversationId;
6
7use super::{Error, Result, Session};
8
9/// An `EpochObserver` is notified whenever a conversation's epoch changes.
10#[cfg_attr(target_family = "wasm", async_trait(?Send))]
11#[cfg_attr(not(target_family = "wasm"), async_trait)]
12pub trait EpochObserver: Send + Sync {
13    /// This function will be called every time a conversation's epoch changes.
14    ///
15    /// The `epoch` parameter is the new epoch.
16    ///
17    /// <div class="warning">
18    /// This function must not block! Foreign implementors of this inteface can
19    /// spawn a task indirecting the notification, or (unblocking) send the notification
20    /// on some kind of channel, or anything else, as long as the operation completes
21    /// quickly.
22    /// </div>
23    async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
24}
25
26impl Session {
27    /// Add an epoch observer to this session.
28    /// (see [EpochObserver]).
29    ///
30    /// This function should be called 0 or 1 times in a session's lifetime. If called
31    /// when an epoch observer already exists, this will return an error.
32    pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
33        let mut observer_guard = self.epoch_observer.write().await;
34        if observer_guard.is_some() {
35            return Err(Error::EpochObserverAlreadyExists);
36        }
37        observer_guard.replace(epoch_observer);
38        Ok(())
39    }
40
41    /// Notify the observer that the epoch has changed, if one is present.
42    pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
43        if let Some(observer) = self.epoch_observer.read().await.as_ref() {
44            observer.epoch_changed(conversation_id, epoch).await;
45        }
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use rstest::rstest;
52    use rstest_reuse::apply;
53    use wasm_bindgen_test::*;
54
55    use crate::test_utils::{TestContext, TestEpochObserver, all_cred_cipher};
56
57    #[apply(all_cred_cipher)]
58    #[wasm_bindgen_test]
59    pub async fn observe_local_epoch_change(case: TestContext) {
60        let [session_context] = case.sessions().await;
61        Box::pin(async move {
62            let test_conv = case.create_conversation([&session_context]).await;
63
64            let observer = TestEpochObserver::new();
65            session_context
66                .session()
67                .await
68                .register_epoch_observer(observer.clone())
69                .await
70                .unwrap();
71
72            // trigger an epoch
73            let id = test_conv.advance_epoch().await.id;
74
75            // ensure we have observed the epoch change
76            let observed_epochs = observer.observed_epochs().await;
77            assert_eq!(
78                observed_epochs.len(),
79                1,
80                "we triggered exactly one epoch change and so should observe one epoch change"
81            );
82            assert_eq!(
83                observed_epochs[0].0, id,
84                "conversation id of observed epoch change must match"
85            );
86        })
87        .await
88    }
89
90    #[apply(all_cred_cipher)]
91    #[wasm_bindgen_test]
92    pub async fn observe_remote_epoch_change(case: TestContext) {
93        let [alice, bob] = case.sessions().await;
94        Box::pin(async move {
95            let test_conv = case.create_conversation([&alice, &bob]).await;
96
97            //  bob has the observer
98            let observer = TestEpochObserver::new();
99            bob.session()
100                .await
101                .register_epoch_observer(observer.clone())
102                .await
103                .unwrap();
104
105            // alice triggers an epoch
106            let id = test_conv.advance_epoch().await.id;
107
108            // ensure we have observed the epoch change
109            let observed_epochs = observer.observed_epochs().await;
110            assert_eq!(
111                observed_epochs.len(),
112                1,
113                "we triggered exactly one epoch change and so should observe one epoch change"
114            );
115            assert_eq!(
116                observed_epochs[0].0, id,
117                "conversation id of observed epoch change must match"
118            );
119        })
120        .await
121    }
122}