core_crypto/mls/session/
epoch_observer.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::{CoreCrypto, RecursiveError, mls::HasSessionAndCrypto as _, prelude::ConversationId};
6
7use super::{Error, Result, Session};
8
9/// An `EpochObserver` is notified whenever a conversation's epoch changes.
10#[cfg_attr(target_family = "wasm", async_trait(?Send))]
11#[cfg_attr(not(target_family = "wasm"), async_trait)]
12pub trait EpochObserver: Send + Sync {
13    /// This function will be called every time a conversation's epoch changes.
14    ///
15    /// The `epoch` parameter is the new epoch.
16    ///
17    /// <div class="warning">
18    /// This function must not block! Foreign implementors of this inteface can
19    /// spawn a task indirecting the notification, or (unblocking) send the notification
20    /// on some kind of channel, or anything else, as long as the operation completes
21    /// quickly.
22    /// </div>
23    async fn epoch_changed(&self, conversation_id: ConversationId, epoch: u64);
24}
25
26impl Session {
27    /// Add an epoch observer to this session.
28    ///
29    /// This function should be called 0 or 1 times in a session's lifetime. If called
30    /// when an epoch observer already exists, this will return an error.
31    pub(crate) async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
32        let mut guard = self.inner.write().await;
33        let inner = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
34        if inner.epoch_observer.is_some() {
35            return Err(Error::EpochObserverAlreadyExists);
36        }
37        inner.epoch_observer = Some(epoch_observer);
38        Ok(())
39    }
40
41    /// Notify the observer that the epoch has changed, if one is present.
42    pub(crate) async fn notify_epoch_changed(&self, conversation_id: ConversationId, epoch: u64) {
43        let guard = self.inner.read().await;
44        if let Some(inner) = guard.as_ref() {
45            if let Some(observer) = inner.epoch_observer.as_ref() {
46                observer.epoch_changed(conversation_id, epoch).await;
47            }
48        }
49    }
50}
51
52impl CoreCrypto {
53    /// Add an epoch observer to this session.
54    ///
55    /// This function should be called 0 or 1 times in a session's lifetime.
56    /// If called when an epoch observer already exists, this will return an error.
57    pub async fn register_epoch_observer(&self, epoch_observer: Arc<dyn EpochObserver>) -> Result<()> {
58        let session = self
59            .session()
60            .await
61            .map_err(RecursiveError::mls("getting mls session"))?;
62        session.register_epoch_observer(epoch_observer).await
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use rstest::rstest;
69    use rstest_reuse::apply;
70    use wasm_bindgen_test::*;
71
72    use crate::test_utils::{
73        TestContext, TestEpochObserver, all_cred_cipher, conversation_id, run_test_with_client_ids,
74    };
75
76    #[apply(all_cred_cipher)]
77    #[wasm_bindgen_test]
78    pub async fn observe_local_epoch_change(case: TestContext) {
79        run_test_with_client_ids(case.clone(), ["alice"], move |[session_context]| {
80            Box::pin(async move {
81                let id = conversation_id();
82                session_context
83                    .transaction
84                    .new_conversation(&id, case.credential_type, case.cfg.clone())
85                    .await
86                    .unwrap();
87
88                let observer = TestEpochObserver::new();
89                session_context
90                    .session()
91                    .await
92                    .register_epoch_observer(observer.clone())
93                    .await
94                    .unwrap();
95
96                // trigger an epoch
97                session_context
98                    .transaction
99                    .conversation(&id)
100                    .await
101                    .unwrap()
102                    .update_key_material()
103                    .await
104                    .unwrap();
105
106                // ensure we have observed the epoch change
107                let observed_epochs = observer.observed_epochs().await;
108                assert_eq!(
109                    observed_epochs.len(),
110                    1,
111                    "we triggered exactly one epoch change and so should observe one epoch change"
112                );
113                assert_eq!(
114                    observed_epochs[0].0, id,
115                    "conversation id of observed epoch change must match"
116                );
117            })
118        })
119        .await
120    }
121
122    #[apply(all_cred_cipher)]
123    #[wasm_bindgen_test]
124    pub async fn observe_remote_epoch_change(case: TestContext) {
125        run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice, bob]| {
126            Box::pin(async move {
127                let id = conversation_id();
128                alice
129                    .transaction
130                    .new_conversation(&id, case.credential_type, case.cfg.clone())
131                    .await
132                    .unwrap();
133
134                alice.invite_all(&case, &id, [&bob]).await.unwrap();
135
136                //  bob has the observer
137                let observer = TestEpochObserver::new();
138                bob.session()
139                    .await
140                    .register_epoch_observer(observer.clone())
141                    .await
142                    .unwrap();
143
144                // alice triggers an epoch
145                alice
146                    .transaction
147                    .conversation(&id)
148                    .await
149                    .unwrap()
150                    .update_key_material()
151                    .await
152                    .unwrap();
153
154                // communicate that to bob
155                let commit = alice.mls_transport.latest_commit().await;
156                bob.transaction
157                    .conversation(&id)
158                    .await
159                    .unwrap()
160                    .decrypt_message(commit.to_bytes().unwrap())
161                    .await
162                    .unwrap();
163
164                // ensure we have observed the epoch change
165                let observed_epochs = observer.observed_epochs().await;
166                assert_eq!(
167                    observed_epochs.len(),
168                    1,
169                    "we triggered exactly one epoch change and so should observe one epoch change"
170                );
171                assert_eq!(
172                    observed_epochs[0].0, id,
173                    "conversation id of observed epoch change must match"
174                );
175            })
176        })
177        .await
178    }
179}