core_crypto/mls/
proposal.rs

1// Wire
2// Copyright (C) 2022 Wire Swiss GmbH
3
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with this program. If not, see http://www.gnu.org/licenses/.
16
17use openmls::prelude::{KeyPackage, hash_ref::ProposalRef};
18
19use mls_crypto_provider::MlsCryptoProvider;
20
21use super::{Error, Result};
22use crate::{
23    RecursiveError,
24    mls::{ClientId, ConversationId, MlsConversation},
25    prelude::{Client, MlsProposalBundle},
26};
27
28use crate::context::CentralContext;
29
30/// Abstraction over a [openmls::prelude::hash_ref::ProposalRef] to deal with conversions
31#[derive(Debug, Clone, Eq, PartialEq, derive_more::From, derive_more::Deref, derive_more::Display)]
32pub struct MlsProposalRef(ProposalRef);
33
34impl From<Vec<u8>> for MlsProposalRef {
35    fn from(value: Vec<u8>) -> Self {
36        Self(ProposalRef::from_slice(value.as_slice()))
37    }
38}
39
40impl MlsProposalRef {
41    /// Duh
42    pub fn into_inner(self) -> ProposalRef {
43        self.0
44    }
45
46    pub(crate) fn to_bytes(&self) -> Vec<u8> {
47        self.0.as_slice().to_vec()
48    }
49}
50
51#[cfg(test)]
52impl From<MlsProposalRef> for Vec<u8> {
53    fn from(prop_ref: MlsProposalRef) -> Self {
54        prop_ref.0.as_slice().to_vec()
55    }
56}
57
58/// Internal representation of proposal to ease further additions
59// To solve the clippy issue we'd need to box the `KeyPackage`, but we can't because we need an
60// owned value of it. We can have it when Box::into_inner is stablized.
61// https://github.com/rust-lang/rust/issues/80437
62#[allow(clippy::large_enum_variant)]
63pub enum MlsProposal {
64    /// Requests that a client with a specified KeyPackage be added to the group
65    Add(KeyPackage),
66    /// Similar mechanism to Add with the distinction that it replaces
67    /// the sender's LeafNode in the tree instead of adding a new leaf to the tree
68    Update,
69    /// Requests that the member with LeafNodeRef removed be removed from the group
70    Remove(ClientId),
71}
72
73impl MlsProposal {
74    /// Creates a new proposal within the specified `MlsGroup`
75    async fn create(
76        self,
77        client: &Client,
78        backend: &MlsCryptoProvider,
79        mut conversation: impl std::ops::DerefMut<Target = MlsConversation>,
80    ) -> Result<MlsProposalBundle> {
81        let proposal = match self {
82            MlsProposal::Add(key_package) => (*conversation)
83                .propose_add_member(client, backend, key_package.into())
84                .await
85                .map_err(RecursiveError::mls_conversation("proposing to add member"))?,
86            MlsProposal::Update => (*conversation)
87                .propose_self_update(client, backend)
88                .await
89                .map_err(RecursiveError::mls_conversation("proposing self update"))?,
90            MlsProposal::Remove(client_id) => {
91                let index = conversation
92                    .group
93                    .members()
94                    .find(|kp| kp.credential.identity() == client_id.as_slice())
95                    .ok_or(Error::ClientNotFound(client_id))
96                    .map(|kp| kp.index)?;
97                (*conversation)
98                    .propose_remove_member(client, backend, index)
99                    .await
100                    .map_err(RecursiveError::mls_conversation("proposing to remove member"))?
101            }
102        };
103        Ok(proposal)
104    }
105}
106
107impl CentralContext {
108    /// Creates a new Add proposal
109    #[cfg_attr(test, crate::idempotent)]
110    pub async fn new_add_proposal(&self, id: &ConversationId, key_package: KeyPackage) -> Result<MlsProposalBundle> {
111        self.new_proposal(id, MlsProposal::Add(key_package)).await
112    }
113
114    /// Creates a new Add proposal
115    #[cfg_attr(test, crate::idempotent)]
116    pub async fn new_remove_proposal(&self, id: &ConversationId, client_id: ClientId) -> Result<MlsProposalBundle> {
117        self.new_proposal(id, MlsProposal::Remove(client_id)).await
118    }
119
120    /// Creates a new Add proposal
121    #[cfg_attr(test, crate::dispotent)]
122    pub async fn new_update_proposal(&self, id: &ConversationId) -> Result<MlsProposalBundle> {
123        self.new_proposal(id, MlsProposal::Update).await
124    }
125
126    /// Creates a new proposal within a group
127    ///
128    /// # Arguments
129    /// * `conversation` - the group/conversation id
130    /// * `proposal` - the proposal do be added in the group
131    ///
132    /// # Return type
133    /// A [MlsProposalBundle] with the proposal in a Mls message and a reference to that proposal in order to rollback it if required
134    ///
135    /// # Errors
136    /// If the conversation is not found, an error will be returned. Errors from OpenMls can be
137    /// returned as well, when for example there's a commit pending to be merged
138    async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result<MlsProposalBundle> {
139        let mut conversation = self
140            .conversation(id)
141            .await
142            .map_err(RecursiveError::mls_conversation("getting conversation by id"))?;
143        let client = &self
144            .mls_client()
145            .await
146            .map_err(RecursiveError::root("getting mls client"))?;
147        proposal
148            .create(
149                client,
150                &self
151                    .mls_provider()
152                    .await
153                    .map_err(RecursiveError::root("getting mls provider"))?,
154                conversation.conversation_mut().await,
155            )
156            .await
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use wasm_bindgen_test::*;
163
164    use crate::{prelude::*, test_utils::*};
165
166    wasm_bindgen_test_configure!(run_in_browser);
167
168    mod add {
169        use super::*;
170
171        #[apply(all_cred_cipher)]
172        #[wasm_bindgen_test]
173        pub async fn should_add_member(case: TestCase) {
174            run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice_central, bob_central]| {
175                Box::pin(async move {
176                    let id = conversation_id();
177                    alice_central
178                        .context
179                        .new_conversation(&id, case.credential_type, case.cfg.clone())
180                        .await
181                        .unwrap();
182                    let bob_kp = bob_central.get_one_key_package(&case).await;
183                    alice_central.context.new_add_proposal(&id, bob_kp).await.unwrap();
184                    alice_central
185                        .context
186                        .conversation(&id)
187                        .await
188                        .unwrap()
189                        .commit_pending_proposals()
190                        .await
191                        .unwrap();
192                    let welcome = alice_central.mls_transport.latest_welcome_message().await;
193                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
194                    let new_id = bob_central
195                        .context
196                        .process_welcome_message(welcome.into(), case.custom_cfg())
197                        .await
198                        .unwrap()
199                        .id;
200                    assert_eq!(id, new_id);
201                    assert!(bob_central.try_talk_to(&id, &alice_central).await.is_ok());
202                })
203            })
204            .await
205        }
206    }
207
208    mod update {
209        use super::*;
210        use itertools::Itertools;
211
212        #[apply(all_cred_cipher)]
213        #[wasm_bindgen_test]
214        pub async fn should_update_hpke_key(case: TestCase) {
215            run_test_with_central(case.clone(), move |[central]| {
216                Box::pin(async move {
217                    let id = conversation_id();
218                    central
219                        .context
220                        .new_conversation(&id, case.credential_type, case.cfg.clone())
221                        .await
222                        .unwrap();
223                    let before = central
224                        .get_conversation_unchecked(&id)
225                        .await
226                        .encryption_keys()
227                        .find_or_first(|_| true)
228                        .unwrap();
229                    central.context.new_update_proposal(&id).await.unwrap();
230                    central
231                        .context
232                        .conversation(&id)
233                        .await
234                        .unwrap()
235                        .commit_pending_proposals()
236                        .await
237                        .unwrap();
238                    let after = central
239                        .get_conversation_unchecked(&id)
240                        .await
241                        .encryption_keys()
242                        .find_or_first(|_| true)
243                        .unwrap();
244                    assert_ne!(before, after)
245                })
246            })
247            .await
248        }
249    }
250
251    mod remove {
252        use super::*;
253
254        #[apply(all_cred_cipher)]
255        #[wasm_bindgen_test]
256        pub async fn should_remove_member(case: TestCase) {
257            use crate::mls;
258
259            run_test_with_client_ids(case.clone(), ["alice", "bob"], |[alice_central, bob_central]| {
260                Box::pin(async move {
261                    let id = conversation_id();
262                    alice_central
263                        .context
264                        .new_conversation(&id, case.credential_type, case.cfg.clone())
265                        .await
266                        .unwrap();
267                    alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
268                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
269                    assert_eq!(bob_central.get_conversation_unchecked(&id).await.members().len(), 2);
270
271                    let remove_proposal = alice_central
272                        .context
273                        .new_remove_proposal(&id, bob_central.get_client_id().await)
274                        .await
275                        .unwrap();
276                    bob_central
277                        .context
278                        .conversation(&id)
279                        .await
280                        .unwrap()
281                        .decrypt_message(remove_proposal.proposal.to_bytes().unwrap())
282                        .await
283                        .unwrap();
284                    alice_central
285                        .context
286                        .conversation(&id)
287                        .await
288                        .unwrap()
289                        .commit_pending_proposals()
290                        .await
291                        .unwrap();
292                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 1);
293
294                    let commit = alice_central.mls_transport.latest_commit().await;
295                    bob_central
296                        .context
297                        .conversation(&id)
298                        .await
299                        .unwrap()
300                        .decrypt_message(commit.to_bytes().unwrap())
301                        .await
302                        .unwrap();
303                    assert!(matches!(
304                        bob_central.context.conversation(&id).await.unwrap_err(),
305                        mls::conversation::Error::Leaf(LeafError::ConversationNotFound(conv_id)) if conv_id == id
306                    ));
307                })
308            })
309            .await
310        }
311
312        #[apply(all_cred_cipher)]
313        #[wasm_bindgen_test]
314        pub async fn should_fail_when_unknown_client(case: TestCase) {
315            use crate::mls;
316
317            run_test_with_client_ids(case.clone(), ["alice"], move |[alice_central]| {
318                Box::pin(async move {
319                    let id = conversation_id();
320                    alice_central
321                        .context
322                        .new_conversation(&id, case.credential_type, case.cfg.clone())
323                        .await
324                        .unwrap();
325
326                    let remove_proposal = alice_central
327                        .context
328                        .new_remove_proposal(&id, b"unknown"[..].into())
329                        .await;
330                    assert!(matches!(
331                        remove_proposal.unwrap_err(),
332                        mls::Error::ClientNotFound(client_id) if client_id == b"unknown"[..].into()
333                    ));
334                })
335            })
336            .await
337        }
338    }
339}