core_crypto/mls/client/
identities.rs1use crate::{
2 LeafError, RecursiveError,
3 mls::client::{
4 ClientInner,
5 error::{Error, Result},
6 },
7};
8use crate::{
9 mls::credential::{CredentialBundle, typ::MlsCredentialType},
10 prelude::{Client, MlsConversation},
11};
12use openmls::prelude::{Credential, SignaturePublicKey};
13use openmls_traits::types::SignatureScheme;
14use std::collections::HashMap;
15use std::ops::Deref;
16use std::sync::Arc;
17
18#[derive(Debug, Clone)]
24pub(crate) struct ClientIdentities(HashMap<SignatureScheme, indexmap::IndexSet<Arc<CredentialBundle>>>);
25
26impl ClientIdentities {
27 pub(crate) fn new(capacity: usize) -> Self {
28 Self(HashMap::with_capacity(capacity))
29 }
30
31 pub(crate) async fn find_credential_bundle_by_public_key(
32 &self,
33 sc: SignatureScheme,
34 ct: MlsCredentialType,
35 pk: &SignaturePublicKey,
36 ) -> Option<Arc<CredentialBundle>> {
37 self.0
38 .get(&sc)?
39 .iter()
40 .find(|c| {
41 let ct_match = ct == c.credential.credential_type().into();
42 let pk_match = c.signature_key.public() == pk.as_slice();
43 ct_match && pk_match
44 })
45 .cloned()
46 }
47
48 pub(crate) async fn find_most_recent_credential_bundle(
49 &self,
50 sc: SignatureScheme,
51 ct: MlsCredentialType,
52 ) -> Option<Arc<CredentialBundle>> {
53 self.0
54 .get(&sc)?
55 .iter()
56 .rfind(|c| ct == c.credential.credential_type().into())
57 .cloned()
58 }
59
60 pub(crate) async fn push_credential_bundle(&mut self, sc: SignatureScheme, cb: CredentialBundle) -> Result<()> {
63 debug_assert_ne!(cb.created_at, 0);
65
66 match self.0.get_mut(&sc) {
67 Some(cbs) => {
68 let already_exists = !cbs.insert(Arc::new(cb));
69 if already_exists {
70 return Err(Error::CredentialBundleConflict);
71 }
72 }
73 None => {
74 self.0.insert(sc, indexmap::IndexSet::from([Arc::new(cb)]));
75 }
76 }
77 Ok(())
78 }
79
80 pub(crate) async fn remove(&mut self, credential: &Credential) -> Result<()> {
81 self.0.iter_mut().for_each(|(_, cbs)| {
82 cbs.retain(|c| c.credential() != credential);
83 });
84 Ok(())
85 }
86
87 pub(crate) fn iter(&self) -> impl Iterator<Item = (SignatureScheme, Arc<CredentialBundle>)> + '_ {
88 self.0.iter().flat_map(|(sc, cb)| cb.iter().map(|c| (*sc, c.clone())))
89 }
90}
91
92impl MlsConversation {
93 pub(crate) async fn find_current_credential_bundle(&self, client: &Client) -> Result<Arc<CredentialBundle>> {
94 let own_leaf = self.group.own_leaf().ok_or(LeafError::InternalMlsError)?;
95 let sc = self.ciphersuite().signature_algorithm();
96 let ct = self
97 .own_credential_type()
98 .map_err(RecursiveError::mls_conversation("getting own credential type"))?;
99
100 client
101 .find_credential_bundle_by_public_key(sc, ct, own_leaf.signature_key())
102 .await
103 }
104
105 pub(crate) async fn find_most_recent_credential_bundle(&self, client: &Client) -> Result<Arc<CredentialBundle>> {
106 let sc = self.ciphersuite().signature_algorithm();
107 let ct = self
108 .own_credential_type()
109 .map_err(RecursiveError::mls_conversation("getting own credential type"))?;
110
111 client.find_most_recent_credential_bundle(sc, ct).await
112 }
113}
114
115impl Client {
116 pub(crate) async fn find_most_recent_credential_bundle(
117 &self,
118 sc: SignatureScheme,
119 ct: MlsCredentialType,
120 ) -> Result<Arc<CredentialBundle>> {
121 match self.state.read().await.deref() {
122 None => Err(Error::MlsNotInitialized),
123 Some(ClientInner { identities, .. }) => identities
124 .find_most_recent_credential_bundle(sc, ct)
125 .await
126 .ok_or(Error::CredentialNotFound(ct)),
127 }
128 }
129
130 pub(crate) async fn find_credential_bundle_by_public_key(
131 &self,
132 sc: SignatureScheme,
133 ct: MlsCredentialType,
134 pk: &SignaturePublicKey,
135 ) -> Result<Arc<CredentialBundle>> {
136 match self.state.read().await.deref() {
137 None => Err(Error::MlsNotInitialized),
138 Some(ClientInner { identities, .. }) => identities
139 .find_credential_bundle_by_public_key(sc, ct, pk)
140 .await
141 .ok_or(Error::CredentialNotFound(ct)),
142 }
143 }
144
145 #[cfg(test)]
146 pub(crate) async fn identities_count(&self) -> Result<usize> {
147 match self.state.read().await.deref() {
148 None => Err(Error::MlsNotInitialized),
149 Some(ClientInner { identities, .. }) => Ok(identities.iter().count()),
150 }
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use crate::{mls, test_utils::*};
157 use openmls::prelude::SignaturePublicKey;
158 use rand::Rng;
159 use wasm_bindgen_test::*;
160
161 wasm_bindgen_test_configure!(run_in_browser);
162
163 mod find {
164 use super::*;
165
166 #[apply(all_cred_cipher)]
167 #[wasm_bindgen_test]
168 async fn should_find_most_recent(case: TestCase) {
169 run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
170 Box::pin(async move {
171 let cert = central.get_intermediate_ca().cloned();
172 let old = central.new_credential_bundle(&case, cert.as_ref()).await;
173
174 async_std::task::sleep(core::time::Duration::from_secs(1)).await;
176
177 let new = central.new_credential_bundle(&case, cert.as_ref()).await;
178 assert_ne!(old, new);
179
180 let found = central
181 .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
182 .await
183 .unwrap();
184 assert_eq!(found.as_ref(), &new);
185 })
186 })
187 .await
188 }
189
190 #[apply(all_cred_cipher)]
191 #[wasm_bindgen_test]
192 async fn should_find_by_public_key(case: TestCase) {
193 run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
194 Box::pin(async move {
195 const N: usize = 50;
196
197 let r = rand::thread_rng().gen_range(0..N);
198 let mut to_search = None;
199 for i in 0..N {
200 let cert = central.get_intermediate_ca().cloned();
201 let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
202 if i == r {
203 to_search = Some(cb.clone());
204 }
205 }
206 let to_search = to_search.unwrap();
207 let pk = SignaturePublicKey::from(to_search.signature_key.public());
208 let client = central.context.mls_client().await.unwrap();
209 let found = client
210 .find_credential_bundle_by_public_key(case.signature_scheme(), case.credential_type, &pk)
211 .await
212 .unwrap();
213 assert_eq!(&to_search, found.as_ref());
214 })
215 })
216 .await
217 }
218 }
219
220 mod push {
221 use super::*;
222
223 #[apply(all_cred_cipher)]
224 #[wasm_bindgen_test]
225 async fn should_add_credential(case: TestCase) {
226 run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
227 Box::pin(async move {
228 let client = central.client().await;
229 let prev_count = client.identities_count().await.unwrap();
230 let cert = central.get_intermediate_ca().cloned();
231 central.new_credential_bundle(&case, cert.as_ref()).await;
233 let next_count = client.identities_count().await.unwrap();
234 assert_eq!(next_count, prev_count + 1);
235 })
236 })
237 .await
238 }
239
240 #[apply(all_cred_cipher)]
241 #[wasm_bindgen_test]
242 async fn pushing_duplicates_should_fail(case: TestCase) {
243 run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
244 Box::pin(async move {
245 let cert = central.get_intermediate_ca().cloned();
246 let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
247 let client = central.context.mls_client().await.unwrap();
248 let push = client
249 .save_identity(
250 ¢ral.context.keystore().await.unwrap(),
251 None,
252 case.signature_scheme(),
253 cb,
254 )
255 .await;
256 assert!(matches!(
257 push.unwrap_err(),
258 mls::client::Error::CredentialBundleConflict
259 ));
260 })
261 })
262 .await
263 }
264 }
265}