core_crypto/mls/client/
identities.rs
1use crate::mls::client::ClientInner;
2use crate::{
3 mls::credential::{typ::MlsCredentialType, CredentialBundle},
4 prelude::{Client, CryptoError, CryptoResult, MlsConversation},
5};
6use openmls::prelude::{Credential, SignaturePublicKey};
7use openmls_traits::types::SignatureScheme;
8use std::collections::HashMap;
9use std::ops::Deref;
10use std::sync::Arc;
11
12#[derive(Debug, Clone)]
18pub(crate) struct ClientIdentities(HashMap<SignatureScheme, indexmap::IndexSet<Arc<CredentialBundle>>>);
19
20impl ClientIdentities {
21 pub(crate) fn new(capacity: usize) -> Self {
22 Self(HashMap::with_capacity(capacity))
23 }
24
25 pub(crate) async fn find_credential_bundle_by_public_key(
26 &self,
27 sc: SignatureScheme,
28 ct: MlsCredentialType,
29 pk: &SignaturePublicKey,
30 ) -> Option<Arc<CredentialBundle>> {
31 self.0
32 .get(&sc)?
33 .iter()
34 .find(|c| {
35 let ct_match = ct == c.credential.credential_type().into();
36 let pk_match = c.signature_key.public() == pk.as_slice();
37 ct_match && pk_match
38 })
39 .cloned()
40 }
41
42 pub(crate) async fn find_most_recent_credential_bundle(
43 &self,
44 sc: SignatureScheme,
45 ct: MlsCredentialType,
46 ) -> Option<Arc<CredentialBundle>> {
47 self.0
48 .get(&sc)?
49 .iter()
50 .rfind(|c| ct == c.credential.credential_type().into())
51 .cloned()
52 }
53
54 pub(crate) async fn push_credential_bundle(
57 &mut self,
58 sc: SignatureScheme,
59 cb: CredentialBundle,
60 ) -> CryptoResult<()> {
61 debug_assert_ne!(cb.created_at, 0);
63
64 match self.0.get_mut(&sc) {
65 Some(cbs) => {
66 let already_exists = !cbs.insert(Arc::new(cb));
67 if already_exists {
68 return Err(CryptoError::CredentialBundleConflict);
69 }
70 }
71 None => {
72 self.0.insert(sc, indexmap::IndexSet::from([Arc::new(cb)]));
73 }
74 }
75 Ok(())
76 }
77
78 pub(crate) async fn remove(&mut self, credential: &Credential) -> CryptoResult<()> {
79 self.0.iter_mut().for_each(|(_, cbs)| {
80 cbs.retain(|c| c.credential() != credential);
81 });
82 Ok(())
83 }
84
85 pub(crate) fn iter(&self) -> impl Iterator<Item = (SignatureScheme, Arc<CredentialBundle>)> + '_ {
86 self.0.iter().flat_map(|(sc, cb)| cb.iter().map(|c| (*sc, c.clone())))
87 }
88}
89
90impl MlsConversation {
91 pub(crate) async fn find_current_credential_bundle(&self, client: &Client) -> CryptoResult<Arc<CredentialBundle>> {
92 let own_leaf = self.group.own_leaf().ok_or(CryptoError::InternalMlsError)?;
93 let sc = self.ciphersuite().signature_algorithm();
94 let ct = self.own_credential_type()?;
95
96 client
97 .find_credential_bundle_by_public_key(sc, ct, own_leaf.signature_key())
98 .await
99 }
100
101 pub(crate) async fn find_most_recent_credential_bundle(
102 &self,
103 client: &Client,
104 ) -> CryptoResult<Arc<CredentialBundle>> {
105 let sc = self.ciphersuite().signature_algorithm();
106 let ct = self.own_credential_type()?;
107
108 client.find_most_recent_credential_bundle(sc, ct).await
109 }
110}
111
112impl Client {
113 pub(crate) async fn find_most_recent_credential_bundle(
114 &self,
115 sc: SignatureScheme,
116 ct: MlsCredentialType,
117 ) -> CryptoResult<Arc<CredentialBundle>> {
118 match self.state.read().await.deref() {
119 None => Err(CryptoError::MlsNotInitialized),
120 Some(ClientInner { identities, .. }) => identities
121 .find_most_recent_credential_bundle(sc, ct)
122 .await
123 .ok_or(CryptoError::CredentialNotFound(ct)),
124 }
125 }
126
127 pub(crate) async fn find_credential_bundle_by_public_key(
128 &self,
129 sc: SignatureScheme,
130 ct: MlsCredentialType,
131 pk: &SignaturePublicKey,
132 ) -> CryptoResult<Arc<CredentialBundle>> {
133 match self.state.read().await.deref() {
134 None => Err(CryptoError::MlsNotInitialized),
135 Some(ClientInner { identities, .. }) => identities
136 .find_credential_bundle_by_public_key(sc, ct, pk)
137 .await
138 .ok_or(CryptoError::CredentialNotFound(ct)),
139 }
140 }
141
142 #[cfg(test)]
143 pub(crate) async fn identities_count(&self) -> CryptoResult<usize> {
144 match self.state.read().await.deref() {
145 None => Err(CryptoError::MlsNotInitialized),
146 Some(ClientInner { identities, .. }) => Ok(identities.iter().count()),
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use crate::test_utils::*;
154 use openmls::prelude::SignaturePublicKey;
155 use rand::Rng;
156 use wasm_bindgen_test::*;
157
158 wasm_bindgen_test_configure!(run_in_browser);
159
160 mod find {
161 use super::*;
162
163 #[apply(all_cred_cipher)]
164 #[wasm_bindgen_test]
165 async fn should_find_most_recent(case: TestCase) {
166 run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
167 Box::pin(async move {
168 let cert = central.get_intermediate_ca().cloned();
169 let old = central.new_credential_bundle(&case, cert.as_ref()).await;
170
171 async_std::task::sleep(core::time::Duration::from_secs(1)).await;
173
174 let new = central.new_credential_bundle(&case, cert.as_ref()).await;
175 assert_ne!(old, new);
176
177 let found = central
178 .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
179 .await
180 .unwrap();
181 assert_eq!(found.as_ref(), &new);
182 })
183 })
184 .await
185 }
186
187 #[apply(all_cred_cipher)]
188 #[wasm_bindgen_test]
189 async fn should_find_by_public_key(case: TestCase) {
190 run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
191 Box::pin(async move {
192 const N: usize = 50;
193
194 let r = rand::thread_rng().gen_range(0..N);
195 let mut to_search = None;
196 for i in 0..N {
197 let cert = central.get_intermediate_ca().cloned();
198 let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
199 if i == r {
200 to_search = Some(cb.clone());
201 }
202 }
203 let to_search = to_search.unwrap();
204 let pk = SignaturePublicKey::from(to_search.signature_key.public());
205 let client = central.context.mls_client().await.unwrap();
206 let found = client
207 .find_credential_bundle_by_public_key(case.signature_scheme(), case.credential_type, &pk)
208 .await
209 .unwrap();
210 assert_eq!(&to_search, found.as_ref());
211 })
212 })
213 .await
214 }
215 }
216
217 mod push {
218 use super::*;
219
220 #[apply(all_cred_cipher)]
221 #[wasm_bindgen_test]
222 async fn should_add_credential(case: TestCase) {
223 run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
224 Box::pin(async move {
225 let client = central.client().await;
226 let prev_count = client.identities_count().await.unwrap();
227 let cert = central.get_intermediate_ca().cloned();
228 central.new_credential_bundle(&case, cert.as_ref()).await;
230 let next_count = client.identities_count().await.unwrap();
231 assert_eq!(next_count, prev_count + 1);
232 })
233 })
234 .await
235 }
236
237 #[apply(all_cred_cipher)]
238 #[wasm_bindgen_test]
239 async fn pushing_duplicates_should_fail(case: TestCase) {
240 run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
241 Box::pin(async move {
242 let cert = central.get_intermediate_ca().cloned();
243 let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
244 let client = central.context.mls_client().await.unwrap();
245 let push = client
246 .save_identity(
247 ¢ral.context.keystore().await.unwrap(),
248 None,
249 case.signature_scheme(),
250 cb,
251 )
252 .await;
253 assert!(matches!(
254 push.unwrap_err(),
255 crate::CryptoError::CredentialBundleConflict
256 ));
257 })
258 })
259 .await
260 }
261 }
262}