core_crypto/mls/session/
epoch_observer.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use super::{Error, Result, Session};
6use crate::ConversationId;
7
8#[cfg_attr(target_family = "wasm", async_trait(?Send))]
10#[cfg_attr(not(target_family = "wasm"), async_trait)]
11pub trait EpochObserver: Send + Sync {
12 async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
23}
24
25impl Session {
26 pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
32 let mut observer_guard = self.epoch_observer.write().await;
33 if observer_guard.is_some() {
34 return Err(Error::EpochObserverAlreadyExists);
35 }
36 observer_guard.replace(epoch_observer);
37 Ok(())
38 }
39
40 pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
42 if let Some(observer) = self.epoch_observer.read().await.as_ref() {
43 observer.epoch_changed(conversation_id, epoch).await;
44 }
45 }
46}
47
48#[cfg(test)]
49mod tests {
50 use rstest::rstest;
51 use rstest_reuse::apply;
52
53 use crate::test_utils::{TestContext, TestEpochObserver, all_cred_cipher};
54
55 #[apply(all_cred_cipher)]
56 pub async fn observe_local_epoch_change(case: TestContext) {
57 let [session_context] = case.sessions().await;
58 Box::pin(async move {
59 let test_conv = case.create_conversation([&session_context]).await;
60
61 let observer = TestEpochObserver::new();
62 session_context
63 .session()
64 .await
65 .register_epoch_observer(observer.clone())
66 .await
67 .unwrap();
68
69 let id = test_conv.advance_epoch().await.id;
71
72 let observed_epochs = observer.observed_epochs().await;
74 assert_eq!(
75 observed_epochs.len(),
76 1,
77 "we triggered exactly one epoch change and so should observe one epoch change"
78 );
79 assert_eq!(
80 observed_epochs[0].0, id,
81 "conversation id of observed epoch change must match"
82 );
83 })
84 .await
85 }
86
87 #[apply(all_cred_cipher)]
88 pub async fn observe_remote_epoch_change(case: TestContext) {
89 let [alice, bob] = case.sessions().await;
90 Box::pin(async move {
91 let test_conv = case.create_conversation([&alice, &bob]).await;
92
93 let observer = TestEpochObserver::new();
95 bob.session()
96 .await
97 .register_epoch_observer(observer.clone())
98 .await
99 .unwrap();
100
101 let id = test_conv.advance_epoch().await.id;
103
104 let observed_epochs = observer.observed_epochs().await;
106 assert_eq!(
107 observed_epochs.len(),
108 1,
109 "we triggered exactly one epoch change and so should observe one epoch change"
110 );
111 assert_eq!(
112 observed_epochs[0].0, id,
113 "conversation id of observed epoch change must match"
114 );
115 })
116 .await
117 }
118}