core_crypto/mls/session/
credential.rs1use std::sync::Arc;
2
3use openmls::prelude::{SignaturePublicKey, SignatureScheme};
4
5use super::{Error, Result};
6use crate::{
7 Ciphersuite, Credential, CredentialFindFilters, CredentialRef, CredentialType, LeafError, MlsConversation,
8 RecursiveError, Session,
9};
10
11impl Session {
12 pub async fn find_credentials(&self, find_filters: CredentialFindFilters<'_>) -> Result<Vec<CredentialRef>> {
16 Ok(self
17 .identities
18 .read()
19 .await
20 .find_credential(find_filters)
21 .map(|credential| CredentialRef::from_credential(&credential))
22 .collect())
23 }
24
25 pub async fn get_credentials(&self) -> Result<Vec<CredentialRef>> {
27 self.find_credentials(Default::default()).await
28 }
29
30 pub(crate) async fn add_credential(&self, credential: Credential) -> Result<CredentialRef> {
34 let credential = self.add_credential_producing_arc(credential).await?;
35 Ok(CredentialRef::from_credential(&credential))
36 }
37
38 pub(crate) async fn add_credential_producing_arc(&self, credential: Credential) -> Result<Arc<Credential>> {
46 if *credential.client_id() != self.id() {
47 return Err(Error::WrongCredential);
48 }
49
50 self.add_credential_without_clientid_check(credential).await
51 }
52
53 pub(crate) async fn add_credential_without_clientid_check(
60 &self,
61 mut credential: Credential,
62 ) -> Result<Arc<Credential>> {
63 let _credential_ref = credential
64 .save(&self.crypto_provider.keystore())
65 .await
66 .map_err(RecursiveError::mls_credential("saving credential"))?;
67
68 let mut identities_guard = self.identities.write().await;
69 let credential = identities_guard.push_credential(credential).await?;
70
71 Ok(credential)
72 }
73
74 pub async fn remove_credential(&self, credential_ref: &CredentialRef) -> Result<()> {
79 if *credential_ref.client_id() != self.id() {
81 return Err(Error::WrongCredential);
82 }
83
84 let database = self.crypto_provider.keystore();
85
86 let credential = credential_ref
87 .load(&database)
88 .await
89 .map_err(RecursiveError::mls_credential_ref(
90 "loading all credentials from ref to remove from session identities",
91 ))?;
92
93 for (conversation_id, conversation) in
99 MlsConversation::load_all(&database)
100 .await
101 .map_err(RecursiveError::mls_conversation(
102 "loading all conversations to check if the credential to be removed is present",
103 ))?
104 {
105 let converation_credential = conversation
106 .own_mls_credential()
107 .map_err(RecursiveError::mls_conversation("geting conversation credential"))?;
108 if credential.mls_credential() == converation_credential {
109 return Err(Error::CredentialStillInUse(conversation_id));
110 }
111 }
112
113 self.remove_keypackages_for(credential_ref).await?;
115
116 {
120 let mut identities = self.identities.write().await;
121 identities.remove_by_mls_credential(credential.mls_credential());
122 }
123
124 credential
126 .delete(&database)
127 .await
128 .map_err(RecursiveError::mls_credential("deleting credential from keystore"))
129 .map_err(Into::into)
130 }
131
132 pub(crate) async fn find_most_recent_credential(
134 &self,
135 signature_scheme: SignatureScheme,
136 credential_type: CredentialType,
137 ) -> Result<Arc<Credential>> {
138 self.identities
139 .read()
140 .await
141 .find_most_recent_credential(signature_scheme, credential_type)
142 .await
143 .ok_or(Error::CredentialNotFound(credential_type, signature_scheme))
144 }
145
146 pub(crate) async fn find_credential_by_public_key(
148 &self,
149 signature_scheme: SignatureScheme,
150 credential_type: CredentialType,
151 public_key: &SignaturePublicKey,
152 ) -> Result<Arc<Credential>> {
153 self.identities
154 .read()
155 .await
156 .find_credential_by_public_key(signature_scheme, credential_type, public_key)
157 .await
158 .ok_or(Error::CredentialNotFound(credential_type, signature_scheme))
159 }
160
161 pub(crate) async fn find_most_recent_or_create_basic_credential(
165 &self,
166 ciphersuite: Ciphersuite,
167 credential_type: CredentialType,
168 ) -> Result<Arc<Credential>> {
169 let credential = match self
170 .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type)
171 .await
172 {
173 Ok(credential) => credential,
174 Err(Error::CredentialNotFound(..)) if credential_type == CredentialType::Basic => {
175 let client_id = self.id();
176 let credential = Credential::basic(ciphersuite, client_id, &self.crypto_provider).map_err(
177 RecursiveError::mls_credential(
178 "creating basic credential in find_most_recent_or_create_basic_credential",
179 ),
180 )?;
181 self.add_credential_producing_arc(credential).await?
182 }
183 Err(Error::CredentialNotFound(..)) if credential_type == CredentialType::X509 => {
184 return Err(LeafError::E2eiEnrollmentNotDone.into());
185 }
186 Err(err) => return Err(err),
187 };
188 Ok(credential)
189 }
190}