core_crypto/mls/session/
identities.rs

1use crate::mls::session::{
2    SessionInner,
3    error::{Error, Result},
4};
5use crate::{
6    mls::credential::{CredentialBundle, typ::MlsCredentialType},
7    prelude::Session,
8};
9use openmls::prelude::{Credential, SignaturePublicKey};
10use openmls_traits::types::SignatureScheme;
11use std::collections::HashMap;
12use std::ops::Deref;
13use std::sync::Arc;
14
15/// In memory Map of a Session's identities: one per SignatureScheme.
16/// We need `indexmap::IndexSet` because each `CredentialBundle` has to be unique and insertion
17/// order matters in order to keep values sorted by time `created_at` so that we can identify most recent ones.
18///
19/// We keep each credential bundle inside an arc to avoid cloning them, as X509 credentials can get quite large.
20#[derive(Debug, Clone)]
21pub(crate) struct Identities(HashMap<SignatureScheme, indexmap::IndexSet<Arc<CredentialBundle>>>);
22
23impl Identities {
24    pub(crate) fn new(capacity: usize) -> Self {
25        Self(HashMap::with_capacity(capacity))
26    }
27
28    pub(crate) async fn find_credential_bundle_by_public_key(
29        &self,
30        sc: SignatureScheme,
31        ct: MlsCredentialType,
32        pk: &SignaturePublicKey,
33    ) -> Option<Arc<CredentialBundle>> {
34        self.0
35            .get(&sc)?
36            .iter()
37            .find(|c| {
38                let ct_match = ct == c.credential.credential_type().into();
39                let pk_match = c.signature_key.public() == pk.as_slice();
40                ct_match && pk_match
41            })
42            .cloned()
43    }
44
45    pub(crate) async fn find_most_recent_credential_bundle(
46        &self,
47        sc: SignatureScheme,
48        ct: MlsCredentialType,
49    ) -> Option<Arc<CredentialBundle>> {
50        self.0
51            .get(&sc)?
52            .iter()
53            .rfind(|c| ct == c.credential.credential_type().into())
54            .cloned()
55    }
56
57    /// Having `cb` requiring ownership kinda forces the caller to first persist it in the keystore and
58    /// only then store it in this in-memory map
59    pub(crate) async fn push_credential_bundle(&mut self, sc: SignatureScheme, cb: CredentialBundle) -> Result<()> {
60        // this would mean we have messed something up and that we do no init this CredentialBundle from a keypair just inserted in the keystore
61        debug_assert_ne!(cb.created_at, 0);
62
63        match self.0.get_mut(&sc) {
64            Some(cbs) => {
65                let already_exists = !cbs.insert(Arc::new(cb));
66                if already_exists {
67                    return Err(Error::CredentialBundleConflict);
68                }
69            }
70            None => {
71                self.0.insert(sc, indexmap::IndexSet::from([Arc::new(cb)]));
72            }
73        }
74        Ok(())
75    }
76
77    pub(crate) async fn remove(&mut self, credential: &Credential) -> Result<()> {
78        self.0.iter_mut().for_each(|(_, cbs)| {
79            cbs.retain(|c| c.credential() != credential);
80        });
81        Ok(())
82    }
83
84    pub(crate) fn iter(&self) -> impl Iterator<Item = (SignatureScheme, Arc<CredentialBundle>)> + '_ {
85        self.0.iter().flat_map(|(sc, cb)| cb.iter().map(|c| (*sc, c.clone())))
86    }
87}
88
89impl Session {
90    pub(crate) async fn find_most_recent_credential_bundle(
91        &self,
92        sc: SignatureScheme,
93        ct: MlsCredentialType,
94    ) -> Result<Arc<CredentialBundle>> {
95        match self.inner.read().await.deref() {
96            None => Err(Error::MlsNotInitialized),
97            Some(SessionInner { identities, .. }) => identities
98                .find_most_recent_credential_bundle(sc, ct)
99                .await
100                .ok_or(Error::CredentialNotFound(ct)),
101        }
102    }
103
104    pub(crate) async fn find_credential_bundle_by_public_key(
105        &self,
106        sc: SignatureScheme,
107        ct: MlsCredentialType,
108        pk: &SignaturePublicKey,
109    ) -> Result<Arc<CredentialBundle>> {
110        match self.inner.read().await.deref() {
111            None => Err(Error::MlsNotInitialized),
112            Some(SessionInner { identities, .. }) => identities
113                .find_credential_bundle_by_public_key(sc, ct, pk)
114                .await
115                .ok_or(Error::CredentialNotFound(ct)),
116        }
117    }
118
119    #[cfg(test)]
120    pub(crate) async fn identities_count(&self) -> Result<usize> {
121        match self.inner.read().await.deref() {
122            None => Err(Error::MlsNotInitialized),
123            Some(SessionInner { identities, .. }) => Ok(identities.iter().count()),
124        }
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use crate::{mls, test_utils::*};
131    use openmls::prelude::SignaturePublicKey;
132    use rand::Rng;
133    use wasm_bindgen_test::*;
134
135    wasm_bindgen_test_configure!(run_in_browser);
136
137    mod find {
138        use super::*;
139
140        #[apply(all_cred_cipher)]
141        #[wasm_bindgen_test]
142        async fn should_find_most_recent(case: TestContext) {
143            let [mut central] = case.sessions().await;
144            Box::pin(async move {
145                let cert = central.get_intermediate_ca().cloned();
146                let old = central.new_credential_bundle(&case, cert.as_ref()).await;
147
148                // wait to make sure we're not in the same second
149                async_std::task::sleep(core::time::Duration::from_secs(1)).await;
150
151                let new = central.new_credential_bundle(&case, cert.as_ref()).await;
152                assert_ne!(old, new);
153
154                let found = central
155                    .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
156                    .await
157                    .unwrap();
158                assert_eq!(found.as_ref(), &new);
159            })
160            .await
161        }
162
163        #[apply(all_cred_cipher)]
164        #[wasm_bindgen_test]
165        async fn should_find_by_public_key(case: TestContext) {
166            let [mut central] = case.sessions().await;
167            Box::pin(async move {
168                const N: usize = 50;
169
170                let r = rand::thread_rng().gen_range(0..N);
171                let mut to_search = None;
172                for i in 0..N {
173                    let cert = central.get_intermediate_ca().cloned();
174                    let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
175                    if i == r {
176                        to_search = Some(cb.clone());
177                    }
178                }
179                let to_search = to_search.unwrap();
180                let pk = SignaturePublicKey::from(to_search.signature_key.public());
181                let client = central.transaction.session().await.unwrap();
182                let found = client
183                    .find_credential_bundle_by_public_key(case.signature_scheme(), case.credential_type, &pk)
184                    .await
185                    .unwrap();
186                assert_eq!(&to_search, found.as_ref());
187            })
188            .await
189        }
190    }
191
192    mod push {
193        use super::*;
194
195        #[apply(all_cred_cipher)]
196        #[wasm_bindgen_test]
197        async fn should_add_credential(case: TestContext) {
198            let [mut central] = case.sessions().await;
199            Box::pin(async move {
200                let client = central.session().await;
201                let prev_count = client.identities_count().await.unwrap();
202                let cert = central.get_intermediate_ca().cloned();
203                // this calls 'push_credential_bundle' under the hood
204                central.new_credential_bundle(&case, cert.as_ref()).await;
205                let next_count = client.identities_count().await.unwrap();
206                assert_eq!(next_count, prev_count + 1);
207            })
208            .await
209        }
210
211        #[apply(all_cred_cipher)]
212        #[wasm_bindgen_test]
213        async fn pushing_duplicates_should_fail(case: TestContext) {
214            let [mut central] = case.sessions().await;
215            Box::pin(async move {
216                let cert = central.get_intermediate_ca().cloned();
217                let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
218                let client = central.transaction.session().await.unwrap();
219                let push = client
220                    .save_identity(
221                        &central.transaction.keystore().await.unwrap(),
222                        None,
223                        case.signature_scheme(),
224                        cb,
225                    )
226                    .await;
227                assert!(matches!(
228                    push.unwrap_err(),
229                    mls::session::Error::CredentialBundleConflict
230                ));
231            })
232            .await
233        }
234    }
235}