core_crypto/mls/session/
credential.rs

1use std::sync::Arc;
2
3use openmls::prelude::{SignaturePublicKey, SignatureScheme};
4use openmls_traits::OpenMlsCryptoProvider as _;
5
6use super::{Error, Result};
7use crate::{
8    Credential, CredentialFindFilters, CredentialRef, CredentialType, LeafError, MlsConversation, RecursiveError,
9    Session, mls::session::SessionInner,
10};
11
12impl Session {
13    /// Find all credentials known by this session which match the specified conditions.
14    ///
15    /// If no filters are set, this is equivalent to [`Self::get_credentials`].
16    pub async fn find_credentials(&self, find_filters: CredentialFindFilters<'_>) -> Result<Vec<CredentialRef>> {
17        let guard = self.inner.read().await;
18        let inner = guard.as_ref().ok_or(Error::MlsNotInitialized)?;
19        Ok(inner
20            .identities
21            .find_credential(find_filters)
22            .map(|credential| CredentialRef::from_credential(&credential))
23            .collect())
24    }
25
26    /// Get all credentials known by this session.
27    pub async fn get_credentials(&self) -> Result<Vec<CredentialRef>> {
28        self.find_credentials(Default::default()).await
29    }
30
31    /// Add a credential to the identities of this session.
32    ///
33    /// As a side effect, stores the credential in the keystore.
34    pub(crate) async fn add_credential(&self, credential: Credential) -> Result<CredentialRef> {
35        let credential = self.add_credential_producing_arc(credential).await?;
36        Ok(CredentialRef::from_credential(&credential))
37    }
38
39    /// Add a credential to the identities of this session.
40    ///
41    /// As a side effect, stores the credential in the keystore.
42    ///
43    /// Returns the actual credential instance which was loaded from the DB.
44    /// This is a convenience for internal use and should _not_ be propagated across
45    /// the FFI boundary. Instead, use [`Self::add_credential`] to produce a [`CredentialRef`].
46    pub(crate) async fn add_credential_producing_arc(&self, credential: Credential) -> Result<Arc<Credential>> {
47        if *credential.client_id() != self.id().await? {
48            return Err(Error::WrongCredential);
49        }
50
51        self.add_credential_without_clientid_check(credential).await
52    }
53
54    /// Add a credential to the identities of this session without validating that its client ID matches the session
55    /// client id.
56    ///
57    /// This is rarely useful and should only be used when absolutely necessary. You'll know it if you need it.
58    ///
59    /// Prefer [`Self::add_credential`].
60    pub(crate) async fn add_credential_without_clientid_check(
61        &self,
62        mut credential: Credential,
63    ) -> Result<Arc<Credential>> {
64        let credential_ref = credential
65            .save(&self.crypto_provider.keystore())
66            .await
67            .map_err(RecursiveError::mls_credential("saving credential"))?;
68
69        let guard = self.inner.upgradable_read().await;
70        let inner = guard.as_ref().ok_or(Error::MlsNotInitialized)?;
71
72        // failfast before loading the cache if we know already that this credential ref can't be added to the identity
73        // set
74        let distinct_result = inner.identities.ensure_distinct(
75            credential_ref.signature_scheme(),
76            credential_ref.r#type(),
77            credential_ref.earliest_validity(),
78        );
79        if let Err(err) = distinct_result {
80            // first clean up by removing the credential we just saved
81            // otherwise, we'll have nondeterministic results when we load
82            //
83            // TODO this depends for correctness that no two added credentials have the same keypair;
84            // if this happens for a keypair which was removed, we'll remove the (old, used) keypair
85            // and forever after be unable to mls_init on that DB due to a missing keypair for the given credential
86            // this is pointlessly difficult to check right now, but we should do a uniqueness check
87            // after WPB-20844
88            credential
89                .delete(&self.crypto_provider.keystore())
90                .await
91                .map_err(RecursiveError::mls_credential(
92                    "deleting nondistinct credential from keystore",
93                ))?;
94            return Err(err);
95        }
96
97        // only upgrade to a write guard here in order to minimize the amount of time the unique lock is held
98        let mut guard = async_lock::RwLockUpgradableReadGuard::upgrade(guard).await;
99        let inner = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
100        let credential = inner.identities.push_credential(credential).await?;
101
102        Ok(credential)
103    }
104
105    /// Remove a credential from the identities of this session.
106    ///
107    /// First checks that the credential is not used in any conversation.
108    /// Removes both the credential itself and also any key packages which were generated from it.
109    pub async fn remove_credential(&self, credential_ref: &CredentialRef) -> Result<()> {
110        // setup
111        if *credential_ref.client_id() != self.id().await? {
112            return Err(Error::WrongCredential);
113        }
114
115        let database = self.crypto_provider.keystore();
116
117        let credentials = credential_ref
118            .load(&database)
119            .await
120            .map_err(RecursiveError::mls_credential_ref(
121                "loading all credentials from ref to remove from session identities",
122            ))?;
123
124        // in a perfect world, we'd pre-cache the mls credentials in a set structure of some sort for faster querying.
125        // unfortunately, `MlsCredential` is `!Hash` and `!Ord`, so both the standard sets are out.
126        // so whatever, linear scan over the credentials every time will have to do.
127
128        // ensure this credential is not in use by any conversation
129        for (conversation_id, conversation) in
130            MlsConversation::load_all(&database)
131                .await
132                .map_err(RecursiveError::mls_conversation(
133                    "loading all conversations to check if the credential to be removed is present",
134                ))?
135        {
136            let converation_credential = conversation
137                .own_mls_credential()
138                .map_err(RecursiveError::mls_conversation("geting conversation credential"))?;
139            if credentials
140                .iter()
141                .any(|credential| credential.mls_credential() == converation_credential)
142            {
143                return Err(Error::CredentialStillInUse(conversation_id));
144            }
145        }
146
147        // remove any key packages generated by this credential
148        let keypackages = self.find_all_keypackages(&self.crypto_provider.keystore()).await?;
149        let keypackages_from_this_credential = keypackages.iter().filter_map(|(_stored_key_package, key_package)| {
150            credentials
151                    .iter()
152                    .any(|credential| key_package.leaf_node().credential() == credential.mls_credential())
153                    // if computing the hash reference fails, we will just not delete the key package
154                    .then(|| key_package.hash_ref(self.crypto_provider.crypto()).ok()).flatten()
155        });
156        self.prune_keypackages(&self.crypto_provider, keypackages_from_this_credential)
157            .await?;
158
159        // remove all credentials associated with this ref
160        // only remove the actual credential after the keypackages are all gone,
161        // and keep the lock open as briefly as possible
162        {
163            let mut inner = self.inner.write().await;
164            let inner = inner.as_mut().ok_or(Error::MlsNotInitialized)?;
165            for credential in &credentials {
166                inner.identities.remove_by_mls_credential(credential.mls_credential());
167            }
168        }
169
170        // finally remove the credentials from the keystore so they won't be loaded on next mls_init
171        for credential in credentials {
172            credential
173                .delete(&database)
174                .await
175                .map_err(RecursiveError::mls_credential("deleting credential from keystore"))?;
176        }
177
178        Ok(())
179    }
180
181    /// convenience function deferring to the implementation on the inner type
182    pub(crate) async fn find_most_recent_credential(
183        &self,
184        signature_scheme: SignatureScheme,
185        credential_type: CredentialType,
186    ) -> Result<Arc<Credential>> {
187        match &*self.inner.read().await {
188            None => Err(Error::MlsNotInitialized),
189            Some(SessionInner { identities, .. }) => identities
190                .find_most_recent_credential(signature_scheme, credential_type)
191                .await
192                .ok_or(Error::CredentialNotFound(credential_type, signature_scheme)),
193        }
194    }
195
196    /// convenience function deferring to the implementation on the inner type
197    pub(crate) async fn find_credential_by_public_key(
198        &self,
199        signature_scheme: SignatureScheme,
200        credential_type: CredentialType,
201        public_key: &SignaturePublicKey,
202    ) -> Result<Arc<Credential>> {
203        match &*self.inner.read().await {
204            None => Err(Error::MlsNotInitialized),
205            Some(SessionInner { identities, .. }) => identities
206                .find_credential_by_public_key(signature_scheme, credential_type, public_key)
207                .await
208                .ok_or(Error::CredentialNotFound(credential_type, signature_scheme)),
209        }
210    }
211
212    /// Convenience function to get the most recent credential, creating it if the credential type is basic.
213    ///
214    /// If the credential type is X509, a missing credential returns `LeafError::E2eiEnrollmentNotDone`
215    pub(crate) async fn find_most_recent_or_create_basic_credential(
216        &self,
217        signature_scheme: SignatureScheme,
218        credential_type: CredentialType,
219    ) -> Result<Arc<Credential>> {
220        let credential = match self
221            .find_most_recent_credential(signature_scheme, credential_type)
222            .await
223        {
224            Ok(credential) => credential,
225            Err(Error::CredentialNotFound(..)) if credential_type == CredentialType::Basic => {
226                let client_id = self.id().await?;
227                let credential = Credential::basic(signature_scheme, client_id, &self.crypto_provider).map_err(
228                    RecursiveError::mls_credential(
229                        "creating basic credential in find_most_recent_or_create_basic_credential",
230                    ),
231                )?;
232                self.add_credential_producing_arc(credential).await?
233            }
234            Err(Error::CredentialNotFound(..)) if credential_type == CredentialType::X509 => {
235                return Err(LeafError::E2eiEnrollmentNotDone.into());
236            }
237            Err(err) => return Err(err),
238        };
239        Ok(credential)
240    }
241}