core_crypto/transaction_context/conversation/
proposal.rs

1use openmls::prelude::KeyPackage;
2
3use super::{Error, Result};
4use crate::{
5    RecursiveError,
6    prelude::{ClientId, ConversationId, MlsProposal, MlsProposalBundle},
7    transaction_context::TransactionContext,
8};
9
10impl TransactionContext {
11    /// Creates a new Add proposal
12    #[cfg_attr(test, crate::idempotent)]
13    pub async fn new_add_proposal(&self, id: &ConversationId, key_package: KeyPackage) -> Result<MlsProposalBundle> {
14        self.new_proposal(id, MlsProposal::Add(key_package)).await
15    }
16
17    /// Creates a new Add proposal
18    #[cfg_attr(test, crate::idempotent)]
19    pub async fn new_remove_proposal(&self, id: &ConversationId, client_id: ClientId) -> Result<MlsProposalBundle> {
20        self.new_proposal(id, MlsProposal::Remove(client_id)).await
21    }
22
23    /// Creates a new Add proposal
24    #[cfg_attr(test, crate::dispotent)]
25    pub async fn new_update_proposal(&self, id: &ConversationId) -> Result<MlsProposalBundle> {
26        self.new_proposal(id, MlsProposal::Update).await
27    }
28
29    /// Creates a new proposal within a group
30    ///
31    /// # Arguments
32    /// * `conversation` - the group/conversation id
33    /// * `proposal` - the proposal do be added in the group
34    ///
35    /// # Return type
36    /// A [MlsProposalBundle] with the proposal in a Mls message and a reference to that proposal in order to rollback it if required
37    ///
38    /// # Errors
39    /// If the conversation is not found, an error will be returned. Errors from OpenMls can be
40    /// returned as well, when for example there's a commit pending to be merged
41    async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result<MlsProposalBundle> {
42        let mut conversation = self.conversation(id).await?;
43        let mut conversation = conversation.conversation_mut().await;
44        let client = &self.session().await?;
45        let backend = &self.mls_provider().await?;
46        let proposal = match proposal {
47            MlsProposal::Add(key_package) => conversation
48                .propose_add_member(client, backend, key_package.into())
49                .await
50                .map_err(RecursiveError::mls_conversation("proposing to add member"))?,
51            MlsProposal::Update => conversation
52                .propose_self_update(client, backend)
53                .await
54                .map_err(RecursiveError::mls_conversation("proposing self update"))?,
55            MlsProposal::Remove(client_id) => {
56                let index = conversation
57                    .group
58                    .members()
59                    .find(|kp| kp.credential.identity() == client_id.as_slice())
60                    .ok_or(Error::ClientNotFound(client_id))
61                    .map(|kp| kp.index)?;
62                (*conversation)
63                    .propose_remove_member(client, backend, index)
64                    .await
65                    .map_err(RecursiveError::mls_conversation("proposing to remove member"))?
66            }
67        };
68        Ok(proposal)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use crate::mls::conversation::ConversationWithMls as _;
75    use crate::{prelude::*, test_utils::*};
76
77    use super::Error;
78
79    mod add {
80        use super::*;
81
82        #[apply(all_cred_cipher)]
83        pub async fn should_add_member(case: TestContext) {
84            let [alice, bob] = case.sessions().await;
85            Box::pin(async move {
86                let conversation = case
87                    .create_conversation([&alice])
88                    .await
89                    .invite_proposal_notify(&bob)
90                    .await
91                    .commit_pending_proposals_notify()
92                    .await;
93                assert_eq!(conversation.member_count().await, 2);
94                assert!(conversation.is_functional_and_contains([&alice, &bob]).await);
95            })
96            .await
97        }
98    }
99
100    mod update {
101        use super::*;
102        use itertools::Itertools;
103
104        #[apply(all_cred_cipher)]
105        pub async fn should_update_hpke_key(case: TestContext) {
106            let [session] = case.sessions().await;
107            let conversation = case.create_conversation([&session]).await;
108            let conversation_guard = conversation.guard().await;
109            let before = conversation_guard
110                .conversation()
111                .await
112                .encryption_keys()
113                .find_or_first(|_| true)
114                .unwrap();
115            conversation
116                .update_proposal_notify()
117                .await
118                .commit_pending_proposals_notify()
119                .await;
120            let after = conversation_guard
121                .conversation()
122                .await
123                .encryption_keys()
124                .find_or_first(|_| true)
125                .unwrap();
126            assert_ne!(before, after)
127        }
128    }
129
130    mod remove {
131        use super::*;
132
133        #[apply(all_cred_cipher)]
134        pub async fn should_remove_member(case: TestContext) {
135            let [alice, bob] = case.sessions().await;
136            Box::pin(async move {
137                let conversation = case.create_conversation([&alice, &bob]).await;
138                let id = conversation.id().clone();
139                assert_eq!(conversation.member_count().await, 2);
140
141                let conversation = conversation
142                    .remove_proposal_notify(&bob)
143                    .await
144                    .commit_pending_proposals_notify()
145                    .await;
146
147                assert_eq!(conversation.member_count().await, 1);
148
149                assert!(matches!(
150                    bob.transaction.conversation(&id).await.unwrap_err(),
151                    Error::Leaf(LeafError::ConversationNotFound(conv_id)) if conv_id == id
152                ));
153            })
154            .await
155        }
156
157        #[apply(all_cred_cipher)]
158        pub async fn should_fail_when_unknown_client(case: TestContext) {
159            let [alice] = case.sessions().await;
160            Box::pin(async move {
161                let conversation = case.create_conversation([&alice]).await;
162                let id = conversation.id().clone();
163
164                let remove_proposal = alice.transaction.new_remove_proposal(&id, b"unknown"[..].into()).await;
165                assert!(matches!(
166                    remove_proposal.unwrap_err(),
167                    Error::ClientNotFound(client_id) if client_id == b"unknown"[..].into()
168                ));
169            })
170            .await
171        }
172    }
173}