core_crypto/mls/session/
credential.rs1use std::sync::Arc;
2
3use openmls::prelude::{SignaturePublicKey, SignatureScheme};
4
5use super::{Error, Result};
6use crate::{
7 Ciphersuite, Credential, CredentialFindFilters, CredentialRef, CredentialType, LeafError, MlsConversation,
8 RecursiveError, Session, mls::session::SessionInner,
9};
10
11impl Session {
12 pub async fn find_credentials(&self, find_filters: CredentialFindFilters<'_>) -> Result<Vec<CredentialRef>> {
16 let guard = self.inner.read().await;
17 let inner = guard.as_ref().ok_or(Error::MlsNotInitialized)?;
18 Ok(inner
19 .identities
20 .find_credential(find_filters)
21 .map(|credential| CredentialRef::from_credential(&credential))
22 .collect())
23 }
24
25 pub async fn get_credentials(&self) -> Result<Vec<CredentialRef>> {
27 self.find_credentials(Default::default()).await
28 }
29
30 pub(crate) async fn add_credential(&self, credential: Credential) -> Result<CredentialRef> {
34 let credential = self.add_credential_producing_arc(credential).await?;
35 Ok(CredentialRef::from_credential(&credential))
36 }
37
38 pub(crate) async fn add_credential_producing_arc(&self, credential: Credential) -> Result<Arc<Credential>> {
46 if *credential.client_id() != self.id().await? {
47 return Err(Error::WrongCredential);
48 }
49
50 self.add_credential_without_clientid_check(credential).await
51 }
52
53 pub(crate) async fn add_credential_without_clientid_check(
60 &self,
61 mut credential: Credential,
62 ) -> Result<Arc<Credential>> {
63 let _credential_ref = credential
64 .save(&self.crypto_provider.keystore())
65 .await
66 .map_err(RecursiveError::mls_credential("saving credential"))?;
67
68 let guard = self.inner.upgradable_read().await;
69
70 let mut guard = async_lock::RwLockUpgradableReadGuard::upgrade(guard).await;
72 let inner = guard.as_mut().ok_or(Error::MlsNotInitialized)?;
73 let credential = inner.identities.push_credential(credential).await?;
74
75 Ok(credential)
76 }
77
78 pub async fn remove_credential(&self, credential_ref: &CredentialRef) -> Result<()> {
83 if *credential_ref.client_id() != self.id().await? {
85 return Err(Error::WrongCredential);
86 }
87
88 let database = self.crypto_provider.keystore();
89
90 let credential = credential_ref
91 .load(&database)
92 .await
93 .map_err(RecursiveError::mls_credential_ref(
94 "loading all credentials from ref to remove from session identities",
95 ))?;
96
97 for (conversation_id, conversation) in
103 MlsConversation::load_all(&database)
104 .await
105 .map_err(RecursiveError::mls_conversation(
106 "loading all conversations to check if the credential to be removed is present",
107 ))?
108 {
109 let converation_credential = conversation
110 .own_mls_credential()
111 .map_err(RecursiveError::mls_conversation("geting conversation credential"))?;
112 if credential.mls_credential() == converation_credential {
113 return Err(Error::CredentialStillInUse(conversation_id));
114 }
115 }
116
117 self.remove_keypackages_for(credential_ref).await?;
119
120 {
124 let mut inner = self.inner.write().await;
125 let inner = inner.as_mut().ok_or(Error::MlsNotInitialized)?;
126 inner.identities.remove_by_mls_credential(credential.mls_credential());
127 }
128
129 credential
131 .delete(&database)
132 .await
133 .map_err(RecursiveError::mls_credential("deleting credential from keystore"))
134 .map_err(Into::into)
135 }
136
137 pub(crate) async fn find_most_recent_credential(
139 &self,
140 signature_scheme: SignatureScheme,
141 credential_type: CredentialType,
142 ) -> Result<Arc<Credential>> {
143 match &*self.inner.read().await {
144 None => Err(Error::MlsNotInitialized),
145 Some(SessionInner { identities, .. }) => identities
146 .find_most_recent_credential(signature_scheme, credential_type)
147 .await
148 .ok_or(Error::CredentialNotFound(credential_type, signature_scheme)),
149 }
150 }
151
152 pub(crate) async fn find_credential_by_public_key(
154 &self,
155 signature_scheme: SignatureScheme,
156 credential_type: CredentialType,
157 public_key: &SignaturePublicKey,
158 ) -> Result<Arc<Credential>> {
159 match &*self.inner.read().await {
160 None => Err(Error::MlsNotInitialized),
161 Some(SessionInner { identities, .. }) => identities
162 .find_credential_by_public_key(signature_scheme, credential_type, public_key)
163 .await
164 .ok_or(Error::CredentialNotFound(credential_type, signature_scheme)),
165 }
166 }
167
168 pub(crate) async fn find_most_recent_or_create_basic_credential(
172 &self,
173 ciphersuite: Ciphersuite,
174 credential_type: CredentialType,
175 ) -> Result<Arc<Credential>> {
176 let credential = match self
177 .find_most_recent_credential(ciphersuite.signature_algorithm(), credential_type)
178 .await
179 {
180 Ok(credential) => credential,
181 Err(Error::CredentialNotFound(..)) if credential_type == CredentialType::Basic => {
182 let client_id = self.id().await?;
183 let credential = Credential::basic(ciphersuite, client_id, &self.crypto_provider).map_err(
184 RecursiveError::mls_credential(
185 "creating basic credential in find_most_recent_or_create_basic_credential",
186 ),
187 )?;
188 self.add_credential_producing_arc(credential).await?
189 }
190 Err(Error::CredentialNotFound(..)) if credential_type == CredentialType::X509 => {
191 return Err(LeafError::E2eiEnrollmentNotDone.into());
192 }
193 Err(err) => return Err(err),
194 };
195 Ok(credential)
196 }
197}