core_crypto/mls/session/
credential.rs1use std::sync::Arc;
2
3use openmls::prelude::{SignaturePublicKey, SignatureScheme};
4use openmls_traits::OpenMlsCryptoProvider as _;
5
6use super::{Error, Result};
7use crate::{
8 Credential, CredentialFindFilters, CredentialRef, CredentialType, LeafError, MlsConversation, RecursiveError,
9 Session, mls::session::SessionInner,
10};
11
12impl Session {
13 pub async fn find_credentials(&self, find_filters: CredentialFindFilters<'_>) -> Result<Vec<CredentialRef>> {
17 let guard = self.inner.read().await;
18 let inner = guard.as_ref().ok_or(Error::MlsNotInitialized)?;
19 Ok(inner
20 .identities
21 .find_credential(find_filters)
22 .map(|credential| CredentialRef::from_credential(&credential))
23 .collect())
24 }
25
26 pub async fn get_credentials(&self) -> Result<Vec<CredentialRef>> {
28 self.find_credentials(Default::default()).await
29 }
30
31 pub(crate) async fn add_credential(&self, credential: Credential) -> Result<CredentialRef> {
35 let credential = self.add_credential_producing_arc(credential).await?;
36 Ok(CredentialRef::from_credential(&credential))
37 }
38
39 pub(crate) async fn add_credential_producing_arc(&self, credential: Credential) -> Result<Arc<Credential>> {
47 if *credential.client_id() != self.id().await? {
48 return Err(Error::WrongCredential);
49 }
50
51 self.add_credential_without_clientid_check(credential).await
52 }
53
54 pub(crate) async fn add_credential_without_clientid_check(
61 &self,
62 mut credential: Credential,
63 ) -> Result<Arc<Credential>> {
64 let credential_ref = credential
65 .save(&self.crypto_provider.keystore())
66 .await
67 .map_err(RecursiveError::mls_credential("saving credential"))?;
68
69 let guard = self.inner.upgradable_read().await;
70 let inner = guard.as_ref().ok_or(Error::MlsNotInitialized)?;
71
72 let distinct_result = inner.identities.ensure_distinct(
75 credential_ref.signature_scheme(),
76 credential_ref.r#type(),
77 credential_ref.earliest_validity(),
78 );
79 if let Err(err) = distinct_result {
80 credential
89 .delete(&self.crypto_provider.keystore())
90 .await
91 .map_err(RecursiveError::mls_credential(
92 "deleting nondistinct credential from keystore",
93 ))?;
94 return Err(err);
95 }
96
97 let mut guard = async_lock::RwLockUpgradableReadGuard::upgrade(guard).await;
99 let inner = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
100 let credential = inner.identities.push_credential(credential).await?;
101
102 Ok(credential)
103 }
104
105 pub async fn remove_credential(&self, credential_ref: &CredentialRef) -> Result<()> {
110 if *credential_ref.client_id() != self.id().await? {
112 return Err(Error::WrongCredential);
113 }
114
115 let database = self.crypto_provider.keystore();
116
117 let credentials = credential_ref
118 .load(&database)
119 .await
120 .map_err(RecursiveError::mls_credential_ref(
121 "loading all credentials from ref to remove from session identities",
122 ))?;
123
124 for (conversation_id, conversation) in
130 MlsConversation::load_all(&database)
131 .await
132 .map_err(RecursiveError::mls_conversation(
133 "loading all conversations to check if the credential to be removed is present",
134 ))?
135 {
136 let converation_credential = conversation
137 .own_mls_credential()
138 .map_err(RecursiveError::mls_conversation("geting conversation credential"))?;
139 if credentials
140 .iter()
141 .any(|credential| credential.mls_credential() == converation_credential)
142 {
143 return Err(Error::CredentialStillInUse(conversation_id));
144 }
145 }
146
147 let keypackages = self.find_all_keypackages(&self.crypto_provider.keystore()).await?;
149 let keypackages_from_this_credential = keypackages.iter().filter_map(|(_stored_key_package, key_package)| {
150 credentials
151 .iter()
152 .any(|credential| key_package.leaf_node().credential() == credential.mls_credential())
153 .then(|| key_package.hash_ref(self.crypto_provider.crypto()).ok()).flatten()
155 });
156 self.prune_keypackages(&self.crypto_provider, keypackages_from_this_credential)
157 .await?;
158
159 {
163 let mut inner = self.inner.write().await;
164 let inner = inner.as_mut().ok_or(Error::MlsNotInitialized)?;
165 for credential in &credentials {
166 inner.identities.remove_by_mls_credential(credential.mls_credential());
167 }
168 }
169
170 for credential in credentials {
172 credential
173 .delete(&database)
174 .await
175 .map_err(RecursiveError::mls_credential("deleting credential from keystore"))?;
176 }
177
178 Ok(())
179 }
180
181 pub(crate) async fn find_most_recent_credential(
183 &self,
184 signature_scheme: SignatureScheme,
185 credential_type: CredentialType,
186 ) -> Result<Arc<Credential>> {
187 match &*self.inner.read().await {
188 None => Err(Error::MlsNotInitialized),
189 Some(SessionInner { identities, .. }) => identities
190 .find_most_recent_credential(signature_scheme, credential_type)
191 .await
192 .ok_or(Error::CredentialNotFound(credential_type, signature_scheme)),
193 }
194 }
195
196 pub(crate) async fn find_credential_by_public_key(
198 &self,
199 signature_scheme: SignatureScheme,
200 credential_type: CredentialType,
201 public_key: &SignaturePublicKey,
202 ) -> Result<Arc<Credential>> {
203 match &*self.inner.read().await {
204 None => Err(Error::MlsNotInitialized),
205 Some(SessionInner { identities, .. }) => identities
206 .find_credential_by_public_key(signature_scheme, credential_type, public_key)
207 .await
208 .ok_or(Error::CredentialNotFound(credential_type, signature_scheme)),
209 }
210 }
211
212 pub(crate) async fn find_most_recent_or_create_basic_credential(
216 &self,
217 signature_scheme: SignatureScheme,
218 credential_type: CredentialType,
219 ) -> Result<Arc<Credential>> {
220 let credential = match self
221 .find_most_recent_credential(signature_scheme, credential_type)
222 .await
223 {
224 Ok(credential) => credential,
225 Err(Error::CredentialNotFound(..)) if credential_type == CredentialType::Basic => {
226 let client_id = self.id().await?;
227 let credential = Credential::basic(signature_scheme, client_id, &self.crypto_provider).map_err(
228 RecursiveError::mls_credential(
229 "creating basic credential in find_most_recent_or_create_basic_credential",
230 ),
231 )?;
232 self.add_credential_producing_arc(credential).await?
233 }
234 Err(Error::CredentialNotFound(..)) if credential_type == CredentialType::X509 => {
235 return Err(LeafError::E2eiEnrollmentNotDone.into());
236 }
237 Err(err) => return Err(err),
238 };
239 Ok(credential)
240 }
241}