core_crypto/mls/session/
epoch_observer.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::prelude::ConversationId;
6
7use super::{Error, Result, Session};
8
9#[cfg_attr(target_family = "wasm", async_trait(?Send))]
11#[cfg_attr(not(target_family = "wasm"), async_trait)]
12pub trait EpochObserver: Send + Sync {
13 async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
24}
25
26impl Session {
27 pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
33 let mut observer_guard = self.epoch_observer.write().await;
34 if observer_guard.is_some() {
35 return Err(Error::EpochObserverAlreadyExists);
36 }
37 observer_guard.replace(epoch_observer);
38 Ok(())
39 }
40
41 pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
43 if let Some(observer) = self.epoch_observer.read().await.as_ref() {
44 observer.epoch_changed(conversation_id, epoch).await;
45 }
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use rstest::rstest;
52 use rstest_reuse::apply;
53
54 use crate::test_utils::{TestContext, TestEpochObserver, all_cred_cipher};
55
56 #[apply(all_cred_cipher)]
57 pub async fn observe_local_epoch_change(case: TestContext) {
58 let [session_context] = case.sessions().await;
59 Box::pin(async move {
60 let test_conv = case.create_conversation([&session_context]).await;
61
62 let observer = TestEpochObserver::new();
63 session_context
64 .session()
65 .await
66 .register_epoch_observer(observer.clone())
67 .await
68 .unwrap();
69
70 let id = test_conv.advance_epoch().await.id;
72
73 let observed_epochs = observer.observed_epochs().await;
75 assert_eq!(
76 observed_epochs.len(),
77 1,
78 "we triggered exactly one epoch change and so should observe one epoch change"
79 );
80 assert_eq!(
81 observed_epochs[0].0, id,
82 "conversation id of observed epoch change must match"
83 );
84 })
85 .await
86 }
87
88 #[apply(all_cred_cipher)]
89 pub async fn observe_remote_epoch_change(case: TestContext) {
90 let [alice, bob] = case.sessions().await;
91 Box::pin(async move {
92 let test_conv = case.create_conversation([&alice, &bob]).await;
93
94 let observer = TestEpochObserver::new();
96 bob.session()
97 .await
98 .register_epoch_observer(observer.clone())
99 .await
100 .unwrap();
101
102 let id = test_conv.advance_epoch().await.id;
104
105 let observed_epochs = observer.observed_epochs().await;
107 assert_eq!(
108 observed_epochs.len(),
109 1,
110 "we triggered exactly one epoch change and so should observe one epoch change"
111 );
112 assert_eq!(
113 observed_epochs[0].0, id,
114 "conversation id of observed epoch change must match"
115 );
116 })
117 .await
118 }
119}