core_crypto/transaction_context/conversation/
proposal.rs1use openmls::prelude::KeyPackage;
2
3use super::{Error, Result};
4use crate::{
5 RecursiveError,
6 prelude::{ClientId, ConversationId, MlsProposal, MlsProposalBundle},
7 transaction_context::TransactionContext,
8};
9
10impl TransactionContext {
11 #[cfg_attr(test, crate::idempotent)]
13 pub async fn new_add_proposal(&self, id: &ConversationId, key_package: KeyPackage) -> Result<MlsProposalBundle> {
14 self.new_proposal(id, MlsProposal::Add(key_package)).await
15 }
16
17 #[cfg_attr(test, crate::idempotent)]
19 pub async fn new_remove_proposal(&self, id: &ConversationId, client_id: ClientId) -> Result<MlsProposalBundle> {
20 self.new_proposal(id, MlsProposal::Remove(client_id)).await
21 }
22
23 #[cfg_attr(test, crate::dispotent)]
25 pub async fn new_update_proposal(&self, id: &ConversationId) -> Result<MlsProposalBundle> {
26 self.new_proposal(id, MlsProposal::Update).await
27 }
28
29 async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result<MlsProposalBundle> {
42 let mut conversation = self.conversation(id).await?;
43 let mut conversation = conversation.conversation_mut().await;
44 let client = &self.session().await?;
45 let backend = &self.mls_provider().await?;
46 let proposal = match proposal {
47 MlsProposal::Add(key_package) => conversation
48 .propose_add_member(client, backend, key_package.into())
49 .await
50 .map_err(RecursiveError::mls_conversation("proposing to add member"))?,
51 MlsProposal::Update => conversation
52 .propose_self_update(client, backend)
53 .await
54 .map_err(RecursiveError::mls_conversation("proposing self update"))?,
55 MlsProposal::Remove(client_id) => {
56 let index = conversation
57 .group
58 .members()
59 .find(|kp| kp.credential.identity() == client_id.as_slice())
60 .ok_or(Error::ClientNotFound(client_id))
61 .map(|kp| kp.index)?;
62 (*conversation)
63 .propose_remove_member(client, backend, index)
64 .await
65 .map_err(RecursiveError::mls_conversation("proposing to remove member"))?
66 }
67 };
68 Ok(proposal)
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use crate::mls::conversation::ConversationWithMls as _;
75 use crate::{prelude::*, test_utils::*};
76
77 use super::Error;
78
79 mod add {
80 use super::*;
81
82 #[apply(all_cred_cipher)]
83 pub async fn should_add_member(case: TestContext) {
84 let [alice, bob] = case.sessions().await;
85 Box::pin(async move {
86 let conversation = case
87 .create_conversation([&alice])
88 .await
89 .invite_proposal_notify(&bob)
90 .await
91 .commit_pending_proposals_notify()
92 .await;
93 assert_eq!(conversation.member_count().await, 2);
94 assert!(conversation.is_functional_and_contains([&alice, &bob]).await);
95 })
96 .await
97 }
98 }
99
100 mod update {
101 use super::*;
102 use itertools::Itertools;
103
104 #[apply(all_cred_cipher)]
105 pub async fn should_update_hpke_key(case: TestContext) {
106 let [session] = case.sessions().await;
107 let conversation = case.create_conversation([&session]).await;
108 let conversation_guard = conversation.guard().await;
109 let before = conversation_guard
110 .conversation()
111 .await
112 .encryption_keys()
113 .find_or_first(|_| true)
114 .unwrap();
115 conversation
116 .update_proposal_notify()
117 .await
118 .commit_pending_proposals_notify()
119 .await;
120 let after = conversation_guard
121 .conversation()
122 .await
123 .encryption_keys()
124 .find_or_first(|_| true)
125 .unwrap();
126 assert_ne!(before, after)
127 }
128 }
129
130 mod remove {
131 use super::*;
132
133 #[apply(all_cred_cipher)]
134 pub async fn should_remove_member(case: TestContext) {
135 let [alice, bob] = case.sessions().await;
136 Box::pin(async move {
137 let conversation = case.create_conversation([&alice, &bob]).await;
138 let id = conversation.id().clone();
139 assert_eq!(conversation.member_count().await, 2);
140
141 let conversation = conversation
142 .remove_proposal_notify(&bob)
143 .await
144 .commit_pending_proposals_notify()
145 .await;
146
147 assert_eq!(conversation.member_count().await, 1);
148
149 assert!(matches!(
150 bob.transaction.conversation(&id).await.unwrap_err(),
151 Error::Leaf(LeafError::ConversationNotFound(conv_id)) if conv_id == id
152 ));
153 })
154 .await
155 }
156
157 #[apply(all_cred_cipher)]
158 pub async fn should_fail_when_unknown_client(case: TestContext) {
159 let [alice] = case.sessions().await;
160 Box::pin(async move {
161 let conversation = case.create_conversation([&alice]).await;
162 let id = conversation.id().clone();
163
164 let remove_proposal = alice.transaction.new_remove_proposal(&id, b"unknown"[..].into()).await;
165 assert!(matches!(
166 remove_proposal.unwrap_err(),
167 Error::ClientNotFound(client_id) if client_id == b"unknown"[..].into()
168 ));
169 })
170 .await
171 }
172 }
173}