core_crypto/transaction_context/conversation/
proposal.rs

1use openmls::prelude::KeyPackage;
2
3use super::{Error, Result};
4use crate::{
5    ClientId, ConversationId, MlsProposal, MlsProposalBundle, RecursiveError, transaction_context::TransactionContext,
6};
7
8impl TransactionContext {
9    /// Creates a new Add proposal
10    #[cfg_attr(test, crate::idempotent)]
11    pub async fn new_add_proposal(&self, id: &ConversationId, key_package: KeyPackage) -> Result<MlsProposalBundle> {
12        self.new_proposal(id, MlsProposal::Add(key_package)).await
13    }
14
15    /// Creates a new Add proposal
16    #[cfg_attr(test, crate::idempotent)]
17    pub async fn new_remove_proposal(&self, id: &ConversationId, client_id: ClientId) -> Result<MlsProposalBundle> {
18        self.new_proposal(id, MlsProposal::Remove(client_id)).await
19    }
20
21    /// Creates a new Add proposal
22    #[cfg_attr(test, crate::dispotent)]
23    pub async fn new_update_proposal(&self, id: &ConversationId) -> Result<MlsProposalBundle> {
24        self.new_proposal(id, MlsProposal::Update).await
25    }
26
27    /// Creates a new proposal within a group
28    ///
29    /// # Arguments
30    /// * `conversation` - the group/conversation id
31    /// * `proposal` - the proposal do be added in the group
32    ///
33    /// # Return type
34    /// A [MlsProposalBundle] with the proposal in a Mls message and a reference to that proposal in order to rollback
35    /// it if required
36    ///
37    /// # Errors
38    /// If the conversation is not found, an error will be returned. Errors from OpenMls can be
39    /// returned as well, when for example there's a commit pending to be merged
40    async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result<MlsProposalBundle> {
41        let mut conversation = self.conversation(id).await?;
42        let mut conversation = conversation.conversation_mut().await;
43        let client = &self.session().await?;
44        let provider = &self.mls_provider().await?;
45        let database = &self.database().await?;
46        let proposal = match proposal {
47            MlsProposal::Add(key_package) => conversation
48                .propose_add_member(client, provider, database, key_package.into())
49                .await
50                .map_err(RecursiveError::mls_conversation("proposing to add member"))?,
51            MlsProposal::Update => conversation
52                .propose_self_update(client, provider, database)
53                .await
54                .map_err(RecursiveError::mls_conversation("proposing self update"))?,
55            MlsProposal::Remove(client_id) => {
56                let index = conversation
57                    .group
58                    .members()
59                    .find(|kp| kp.credential.identity() == client_id.as_slice())
60                    .ok_or(Error::ClientNotFound(client_id))
61                    .map(|kp| kp.index)?;
62                (*conversation)
63                    .propose_remove_member(client, provider, database, index)
64                    .await
65                    .map_err(RecursiveError::mls_conversation("proposing to remove member"))?
66            }
67        };
68        Ok(proposal)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::Error;
75    use crate::{mls::conversation::ConversationWithMls as _, test_utils::*, *};
76
77    mod add {
78        use super::*;
79
80        #[apply(all_cred_cipher)]
81        pub async fn should_add_member(case: TestContext) {
82            let [alice, bob] = case.sessions().await;
83            Box::pin(async move {
84                let conversation = case
85                    .create_conversation([&alice])
86                    .await
87                    .invite_proposal_notify(&bob)
88                    .await
89                    .commit_pending_proposals_notify()
90                    .await;
91                assert_eq!(conversation.member_count().await, 2);
92                assert!(conversation.is_functional_and_contains([&alice, &bob]).await);
93            })
94            .await
95        }
96    }
97
98    mod update {
99        use itertools::Itertools;
100
101        use super::*;
102
103        #[apply(all_cred_cipher)]
104        pub async fn should_update_hpke_key(case: TestContext) {
105            let [session] = case.sessions().await;
106            let conversation = case.create_conversation([&session]).await;
107            let conversation_guard = conversation.guard().await;
108            let before = conversation_guard
109                .conversation()
110                .await
111                .encryption_keys()
112                .find_or_first(|_| true)
113                .unwrap();
114            conversation
115                .update_proposal_notify()
116                .await
117                .commit_pending_proposals_notify()
118                .await;
119            let after = conversation_guard
120                .conversation()
121                .await
122                .encryption_keys()
123                .find_or_first(|_| true)
124                .unwrap();
125            assert_ne!(before, after)
126        }
127    }
128
129    mod remove {
130        use super::*;
131
132        #[apply(all_cred_cipher)]
133        pub async fn should_remove_member(case: TestContext) {
134            let [alice, bob] = case.sessions().await;
135            Box::pin(async move {
136                let conversation = case.create_conversation([&alice, &bob]).await;
137                let id = conversation.id().clone();
138                assert_eq!(conversation.member_count().await, 2);
139
140                let conversation = conversation
141                    .remove_proposal_notify(&bob)
142                    .await
143                    .commit_pending_proposals_notify()
144                    .await;
145
146                assert_eq!(conversation.member_count().await, 1);
147
148                assert!(matches!(
149                    bob.transaction.conversation(&id).await.unwrap_err(),
150                    Error::Leaf(LeafError::ConversationNotFound(conv_id)) if conv_id == id
151                ));
152            })
153            .await
154        }
155
156        #[apply(all_cred_cipher)]
157        pub async fn should_fail_when_unknown_client(case: TestContext) {
158            let [alice] = case.sessions().await;
159            Box::pin(async move {
160                let conversation = case.create_conversation([&alice]).await;
161                let id = conversation.id().clone();
162
163                let remove_proposal = alice
164                    .transaction
165                    .new_remove_proposal(&id, b"unknown".as_slice().to_owned().into())
166                    .await;
167                assert!(matches!(
168                    remove_proposal.unwrap_err(),
169                    Error::ClientNotFound(client_id) if client_id == b"unknown".as_slice()
170                ));
171            })
172            .await
173        }
174    }
175}