core_crypto/transaction_context/conversation/
proposal.rs

1use openmls::prelude::KeyPackage;
2
3use super::{Error, Result};
4use crate::{
5    ClientId, ConversationId, MlsProposal, MlsProposalBundle, RecursiveError, transaction_context::TransactionContext,
6};
7
8impl TransactionContext {
9    /// Creates a new Add proposal
10    #[cfg_attr(test, crate::idempotent)]
11    pub async fn new_add_proposal(&self, id: &ConversationId, key_package: KeyPackage) -> Result<MlsProposalBundle> {
12        self.new_proposal(id, MlsProposal::Add(key_package)).await
13    }
14
15    /// Creates a new Add proposal
16    #[cfg_attr(test, crate::idempotent)]
17    pub async fn new_remove_proposal(&self, id: &ConversationId, client_id: ClientId) -> Result<MlsProposalBundle> {
18        self.new_proposal(id, MlsProposal::Remove(client_id)).await
19    }
20
21    /// Creates a new Add proposal
22    #[cfg_attr(test, crate::dispotent)]
23    pub async fn new_update_proposal(&self, id: &ConversationId) -> Result<MlsProposalBundle> {
24        self.new_proposal(id, MlsProposal::Update).await
25    }
26
27    /// Creates a new proposal within a group
28    ///
29    /// # Arguments
30    /// * `conversation` - the group/conversation id
31    /// * `proposal` - the proposal do be added in the group
32    ///
33    /// # Return type
34    /// A [MlsProposalBundle] with the proposal in a Mls message and a reference to that proposal in order to rollback it if required
35    ///
36    /// # Errors
37    /// If the conversation is not found, an error will be returned. Errors from OpenMls can be
38    /// returned as well, when for example there's a commit pending to be merged
39    async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result<MlsProposalBundle> {
40        let mut conversation = self.conversation(id).await?;
41        let mut conversation = conversation.conversation_mut().await;
42        let client = &self.session().await?;
43        let backend = &self.mls_provider().await?;
44        let proposal = match proposal {
45            MlsProposal::Add(key_package) => conversation
46                .propose_add_member(client, backend, key_package.into())
47                .await
48                .map_err(RecursiveError::mls_conversation("proposing to add member"))?,
49            MlsProposal::Update => conversation
50                .propose_self_update(client, backend)
51                .await
52                .map_err(RecursiveError::mls_conversation("proposing self update"))?,
53            MlsProposal::Remove(client_id) => {
54                let index = conversation
55                    .group
56                    .members()
57                    .find(|kp| kp.credential.identity() == client_id.as_slice())
58                    .ok_or(Error::ClientNotFound(client_id))
59                    .map(|kp| kp.index)?;
60                (*conversation)
61                    .propose_remove_member(client, backend, index)
62                    .await
63                    .map_err(RecursiveError::mls_conversation("proposing to remove member"))?
64            }
65        };
66        Ok(proposal)
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::Error;
73    use crate::{mls::conversation::ConversationWithMls as _, test_utils::*, *};
74
75    mod add {
76        use super::*;
77
78        #[apply(all_cred_cipher)]
79        pub async fn should_add_member(case: TestContext) {
80            let [alice, bob] = case.sessions().await;
81            Box::pin(async move {
82                let conversation = case
83                    .create_conversation([&alice])
84                    .await
85                    .invite_proposal_notify(&bob)
86                    .await
87                    .commit_pending_proposals_notify()
88                    .await;
89                assert_eq!(conversation.member_count().await, 2);
90                assert!(conversation.is_functional_and_contains([&alice, &bob]).await);
91            })
92            .await
93        }
94    }
95
96    mod update {
97        use itertools::Itertools;
98
99        use super::*;
100
101        #[apply(all_cred_cipher)]
102        pub async fn should_update_hpke_key(case: TestContext) {
103            let [session] = case.sessions().await;
104            let conversation = case.create_conversation([&session]).await;
105            let conversation_guard = conversation.guard().await;
106            let before = conversation_guard
107                .conversation()
108                .await
109                .encryption_keys()
110                .find_or_first(|_| true)
111                .unwrap();
112            conversation
113                .update_proposal_notify()
114                .await
115                .commit_pending_proposals_notify()
116                .await;
117            let after = conversation_guard
118                .conversation()
119                .await
120                .encryption_keys()
121                .find_or_first(|_| true)
122                .unwrap();
123            assert_ne!(before, after)
124        }
125    }
126
127    mod remove {
128        use super::*;
129
130        #[apply(all_cred_cipher)]
131        pub async fn should_remove_member(case: TestContext) {
132            let [alice, bob] = case.sessions().await;
133            Box::pin(async move {
134                let conversation = case.create_conversation([&alice, &bob]).await;
135                let id = conversation.id().clone();
136                assert_eq!(conversation.member_count().await, 2);
137
138                let conversation = conversation
139                    .remove_proposal_notify(&bob)
140                    .await
141                    .commit_pending_proposals_notify()
142                    .await;
143
144                assert_eq!(conversation.member_count().await, 1);
145
146                assert!(matches!(
147                    bob.transaction.conversation(&id).await.unwrap_err(),
148                    Error::Leaf(LeafError::ConversationNotFound(conv_id)) if conv_id == id
149                ));
150            })
151            .await
152        }
153
154        #[apply(all_cred_cipher)]
155        pub async fn should_fail_when_unknown_client(case: TestContext) {
156            let [alice] = case.sessions().await;
157            Box::pin(async move {
158                let conversation = case.create_conversation([&alice]).await;
159                let id = conversation.id().clone();
160
161                let remove_proposal = alice.transaction.new_remove_proposal(&id, b"unknown"[..].into()).await;
162                assert!(matches!(
163                    remove_proposal.unwrap_err(),
164                    Error::ClientNotFound(client_id) if client_id == b"unknown"[..].into()
165                ));
166            })
167            .await
168        }
169    }
170}