core_crypto_macros/entity_derive/
derive_impl.rs

1use crate::entity_derive::{IdColumnType, IdTransformation, KeyStoreEntityFlattened};
2use quote::quote;
3
4impl quote::ToTokens for KeyStoreEntityFlattened {
5    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
6        let entity_base_impl = self.entity_base_impl();
7        let entity_generic_impl = self.entity_generic_impl();
8        let entity_wasm_impl = self.entity_wasm_impl();
9        let entity_transaction_ext_impl = self.entity_transaction_ext_impl();
10        tokens.extend(quote! {
11            #entity_base_impl
12            #entity_generic_impl
13            #entity_wasm_impl
14            #entity_transaction_ext_impl
15        });
16    }
17}
18impl KeyStoreEntityFlattened {
19    fn entity_base_impl(&self) -> proc_macro2::TokenStream {
20        let Self {
21            collection_name,
22            struct_name,
23            ..
24        } = self;
25
26        // Identical for both wasm and non-wasm
27        quote! {
28            #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
29            #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
30            impl crate::entities::EntityBase for #struct_name {
31                type ConnectionType = crate::connection::KeystoreDatabaseConnection;
32                type AutoGeneratedFields = ();
33                const COLLECTION_NAME: &'static str = #collection_name;
34
35                fn to_missing_key_err_kind() -> crate::MissingKeyErrorKind {
36                    crate::MissingKeyErrorKind::#struct_name
37                }
38
39                fn to_transaction_entity(self) -> crate::transaction::dynamic_dispatch::Entity {
40                    crate::transaction::dynamic_dispatch::Entity::#struct_name(self)
41                }
42            }
43        }
44    }
45
46    fn entity_generic_impl(&self) -> proc_macro2::TokenStream {
47        let Self {
48            collection_name,
49            struct_name,
50            id,
51            id_type,
52            id_name,
53            id_transformation,
54            blob_columns,
55            blob_column_names,
56            all_columns,
57            optional_blob_columns,
58            optional_blob_column_names,
59            ..
60        } = self;
61
62        let string_id_conversion = (*id_type == IdColumnType::String).then(|| {
63            quote! { let #id: String = id.try_into()?; }
64        });
65
66        let id_to_byte_slice = match id_type {
67            IdColumnType::String => quote! {self.#id.as_bytes() },
68            IdColumnType::Bytes => quote! { &self.#id.as_slice() },
69        };
70
71        let id_field_find_one = match id_type {
72            IdColumnType::String => quote! { #id, },
73            IdColumnType::Bytes => quote! { #id: id.to_bytes(), },
74        };
75
76        let id_slice = match id_type {
77            IdColumnType::String => quote! { #id.as_str() },
78            IdColumnType::Bytes => quote! { #id.as_slice() },
79        };
80
81        let id_input_transformed = match id_transformation {
82            Some(IdTransformation::Hex) => quote! { id.as_hex_string() },
83            Some(IdTransformation::Sha256) => todo!(),
84            None => id_slice,
85        };
86
87        let destructure_row = match id_transformation {
88            Some(IdTransformation::Hex) => quote! { let (rowid, #id): (_, String) = row?; },
89            Some(IdTransformation::Sha256) => todo!(),
90            None => quote! { let (rowid, #id) = row?; },
91        };
92
93        let id_from_transformed = match id_transformation {
94            Some(IdTransformation::Hex) => {
95                quote! { let #id = <Self as crate::entities::EntityIdStringExt>::id_from_hex(&#id)?; }
96            }
97            Some(IdTransformation::Sha256) => todo!(),
98            None => quote! {},
99        };
100
101        let find_all_query = format!("SELECT rowid, {id_name} FROM {collection_name} ");
102
103        let find_one_query = format!("SELECT rowid FROM {collection_name} WHERE {id_name} = ?");
104
105        let count_query = format!("SELECT COUNT(*) FROM {collection_name}");
106
107        quote! {
108            #[cfg(not(target_family = "wasm"))]
109            #[async_trait::async_trait]
110            impl crate::entities::Entity for #struct_name {
111                fn id_raw(&self) -> &[u8] {
112                    #id_to_byte_slice
113                }
114
115                async fn find_all(
116                    conn: &mut Self::ConnectionType,
117                    params: crate::entities::EntityFindParams,
118                ) -> crate::CryptoKeystoreResult<Vec<Self>> {
119                    let mut conn = conn.conn().await;
120                    let transaction = conn.transaction()?;
121                    let query = #find_all_query.to_string() + &params.to_sql();
122
123                    let mut stmt = transaction.prepare_cached(&query)?;
124                    let mut rows = stmt.query_map([], |r| Ok((r.get(0)?, r.get(1)?)))?;
125                    use std::io::Read as _;
126                    rows.map(|row| {
127                        #destructure_row
128                        #id_from_transformed
129
130                        #(
131                            let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #blob_column_names, rowid, true)?;
132                            let mut #blob_columns = Vec::with_capacity(blob.len());
133                            blob.read_to_end(&mut #blob_columns)?;
134                            blob.close()?;
135                        )*
136
137                        #(
138                            let mut #optional_blob_columns = None;
139                            if let Ok(mut blob) =
140                                transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #optional_blob_column_names, rowid, true)
141                            {
142                                if !blob.is_empty() {
143                                    let mut blob_data = Vec::with_capacity(blob.len());
144                                    blob.read_to_end(&mut blob_data)?;
145                                    #optional_blob_columns.replace(blob_data);
146                                }
147                                blob.close()?;
148                            }
149                        )*
150
151                        Ok(Self { #id
152                            #(
153                            , #all_columns
154                            )*
155                        })
156                    }).collect()
157                }
158
159                async fn find_one(
160                    conn: &mut Self::ConnectionType,
161                    id: &crate::entities::StringEntityId,
162                ) -> crate::CryptoKeystoreResult<Option<Self>> {
163                    let mut conn = conn.conn().await;
164                    let transaction = conn.transaction()?;
165                    use rusqlite::OptionalExtension as _;
166
167                   #string_id_conversion
168
169                    let mut rowid: Option<i64> = transaction
170                        .query_row(&#find_one_query, [#id_input_transformed], |r| {
171                            r.get::<_, i64>(0)
172                        })
173                        .optional()?;
174
175                    use std::io::Read as _;
176                    if let Some(rowid) = rowid.take() {
177                        #(
178                            let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #blob_column_names, rowid, true)?;
179                            let mut #blob_columns = Vec::with_capacity(blob.len());
180                            blob.read_to_end(&mut #blob_columns)?;
181                            blob.close()?;
182                        )*
183
184                        #(
185                            let mut #optional_blob_columns = None;
186                            if let Ok(mut blob) =
187                                transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #optional_blob_column_names, rowid, true)
188                            {
189                                if !blob.is_empty() {
190                                    let mut blob_data = Vec::with_capacity(blob.len());
191                                    blob.read_to_end(&mut blob_data)?;
192                                    #optional_blob_columns.replace(blob_data);
193                                }
194                                blob.close()?;
195                            }
196                        )*
197
198                        Ok(Some(Self {
199                            #id_field_find_one
200                            #(
201                                #blob_columns,
202                            )*
203                            #(
204                                #optional_blob_columns,
205                            )*
206                        }))
207                    } else {
208                        Ok(None)
209                    }
210                }
211
212                async fn count(conn: &mut Self::ConnectionType) -> crate::CryptoKeystoreResult<usize> {
213                    let conn = conn.conn().await;
214                    conn.query_row(&#count_query, [], |r| r.get(0)).map_err(Into::into)
215                }
216            }
217        }
218    }
219
220    fn entity_wasm_impl(&self) -> proc_macro2::TokenStream {
221        let Self {
222            collection_name,
223            struct_name,
224            id,
225            id_type,
226            blob_columns,
227            ..
228        } = self;
229
230        let id_to_byte_slice = match id_type {
231            IdColumnType::String => quote! {self.#id.as_bytes() },
232            IdColumnType::Bytes => quote! { self.#id.as_slice() },
233        };
234
235        quote! {
236            #[cfg(target_family = "wasm")]
237            #[async_trait::async_trait(?Send)]
238            impl crate::entities::Entity for #struct_name {
239                fn id_raw(&self) -> &[u8] {
240                    #id_to_byte_slice
241                }
242
243                async fn find_all(conn: &mut Self::ConnectionType, params: crate::entities::EntityFindParams) ->  crate::CryptoKeystoreResult<Vec<Self>> {
244                    let storage = conn.storage();
245                    storage.get_all(#collection_name, Some(params)).await
246                }
247
248                async fn find_one(conn: &mut Self::ConnectionType, id: &crate::entities::StringEntityId) ->  crate::CryptoKeystoreResult<Option<Self>> {
249                    conn.storage().get(#collection_name, id.as_slice()).await
250                }
251
252                async fn count(conn: &mut Self::ConnectionType) ->  crate::CryptoKeystoreResult<usize> {
253                    conn.storage().count(#collection_name).await
254                }
255
256                fn encrypt(&mut self, cipher: &aes_gcm::Aes256Gcm) -> crate::CryptoKeystoreResult<()> {
257                    use crate::connection::DatabaseConnection as _;
258                    #(
259                        self.#blob_columns = self.encrypt_data(cipher, self.#blob_columns.as_slice())?;
260                        Self::ConnectionType::check_buffer_size(self.#blob_columns.len())?;
261                    )*
262                    Ok(())
263                }
264
265                fn decrypt(&mut self, cipher: &aes_gcm::Aes256Gcm) -> crate::CryptoKeystoreResult<()> {
266                    #(
267                        self.#blob_columns = self.decrypt_data(cipher, self.#blob_columns.as_slice())?;
268                    )*
269                    Ok(())
270                }
271            }
272        }
273    }
274
275    fn entity_transaction_ext_impl(&self) -> proc_macro2::TokenStream {
276        let Self {
277            collection_name,
278            struct_name,
279            id,
280            id_name,
281            all_columns,
282            all_column_names,
283            blob_columns,
284            blob_column_names,
285            optional_blob_columns,
286            optional_blob_column_names,
287            id_transformation,
288            no_upsert,
289            id_type,
290            ..
291        } = self;
292
293        let upsert_pairs: Vec<_> = all_column_names
294            .iter()
295            .map(|col| format! { "{col} = excluded.{col}"})
296            .collect();
297        let upsert_postfix = (!no_upsert)
298            // UPSERT (ON CONFLICT DO UPDATE) with RETURNING to get the rowid
299            .then(|| format!(" ON CONFLICT({id_name}) DO UPDATE SET {}", upsert_pairs.join(", ")))
300            .unwrap_or_default();
301
302        let column_list = all_columns
303            .iter()
304            .map(ToString::to_string)
305            .collect::<Vec<_>>()
306            .join(", ");
307
308        let import_id_string_ext = match id_transformation {
309            Some(IdTransformation::Hex) => quote! { use crate::entities::EntityIdStringExt as _; },
310            Some(IdTransformation::Sha256) => todo!(),
311            None => quote! {},
312        };
313
314        let upsert_query = format!(
315            "INSERT INTO {collection_name} ({id_name}, {column_list}) VALUES (?{}){upsert_postfix} RETURNING rowid",
316            ", ?".repeat(self.all_columns.len()),
317        );
318
319        let self_id_transformed = match id_transformation {
320            Some(IdTransformation::Hex) => quote! { self.id_hex() },
321            Some(IdTransformation::Sha256) => todo!(),
322            None => quote! { self.#id },
323        };
324
325        let delete_query = format!("DELETE FROM {collection_name} WHERE {id_name} = ?");
326
327        let id_slice_delete = match id_type {
328            IdColumnType::String => quote! { id.try_as_str()? },
329            IdColumnType::Bytes => quote! { id.as_slice() },
330        };
331
332        let id_input_transformed_delete = match id_transformation {
333            Some(IdTransformation::Hex) => quote! { id.as_hex_string() },
334            Some(IdTransformation::Sha256) => todo!(),
335            None => id_slice_delete,
336        };
337
338        quote! {
339            #[cfg(target_family = "wasm")]
340            #[async_trait::async_trait(?Send)]
341            impl crate::entities::EntityTransactionExt for #struct_name {}
342
343            #[cfg(not(target_family = "wasm"))]
344            #[async_trait::async_trait]
345            impl crate::entities::EntityTransactionExt for #struct_name {
346                async fn save(&self, transaction: &crate::connection::TransactionWrapper<'_>) -> crate::CryptoKeystoreResult<()> {
347                    use crate::entities::EntityBase as _;
348                    use rusqlite::ToSql as _;
349                    use crate::connection::DatabaseConnection as _;
350
351                    #(
352                        crate::connection::KeystoreDatabaseConnection::check_buffer_size(self.#blob_columns.len())?;
353                    )*
354                    #(
355                      crate::connection::KeystoreDatabaseConnection::check_buffer_size(
356                            self.#optional_blob_columns.as_ref().map(|v| v.len()).unwrap_or_default()
357                      )?;
358                    )*
359
360                    #import_id_string_ext
361
362                    let sql = #upsert_query;
363
364                    let rowid_result: Result<i64, rusqlite::Error> =
365                        transaction.query_row(&sql, [
366                        #self_id_transformed.to_sql()?
367                        #(
368                            ,
369                            rusqlite::blob::ZeroBlob(self.#blob_columns.len() as i32).to_sql()?
370                        )*
371                        #(
372                            ,
373                            rusqlite::blob::ZeroBlob(self.#optional_blob_columns.as_ref().map(|v| v.len() as i32).unwrap_or_default()).to_sql()?
374                        )*
375                    ], |r| r.get(0));
376
377                    use std::io::Write as _;
378                    match rowid_result {
379                        Ok(rowid) => {
380                            #(
381                                let mut blob = transaction.blob_open(
382                                    rusqlite::DatabaseName::Main,
383                                    #collection_name,
384                                    #blob_column_names,
385                                    rowid,
386                                    false,
387                                )?;
388
389                                blob.write_all(&self.#blob_columns)?;
390                                blob.close()?;
391                            )*
392
393                            #(
394                                let mut blob = transaction.blob_open(rusqlite::DatabaseName::Main, #collection_name, #optional_blob_column_names, rowid, false)?;
395                                if let Some(#optional_blob_columns) = self.#optional_blob_columns.as_ref() {
396                                    blob.write_all(#optional_blob_columns)?;
397                                }
398                                blob.close()?;
399                            )*
400
401                            Ok(())
402                        }
403                        Err(rusqlite::Error::SqliteFailure(e, _)) if e.extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_UNIQUE => {
404                            Err(crate::CryptoKeystoreError::AlreadyExists)
405                        }
406                        Err(e) => Err(e.into()),
407                    }
408                }
409
410                async fn delete_fail_on_missing_id(
411                    transaction: &crate::connection::TransactionWrapper<'_>,
412                    id: crate::entities::StringEntityId<'_>,
413                ) -> crate::CryptoKeystoreResult<()> {
414                    use crate::entities::EntityBase as _;
415                    let deleted = transaction.execute(&#delete_query, [#id_input_transformed_delete])?;
416
417                    if deleted > 0 {
418                        Ok(())
419                    } else {
420                        Err(Self::to_missing_key_err_kind().into())
421                    }
422                }
423            }
424        }
425    }
426}