core_crypto/mls/
proposal.rs

1// Wire
2// Copyright (C) 2022 Wire Swiss GmbH
3
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with this program. If not, see http://www.gnu.org/licenses/.
16
17use openmls::prelude::{hash_ref::ProposalRef, KeyPackage};
18
19use mls_crypto_provider::MlsCryptoProvider;
20
21use crate::{
22    mls::{ClientId, ConversationId, MlsConversation},
23    prelude::{Client, CryptoError, CryptoResult, MlsProposalBundle},
24};
25
26use crate::context::CentralContext;
27
28/// Abstraction over a [openmls::prelude::hash_ref::ProposalRef] to deal with conversions
29#[derive(Debug, Clone, Eq, PartialEq, derive_more::From, derive_more::Deref, derive_more::Display)]
30pub struct MlsProposalRef(ProposalRef);
31
32impl From<Vec<u8>> for MlsProposalRef {
33    fn from(value: Vec<u8>) -> Self {
34        Self(ProposalRef::from_slice(value.as_slice()))
35    }
36}
37
38impl MlsProposalRef {
39    /// Duh
40    pub fn into_inner(self) -> ProposalRef {
41        self.0
42    }
43
44    pub(crate) fn to_bytes(&self) -> Vec<u8> {
45        self.0.as_slice().to_vec()
46    }
47}
48
49#[cfg(test)]
50impl From<MlsProposalRef> for Vec<u8> {
51    fn from(prop_ref: MlsProposalRef) -> Self {
52        prop_ref.0.as_slice().to_vec()
53    }
54}
55
56/// Internal representation of proposal to ease further additions
57// To solve the clippy issue we'd need to box the `KeyPackage`, but we can't because we need an
58// owned value of it. We can have it when Box::into_inner is stablized.
59// https://github.com/rust-lang/rust/issues/80437
60#[allow(clippy::large_enum_variant)]
61pub enum MlsProposal {
62    /// Requests that a client with a specified KeyPackage be added to the group
63    Add(KeyPackage),
64    /// Similar mechanism to Add with the distinction that it replaces
65    /// the sender's LeafNode in the tree instead of adding a new leaf to the tree
66    Update,
67    /// Requests that the member with LeafNodeRef removed be removed from the group
68    Remove(ClientId),
69}
70
71impl MlsProposal {
72    /// Creates a new proposal within the specified `MlsGroup`
73    async fn create(
74        self,
75        client: &Client,
76        backend: &MlsCryptoProvider,
77        mut conversation: impl std::ops::DerefMut<Target = MlsConversation>,
78    ) -> CryptoResult<MlsProposalBundle> {
79        let proposal = match self {
80            MlsProposal::Add(key_package) => {
81                (*conversation)
82                    .propose_add_member(client, backend, key_package.into())
83                    .await
84            }
85            MlsProposal::Update => (*conversation).propose_self_update(client, backend).await,
86            MlsProposal::Remove(client_id) => {
87                let index = conversation
88                    .group
89                    .members()
90                    .find(|kp| kp.credential.identity() == client_id.as_slice())
91                    .ok_or(CryptoError::ClientNotFound(client_id))
92                    .map(|kp| kp.index)?;
93                (*conversation).propose_remove_member(client, backend, index).await
94            }
95        }?;
96        Ok(proposal)
97    }
98}
99
100impl CentralContext {
101    /// Creates a new Add proposal
102    #[cfg_attr(test, crate::idempotent)]
103    pub async fn new_add_proposal(
104        &self,
105        id: &ConversationId,
106        key_package: KeyPackage,
107    ) -> CryptoResult<MlsProposalBundle> {
108        self.new_proposal(id, MlsProposal::Add(key_package)).await
109    }
110
111    /// Creates a new Add proposal
112    #[cfg_attr(test, crate::idempotent)]
113    pub async fn new_remove_proposal(
114        &self,
115        id: &ConversationId,
116        client_id: ClientId,
117    ) -> CryptoResult<MlsProposalBundle> {
118        self.new_proposal(id, MlsProposal::Remove(client_id)).await
119    }
120
121    /// Creates a new Add proposal
122    #[cfg_attr(test, crate::dispotent)]
123    pub async fn new_update_proposal(&self, id: &ConversationId) -> CryptoResult<MlsProposalBundle> {
124        self.new_proposal(id, MlsProposal::Update).await
125    }
126
127    /// Creates a new proposal within a group
128    ///
129    /// # Arguments
130    /// * `conversation` - the group/conversation id
131    /// * `proposal` - the proposal do be added in the group
132    ///
133    /// # Return type
134    /// A [MlsProposalBundle] with the proposal in a Mls message and a reference to that proposal in order to rollback it if required
135    ///
136    /// # Errors
137    /// If the conversation is not found, an error will be returned. Errors from OpenMls can be
138    /// returned as well, when for example there's a commit pending to be merged
139    async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> CryptoResult<MlsProposalBundle> {
140        let conversation = self.get_conversation(id).await?;
141        let client = &self.mls_client().await?;
142        proposal
143            .create(client, &self.mls_provider().await?, conversation.write().await)
144            .await
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use wasm_bindgen_test::*;
151
152    use crate::{prelude::MlsCommitBundle, prelude::*, test_utils::*};
153
154    wasm_bindgen_test_configure!(run_in_browser);
155
156    mod add {
157        use super::*;
158
159        #[apply(all_cred_cipher)]
160        #[wasm_bindgen_test]
161        pub async fn should_add_member(case: TestCase) {
162            run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice_central, bob_central]| {
163                Box::pin(async move {
164                    let id = conversation_id();
165                    alice_central
166                        .context
167                        .new_conversation(&id, case.credential_type, case.cfg.clone())
168                        .await
169                        .unwrap();
170                    let bob_kp = bob_central.get_one_key_package(&case).await;
171                    alice_central.context.new_add_proposal(&id, bob_kp).await.unwrap();
172                    let MlsCommitBundle { welcome, .. } = alice_central
173                        .context
174                        .commit_pending_proposals(&id)
175                        .await
176                        .unwrap()
177                        .unwrap();
178                    alice_central.context.commit_accepted(&id).await.unwrap();
179                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
180                    let new_id = bob_central
181                        .context
182                        .process_welcome_message(welcome.unwrap().into(), case.custom_cfg())
183                        .await
184                        .unwrap()
185                        .id;
186                    assert_eq!(id, new_id);
187                    assert!(bob_central.try_talk_to(&id, &alice_central).await.is_ok());
188                })
189            })
190            .await
191        }
192    }
193
194    mod update {
195        use super::*;
196        use itertools::Itertools;
197
198        #[apply(all_cred_cipher)]
199        #[wasm_bindgen_test]
200        pub async fn should_update_hpke_key(case: TestCase) {
201            run_test_with_central(case.clone(), move |[central]| {
202                Box::pin(async move {
203                    let id = conversation_id();
204                    central
205                        .context
206                        .new_conversation(&id, case.credential_type, case.cfg.clone())
207                        .await
208                        .unwrap();
209                    let before = central
210                        .get_conversation_unchecked(&id)
211                        .await
212                        .encryption_keys()
213                        .find_or_first(|_| true)
214                        .unwrap();
215                    central.context.new_update_proposal(&id).await.unwrap();
216                    central.context.commit_pending_proposals(&id).await.unwrap();
217                    central.context.commit_accepted(&id).await.unwrap();
218                    let after = central
219                        .get_conversation_unchecked(&id)
220                        .await
221                        .encryption_keys()
222                        .find_or_first(|_| true)
223                        .unwrap();
224                    assert_ne!(before, after)
225                })
226            })
227            .await
228        }
229    }
230
231    mod remove {
232        use super::*;
233
234        #[apply(all_cred_cipher)]
235        #[wasm_bindgen_test]
236        pub async fn should_remove_member(case: TestCase) {
237            run_test_with_client_ids(case.clone(), ["alice", "bob"], |[alice_central, bob_central]| {
238                Box::pin(async move {
239                    let id = conversation_id();
240                    alice_central
241                        .context
242                        .new_conversation(&id, case.credential_type, case.cfg.clone())
243                        .await
244                        .unwrap();
245                    alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
246                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 2);
247                    assert_eq!(bob_central.get_conversation_unchecked(&id).await.members().len(), 2);
248
249                    let remove_proposal = alice_central
250                        .context
251                        .new_remove_proposal(&id, bob_central.get_client_id().await)
252                        .await
253                        .unwrap();
254                    bob_central
255                        .context
256                        .decrypt_message(&id, remove_proposal.proposal.to_bytes().unwrap())
257                        .await
258                        .unwrap();
259                    let MlsCommitBundle { commit, .. } = alice_central
260                        .context
261                        .commit_pending_proposals(&id)
262                        .await
263                        .unwrap()
264                        .unwrap();
265                    alice_central.context.commit_accepted(&id).await.unwrap();
266                    assert_eq!(alice_central.get_conversation_unchecked(&id).await.members().len(), 1);
267
268                    bob_central
269                        .context
270                        .decrypt_message(&id, commit.to_bytes().unwrap())
271                        .await
272                        .unwrap();
273                    assert!(matches!(
274                        bob_central.context.get_conversation(&id).await.unwrap_err(),
275                        CryptoError::ConversationNotFound(conv_id) if conv_id == id
276                    ));
277                })
278            })
279            .await
280        }
281
282        #[apply(all_cred_cipher)]
283        #[wasm_bindgen_test]
284        pub async fn should_fail_when_unknown_client(case: TestCase) {
285            run_test_with_client_ids(case.clone(), ["alice"], move |[alice_central]| {
286                Box::pin(async move {
287                    let id = conversation_id();
288                    alice_central
289                        .context
290                        .new_conversation(&id, case.credential_type, case.cfg.clone())
291                        .await
292                        .unwrap();
293
294                    let remove_proposal = alice_central
295                        .context
296                        .new_remove_proposal(&id, b"unknown"[..].into())
297                        .await;
298                    assert!(matches!(
299                        remove_proposal.unwrap_err(),
300                        CryptoError::ClientNotFound(client_id) if client_id == b"unknown"[..].into()
301                    ));
302                })
303            })
304            .await
305        }
306    }
307}