mod error;
pub(crate) mod id;
pub(crate) mod identifier;
pub(crate) mod identities;
pub(crate) mod key_package;
pub(crate) mod user_id;
use crate::{
    KeystoreError, LeafError, MlsError, RecursiveError,
    mls::credential::{CredentialBundle, ext::CredentialExt},
    prelude::{
        CertificateBundle, ClientId, MlsCiphersuite, MlsCredentialType, identifier::ClientIdentifier,
        key_package::KEYPACKAGE_DEFAULT_LIFETIME,
    },
};
pub(crate) use error::{Error, Result};
use async_lock::RwLock;
use core_crypto_keystore::{Connection, CryptoKeystoreError, connection::FetchFromDatabase};
use log::debug;
use openmls::prelude::{Credential, CredentialType};
use openmls_basic_credential::SignatureKeyPair;
use openmls_traits::{OpenMlsCryptoProvider, crypto::OpenMlsCrypto, types::SignatureScheme};
use std::collections::HashSet;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use tls_codec::{Deserialize, Serialize};
use core_crypto_keystore::entities::{EntityFindParams, MlsCredential, MlsSignatureKeyPair};
use identities::ClientIdentities;
use mls_crypto_provider::MlsCryptoProvider;
#[derive(Clone, Debug, Default)]
pub struct Client {
    state: Arc<RwLock<Option<ClientInner>>>,
}
#[derive(Debug, Clone)]
struct ClientInner {
    id: ClientId,
    pub(crate) identities: ClientIdentities,
    keypackage_lifetime: std::time::Duration,
}
impl Client {
    pub async fn init(
        &self,
        identifier: ClientIdentifier,
        ciphersuites: &[MlsCiphersuite],
        backend: &MlsCryptoProvider,
        nb_key_package: usize,
    ) -> Result<()> {
        self.ensure_unready().await?;
        let id = identifier.get_id()?;
        let credentials = backend
            .key_store()
            .find_all::<MlsCredential>(EntityFindParams::default())
            .await
            .map_err(KeystoreError::wrap("finding all mls credentials"))?;
        let credentials = credentials
            .into_iter()
            .filter(|mls_credential| &mls_credential.id[..] == id.as_slice())
            .map(|mls_credential| -> Result<_> {
                let credential = Credential::tls_deserialize(&mut mls_credential.credential.as_slice())
                    .map_err(Error::tls_deserialize("mls credential"))?;
                Ok((credential, mls_credential.created_at))
            })
            .collect::<Result<Vec<_>>>()?;
        if !credentials.is_empty() {
            let signature_schemes = ciphersuites
                .iter()
                .map(|cs| cs.signature_algorithm())
                .collect::<HashSet<_>>();
            match self.load(backend, id.as_ref(), credentials, signature_schemes).await {
                Ok(client) => client,
                Err(Error::ClientSignatureNotFound) => {
                    debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Client signature not found. Generating client");
                    self.generate(identifier, backend, ciphersuites, nb_key_package).await?
                }
                Err(e) => return Err(e),
            }
        } else {
            debug!(count = nb_key_package, ciphersuites:? = ciphersuites; "Generating client");
            self.generate(identifier, backend, ciphersuites, nb_key_package).await?
        };
        Ok(())
    }
    pub(crate) async fn is_ready(&self) -> bool {
        let inner_lock = self.state.read().await;
        inner_lock.is_some()
    }
    async fn ensure_unready(&self) -> Result<()> {
        if self.is_ready().await {
            Err(Error::UnexpectedlyReady)
        } else {
            Ok(())
        }
    }
    async fn replace_inner(&self, new_inner: ClientInner) {
        let mut inner_lock = self.state.write().await;
        *inner_lock = Some(new_inner);
    }
    pub async fn generate_raw_keypairs(
        &self,
        ciphersuites: &[MlsCiphersuite],
        backend: &MlsCryptoProvider,
    ) -> Result<Vec<ClientId>> {
        self.ensure_unready().await?;
        const TEMP_KEY_SIZE: usize = 16;
        let credentials = Self::find_all_basic_credentials(backend).await?;
        if !credentials.is_empty() {
            return Err(Error::IdentityAlreadyPresent);
        }
        use openmls_traits::random::OpenMlsRand as _;
        let mut tmp_client_ids = Vec::with_capacity(ciphersuites.len());
        for cs in ciphersuites {
            let tmp_client_id: ClientId = backend
                .rand()
                .random_vec(TEMP_KEY_SIZE)
                .map_err(MlsError::wrap("generating random client id"))?
                .into();
            let cb = Self::new_basic_credential_bundle(&tmp_client_id, cs.signature_algorithm(), backend)
                .map_err(RecursiveError::mls_credential("creating new basic credential bundle"))?;
            let sign_kp = MlsSignatureKeyPair::new(
                cs.signature_algorithm(),
                cb.signature_key.to_public_vec(),
                cb.signature_key
                    .tls_serialize_detached()
                    .map_err(Error::tls_serialize("signature key"))?,
                tmp_client_id.clone().into(),
            );
            backend
                .key_store()
                .save(sign_kp)
                .await
                .map_err(KeystoreError::wrap("save signature keypair in keystore"))?;
            tmp_client_ids.push(tmp_client_id);
        }
        Ok(tmp_client_ids)
    }
    pub async fn init_with_external_client_id(
        &self,
        client_id: ClientId,
        tmp_ids: Vec<ClientId>,
        ciphersuites: &[MlsCiphersuite],
        backend: &MlsCryptoProvider,
    ) -> Result<()> {
        self.ensure_unready().await?;
        let stored_skp = backend
            .key_store()
            .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
            .await
            .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
        match stored_skp.len().cmp(&tmp_ids.len()) {
            std::cmp::Ordering::Less => return Err(Error::NoProvisionalIdentityFound),
            std::cmp::Ordering::Greater => return Err(Error::TooManyIdentitiesPresent),
            _ => {}
        }
        let all_tmp_ids_exist = stored_skp
            .iter()
            .all(|kp| tmp_ids.contains(&kp.credential_id.as_slice().into()));
        if !all_tmp_ids_exist {
            return Err(Error::NoProvisionalIdentityFound);
        }
        let identities = stored_skp.iter().zip(ciphersuites);
        self.replace_inner(ClientInner {
            id: client_id.clone(),
            identities: ClientIdentities::new(stored_skp.len()),
            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
        })
        .await;
        let id = &client_id;
        for (tmp_kp, &cs) in identities {
            let scheme = tmp_kp
                .signature_scheme
                .try_into()
                .map_err(|_| Error::InvalidSignatureScheme)?;
            let new_keypair =
                MlsSignatureKeyPair::new(scheme, tmp_kp.pk.clone(), tmp_kp.keypair.clone(), id.clone().into());
            let new_credential = MlsCredential {
                id: id.clone().into(),
                credential: tmp_kp.credential_id.clone(),
                created_at: 0,
            };
            backend
                .key_store()
                .remove::<MlsSignatureKeyPair, &[u8]>(&new_keypair.pk)
                .await
                .map_err(KeystoreError::wrap("removing mls signature keypair"))?;
            let signature_key = SignatureKeyPair::tls_deserialize(&mut new_keypair.keypair.as_slice())
                .map_err(Error::tls_deserialize("signature key"))?;
            let cb = CredentialBundle {
                credential: Credential::new_basic(new_credential.credential.clone()),
                signature_key,
                created_at: 0, };
            self.save_identity(&backend.keystore(), Some(id), cs.signature_algorithm(), cb)
                .await?;
        }
        Ok(())
    }
    pub(crate) async fn generate(
        &self,
        identifier: ClientIdentifier,
        backend: &MlsCryptoProvider,
        ciphersuites: &[MlsCiphersuite],
        nb_key_package: usize,
    ) -> Result<()> {
        self.ensure_unready().await?;
        let id = identifier.get_id()?;
        let signature_schemes = ciphersuites
            .iter()
            .map(|cs| cs.signature_algorithm())
            .collect::<HashSet<_>>();
        self.replace_inner(ClientInner {
            id: id.into_owned(),
            identities: ClientIdentities::new(signature_schemes.len()),
            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
        })
        .await;
        let identities = identifier.generate_credential_bundles(backend, signature_schemes)?;
        for (sc, id, cb) in identities {
            self.save_identity(&backend.keystore(), Some(&id), sc, cb).await?;
        }
        let identities = match self.state.read().await.deref() {
            None => return Err(Error::MlsNotInitialized),
            Some(ClientInner { identities, .. }) => identities.clone(),
        };
        if nb_key_package != 0 {
            for cs in ciphersuites {
                let sc = cs.signature_algorithm();
                let identity = identities.iter().filter(|(id_sc, _)| id_sc == &sc);
                for (_, cb) in identity {
                    self.request_key_packages(nb_key_package, *cs, cb.credential.credential_type().into(), backend)
                        .await?;
                }
            }
        }
        Ok(())
    }
    pub(crate) async fn load(
        &self,
        backend: &MlsCryptoProvider,
        id: &ClientId,
        mut credentials: Vec<(Credential, u64)>,
        signature_schemes: HashSet<SignatureScheme>,
    ) -> Result<()> {
        self.ensure_unready().await?;
        let mut identities = ClientIdentities::new(signature_schemes.len());
        credentials.sort_by_key(|(_, timestamp)| *timestamp);
        let store_skps = backend
            .key_store()
            .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
            .await
            .map_err(KeystoreError::wrap("finding all mls signature keypairs"))?;
        for sc in signature_schemes {
            let kp = store_skps.iter().find(|skp| skp.signature_scheme == (sc as u16));
            let signature_key = if let Some(kp) = kp {
                SignatureKeyPair::tls_deserialize(&mut kp.keypair.as_slice())
                    .map_err(Error::tls_deserialize("signature keypair"))?
            } else {
                let (sk, pk) = backend
                    .crypto()
                    .signature_key_gen(sc)
                    .map_err(MlsError::wrap("generating signature key"))?;
                let keypair = SignatureKeyPair::from_raw(sc, sk, pk.clone());
                let raw_keypair = keypair
                    .tls_serialize_detached()
                    .map_err(Error::tls_serialize("raw keypair"))?;
                let store_keypair = MlsSignatureKeyPair::new(sc, pk, raw_keypair, id.as_slice().into());
                backend
                    .key_store()
                    .save(store_keypair.clone())
                    .await
                    .map_err(KeystoreError::wrap("storing keypairs in keystore"))?;
                SignatureKeyPair::tls_deserialize(&mut store_keypair.keypair.as_slice())
                    .map_err(Error::tls_deserialize("signature keypair"))?
            };
            for (credential, created_at) in &credentials {
                match credential.mls_credential() {
                    openmls::prelude::MlsCredentialType::Basic(_) => {
                        if id.as_slice() != credential.identity() {
                            return Err(Error::WrongCredential);
                        }
                    }
                    openmls::prelude::MlsCredentialType::X509(cert) => {
                        let spk = cert
                            .extract_public_key()
                            .map_err(RecursiveError::mls_credential("extracting public key"))?
                            .ok_or(LeafError::InternalMlsError)?;
                        if signature_key.public() != spk {
                            return Err(Error::WrongCredential);
                        }
                    }
                };
                let cb = CredentialBundle {
                    credential: credential.clone(),
                    signature_key: signature_key.clone(),
                    created_at: *created_at,
                };
                identities.push_credential_bundle(sc, cb).await?;
            }
        }
        self.replace_inner(ClientInner {
            id: id.clone(),
            identities,
            keypackage_lifetime: KEYPACKAGE_DEFAULT_LIFETIME,
        })
        .await;
        Ok(())
    }
    async fn find_all_basic_credentials(backend: &MlsCryptoProvider) -> Result<Vec<Credential>> {
        let store_credentials = backend
            .key_store()
            .find_all::<MlsCredential>(EntityFindParams::default())
            .await
            .map_err(KeystoreError::wrap("finding all mls credentialss"))?;
        let mut credentials = Vec::with_capacity(store_credentials.len());
        for store_credential in store_credentials.into_iter() {
            let credential = Credential::tls_deserialize(&mut store_credential.credential.as_slice())
                .map_err(Error::tls_deserialize("credential"))?;
            if !matches!(credential.credential_type(), CredentialType::Basic) {
                continue;
            }
            credentials.push(credential);
        }
        Ok(credentials)
    }
    pub(crate) async fn save_identity(
        &self,
        keystore: &Connection,
        id: Option<&ClientId>,
        sc: SignatureScheme,
        mut cb: CredentialBundle,
    ) -> Result<CredentialBundle> {
        match self.state.write().await.deref_mut() {
            None => Err(Error::MlsNotInitialized),
            Some(ClientInner {
                id: existing_id,
                identities,
                ..
            }) => {
                let id = id.unwrap_or(existing_id);
                let credential = cb
                    .credential
                    .tls_serialize_detached()
                    .map_err(Error::tls_serialize("credential bundle"))?;
                let credential = MlsCredential {
                    id: id.clone().into(),
                    credential,
                    created_at: 0,
                };
                let credential = keystore
                    .save(credential)
                    .await
                    .map_err(KeystoreError::wrap("saving credential"))?;
                let sign_kp = MlsSignatureKeyPair::new(
                    sc,
                    cb.signature_key.to_public_vec(),
                    cb.signature_key
                        .tls_serialize_detached()
                        .map_err(Error::tls_serialize("signature keypair"))?,
                    id.clone().into(),
                );
                keystore.save(sign_kp).await.map_err(|e| match e {
                    CryptoKeystoreError::AlreadyExists => Error::CredentialBundleConflict,
                    _ => KeystoreError::wrap("saving mls signature key pair")(e).into(),
                })?;
                cb.created_at = credential.created_at;
                identities.push_credential_bundle(sc, cb.clone()).await?;
                Ok(cb)
            }
        }
    }
    pub async fn id(&self) -> Result<ClientId> {
        match self.state.read().await.deref() {
            None => Err(Error::MlsNotInitialized),
            Some(ClientInner { id, .. }) => Ok(id.clone()),
        }
    }
    pub async fn is_e2ei_capable(&self) -> bool {
        match self.state.read().await.deref() {
            None => false,
            Some(ClientInner { identities, .. }) => identities
                .iter()
                .any(|(_, cred)| cred.credential().credential_type() == CredentialType::X509),
        }
    }
    pub(crate) async fn get_most_recent_or_create_credential_bundle(
        &self,
        backend: &MlsCryptoProvider,
        sc: SignatureScheme,
        ct: MlsCredentialType,
    ) -> Result<Arc<CredentialBundle>> {
        match ct {
            MlsCredentialType::Basic => {
                self.init_basic_credential_bundle_if_missing(backend, sc).await?;
                self.find_most_recent_credential_bundle(sc, ct).await
            }
            MlsCredentialType::X509 => self
                .find_most_recent_credential_bundle(sc, ct)
                .await
                .map_err(|e| match e {
                    Error::CredentialNotFound(_) => LeafError::E2eiEnrollmentNotDone.into(),
                    _ => e,
                }),
        }
    }
    pub(crate) async fn init_basic_credential_bundle_if_missing(
        &self,
        backend: &MlsCryptoProvider,
        sc: SignatureScheme,
    ) -> Result<()> {
        let existing_cb = self
            .find_most_recent_credential_bundle(sc, MlsCredentialType::Basic)
            .await;
        if matches!(existing_cb, Err(Error::CredentialNotFound(_))) {
            let id = self.id().await?;
            debug!(id:% = &id; "Initializing basic credential bundle");
            let cb = Self::new_basic_credential_bundle(&id, sc, backend)
                .map_err(RecursiveError::mls_credential("creating new basic credential bundle"))?;
            self.save_identity(&backend.keystore(), None, sc, cb).await?;
        }
        Ok(())
    }
    pub(crate) async fn save_new_x509_credential_bundle(
        &self,
        keystore: &Connection,
        sc: SignatureScheme,
        cb: CertificateBundle,
    ) -> Result<CredentialBundle> {
        let id = cb
            .get_client_id()
            .map_err(RecursiveError::mls_credential("getting client id"))?;
        let cb = Self::new_x509_credential_bundle(cb)
            .map_err(RecursiveError::mls_credential("creating new x509 credential bundle"))?;
        self.save_identity(keystore, Some(&id), sc, cb).await
    }
}
#[cfg(test)]
impl Client {
    #![allow(missing_docs)]
    pub async fn random_generate(
        case: &crate::test_utils::TestCase,
        backend: &MlsCryptoProvider,
        signer: Option<&crate::test_utils::x509::X509Certificate>,
        provision: bool,
    ) -> Result<Self> {
        let user_uuid = uuid::Uuid::new_v4();
        let rnd_id = rand::random::<usize>();
        let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
        let identity = match case.credential_type {
            MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.as_str().into()),
            MlsCredentialType::X509 => {
                let signer = signer.expect("Missing intermediate CA");
                CertificateBundle::rand_identifier(&client_id, &[signer])
            }
        };
        let nb_key_package = if provision {
            crate::prelude::INITIAL_KEYING_MATERIAL_COUNT
        } else {
            0
        };
        let client = Self::default();
        client
            .generate(identity, backend, &[case.ciphersuite()], nb_key_package)
            .await?;
        Ok(client)
    }
    pub async fn find_keypackages(&self, backend: &MlsCryptoProvider) -> Result<Vec<openmls::prelude::KeyPackage>> {
        use core_crypto_keystore::CryptoKeystoreMls as _;
        let kps = backend
            .key_store()
            .mls_fetch_keypackages::<openmls::prelude::KeyPackage>(u32::MAX)
            .await
            .map_err(KeystoreError::wrap("fetching mls keypackages"))?;
        Ok(kps)
    }
}
#[cfg(test)]
mod tests {
    use crate::prelude::ClientId;
    use crate::test_utils::*;
    use core_crypto_keystore::connection::FetchFromDatabase;
    use core_crypto_keystore::entities::{EntityFindParams, MlsSignatureKeyPair};
    use mls_crypto_provider::MlsCryptoProvider;
    use wasm_bindgen_test::*;
    use super::Client;
    wasm_bindgen_test_configure!(run_in_browser);
    #[apply(all_cred_cipher)]
    #[wasm_bindgen_test]
    async fn can_generate_client(case: TestCase) {
        let backend = MlsCryptoProvider::try_new_in_memory("test").await.unwrap();
        let x509_test_chain = if case.is_x509() {
            let x509_test_chain = crate::test_utils::x509::X509TestChain::init_empty(case.signature_scheme());
            x509_test_chain.register_with_provider(&backend).await;
            Some(x509_test_chain)
        } else {
            None
        };
        backend.new_transaction().await.unwrap();
        let _ = Client::random_generate(
            &case,
            &backend,
            x509_test_chain.as_ref().map(|chain| chain.find_local_intermediate_ca()),
            false,
        )
        .await
        .unwrap();
    }
    #[apply(all_cred_cipher)]
    #[wasm_bindgen_test]
    async fn can_externally_generate_client(case: TestCase) {
        if case.is_basic() {
            run_tests(move |[tmp_dir_argument]| {
                Box::pin(async move {
                    let backend = MlsCryptoProvider::try_new(tmp_dir_argument, "test").await.unwrap();
                    backend.new_transaction().await.unwrap();
                    let client_id: ClientId = b"whatever:my:client:is@world.com".to_vec().into();
                    let alice = Client::default();
                    let handles = alice
                        .generate_raw_keypairs(&[case.ciphersuite()], &backend)
                        .await
                        .unwrap();
                    let mut identities = backend
                        .keystore()
                        .find_all::<MlsSignatureKeyPair>(EntityFindParams::default())
                        .await
                        .unwrap();
                    assert_eq!(identities.len(), 1);
                    let prov_identity = identities.pop().unwrap();
                    let prov_client_id: ClientId = prov_identity.credential_id.as_slice().into();
                    assert_eq!(&prov_client_id, handles.first().unwrap());
                    alice
                        .init_with_external_client_id(
                            client_id.clone(),
                            handles.clone(),
                            &[case.ciphersuite()],
                            &backend,
                        )
                        .await
                        .unwrap();
                    assert_eq!(alice.id().await.unwrap(), client_id);
                    let cb = alice
                        .find_most_recent_credential_bundle(case.signature_scheme(), case.credential_type)
                        .await
                        .unwrap();
                    let client_id: ClientId = cb.credential().identity().into();
                    assert_eq!(&client_id, handles.first().unwrap());
                })
            })
            .await
        }
    }
}