core_crypto/mls/session/
credential.rs

1use std::sync::Arc;
2
3use openmls::prelude::{SignaturePublicKey, SignatureScheme};
4
5use super::{Error, Result};
6use crate::{
7    Ciphersuite, Credential, CredentialFindFilters, CredentialRef, CredentialType, LeafError, MlsConversation,
8    RecursiveError, Session, mls::session::SessionInner,
9};
10
11impl Session {
12    /// Find all credentials known by this session which match the specified conditions.
13    ///
14    /// If no filters are set, this is equivalent to [`Self::get_credentials`].
15    pub async fn find_credentials(&self, find_filters: CredentialFindFilters<'_>) -> Result<Vec<CredentialRef>> {
16        let guard = self.inner.read().await;
17        let inner = guard.as_ref().ok_or(Error::MlsNotInitialized)?;
18        Ok(inner
19            .identities
20            .find_credential(find_filters)
21            .map(|credential| CredentialRef::from_credential(&credential))
22            .collect())
23    }
24
25    /// Get all credentials known by this session.
26    pub async fn get_credentials(&self) -> Result<Vec<CredentialRef>> {
27        self.find_credentials(Default::default()).await
28    }
29
30    /// Add a credential to the identities of this session.
31    ///
32    /// As a side effect, stores the credential in the keystore.
33    pub(crate) async fn add_credential(&self, credential: Credential) -> Result<CredentialRef> {
34        let credential = self.add_credential_producing_arc(credential).await?;
35        Ok(CredentialRef::from_credential(&credential))
36    }
37
38    /// Add a credential to the identities of this session.
39    ///
40    /// As a side effect, stores the credential in the keystore.
41    ///
42    /// Returns the actual credential instance which was loaded from the DB.
43    /// This is a convenience for internal use and should _not_ be propagated across
44    /// the FFI boundary. Instead, use [`Self::add_credential`] to produce a [`CredentialRef`].
45    pub(crate) async fn add_credential_producing_arc(&self, credential: Credential) -> Result<Arc<Credential>> {
46        if *credential.client_id() != self.id().await? {
47            return Err(Error::WrongCredential);
48        }
49
50        self.add_credential_without_clientid_check(credential).await
51    }
52
53    /// Add a credential to the identities of this session without validating that its client ID matches the session
54    /// client id.
55    ///
56    /// This is rarely useful and should only be used when absolutely necessary. You'll know it if you need it.
57    ///
58    /// Prefer [`Self::add_credential`].
59    pub(crate) async fn add_credential_without_clientid_check(
60        &self,
61        mut credential: Credential,
62    ) -> Result<Arc<Credential>> {
63        let _credential_ref = credential
64            .save(&self.crypto_provider.keystore())
65            .await
66            .map_err(RecursiveError::mls_credential("saving credential"))?;
67
68        let guard = self.inner.upgradable_read().await;
69
70        // only upgrade to a write guard here in order to minimize the amount of time the unique lock is held
71        let mut guard = async_lock::RwLockUpgradableReadGuard::upgrade(guard).await;
72        let inner = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
73        let credential = inner.identities.push_credential(credential).await?;
74
75        Ok(credential)
76    }
77
78    /// Remove a credential from the identities of this session.
79    ///
80    /// First checks that the credential is not used in any conversation.
81    /// Removes both the credential itself and also any key packages which were generated from it.
82    pub async fn remove_credential(&self, credential_ref: &CredentialRef) -> Result<()> {
83        // setup
84        if *credential_ref.client_id() != self.id().await? {
85            return Err(Error::WrongCredential);
86        }
87
88        let database = self.crypto_provider.keystore();
89
90        let credential = credential_ref
91            .load(&database)
92            .await
93            .map_err(RecursiveError::mls_credential_ref(
94                "loading all credentials from ref to remove from session identities",
95            ))?;
96
97        // in a perfect world, we'd pre-cache the mls credentials in a set structure of some sort for faster querying.
98        // unfortunately, `MlsCredential` is `!Hash` and `!Ord`, so both the standard sets are out.
99        // so whatever, linear scan over the credentials every time will have to do.
100
101        // ensure this credential is not in use by any conversation
102        for (conversation_id, conversation) in
103            MlsConversation::load_all(&database)
104                .await
105                .map_err(RecursiveError::mls_conversation(
106                    "loading all conversations to check if the credential to be removed is present",
107                ))?
108        {
109            let converation_credential = conversation
110                .own_mls_credential()
111                .map_err(RecursiveError::mls_conversation("geting conversation credential"))?;
112            if credential.mls_credential() == converation_credential {
113                return Err(Error::CredentialStillInUse(conversation_id));
114            }
115        }
116
117        // remove any key packages generated by this credential
118        self.remove_keypackages_for(credential_ref).await?;
119
120        // remove all credentials associated with this ref
121        // only remove the actual credential after the keypackages are all gone,
122        // and keep the lock open as briefly as possible
123        {
124            let mut inner = self.inner.write().await;
125            let inner = inner.as_mut().ok_or(Error::MlsNotInitialized)?;
126            inner.identities.remove_by_mls_credential(credential.mls_credential());
127        }
128
129        // finally remove the credentials from the keystore so they won't be loaded on next mls_init
130        credential
131            .delete(&database)
132            .await
133            .map_err(RecursiveError::mls_credential("deleting credential from keystore"))
134            .map_err(Into::into)
135    }
136
137    /// convenience function deferring to the implementation on the inner type
138    pub(crate) async fn find_most_recent_credential(
139        &self,
140        signature_scheme: SignatureScheme,
141        credential_type: CredentialType,
142    ) -> Result<Arc<Credential>> {
143        match &*self.inner.read().await {
144            None => Err(Error::MlsNotInitialized),
145            Some(SessionInner { identities, .. }) => identities
146                .find_most_recent_credential(signature_scheme, credential_type)
147                .await
148                .ok_or(Error::CredentialNotFound(credential_type, signature_scheme)),
149        }
150    }
151
152    /// convenience function deferring to the implementation on the inner type
153    pub(crate) async fn find_credential_by_public_key(
154        &self,
155        signature_scheme: SignatureScheme,
156        credential_type: CredentialType,
157        public_key: &SignaturePublicKey,
158    ) -> Result<Arc<Credential>> {
159        match &*self.inner.read().await {
160            None => Err(Error::MlsNotInitialized),
161            Some(SessionInner { identities, .. }) => identities
162                .find_credential_by_public_key(signature_scheme, credential_type, public_key)
163                .await
164                .ok_or(Error::CredentialNotFound(credential_type, signature_scheme)),
165        }
166    }
167
168    /// Convenience function to get the most recent credential, creating it if the credential type is basic.
169    ///
170    /// If the credential type is X509, a missing credential returns `LeafError::E2eiEnrollmentNotDone`
171    pub(crate) async fn find_most_recent_or_create_basic_credential(
172        &self,
173        ciphersuite: Ciphersuite,
174        credential_type: CredentialType,
175    ) -> Result<Arc<Credential>> {
176        let credential = match self
177            .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type)
178            .await
179        {
180            Ok(credential) => credential,
181            Err(Error::CredentialNotFound(..)) if credential_type == CredentialType::Basic => {
182                let client_id = self.id().await?;
183                let credential = Credential::basic(ciphersuite, client_id, &self.crypto_provider).map_err(
184                    RecursiveError::mls_credential(
185                        "creating basic credential in find_most_recent_or_create_basic_credential",
186                    ),
187                )?;
188                self.add_credential_producing_arc(credential).await?
189            }
190            Err(Error::CredentialNotFound(..)) if credential_type == CredentialType::X509 => {
191                return Err(LeafError::E2eiEnrollmentNotDone.into());
192            }
193            Err(err) => return Err(err),
194        };
195        Ok(credential)
196    }
197}