core_crypto/mls/session/
epoch_observer.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use super::{Error, Result, Session};
6use crate::ConversationId;
7
8#[cfg_attr(target_os = "unknown", async_trait(?Send))]
10#[cfg_attr(not(target_os = "unknown"), async_trait)]
11pub trait EpochObserver: Send + Sync {
12 async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
23}
24
25impl<D> Session<D> {
26 pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
32 let mut observer_guard = self.epoch_observer.write().await;
33 if observer_guard.is_some() {
34 return Err(Error::EpochObserverAlreadyExists);
35 }
36 observer_guard.replace(epoch_observer);
37 Ok(())
38 }
39
40 pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
42 if let Some(observer) = self.epoch_observer.read().await.as_ref() {
43 observer.epoch_changed(conversation_id, epoch).await;
44 }
45 }
46}
47
48#[cfg(test)]
49mod tests {
50 use rstest::rstest;
51 use rstest_reuse::apply;
52
53 use crate::test_utils::{TestContext, TestEpochObserver, all_cred_cipher};
54
55 #[apply(all_cred_cipher)]
56 pub async fn observe_local_epoch_change(case: TestContext) {
57 let [session_context] = case.sessions().await;
58 Box::pin(async move {
59 let test_conv = case.create_conversation([&session_context]).await;
60
61 let observer = TestEpochObserver::new();
62 session_context
63 .session()
64 .await
65 .register_epoch_observer(observer.clone())
66 .await
67 .unwrap();
68
69 let id = test_conv.advance_epoch().await.id;
71 session_context.transaction.finish().await.unwrap();
72
73 let observed_epochs = observer.observed_epochs().await;
75 assert_eq!(
76 observed_epochs.len(),
77 1,
78 "we triggered exactly one epoch change and so should observe one epoch change"
79 );
80 assert_eq!(
81 observed_epochs[0].0, id,
82 "conversation id of observed epoch change must match"
83 );
84 })
85 .await
86 }
87
88 #[apply(all_cred_cipher)]
89 pub async fn observe_remote_epoch_change(case: TestContext) {
90 let [alice, bob] = case.sessions().await;
91 Box::pin(async move {
92 let test_conv = case.create_conversation([&alice, &bob]).await;
93
94 let observer = TestEpochObserver::new();
96 bob.session()
97 .await
98 .register_epoch_observer(observer.clone())
99 .await
100 .unwrap();
101
102 let id = test_conv.advance_epoch().await.id;
104 bob.transaction.finish().await.unwrap();
105
106 let observed_epochs = observer.observed_epochs().await;
108 assert_eq!(
109 observed_epochs.len(),
110 1,
111 "we triggered exactly one epoch change and so should observe one epoch change"
112 );
113 assert_eq!(
114 observed_epochs[0].0, id,
115 "conversation id of observed epoch change must match"
116 );
117 })
118 .await
119 }
120}