core_crypto/mls/client/
epoch_observer.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::{CoreCrypto, RecursiveError, mls::HasClientAndProvider as _, prelude::ConversationId};
6
7use super::{Client, Error, Result};
8
9/// An `EpochObserver` is notified whenever a conversation's epoch changes.
10#[async_trait]
11pub trait EpochObserver: Send + Sync {
12    /// This function will be called every time a conversation's epoch changes.
13    ///
14    /// The `epoch` parameter is the new epoch.
15    ///
16    /// <div class="warning">
17    /// This function must not block! Foreign implementors of this inteface can
18    /// spawn a task indirecting the notification, or (unblocking) send the notification
19    /// on some kind of channel, or anything else, as long as the operation completes
20    /// quickly.
21    /// </div>
22    async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
23}
24
25impl Client {
26    /// Add an epoch observer to this client.
27    ///
28    /// This function should be called 0 or 1 times in a client's lifetime. If called
29    /// when an epoch observer already exists, this will return an error.
30    pub(crate) async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
31        let mut guard = self.state.write().await;
32        let inner = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
33        if inner.epoch_observer.is_some() {
34            return Err(Error::EpochObserverAlreadyExists);
35        }
36        inner.epoch_observer = Some(epoch_observer);
37        Ok(())
38    }
39
40    /// Notify the observer that the epoch has changed, if one is present.
41    pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
42        let guard = self.state.read().await;
43        if let Some(inner) = guard.as_ref() {
44            if let Some(observer) = inner.epoch_observer.as_ref() {
45                observer.epoch_changed(conversation_id, epoch).await;
46            }
47        }
48    }
49}
50
51impl CoreCrypto {
52    /// Add an epoch observer to this client.
53    ///
54    /// This function should be called 0 or 1 times in a client's lifetime.
55    /// If called when an epoch observer already exists, this will return an error.
56    pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
57        let client = self.client().await.map_err(RecursiveError::mls("getting mls client"))?;
58        client.register_epoch_observer(epoch_observer).await
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use rstest::rstest;
65    use rstest_reuse::apply;
66    use wasm_bindgen_test::*;
67
68    use crate::test_utils::{TestCase, TestEpochObserver, all_cred_cipher, conversation_id, run_test_with_client_ids};
69
70    #[apply(all_cred_cipher)]
71    #[wasm_bindgen_test]
72    pub async fn observe_local_epoch_change(case: TestCase) {
73        run_test_with_client_ids(case.clone(), ["alice"], move |[client]| {
74            Box::pin(async move {
75                let id = conversation_id();
76                client
77                    .context
78                    .new_conversation(&id, case.credential_type, case.cfg.clone())
79                    .await
80                    .unwrap();
81
82                let observer = TestEpochObserver::new();
83                client
84                    .client()
85                    .await
86                    .register_epoch_observer(observer.clone())
87                    .await
88                    .unwrap();
89
90                // trigger an epoch
91                client
92                    .context
93                    .conversation(&id)
94                    .await
95                    .unwrap()
96                    .update_key_material()
97                    .await
98                    .unwrap();
99
100                // ensure we have observed the epoch change
101                let observed_epochs = observer.observed_epochs().await;
102                assert_eq!(
103                    observed_epochs.len(),
104                    1,
105                    "we triggered exactly one epoch change and so should observe one epoch change"
106                );
107                assert_eq!(
108                    observed_epochs[0].0, id,
109                    "conversation id of observed epoch change must match"
110                );
111            })
112        })
113        .await
114    }
115
116    #[apply(all_cred_cipher)]
117    #[wasm_bindgen_test]
118    pub async fn observe_remote_epoch_change(case: TestCase) {
119        run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice, bob]| {
120            Box::pin(async move {
121                let id = conversation_id();
122                alice
123                    .context
124                    .new_conversation(&id, case.credential_type, case.cfg.clone())
125                    .await
126                    .unwrap();
127
128                alice.invite_all(&case, &id, [&bob]).await.unwrap();
129
130                //  bob has the observer
131                let observer = TestEpochObserver::new();
132                bob.client()
133                    .await
134                    .register_epoch_observer(observer.clone())
135                    .await
136                    .unwrap();
137
138                // alice triggers an epoch
139                alice
140                    .context
141                    .conversation(&id)
142                    .await
143                    .unwrap()
144                    .update_key_material()
145                    .await
146                    .unwrap();
147
148                // communicate that to bob
149                let commit = alice.mls_transport.latest_commit().await;
150                bob.context
151                    .conversation(&id)
152                    .await
153                    .unwrap()
154                    .decrypt_message(commit.to_bytes().unwrap())
155                    .await
156                    .unwrap();
157
158                // ensure we have observed the epoch change
159                let observed_epochs = observer.observed_epochs().await;
160                assert_eq!(
161                    observed_epochs.len(),
162                    1,
163                    "we triggered exactly one epoch change and so should observe one epoch change"
164                );
165                assert_eq!(
166                    observed_epochs[0].0, id,
167                    "conversation id of observed epoch change must match"
168                );
169            })
170        })
171        .await
172    }
173}