1use crate::mls::HasClientAndProvider;
5#[cfg(feature = "proteus")]
6use crate::proteus::ProteusCentral;
7use crate::{
8 CoreCrypto, Error, KeystoreError, MlsError, MlsTransport, RecursiveError, Result,
9 group_store::GroupStore,
10 prelude::{Client, MlsConversation},
11};
12use async_lock::{Mutex, RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};
13use core_crypto_keystore::{CryptoKeystoreError, connection::FetchFromDatabase, entities::ConsumerData};
14use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
15use std::{ops::Deref, sync::Arc};
16
17#[derive(Debug, Clone)]
24pub struct CentralContext {
25 state: Arc<RwLock<ContextState>>,
26}
27
28#[derive(Debug, Clone)]
32enum ContextState {
33 Valid {
34 provider: MlsCryptoProvider,
35 transport: Arc<RwLock<Option<Arc<dyn MlsTransport + 'static>>>>,
36 mls_client: Client,
37 mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
38 #[cfg(feature = "proteus")]
39 proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
40 },
41 Invalid,
42}
43
44impl CoreCrypto {
45 pub async fn new_transaction(&self) -> Result<CentralContext> {
49 CentralContext::new(
50 &self.mls,
51 #[cfg(feature = "proteus")]
52 self.proteus.clone(),
53 )
54 .await
55 }
56}
57
58#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
59#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
60impl HasClientAndProvider for CentralContext {
61 async fn client(&self) -> crate::mls::Result<Client> {
62 self.mls_client()
63 .await
64 .map_err(RecursiveError::root("getting mls client"))
65 .map_err(Into::into)
66 }
67
68 async fn mls_provider(&self) -> crate::mls::Result<MlsCryptoProvider> {
69 self.mls_provider()
70 .await
71 .map_err(RecursiveError::root("getting mls provider"))
72 .map_err(Into::into)
73 }
74}
75
76impl CentralContext {
77 async fn new(
78 client: &Client,
79 #[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
80 ) -> Result<Self> {
81 client
82 .mls_backend
83 .new_transaction()
84 .await
85 .map_err(MlsError::wrap("creating new transaction"))?;
86 let mls_groups = Arc::new(RwLock::new(Default::default()));
87 let callbacks = client.transport.clone();
88 let mls_client = client.clone();
89 Ok(Self {
90 state: Arc::new(
91 ContextState::Valid {
92 mls_client,
93 transport: callbacks,
94 provider: client.mls_backend.clone(),
95 mls_groups,
96 #[cfg(feature = "proteus")]
97 proteus_central,
98 }
99 .into(),
100 ),
101 })
102 }
103
104 pub(crate) async fn mls_client(&self) -> Result<Client> {
105 match self.state.read().await.deref() {
106 ContextState::Valid { mls_client, .. } => Ok(mls_client.clone()),
107 ContextState::Invalid => Err(Error::InvalidContext),
108 }
109 }
110
111 pub(crate) async fn mls_transport(&self) -> Result<RwLockReadGuardArc<Option<Arc<dyn MlsTransport + 'static>>>> {
112 match self.state.read().await.deref() {
113 ContextState::Valid {
114 transport: callbacks, ..
115 } => Ok(callbacks.read_arc().await),
116 ContextState::Invalid => Err(Error::InvalidContext),
117 }
118 }
119
120 #[cfg(test)]
121 pub(crate) async fn set_transport_callbacks(
122 &self,
123 callbacks: Option<Arc<dyn MlsTransport + 'static>>,
124 ) -> Result<()> {
125 match self.state.read().await.deref() {
126 ContextState::Valid { transport: cbs, .. } => {
127 *cbs.write_arc().await = callbacks;
128 Ok(())
129 }
130 ContextState::Invalid => Err(Error::InvalidContext),
131 }
132 }
133
134 pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
136 match self.state.read().await.deref() {
137 ContextState::Valid { provider, .. } => Ok(provider.clone()),
138 ContextState::Invalid => Err(Error::InvalidContext),
139 }
140 }
141
142 pub(crate) async fn keystore(&self) -> Result<CryptoKeystore> {
143 match self.state.read().await.deref() {
144 ContextState::Valid { provider, .. } => Ok(provider.keystore()),
145 ContextState::Invalid => Err(Error::InvalidContext),
146 }
147 }
148
149 pub(crate) async fn mls_groups(&self) -> Result<RwLockWriteGuardArc<GroupStore<MlsConversation>>> {
150 match self.state.read().await.deref() {
151 ContextState::Valid { mls_groups, .. } => Ok(mls_groups.write_arc().await),
152 ContextState::Invalid => Err(Error::InvalidContext),
153 }
154 }
155
156 #[cfg(feature = "proteus")]
157 pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
158 match self.state.read().await.deref() {
159 ContextState::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
160 ContextState::Invalid => Err(Error::InvalidContext),
161 }
162 }
163
164 pub async fn finish(&self) -> Result<()> {
168 let mut guard = self.state.write().await;
169 let ContextState::Valid { provider, .. } = guard.deref() else {
170 return Err(Error::InvalidContext);
171 };
172
173 let commit_result = provider
174 .keystore()
175 .commit_transaction()
176 .await
177 .map_err(KeystoreError::wrap("commiting transaction"))
178 .map_err(Into::into);
179
180 *guard = ContextState::Invalid;
181 commit_result
182 }
183
184 pub async fn abort(&self) -> Result<()> {
188 let mut guard = self.state.write().await;
189
190 let ContextState::Valid { provider, .. } = guard.deref() else {
191 return Err(Error::InvalidContext);
192 };
193
194 let result = provider
195 .keystore()
196 .rollback_transaction()
197 .await
198 .map_err(KeystoreError::wrap("rolling back transaction"))
199 .map_err(Into::into);
200
201 *guard = ContextState::Invalid;
202 result
203 }
204
205 pub async fn set_data(&self, data: Vec<u8>) -> Result<()> {
209 self.keystore()
210 .await?
211 .save(ConsumerData::from(data))
212 .await
213 .map_err(KeystoreError::wrap("saving consumer data"))?;
214 Ok(())
215 }
216
217 pub async fn get_data(&self) -> Result<Option<Vec<u8>>> {
220 match self.keystore().await?.find_unique::<ConsumerData>().await {
221 Ok(data) => Ok(Some(data.into())),
222 Err(CryptoKeystoreError::NotFound(..)) => Ok(None),
223 Err(err) => Err(KeystoreError::wrap("finding unique consumer data")(err).into()),
224 }
225 }
226}