core_crypto/mls/session/
epoch_observer.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use super::{Error, Result, Session};
6use crate::ConversationId;
7
8/// An `EpochObserver` is notified whenever a conversation's epoch changes.
9#[cfg_attr(target_family = "wasm", async_trait(?Send))]
10#[cfg_attr(not(target_family = "wasm"), async_trait)]
11pub trait EpochObserver: Send + Sync {
12    /// This function will be called every time a conversation's epoch changes.
13    ///
14    /// The `epoch` parameter is the new epoch.
15    ///
16    /// <div class="warning">
17    /// This function must not block! Foreign implementors of this inteface can
18    /// spawn a task indirecting the notification, or (unblocking) send the notification
19    /// on some kind of channel, or anything else, as long as the operation completes
20    /// quickly.
21    /// </div>
22    async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
23}
24
25impl Session {
26    /// Add an epoch observer to this session.
27    /// (see [EpochObserver]).
28    ///
29    /// This function should be called 0 or 1 times in a session's lifetime. If called
30    /// when an epoch observer already exists, this will return an error.
31    pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
32        let mut observer_guard = self.epoch_observer.write().await;
33        if observer_guard.is_some() {
34            return Err(Error::EpochObserverAlreadyExists);
35        }
36        observer_guard.replace(epoch_observer);
37        Ok(())
38    }
39
40    /// Notify the observer that the epoch has changed, if one is present.
41    pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
42        if let Some(observer) = self.epoch_observer.read().await.as_ref() {
43            observer.epoch_changed(conversation_id, epoch).await;
44        }
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use rstest::rstest;
51    use rstest_reuse::apply;
52
53    use crate::test_utils::{TestContext, TestEpochObserver, all_cred_cipher};
54
55    #[apply(all_cred_cipher)]
56    pub async fn observe_local_epoch_change(case: TestContext) {
57        let [session_context] = case.sessions().await;
58        Box::pin(async move {
59            let test_conv = case.create_conversation([&session_context]).await;
60
61            let observer = TestEpochObserver::new();
62            session_context
63                .session()
64                .await
65                .register_epoch_observer(observer.clone())
66                .await
67                .unwrap();
68
69            // trigger an epoch
70            let id = test_conv.advance_epoch().await.id;
71
72            // ensure we have observed the epoch change
73            let observed_epochs = observer.observed_epochs().await;
74            assert_eq!(
75                observed_epochs.len(),
76                1,
77                "we triggered exactly one epoch change and so should observe one epoch change"
78            );
79            assert_eq!(
80                observed_epochs[0].0, id,
81                "conversation id of observed epoch change must match"
82            );
83        })
84        .await
85    }
86
87    #[apply(all_cred_cipher)]
88    pub async fn observe_remote_epoch_change(case: TestContext) {
89        let [alice, bob] = case.sessions().await;
90        Box::pin(async move {
91            let test_conv = case.create_conversation([&alice, &bob]).await;
92
93            //  bob has the observer
94            let observer = TestEpochObserver::new();
95            bob.session()
96                .await
97                .register_epoch_observer(observer.clone())
98                .await
99                .unwrap();
100
101            // alice triggers an epoch
102            let id = test_conv.advance_epoch().await.id;
103
104            // ensure we have observed the epoch change
105            let observed_epochs = observer.observed_epochs().await;
106            assert_eq!(
107                observed_epochs.len(),
108                1,
109                "we triggered exactly one epoch change and so should observe one epoch change"
110            );
111            assert_eq!(
112                observed_epochs[0].0, id,
113                "conversation id of observed epoch change must match"
114            );
115        })
116        .await
117    }
118}