1use crate::mls::HasClientAndProvider;
5#[cfg(feature = "proteus")]
6use crate::proteus::ProteusCentral;
7use crate::{
8 CoreCrypto, Error, KeystoreError, MlsError, MlsTransport, RecursiveError, Result,
9 group_store::GroupStore,
10 mls::MlsCentral,
11 prelude::{Client, MlsConversation},
12};
13use async_lock::{Mutex, RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
14use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
15use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
16use std::{ops::Deref, sync::Arc};
17
18#[derive(Debug, Clone)]
25pub struct CentralContext {
26 state: Arc<RwLock<ContextState>>,
27}
28
29#[derive(Debug, Clone)]
33enum ContextState {
34 Valid {
35 provider: MlsCryptoProvider,
36 transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
37 mls_client: Client,
38 mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
39 #[cfg(feature = "proteus")]
40 proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
41 },
42 Invalid,
43}
44
45impl CoreCrypto {
46 pub async fn new_transaction(&self) -> Result<CentralContext> {
50 CentralContext::new(
51 &self.mls,
52 #[cfg(feature = "proteus")]
53 self.proteus.clone(),
54 )
55 .await
56 }
57}
58
59#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
60#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
61impl HasClientAndProvider for CentralContext {
62 async fn client(&self) -> crate::mls::Result<Client> {
63 self.mls_client()
64 .await
65 .map_err(RecursiveError::root("getting mls client"))
66 .map_err(Into::into)
67 }
68
69 async fn mls_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
70 self.mls_provider()
71 .await
72 .map_err(RecursiveError::root("getting mls provider"))
73 .map_err(Into::into)
74 }
75}
76
77impl CentralContext {
78 async fn new(
79 mls_central: &MlsCentral,
80 #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
81 ) -> Result<Self> {
82 mls_central
83 .mls_backend
84 .new_transaction()
85 .await
86 .map_err(MlsError::wrap("creating new transaction"))?;
87 let mls_groups = Arc::new(RwLock::new(Default::default()));
88 let callbacks = mls_central.transport.clone();
89 let mls_client = mls_central.mls_client.clone();
90 Ok(Self {
91 state: Arc::new(
92 ContextState::Valid {
93 mls_client,
94 transport: callbacks,
95 provider: mls_central.mls_backend.clone(),
96 mls_groups,
97 #[cfg(feature = "proteus")]
98 proteus_central,
99 }
100 .into(),
101 ),
102 })
103 }
104
105 pub(crate) async fn mls_client(&self) -> Result<Client> {
106 match self.state.read().await.deref() {
107 ContextState::Valid { mls_client, .. } => Ok(mls_client.clone()),
108 ContextState::Invalid => Err(Error::InvalidContext),
109 }
110 }
111
112 pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
113 match self.state.read().await.deref() {
114 ContextState::Valid {
115 transport: callbacks, ..
116 } => Ok(callbacks.read_arc().await),
117 ContextState::Invalid => Err(Error::InvalidContext),
118 }
119 }
120
121 #[cfg(test)]
122 pub(crate) async fn set_transport_callbacks(
123 &self,
124 callbacks: Option<Arc<dyn MlsTransport + 'static>>,
125 ) -> Result<()> {
126 match self.state.read().await.deref() {
127 ContextState::Valid { transport: cbs, .. } => {
128 *cbs.write_arc().await = callbacks;
129 Ok(())
130 }
131 ContextState::Invalid => Err(Error::InvalidContext),
132 }
133 }
134
135 pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
137 match self.state.read().await.deref() {
138 ContextState::Valid { provider, .. } => Ok(provider.clone()),
139 ContextState::Invalid => Err(Error::InvalidContext),
140 }
141 }
142
143 pub(crate) async fn keystore(&self) -> Result<CryptoKeystore> {
144 match self.state.read().await.deref() {
145 ContextState::Valid { provider, .. } => Ok(provider.keystore()),
146 ContextState::Invalid => Err(Error::InvalidContext),
147 }
148 }
149
150 pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
151 match self.state.read().await.deref() {
152 ContextState::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
153 ContextState::Invalid => Err(Error::InvalidContext),
154 }
155 }
156
157 #[cfg(feature = "proteus")]
158 pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
159 match self.state.read().await.deref() {
160 ContextState::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
161 ContextState::Invalid => Err(Error::InvalidContext),
162 }
163 }
164
165 pub async fn finish(&self) -> Result<()> {
169 let mut guard = self.state.write().await;
170 let ContextState::Valid { provider, .. } = guard.deref() else {
171 return Err(Error::InvalidContext);
172 };
173
174 let commit_result = provider
175 .keystore()
176 .commit_transaction()
177 .await
178 .map_err(KeystoreError::wrap("commiting transaction"))
179 .map_err(Into::into);
180
181 *guard = ContextState::Invalid;
182 commit_result
183 }
184
185 pub async fn abort(&self) -> Result<()> {
189 let mut guard = self.state.write().await;
190
191 let ContextState::Valid { provider, .. } = guard.deref() else {
192 return Err(Error::InvalidContext);
193 };
194
195 let result = provider
196 .keystore()
197 .rollback_transaction()
198 .await
199 .map_err(KeystoreError::wrap("rolling back transaction"))
200 .map_err(Into::into);
201
202 *guard = ContextState::Invalid;
203 result
204 }
205
206 pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
210 self.keystore()
211 .await?
212 .save(ConsumerData::from(data))
213 .await
214 .map_err(KeystoreError::wrap("saving consumer data"))?;
215 Ok(())
216 }
217
218 pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
221 match self.keystore().await?.find_unique::<ConsumerData>().await {
222 Ok(data) => Ok(Some(data.into())),
223 Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
224 Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
225 }
226 }
227}