core_crypto/transaction_context/conversation/
welcome.rs

1//! This module contains transactional conversation operations that are related to processing welcome messages.
2
3use std::borrow::BorrowMut as _;
4
5use openmls::prelude::{MlsMessageIn, MlsMessageInBody};
6use tls_codec::Deserialize as _;
7
8use super::{Error, Result, TransactionContext};
9use crate::{ConversationId, MlsConversation, MlsConversationConfiguration, RecursiveError};
10
11impl TransactionContext {
12    /// Create a conversation from a TLS serialized MLS Welcome message. The `MlsConversationConfiguration` used in this
13    /// function will be the default implementation.
14    ///
15    /// # Arguments
16    /// * `welcome` - a TLS serialized welcome message
17    ///
18    /// # Return type
19    /// This function will return the conversation/group id
20    ///
21    /// # Errors
22    /// see [TransactionContext::process_welcome_message]
23    #[cfg_attr(test, crate::dispotent)]
24    pub async fn process_raw_welcome_message(&self, welcome: &[u8]) -> Result<ConversationId> {
25        let mut cursor = std::io::Cursor::new(welcome);
26        let welcome =
27            MlsMessageIn::tls_deserialize(&mut cursor).map_err(Error::tls_deserialize("mls message in (welcome)"))?;
28        self.process_welcome_message(welcome).await
29    }
30
31    /// Create a conversation from a received MLS Welcome message
32    ///
33    /// # Arguments
34    /// * `welcome` - a `Welcome` message received as a result of a commit adding new members to a group
35    ///
36    /// # Return type
37    /// This function will return the conversation/group id
38    ///
39    /// # Errors
40    /// Errors can be originating from the KeyStore of from OpenMls:
41    /// * if no [openmls::key_packages::KeyPackage] can be read from the KeyStore
42    /// * if the message can't be decrypted
43    #[cfg_attr(test, crate::dispotent)]
44    pub async fn process_welcome_message(&self, welcome: MlsMessageIn) -> Result<ConversationId> {
45        let database = &self.database().await?;
46        let MlsMessageInBody::Welcome(welcome) = welcome.extract() else {
47            return Err(Error::CallerError(
48                "the message provided to process_welcome_message was not a welcome message",
49            ));
50        };
51        let cs = welcome.ciphersuite().into();
52        let configuration = MlsConversationConfiguration {
53            ciphersuite: cs,
54            ..Default::default()
55        };
56        let mls_provider = self
57            .mls_provider()
58            .await
59            .map_err(RecursiveError::transaction("getting mls provider"))?;
60        let mut mls_groups = self
61            .mls_groups()
62            .await
63            .map_err(RecursiveError::transaction("getting mls groups"))?;
64        let conversation = MlsConversation::from_welcome_message(
65            welcome,
66            configuration,
67            &mls_provider,
68            database,
69            mls_groups.borrow_mut(),
70        )
71        .await
72        .map_err(RecursiveError::mls_conversation("creating conversation from welcome"))?;
73
74        let id = conversation.id.clone();
75        mls_groups.insert(&id, conversation);
76
77        Ok(id)
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use crate::test_utils::*;
85
86    #[apply(all_cred_cipher)]
87    async fn joining_from_welcome_should_prune_local_key_material(case: TestContext) {
88        let [alice, bob] = case.sessions().await;
89        Box::pin(async move {
90            // has to be before the original key_package count because it creates one
91            // Create a conversation from alice, where she invites bob
92            let commit_guard = case.create_conversation([&alice]).await.invite([&bob]).await;
93
94            // Keep track of the whatever amount was initially generated
95            let prev_count = bob.transaction.count_entities().await;
96            // Bob accepts the welcome message, and as such, it should prune the used keypackage from the store
97            commit_guard.notify_members().await;
98
99            // Ensure we're left with 1 less keypackage bundle in the store, because it was consumed with the OpenMLS
100            // Welcome message
101            let next_count = bob.transaction.count_entities().await;
102            assert_eq!(next_count.key_package, prev_count.key_package - 1);
103            assert_eq!(next_count.hpke_private_key, prev_count.hpke_private_key - 1);
104            assert_eq!(next_count.encryption_keypair, prev_count.encryption_keypair - 1);
105        })
106        .await;
107    }
108
109    #[apply(all_cred_cipher)]
110    async fn process_welcome_should_fail_when_already_exists(case: TestContext) {
111        use crate::LeafError;
112
113        let [alice, bob] = case.sessions().await;
114        Box::pin(async move {
115            let credential_ref = &bob.initial_credential;
116            let commit = case.create_conversation([&alice]).await.invite([&bob]).await;
117            let conversation = commit.conversation();
118            let id = conversation.id().clone();
119                // Meanwhile Bob creates a conversation with the exact same id as the one he's trying to join
120                bob
121                    .transaction
122                    .new_conversation(&id, credential_ref, case.cfg.clone())
123                    .await
124                    .unwrap();
125
126                let welcome = conversation.transport().await.latest_welcome_message().await;
127                let join_welcome = bob
128                    .transaction
129                    .process_welcome_message(welcome.into())
130                    .await;
131                assert!(
132                    matches!(join_welcome.unwrap_err(),
133                    Error::Recursive(crate::RecursiveError::MlsConversation { source, .. })
134                        if matches!(*source, crate::mls::conversation::Error::Leaf(LeafError::ConversationAlreadyExists(ref i)) if i == &id
135                        )
136                    )
137                );
138            })
139        .await;
140    }
141}