core_crypto/transaction_context/conversation/
welcome.rs1use std::borrow::BorrowMut as _;
4
5use openmls::prelude::{MlsMessageIn, MlsMessageInBody};
6use tls_codec::Deserialize as _;
7
8use super::{Error, Result, TransactionContext};
9use crate::{ConversationId, MlsConversation, MlsConversationConfiguration, RecursiveError};
10
11impl TransactionContext {
12 #[cfg_attr(test, crate::dispotent)]
24 pub async fn process_raw_welcome_message(&self, welcome: &[u8]) -> Result<ConversationId> {
25 let mut cursor = std::io::Cursor::new(welcome);
26 let welcome =
27 MlsMessageIn::tls_deserialize(&mut cursor).map_err(Error::tls_deserialize("mls message in (welcome)"))?;
28 self.process_welcome_message(welcome).await
29 }
30
31 #[cfg_attr(test, crate::dispotent)]
44 pub async fn process_welcome_message(&self, welcome: MlsMessageIn) -> Result<ConversationId> {
45 let database = &self.database().await?;
46 let MlsMessageInBody::Welcome(welcome) = welcome.extract() else {
47 return Err(Error::CallerError(
48 "the message provided to process_welcome_message was not a welcome message",
49 ));
50 };
51 let cs = welcome.ciphersuite().into();
52 let configuration = MlsConversationConfiguration {
53 ciphersuite: cs,
54 ..Default::default()
55 };
56 let mls_provider = self
57 .mls_provider()
58 .await
59 .map_err(RecursiveError::transaction("getting mls provider"))?;
60 let mut mls_groups = self
61 .mls_groups()
62 .await
63 .map_err(RecursiveError::transaction("getting mls groups"))?;
64 let conversation = MlsConversation::from_welcome_message(
65 welcome,
66 configuration,
67 &mls_provider,
68 database,
69 mls_groups.borrow_mut(),
70 )
71 .await
72 .map_err(RecursiveError::mls_conversation("creating conversation from welcome"))?;
73
74 let id = conversation.id.clone();
75 mls_groups.insert(&id, conversation);
76
77 Ok(id)
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use crate::test_utils::*;
85
86 #[apply(all_cred_cipher)]
87 async fn joining_from_welcome_should_prune_local_key_material(case: TestContext) {
88 let [alice, bob] = case.sessions().await;
89 Box::pin(async move {
90 let commit_guard = case.create_conversation([&alice]).await.invite([&bob]).await;
93
94 let prev_count = bob.transaction.count_entities().await;
96 commit_guard.notify_members().await;
98
99 let next_count = bob.transaction.count_entities().await;
102 assert_eq!(next_count.key_package, prev_count.key_package - 1);
103 assert_eq!(next_count.hpke_private_key, prev_count.hpke_private_key - 1);
104 assert_eq!(next_count.encryption_keypair, prev_count.encryption_keypair - 1);
105 })
106 .await;
107 }
108
109 #[apply(all_cred_cipher)]
110 async fn process_welcome_should_fail_when_already_exists(case: TestContext) {
111 use crate::LeafError;
112
113 let [alice, bob] = case.sessions().await;
114 Box::pin(async move {
115 let credential_ref = &bob.initial_credential;
116 let commit = case.create_conversation([&alice]).await.invite([&bob]).await;
117 let conversation = commit.conversation();
118 let id = conversation.id().clone();
119 bob
121 .transaction
122 .new_conversation(&id, credential_ref, case.cfg.clone())
123 .await
124 .unwrap();
125
126 let welcome = conversation.transport().await.latest_welcome_message().await;
127 let join_welcome = bob
128 .transaction
129 .process_welcome_message(welcome.into())
130 .await;
131 assert!(
132 matches!(join_welcome.unwrap_err(),
133 Error::Recursive(crate::RecursiveError::MlsConversation { source, .. })
134 if matches!(*source, crate::mls::conversation::Error::Leaf(LeafError::ConversationAlreadyExists(ref i)) if i == &id
135 )
136 )
137 );
138 })
139 .await;
140 }
141}