core_crypto/mls/session/
epoch_observer.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::prelude::ConversationId;
6
7use super::{Error, Result, Session};
8
9/// An `EpochObserver` is notified whenever a conversation's epoch changes.
10#[cfg_attr(target_family = "wasm", async_trait(?Send))]
11#[cfg_attr(not(target_family = "wasm"), async_trait)]
12pub trait EpochObserver: Send + Sync {
13    /// This function will be called every time a conversation's epoch changes.
14    ///
15    /// The `epoch` parameter is the new epoch.
16    ///
17    /// <div class="warning">
18    /// This function must not block! Foreign implementors of this inteface can
19    /// spawn a task indirecting the notification, or (unblocking) send the notification
20    /// on some kind of channel, or anything else, as long as the operation completes
21    /// quickly.
22    /// </div>
23    async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
24}
25
26impl Session {
27    /// Add an epoch observer to this session.
28    /// (see [EpochObserver]).
29    ///
30    /// This function should be called 0 or 1 times in a session's lifetime. If called
31    /// when an epoch observer already exists, this will return an error.
32    pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
33        let mut observer_guard = self.epoch_observer.write().await;
34        if observer_guard.is_some() {
35            return Err(Error::EpochObserverAlreadyExists);
36        }
37        observer_guard.replace(epoch_observer);
38        Ok(())
39    }
40
41    /// Notify the observer that the epoch has changed, if one is present.
42    pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
43        if let Some(observer) = self.epoch_observer.read().await.as_ref() {
44            observer.epoch_changed(conversation_id, epoch).await;
45        }
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use rstest::rstest;
52    use rstest_reuse::apply;
53
54    use crate::test_utils::{TestContext, TestEpochObserver, all_cred_cipher};
55
56    #[apply(all_cred_cipher)]
57    pub async fn observe_local_epoch_change(case: TestContext) {
58        let [session_context] = case.sessions().await;
59        Box::pin(async move {
60            let test_conv = case.create_conversation([&session_context]).await;
61
62            let observer = TestEpochObserver::new();
63            session_context
64                .session()
65                .await
66                .register_epoch_observer(observer.clone())
67                .await
68                .unwrap();
69
70            // trigger an epoch
71            let id = test_conv.advance_epoch().await.id;
72            session_context.transaction.finish().await.unwrap();
73
74            // ensure we have observed the epoch change
75            let observed_epochs = observer.observed_epochs().await;
76            assert_eq!(
77                observed_epochs.len(),
78                1,
79                "we triggered exactly one epoch change and so should observe one epoch change"
80            );
81            assert_eq!(
82                observed_epochs[0].0, id,
83                "conversation id of observed epoch change must match"
84            );
85        })
86        .await
87    }
88
89    #[apply(all_cred_cipher)]
90    pub async fn observe_remote_epoch_change(case: TestContext) {
91        let [alice, bob] = case.sessions().await;
92        Box::pin(async move {
93            let test_conv = case.create_conversation([&alice, &bob]).await;
94
95            //  bob has the observer
96            let observer = TestEpochObserver::new();
97            bob.session()
98                .await
99                .register_epoch_observer(observer.clone())
100                .await
101                .unwrap();
102
103            // alice triggers an epoch
104            let id = test_conv.advance_epoch().await.id;
105            bob.transaction.finish().await.unwrap();
106
107            // ensure we have observed the epoch change
108            let observed_epochs = observer.observed_epochs().await;
109            assert_eq!(
110                observed_epochs.len(),
111                1,
112                "we triggered exactly one epoch change and so should observe one epoch change"
113            );
114            assert_eq!(
115                observed_epochs[0].0, id,
116                "conversation id of observed epoch change must match"
117            );
118        })
119        .await
120    }
121}