core_crypto/mls/client/
identities.rs

1use crate::{
2    LeafError, RecursiveError,
3    mls::client::{
4        ClientInner,
5        error::{Error, Result},
6    },
7};
8use crate::{
9    mls::credential::{CredentialBundle, typ::MlsCredentialType},
10    prelude::{Client, MlsConversation},
11};
12use openmls::prelude::{Credential, SignaturePublicKey};
13use openmls_traits::types::SignatureScheme;
14use std::collections::HashMap;
15use std::ops::Deref;
16use std::sync::Arc;
17
18/// In memory Map of a Client's identities: one per SignatureScheme.
19/// We need `indexmap::IndexSet` because each `CredentialBundle` has to be unique and insertion
20/// order matters in order to keep values sorted by time `created_at` so that we can identify most recent ones.
21///
22/// We keep each credential bundle inside an arc to avoid cloning them, as X509 credentials can get quite large.
23#[derive(Debug, Clone)]
24pub(crate) struct ClientIdentities(HashMap<SignatureScheme, indexmap::IndexSet<Arc<CredentialBundle>>>);
25
26impl ClientIdentities {
27    pub(crate) fn new(capacity: usize) -> Self {
28        Self(HashMap::with_capacity(capacity))
29    }
30
31    pub(crate) async fn find_credential_bundle_by_public_key(
32        &self,
33        sc: SignatureScheme,
34        ct: MlsCredentialType,
35        pk: &SignaturePublicKey,
36    ) -> Option<Arc<CredentialBundle>> {
37        self.0
38            .get(&sc)?
39            .iter()
40            .find(|c| {
41                let ct_match = ct == c.credential.credential_type().into();
42                let pk_match = c.signature_key.public() == pk.as_slice();
43                ct_match && pk_match
44            })
45            .cloned()
46    }
47
48    pub(crate) async fn find_most_recent_credential_bundle(
49        &self,
50        sc: SignatureScheme,
51        ct: MlsCredentialType,
52    ) -> Option<Arc<CredentialBundle>> {
53        self.0
54            .get(&sc)?
55            .iter()
56            .rfind(|c| ct == c.credential.credential_type().into())
57            .cloned()
58    }
59
60    /// Having `cb` requiring ownership kinda forces the caller to first persist it in the keystore and
61    /// only then store it in this in-memory map
62    pub(crate) async fn push_credential_bundle(&mut self, sc: SignatureScheme, cb: CredentialBundle) -> Result<()> {
63        // this would mean we have messed something up and that we do no init this CredentialBundle from a keypair just inserted in the keystore
64        debug_assert_ne!(cb.created_at, 0);
65
66        match self.0.get_mut(&sc) {
67            Some(cbs) => {
68                let already_exists = !cbs.insert(Arc::new(cb));
69                if already_exists {
70                    return Err(Error::CredentialBundleConflict);
71                }
72            }
73            None => {
74                self.0.insert(sc, indexmap::IndexSet::from([Arc::new(cb)]));
75            }
76        }
77        Ok(())
78    }
79
80    pub(crate) async fn remove(&mut self, credential: &Credential) -> Result<()> {
81        self.0.iter_mut().for_each(|(_, cbs)| {
82            cbs.retain(|c| c.credential() != credential);
83        });
84        Ok(())
85    }
86
87    pub(crate) fn iter(&self) -> impl Iterator<Item = (SignatureScheme, Arc<CredentialBundle>)> + '_ {
88        self.0.iter().flat_map(|(sc, cb)| cb.iter().map(|c| (*sc, c.clone())))
89    }
90}
91
92impl MlsConversation {
93    pub(crate) async fn find_current_credential_bundle(&self, client: &Client) -> Result<Arc<CredentialBundle>> {
94        let own_leaf = self.group.own_leaf().ok_or(LeafError::InternalMlsError)?;
95        let sc = self.ciphersuite().signature_algorithm();
96        let ct = self
97            .own_credential_type()
98            .map_err(RecursiveError::mls_conversation("getting own credential type"))?;
99
100        client
101            .find_credential_bundle_by_public_key(sc, ct, own_leaf.signature_key())
102            .await
103    }
104
105    pub(crate) async fn find_most_recent_credential_bundle(&self, client: &Client) -> Result<Arc<CredentialBundle>> {
106        let sc = self.ciphersuite().signature_algorithm();
107        let ct = self
108            .own_credential_type()
109            .map_err(RecursiveError::mls_conversation("getting own credential type"))?;
110
111        client.find_most_recent_credential_bundle(sc, ct).await
112    }
113}
114
115impl Client {
116    pub(crate) async fn find_most_recent_credential_bundle(
117        &self,
118        sc: SignatureScheme,
119        ct: MlsCredentialType,
120    ) -> Result<Arc<CredentialBundle>> {
121        match self.state.read().await.deref() {
122            None => Err(Error::MlsNotInitialized),
123            Some(ClientInner { identities, .. }) => identities
124                .find_most_recent_credential_bundle(sc, ct)
125                .await
126                .ok_or(Error::CredentialNotFound(ct)),
127        }
128    }
129
130    pub(crate) async fn find_credential_bundle_by_public_key(
131        &self,
132        sc: SignatureScheme,
133        ct: MlsCredentialType,
134        pk: &SignaturePublicKey,
135    ) -> Result<Arc<CredentialBundle>> {
136        match self.state.read().await.deref() {
137            None => Err(Error::MlsNotInitialized),
138            Some(ClientInner { identities, .. }) => identities
139                .find_credential_bundle_by_public_key(sc, ct, pk)
140                .await
141                .ok_or(Error::CredentialNotFound(ct)),
142        }
143    }
144
145    #[cfg(test)]
146    pub(crate) async fn identities_count(&self) -> Result<usize> {
147        match self.state.read().await.deref() {
148            None => Err(Error::MlsNotInitialized),
149            Some(ClientInner { identities, .. }) => Ok(identities.iter().count()),
150        }
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use crate::{mls, test_utils::*};
157    use openmls::prelude::SignaturePublicKey;
158    use rand::Rng;
159    use wasm_bindgen_test::*;
160
161    wasm_bindgen_test_configure!(run_in_browser);
162
163    mod find {
164        use super::*;
165
166        #[apply(all_cred_cipher)]
167        #[wasm_bindgen_test]
168        async fn should_find_most_recent(case: TestCase) {
169            run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
170                Box::pin(async move {
171                    let cert = central.get_intermediate_ca().cloned();
172                    let old = central.new_credential_bundle(&case, cert.as_ref()).await;
173
174                    // wait to make sure we're not in the same second
175                    async_std::task::sleep(core::time::Duration::from_secs(1)).await;
176
177                    let new = central.new_credential_bundle(&case, cert.as_ref()).await;
178                    assert_ne!(old, new);
179
180                    let found = central
181                        .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
182                        .await
183                        .unwrap();
184                    assert_eq!(found.as_ref(), &new);
185                })
186            })
187            .await
188        }
189
190        #[apply(all_cred_cipher)]
191        #[wasm_bindgen_test]
192        async fn should_find_by_public_key(case: TestCase) {
193            run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
194                Box::pin(async move {
195                    const N: usize = 50;
196
197                    let r = rand::thread_rng().gen_range(0..N);
198                    let mut to_search = None;
199                    for i in 0..N {
200                        let cert = central.get_intermediate_ca().cloned();
201                        let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
202                        if i == r {
203                            to_search = Some(cb.clone());
204                        }
205                    }
206                    let to_search = to_search.unwrap();
207                    let pk = SignaturePublicKey::from(to_search.signature_key.public());
208                    let client = central.context.mls_client().await.unwrap();
209                    let found = client
210                        .find_credential_bundle_by_public_key(case.signature_scheme(), case.credential_type, &pk)
211                        .await
212                        .unwrap();
213                    assert_eq!(&to_search, found.as_ref());
214                })
215            })
216            .await
217        }
218    }
219
220    mod push {
221        use super::*;
222
223        #[apply(all_cred_cipher)]
224        #[wasm_bindgen_test]
225        async fn should_add_credential(case: TestCase) {
226            run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
227                Box::pin(async move {
228                    let client = central.client().await;
229                    let prev_count = client.identities_count().await.unwrap();
230                    let cert = central.get_intermediate_ca().cloned();
231                    // this calls 'push_credential_bundle' under the hood
232                    central.new_credential_bundle(&case, cert.as_ref()).await;
233                    let next_count = client.identities_count().await.unwrap();
234                    assert_eq!(next_count, prev_count + 1);
235                })
236            })
237            .await
238        }
239
240        #[apply(all_cred_cipher)]
241        #[wasm_bindgen_test]
242        async fn pushing_duplicates_should_fail(case: TestCase) {
243            run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
244                Box::pin(async move {
245                    let cert = central.get_intermediate_ca().cloned();
246                    let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
247                    let client = central.context.mls_client().await.unwrap();
248                    let push = client
249                        .save_identity(
250                            &central.context.keystore().await.unwrap(),
251                            None,
252                            case.signature_scheme(),
253                            cb,
254                        )
255                        .await;
256                    assert!(matches!(
257                        push.unwrap_err(),
258                        mls::client::Error::CredentialBundleConflict
259                    ));
260                })
261            })
262            .await
263        }
264    }
265}