core_crypto/mls/conversation/
mod.rs

1//! MLS groups (aka conversation) are the actual entities cementing all the participants in a
2//! conversation.
3//!
4//! This table summarizes what operations are permitted on a group depending its state:
5//! *(PP=pending proposal, PC=pending commit)*
6//!
7//! | can I ?   | 0 PP / 0 PC | 1+ PP / 0 PC | 0 PP / 1 PC | 1+ PP / 1 PC |
8//! |-----------|-------------|--------------|-------------|--------------|
9//! | encrypt   | ✅           | ❌            | ❌           | ❌            |
10//! | handshake | ✅           | ✅            | ❌           | ❌            |
11//! | merge     | ❌           | ❌            | ✅           | ✅            |
12//! | decrypt   | ✅           | ✅            | ✅           | ✅            |
13
14use std::{
15    borrow::{Borrow, Cow},
16    collections::{HashMap, HashSet},
17    ops::Deref,
18    sync::Arc,
19};
20
21use config::MlsConversationConfiguration;
22use core_crypto_keystore::CryptoKeystoreMls;
23use itertools::Itertools as _;
24use log::trace;
25use mls_crypto_provider::{Database, MlsCryptoProvider};
26use openmls::{
27    group::MlsGroup,
28    prelude::{Credential, CredentialWithKey, LeafNodeIndex, Proposal, SignaturePublicKey},
29};
30use openmls_traits::{OpenMlsCryptoProvider, types::SignatureScheme};
31
32use crate::{
33    ClientId, E2eiConversationState, KeystoreError, LeafError, MlsCiphersuite, MlsCredentialType, MlsError,
34    RecursiveError, WireIdentity, mls::Session,
35};
36
37pub(crate) mod commit;
38mod commit_delay;
39pub(crate) mod config;
40pub(crate) mod conversation_guard;
41mod duplicate;
42#[cfg(test)]
43mod durability;
44mod error;
45pub(crate) mod group_info;
46mod immutable_conversation;
47pub(crate) mod merge;
48mod orphan_welcome;
49mod own_commit;
50pub(crate) mod pending_conversation;
51pub(crate) mod proposal;
52mod renew;
53pub(crate) mod welcome;
54mod wipe;
55
56pub use conversation_guard::ConversationGuard;
57pub use error::{Error, Result};
58pub use immutable_conversation::ImmutableConversation;
59
60use super::credential::CredentialBundle;
61use crate::{
62    UserId,
63    mls::{HasSessionAndCrypto, credential::ext::CredentialExt as _},
64};
65
66/// The base layer for [Conversation].
67/// The trait is only exposed internally.
68#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
69#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
70pub(crate) trait ConversationWithMls<'a> {
71    /// [Session] or [TransactionContext] both implement [HasSessionAndCrypto].
72    type Context: HasSessionAndCrypto;
73
74    type Conversation: Deref<Target = MlsConversation> + Send;
75
76    async fn context(&self) -> Result<Self::Context>;
77
78    async fn conversation(&'a self) -> Self::Conversation;
79
80    async fn crypto_provider(&self) -> Result<MlsCryptoProvider> {
81        self.context()
82            .await?
83            .crypto_provider()
84            .await
85            .map_err(RecursiveError::mls("getting mls provider"))
86            .map_err(Into::into)
87    }
88
89    async fn session(&self) -> Result<Session> {
90        self.context()
91            .await?
92            .session()
93            .await
94            .map_err(RecursiveError::mls("getting mls client"))
95            .map_err(Into::into)
96    }
97}
98
99/// The `Conversation` trait provides a set of operations that can be done on
100/// an **immutable** conversation.
101// We keep the super trait internal intentionally, as it is not meant to be used by the public API,
102// hence #[expect(private_bounds)].
103#[expect(private_bounds)]
104#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
105#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
106pub trait Conversation<'a>: ConversationWithMls<'a> {
107    /// Returns the epoch of a given conversation
108    async fn epoch(&'a self) -> u64 {
109        self.conversation().await.group().epoch().as_u64()
110    }
111
112    /// Returns the ciphersuite of a given conversation
113    async fn ciphersuite(&'a self) -> MlsCiphersuite {
114        self.conversation().await.ciphersuite()
115    }
116
117    /// Derives a new key from the one in the group, to be used elsewhere.
118    ///
119    /// # Arguments
120    /// * `key_length` - the length of the key to be derived. If the value is higher than the
121    ///     bounds of `u16` or the context hash * 255, an error will be returned
122    ///
123    /// # Errors
124    /// OpenMls secret generation error
125    async fn export_secret_key(&'a self, key_length: usize) -> Result<Vec<u8>> {
126        const EXPORTER_LABEL: &str = "exporter";
127        const EXPORTER_CONTEXT: &[u8] = &[];
128        let backend = self.crypto_provider().await?;
129        let inner = self.conversation().await;
130        inner
131            .group()
132            .export_secret(&backend, EXPORTER_LABEL, EXPORTER_CONTEXT, key_length)
133            .map_err(MlsError::wrap("exporting secret key"))
134            .map_err(Into::into)
135    }
136
137    /// Exports the clients from a conversation
138    ///
139    /// # Arguments
140    /// * `conversation_id` - the group/conversation id
141    async fn get_client_ids(&'a self) -> Vec<ClientId> {
142        let inner = self.conversation().await;
143        inner
144            .group()
145            .members()
146            .map(|kp| ClientId::from(kp.credential.identity()))
147            .collect()
148    }
149
150    /// Returns the raw public key of the single external sender present in this group.
151    /// This should be used to initialize a subconversation
152    async fn get_external_sender(&'a self) -> Result<Vec<u8>> {
153        let inner = self.conversation().await;
154        let ext_senders = inner
155            .group()
156            .group_context_extensions()
157            .external_senders()
158            .ok_or(Error::MissingExternalSenderExtension)?;
159        let ext_sender = ext_senders.first().ok_or(Error::MissingExternalSenderExtension)?;
160        let ext_sender_public_key = ext_sender.signature_key().as_slice().to_vec();
161        Ok(ext_sender_public_key)
162    }
163
164    /// Indicates when to mark a conversation as not verified i.e. when not all its members have a X509
165    /// Credential generated by Wire's end-to-end identity enrollment
166    async fn e2ei_conversation_state(&'a self) -> Result<E2eiConversationState> {
167        let backend = self.crypto_provider().await?;
168        let authentication_service = backend.authentication_service();
169        authentication_service.refresh_time_of_interest().await;
170        let inner = self.conversation().await;
171        let state = Session::compute_conversation_state(
172            inner.ciphersuite(),
173            inner.group.members_credentials(),
174            MlsCredentialType::X509,
175            authentication_service.borrow().await.as_ref(),
176        )
177        .await;
178        Ok(state)
179    }
180
181    /// From a given conversation, get the identity of the members supplied. Identity is only present for
182    /// members with a Certificate Credential (after turning on end-to-end identity).
183    /// If no member has a x509 certificate, it will return an empty Vec
184    async fn get_device_identities(&'a self, device_ids: &[ClientId]) -> Result<Vec<WireIdentity>> {
185        if device_ids.is_empty() {
186            return Err(Error::CallerError(
187                "This function accepts a list of IDs as a parameter, but that list was empty.",
188            ));
189        }
190        let mls_provider = self.crypto_provider().await?;
191        let auth_service = mls_provider.authentication_service();
192        auth_service.refresh_time_of_interest().await;
193        let auth_service = auth_service.borrow().await;
194        let env = auth_service.as_ref();
195        let conversation = self.conversation().await;
196        conversation
197            .members_with_key()
198            .into_iter()
199            .filter(|(id, _)| device_ids.contains(&ClientId::from(id.as_slice())))
200            .map(|(_, c)| {
201                c.extract_identity(conversation.ciphersuite(), env)
202                    .map_err(RecursiveError::mls_credential("extracting identity"))
203            })
204            .collect::<Result<Vec<_>, _>>()
205            .map_err(Into::into)
206    }
207
208    /// From a given conversation, get the identity of the users (device holders) supplied.
209    /// Identity is only present for devices with a Certificate Credential (after turning on end-to-end identity).
210    /// If no member has a x509 certificate, it will return an empty Vec.
211    ///
212    /// Returns a Map with all the identities for a given users. Consumers are then recommended to
213    /// reduce those identities to determine the actual status of a user.
214    async fn get_user_identities(&'a self, user_ids: &[String]) -> Result<HashMap<String, Vec<WireIdentity>>> {
215        if user_ids.is_empty() {
216            return Err(Error::CallerError(
217                "This function accepts a list of IDs as a parameter, but that list was empty.",
218            ));
219        }
220        let mls_provider = self.crypto_provider().await?;
221        let auth_service = mls_provider.authentication_service();
222        auth_service.refresh_time_of_interest().await;
223        let auth_service = auth_service.borrow().await;
224        let env = auth_service.as_ref();
225        let conversation = self.conversation().await;
226        let user_ids = user_ids.iter().map(|uid| uid.as_bytes()).collect::<Vec<_>>();
227
228        conversation
229            .members_with_key()
230            .iter()
231            .filter_map(|(id, c)| UserId::try_from(id.as_slice()).ok().zip(Some(c)))
232            .filter(|(uid, _)| user_ids.contains(uid))
233            .map(|(uid, c)| {
234                let uid = String::try_from(uid).map_err(RecursiveError::mls_client("getting user identities"))?;
235                let identity = c
236                    .extract_identity(conversation.ciphersuite(), env)
237                    .map_err(RecursiveError::mls_credential("extracting identity"))?;
238                Ok((uid, identity))
239            })
240            .process_results(|iter| iter.into_group_map())
241    }
242
243    /// Generate a new [`crate::HistorySecret`].
244    ///
245    /// This is useful when it's this client's turn to generate a new history client.
246    ///
247    /// The generated secret is cryptographically unrelated to the current CoreCrypto client.
248    async fn generate_history_secret(&'a self) -> Result<crate::HistorySecret> {
249        let ciphersuite = self.ciphersuite().await;
250        crate::ephemeral::generate_history_secret(ciphersuite)
251            .await
252            .map_err(RecursiveError::root("generating history secret"))
253            .map_err(Into::into)
254    }
255
256    /// Check if history sharing is enabled, i.e., if any of the conversation members have a [ClientId] starting
257    /// with [crate::HISTORY_CLIENT_ID_PREFIX].
258    async fn is_history_sharing_enabled(&'a self) -> bool {
259        self.get_client_ids()
260            .await
261            .iter()
262            .any(|client_id| client_id.starts_with(crate::ephemeral::HISTORY_CLIENT_ID_PREFIX.as_bytes()))
263    }
264}
265
266impl<'a, T: ConversationWithMls<'a>> Conversation<'a> for T {}
267
268/// A unique identifier for a group/conversation. The identifier must be unique within a client.
269#[derive(
270    core_crypto_macros::Debug,
271    derive_more::AsRef,
272    derive_more::From,
273    derive_more::Into,
274    PartialEq,
275    Eq,
276    PartialOrd,
277    Ord,
278    Hash,
279    Clone,
280)]
281#[sensitive]
282#[as_ref([u8])]
283#[from(&[u8], Vec<u8>)]
284pub struct ConversationId(Vec<u8>);
285
286impl From<ConversationId> for Cow<'_, [u8]> {
287    fn from(value: ConversationId) -> Self {
288        Cow::Owned(value.0)
289    }
290}
291
292impl<'a> From<&'a ConversationId> for Cow<'a, [u8]> {
293    fn from(value: &'a ConversationId) -> Self {
294        Cow::Borrowed(value.as_ref())
295    }
296}
297
298/// Reference to a ConversationId.
299///
300/// This type is `!Sized` and is only ever seen as a reference, like `str` or `[u8]`.
301//
302// pattern from https://stackoverflow.com/a/64990850
303#[repr(transparent)]
304#[derive(PartialEq, Eq, PartialOrd, Ord, Hash)]
305pub struct ConversationIdRef([u8]);
306
307impl ConversationIdRef {
308    /// Creates a `ConversationId` Ref, needed to implement `Borrow<ConversationIdRef>` for `T`
309    pub fn new<Bytes>(bytes: &Bytes) -> &ConversationIdRef
310    where
311        Bytes: AsRef<[u8]> + ?Sized,
312    {
313        // safety: because of `repr(transparent)` we know that `ConversationIdRef` has a memory layout
314        // identical to `[u8]`, so we can perform this cast
315        unsafe { &*(bytes.as_ref() as *const [u8] as *const ConversationIdRef) }
316    }
317}
318
319impl ConversationIdRef {
320    fn to_bytes(&self) -> Vec<u8> {
321        self.as_ref().to_owned()
322    }
323}
324
325impl Borrow<ConversationIdRef> for ConversationId {
326    fn borrow(&self) -> &ConversationIdRef {
327        ConversationIdRef::new(&self.0)
328    }
329}
330
331impl Deref for ConversationId {
332    type Target = ConversationIdRef;
333
334    fn deref(&self) -> &Self::Target {
335        ConversationIdRef::new(&self.0)
336    }
337}
338
339impl ToOwned for ConversationIdRef {
340    type Owned = ConversationId;
341
342    fn to_owned(&self) -> Self::Owned {
343        ConversationId(self.0.to_owned())
344    }
345}
346
347impl AsRef<[u8]> for ConversationIdRef {
348    fn as_ref(&self) -> &[u8] {
349        &self.0
350    }
351}
352
353impl<'a> From<&'a ConversationIdRef> for Cow<'a, [u8]> {
354    fn from(value: &'a ConversationIdRef) -> Self {
355        Cow::Borrowed(value.as_ref())
356    }
357}
358
359/// This is a wrapper on top of the OpenMls's [MlsGroup], that provides Core Crypto specific functionality
360///
361/// This type will store the state of a group. With the [MlsGroup] it holds, it provides all
362/// operations that can be done in a group, such as creating proposals and commits.
363/// More information [here](https://messaginglayersecurity.rocks/mls-architecture/draft-ietf-mls-architecture.html#name-general-setting)
364#[derive(Debug)]
365#[allow(dead_code)]
366pub struct MlsConversation {
367    pub(crate) id: ConversationId,
368    pub(crate) parent_id: Option<ConversationId>,
369    pub(crate) group: MlsGroup,
370    configuration: MlsConversationConfiguration,
371}
372
373impl MlsConversation {
374    /// Creates a new group/conversation
375    ///
376    /// # Arguments
377    /// * `id` - group/conversation identifier
378    /// * `author_client` - the client responsible for creating the group
379    /// * `creator_credential_type` - kind of credential the creator wants to join the group with
380    /// * `config` - group configuration
381    /// * `backend` - MLS Provider that will be used to persist the group
382    ///
383    /// # Errors
384    /// Errors can happen from OpenMls or from the KeyStore
385    pub async fn create(
386        id: ConversationId,
387        author_client: &Session,
388        creator_credential_type: MlsCredentialType,
389        configuration: MlsConversationConfiguration,
390        backend: &MlsCryptoProvider,
391    ) -> Result<Self> {
392        let (cs, ct) = (configuration.ciphersuite, creator_credential_type);
393        let cb = author_client
394            .get_most_recent_or_create_credential_bundle(backend, cs.signature_algorithm(), ct)
395            .await
396            .map_err(RecursiveError::mls_client("getting or creating credential bundle"))?;
397
398        let group = MlsGroup::new_with_group_id(
399            backend,
400            &cb.signature_key,
401            &configuration.as_openmls_default_configuration()?,
402            openmls::prelude::GroupId::from_slice(id.as_ref()),
403            cb.to_mls_credential_with_key(),
404        )
405        .await
406        .map_err(MlsError::wrap("creating group with id"))?;
407
408        let mut conversation = Self {
409            id,
410            group,
411            parent_id: None,
412            configuration,
413        };
414
415        conversation
416            .persist_group_when_changed(&backend.keystore(), true)
417            .await?;
418
419        Ok(conversation)
420    }
421
422    /// Internal API: create a group from an existing conversation. For example by external commit
423    pub(crate) async fn from_mls_group(
424        group: MlsGroup,
425        configuration: MlsConversationConfiguration,
426        backend: &MlsCryptoProvider,
427    ) -> Result<Self> {
428        let id = ConversationId::from(group.group_id().as_slice());
429
430        let mut conversation = Self {
431            id,
432            group,
433            configuration,
434            parent_id: None,
435        };
436
437        conversation
438            .persist_group_when_changed(&backend.keystore(), true)
439            .await?;
440
441        Ok(conversation)
442    }
443
444    /// Internal API: restore the conversation from a persistence-saved serialized Group State.
445    pub(crate) fn from_serialized_state(buf: Vec<u8>, parent_id: Option<ConversationId>) -> Result<Self> {
446        let group: MlsGroup =
447            core_crypto_keystore::deser(&buf).map_err(KeystoreError::wrap("deserializing group state"))?;
448        let id = ConversationId::from(group.group_id().as_slice());
449        let configuration = MlsConversationConfiguration {
450            ciphersuite: group.ciphersuite().into(),
451            ..Default::default()
452        };
453
454        Ok(Self {
455            id,
456            group,
457            parent_id,
458            configuration,
459        })
460    }
461
462    /// Group/conversation id
463    pub fn id(&self) -> &ConversationId {
464        &self.id
465    }
466
467    pub(crate) fn group(&self) -> &MlsGroup {
468        &self.group
469    }
470
471    /// Returns all members credentials from the group/conversation
472    pub fn members(&self) -> HashMap<Vec<u8>, Credential> {
473        self.group.members().fold(HashMap::new(), |mut acc, kp| {
474            let credential = kp.credential;
475            let id = credential.identity().to_vec();
476            acc.entry(id).or_insert(credential);
477            acc
478        })
479    }
480
481    /// Get actual group members and subtract pending remove proposals
482    pub fn members_in_next_epoch(&self) -> Vec<ClientId> {
483        let pending_removals = self.pending_removals();
484        let existing_clients = self
485            .group
486            .members()
487            .filter_map(|kp| {
488                if !pending_removals.contains(&kp.index) {
489                    Some(kp.credential.identity().into())
490                } else {
491                    trace!(client_index:% = kp.index; "Client is pending removal");
492                    None
493                }
494            })
495            .collect::<HashSet<_>>();
496        existing_clients.into_iter().collect()
497    }
498
499    /// Gather pending remove proposals
500    fn pending_removals(&self) -> Vec<LeafNodeIndex> {
501        self.group
502            .pending_proposals()
503            .filter_map(|proposal| match proposal.proposal() {
504                Proposal::Remove(remove) => Some(remove.removed()),
505                _ => None,
506            })
507            .collect::<Vec<_>>()
508    }
509
510    /// Returns all members credentials with their signature public key from the group/conversation
511    pub fn members_with_key(&self) -> HashMap<Vec<u8>, CredentialWithKey> {
512        self.group.members().fold(HashMap::new(), |mut acc, kp| {
513            let credential = kp.credential;
514            let id = credential.identity().to_vec();
515            let signature_key = SignaturePublicKey::from(kp.signature_key);
516            let credential = CredentialWithKey {
517                credential,
518                signature_key,
519            };
520            acc.entry(id).or_insert(credential);
521            acc
522        })
523    }
524
525    pub(crate) async fn persist_group_when_changed(&mut self, keystore: &Database, force: bool) -> Result<()> {
526        if force || self.group.state_changed() == openmls::group::InnerState::Changed {
527            keystore
528                .mls_group_persist(
529                    &self.id,
530                    &core_crypto_keystore::ser(&self.group).map_err(KeystoreError::wrap("serializing group state"))?,
531                    self.parent_id.as_ref().map(|id| id.as_ref()),
532                )
533                .await
534                .map_err(KeystoreError::wrap("persisting mls group"))?;
535
536            self.group.set_state(openmls::group::InnerState::Persisted);
537        }
538
539        Ok(())
540    }
541
542    pub(crate) fn own_credential_type(&self) -> Result<MlsCredentialType> {
543        Ok(self
544            .group
545            .own_leaf_node()
546            .ok_or(Error::MlsGroupInvalidState("own_leaf_node not present in group"))?
547            .credential()
548            .credential_type()
549            .into())
550    }
551
552    pub(crate) fn ciphersuite(&self) -> MlsCiphersuite {
553        self.configuration.ciphersuite
554    }
555
556    pub(crate) fn signature_scheme(&self) -> SignatureScheme {
557        self.ciphersuite().signature_algorithm()
558    }
559
560    pub(crate) async fn find_current_credential_bundle(&self, client: &Session) -> Result<Arc<CredentialBundle>> {
561        let own_leaf = self.group.own_leaf().ok_or(LeafError::InternalMlsError)?;
562        let sc = self.ciphersuite().signature_algorithm();
563        let ct = self
564            .own_credential_type()
565            .map_err(RecursiveError::mls_conversation("getting own credential type"))?;
566
567        client
568            .find_credential_bundle_by_public_key(sc, ct, own_leaf.signature_key())
569            .await
570            .map_err(RecursiveError::mls_client("finding current credential bundle"))
571            .map_err(Into::into)
572    }
573
574    pub(crate) async fn find_most_recent_credential_bundle(&self, client: &Session) -> Result<Arc<CredentialBundle>> {
575        let sc = self.ciphersuite().signature_algorithm();
576        let ct = self
577            .own_credential_type()
578            .map_err(RecursiveError::mls_conversation("getting own credential type"))?;
579
580        client
581            .find_most_recent_credential_bundle(sc, ct)
582            .await
583            .map_err(RecursiveError::mls_client("finding most recent credential bundle"))
584            .map_err(Into::into)
585    }
586}
587
588#[cfg(test)]
589pub mod test_utils {
590    use super::*;
591
592    impl MlsConversation {
593        pub fn signature_keys(&self) -> impl Iterator<Item = SignaturePublicKey> + '_ {
594            self.group
595                .members()
596                .map(|m| m.signature_key)
597                .map(|mpk| SignaturePublicKey::from(mpk.as_slice()))
598        }
599
600        pub fn encryption_keys(&self) -> impl Iterator<Item = Vec<u8>> + '_ {
601            self.group.members().map(|m| m.encryption_key)
602        }
603
604        pub fn extensions(&self) -> &openmls::prelude::Extensions {
605            self.group.export_group_context().extensions()
606        }
607    }
608}
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613    use crate::test_utils::*;
614
615    #[apply(all_cred_cipher)]
616    pub async fn create_self_conversation_should_succeed(case: TestContext) {
617        let [alice] = case.sessions().await;
618        Box::pin(async move {
619            let conversation = case.create_conversation([&alice]).await;
620            assert_eq!(1, conversation.member_count().await);
621            let alice_can_send_message = conversation.guard().await.encrypt_message(b"me").await;
622            assert!(alice_can_send_message.is_ok());
623        })
624        .await;
625    }
626
627    #[apply(all_cred_cipher)]
628    pub async fn create_1_1_conversation_should_succeed(case: TestContext) {
629        let [alice, bob] = case.sessions().await;
630        Box::pin(async move {
631            let conversation = case.create_conversation([&alice, &bob]).await;
632            assert_eq!(2, conversation.member_count().await);
633            assert!(conversation.is_functional_and_contains([&alice, &bob]).await);
634        })
635        .await;
636    }
637
638    #[apply(all_cred_cipher)]
639    pub async fn create_many_people_conversation(case: TestContext) {
640        const SIZE_PLUS_1: usize = GROUP_SAMPLE_SIZE + 1;
641        let alice_and_friends = case.sessions::<SIZE_PLUS_1>().await;
642        Box::pin(async move {
643            let alice = &alice_and_friends[0];
644            let conversation = case.create_conversation([alice]).await;
645
646            let bob_and_friends = &alice_and_friends[1..];
647            let conversation = conversation.invite_notify(bob_and_friends).await;
648
649            assert_eq!(conversation.member_count().await, 1 + GROUP_SAMPLE_SIZE);
650            assert!(conversation.is_functional_and_contains(&alice_and_friends).await);
651        })
652        .await;
653    }
654
655    mod wire_identity_getters {
656        use super::Error;
657        use crate::{
658            ClientId, DeviceStatus, E2eiConversationState, MlsCredentialType, mls::conversation::Conversation,
659            test_utils::*,
660        };
661
662        async fn all_identities_check<'a, C, const N: usize>(
663            conversation: &'a C,
664            user_ids: &[String; N],
665            expected_sizes: [usize; N],
666        ) where
667            C: Conversation<'a> + Sync,
668        {
669            let all_identities = conversation.get_user_identities(user_ids).await.unwrap();
670            assert_eq!(all_identities.len(), N);
671            for (expected_size, user_id) in expected_sizes.into_iter().zip(user_ids.iter()) {
672                let alice_identities = all_identities.get(user_id).unwrap();
673                assert_eq!(alice_identities.len(), expected_size);
674            }
675            // Not found
676            let not_found = conversation
677                .get_user_identities(&["aaaaaaaaaaaaa".to_string()])
678                .await
679                .unwrap();
680            assert!(not_found.is_empty());
681
682            // Invalid usage
683            let invalid = conversation.get_user_identities(&[]).await;
684            assert!(matches!(invalid.unwrap_err(), Error::CallerError(_)));
685        }
686
687        async fn check_identities_device_status<'a, C, const N: usize>(
688            conversation: &'a C,
689            client_ids: &[ClientId; N],
690            name_status: &[(impl ToString, DeviceStatus); N],
691        ) where
692            C: Conversation<'a> + Sync,
693        {
694            let mut identities = conversation.get_device_identities(client_ids).await.unwrap();
695
696            for (user_name, status) in name_status.iter() {
697                let client_identity = identities.remove(
698                    identities
699                        .iter()
700                        .position(|i| i.x509_identity.as_ref().unwrap().display_name == user_name.to_string())
701                        .unwrap(),
702                );
703                assert_eq!(client_identity.status, *status);
704            }
705            assert!(identities.is_empty());
706
707            assert_eq!(
708                conversation.e2ei_conversation_state().await.unwrap(),
709                E2eiConversationState::NotVerified
710            );
711        }
712
713        #[macro_rules_attribute::apply(smol_macros::test)]
714        async fn should_read_device_identities() {
715            let case = TestContext::default_x509();
716
717            let [alice_android, alice_ios] = case.sessions().await;
718            Box::pin(async move {
719                let conversation = case.create_conversation([&alice_android, &alice_ios]).await;
720
721                let (android_id, ios_id) = (alice_android.get_client_id().await, alice_ios.get_client_id().await);
722
723                let mut android_ids = conversation
724                    .guard()
725                    .await
726                    .get_device_identities(&[android_id.clone(), ios_id.clone()])
727                    .await
728                    .unwrap();
729                android_ids.sort_by(|a, b| a.client_id.cmp(&b.client_id));
730                assert_eq!(android_ids.len(), 2);
731                let mut ios_ids = conversation
732                    .guard_of(&alice_ios)
733                    .await
734                    .get_device_identities(&[android_id.clone(), ios_id.clone()])
735                    .await
736                    .unwrap();
737                ios_ids.sort_by(|a, b| a.client_id.cmp(&b.client_id));
738                assert_eq!(ios_ids.len(), 2);
739
740                assert_eq!(android_ids, ios_ids);
741
742                let android_identities = conversation
743                    .guard()
744                    .await
745                    .get_device_identities(&[android_id])
746                    .await
747                    .unwrap();
748                let android_id = android_identities.first().unwrap();
749                assert_eq!(
750                    android_id.client_id.as_bytes(),
751                    alice_android.transaction.client_id().await.unwrap().0.as_slice()
752                );
753
754                let ios_identities = conversation
755                    .guard()
756                    .await
757                    .get_device_identities(&[ios_id])
758                    .await
759                    .unwrap();
760                let ios_id = ios_identities.first().unwrap();
761                assert_eq!(
762                    ios_id.client_id.as_bytes(),
763                    alice_ios.transaction.client_id().await.unwrap().0.as_slice()
764                );
765
766                let invalid = conversation.guard().await.get_device_identities(&[]).await;
767                assert!(matches!(invalid.unwrap_err(), Error::CallerError(_)));
768            })
769            .await
770        }
771
772        #[macro_rules_attribute::apply(smol_macros::test)]
773        async fn should_read_revoked_device_cross_signed() {
774            let case = TestContext::default_x509();
775            let alice_user_id = uuid::Uuid::new_v4();
776            let bob_user_id = uuid::Uuid::new_v4();
777            let rupert_user_id = uuid::Uuid::new_v4();
778            let john_user_id = uuid::Uuid::new_v4();
779            let dilbert_user_id = uuid::Uuid::new_v4();
780
781            let [alice_client_id] = case.x509_client_ids_for_user(&alice_user_id);
782            let [bob_client_id] = case.x509_client_ids_for_user(&bob_user_id);
783            let [rupert_client_id] = case.x509_client_ids_for_user(&rupert_user_id);
784            let [john_client_id] = case.x509_client_ids_for_user(&john_user_id);
785            let [dilbert_client_id] = case.x509_client_ids_for_user(&dilbert_user_id);
786
787            let sessions = case
788                .sessions_x509_cross_signed_with_client_ids_and_revocation(
789                    [alice_client_id, bob_client_id, rupert_client_id],
790                    [john_client_id, dilbert_client_id],
791                    &[dilbert_user_id.to_string(), rupert_user_id.to_string()],
792                )
793                .await;
794
795            Box::pin(async move {
796                let ([alice, bob, rupert], [john, dilbert]) = &sessions;
797                let mut sessions = sessions.0.iter().chain(sessions.1.iter());
798                let conversation = case.create_conversation(&mut sessions).await;
799
800                let (alice_id, bob_id, rupert_id, john_id, dilbert_id) = (
801                    alice.get_client_id().await,
802                    bob.get_client_id().await,
803                    rupert.get_client_id().await,
804                    john.get_client_id().await,
805                    dilbert.get_client_id().await,
806                );
807
808                let client_ids = [alice_id, bob_id, rupert_id, john_id, dilbert_id];
809                let name_status = [
810                    (alice_user_id, DeviceStatus::Valid),
811                    (bob_user_id, DeviceStatus::Valid),
812                    (rupert_user_id, DeviceStatus::Revoked),
813                    (john_user_id, DeviceStatus::Valid),
814                    (dilbert_user_id, DeviceStatus::Revoked),
815                ];
816                // Do it a multiple times to avoid WPB-6904 happening again
817                for _ in 0..2 {
818                    for session in sessions.clone() {
819                        let conversation = conversation.guard_of(session).await;
820                        check_identities_device_status(&conversation, &client_ids, &name_status).await;
821                    }
822                }
823            })
824            .await
825        }
826
827        #[macro_rules_attribute::apply(smol_macros::test)]
828        async fn should_read_revoked_device() {
829            let case = TestContext::default_x509();
830            let rupert_user_id = uuid::Uuid::new_v4();
831            let bob_user_id = uuid::Uuid::new_v4();
832            let alice_user_id = uuid::Uuid::new_v4();
833
834            let [rupert_client_id] = case.x509_client_ids_for_user(&rupert_user_id);
835            let [alice_client_id] = case.x509_client_ids_for_user(&alice_user_id);
836            let [bob_client_id] = case.x509_client_ids_for_user(&bob_user_id);
837
838            let sessions = case
839                .sessions_x509_with_client_ids_and_revocation(
840                    [alice_client_id.clone(), bob_client_id.clone(), rupert_client_id.clone()],
841                    &[rupert_user_id.to_string()],
842                )
843                .await;
844
845            Box::pin(async move {
846                let [alice, bob, rupert] = &sessions;
847                let conversation = case.create_conversation(&sessions).await;
848
849                let (alice_id, bob_id, rupert_id) = (
850                    alice.get_client_id().await,
851                    bob.get_client_id().await,
852                    rupert.get_client_id().await,
853                );
854
855                let client_ids = [alice_id, bob_id, rupert_id];
856                let name_status = [
857                    (alice_user_id, DeviceStatus::Valid),
858                    (bob_user_id, DeviceStatus::Valid),
859                    (rupert_user_id, DeviceStatus::Revoked),
860                ];
861
862                // Do it a multiple times to avoid WPB-6904 happening again
863                for _ in 0..2 {
864                    for session in sessions.iter() {
865                        let conversation = conversation.guard_of(session).await;
866                        check_identities_device_status(&conversation, &client_ids, &name_status).await;
867                    }
868                }
869            })
870            .await
871        }
872
873        #[macro_rules_attribute::apply(smol_macros::test)]
874        async fn should_not_fail_when_basic() {
875            let case = TestContext::default();
876
877            let [alice_android, alice_ios] = case.sessions().await;
878            Box::pin(async move {
879                let conversation = case.create_conversation([&alice_android, &alice_ios]).await;
880
881                let (android_id, ios_id) = (alice_android.get_client_id().await, alice_ios.get_client_id().await);
882
883                let mut android_ids = conversation
884                    .guard()
885                    .await
886                    .get_device_identities(&[android_id.clone(), ios_id.clone()])
887                    .await
888                    .unwrap();
889                android_ids.sort();
890
891                let mut ios_ids = conversation
892                    .guard_of(&alice_ios)
893                    .await
894                    .get_device_identities(&[android_id, ios_id])
895                    .await
896                    .unwrap();
897                ios_ids.sort();
898
899                assert_eq!(ios_ids.len(), 2);
900                assert_eq!(ios_ids, android_ids);
901
902                assert!(ios_ids.iter().all(|i| {
903                    matches!(i.credential_type, MlsCredentialType::Basic)
904                        && matches!(i.status, DeviceStatus::Valid)
905                        && i.x509_identity.is_none()
906                        && !i.thumbprint.is_empty()
907                        && !i.client_id.is_empty()
908                }));
909            })
910            .await
911        }
912
913        // this test is a duplicate of its counterpart but taking federation into account
914        // The heavy lifting of cross-signing the certificates is being done by the test utils.
915        #[macro_rules_attribute::apply(smol_macros::test)]
916        async fn should_read_users_cross_signed() {
917            let case = TestContext::default_x509();
918            let [alice_1_id, alice_2_id] = case.x509_client_ids_for_user(&uuid::Uuid::new_v4());
919            let [federated_alice_1_id, federated_alice_2_id] = case.x509_client_ids_for_user(&uuid::Uuid::new_v4());
920            let [bob_id, federated_bob_id] = case.x509_client_ids();
921
922            let ([alice_1, alice_2, bob], [federated_alice_1, federated_alice_2, federated_bob]) = case
923                .sessions_x509_cross_signed_with_client_ids(
924                    [alice_1_id, alice_2_id, bob_id],
925                    [federated_alice_1_id, federated_alice_2_id, federated_bob_id],
926                )
927                .await;
928            Box::pin(async move {
929                let sessions = [
930                    &alice_1,
931                    &alice_2,
932                    &bob,
933                    &federated_bob,
934                    &federated_alice_1,
935                    &federated_alice_2,
936                ];
937                let conversation = case.create_conversation(sessions).await;
938
939                let nb_members = conversation.member_count().await;
940                assert_eq!(nb_members, 6);
941                let conversation_guard = conversation.guard().await;
942
943                assert_eq!(alice_1.get_user_id().await, alice_2.get_user_id().await);
944
945                let alicem_user_id = federated_alice_2.get_user_id().await;
946                let bobt_user_id = federated_bob.get_user_id().await;
947
948                // Finds both Alice's devices
949                let alice_user_id = alice_1.get_user_id().await;
950                let alice_identities = conversation_guard
951                    .get_user_identities(std::slice::from_ref(&alice_user_id))
952                    .await
953                    .unwrap();
954                assert_eq!(alice_identities.len(), 1);
955                let identities = alice_identities.get(&alice_user_id).unwrap();
956                assert_eq!(identities.len(), 2);
957
958                // Finds Bob only device
959                let bob_user_id = bob.get_user_id().await;
960                let bob_identities = conversation_guard
961                    .get_user_identities(std::slice::from_ref(&bob_user_id))
962                    .await
963                    .unwrap();
964                assert_eq!(bob_identities.len(), 1);
965                let identities = bob_identities.get(&bob_user_id).unwrap();
966                assert_eq!(identities.len(), 1);
967
968                // Finds all devices
969                let user_ids = [alice_user_id, bob_user_id, alicem_user_id, bobt_user_id];
970                let expected_sizes = [2, 1, 2, 1];
971
972                for session in sessions {
973                    all_identities_check(&conversation.guard_of(session).await, &user_ids, expected_sizes).await;
974                }
975            })
976            .await
977        }
978
979        #[macro_rules_attribute::apply(smol_macros::test)]
980        async fn should_read_users() {
981            let case = TestContext::default_x509();
982            let [alice_android, alice_ios] = case.x509_client_ids_for_user(&uuid::Uuid::new_v4());
983            let [bob_android] = case.x509_client_ids();
984
985            let sessions = case
986                .sessions_x509_with_client_ids([alice_android, alice_ios, bob_android])
987                .await;
988
989            Box::pin(async move {
990                let conversation = case.create_conversation(&sessions).await;
991
992                let nb_members = conversation.member_count().await;
993                assert_eq!(nb_members, 3);
994
995                let [alice_android, alice_ios, bob_android] = &sessions;
996                assert_eq!(alice_android.get_user_id().await, alice_ios.get_user_id().await);
997
998                // Finds both Alice's devices
999                let alice_user_id = alice_android.get_user_id().await;
1000                let alice_identities = conversation
1001                    .guard()
1002                    .await
1003                    .get_user_identities(std::slice::from_ref(&alice_user_id))
1004                    .await
1005                    .unwrap();
1006                assert_eq!(alice_identities.len(), 1);
1007                let identities = alice_identities.get(&alice_user_id).unwrap();
1008                assert_eq!(identities.len(), 2);
1009
1010                // Finds Bob only device
1011                let bob_user_id = bob_android.get_user_id().await;
1012                let bob_identities = conversation
1013                    .guard()
1014                    .await
1015                    .get_user_identities(std::slice::from_ref(&bob_user_id))
1016                    .await
1017                    .unwrap();
1018                assert_eq!(bob_identities.len(), 1);
1019                let identities = bob_identities.get(&bob_user_id).unwrap();
1020                assert_eq!(identities.len(), 1);
1021
1022                let user_ids = [alice_user_id, bob_user_id];
1023                let expected_sizes = [2, 1];
1024
1025                for session in &sessions {
1026                    all_identities_check(&conversation.guard_of(session).await, &user_ids, expected_sizes).await;
1027                }
1028            })
1029            .await
1030        }
1031
1032        #[macro_rules_attribute::apply(smol_macros::test)]
1033        async fn should_exchange_messages_cross_signed() {
1034            let case = TestContext::default_x509();
1035            let sessions = case.sessions_x509_cross_signed::<3, 3>().await;
1036            Box::pin(async move {
1037                let sessions = sessions.0.iter().chain(sessions.1.iter());
1038                let conversation = case.create_conversation(sessions.clone()).await;
1039
1040                assert_eq!(conversation.member_count().await, 6);
1041
1042                assert!(conversation.is_functional_and_contains(sessions).await);
1043            })
1044            .await;
1045        }
1046    }
1047
1048    mod export_secret {
1049        use openmls::prelude::ExportSecretError;
1050
1051        use super::*;
1052        use crate::MlsErrorKind;
1053
1054        #[apply(all_cred_cipher)]
1055        pub async fn can_export_secret_key(case: TestContext) {
1056            let [alice] = case.sessions().await;
1057            Box::pin(async move {
1058                let conversation = case.create_conversation([&alice]).await;
1059
1060                let key_length = 128;
1061                let result = conversation.guard().await.export_secret_key(key_length).await;
1062                assert!(result.is_ok());
1063                assert_eq!(result.unwrap().len(), key_length);
1064            })
1065            .await
1066        }
1067
1068        #[apply(all_cred_cipher)]
1069        pub async fn cannot_export_secret_key_invalid_length(case: TestContext) {
1070            let [alice] = case.sessions().await;
1071            Box::pin(async move {
1072                let conversation = case.create_conversation([&alice]).await;
1073
1074                let result = conversation.guard().await.export_secret_key(usize::MAX).await;
1075                let error = result.unwrap_err();
1076                assert!(innermost_source_matches!(
1077                    error,
1078                    MlsErrorKind::MlsExportSecretError(ExportSecretError::KeyLengthTooLong)
1079                ));
1080            })
1081            .await
1082        }
1083    }
1084
1085    mod get_client_ids {
1086        use super::*;
1087
1088        #[apply(all_cred_cipher)]
1089        pub async fn can_get_client_ids(case: TestContext) {
1090            let [alice, bob] = case.sessions().await;
1091            Box::pin(async move {
1092                let conversation = case.create_conversation([&alice]).await;
1093
1094                assert_eq!(conversation.guard().await.get_client_ids().await.len(), 1);
1095
1096                let conversation = conversation.invite_notify([&bob]).await;
1097
1098                assert_eq!(conversation.guard().await.get_client_ids().await.len(), 2);
1099            })
1100            .await
1101        }
1102    }
1103
1104    mod external_sender {
1105        use super::*;
1106
1107        #[apply(all_cred_cipher)]
1108        pub async fn should_fetch_ext_sender(mut case: TestContext) {
1109            let [alice, external_sender] = case.sessions().await;
1110            Box::pin(async move {
1111                let conversation = case
1112                    .create_conversation_with_external_sender(&external_sender, [&alice])
1113                    .await;
1114
1115                let alice_ext_sender = conversation.guard().await.get_external_sender().await.unwrap();
1116                assert!(!alice_ext_sender.is_empty());
1117                assert_eq!(
1118                    alice_ext_sender,
1119                    external_sender.client_signature_key(&case).await.as_slice().to_vec()
1120                );
1121            })
1122            .await
1123        }
1124    }
1125}