core_crypto/mls/session/
identities.rs1use crate::mls::session::{
2 SessionInner,
3 error::{Error, Result},
4};
5use crate::{
6 mls::credential::{CredentialBundle, typ::MlsCredentialType},
7 prelude::Session,
8};
9use openmls::prelude::{Credential, SignaturePublicKey};
10use openmls_traits::types::SignatureScheme;
11use std::collections::HashMap;
12use std::ops::Deref;
13use std::sync::Arc;
14
15#[derive(Debug, Clone)]
21pub(crate) struct Identities(HashMap<SignatureScheme, indexmap::IndexSet<Arc<CredentialBundle>>>);
22
23impl Identities {
24 pub(crate) fn new(capacity: usize) -> Self {
25 Self(HashMap::with_capacity(capacity))
26 }
27
28 pub(crate) async fn find_credential_bundle_by_public_key(
29 &self,
30 sc: SignatureScheme,
31 ct: MlsCredentialType,
32 pk: &SignaturePublicKey,
33 ) -> Option<Arc<CredentialBundle>> {
34 self.0
35 .get(&sc)?
36 .iter()
37 .find(|c| {
38 let ct_match = ct == c.credential.credential_type().into();
39 let pk_match = c.signature_key.public() == pk.as_slice();
40 ct_match && pk_match
41 })
42 .cloned()
43 }
44
45 pub(crate) async fn find_most_recent_credential_bundle(
46 &self,
47 sc: SignatureScheme,
48 ct: MlsCredentialType,
49 ) -> Option<Arc<CredentialBundle>> {
50 self.0
51 .get(&sc)?
52 .iter()
53 .rfind(|c| ct == c.credential.credential_type().into())
54 .cloned()
55 }
56
57 pub(crate) async fn push_credential_bundle(&mut self, sc: SignatureScheme, cb: CredentialBundle) -> Result<()> {
60 debug_assert_ne!(cb.created_at, 0);
62
63 match self.0.get_mut(&sc) {
64 Some(cbs) => {
65 let already_exists = !cbs.insert(Arc::new(cb));
66 if already_exists {
67 return Err(Error::CredentialBundleConflict);
68 }
69 }
70 None => {
71 self.0.insert(sc, indexmap::IndexSet::from([Arc::new(cb)]));
72 }
73 }
74 Ok(())
75 }
76
77 pub(crate) async fn remove(&mut self, credential: &Credential) -> Result<()> {
78 self.0.iter_mut().for_each(|(_, cbs)| {
79 cbs.retain(|c| c.credential() != credential);
80 });
81 Ok(())
82 }
83
84 pub(crate) fn iter(&self) -> impl Iterator<Item = (SignatureScheme, Arc<CredentialBundle>)> + '_ {
85 self.0.iter().flat_map(|(sc, cb)| cb.iter().map(|c| (*sc, c.clone())))
86 }
87}
88
89impl Session {
90 pub(crate) async fn find_most_recent_credential_bundle(
91 &self,
92 sc: SignatureScheme,
93 ct: MlsCredentialType,
94 ) -> Result<Arc<CredentialBundle>> {
95 match self.inner.read().await.deref() {
96 None => Err(Error::MlsNotInitialized),
97 Some(SessionInner { identities, .. }) => identities
98 .find_most_recent_credential_bundle(sc, ct)
99 .await
100 .ok_or(Error::CredentialNotFound(ct)),
101 }
102 }
103
104 pub(crate) async fn find_credential_bundle_by_public_key(
105 &self,
106 sc: SignatureScheme,
107 ct: MlsCredentialType,
108 pk: &SignaturePublicKey,
109 ) -> Result<Arc<CredentialBundle>> {
110 match self.inner.read().await.deref() {
111 None => Err(Error::MlsNotInitialized),
112 Some(SessionInner { identities, .. }) => identities
113 .find_credential_bundle_by_public_key(sc, ct, pk)
114 .await
115 .ok_or(Error::CredentialNotFound(ct)),
116 }
117 }
118
119 #[cfg(test)]
120 pub(crate) async fn identities_count(&self) -> Result<usize> {
121 match self.inner.read().await.deref() {
122 None => Err(Error::MlsNotInitialized),
123 Some(SessionInner { identities, .. }) => Ok(identities.iter().count()),
124 }
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use crate::{mls, test_utils::*};
131 use openmls::prelude::SignaturePublicKey;
132 use rand::Rng;
133 use wasm_bindgen_test::*;
134
135 wasm_bindgen_test_configure!(run_in_browser);
136
137 mod find {
138 use super::*;
139
140 #[apply(all_cred_cipher)]
141 #[wasm_bindgen_test]
142 async fn should_find_most_recent(case: TestContext) {
143 let [mut central] = case.sessions().await;
144 Box::pin(async move {
145 let cert = central.get_intermediate_ca().cloned();
146 let old = central.new_credential_bundle(&case, cert.as_ref()).await;
147
148 async_std::task::sleep(core::time::Duration::from_secs(1)).await;
150
151 let new = central.new_credential_bundle(&case, cert.as_ref()).await;
152 assert_ne!(old, new);
153
154 let found = central
155 .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
156 .await
157 .unwrap();
158 assert_eq!(found.as_ref(), &new);
159 })
160 .await
161 }
162
163 #[apply(all_cred_cipher)]
164 #[wasm_bindgen_test]
165 async fn should_find_by_public_key(case: TestContext) {
166 let [mut central] = case.sessions().await;
167 Box::pin(async move {
168 const N: usize = 50;
169
170 let r = rand::thread_rng().gen_range(0..N);
171 let mut to_search = None;
172 for i in 0..N {
173 let cert = central.get_intermediate_ca().cloned();
174 let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
175 if i == r {
176 to_search = Some(cb.clone());
177 }
178 }
179 let to_search = to_search.unwrap();
180 let pk = SignaturePublicKey::from(to_search.signature_key.public());
181 let client = central.transaction.session().await.unwrap();
182 let found = client
183 .find_credential_bundle_by_public_key(case.signature_scheme(), case.credential_type, &pk)
184 .await
185 .unwrap();
186 assert_eq!(&to_search, found.as_ref());
187 })
188 .await
189 }
190 }
191
192 mod push {
193 use super::*;
194
195 #[apply(all_cred_cipher)]
196 #[wasm_bindgen_test]
197 async fn should_add_credential(case: TestContext) {
198 let [mut central] = case.sessions().await;
199 Box::pin(async move {
200 let client = central.session().await;
201 let prev_count = client.identities_count().await.unwrap();
202 let cert = central.get_intermediate_ca().cloned();
203 central.new_credential_bundle(&case, cert.as_ref()).await;
205 let next_count = client.identities_count().await.unwrap();
206 assert_eq!(next_count, prev_count + 1);
207 })
208 .await
209 }
210
211 #[apply(all_cred_cipher)]
212 #[wasm_bindgen_test]
213 async fn pushing_duplicates_should_fail(case: TestContext) {
214 let [mut central] = case.sessions().await;
215 Box::pin(async move {
216 let cert = central.get_intermediate_ca().cloned();
217 let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
218 let client = central.transaction.session().await.unwrap();
219 let push = client
220 .save_identity(
221 ¢ral.transaction.keystore().await.unwrap(),
222 None,
223 case.signature_scheme(),
224 cb,
225 )
226 .await;
227 assert!(matches!(
228 push.unwrap_err(),
229 mls::session::Error::CredentialBundleConflict
230 ));
231 })
232 .await
233 }
234 }
235}