core_crypto/mls/client/
epoch_observer.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::{CoreCrypto, RecursiveError, mls::HasClientAndProvider as _, prelude::ConversationId};
6
7use super::{Client, Error, Result};
8
9#[async_trait]
11pub trait EpochObserver: Send + Sync {
12 async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
23}
24
25impl Client {
26 pub(crate) async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
31 let mut guard = self.state.write().await;
32 let inner = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
33 if inner.epoch_observer.is_some() {
34 return Err(Error::EpochObserverAlreadyExists);
35 }
36 inner.epoch_observer = Some(epoch_observer);
37 Ok(())
38 }
39
40 pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
42 let guard = self.state.read().await;
43 if let Some(inner) = guard.as_ref() {
44 if let Some(observer) = inner.epoch_observer.as_ref() {
45 observer.epoch_changed(conversation_id, epoch).await;
46 }
47 }
48 }
49}
50
51impl CoreCrypto {
52 pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
57 let client = self.client().await.map_err(RecursiveError::mls("getting mls client"))?;
58 client.register_epoch_observer(epoch_observer).await
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use rstest::rstest;
65 use rstest_reuse::apply;
66 use wasm_bindgen_test::*;
67
68 use crate::test_utils::{TestCase, TestEpochObserver, all_cred_cipher, conversation_id, run_test_with_client_ids};
69
70 #[apply(all_cred_cipher)]
71 #[wasm_bindgen_test]
72 pub async fn observe_local_epoch_change(case: TestCase) {
73 run_test_with_client_ids(case.clone(), ["alice"], move |[client]| {
74 Box::pin(async move {
75 let id = conversation_id();
76 client
77 .context
78 .new_conversation(&id, case.credential_type, case.cfg.clone())
79 .await
80 .unwrap();
81
82 let observer = TestEpochObserver::new();
83 client
84 .client()
85 .await
86 .register_epoch_observer(observer.clone())
87 .await
88 .unwrap();
89
90 client
92 .context
93 .conversation(&id)
94 .await
95 .unwrap()
96 .update_key_material()
97 .await
98 .unwrap();
99
100 let observed_epochs = observer.observed_epochs().await;
102 assert_eq!(
103 observed_epochs.len(),
104 1,
105 "we triggered exactly one epoch change and so should observe one epoch change"
106 );
107 assert_eq!(
108 observed_epochs[0].0, id,
109 "conversation id of observed epoch change must match"
110 );
111 })
112 })
113 .await
114 }
115
116 #[apply(all_cred_cipher)]
117 #[wasm_bindgen_test]
118 pub async fn observe_remote_epoch_change(case: TestCase) {
119 run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice, bob]| {
120 Box::pin(async move {
121 let id = conversation_id();
122 alice
123 .context
124 .new_conversation(&id, case.credential_type, case.cfg.clone())
125 .await
126 .unwrap();
127
128 alice.invite_all(&case, &id, [&bob]).await.unwrap();
129
130 let observer = TestEpochObserver::new();
132 bob.client()
133 .await
134 .register_epoch_observer(observer.clone())
135 .await
136 .unwrap();
137
138 alice
140 .context
141 .conversation(&id)
142 .await
143 .unwrap()
144 .update_key_material()
145 .await
146 .unwrap();
147
148 let commit = alice.mls_transport.latest_commit().await;
150 bob.context
151 .conversation(&id)
152 .await
153 .unwrap()
154 .decrypt_message(commit.to_bytes().unwrap())
155 .await
156 .unwrap();
157
158 let observed_epochs = observer.observed_epochs().await;
160 assert_eq!(
161 observed_epochs.len(),
162 1,
163 "we triggered exactly one epoch change and so should observe one epoch change"
164 );
165 assert_eq!(
166 observed_epochs[0].0, id,
167 "conversation id of observed epoch change must match"
168 );
169 })
170 })
171 .await
172 }
173}