core_crypto_keystore/entities/
mls.rs

1use super::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, StringEntityId};
2use crate::{CryptoKeystoreError, CryptoKeystoreResult, connection::TransactionWrapper};
3use openmls_traits::types::SignatureScheme;
4use zeroize::Zeroize;
5
6/// Entity representing a persisted `MlsGroup`
7#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
8#[zeroize(drop)]
9#[entity(collection_name = "mls_groups")]
10#[cfg_attr(
11    any(target_family = "wasm", feature = "serde"),
12    derive(serde::Serialize, serde::Deserialize)
13)]
14pub struct PersistedMlsGroup {
15    #[id(hex, column = "id_hex")]
16    pub id: Vec<u8>,
17    pub state: Vec<u8>,
18    pub parent_id: Option<Vec<u8>>,
19}
20
21#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
22#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
23pub trait PersistedMlsGroupExt: Entity {
24    fn parent_id(&self) -> Option<&[u8]>;
25
26    async fn parent_group(
27        &self,
28        conn: &mut <Self as super::EntityBase>::ConnectionType,
29    ) -> CryptoKeystoreResult<Option<Self>> {
30        let Some(parent_id) = self.parent_id() else {
31            return Ok(None);
32        };
33
34        <Self as super::Entity>::find_one(conn, &parent_id.into()).await
35    }
36
37    async fn child_groups(
38        &self,
39        conn: &mut <Self as super::EntityBase>::ConnectionType,
40    ) -> CryptoKeystoreResult<Vec<Self>> {
41        let entities = <Self as super::Entity>::find_all(conn, super::EntityFindParams::default()).await?;
42
43        let id = self.id_raw();
44
45        Ok(entities
46            .into_iter()
47            .filter(|entity| entity.parent_id().map(|parent_id| parent_id == id).unwrap_or_default())
48            .collect())
49    }
50}
51
52/// Entity representing a temporarily persisted `MlsGroup`
53#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
54#[zeroize(drop)]
55#[cfg_attr(
56    any(target_family = "wasm", feature = "serde"),
57    derive(serde::Serialize, serde::Deserialize)
58)]
59pub struct PersistedMlsPendingGroup {
60    pub id: Vec<u8>,
61    pub state: Vec<u8>,
62    pub parent_id: Option<Vec<u8>>,
63    pub custom_configuration: Vec<u8>,
64}
65
66/// Entity representing a buffered message
67#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
68#[zeroize(drop)]
69#[cfg_attr(
70    any(target_family = "wasm", feature = "serde"),
71    derive(serde::Serialize, serde::Deserialize)
72)]
73pub struct MlsPendingMessage {
74    pub foreign_id: Vec<u8>,
75    pub message: Vec<u8>,
76}
77
78/// Entity representing a buffered commit.
79///
80/// There should always exist either 0 or 1 of these in the store per conversation.
81/// Commits are buffered if not all proposals they reference have yet been received.
82///
83/// We don't automatically zeroize on drop because the commit data is still encrypted at this point;
84/// it is not risky to leave it in memory.
85#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
86#[cfg_attr(
87    any(target_family = "wasm", feature = "serde"),
88    derive(serde::Serialize, serde::Deserialize)
89)]
90pub struct MlsBufferedCommit {
91    // we'd ideally just call this field `conversation_id`, but as of right now the
92    // Entity macro does not yet support id columns not named `id`
93    #[id(hex, column = "conversation_id_hex")]
94    conversation_id: Vec<u8>,
95    commit_data: Vec<u8>,
96}
97
98impl MlsBufferedCommit {
99    /// Create a new `Self` from conversation id and the commit data.
100    pub fn new(conversation_id: Vec<u8>, commit_data: Vec<u8>) -> Self {
101        Self {
102            conversation_id,
103            commit_data,
104        }
105    }
106
107    pub fn conversation_id(&self) -> &[u8] {
108        &self.conversation_id
109    }
110
111    pub fn commit_data(&self) -> &[u8] {
112        &self.commit_data
113    }
114
115    pub fn into_commit_data(self) -> Vec<u8> {
116        self.commit_data
117    }
118}
119
120/// Entity representing a persisted `Credential`
121#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
122#[zeroize(drop)]
123#[cfg_attr(
124    any(target_family = "wasm", feature = "serde"),
125    derive(serde::Serialize, serde::Deserialize)
126)]
127pub struct MlsCredential {
128    pub id: Vec<u8>,
129    pub credential: Vec<u8>,
130    pub created_at: u64,
131}
132
133#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
134#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
135pub trait MlsCredentialExt: Entity {
136    async fn delete_by_credential(tx: &TransactionWrapper<'_>, credential: Vec<u8>) -> CryptoKeystoreResult<()>;
137}
138
139/// Entity representing a persisted `SignatureKeyPair`
140#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
141#[zeroize(drop)]
142#[cfg_attr(
143    any(target_family = "wasm", feature = "serde"),
144    derive(serde::Serialize, serde::Deserialize)
145)]
146pub struct MlsSignatureKeyPair {
147    pub signature_scheme: u16,
148    pub pk: Vec<u8>,
149    pub keypair: Vec<u8>,
150    pub credential_id: Vec<u8>,
151}
152
153impl MlsSignatureKeyPair {
154    pub fn new(signature_scheme: SignatureScheme, pk: Vec<u8>, keypair: Vec<u8>, credential_id: Vec<u8>) -> Self {
155        Self {
156            signature_scheme: signature_scheme as u16,
157            pk,
158            keypair,
159            credential_id,
160        }
161    }
162}
163
164/// Entity representing a persisted `HpkePrivateKey` (related to LeafNode Private keys that the client is aware of)
165#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
166#[zeroize(drop)]
167#[cfg_attr(
168    any(target_family = "wasm", feature = "serde"),
169    derive(serde::Serialize, serde::Deserialize)
170)]
171pub struct MlsHpkePrivateKey {
172    pub sk: Vec<u8>,
173    pub pk: Vec<u8>,
174}
175
176/// Entity representing a persisted `HpkePrivateKey` (related to LeafNode Private keys that the client is aware of)
177#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
178#[zeroize(drop)]
179#[cfg_attr(
180    any(target_family = "wasm", feature = "serde"),
181    derive(serde::Serialize, serde::Deserialize)
182)]
183pub struct MlsEncryptionKeyPair {
184    pub sk: Vec<u8>,
185    pub pk: Vec<u8>,
186}
187
188/// Entity representing a list of [MlsEncryptionKeyPair]
189#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
190#[zeroize(drop)]
191#[entity(collection_name = "mls_epoch_encryption_keypairs")]
192#[cfg_attr(target_family = "wasm", derive(serde::Serialize, serde::Deserialize))]
193pub struct MlsEpochEncryptionKeyPair {
194    #[id(hex, column = "id_hex")]
195    pub id: Vec<u8>,
196    pub keypairs: Vec<u8>,
197}
198
199/// Entity representing a persisted `SignatureKeyPair`
200#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
201#[zeroize(drop)]
202#[cfg_attr(
203    any(target_family = "wasm", feature = "serde"),
204    derive(serde::Serialize, serde::Deserialize)
205)]
206pub struct MlsPskBundle {
207    pub psk_id: Vec<u8>,
208    pub psk: Vec<u8>,
209}
210
211/// Entity representing a persisted `KeyPackage`
212#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
213#[zeroize(drop)]
214#[entity(collection_name = "mls_keypackages")]
215#[cfg_attr(
216    any(target_family = "wasm", feature = "serde"),
217    derive(serde::Serialize, serde::Deserialize)
218)]
219pub struct MlsKeyPackage {
220    #[id(hex, column = "keypackage_ref_hex")]
221    pub keypackage_ref: Vec<u8>,
222    pub keypackage: Vec<u8>,
223}
224
225/// Entity representing an enrollment instance used to fetch a x509 certificate and persisted when
226/// context switches and the memory it lives in is about to be erased
227#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
228#[zeroize(drop)]
229#[entity(collection_name = "e2ei_enrollment", no_upsert)]
230#[cfg_attr(
231    any(target_family = "wasm", feature = "serde"),
232    derive(serde::Serialize, serde::Deserialize)
233)]
234pub struct E2eiEnrollment {
235    pub id: Vec<u8>,
236    pub content: Vec<u8>,
237}
238
239#[cfg(target_family = "wasm")]
240#[async_trait::async_trait(?Send)]
241pub trait UniqueEntity:
242    EntityBase<ConnectionType = crate::connection::KeystoreDatabaseConnection>
243    + serde::Serialize
244    + serde::de::DeserializeOwned
245where
246    Self: 'static,
247{
248    const ID: [u8; 1] = [0];
249
250    fn content(&self) -> &[u8];
251
252    fn set_content(&mut self, content: Vec<u8>);
253
254    async fn find_unique(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult<Self> {
255        Ok(conn
256            .storage()
257            .get(Self::COLLECTION_NAME, &Self::ID)
258            .await?
259            .ok_or(CryptoKeystoreError::NotFound(Self::COLLECTION_NAME, "".to_string()))?)
260    }
261
262    async fn find_all(conn: &mut Self::ConnectionType, _params: EntityFindParams) -> CryptoKeystoreResult<Vec<Self>> {
263        match Self::find_unique(conn).await {
264            Ok(record) => Ok(vec![record]),
265            Err(CryptoKeystoreError::NotFound(_, _)) => Ok(vec![]),
266            Err(err) => Err(err),
267        }
268    }
269
270    async fn find_one(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult<Option<Self>> {
271        match Self::find_unique(conn).await {
272            Ok(record) => Ok(Some(record)),
273            Err(CryptoKeystoreError::NotFound(_, _)) => Ok(None),
274            Err(err) => Err(err),
275        }
276    }
277
278    async fn count(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult<usize> {
279        conn.storage().count(Self::COLLECTION_NAME).await
280    }
281
282    async fn replace<'a>(&'a self, transaction: &TransactionWrapper<'a>) -> CryptoKeystoreResult<()> {
283        transaction.save(self.clone()).await?;
284        Ok(())
285    }
286}
287
288#[cfg(not(target_family = "wasm"))]
289#[async_trait::async_trait]
290pub trait UniqueEntity: EntityBase<ConnectionType = crate::connection::KeystoreDatabaseConnection> {
291    const ID: usize = 0;
292
293    fn new(content: Vec<u8>) -> Self;
294
295    async fn find_unique(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult<Self> {
296        let mut conn = conn.conn().await;
297        let transaction = conn.transaction()?;
298        use rusqlite::OptionalExtension as _;
299
300        let maybe_content = transaction
301            .query_row(
302                &format!("SELECT content FROM {} WHERE id = ?", Self::COLLECTION_NAME),
303                [Self::ID],
304                |r| r.get::<_, Vec<u8>>(0),
305            )
306            .optional()?;
307
308        if let Some(content) = maybe_content {
309            Ok(Self::new(content))
310        } else {
311            Err(CryptoKeystoreError::NotFound(Self::COLLECTION_NAME, "".to_string()))
312        }
313    }
314
315    async fn find_all(conn: &mut Self::ConnectionType, _params: EntityFindParams) -> CryptoKeystoreResult<Vec<Self>> {
316        match Self::find_unique(conn).await {
317            Ok(record) => Ok(vec![record]),
318            Err(CryptoKeystoreError::NotFound(_, _)) => Ok(vec![]),
319            Err(err) => Err(err),
320        }
321    }
322
323    async fn find_one(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult<Option<Self>> {
324        match Self::find_unique(conn).await {
325            Ok(record) => Ok(Some(record)),
326            Err(CryptoKeystoreError::NotFound(_, _)) => Ok(None),
327            Err(err) => Err(err),
328        }
329    }
330
331    async fn count(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult<usize> {
332        let conn = conn.conn().await;
333        conn.query_row(&format!("SELECT COUNT(*) FROM {}", Self::COLLECTION_NAME), [], |r| {
334            r.get(0)
335        })
336        .map_err(Into::into)
337    }
338
339    fn content(&self) -> &[u8];
340
341    async fn replace(&self, transaction: &TransactionWrapper<'_>) -> CryptoKeystoreResult<()> {
342        use crate::connection::DatabaseConnection;
343        Self::ConnectionType::check_buffer_size(self.content().len())?;
344        let zb_content = rusqlite::blob::ZeroBlob(self.content().len() as i32);
345
346        use rusqlite::ToSql;
347        let params: [rusqlite::types::ToSqlOutput; 2] = [Self::ID.to_sql()?, zb_content.to_sql()?];
348
349        transaction.execute(
350            &format!(
351                "INSERT OR REPLACE INTO {} (id, content) VALUES (?, ?)",
352                Self::COLLECTION_NAME
353            ),
354            params,
355        )?;
356        let row_id = transaction.last_insert_rowid();
357
358        let mut blob = transaction.blob_open(
359            rusqlite::DatabaseName::Main,
360            Self::COLLECTION_NAME,
361            "content",
362            row_id,
363            false,
364        )?;
365        use std::io::Write;
366        blob.write_all(self.content())?;
367        blob.close()?;
368
369        Ok(())
370    }
371}
372
373#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
374#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
375impl<T: UniqueEntity + Send + Sync> EntityTransactionExt for T {
376    #[cfg(not(target_family = "wasm"))]
377    async fn save(&self, tx: &TransactionWrapper<'_>) -> CryptoKeystoreResult<()> {
378        self.replace(tx).await
379    }
380
381    #[cfg(target_family = "wasm")]
382    async fn save<'a>(&'a self, tx: &TransactionWrapper<'a>) -> CryptoKeystoreResult<()> {
383        self.replace(tx).await
384    }
385
386    #[cfg(not(target_family = "wasm"))]
387    async fn delete_fail_on_missing_id(
388        _: &TransactionWrapper<'_>,
389        _id: StringEntityId<'_>,
390    ) -> CryptoKeystoreResult<()> {
391        Err(CryptoKeystoreError::NotImplemented)
392    }
393
394    #[cfg(target_family = "wasm")]
395    async fn delete_fail_on_missing_id<'a>(
396        _: &TransactionWrapper<'a>,
397        _id: StringEntityId<'a>,
398    ) -> CryptoKeystoreResult<()> {
399        Err(CryptoKeystoreError::NotImplemented)
400    }
401}
402
403/// OIDC refresh token used in E2EI
404#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
405#[zeroize(drop)]
406#[cfg_attr(
407    any(target_family = "wasm", feature = "serde"),
408    derive(serde::Serialize, serde::Deserialize)
409)]
410pub struct E2eiRefreshToken {
411    pub content: Vec<u8>,
412}
413
414#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
415#[zeroize(drop)]
416#[cfg_attr(
417    any(target_family = "wasm", feature = "serde"),
418    derive(serde::Serialize, serde::Deserialize)
419)]
420pub struct E2eiAcmeCA {
421    pub content: Vec<u8>,
422}
423
424#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
425#[zeroize(drop)]
426#[cfg_attr(
427    any(target_family = "wasm", feature = "serde"),
428    derive(serde::Serialize, serde::Deserialize)
429)]
430pub struct E2eiIntermediateCert {
431    // key to identify the CA cert; Using a combination of SKI & AKI extensions concatenated like so is suitable: `SKI[+AKI]`
432    #[id]
433    pub ski_aki_pair: String,
434    pub content: Vec<u8>,
435}
436
437#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
438#[zeroize(drop)]
439#[cfg_attr(
440    any(target_family = "wasm", feature = "serde"),
441    derive(serde::Serialize, serde::Deserialize)
442)]
443pub struct E2eiCrl {
444    #[id]
445    pub distribution_point: String,
446    pub content: Vec<u8>,
447}