core_crypto/transaction_context/conversation/
proposal.rs1use openmls::prelude::KeyPackage;
2
3use super::{Error, Result};
4use crate::{
5 ClientId, ConversationId, MlsProposal, MlsProposalBundle, RecursiveError, transaction_context::TransactionContext,
6};
7
8impl TransactionContext {
9 #[cfg_attr(test, crate::idempotent)]
11 pub async fn new_add_proposal(&self, id: &ConversationId, key_package: KeyPackage) -> Result<MlsProposalBundle> {
12 self.new_proposal(id, MlsProposal::Add(key_package)).await
13 }
14
15 #[cfg_attr(test, crate::idempotent)]
17 pub async fn new_remove_proposal(&self, id: &ConversationId, client_id: ClientId) -> Result<MlsProposalBundle> {
18 self.new_proposal(id, MlsProposal::Remove(client_id)).await
19 }
20
21 #[cfg_attr(test, crate::dispotent)]
23 pub async fn new_update_proposal(&self, id: &ConversationId) -> Result<MlsProposalBundle> {
24 self.new_proposal(id, MlsProposal::Update).await
25 }
26
27 async fn new_proposal(&self, id: &ConversationId, proposal: MlsProposal) -> Result<MlsProposalBundle> {
40 let mut conversation = self.conversation(id).await?;
41 let mut conversation = conversation.conversation_mut().await;
42 let client = &self.session().await?;
43 let backend = &self.mls_provider().await?;
44 let proposal = match proposal {
45 MlsProposal::Add(key_package) => conversation
46 .propose_add_member(client, backend, key_package.into())
47 .await
48 .map_err(RecursiveError::mls_conversation("proposing to add member"))?,
49 MlsProposal::Update => conversation
50 .propose_self_update(client, backend)
51 .await
52 .map_err(RecursiveError::mls_conversation("proposing self update"))?,
53 MlsProposal::Remove(client_id) => {
54 let index = conversation
55 .group
56 .members()
57 .find(|kp| kp.credential.identity() == client_id.as_slice())
58 .ok_or(Error::ClientNotFound(client_id))
59 .map(|kp| kp.index)?;
60 (*conversation)
61 .propose_remove_member(client, backend, index)
62 .await
63 .map_err(RecursiveError::mls_conversation("proposing to remove member"))?
64 }
65 };
66 Ok(proposal)
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::Error;
73 use crate::{mls::conversation::ConversationWithMls as _, test_utils::*, *};
74
75 mod add {
76 use super::*;
77
78 #[apply(all_cred_cipher)]
79 pub async fn should_add_member(case: TestContext) {
80 let [alice, bob] = case.sessions().await;
81 Box::pin(async move {
82 let conversation = case
83 .create_conversation([&alice])
84 .await
85 .invite_proposal_notify(&bob)
86 .await
87 .commit_pending_proposals_notify()
88 .await;
89 assert_eq!(conversation.member_count().await, 2);
90 assert!(conversation.is_functional_and_contains([&alice, &bob]).await);
91 })
92 .await
93 }
94 }
95
96 mod update {
97 use itertools::Itertools;
98
99 use super::*;
100
101 #[apply(all_cred_cipher)]
102 pub async fn should_update_hpke_key(case: TestContext) {
103 let [session] = case.sessions().await;
104 let conversation = case.create_conversation([&session]).await;
105 let conversation_guard = conversation.guard().await;
106 let before = conversation_guard
107 .conversation()
108 .await
109 .encryption_keys()
110 .find_or_first(|_| true)
111 .unwrap();
112 conversation
113 .update_proposal_notify()
114 .await
115 .commit_pending_proposals_notify()
116 .await;
117 let after = conversation_guard
118 .conversation()
119 .await
120 .encryption_keys()
121 .find_or_first(|_| true)
122 .unwrap();
123 assert_ne!(before, after)
124 }
125 }
126
127 mod remove {
128 use super::*;
129
130 #[apply(all_cred_cipher)]
131 pub async fn should_remove_member(case: TestContext) {
132 let [alice, bob] = case.sessions().await;
133 Box::pin(async move {
134 let conversation = case.create_conversation([&alice, &bob]).await;
135 let id = conversation.id().clone();
136 assert_eq!(conversation.member_count().await, 2);
137
138 let conversation = conversation
139 .remove_proposal_notify(&bob)
140 .await
141 .commit_pending_proposals_notify()
142 .await;
143
144 assert_eq!(conversation.member_count().await, 1);
145
146 assert!(matches!(
147 bob.transaction.conversation(&id).await.unwrap_err(),
148 Error::Leaf(LeafError::ConversationNotFound(conv_id)) if conv_id == id
149 ));
150 })
151 .await
152 }
153
154 #[apply(all_cred_cipher)]
155 pub async fn should_fail_when_unknown_client(case: TestContext) {
156 let [alice] = case.sessions().await;
157 Box::pin(async move {
158 let conversation = case.create_conversation([&alice]).await;
159 let id = conversation.id().clone();
160
161 let remove_proposal = alice.transaction.new_remove_proposal(&id, b"unknown"[..].into()).await;
162 assert!(matches!(
163 remove_proposal.unwrap_err(),
164 Error::ClientNotFound(client_id) if client_id == b"unknown"[..].into()
165 ));
166 })
167 .await
168 }
169 }
170}