core_crypto/mls/session/
epoch_observer.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::prelude::ConversationId;
6
7use super::{Error, Result, Session};
8
9/// An `EpochObserver` is notified whenever a conversation's epoch changes.
10#[cfg_attr(target_family = "wasm", async_trait(?Send))]
11#[cfg_attr(not(target_family = "wasm"), async_trait)]
12pub trait EpochObserver: Send + Sync {
13    /// This function will be called every time a conversation's epoch changes.
14    ///
15    /// The `epoch` parameter is the new epoch.
16    ///
17    /// <div class="warning">
18    /// This function must not block! Foreign implementors of this inteface can
19    /// spawn a task indirecting the notification, or (unblocking) send the notification
20    /// on some kind of channel, or anything else, as long as the operation completes
21    /// quickly.
22    /// </div>
23    async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
24}
25
26impl Session {
27    /// Add an epoch observer to this session.
28    /// (see [EpochObserver]).
29    ///
30    /// This function should be called 0 or 1 times in a session's lifetime. If called
31    /// when an epoch observer already exists, this will return an error.
32    pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
33        let mut observer_guard = self.epoch_observer.write().await;
34        if observer_guard.is_some() {
35            return Err(Error::EpochObserverAlreadyExists);
36        }
37        observer_guard.replace(epoch_observer);
38        Ok(())
39    }
40
41    /// Notify the observer that the epoch has changed, if one is present.
42    pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
43        if let Some(observer) = self.epoch_observer.read().await.as_ref() {
44            observer.epoch_changed(conversation_id, epoch).await;
45        }
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use rstest::rstest;
52    use rstest_reuse::apply;
53
54    use crate::test_utils::{TestContext, TestEpochObserver, all_cred_cipher};
55
56    #[apply(all_cred_cipher)]
57    pub async fn observe_local_epoch_change(case: TestContext) {
58        let [session_context] = case.sessions().await;
59        Box::pin(async move {
60            let test_conv = case.create_conversation([&session_context]).await;
61
62            let observer = TestEpochObserver::new();
63            session_context
64                .session()
65                .await
66                .register_epoch_observer(observer.clone())
67                .await
68                .unwrap();
69
70            // trigger an epoch
71            let id = test_conv.advance_epoch().await.id;
72
73            // ensure we have observed the epoch change
74            let observed_epochs = observer.observed_epochs().await;
75            assert_eq!(
76                observed_epochs.len(),
77                1,
78                "we triggered exactly one epoch change and so should observe one epoch change"
79            );
80            assert_eq!(
81                observed_epochs[0].0, id,
82                "conversation id of observed epoch change must match"
83            );
84        })
85        .await
86    }
87
88    #[apply(all_cred_cipher)]
89    pub async fn observe_remote_epoch_change(case: TestContext) {
90        let [alice, bob] = case.sessions().await;
91        Box::pin(async move {
92            let test_conv = case.create_conversation([&alice, &bob]).await;
93
94            //  bob has the observer
95            let observer = TestEpochObserver::new();
96            bob.session()
97                .await
98                .register_epoch_observer(observer.clone())
99                .await
100                .unwrap();
101
102            // alice triggers an epoch
103            let id = test_conv.advance_epoch().await.id;
104
105            // ensure we have observed the epoch change
106            let observed_epochs = observer.observed_epochs().await;
107            assert_eq!(
108                observed_epochs.len(),
109                1,
110                "we triggered exactly one epoch change and so should observe one epoch change"
111            );
112            assert_eq!(
113                observed_epochs[0].0, id,
114                "conversation id of observed epoch change must match"
115            );
116        })
117        .await
118    }
119}