1use crate::mls::MlsCentral;
5#[cfg(feature = "proteus")]
6use crate::proteus::ProteusCentral;
7use crate::{
8 group_store::GroupStore,
9 prelude::{Client, MlsConversation},
10 CoreCrypto, CoreCryptoCallbacks, CryptoError, CryptoResult,
11};
12use async_lock::{Mutex, RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
13use core_crypto_keystore::connection::FetchFromDatabase;
14use core_crypto_keystore::entities::ConsumerData;
15use core_crypto_keystore::CryptoKeystoreError;
16use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
17use std::{ops::Deref, sync::Arc};
18
19#[derive(Debug, Clone)]
26pub struct CentralContext {
27 state: Arc<RwLock<ContextState>>,
28}
29
30#[derive(Debug, Clone)]
34enum ContextState {
35 Valid {
36 provider: MlsCryptoProvider,
37 callbacks: Arc<RwLock<Option<std::sync::Arc<dyn CoreCryptoCallbacks + 'static>>>>,
38 mls_client: Client,
39 mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
40 #[cfg(feature = "proteus")]
41 proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
42 },
43 Invalid,
44}
45
46impl CoreCrypto {
47 pub async fn new_transaction(&self) -> CryptoResult<CentralContext> {
51 CentralContext::new(
52 &self.mls,
53 #[cfg(feature = "proteus")]
54 self.proteus.clone(),
55 )
56 .await
57 }
58}
59
60impl CentralContext {
61 async fn new(
62 mls_central: &MlsCentral,
63 #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
64 ) -> CryptoResult<Self> {
65 mls_central.mls_backend.new_transaction().await?;
66 let mls_groups = Arc::new(RwLock::new(Default::default()));
67 let callbacks = mls_central.callbacks.clone();
68 let mls_client = mls_central.mls_client.clone();
69 Ok(Self {
70 state: Arc::new(
71 ContextState::Valid {
72 mls_client,
73 callbacks,
74 provider: mls_central.mls_backend.clone(),
75 mls_groups,
76 #[cfg(feature = "proteus")]
77 proteus_central,
78 }
79 .into(),
80 ),
81 })
82 }
83
84 pub(crate) async fn mls_client(&self) -> CryptoResult<Client> {
85 match self.state.read().await.deref() {
86 ContextState::Valid { mls_client, .. } => Ok(mls_client.clone()),
87 ContextState::Invalid => Err(CryptoError::InvalidContext),
88 }
89 }
90
91 pub(crate) async fn callbacks(
92 &self,
93 ) -> CryptoResult<RwLockReadGuardArc<Option<Arc<dyn CoreCryptoCallbacks + 'static>>>> {
94 match self.state.read().await.deref() {
95 ContextState::Valid { callbacks, .. } => Ok(callbacks.read_arc().await),
96 ContextState::Invalid => Err(CryptoError::InvalidContext),
97 }
98 }
99
100 #[cfg(test)]
101 pub(crate) async fn set_callbacks(
102 &self,
103 callbacks: Option<Arc<dyn CoreCryptoCallbacks + 'static>>,
104 ) -> CryptoResult<()> {
105 match self.state.read().await.deref() {
106 ContextState::Valid { callbacks: cbs, .. } => {
107 *cbs.write_arc().await = callbacks;
108 Ok(())
109 }
110 ContextState::Invalid => Err(CryptoError::InvalidContext),
111 }
112 }
113
114 pub async fn mls_provider(&self) -> CryptoResult<MlsCryptoProvider> {
116 match self.state.read().await.deref() {
117 ContextState::Valid { provider, .. } => Ok(provider.clone()),
118 ContextState::Invalid => Err(CryptoError::InvalidContext),
119 }
120 }
121
122 pub(crate) async fn keystore(&self) -> CryptoResult<CryptoKeystore> {
123 match self.state.read().await.deref() {
124 ContextState::Valid { provider, .. } => Ok(provider.keystore()),
125 ContextState::Invalid => Err(CryptoError::InvalidContext),
126 }
127 }
128
129 pub(crate) async fn mls_groups(&self) -> CryptoResult<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
130 match self.state.read().await.deref() {
131 ContextState::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
132 ContextState::Invalid => Err(CryptoError::InvalidContext),
133 }
134 }
135
136 #[cfg(feature = "proteus")]
137 pub(crate) async fn proteus_central(&self) -> CryptoResult<Arc<Mutex<Option<ProteusCentral>>>> {
138 match self.state.read().await.deref() {
139 ContextState::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
140 ContextState::Invalid => Err(CryptoError::InvalidContext),
141 }
142 }
143
144 pub async fn finish(&self) -> CryptoResult<()> {
148 let mut guard = self.state.write().await;
149 let commit_result = match guard.deref() {
150 ContextState::Valid { provider, .. } => provider.keystore().commit_transaction().await,
151 ContextState::Invalid => return Err(CryptoError::InvalidContext),
152 };
153 *guard = ContextState::Invalid;
154 commit_result.map_err(Into::into)
155 }
156
157 pub async fn abort(&self) -> CryptoResult<()> {
161 let mut guard = self.state.write().await;
162 let rollback_result = match guard.deref() {
163 ContextState::Valid { provider, .. } => provider.keystore().rollback_transaction().await,
164 ContextState::Invalid => return Err(CryptoError::InvalidContext),
165 };
166 *guard = ContextState::Invalid;
167 rollback_result.map_err(Into::into)
168 }
169
170 pub async fn set_data(&self, data: Vec<u8>) -> CryptoResult<()> {
174 self.keystore().await?.save(ConsumerData::from(data)).await?;
175 Ok(())
176 }
177
178 pub async fn get_data(&self) -> CryptoResult<Option<Vec<u8>>> {
181 match self.keystore().await?.find_unique::<ConsumerData>().await {
182 Ok(data) => Ok(Some(data.into())),
183 Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
184 Err(err) => Err(err.into()),
185 }
186 }
187}