core_crypto/mls/session/
epoch_observer.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::prelude::ConversationId;
6
7use super::{Error, Result, Session};
8
9#[cfg_attr(target_family = "wasm", async_trait(?Send))]
11#[cfg_attr(not(target_family = "wasm"), async_trait)]
12pub trait EpochObserver: Send + Sync {
13 async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
24}
25
26impl Session {
27 pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
33 let mut observer_guard = self.epoch_observer.write().await;
34 if observer_guard.is_some() {
35 return Err(Error::EpochObserverAlreadyExists);
36 }
37 observer_guard.replace(epoch_observer);
38 Ok(())
39 }
40
41 pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
43 if let Some(observer) = self.epoch_observer.read().await.as_ref() {
44 observer.epoch_changed(conversation_id, epoch).await;
45 }
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use rstest::rstest;
52 use rstest_reuse::apply;
53 use wasm_bindgen_test::*;
54
55 use crate::test_utils::{TestContext, TestEpochObserver, all_cred_cipher, conversation_id};
56
57 #[apply(all_cred_cipher)]
58 #[wasm_bindgen_test]
59 pub async fn observe_local_epoch_change(case: TestContext) {
60 let [session_context] = case.sessions().await;
61 Box::pin(async move {
62 let id = conversation_id();
63 session_context
64 .transaction
65 .new_conversation(&id, case.credential_type, case.cfg.clone())
66 .await
67 .unwrap();
68
69 let observer = TestEpochObserver::new();
70 session_context
71 .session()
72 .await
73 .register_epoch_observer(observer.clone())
74 .await
75 .unwrap();
76
77 session_context
79 .transaction
80 .conversation(&id)
81 .await
82 .unwrap()
83 .update_key_material()
84 .await
85 .unwrap();
86
87 let observed_epochs = observer.observed_epochs().await;
89 assert_eq!(
90 observed_epochs.len(),
91 1,
92 "we triggered exactly one epoch change and so should observe one epoch change"
93 );
94 assert_eq!(
95 observed_epochs[0].0, id,
96 "conversation id of observed epoch change must match"
97 );
98 })
99 .await
100 }
101
102 #[apply(all_cred_cipher)]
103 #[wasm_bindgen_test]
104 pub async fn observe_remote_epoch_change(case: TestContext) {
105 let [alice, bob] = case.sessions().await;
106 Box::pin(async move {
107 let id = conversation_id();
108 alice
109 .transaction
110 .new_conversation(&id, case.credential_type, case.cfg.clone())
111 .await
112 .unwrap();
113
114 alice.invite_all(&case, &id, [&bob]).await.unwrap();
115
116 let observer = TestEpochObserver::new();
118 bob.session()
119 .await
120 .register_epoch_observer(observer.clone())
121 .await
122 .unwrap();
123
124 alice
126 .transaction
127 .conversation(&id)
128 .await
129 .unwrap()
130 .update_key_material()
131 .await
132 .unwrap();
133
134 let commit = alice.mls_transport.latest_commit().await;
136 bob.transaction
137 .conversation(&id)
138 .await
139 .unwrap()
140 .decrypt_message(commit.to_bytes().unwrap())
141 .await
142 .unwrap();
143
144 let observed_epochs = observer.observed_epochs().await;
146 assert_eq!(
147 observed_epochs.len(),
148 1,
149 "we triggered exactly one epoch change and so should observe one epoch change"
150 );
151 assert_eq!(
152 observed_epochs[0].0, id,
153 "conversation id of observed epoch change must match"
154 );
155 })
156 .await
157 }
158}