core_crypto/mls/session/
epoch_observer.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::{CoreCrypto, RecursiveError, mls::HasSessionAndCrypto as _, prelude::ConversationId};
6
7use super::{Error, Result, Session};
8
9/// An `EpochObserver` is notified whenever a conversation's epoch changes.
10#[cfg_attr(target_family = "wasm", async_trait(?Send))]
11#[cfg_attr(not(target_family = "wasm"), async_trait)]
12pub trait EpochObserver: Send + Sync {
13    /// This function will be called every time a conversation's epoch changes.
14    ///
15    /// The `epoch` parameter is the new epoch.
16    ///
17    /// <div class="warning">
18    /// This function must not block! Foreign implementors of this inteface can
19    /// spawn a task indirecting the notification, or (unblocking) send the notification
20    /// on some kind of channel, or anything else, as long as the operation completes
21    /// quickly.
22    /// </div>
23    async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
24}
25
26impl Session {
27    /// Add an epoch observer to this session.
28    ///
29    /// This function should be called 0 or 1 times in a session's lifetime. If called
30    /// when an epoch observer already exists, this will return an error.
31    pub(crate) async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
32        let mut guard = self.inner.write().await;
33        let inner = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
34        if inner.epoch_observer.is_some() {
35            return Err(Error::EpochObserverAlreadyExists);
36        }
37        inner.epoch_observer = Some(epoch_observer);
38        Ok(())
39    }
40
41    /// Notify the observer that the epoch has changed, if one is present.
42    pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
43        let guard = self.inner.read().await;
44        if let Some(inner) = guard.as_ref() {
45            if let Some(observer) = inner.epoch_observer.as_ref() {
46                observer.epoch_changed(conversation_id, epoch).await;
47            }
48        }
49    }
50}
51
52impl CoreCrypto {
53    /// Add an epoch observer to this session.
54    ///
55    /// This function should be called 0 or 1 times in a session's lifetime.
56    /// If called when an epoch observer already exists, this will return an error.
57    pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
58        let session = self
59            .session()
60            .await
61            .map_err(RecursiveError::mls("getting mls session"))?;
62        session.register_epoch_observer(epoch_observer).await
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use rstest::rstest;
69    use rstest_reuse::apply;
70    use wasm_bindgen_test::*;
71
72    use crate::test_utils::{TestCase, TestEpochObserver, all_cred_cipher, conversation_id, run_test_with_client_ids};
73
74    #[apply(all_cred_cipher)]
75    #[wasm_bindgen_test]
76    pub async fn observe_local_epoch_change(case: TestCase) {
77        run_test_with_client_ids(case.clone(), ["alice"], move |[session_context]| {
78            Box::pin(async move {
79                let id = conversation_id();
80                session_context
81                    .context
82                    .new_conversation(&id, case.credential_type, case.cfg.clone())
83                    .await
84                    .unwrap();
85
86                let observer = TestEpochObserver::new();
87                session_context
88                    .session()
89                    .await
90                    .register_epoch_observer(observer.clone())
91                    .await
92                    .unwrap();
93
94                // trigger an epoch
95                session_context
96                    .context
97                    .conversation(&id)
98                    .await
99                    .unwrap()
100                    .update_key_material()
101                    .await
102                    .unwrap();
103
104                // ensure we have observed the epoch change
105                let observed_epochs = observer.observed_epochs().await;
106                assert_eq!(
107                    observed_epochs.len(),
108                    1,
109                    "we triggered exactly one epoch change and so should observe one epoch change"
110                );
111                assert_eq!(
112                    observed_epochs[0].0, id,
113                    "conversation id of observed epoch change must match"
114                );
115            })
116        })
117        .await
118    }
119
120    #[apply(all_cred_cipher)]
121    #[wasm_bindgen_test]
122    pub async fn observe_remote_epoch_change(case: TestCase) {
123        run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice, bob]| {
124            Box::pin(async move {
125                let id = conversation_id();
126                alice
127                    .context
128                    .new_conversation(&id, case.credential_type, case.cfg.clone())
129                    .await
130                    .unwrap();
131
132                alice.invite_all(&case, &id, [&bob]).await.unwrap();
133
134                //  bob has the observer
135                let observer = TestEpochObserver::new();
136                bob.session()
137                    .await
138                    .register_epoch_observer(observer.clone())
139                    .await
140                    .unwrap();
141
142                // alice triggers an epoch
143                alice
144                    .context
145                    .conversation(&id)
146                    .await
147                    .unwrap()
148                    .update_key_material()
149                    .await
150                    .unwrap();
151
152                // communicate that to bob
153                let commit = alice.mls_transport.latest_commit().await;
154                bob.context
155                    .conversation(&id)
156                    .await
157                    .unwrap()
158                    .decrypt_message(commit.to_bytes().unwrap())
159                    .await
160                    .unwrap();
161
162                // ensure we have observed the epoch change
163                let observed_epochs = observer.observed_epochs().await;
164                assert_eq!(
165                    observed_epochs.len(),
166                    1,
167                    "we triggered exactly one epoch change and so should observe one epoch change"
168                );
169                assert_eq!(
170                    observed_epochs[0].0, id,
171                    "conversation id of observed epoch change must match"
172                );
173            })
174        })
175        .await
176    }
177}