core_crypto/transaction_context/conversation/
proposal.rs1use openmls::prelude::KeyPackage;
2
3use super::{Error, Result};
4use crate::{
5 ClientId, ConversationId, MlsProposal, MlsProposalBundle, RecursiveError, transaction_context::TransactionContext,
6};
7
8impl TransactionContext {
9 #[cfg_attr(test, crate::idempotent)]
11 pub async fn new_add_proposal(&self, id: &ConversationId, key_package: KeyPackage) -> Result<MlsProposalBundle> {
12 self.new_proposal(id, MlsProposal::Add(key_package)).await
13 }
14
15 #[cfg_attr(test, crate::idempotent)]
17 pub async fn new_remove_proposal(&self, id: &ConversationId, client_id: ClientId) -> Result<MlsProposalBundle> {
18 self.new_proposal(id, MlsProposal::Remove(client_id)).await
19 }
20
21 #[cfg_attr(test, crate::dispotent)]
23 pub async fn new_update_proposal(&self, id: &ConversationId) -> Result<MlsProposalBundle> {
24 self.new_proposal(id, MlsProposal::Update).await
25 }
26
27 async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result<MlsProposalBundle> {
41 let mut conversation = self.conversation(id).await?;
42 let mut conversation = conversation.conversation_mut().await;
43 let client = &self.session().await?;
44 let provider = &self.mls_provider().await?;
45 let database = &self.database().await?;
46 let proposal = match proposal {
47 MlsProposal::Add(key_package) => conversation
48 .propose_add_member(client, provider, database, key_package.into())
49 .await
50 .map_err(RecursiveError::mls_conversation("proposing to add member"))?,
51 MlsProposal::Update => conversation
52 .propose_self_update(client, provider, database)
53 .await
54 .map_err(RecursiveError::mls_conversation("proposing self update"))?,
55 MlsProposal::Remove(client_id) => {
56 let index = conversation
57 .group
58 .members()
59 .find(|kp| kp.credential.identity() == client_id.as_slice())
60 .ok_or(Error::ClientNotFound(client_id))
61 .map(|kp| kp.index)?;
62 (*conversation)
63 .propose_remove_member(client, provider, database, index)
64 .await
65 .map_err(RecursiveError::mls_conversation("proposing to remove member"))?
66 }
67 };
68 Ok(proposal)
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::Error;
75 use crate::{mls::conversation::ConversationWithMls as _, test_utils::*, *};
76
77 mod add {
78 use super::*;
79
80 #[apply(all_cred_cipher)]
81 pub async fn should_add_member(case: TestContext) {
82 let [alice, bob] = case.sessions().await;
83 Box::pin(async move {
84 let conversation = case
85 .create_conversation([&alice])
86 .await
87 .invite_proposal_notify(&bob)
88 .await
89 .commit_pending_proposals_notify()
90 .await;
91 assert_eq!(conversation.member_count().await, 2);
92 assert!(conversation.is_functional_and_contains([&alice, &bob]).await);
93 })
94 .await
95 }
96 }
97
98 mod update {
99 use itertools::Itertools;
100
101 use super::*;
102
103 #[apply(all_cred_cipher)]
104 pub async fn should_update_hpke_key(case: TestContext) {
105 let [session] = case.sessions().await;
106 let conversation = case.create_conversation([&session]).await;
107 let conversation_guard = conversation.guard().await;
108 let before = conversation_guard
109 .conversation()
110 .await
111 .encryption_keys()
112 .find_or_first(|_| true)
113 .unwrap();
114 conversation
115 .update_proposal_notify()
116 .await
117 .commit_pending_proposals_notify()
118 .await;
119 let after = conversation_guard
120 .conversation()
121 .await
122 .encryption_keys()
123 .find_or_first(|_| true)
124 .unwrap();
125 assert_ne!(before, after)
126 }
127 }
128
129 mod remove {
130 use super::*;
131
132 #[apply(all_cred_cipher)]
133 pub async fn should_remove_member(case: TestContext) {
134 let [alice, bob] = case.sessions().await;
135 Box::pin(async move {
136 let conversation = case.create_conversation([&alice, &bob]).await;
137 let id = conversation.id().clone();
138 assert_eq!(conversation.member_count().await, 2);
139
140 let conversation = conversation
141 .remove_proposal_notify(&bob)
142 .await
143 .commit_pending_proposals_notify()
144 .await;
145
146 assert_eq!(conversation.member_count().await, 1);
147
148 assert!(matches!(
149 bob.transaction.conversation(&id).await.unwrap_err(),
150 Error::Leaf(LeafError::ConversationNotFound(conv_id)) if conv_id == id
151 ));
152 })
153 .await
154 }
155
156 #[apply(all_cred_cipher)]
157 pub async fn should_fail_when_unknown_client(case: TestContext) {
158 let [alice] = case.sessions().await;
159 Box::pin(async move {
160 let conversation = case.create_conversation([&alice]).await;
161 let id = conversation.id().clone();
162
163 let remove_proposal = alice
164 .transaction
165 .new_remove_proposal(&id, b"unknown".as_slice().to_owned().into())
166 .await;
167 assert!(matches!(
168 remove_proposal.unwrap_err(),
169 Error::ClientNotFound(client_id) if client_id == b"unknown".as_slice()
170 ));
171 })
172 .await
173 }
174 }
175}