core_crypto_macros/entity_derive/
derive_impl.rs

1use crate::entity_derive::{IdColumnType, IdTransformation, KeyStoreEntityFlattened};
2use quote::quote;
3
4impl quote::ToTokens for KeyStoreEntityFlattened {
5    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
6        let entity_base_impl = self.entity_base_impl();
7        let entity_generic_impl = self.entity_generic_impl();
8        let entity_wasm_impl = self.entity_wasm_impl();
9        let entity_transaction_ext_impl = self.entity_transaction_ext_impl();
10        tokens.extend(quote! {
11            #entity_base_impl
12            #entity_generic_impl
13            #entity_wasm_impl
14            #entity_transaction_ext_impl
15        });
16    }
17}
18impl KeyStoreEntityFlattened {
19    fn entity_base_impl(&self) -> proc_macro2::TokenStream {
20        let Self {
21            collection_name,
22            struct_name,
23            ..
24        } = self;
25
26        // Identical for both wasm and non-wasm
27        quote! {
28            #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
29            #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
30            impl crate::entities::EntityBase for #struct_name {
31                type ConnectionType = crate::connection::KeystoreDatabaseConnection;
32                type AutoGeneratedFields = ();
33                const COLLECTION_NAME: &'static str = #collection_name;
34
35                fn to_missing_key_err_kind() -> crate::MissingKeyErrorKind {
36                    crate::MissingKeyErrorKind::#struct_name
37                }
38
39                fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity {
40                    crate::transaction::dynamic_dispatch::Entity::#struct_name(self)
41                }
42            }
43        }
44    }
45
46    fn entity_generic_impl(&self) -> proc_macro2::TokenStream {
47        let Self {
48            collection_name,
49            struct_name,
50            id,
51            id_type,
52            id_name,
53            id_transformation,
54            blob_columns,
55            blob_column_names,
56            all_columns,
57            optional_blob_columns,
58            optional_blob_column_names,
59            ..
60        } = self;
61
62        let string_id_conversion = (*id_type == IdColumnType::String).then(|| {
63            quote! { let #id: String = id.try_into()?; }
64        });
65
66        let id_to_byte_slice = match id_type {
67            IdColumnType::String => quote! {self.#id.as_bytes() },
68            IdColumnType::Bytes => quote! { &self.#id.as_slice() },
69        };
70
71        let id_field_find_one = match id_type {
72            IdColumnType::String => quote! { #id, },
73            IdColumnType::Bytes => quote! { #id: id.to_bytes(), },
74        };
75
76        let id_slice = match id_type {
77            IdColumnType::String => quote! { #id.as_str() },
78            IdColumnType::Bytes => quote! { #id.as_slice() },
79        };
80
81        let id_input_transformed = match id_transformation {
82            Some(IdTransformation::Hex) => quote! { id.as_hex_string() },
83            Some(IdTransformation::Sha256) => todo!(),
84            None => id_slice,
85        };
86
87        let destructure_row = match id_transformation {
88            Some(IdTransformation::Hex) => quote! { let (rowid, #id): (_, String) = row?; },
89            Some(IdTransformation::Sha256) => todo!(),
90            None => quote! { let (rowid, #id) = row?; },
91        };
92
93        let id_from_transformed = match id_transformation {
94            Some(IdTransformation::Hex) => {
95                quote! { let #id = <Self as crate::entities::EntityIdStringExt>::id_from_hex(&#id)?; }
96            }
97            Some(IdTransformation::Sha256) => todo!(),
98            None => quote! {},
99        };
100
101        let find_all_query = format!("SELECT rowid, {id_name} FROM {collection_name} ");
102
103        let find_one_query = format!("SELECT rowid FROM {collection_name} WHERE {id_name} = ?");
104
105        let count_query = format!("SELECT COUNT(*) FROM {collection_name}");
106
107        quote! {
108            #[cfg(not(target_family = "wasm"))]
109            #[async_trait::async_trait]
110            impl crate::entities::Entity for #struct_name {
111                fn id_raw(&self) -> &[u8] {
112                    #id_to_byte_slice
113                }
114
115                async fn find_all(
116                    conn: &mut Self::ConnectionType,
117                    params: crate::entities::EntityFindParams,
118                ) -> crate::CryptoKeystoreResult<Vec<Self>> {
119                    let transaction = conn.transaction()?;
120                    let query = #find_all_query.to_string() + &params.to_sql();
121
122                    let mut stmt = transaction.prepare_cached(&query)?;
123                    let mut rows = stmt.query_map([], |r| Ok((r.get(0)?, r.get(1)?)))?;
124                    use std::io::Read as _;
125                    rows.map(|row| {
126                        #destructure_row
127                        #id_from_transformed
128
129                        #(
130                            let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #blob_column_names, rowid, true)?;
131                            let mut #blob_columns = Vec::with_capacity(blob.len());
132                            blob.read_to_end(&mut #blob_columns)?;
133                            blob.close()?;
134                        )*
135
136                        #(
137                            let mut #optional_blob_columns = None;
138                            if let Ok(mut blob) =
139                                transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #optional_blob_column_names, rowid, true)
140                            {
141                                if !blob.is_empty() {
142                                    let mut blob_data = Vec::with_capacity(blob.len());
143                                    blob.read_to_end(&mut blob_data)?;
144                                    #optional_blob_columns.replace(blob_data);
145                                }
146                                blob.close()?;
147                            }
148                        )*
149
150                        Ok(Self { #id
151                            #(
152                            , #all_columns
153                            )*
154                        })
155                    }).collect()
156                }
157
158                async fn find_one(
159                    conn: &mut Self::ConnectionType,
160                    id: &crate::entities::StringEntityId,
161                ) -> crate::CryptoKeystoreResult<Option<Self>> {
162                    let transaction = conn.transaction()?;
163                    use rusqlite::OptionalExtension as _;
164
165                   #string_id_conversion
166
167                    let mut rowid: Option<i64> = transaction
168                        .query_row(&#find_one_query, [#id_input_transformed], |r| {
169                            r.get::<_, i64>(0)
170                        })
171                        .optional()?;
172
173                    use std::io::Read as _;
174                    if let Some(rowid) = rowid.take() {
175                        #(
176                            let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #blob_column_names, rowid, true)?;
177                            let mut #blob_columns = Vec::with_capacity(blob.len());
178                            blob.read_to_end(&mut #blob_columns)?;
179                            blob.close()?;
180                        )*
181
182                        #(
183                            let mut #optional_blob_columns = None;
184                            if let Ok(mut blob) =
185                                transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #optional_blob_column_names, rowid, true)
186                            {
187                                if !blob.is_empty() {
188                                    let mut blob_data = Vec::with_capacity(blob.len());
189                                    blob.read_to_end(&mut blob_data)?;
190                                    #optional_blob_columns.replace(blob_data);
191                                }
192                                blob.close()?;
193                            }
194                        )*
195
196                        Ok(Some(Self {
197                            #id_field_find_one
198                            #(
199                                #blob_columns,
200                            )*
201                            #(
202                                #optional_blob_columns,
203                            )*
204                        }))
205                    } else {
206                        Ok(None)
207                    }
208                }
209
210                async fn count(conn: &mut Self::ConnectionType) -> crate::CryptoKeystoreResult<usize> {
211                    Ok(conn.query_row(&#count_query, [], |r| r.get(0))?)
212                }
213            }
214        }
215    }
216
217    fn entity_wasm_impl(&self) -> proc_macro2::TokenStream {
218        let Self {
219            collection_name,
220            struct_name,
221            id,
222            id_type,
223            blob_columns,
224            ..
225        } = self;
226
227        let id_to_byte_slice = match id_type {
228            IdColumnType::String => quote! {self.#id.as_bytes() },
229            IdColumnType::Bytes => quote! { self.#id.as_slice() },
230        };
231
232        quote! {
233            #[cfg(target_family = "wasm")]
234            #[async_trait::async_trait(?Send)]
235            impl crate::entities::Entity for #struct_name {
236                fn id_raw(&self) -> &[u8] {
237                    #id_to_byte_slice
238                }
239
240                async fn find_all(conn: &mut Self::ConnectionType, params: crate::entities::EntityFindParams) ->  crate::CryptoKeystoreResult<Vec<Self>> {
241                    let storage = conn.storage();
242                    storage.get_all(#collection_name, Some(params)).await
243                }
244
245                async fn find_one(conn: &mut Self::ConnectionType, id: &crate::entities::StringEntityId) ->  crate::CryptoKeystoreResult<Option<Self>> {
246                    conn.storage().get(#collection_name, id.as_slice()).await
247                }
248
249                async fn count(conn: &mut Self::ConnectionType) ->  crate::CryptoKeystoreResult<usize> {
250                    conn.storage().count(#collection_name).await
251                }
252
253                fn encrypt(&mut self, cipher: &aes_gcm::Aes256Gcm) -> crate::CryptoKeystoreResult<()> {
254                    use crate::connection::DatabaseConnection as _;
255                    #(
256                        self.#blob_columns = self.encrypt_data(cipher, self.#blob_columns.as_slice())?;
257                        Self::ConnectionType::check_buffer_size(self.#blob_columns.len())?;
258                    )*
259                    Ok(())
260                }
261
262                fn decrypt(&mut self, cipher: &aes_gcm::Aes256Gcm) -> crate::CryptoKeystoreResult<()> {
263                    #(
264                        self.#blob_columns = self.decrypt_data(cipher, self.#blob_columns.as_slice())?;
265                    )*
266                    Ok(())
267                }
268            }
269        }
270    }
271
272    fn entity_transaction_ext_impl(&self) -> proc_macro2::TokenStream {
273        let Self {
274            collection_name,
275            struct_name,
276            id,
277            id_name,
278            all_columns,
279            all_column_names,
280            blob_columns,
281            blob_column_names,
282            optional_blob_columns,
283            optional_blob_column_names,
284            id_transformation,
285            no_upsert,
286            id_type,
287            ..
288        } = self;
289
290        let upsert_pairs: Vec<_> = all_column_names
291            .iter()
292            .map(|col| format! { "{col} = excluded.{col}"})
293            .collect();
294        let upsert_postfix = (!no_upsert)
295            // UPSERT (ON CONFLICT DO UPDATE) with RETURNING to get the rowid
296            .then(|| format!(" ON CONFLICT({id_name}) DO UPDATE SET {}", upsert_pairs.join(", ")))
297            .unwrap_or_default();
298
299        let column_list = all_columns
300            .iter()
301            .map(ToString::to_string)
302            .collect::<Vec<_>>()
303            .join(", ");
304
305        let import_id_string_ext = match id_transformation {
306            Some(IdTransformation::Hex) => quote! { use crate::entities::EntityIdStringExt as _; },
307            Some(IdTransformation::Sha256) => todo!(),
308            None => quote! {},
309        };
310
311        let upsert_query = format!(
312            "INSERT INTO {collection_name} ({id_name}, {column_list}) VALUES (?{}){upsert_postfix} RETURNING rowid",
313            ", ?".repeat(self.all_columns.len()),
314        );
315
316        let self_id_transformed = match id_transformation {
317            Some(IdTransformation::Hex) => quote! { self.id_hex() },
318            Some(IdTransformation::Sha256) => todo!(),
319            None => quote! { self.#id },
320        };
321
322        let delete_query = format!("DELETE FROM {collection_name} WHERE {id_name} = ?");
323
324        let id_slice_delete = match id_type {
325            IdColumnType::String => quote! { id.try_as_str()? },
326            IdColumnType::Bytes => quote! { id.as_slice() },
327        };
328
329        let id_input_transformed_delete = match id_transformation {
330            Some(IdTransformation::Hex) => quote! { id.as_hex_string() },
331            Some(IdTransformation::Sha256) => todo!(),
332            None => id_slice_delete,
333        };
334
335        quote! {
336            #[cfg(target_family = "wasm")]
337            #[async_trait::async_trait(?Send)]
338            impl crate::entities::EntityTransactionExt for #struct_name {}
339
340            #[cfg(not(target_family = "wasm"))]
341            #[async_trait::async_trait]
342            impl crate::entities::EntityTransactionExt for #struct_name {
343                async fn save(&self, transaction: &crate::connection::TransactionWrapper<'_>) -> crate::CryptoKeystoreResult<()> {
344                    use crate::entities::EntityBase as _;
345                    use rusqlite::ToSql as _;
346                    use crate::connection::DatabaseConnection as _;
347
348                    #(
349                        crate::connection::KeystoreDatabaseConnection::check_buffer_size(self.#blob_columns.len())?;
350                    )*
351                    #(
352                      crate::connection::KeystoreDatabaseConnection::check_buffer_size(
353                            self.#optional_blob_columns.as_ref().map(|v| v.len()).unwrap_or_default()
354                      )?;
355                    )*
356
357                    #import_id_string_ext
358
359                    let sql = #upsert_query;
360
361                    let rowid_result: Result<i64, rusqlite::Error> =
362                        transaction.query_row(&sql, [
363                        #self_id_transformed.to_sql()?
364                        #(
365                            ,
366                            rusqlite::blob::ZeroBlob(self.#blob_columns.len() as i32).to_sql()?
367                        )*
368                        #(
369                            ,
370                            rusqlite::blob::ZeroBlob(self.#optional_blob_columns.as_ref().map(|v| v.len() as i32).unwrap_or_default()).to_sql()?
371                        )*
372                    ], |r| r.get(0));
373
374                    use std::io::Write as _;
375                    match rowid_result {
376                        Ok(rowid) => {
377                            #(
378                                let mut blob = transaction.blob_open(
379                                    rusqlite::DatabaseName::Main,
380                                    #collection_name,
381                                    #blob_column_names,
382                                    rowid,
383                                    false,
384                                )?;
385
386                                blob.write_all(&self.#blob_columns)?;
387                                blob.close()?;
388                            )*
389
390                            #(
391                                let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #optional_blob_column_names, rowid, false)?;
392                                if let Some(#optional_blob_columns) = self.#optional_blob_columns.as_ref() {
393                                    blob.write_all(#optional_blob_columns)?;
394                                }
395                                blob.close()?;
396                            )*
397
398                            Ok(())
399                        }
400                        Err(rusqlite::Error::SqliteFailure(e, _)) if e.extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_UNIQUE => {
401                            Err(crate::CryptoKeystoreError::AlreadyExists)
402                        }
403                        Err(e) => Err(e.into()),
404                    }
405                }
406
407                async fn delete_fail_on_missing_id(
408                    transaction: &crate::connection::TransactionWrapper<'_>,
409                    id: crate::entities::StringEntityId<'_>,
410                ) -> crate::CryptoKeystoreResult<()> {
411                    use crate::entities::EntityBase as _;
412                    let deleted = transaction.execute(&#delete_query, [#id_input_transformed_delete])?;
413
414                    if deleted > 0 {
415                        Ok(())
416                    } else {
417                        Err(Self::to_missing_key_err_kind().into())
418                    }
419                }
420            }
421        }
422    }
423}