core_crypto/transaction_context/e2e_identity/
conversation_state.rs1use crate::{
2 MlsError, RecursiveError,
3 prelude::{MlsCredentialType, Session},
4};
5
6use openmls_traits::OpenMlsCryptoProvider;
7
8use crate::transaction_context::TransactionContext;
9use openmls::{messages::group_info::VerifiableGroupInfo, prelude::Node};
10
11use super::Result;
12
13#[derive(Debug, Clone, Copy, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
18#[repr(u8)]
19pub enum E2eiConversationState {
20 Verified = 1,
22 NotVerified,
24 NotEnabled,
26}
27
28impl TransactionContext {
29 pub async fn e2ei_verify_group_state(&self, group_info: VerifiableGroupInfo) -> Result<E2eiConversationState> {
31 let mls_provider = self
32 .mls_provider()
33 .await
34 .map_err(RecursiveError::transaction("getting mls provider"))?;
35 let auth_service = mls_provider.authentication_service();
36 auth_service.refresh_time_of_interest().await;
37 let cs = group_info.ciphersuite().into();
38
39 let is_sender = true; let Ok(rt) = group_info
41 .take_ratchet_tree(
42 &self
43 .mls_provider()
44 .await
45 .map_err(RecursiveError::transaction("getting mls provider"))?,
46 is_sender,
47 )
48 .await
49 else {
50 return Ok(E2eiConversationState::NotVerified);
51 };
52
53 let credentials = rt.iter().filter_map(|n| match n {
54 Some(Node::LeafNode(ln)) => Some(ln.credential()),
55 _ => None,
56 });
57
58 let auth_service = auth_service.borrow().await;
59 Ok(Session::compute_conversation_state(cs, credentials, MlsCredentialType::X509, auth_service.as_ref()).await)
60 }
61
62 pub async fn get_credential_in_use(
64 &self,
65 group_info: VerifiableGroupInfo,
66 credential_type: MlsCredentialType,
67 ) -> Result<E2eiConversationState> {
68 let cs = group_info.ciphersuite().into();
69 let rt = group_info
74 .take_ratchet_tree(
75 &self
76 .mls_provider()
77 .await
78 .map_err(RecursiveError::transaction("getting mls provider"))?,
79 false,
80 )
81 .await
82 .map_err(MlsError::wrap("taking ratchet tree"))?;
83 let mls_provider = self
84 .mls_provider()
85 .await
86 .map_err(RecursiveError::transaction("getting mls provider"))?;
87 let auth_service = mls_provider.authentication_service().borrow().await;
88 Session::get_credential_in_use_in_ratchet_tree(cs, rt, credential_type, auth_service.as_ref())
89 .await
90 .map_err(RecursiveError::mls_client("getting credentials in use"))
91 .map_err(Into::into)
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::mls::conversation::Conversation as _;
99 use crate::{
100 prelude::{CertificateBundle, MlsCredentialType, Session},
101 test_utils::*,
102 };
103 use wasm_bindgen_test::*;
104
105 wasm_bindgen_test_configure!(run_in_browser);
106
107 #[apply(all_cred_cipher)]
109 #[wasm_bindgen_test]
110 async fn uniform_conversation_should_be_not_verified_when_basic(case: TestCase) {
111 run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice_central, bob_central]| {
112 Box::pin(async move {
113 let id = conversation_id();
114
115 let creator_ct = case.credential_type;
117 alice_central
118 .context
119 .new_conversation(&id, creator_ct, case.cfg.clone())
120 .await
121 .unwrap();
122 alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
123
124 match case.credential_type {
125 MlsCredentialType::Basic => {
126 let alice_state = alice_central
127 .context
128 .conversation(&id)
129 .await
130 .unwrap()
131 .e2ei_conversation_state()
132 .await
133 .unwrap();
134 let bob_state = bob_central
135 .context
136 .conversation(&id)
137 .await
138 .unwrap()
139 .e2ei_conversation_state()
140 .await
141 .unwrap();
142 assert_eq!(alice_state, E2eiConversationState::NotEnabled);
143 assert_eq!(bob_state, E2eiConversationState::NotEnabled);
144
145 let gi = alice_central.get_group_info(&id).await;
146 let state = alice_central
147 .context
148 .get_credential_in_use(gi, MlsCredentialType::X509)
149 .await
150 .unwrap();
151 assert_eq!(state, E2eiConversationState::NotEnabled);
152 }
153 MlsCredentialType::X509 => {
154 let alice_state = alice_central
155 .context
156 .conversation(&id)
157 .await
158 .unwrap()
159 .e2ei_conversation_state()
160 .await
161 .unwrap();
162 let bob_state = bob_central
163 .context
164 .conversation(&id)
165 .await
166 .unwrap()
167 .e2ei_conversation_state()
168 .await
169 .unwrap();
170 assert_eq!(alice_state, E2eiConversationState::Verified);
171 assert_eq!(bob_state, E2eiConversationState::Verified);
172
173 let gi = alice_central.get_group_info(&id).await;
174 let state = alice_central
175 .context
176 .get_credential_in_use(gi, MlsCredentialType::X509)
177 .await
178 .unwrap();
179 assert_eq!(state, E2eiConversationState::Verified);
180 }
181 }
182 })
183 })
184 .await
185 }
186
187 #[apply(all_cred_cipher)]
189 #[wasm_bindgen_test]
190 async fn heterogeneous_conversation_should_be_not_verified(case: TestCase) {
191 use crate::e2e_identity::enrollment::test_utils::failsafe_ctx;
192
193 run_test_with_client_ids(
194 case.clone(),
195 ["alice", "bob"],
196 move |[mut alice_central, mut bob_central]| {
197 Box::pin(async move {
198 let id = conversation_id();
199 let x509_test_chain_arc =
200 failsafe_ctx(&mut [&mut alice_central, &mut bob_central], case.signature_scheme()).await;
201
202 let x509_test_chain = x509_test_chain_arc.as_ref().as_ref().unwrap();
203
204 let alice_client = alice_central.context.session().await.unwrap();
206 let alice_provider = alice_central.context.mls_provider().await.unwrap();
207 let creator_ct = match case.credential_type {
208 MlsCredentialType::Basic => {
209 let intermediate_ca = x509_test_chain.find_local_intermediate_ca();
210 let cert_bundle =
211 CertificateBundle::rand(&alice_client.id().await.unwrap(), intermediate_ca);
212 alice_client
213 .init_x509_credential_bundle_if_missing(
214 &alice_provider,
215 case.signature_scheme(),
216 cert_bundle,
217 )
218 .await
219 .unwrap();
220 MlsCredentialType::X509
221 }
222 MlsCredentialType::X509 => {
223 alice_client
224 .init_basic_credential_bundle_if_missing(&alice_provider, case.signature_scheme())
225 .await
226 .unwrap();
227 MlsCredentialType::Basic
228 }
229 };
230
231 alice_central
232 .context
233 .new_conversation(&id, creator_ct, case.cfg.clone())
234 .await
235 .unwrap();
236 alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
237
238 let alice_state = alice_central
240 .context
241 .conversation(&id)
242 .await
243 .unwrap()
244 .e2ei_conversation_state()
245 .await
246 .unwrap();
247 let bob_state = bob_central
248 .context
249 .conversation(&id)
250 .await
251 .unwrap()
252 .e2ei_conversation_state()
253 .await
254 .unwrap();
255 assert_eq!(alice_state, E2eiConversationState::NotVerified);
256 assert_eq!(bob_state, E2eiConversationState::NotVerified);
257
258 let gi = alice_central.get_group_info(&id).await;
259 let state = alice_central
260 .context
261 .get_credential_in_use(gi, MlsCredentialType::X509)
262 .await
263 .unwrap();
264 assert_eq!(state, E2eiConversationState::NotVerified);
265 })
266 },
267 )
268 .await
269 }
270
271 #[apply(all_cred_cipher)]
272 #[wasm_bindgen_test]
273 async fn should_be_not_verified_when_one_expired(case: TestCase) {
274 if !case.is_x509() {
275 return;
276 }
277 run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice_central, bob_central]| {
278 Box::pin(async move {
279 let id = conversation_id();
280
281 alice_central
282 .context
283 .new_conversation(&id, case.credential_type, case.cfg.clone())
284 .await
285 .unwrap();
286 alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
287
288 let expiration_time = core::time::Duration::from_secs(14);
289 let start = web_time::Instant::now();
290
291 let intermediate_ca = alice_central
292 .x509_test_chain
293 .as_ref()
294 .as_ref()
295 .expect("No x509 test chain")
296 .find_local_intermediate_ca();
297 let cert = CertificateBundle::new_with_default_values(intermediate_ca, Some(expiration_time));
298 let cb = Session::new_x509_credential_bundle(cert.clone()).unwrap();
299 alice_central
300 .context
301 .conversation(&id)
302 .await
303 .unwrap()
304 .e2ei_rotate(Some(&cb))
305 .await
306 .unwrap();
307 let commit = alice_central.mls_transport.latest_commit().await;
308 bob_central
309 .context
310 .conversation(&id)
311 .await
312 .unwrap()
313 .decrypt_message(commit.to_bytes().unwrap())
314 .await
315 .unwrap();
316
317 let alice_client = alice_central.context.session().await.unwrap();
318 let alice_provider = alice_central.context.mls_provider().await.unwrap();
319 alice_client
321 .save_new_x509_credential_bundle(&alice_provider.keystore(), case.signature_scheme(), cert)
322 .await
323 .unwrap();
324
325 let gi = alice_central.get_group_info(&id).await;
327
328 let elapsed = start.elapsed();
329 if expiration_time > elapsed {
331 async_std::task::sleep(expiration_time - elapsed + core::time::Duration::from_secs(1)).await;
332 }
333
334 let alice_state = alice_central
335 .context
336 .conversation(&id)
337 .await
338 .unwrap()
339 .e2ei_conversation_state()
340 .await
341 .unwrap();
342 let bob_state = bob_central
343 .context
344 .conversation(&id)
345 .await
346 .unwrap()
347 .e2ei_conversation_state()
348 .await
349 .unwrap();
350 assert_eq!(alice_state, E2eiConversationState::NotVerified);
351 assert_eq!(bob_state, E2eiConversationState::NotVerified);
352
353 let state = alice_central
354 .context
355 .get_credential_in_use(gi, MlsCredentialType::X509)
356 .await
357 .unwrap();
358 assert_eq!(state, E2eiConversationState::NotVerified);
359 })
360 })
361 .await
362 }
363
364 #[apply(all_cred_cipher)]
365 #[wasm_bindgen_test]
366 async fn should_be_not_verified_when_all_expired(case: TestCase) {
367 if case.is_x509() {
368 run_test_with_client_ids(case.clone(), ["alice"], move |[alice_central]| {
369 Box::pin(async move {
370 let id = conversation_id();
371
372 alice_central
373 .context
374 .new_conversation(&id, case.credential_type, case.cfg.clone())
375 .await
376 .unwrap();
377
378 let expiration_time = core::time::Duration::from_secs(14);
379 let start = web_time::Instant::now();
380 let alice_test_chain = alice_central.x509_test_chain.as_ref().as_ref().unwrap();
381
382 let alice_intermediate_ca = alice_test_chain.find_local_intermediate_ca();
383 let mut alice_cert = alice_test_chain
384 .actors
385 .iter()
386 .find(|actor| actor.name == "alice")
387 .unwrap()
388 .clone();
389 alice_intermediate_ca.update_end_identity(&mut alice_cert.certificate, Some(expiration_time));
390
391 let cert_bundle =
392 CertificateBundle::from_certificate_and_issuer(&alice_cert.certificate, alice_intermediate_ca);
393 let cb = Session::new_x509_credential_bundle(cert_bundle.clone()).unwrap();
394 alice_central
395 .context
396 .conversation(&id)
397 .await
398 .unwrap()
399 .e2ei_rotate(Some(&cb))
400 .await
401 .unwrap();
402
403 let alice_client = alice_central.session().await;
404 let alice_provider = alice_central.context.mls_provider().await.unwrap();
405
406 alice_client
408 .save_new_x509_credential_bundle(
409 &alice_provider.keystore(),
410 case.signature_scheme(),
411 cert_bundle,
412 )
413 .await
414 .unwrap();
415
416 let elapsed = start.elapsed();
417 if expiration_time > elapsed {
419 async_std::task::sleep(expiration_time - elapsed + core::time::Duration::from_secs(1)).await;
420 }
421
422 let alice_state = alice_central
423 .context
424 .conversation(&id)
425 .await
426 .unwrap()
427 .e2ei_conversation_state()
428 .await
429 .unwrap();
430 assert_eq!(alice_state, E2eiConversationState::NotVerified);
431
432 let gi = alice_central.get_group_info(&id).await;
434
435 let state = alice_central
436 .context
437 .get_credential_in_use(gi, MlsCredentialType::X509)
438 .await
439 .unwrap();
440 assert_eq!(state, E2eiConversationState::NotVerified);
441 })
442 })
443 .await
444 }
445 }
446}