core_crypto/transaction_context/conversation/
proposal.rs

1use openmls::prelude::KeyPackage;
2
3use super::{Error, Result};
4use crate::{
5    RecursiveError,
6    prelude::{ClientId, ConversationId, MlsProposal, MlsProposalBundle},
7    transaction_context::TransactionContext,
8};
9
10impl TransactionContext {
11    /// Creates a new Add proposal
12    #[cfg_attr(test, crate::idempotent)]
13    pub async fn new_add_proposal(&self, id: &ConversationId, key_package: KeyPackage) -> Result<MlsProposalBundle> {
14        self.new_proposal(id, MlsProposal::Add(key_package)).await
15    }
16
17    /// Creates a new Add proposal
18    #[cfg_attr(test, crate::idempotent)]
19    pub async fn new_remove_proposal(&self, id: &ConversationId, client_id: ClientId) -> Result<MlsProposalBundle> {
20        self.new_proposal(id, MlsProposal::Remove(client_id)).await
21    }
22
23    /// Creates a new Add proposal
24    #[cfg_attr(test, crate::dispotent)]
25    pub async fn new_update_proposal(&self, id: &ConversationId) -> Result<MlsProposalBundle> {
26        self.new_proposal(id, MlsProposal::Update).await
27    }
28
29    /// Creates a new proposal within a group
30    ///
31    /// # Arguments
32    /// * `conversation` - the group/conversation id
33    /// * `proposal` - the proposal do be added in the group
34    ///
35    /// # Return type
36    /// A [MlsProposalBundle] with the proposal in a Mls message and a reference to that proposal in order to rollback it if required
37    ///
38    /// # Errors
39    /// If the conversation is not found, an error will be returned. Errors from OpenMls can be
40    /// returned as well, when for example there's a commit pending to be merged
41    async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result<MlsProposalBundle> {
42        let mut conversation = self.conversation(id).await?;
43        let mut conversation = conversation.conversation_mut().await;
44        let client = &self.session().await?;
45        let backend = &self.mls_provider().await?;
46        let proposal = match proposal {
47            MlsProposal::Add(key_package) => conversation
48                .propose_add_member(client, backend, key_package.into())
49                .await
50                .map_err(RecursiveError::mls_conversation("proposing to add member"))?,
51            MlsProposal::Update => conversation
52                .propose_self_update(client, backend)
53                .await
54                .map_err(RecursiveError::mls_conversation("proposing self update"))?,
55            MlsProposal::Remove(client_id) => {
56                let index = conversation
57                    .group
58                    .members()
59                    .find(|kp| kp.credential.identity() == client_id.as_slice())
60                    .ok_or(Error::ClientNotFound(client_id))
61                    .map(|kp| kp.index)?;
62                (*conversation)
63                    .propose_remove_member(client, backend, index)
64                    .await
65                    .map_err(RecursiveError::mls_conversation("proposing to remove member"))?
66            }
67        };
68        Ok(proposal)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use crate::{prelude::*, test_utils::*};
75
76    use super::Error;
77
78    mod add {
79        use super::*;
80
81        #[apply(all_cred_cipher)]
82        pub async fn should_add_member(case: TestContext) {
83            let [alice, bob] = case.sessions().await;
84            Box::pin(async move {
85                let conversation = case
86                    .create_conversation([&alice])
87                    .await
88                    .invite_proposal_notify(&bob)
89                    .await
90                    .commit_pending_proposals_notify()
91                    .await;
92                assert_eq!(conversation.member_count().await, 2);
93                assert!(conversation.is_functional_and_contains([&alice, &bob]).await);
94            })
95            .await
96        }
97    }
98
99    mod update {
100        use super::*;
101        use itertools::Itertools;
102
103        #[apply(all_cred_cipher)]
104        pub async fn should_update_hpke_key(case: TestContext) {
105            use crate::mls::conversation::ConversationWithMls as _;
106
107            let [session] = case.sessions().await;
108            let conversation = case.create_conversation([&session]).await;
109            let conversation_guard = conversation.guard().await;
110            let before = conversation_guard
111                .conversation()
112                .await
113                .encryption_keys()
114                .find_or_first(|_| true)
115                .unwrap();
116            conversation
117                .update_proposal_notify()
118                .await
119                .commit_pending_proposals_notify()
120                .await;
121            let after = conversation_guard
122                .conversation()
123                .await
124                .encryption_keys()
125                .find_or_first(|_| true)
126                .unwrap();
127            assert_ne!(before, after)
128        }
129    }
130
131    mod remove {
132        use super::*;
133
134        #[apply(all_cred_cipher)]
135        pub async fn should_remove_member(case: TestContext) {
136            let [alice, bob] = case.sessions().await;
137            Box::pin(async move {
138                let conversation = case.create_conversation([&alice, &bob]).await;
139                let id = conversation.id().clone();
140                assert_eq!(conversation.member_count().await, 2);
141
142                let conversation = conversation
143                    .remove_proposal_notify(&bob)
144                    .await
145                    .commit_pending_proposals_notify()
146                    .await;
147
148                assert_eq!(conversation.member_count().await, 1);
149
150                assert!(matches!(
151                    bob.transaction.conversation(&id).await.unwrap_err(),
152                    Error::Leaf(LeafError::ConversationNotFound(conv_id)) if conv_id == id
153                ));
154            })
155            .await
156        }
157
158        #[apply(all_cred_cipher)]
159        pub async fn should_fail_when_unknown_client(case: TestContext) {
160            let [alice] = case.sessions().await;
161            Box::pin(async move {
162                let conversation = case.create_conversation([&alice]).await;
163                let id = conversation.id().clone();
164
165                let remove_proposal = alice.transaction.new_remove_proposal(&id, b"unknown"[..].into()).await;
166                assert!(matches!(
167                    remove_proposal.unwrap_err(),
168                    Error::ClientNotFound(client_id) if client_id == b"unknown"[..].into()
169                ));
170            })
171            .await
172        }
173    }
174}