core_crypto/mls/session/
epoch_observer.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::{CoreCrypto, RecursiveError, mls::HasSessionAndCrypto as _, prelude::ConversationId};
6
7use super::{Error, Result, Session};
8
9#[cfg_attr(target_family = "wasm", async_trait(?Send))]
11#[cfg_attr(not(target_family = "wasm"), async_trait)]
12pub trait EpochObserver: Send + Sync {
13 async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
24}
25
26impl Session {
27 pub(crate) async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
32 let mut guard = self.inner.write().await;
33 let inner = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
34 if inner.epoch_observer.is_some() {
35 return Err(Error::EpochObserverAlreadyExists);
36 }
37 inner.epoch_observer = Some(epoch_observer);
38 Ok(())
39 }
40
41 pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
43 let guard = self.inner.read().await;
44 if let Some(inner) = guard.as_ref() {
45 if let Some(observer) = inner.epoch_observer.as_ref() {
46 observer.epoch_changed(conversation_id, epoch).await;
47 }
48 }
49 }
50}
51
52impl CoreCrypto {
53 pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
58 let session = self
59 .session()
60 .await
61 .map_err(RecursiveError::mls("getting mls session"))?;
62 session.register_epoch_observer(epoch_observer).await
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use rstest::rstest;
69 use rstest_reuse::apply;
70 use wasm_bindgen_test::*;
71
72 use crate::test_utils::{
73 TestContext, TestEpochObserver, all_cred_cipher, conversation_id, run_test_with_client_ids,
74 };
75
76 #[apply(all_cred_cipher)]
77 #[wasm_bindgen_test]
78 pub async fn observe_local_epoch_change(case: TestContext) {
79 run_test_with_client_ids(case.clone(), ["alice"], move |[session_context]| {
80 Box::pin(async move {
81 let id = conversation_id();
82 session_context
83 .transaction
84 .new_conversation(&id, case.credential_type, case.cfg.clone())
85 .await
86 .unwrap();
87
88 let observer = TestEpochObserver::new();
89 session_context
90 .session()
91 .await
92 .register_epoch_observer(observer.clone())
93 .await
94 .unwrap();
95
96 session_context
98 .transaction
99 .conversation(&id)
100 .await
101 .unwrap()
102 .update_key_material()
103 .await
104 .unwrap();
105
106 let observed_epochs = observer.observed_epochs().await;
108 assert_eq!(
109 observed_epochs.len(),
110 1,
111 "we triggered exactly one epoch change and so should observe one epoch change"
112 );
113 assert_eq!(
114 observed_epochs[0].0, id,
115 "conversation id of observed epoch change must match"
116 );
117 })
118 })
119 .await
120 }
121
122 #[apply(all_cred_cipher)]
123 #[wasm_bindgen_test]
124 pub async fn observe_remote_epoch_change(case: TestContext) {
125 run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice, bob]| {
126 Box::pin(async move {
127 let id = conversation_id();
128 alice
129 .transaction
130 .new_conversation(&id, case.credential_type, case.cfg.clone())
131 .await
132 .unwrap();
133
134 alice.invite_all(&case, &id, [&bob]).await.unwrap();
135
136 let observer = TestEpochObserver::new();
138 bob.session()
139 .await
140 .register_epoch_observer(observer.clone())
141 .await
142 .unwrap();
143
144 alice
146 .transaction
147 .conversation(&id)
148 .await
149 .unwrap()
150 .update_key_material()
151 .await
152 .unwrap();
153
154 let commit = alice.mls_transport.latest_commit().await;
156 bob.transaction
157 .conversation(&id)
158 .await
159 .unwrap()
160 .decrypt_message(commit.to_bytes().unwrap())
161 .await
162 .unwrap();
163
164 let observed_epochs = observer.observed_epochs().await;
166 assert_eq!(
167 observed_epochs.len(),
168 1,
169 "we triggered exactly one epoch change and so should observe one epoch change"
170 );
171 assert_eq!(
172 observed_epochs[0].0, id,
173 "conversation id of observed epoch change must match"
174 );
175 })
176 })
177 .await
178 }
179}