core_crypto/mls/session/
identities.rs

1use crate::mls::session::{
2    SessionInner,
3    error::{Error, Result},
4};
5use crate::{
6    mls::credential::{CredentialBundle, typ::MlsCredentialType},
7    prelude::Session,
8};
9use openmls::prelude::{Credential, SignaturePublicKey};
10use openmls_traits::types::SignatureScheme;
11use std::collections::HashMap;
12use std::ops::Deref;
13use std::sync::Arc;
14
15/// In memory Map of a Session's identities: one per SignatureScheme.
16/// We need `indexmap::IndexSet` because each `CredentialBundle` has to be unique and insertion
17/// order matters in order to keep values sorted by time `created_at` so that we can identify most recent ones.
18///
19/// We keep each credential bundle inside an arc to avoid cloning them, as X509 credentials can get quite large.
20#[derive(Debug, Clone)]
21pub(crate) struct Identities(HashMap<SignatureScheme, indexmap::IndexSet<Arc<CredentialBundle>>>);
22
23impl Identities {
24    pub(crate) fn new(capacity: usize) -> Self {
25        Self(HashMap::with_capacity(capacity))
26    }
27
28    pub(crate) async fn find_credential_bundle_by_public_key(
29        &self,
30        sc: SignatureScheme,
31        ct: MlsCredentialType,
32        pk: &SignaturePublicKey,
33    ) -> Option<Arc<CredentialBundle>> {
34        self.0
35            .get(&sc)?
36            .iter()
37            .find(|c| {
38                let ct_match = ct == c.credential.credential_type().into();
39                let pk_match = c.signature_key.public() == pk.as_slice();
40                ct_match && pk_match
41            })
42            .cloned()
43    }
44
45    pub(crate) async fn find_most_recent_credential_bundle(
46        &self,
47        sc: SignatureScheme,
48        ct: MlsCredentialType,
49    ) -> Option<Arc<CredentialBundle>> {
50        self.0
51            .get(&sc)?
52            .iter()
53            .rfind(|c| ct == c.credential.credential_type().into())
54            .cloned()
55    }
56
57    /// Having `cb` requiring ownership kinda forces the caller to first persist it in the keystore and
58    /// only then store it in this in-memory map
59    pub(crate) async fn push_credential_bundle(&mut self, sc: SignatureScheme, cb: CredentialBundle) -> Result<()> {
60        // this would mean we have messed something up and that we do no init this CredentialBundle from a keypair just inserted in the keystore
61        debug_assert_ne!(cb.created_at, 0);
62
63        match self.0.get_mut(&sc) {
64            Some(cbs) => {
65                let already_exists = !cbs.insert(Arc::new(cb));
66                if already_exists {
67                    return Err(Error::CredentialBundleConflict);
68                }
69            }
70            None => {
71                self.0.insert(sc, indexmap::IndexSet::from([Arc::new(cb)]));
72            }
73        }
74        Ok(())
75    }
76
77    pub(crate) async fn remove(&mut self, credential: &Credential) -> Result<()> {
78        self.0.iter_mut().for_each(|(_, cbs)| {
79            cbs.retain(|c| c.credential() != credential);
80        });
81        Ok(())
82    }
83
84    pub(crate) fn iter(&self) -> impl Iterator<Item = (SignatureScheme, Arc<CredentialBundle>)> + '_ {
85        self.0.iter().flat_map(|(sc, cb)| cb.iter().map(|c| (*sc, c.clone())))
86    }
87}
88
89impl Session {
90    pub(crate) async fn find_most_recent_credential_bundle(
91        &self,
92        sc: SignatureScheme,
93        ct: MlsCredentialType,
94    ) -> Result<Arc<CredentialBundle>> {
95        match self.inner.read().await.deref() {
96            None => Err(Error::MlsNotInitialized),
97            Some(SessionInner { identities, .. }) => identities
98                .find_most_recent_credential_bundle(sc, ct)
99                .await
100                .ok_or(Error::CredentialNotFound(ct)),
101        }
102    }
103
104    pub(crate) async fn find_credential_bundle_by_public_key(
105        &self,
106        sc: SignatureScheme,
107        ct: MlsCredentialType,
108        pk: &SignaturePublicKey,
109    ) -> Result<Arc<CredentialBundle>> {
110        match self.inner.read().await.deref() {
111            None => Err(Error::MlsNotInitialized),
112            Some(SessionInner { identities, .. }) => identities
113                .find_credential_bundle_by_public_key(sc, ct, pk)
114                .await
115                .ok_or(Error::CredentialNotFound(ct)),
116        }
117    }
118
119    #[cfg(test)]
120    pub(crate) async fn identities_count(&self) -> Result<usize> {
121        match self.inner.read().await.deref() {
122            None => Err(Error::MlsNotInitialized),
123            Some(SessionInner { identities, .. }) => Ok(identities.iter().count()),
124        }
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use crate::{mls, test_utils::*};
131    use openmls::prelude::SignaturePublicKey;
132    use rand::Rng;
133    use wasm_bindgen_test::*;
134
135    wasm_bindgen_test_configure!(run_in_browser);
136
137    mod find {
138        use super::*;
139
140        #[apply(all_cred_cipher)]
141        #[wasm_bindgen_test]
142        async fn should_find_most_recent(case: TestContext) {
143            run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
144                Box::pin(async move {
145                    let cert = central.get_intermediate_ca().cloned();
146                    let old = central.new_credential_bundle(&case, cert.as_ref()).await;
147
148                    // wait to make sure we're not in the same second
149                    async_std::task::sleep(core::time::Duration::from_secs(1)).await;
150
151                    let new = central.new_credential_bundle(&case, cert.as_ref()).await;
152                    assert_ne!(old, new);
153
154                    let found = central
155                        .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
156                        .await
157                        .unwrap();
158                    assert_eq!(found.as_ref(), &new);
159                })
160            })
161            .await
162        }
163
164        #[apply(all_cred_cipher)]
165        #[wasm_bindgen_test]
166        async fn should_find_by_public_key(case: TestContext) {
167            run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
168                Box::pin(async move {
169                    const N: usize = 50;
170
171                    let r = rand::thread_rng().gen_range(0..N);
172                    let mut to_search = None;
173                    for i in 0..N {
174                        let cert = central.get_intermediate_ca().cloned();
175                        let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
176                        if i == r {
177                            to_search = Some(cb.clone());
178                        }
179                    }
180                    let to_search = to_search.unwrap();
181                    let pk = SignaturePublicKey::from(to_search.signature_key.public());
182                    let client = central.transaction.session().await.unwrap();
183                    let found = client
184                        .find_credential_bundle_by_public_key(case.signature_scheme(), case.credential_type, &pk)
185                        .await
186                        .unwrap();
187                    assert_eq!(&to_search, found.as_ref());
188                })
189            })
190            .await
191        }
192    }
193
194    mod push {
195        use super::*;
196
197        #[apply(all_cred_cipher)]
198        #[wasm_bindgen_test]
199        async fn should_add_credential(case: TestContext) {
200            run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
201                Box::pin(async move {
202                    let client = central.session().await;
203                    let prev_count = client.identities_count().await.unwrap();
204                    let cert = central.get_intermediate_ca().cloned();
205                    // this calls 'push_credential_bundle' under the hood
206                    central.new_credential_bundle(&case, cert.as_ref()).await;
207                    let next_count = client.identities_count().await.unwrap();
208                    assert_eq!(next_count, prev_count + 1);
209                })
210            })
211            .await
212        }
213
214        #[apply(all_cred_cipher)]
215        #[wasm_bindgen_test]
216        async fn pushing_duplicates_should_fail(case: TestContext) {
217            run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
218                Box::pin(async move {
219                    let cert = central.get_intermediate_ca().cloned();
220                    let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
221                    let client = central.transaction.session().await.unwrap();
222                    let push = client
223                        .save_identity(
224                            &central.transaction.keystore().await.unwrap(),
225                            None,
226                            case.signature_scheme(),
227                            cb,
228                        )
229                        .await;
230                    assert!(matches!(
231                        push.unwrap_err(),
232                        mls::session::Error::CredentialBundleConflict
233                    ));
234                })
235            })
236            .await
237        }
238    }
239}