core_crypto/mls/session/
identities.rs1use crate::mls::session::{
2 SessionInner,
3 error::{Error, Result},
4};
5use crate::{
6 mls::credential::{CredentialBundle, typ::MlsCredentialType},
7 prelude::Session,
8};
9use openmls::prelude::{Credential, SignaturePublicKey};
10use openmls_traits::types::SignatureScheme;
11use std::collections::HashMap;
12use std::ops::Deref;
13use std::sync::Arc;
14
15#[derive(Debug, Clone)]
21pub(crate) struct Identities(HashMap<SignatureScheme, indexmap::IndexSet<Arc<CredentialBundle>>>);
22
23impl Identities {
24 pub(crate) fn new(capacity: usize) -> Self {
25 Self(HashMap::with_capacity(capacity))
26 }
27
28 pub(crate) async fn find_credential_bundle_by_public_key(
29 &self,
30 sc: SignatureScheme,
31 ct: MlsCredentialType,
32 pk: &SignaturePublicKey,
33 ) -> Option<Arc<CredentialBundle>> {
34 self.0
35 .get(&sc)?
36 .iter()
37 .find(|c| {
38 let ct_match = ct == c.credential.credential_type().into();
39 let pk_match = c.signature_key.public() == pk.as_slice();
40 ct_match && pk_match
41 })
42 .cloned()
43 }
44
45 pub(crate) async fn find_most_recent_credential_bundle(
46 &self,
47 sc: SignatureScheme,
48 ct: MlsCredentialType,
49 ) -> Option<Arc<CredentialBundle>> {
50 self.0
51 .get(&sc)?
52 .iter()
53 .rfind(|c| ct == c.credential.credential_type().into())
54 .cloned()
55 }
56
57 pub(crate) async fn push_credential_bundle(&mut self, sc: SignatureScheme, cb: CredentialBundle) -> Result<()> {
60 debug_assert_ne!(cb.created_at, 0);
62
63 match self.0.get_mut(&sc) {
64 Some(cbs) => {
65 let already_exists = !cbs.insert(Arc::new(cb));
66 if already_exists {
67 return Err(Error::CredentialBundleConflict);
68 }
69 }
70 None => {
71 self.0.insert(sc, indexmap::IndexSet::from([Arc::new(cb)]));
72 }
73 }
74 Ok(())
75 }
76
77 pub(crate) async fn remove(&mut self, credential: &Credential) -> Result<()> {
78 self.0.iter_mut().for_each(|(_, cbs)| {
79 cbs.retain(|c| c.credential() != credential);
80 });
81 Ok(())
82 }
83
84 pub(crate) fn iter(&self) -> impl Iterator<Item = (SignatureScheme, Arc<CredentialBundle>)> + '_ {
85 self.0.iter().flat_map(|(sc, cb)| cb.iter().map(|c| (*sc, c.clone())))
86 }
87}
88
89impl Session {
90 pub(crate) async fn find_most_recent_credential_bundle(
91 &self,
92 sc: SignatureScheme,
93 ct: MlsCredentialType,
94 ) -> Result<Arc<CredentialBundle>> {
95 match self.inner.read().await.deref() {
96 None => Err(Error::MlsNotInitialized),
97 Some(SessionInner { identities, .. }) => identities
98 .find_most_recent_credential_bundle(sc, ct)
99 .await
100 .ok_or(Error::CredentialNotFound(ct)),
101 }
102 }
103
104 pub(crate) async fn find_credential_bundle_by_public_key(
105 &self,
106 sc: SignatureScheme,
107 ct: MlsCredentialType,
108 pk: &SignaturePublicKey,
109 ) -> Result<Arc<CredentialBundle>> {
110 match self.inner.read().await.deref() {
111 None => Err(Error::MlsNotInitialized),
112 Some(SessionInner { identities, .. }) => identities
113 .find_credential_bundle_by_public_key(sc, ct, pk)
114 .await
115 .ok_or(Error::CredentialNotFound(ct)),
116 }
117 }
118
119 #[cfg(test)]
120 pub(crate) async fn identities_count(&self) -> Result<usize> {
121 match self.inner.read().await.deref() {
122 None => Err(Error::MlsNotInitialized),
123 Some(SessionInner { identities, .. }) => Ok(identities.iter().count()),
124 }
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use crate::{mls, test_utils::*};
131 use openmls::prelude::SignaturePublicKey;
132 use rand::Rng;
133 use wasm_bindgen_test::*;
134
135 wasm_bindgen_test_configure!(run_in_browser);
136
137 mod find {
138 use super::*;
139
140 #[apply(all_cred_cipher)]
141 #[wasm_bindgen_test]
142 async fn should_find_most_recent(case: TestContext) {
143 run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
144 Box::pin(async move {
145 let cert = central.get_intermediate_ca().cloned();
146 let old = central.new_credential_bundle(&case, cert.as_ref()).await;
147
148 async_std::task::sleep(core::time::Duration::from_secs(1)).await;
150
151 let new = central.new_credential_bundle(&case, cert.as_ref()).await;
152 assert_ne!(old, new);
153
154 let found = central
155 .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
156 .await
157 .unwrap();
158 assert_eq!(found.as_ref(), &new);
159 })
160 })
161 .await
162 }
163
164 #[apply(all_cred_cipher)]
165 #[wasm_bindgen_test]
166 async fn should_find_by_public_key(case: TestContext) {
167 run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
168 Box::pin(async move {
169 const N: usize = 50;
170
171 let r = rand::thread_rng().gen_range(0..N);
172 let mut to_search = None;
173 for i in 0..N {
174 let cert = central.get_intermediate_ca().cloned();
175 let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
176 if i == r {
177 to_search = Some(cb.clone());
178 }
179 }
180 let to_search = to_search.unwrap();
181 let pk = SignaturePublicKey::from(to_search.signature_key.public());
182 let client = central.transaction.session().await.unwrap();
183 let found = client
184 .find_credential_bundle_by_public_key(case.signature_scheme(), case.credential_type, &pk)
185 .await
186 .unwrap();
187 assert_eq!(&to_search, found.as_ref());
188 })
189 })
190 .await
191 }
192 }
193
194 mod push {
195 use super::*;
196
197 #[apply(all_cred_cipher)]
198 #[wasm_bindgen_test]
199 async fn should_add_credential(case: TestContext) {
200 run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
201 Box::pin(async move {
202 let client = central.session().await;
203 let prev_count = client.identities_count().await.unwrap();
204 let cert = central.get_intermediate_ca().cloned();
205 central.new_credential_bundle(&case, cert.as_ref()).await;
207 let next_count = client.identities_count().await.unwrap();
208 assert_eq!(next_count, prev_count + 1);
209 })
210 })
211 .await
212 }
213
214 #[apply(all_cred_cipher)]
215 #[wasm_bindgen_test]
216 async fn pushing_duplicates_should_fail(case: TestContext) {
217 run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
218 Box::pin(async move {
219 let cert = central.get_intermediate_ca().cloned();
220 let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
221 let client = central.transaction.session().await.unwrap();
222 let push = client
223 .save_identity(
224 ¢ral.transaction.keystore().await.unwrap(),
225 None,
226 case.signature_scheme(),
227 cb,
228 )
229 .await;
230 assert!(matches!(
231 push.unwrap_err(),
232 mls::session::Error::CredentialBundleConflict
233 ));
234 })
235 })
236 .await
237 }
238 }
239}