core_crypto/mls/session/
credential.rs

1use std::sync::Arc;
2
3use openmls::prelude::{SignaturePublicKey, SignatureScheme};
4
5use super::{Error, Result};
6use crate::{
7    Ciphersuite, Credential, CredentialFindFilters, CredentialRef, CredentialType, LeafError, MlsConversation,
8    RecursiveError, Session,
9};
10
11impl Session {
12    /// Find all credentials known by this session which match the specified conditions.
13    ///
14    /// If no filters are set, this is equivalent to [`Self::get_credentials`].
15    pub async fn find_credentials(&self, find_filters: CredentialFindFilters<'_>) -> Result<Vec<CredentialRef>> {
16        Ok(self
17            .identities
18            .read()
19            .await
20            .find_credential(find_filters)
21            .map(|credential| CredentialRef::from_credential(&credential))
22            .collect())
23    }
24
25    /// Get all credentials known by this session.
26    pub async fn get_credentials(&self) -> Result<Vec<CredentialRef>> {
27        self.find_credentials(Default::default()).await
28    }
29
30    /// Add a credential to the identities of this session.
31    ///
32    /// As a side effect, stores the credential in the keystore.
33    pub(crate) async fn add_credential(&self, credential: Credential) -> Result<CredentialRef> {
34        let credential = self.add_credential_producing_arc(credential).await?;
35        Ok(CredentialRef::from_credential(&credential))
36    }
37
38    /// Add a credential to the identities of this session.
39    ///
40    /// As a side effect, stores the credential in the keystore.
41    ///
42    /// Returns the actual credential instance which was loaded from the DB.
43    /// This is a convenience for internal use and should _not_ be propagated across
44    /// the FFI boundary. Instead, use [`Self::add_credential`] to produce a [`CredentialRef`].
45    pub(crate) async fn add_credential_producing_arc(&self, credential: Credential) -> Result<Arc<Credential>> {
46        if *credential.client_id() != self.id() {
47            return Err(Error::WrongCredential);
48        }
49
50        self.add_credential_without_clientid_check(credential).await
51    }
52
53    /// Add a credential to the identities of this session without validating that its client ID matches the session
54    /// client id.
55    ///
56    /// This is rarely useful and should only be used when absolutely necessary. You'll know it if you need it.
57    ///
58    /// Prefer [`Self::add_credential`].
59    pub(crate) async fn add_credential_without_clientid_check(
60        &self,
61        mut credential: Credential,
62    ) -> Result<Arc<Credential>> {
63        let _credential_ref = credential
64            .save(&self.crypto_provider.keystore())
65            .await
66            .map_err(RecursiveError::mls_credential("saving credential"))?;
67
68        let mut identities_guard = self.identities.write().await;
69        let credential = identities_guard.push_credential(credential).await?;
70
71        Ok(credential)
72    }
73
74    /// Remove a credential from the identities of this session.
75    ///
76    /// First checks that the credential is not used in any conversation.
77    /// Removes both the credential itself and also any key packages which were generated from it.
78    pub async fn remove_credential(&self, credential_ref: &CredentialRef) -> Result<()> {
79        // setup
80        if *credential_ref.client_id() != self.id() {
81            return Err(Error::WrongCredential);
82        }
83
84        let database = self.crypto_provider.keystore();
85
86        let credential = credential_ref
87            .load(&database)
88            .await
89            .map_err(RecursiveError::mls_credential_ref(
90                "loading all credentials from ref to remove from session identities",
91            ))?;
92
93        // in a perfect world, we'd pre-cache the mls credentials in a set structure of some sort for faster querying.
94        // unfortunately, `MlsCredential` is `!Hash` and `!Ord`, so both the standard sets are out.
95        // so whatever, linear scan over the credentials every time will have to do.
96
97        // ensure this credential is not in use by any conversation
98        for (conversation_id, conversation) in
99            MlsConversation::load_all(&database)
100                .await
101                .map_err(RecursiveError::mls_conversation(
102                    "loading all conversations to check if the credential to be removed is present",
103                ))?
104        {
105            let converation_credential = conversation
106                .own_mls_credential()
107                .map_err(RecursiveError::mls_conversation("geting conversation credential"))?;
108            if credential.mls_credential() == converation_credential {
109                return Err(Error::CredentialStillInUse(conversation_id));
110            }
111        }
112
113        // remove any key packages generated by this credential
114        self.remove_keypackages_for(credential_ref).await?;
115
116        // remove all credentials associated with this ref
117        // only remove the actual credential after the keypackages are all gone,
118        // and keep the lock open as briefly as possible
119        {
120            let mut identities = self.identities.write().await;
121            identities.remove_by_mls_credential(credential.mls_credential());
122        }
123
124        // finally remove the credentials from the keystore so they won't be loaded on next mls_init
125        credential
126            .delete(&database)
127            .await
128            .map_err(RecursiveError::mls_credential("deleting credential from keystore"))
129            .map_err(Into::into)
130    }
131
132    /// convenience function deferring to the implementation on the inner type
133    pub(crate) async fn find_most_recent_credential(
134        &self,
135        signature_scheme: SignatureScheme,
136        credential_type: CredentialType,
137    ) -> Result<Arc<Credential>> {
138        self.identities
139            .read()
140            .await
141            .find_most_recent_credential(signature_scheme, credential_type)
142            .await
143            .ok_or(Error::CredentialNotFound(credential_type, signature_scheme))
144    }
145
146    /// convenience function deferring to the implementation on the inner type
147    pub(crate) async fn find_credential_by_public_key(
148        &self,
149        signature_scheme: SignatureScheme,
150        credential_type: CredentialType,
151        public_key: &SignaturePublicKey,
152    ) -> Result<Arc<Credential>> {
153        self.identities
154            .read()
155            .await
156            .find_credential_by_public_key(signature_scheme, credential_type, public_key)
157            .await
158            .ok_or(Error::CredentialNotFound(credential_type, signature_scheme))
159    }
160
161    /// Convenience function to get the most recent credential, creating it if the credential type is basic.
162    ///
163    /// If the credential type is X509, a missing credential returns `LeafError::E2eiEnrollmentNotDone`
164    pub(crate) async fn find_most_recent_or_create_basic_credential(
165        &self,
166        ciphersuite: Ciphersuite,
167        credential_type: CredentialType,
168    ) -> Result<Arc<Credential>> {
169        let credential = match self
170            .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type)
171            .await
172        {
173            Ok(credential) => credential,
174            Err(Error::CredentialNotFound(..)) if credential_type == CredentialType::Basic => {
175                let client_id = self.id();
176                let credential = Credential::basic(ciphersuite, client_id, &self.crypto_provider).map_err(
177                    RecursiveError::mls_credential(
178                        "creating basic credential in find_most_recent_or_create_basic_credential",
179                    ),
180                )?;
181                self.add_credential_producing_arc(credential).await?
182            }
183            Err(Error::CredentialNotFound(..)) if credential_type == CredentialType::X509 => {
184                return Err(LeafError::E2eiEnrollmentNotDone.into());
185            }
186            Err(err) => return Err(err),
187        };
188        Ok(credential)
189    }
190}