core_crypto/mls/client/
identities.rs

1use crate::mls::client::ClientInner;
2use crate::{
3    mls::credential::{typ::MlsCredentialType, CredentialBundle},
4    prelude::{Client, CryptoError, CryptoResult, MlsConversation},
5};
6use openmls::prelude::{Credential, SignaturePublicKey};
7use openmls_traits::types::SignatureScheme;
8use std::collections::HashMap;
9use std::ops::Deref;
10use std::sync::Arc;
11
12/// In memory Map of a Client's identities: one per SignatureScheme.
13/// We need `indexmap::IndexSet` because each `CredentialBundle` has to be unique and insertion
14/// order matters in order to keep values sorted by time `created_at` so that we can identify most recent ones.
15///
16/// We keep each credential bundle inside an arc to avoid cloning them, as X509 credentials can get quite large.
17#[derive(Debug, Clone)]
18pub(crate) struct ClientIdentities(HashMap<SignatureScheme, indexmap::IndexSet<Arc<CredentialBundle>>>);
19
20impl ClientIdentities {
21    pub(crate) fn new(capacity: usize) -> Self {
22        Self(HashMap::with_capacity(capacity))
23    }
24
25    pub(crate) async fn find_credential_bundle_by_public_key(
26        &self,
27        sc: SignatureScheme,
28        ct: MlsCredentialType,
29        pk: &SignaturePublicKey,
30    ) -> Option<Arc<CredentialBundle>> {
31        self.0
32            .get(&sc)?
33            .iter()
34            .find(|c| {
35                let ct_match = ct == c.credential.credential_type().into();
36                let pk_match = c.signature_key.public() == pk.as_slice();
37                ct_match && pk_match
38            })
39            .cloned()
40    }
41
42    pub(crate) async fn find_most_recent_credential_bundle(
43        &self,
44        sc: SignatureScheme,
45        ct: MlsCredentialType,
46    ) -> Option<Arc<CredentialBundle>> {
47        self.0
48            .get(&sc)?
49            .iter()
50            .rfind(|c| ct == c.credential.credential_type().into())
51            .cloned()
52    }
53
54    /// Having `cb` requiring ownership kinda forces the caller to first persist it in the keystore and
55    /// only then store it in this in-memory map
56    pub(crate) async fn push_credential_bundle(
57        &mut self,
58        sc: SignatureScheme,
59        cb: CredentialBundle,
60    ) -> CryptoResult<()> {
61        // this would mean we have messed something up and that we do no init this CredentialBundle from a keypair just inserted in the keystore
62        debug_assert_ne!(cb.created_at, 0);
63
64        match self.0.get_mut(&sc) {
65            Some(cbs) => {
66                let already_exists = !cbs.insert(Arc::new(cb));
67                if already_exists {
68                    return Err(CryptoError::CredentialBundleConflict);
69                }
70            }
71            None => {
72                self.0.insert(sc, indexmap::IndexSet::from([Arc::new(cb)]));
73            }
74        }
75        Ok(())
76    }
77
78    pub(crate) async fn remove(&mut self, credential: &Credential) -> CryptoResult<()> {
79        self.0.iter_mut().for_each(|(_, cbs)| {
80            cbs.retain(|c| c.credential() != credential);
81        });
82        Ok(())
83    }
84
85    pub(crate) fn iter(&self) -> impl Iterator<Item = (SignatureScheme, Arc<CredentialBundle>)> + '_ {
86        self.0.iter().flat_map(|(sc, cb)| cb.iter().map(|c| (*sc, c.clone())))
87    }
88}
89
90impl MlsConversation {
91    pub(crate) async fn find_current_credential_bundle(&self, client: &Client) -> CryptoResult<Arc<CredentialBundle>> {
92        let own_leaf = self.group.own_leaf().ok_or(CryptoError::InternalMlsError)?;
93        let sc = self.ciphersuite().signature_algorithm();
94        let ct = self.own_credential_type()?;
95
96        client
97            .find_credential_bundle_by_public_key(sc, ct, own_leaf.signature_key())
98            .await
99    }
100
101    pub(crate) async fn find_most_recent_credential_bundle(
102        &self,
103        client: &Client,
104    ) -> CryptoResult<Arc<CredentialBundle>> {
105        let sc = self.ciphersuite().signature_algorithm();
106        let ct = self.own_credential_type()?;
107
108        client.find_most_recent_credential_bundle(sc, ct).await
109    }
110}
111
112impl Client {
113    pub(crate) async fn find_most_recent_credential_bundle(
114        &self,
115        sc: SignatureScheme,
116        ct: MlsCredentialType,
117    ) -> CryptoResult<Arc<CredentialBundle>> {
118        match self.state.read().await.deref() {
119            None => Err(CryptoError::MlsNotInitialized),
120            Some(ClientInner { identities, .. }) => identities
121                .find_most_recent_credential_bundle(sc, ct)
122                .await
123                .ok_or(CryptoError::CredentialNotFound(ct)),
124        }
125    }
126
127    pub(crate) async fn find_credential_bundle_by_public_key(
128        &self,
129        sc: SignatureScheme,
130        ct: MlsCredentialType,
131        pk: &SignaturePublicKey,
132    ) -> CryptoResult<Arc<CredentialBundle>> {
133        match self.state.read().await.deref() {
134            None => Err(CryptoError::MlsNotInitialized),
135            Some(ClientInner { identities, .. }) => identities
136                .find_credential_bundle_by_public_key(sc, ct, pk)
137                .await
138                .ok_or(CryptoError::CredentialNotFound(ct)),
139        }
140    }
141
142    #[cfg(test)]
143    pub(crate) async fn identities_count(&self) -> CryptoResult<usize> {
144        match self.state.read().await.deref() {
145            None => Err(CryptoError::MlsNotInitialized),
146            Some(ClientInner { identities, .. }) => Ok(identities.iter().count()),
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use crate::test_utils::*;
154    use openmls::prelude::SignaturePublicKey;
155    use rand::Rng;
156    use wasm_bindgen_test::*;
157
158    wasm_bindgen_test_configure!(run_in_browser);
159
160    mod find {
161        use super::*;
162
163        #[apply(all_cred_cipher)]
164        #[wasm_bindgen_test]
165        async fn should_find_most_recent(case: TestCase) {
166            run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
167                Box::pin(async move {
168                    let cert = central.get_intermediate_ca().cloned();
169                    let old = central.new_credential_bundle(&case, cert.as_ref()).await;
170
171                    // wait to make sure we're not in the same second
172                    async_std::task::sleep(core::time::Duration::from_secs(1)).await;
173
174                    let new = central.new_credential_bundle(&case, cert.as_ref()).await;
175                    assert_ne!(old, new);
176
177                    let found = central
178                        .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
179                        .await
180                        .unwrap();
181                    assert_eq!(found.as_ref(), &new);
182                })
183            })
184            .await
185        }
186
187        #[apply(all_cred_cipher)]
188        #[wasm_bindgen_test]
189        async fn should_find_by_public_key(case: TestCase) {
190            run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
191                Box::pin(async move {
192                    const N: usize = 50;
193
194                    let r = rand::thread_rng().gen_range(0..N);
195                    let mut to_search = None;
196                    for i in 0..N {
197                        let cert = central.get_intermediate_ca().cloned();
198                        let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
199                        if i == r {
200                            to_search = Some(cb.clone());
201                        }
202                    }
203                    let to_search = to_search.unwrap();
204                    let pk = SignaturePublicKey::from(to_search.signature_key.public());
205                    let client = central.context.mls_client().await.unwrap();
206                    let found = client
207                        .find_credential_bundle_by_public_key(case.signature_scheme(), case.credential_type, &pk)
208                        .await
209                        .unwrap();
210                    assert_eq!(&to_search, found.as_ref());
211                })
212            })
213            .await
214        }
215    }
216
217    mod push {
218        use super::*;
219
220        #[apply(all_cred_cipher)]
221        #[wasm_bindgen_test]
222        async fn should_add_credential(case: TestCase) {
223            run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
224                Box::pin(async move {
225                    let client = central.client().await;
226                    let prev_count = client.identities_count().await.unwrap();
227                    let cert = central.get_intermediate_ca().cloned();
228                    // this calls 'push_credential_bundle' under the hood
229                    central.new_credential_bundle(&case, cert.as_ref()).await;
230                    let next_count = client.identities_count().await.unwrap();
231                    assert_eq!(next_count, prev_count + 1);
232                })
233            })
234            .await
235        }
236
237        #[apply(all_cred_cipher)]
238        #[wasm_bindgen_test]
239        async fn pushing_duplicates_should_fail(case: TestCase) {
240            run_test_with_client_ids(case.clone(), ["alice"], move |[mut central]| {
241                Box::pin(async move {
242                    let cert = central.get_intermediate_ca().cloned();
243                    let cb = central.new_credential_bundle(&case, cert.as_ref()).await;
244                    let client = central.context.mls_client().await.unwrap();
245                    let push = client
246                        .save_identity(
247                            &central.context.keystore().await.unwrap(),
248                            None,
249                            case.signature_scheme(),
250                            cb,
251                        )
252                        .await;
253                    assert!(matches!(
254                        push.unwrap_err(),
255                        crate::CryptoError::CredentialBundleConflict
256                    ));
257                })
258            })
259            .await
260        }
261    }
262}