core_crypto/mls/session/
epoch_observer.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::prelude::ConversationId;
6
7use super::{Error, Result, Session};
8
9#[cfg_attr(target_family = "wasm", async_trait(?Send))]
11#[cfg_attr(not(target_family = "wasm"), async_trait)]
12pub trait EpochObserver: Send + Sync {
13 async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
24}
25
26impl Session {
27 pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
33 let mut observer_guard = self.epoch_observer.write().await;
34 if observer_guard.is_some() {
35 return Err(Error::EpochObserverAlreadyExists);
36 }
37 observer_guard.replace(epoch_observer);
38 Ok(())
39 }
40
41 pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
43 if let Some(observer) = self.epoch_observer.read().await.as_ref() {
44 observer.epoch_changed(conversation_id, epoch).await;
45 }
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use rstest::rstest;
52 use rstest_reuse::apply;
53
54 use crate::test_utils::{TestContext, TestEpochObserver, all_cred_cipher};
55
56 #[apply(all_cred_cipher)]
57 pub async fn observe_local_epoch_change(case: TestContext) {
58 let [session_context] = case.sessions().await;
59 Box::pin(async move {
60 let test_conv = case.create_conversation([&session_context]).await;
61
62 let observer = TestEpochObserver::new();
63 session_context
64 .session()
65 .await
66 .register_epoch_observer(observer.clone())
67 .await
68 .unwrap();
69
70 let id = test_conv.advance_epoch().await.id;
72 session_context.transaction.finish().await.unwrap();
73
74 let observed_epochs = observer.observed_epochs().await;
76 assert_eq!(
77 observed_epochs.len(),
78 1,
79 "we triggered exactly one epoch change and so should observe one epoch change"
80 );
81 assert_eq!(
82 observed_epochs[0].0, id,
83 "conversation id of observed epoch change must match"
84 );
85 })
86 .await
87 }
88
89 #[apply(all_cred_cipher)]
90 pub async fn observe_remote_epoch_change(case: TestContext) {
91 let [alice, bob] = case.sessions().await;
92 Box::pin(async move {
93 let test_conv = case.create_conversation([&alice, &bob]).await;
94
95 let observer = TestEpochObserver::new();
97 bob.session()
98 .await
99 .register_epoch_observer(observer.clone())
100 .await
101 .unwrap();
102
103 let id = test_conv.advance_epoch().await.id;
105 bob.transaction.finish().await.unwrap();
106
107 let observed_epochs = observer.observed_epochs().await;
109 assert_eq!(
110 observed_epochs.len(),
111 1,
112 "we triggered exactly one epoch change and so should observe one epoch change"
113 );
114 assert_eq!(
115 observed_epochs[0].0, id,
116 "conversation id of observed epoch change must match"
117 );
118 })
119 .await
120 }
121}