core_crypto/transaction_context/conversation/
proposal.rs1use openmls::prelude::KeyPackage;
2
3use super::{Error, Result};
4use crate::{
5 ClientId, ConversationId, MlsProposal, MlsProposalBundle, RecursiveError, transaction_context::TransactionContext,
6};
7
8impl TransactionContext {
9 #[cfg_attr(test, crate::idempotent)]
11 pub async fn new_add_proposal(&self, id: &ConversationId, key_package: KeyPackage) -> Result<MlsProposalBundle> {
12 self.new_proposal(id, MlsProposal::Add(key_package)).await
13 }
14
15 #[cfg_attr(test, crate::idempotent)]
17 pub async fn new_remove_proposal(&self, id: &ConversationId, client_id: ClientId) -> Result<MlsProposalBundle> {
18 self.new_proposal(id, MlsProposal::Remove(client_id)).await
19 }
20
21 #[cfg_attr(test, crate::dispotent)]
23 pub async fn new_update_proposal(&self, id: &ConversationId) -> Result<MlsProposalBundle> {
24 self.new_proposal(id, MlsProposal::Update).await
25 }
26
27 async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result<MlsProposalBundle> {
41 let mut conversation = self.conversation(id).await?;
42 let mut conversation = conversation.conversation_mut().await;
43 let client = &self.session().await?;
44 let backend = &self.mls_provider().await?;
45 let proposal = match proposal {
46 MlsProposal::Add(key_package) => conversation
47 .propose_add_member(client, backend, key_package.into())
48 .await
49 .map_err(RecursiveError::mls_conversation("proposing to add member"))?,
50 MlsProposal::Update => conversation
51 .propose_self_update(client, backend)
52 .await
53 .map_err(RecursiveError::mls_conversation("proposing self update"))?,
54 MlsProposal::Remove(client_id) => {
55 let index = conversation
56 .group
57 .members()
58 .find(|kp| kp.credential.identity() == client_id.as_slice())
59 .ok_or(Error::ClientNotFound(client_id))
60 .map(|kp| kp.index)?;
61 (*conversation)
62 .propose_remove_member(client, backend, index)
63 .await
64 .map_err(RecursiveError::mls_conversation("proposing to remove member"))?
65 }
66 };
67 Ok(proposal)
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use super::Error;
74 use crate::{mls::conversation::ConversationWithMls as _, test_utils::*, *};
75
76 mod add {
77 use super::*;
78
79 #[apply(all_cred_cipher)]
80 pub async fn should_add_member(case: TestContext) {
81 let [alice, bob] = case.sessions().await;
82 Box::pin(async move {
83 let conversation = case
84 .create_conversation([&alice])
85 .await
86 .invite_proposal_notify(&bob)
87 .await
88 .commit_pending_proposals_notify()
89 .await;
90 assert_eq!(conversation.member_count().await, 2);
91 assert!(conversation.is_functional_and_contains([&alice, &bob]).await);
92 })
93 .await
94 }
95 }
96
97 mod update {
98 use itertools::Itertools;
99
100 use super::*;
101
102 #[apply(all_cred_cipher)]
103 pub async fn should_update_hpke_key(case: TestContext) {
104 let [session] = case.sessions().await;
105 let conversation = case.create_conversation([&session]).await;
106 let conversation_guard = conversation.guard().await;
107 let before = conversation_guard
108 .conversation()
109 .await
110 .encryption_keys()
111 .find_or_first(|_| true)
112 .unwrap();
113 conversation
114 .update_proposal_notify()
115 .await
116 .commit_pending_proposals_notify()
117 .await;
118 let after = conversation_guard
119 .conversation()
120 .await
121 .encryption_keys()
122 .find_or_first(|_| true)
123 .unwrap();
124 assert_ne!(before, after)
125 }
126 }
127
128 mod remove {
129 use super::*;
130
131 #[apply(all_cred_cipher)]
132 pub async fn should_remove_member(case: TestContext) {
133 let [alice, bob] = case.sessions().await;
134 Box::pin(async move {
135 let conversation = case.create_conversation([&alice, &bob]).await;
136 let id = conversation.id().clone();
137 assert_eq!(conversation.member_count().await, 2);
138
139 let conversation = conversation
140 .remove_proposal_notify(&bob)
141 .await
142 .commit_pending_proposals_notify()
143 .await;
144
145 assert_eq!(conversation.member_count().await, 1);
146
147 assert!(matches!(
148 bob.transaction.conversation(&id).await.unwrap_err(),
149 Error::Leaf(LeafError::ConversationNotFound(conv_id)) if conv_id == id
150 ));
151 })
152 .await
153 }
154
155 #[apply(all_cred_cipher)]
156 pub async fn should_fail_when_unknown_client(case: TestContext) {
157 let [alice] = case.sessions().await;
158 Box::pin(async move {
159 let conversation = case.create_conversation([&alice]).await;
160 let id = conversation.id().clone();
161
162 let remove_proposal = alice
163 .transaction
164 .new_remove_proposal(&id, b"unknown".as_slice().to_owned().into())
165 .await;
166 assert!(matches!(
167 remove_proposal.unwrap_err(),
168 Error::ClientNotFound(client_id) if client_id == b"unknown".as_slice()
169 ));
170 })
171 .await
172 }
173 }
174}