core_crypto/transaction_context/conversation/
proposal.rs

1use openmls::prelude::KeyPackage;
2
3use super::{Error, Result};
4use crate::{
5    ClientId, ConversationId, MlsProposal, MlsProposalBundle, RecursiveError, transaction_context::TransactionContext,
6};
7
8impl TransactionContext {
9    /// Creates a new Add proposal
10    #[cfg_attr(test, crate::idempotent)]
11    pub async fn new_add_proposal(&self, id: &ConversationId, key_package: KeyPackage) -> Result<MlsProposalBundle> {
12        self.new_proposal(id, MlsProposal::Add(key_package)).await
13    }
14
15    /// Creates a new Add proposal
16    #[cfg_attr(test, crate::idempotent)]
17    pub async fn new_remove_proposal(&self, id: &ConversationId, client_id: ClientId) -> Result<MlsProposalBundle> {
18        self.new_proposal(id, MlsProposal::Remove(client_id)).await
19    }
20
21    /// Creates a new Add proposal
22    #[cfg_attr(test, crate::dispotent)]
23    pub async fn new_update_proposal(&self, id: &ConversationId) -> Result<MlsProposalBundle> {
24        self.new_proposal(id, MlsProposal::Update).await
25    }
26
27    /// Creates a new proposal within a group
28    ///
29    /// # Arguments
30    /// * `conversation` - the group/conversation id
31    /// * `proposal` - the proposal do be added in the group
32    ///
33    /// # Return type
34    /// A [MlsProposalBundle] with the proposal in a Mls message and a reference to that proposal in order to rollback
35    /// it if required
36    ///
37    /// # Errors
38    /// If the conversation is not found, an error will be returned. Errors from OpenMls can be
39    /// returned as well, when for example there's a commit pending to be merged
40    async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result<MlsProposalBundle> {
41        let mut conversation = self.conversation(id).await?;
42        let mut conversation = conversation.conversation_mut().await;
43        let client = &self.session().await?;
44        let backend = &self.mls_provider().await?;
45        let proposal = match proposal {
46            MlsProposal::Add(key_package) => conversation
47                .propose_add_member(client, backend, key_package.into())
48                .await
49                .map_err(RecursiveError::mls_conversation("proposing to add member"))?,
50            MlsProposal::Update => conversation
51                .propose_self_update(client, backend)
52                .await
53                .map_err(RecursiveError::mls_conversation("proposing self update"))?,
54            MlsProposal::Remove(client_id) => {
55                let index = conversation
56                    .group
57                    .members()
58                    .find(|kp| kp.credential.identity() == client_id.as_slice())
59                    .ok_or(Error::ClientNotFound(client_id))
60                    .map(|kp| kp.index)?;
61                (*conversation)
62                    .propose_remove_member(client, backend, index)
63                    .await
64                    .map_err(RecursiveError::mls_conversation("proposing to remove member"))?
65            }
66        };
67        Ok(proposal)
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::Error;
74    use crate::{mls::conversation::ConversationWithMls as _, test_utils::*, *};
75
76    mod add {
77        use super::*;
78
79        #[apply(all_cred_cipher)]
80        pub async fn should_add_member(case: TestContext) {
81            let [alice, bob] = case.sessions().await;
82            Box::pin(async move {
83                let conversation = case
84                    .create_conversation([&alice])
85                    .await
86                    .invite_proposal_notify(&bob)
87                    .await
88                    .commit_pending_proposals_notify()
89                    .await;
90                assert_eq!(conversation.member_count().await, 2);
91                assert!(conversation.is_functional_and_contains([&alice, &bob]).await);
92            })
93            .await
94        }
95    }
96
97    mod update {
98        use itertools::Itertools;
99
100        use super::*;
101
102        #[apply(all_cred_cipher)]
103        pub async fn should_update_hpke_key(case: TestContext) {
104            let [session] = case.sessions().await;
105            let conversation = case.create_conversation([&session]).await;
106            let conversation_guard = conversation.guard().await;
107            let before = conversation_guard
108                .conversation()
109                .await
110                .encryption_keys()
111                .find_or_first(|_| true)
112                .unwrap();
113            conversation
114                .update_proposal_notify()
115                .await
116                .commit_pending_proposals_notify()
117                .await;
118            let after = conversation_guard
119                .conversation()
120                .await
121                .encryption_keys()
122                .find_or_first(|_| true)
123                .unwrap();
124            assert_ne!(before, after)
125        }
126    }
127
128    mod remove {
129        use super::*;
130
131        #[apply(all_cred_cipher)]
132        pub async fn should_remove_member(case: TestContext) {
133            let [alice, bob] = case.sessions().await;
134            Box::pin(async move {
135                let conversation = case.create_conversation([&alice, &bob]).await;
136                let id = conversation.id().clone();
137                assert_eq!(conversation.member_count().await, 2);
138
139                let conversation = conversation
140                    .remove_proposal_notify(&bob)
141                    .await
142                    .commit_pending_proposals_notify()
143                    .await;
144
145                assert_eq!(conversation.member_count().await, 1);
146
147                assert!(matches!(
148                    bob.transaction.conversation(&id).await.unwrap_err(),
149                    Error::Leaf(LeafError::ConversationNotFound(conv_id)) if conv_id == id
150                ));
151            })
152            .await
153        }
154
155        #[apply(all_cred_cipher)]
156        pub async fn should_fail_when_unknown_client(case: TestContext) {
157            let [alice] = case.sessions().await;
158            Box::pin(async move {
159                let conversation = case.create_conversation([&alice]).await;
160                let id = conversation.id().clone();
161
162                let remove_proposal = alice
163                    .transaction
164                    .new_remove_proposal(&id, b"unknown".as_slice().to_owned().into())
165                    .await;
166                assert!(matches!(
167                    remove_proposal.unwrap_err(),
168                    Error::ClientNotFound(client_id) if client_id == b"unknown".as_slice()
169                ));
170            })
171            .await
172        }
173    }
174}