core_crypto/mls/session/
epoch_observer.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::{CoreCrypto, RecursiveError, mls::HasSessionAndCrypto as _, prelude::ConversationId};
6
7use super::{Error, Result, Session};
8
9#[cfg_attr(target_family = "wasm", async_trait(?Send))]
11#[cfg_attr(not(target_family = "wasm"), async_trait)]
12pub trait EpochObserver: Send + Sync {
13 async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
24}
25
26impl Session {
27 pub(crate) async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
32 let mut guard = self.inner.write().await;
33 let inner = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
34 if inner.epoch_observer.is_some() {
35 return Err(Error::EpochObserverAlreadyExists);
36 }
37 inner.epoch_observer = Some(epoch_observer);
38 Ok(())
39 }
40
41 pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
43 let guard = self.inner.read().await;
44 if let Some(inner) = guard.as_ref() {
45 if let Some(observer) = inner.epoch_observer.as_ref() {
46 observer.epoch_changed(conversation_id, epoch).await;
47 }
48 }
49 }
50}
51
52impl CoreCrypto {
53 pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
58 let session = self
59 .session()
60 .await
61 .map_err(RecursiveError::mls("getting mls session"))?;
62 session.register_epoch_observer(epoch_observer).await
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use rstest::rstest;
69 use rstest_reuse::apply;
70 use wasm_bindgen_test::*;
71
72 use crate::test_utils::{TestCase, TestEpochObserver, all_cred_cipher, conversation_id, run_test_with_client_ids};
73
74 #[apply(all_cred_cipher)]
75 #[wasm_bindgen_test]
76 pub async fn observe_local_epoch_change(case: TestCase) {
77 run_test_with_client_ids(case.clone(), ["alice"], move |[session_context]| {
78 Box::pin(async move {
79 let id = conversation_id();
80 session_context
81 .context
82 .new_conversation(&id, case.credential_type, case.cfg.clone())
83 .await
84 .unwrap();
85
86 let observer = TestEpochObserver::new();
87 session_context
88 .session()
89 .await
90 .register_epoch_observer(observer.clone())
91 .await
92 .unwrap();
93
94 session_context
96 .context
97 .conversation(&id)
98 .await
99 .unwrap()
100 .update_key_material()
101 .await
102 .unwrap();
103
104 let observed_epochs = observer.observed_epochs().await;
106 assert_eq!(
107 observed_epochs.len(),
108 1,
109 "we triggered exactly one epoch change and so should observe one epoch change"
110 );
111 assert_eq!(
112 observed_epochs[0].0, id,
113 "conversation id of observed epoch change must match"
114 );
115 })
116 })
117 .await
118 }
119
120 #[apply(all_cred_cipher)]
121 #[wasm_bindgen_test]
122 pub async fn observe_remote_epoch_change(case: TestCase) {
123 run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice, bob]| {
124 Box::pin(async move {
125 let id = conversation_id();
126 alice
127 .context
128 .new_conversation(&id, case.credential_type, case.cfg.clone())
129 .await
130 .unwrap();
131
132 alice.invite_all(&case, &id, [&bob]).await.unwrap();
133
134 let observer = TestEpochObserver::new();
136 bob.session()
137 .await
138 .register_epoch_observer(observer.clone())
139 .await
140 .unwrap();
141
142 alice
144 .context
145 .conversation(&id)
146 .await
147 .unwrap()
148 .update_key_material()
149 .await
150 .unwrap();
151
152 let commit = alice.mls_transport.latest_commit().await;
154 bob.context
155 .conversation(&id)
156 .await
157 .unwrap()
158 .decrypt_message(commit.to_bytes().unwrap())
159 .await
160 .unwrap();
161
162 let observed_epochs = observer.observed_epochs().await;
164 assert_eq!(
165 observed_epochs.len(),
166 1,
167 "we triggered exactly one epoch change and so should observe one epoch change"
168 );
169 assert_eq!(
170 observed_epochs[0].0, id,
171 "conversation id of observed epoch change must match"
172 );
173 })
174 })
175 .await
176 }
177}