core_crypto/mls/session/
epoch_observer.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::prelude::ConversationId;
6
7use super::{Error, Result, Session};
8
9#[cfg_attr(target_family = "wasm", async_trait(?Send))]
11#[cfg_attr(not(target_family = "wasm"), async_trait)]
12pub trait EpochObserver: Send + Sync {
13 async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
24}
25
26impl Session {
27 pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
33 let mut observer_guard = self.epoch_observer.write().await;
34 if observer_guard.is_some() {
35 return Err(Error::EpochObserverAlreadyExists);
36 }
37 observer_guard.replace(epoch_observer);
38 Ok(())
39 }
40
41 pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
43 if let Some(observer) = self.epoch_observer.read().await.as_ref() {
44 observer.epoch_changed(conversation_id, epoch).await;
45 }
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use rstest::rstest;
52 use rstest_reuse::apply;
53 use wasm_bindgen_test::*;
54
55 use crate::test_utils::{TestContext, TestEpochObserver, all_cred_cipher};
56
57 #[apply(all_cred_cipher)]
58 #[wasm_bindgen_test]
59 pub async fn observe_local_epoch_change(case: TestContext) {
60 let [session_context] = case.sessions().await;
61 Box::pin(async move {
62 let test_conv = case.create_conversation([&session_context]).await;
63
64 let observer = TestEpochObserver::new();
65 session_context
66 .session()
67 .await
68 .register_epoch_observer(observer.clone())
69 .await
70 .unwrap();
71
72 let id = test_conv.advance_epoch().await.id;
74
75 let observed_epochs = observer.observed_epochs().await;
77 assert_eq!(
78 observed_epochs.len(),
79 1,
80 "we triggered exactly one epoch change and so should observe one epoch change"
81 );
82 assert_eq!(
83 observed_epochs[0].0, id,
84 "conversation id of observed epoch change must match"
85 );
86 })
87 .await
88 }
89
90 #[apply(all_cred_cipher)]
91 #[wasm_bindgen_test]
92 pub async fn observe_remote_epoch_change(case: TestContext) {
93 let [alice, bob] = case.sessions().await;
94 Box::pin(async move {
95 let test_conv = case.create_conversation([&alice, &bob]).await;
96
97 let observer = TestEpochObserver::new();
99 bob.session()
100 .await
101 .register_epoch_observer(observer.clone())
102 .await
103 .unwrap();
104
105 let id = test_conv.advance_epoch().await.id;
107
108 let observed_epochs = observer.observed_epochs().await;
110 assert_eq!(
111 observed_epochs.len(),
112 1,
113 "we triggered exactly one epoch change and so should observe one epoch change"
114 );
115 assert_eq!(
116 observed_epochs[0].0, id,
117 "conversation id of observed epoch change must match"
118 );
119 })
120 .await
121 }
122}