core_crypto_keystore/entities/
mls.rs

1// Wire
2// Copyright (C) 2022 Wire Swiss GmbH
3
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with this program. If not, see http://www.gnu.org/licenses/.
16
17use super::{Entity, EntityBase, EntityFindParams, EntityTransactionExt, StringEntityId};
18use crate::{CryptoKeystoreError, CryptoKeystoreResult, connection::TransactionWrapper};
19use openmls_traits::types::SignatureScheme;
20use zeroize::Zeroize;
21
22/// Entity representing a persisted `MlsGroup`
23#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
24#[zeroize(drop)]
25#[entity(collection_name = "mls_groups")]
26#[cfg_attr(
27    any(target_family = "wasm", feature = "serde"),
28    derive(serde::Serialize, serde::Deserialize)
29)]
30pub struct PersistedMlsGroup {
31    #[id(hex, column = "id_hex")]
32    pub id: Vec<u8>,
33    pub state: Vec<u8>,
34    pub parent_id: Option<Vec<u8>>,
35}
36
37#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
38#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
39pub trait PersistedMlsGroupExt: Entity {
40    fn parent_id(&self) -> Option<&[u8]>;
41
42    async fn parent_group(
43        &self,
44        conn: &mut <Self as super::EntityBase>::ConnectionType,
45    ) -> CryptoKeystoreResult<Option<Self>> {
46        let Some(parent_id) = self.parent_id() else {
47            return Ok(None);
48        };
49
50        <Self as super::Entity>::find_one(conn, &parent_id.into()).await
51    }
52
53    async fn child_groups(
54        &self,
55        conn: &mut <Self as super::EntityBase>::ConnectionType,
56    ) -> CryptoKeystoreResult<Vec<Self>> {
57        let entities = <Self as super::Entity>::find_all(conn, super::EntityFindParams::default()).await?;
58
59        let id = self.id_raw();
60
61        Ok(entities
62            .into_iter()
63            .filter(|entity| entity.parent_id().map(|parent_id| parent_id == id).unwrap_or_default())
64            .collect())
65    }
66}
67
68/// Entity representing a temporarily persisted `MlsGroup`
69#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
70#[zeroize(drop)]
71#[cfg_attr(
72    any(target_family = "wasm", feature = "serde"),
73    derive(serde::Serialize, serde::Deserialize)
74)]
75pub struct PersistedMlsPendingGroup {
76    pub id: Vec<u8>,
77    pub state: Vec<u8>,
78    pub parent_id: Option<Vec<u8>>,
79    pub custom_configuration: Vec<u8>,
80}
81
82/// Entity representing a buffered message
83#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
84#[zeroize(drop)]
85#[cfg_attr(
86    any(target_family = "wasm", feature = "serde"),
87    derive(serde::Serialize, serde::Deserialize)
88)]
89pub struct MlsPendingMessage {
90    pub foreign_id: Vec<u8>,
91    pub message: Vec<u8>,
92}
93
94/// Entity representing a buffered commit.
95///
96/// There should always exist either 0 or 1 of these in the store per conversation.
97/// Commits are buffered if not all proposals they reference have yet been received.
98///
99/// We don't automatically zeroize on drop because the commit data is still encrypted at this point;
100/// it is not risky to leave it in memory.
101#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
102#[cfg_attr(
103    any(target_family = "wasm", feature = "serde"),
104    derive(serde::Serialize, serde::Deserialize)
105)]
106pub struct MlsBufferedCommit {
107    // we'd ideally just call this field `conversation_id`, but as of right now the
108    // Entity macro does not yet support id columns not named `id`
109    #[id(hex, column = "conversation_id_hex")]
110    conversation_id: Vec<u8>,
111    commit_data: Vec<u8>,
112}
113
114impl MlsBufferedCommit {
115    /// Create a new `Self` from conversation id and the commit data.
116    pub fn new(conversation_id: Vec<u8>, commit_data: Vec<u8>) -> Self {
117        Self {
118            conversation_id,
119            commit_data,
120        }
121    }
122
123    pub fn conversation_id(&self) -> &[u8] {
124        &self.conversation_id
125    }
126
127    pub fn commit_data(&self) -> &[u8] {
128        &self.commit_data
129    }
130
131    pub fn into_commit_data(self) -> Vec<u8> {
132        self.commit_data
133    }
134}
135
136/// Entity representing a persisted `Credential`
137#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
138#[zeroize(drop)]
139#[cfg_attr(
140    any(target_family = "wasm", feature = "serde"),
141    derive(serde::Serialize, serde::Deserialize)
142)]
143pub struct MlsCredential {
144    pub id: Vec<u8>,
145    pub credential: Vec<u8>,
146    pub created_at: u64,
147}
148
149#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
150#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
151pub trait MlsCredentialExt: Entity {
152    async fn delete_by_credential(tx: &TransactionWrapper<'_>, credential: Vec<u8>) -> CryptoKeystoreResult<()>;
153}
154
155/// Entity representing a persisted `SignatureKeyPair`
156#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
157#[zeroize(drop)]
158#[cfg_attr(
159    any(target_family = "wasm", feature = "serde"),
160    derive(serde::Serialize, serde::Deserialize)
161)]
162pub struct MlsSignatureKeyPair {
163    pub signature_scheme: u16,
164    pub pk: Vec<u8>,
165    pub keypair: Vec<u8>,
166    pub credential_id: Vec<u8>,
167}
168
169impl MlsSignatureKeyPair {
170    pub fn new(signature_scheme: SignatureScheme, pk: Vec<u8>, keypair: Vec<u8>, credential_id: Vec<u8>) -> Self {
171        Self {
172            signature_scheme: signature_scheme as u16,
173            pk,
174            keypair,
175            credential_id,
176        }
177    }
178}
179
180/// Entity representing a persisted `HpkePrivateKey` (related to LeafNode Private keys that the client is aware of)
181#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
182#[zeroize(drop)]
183#[cfg_attr(
184    any(target_family = "wasm", feature = "serde"),
185    derive(serde::Serialize, serde::Deserialize)
186)]
187pub struct MlsHpkePrivateKey {
188    pub sk: Vec<u8>,
189    pub pk: Vec<u8>,
190}
191
192/// Entity representing a persisted `HpkePrivateKey` (related to LeafNode Private keys that the client is aware of)
193#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
194#[zeroize(drop)]
195#[cfg_attr(
196    any(target_family = "wasm", feature = "serde"),
197    derive(serde::Serialize, serde::Deserialize)
198)]
199pub struct MlsEncryptionKeyPair {
200    pub sk: Vec<u8>,
201    pub pk: Vec<u8>,
202}
203
204/// Entity representing a list of [MlsEncryptionKeyPair]
205#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
206#[zeroize(drop)]
207#[entity(collection_name = "mls_epoch_encryption_keypairs")]
208#[cfg_attr(target_family = "wasm", derive(serde::Serialize, serde::Deserialize))]
209pub struct MlsEpochEncryptionKeyPair {
210    #[id(hex, column = "id_hex")]
211    pub id: Vec<u8>,
212    pub keypairs: Vec<u8>,
213}
214
215/// Entity representing a persisted `SignatureKeyPair`
216#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
217#[zeroize(drop)]
218#[cfg_attr(
219    any(target_family = "wasm", feature = "serde"),
220    derive(serde::Serialize, serde::Deserialize)
221)]
222pub struct MlsPskBundle {
223    pub psk_id: Vec<u8>,
224    pub psk: Vec<u8>,
225}
226
227/// Entity representing a persisted `KeyPackage`
228#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
229#[zeroize(drop)]
230#[entity(collection_name = "mls_keypackages")]
231#[cfg_attr(
232    any(target_family = "wasm", feature = "serde"),
233    derive(serde::Serialize, serde::Deserialize)
234)]
235pub struct MlsKeyPackage {
236    #[id(hex, column = "keypackage_ref_hex")]
237    pub keypackage_ref: Vec<u8>,
238    pub keypackage: Vec<u8>,
239}
240
241/// Entity representing an enrollment instance used to fetch a x509 certificate and persisted when
242/// context switches and the memory it lives in is about to be erased
243#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
244#[zeroize(drop)]
245#[entity(collection_name = "e2ei_enrollment", no_upsert)]
246#[cfg_attr(
247    any(target_family = "wasm", feature = "serde"),
248    derive(serde::Serialize, serde::Deserialize)
249)]
250pub struct E2eiEnrollment {
251    pub id: Vec<u8>,
252    pub content: Vec<u8>,
253}
254
255#[cfg(target_family = "wasm")]
256#[async_trait::async_trait(?Send)]
257pub trait UniqueEntity:
258    EntityBase<ConnectionType = crate::connection::KeystoreDatabaseConnection>
259    + serde::Serialize
260    + serde::de::DeserializeOwned
261where
262    Self: 'static,
263{
264    const ID: [u8; 1] = [0];
265
266    fn content(&self) -> &[u8];
267
268    fn set_content(&mut self, content: Vec<u8>);
269
270    async fn find_unique(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult<Self> {
271        Ok(conn
272            .storage()
273            .get(Self::COLLECTION_NAME, &Self::ID)
274            .await?
275            .ok_or(CryptoKeystoreError::NotFound(Self::COLLECTION_NAME, "".to_string()))?)
276    }
277
278    async fn find_all(conn: &mut Self::ConnectionType, _params: EntityFindParams) -> CryptoKeystoreResult<Vec<Self>> {
279        match Self::find_unique(conn).await {
280            Ok(record) => Ok(vec![record]),
281            Err(CryptoKeystoreError::NotFound(_, _)) => Ok(vec![]),
282            Err(err) => Err(err),
283        }
284    }
285
286    async fn find_one(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult<Option<Self>> {
287        match Self::find_unique(conn).await {
288            Ok(record) => Ok(Some(record)),
289            Err(CryptoKeystoreError::NotFound(_, _)) => Ok(None),
290            Err(err) => Err(err),
291        }
292    }
293
294    async fn count(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult<usize> {
295        conn.storage().count(Self::COLLECTION_NAME).await
296    }
297
298    async fn replace<'a>(&'a self, transaction: &TransactionWrapper<'a>) -> CryptoKeystoreResult<()> {
299        transaction.save(self.clone()).await?;
300        Ok(())
301    }
302}
303
304#[cfg(not(target_family = "wasm"))]
305#[async_trait::async_trait]
306pub trait UniqueEntity: EntityBase<ConnectionType = crate::connection::KeystoreDatabaseConnection> {
307    const ID: usize = 0;
308
309    fn new(content: Vec<u8>) -> Self;
310
311    async fn find_unique(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult<Self> {
312        let mut conn = conn.conn().await;
313        let transaction = conn.transaction()?;
314        use rusqlite::OptionalExtension as _;
315
316        let maybe_content = transaction
317            .query_row(
318                &format!("SELECT content FROM {} WHERE id = ?", Self::COLLECTION_NAME),
319                [Self::ID],
320                |r| r.get::<_, Vec<u8>>(0),
321            )
322            .optional()?;
323
324        if let Some(content) = maybe_content {
325            Ok(Self::new(content))
326        } else {
327            Err(CryptoKeystoreError::NotFound(Self::COLLECTION_NAME, "".to_string()))
328        }
329    }
330
331    async fn find_all(conn: &mut Self::ConnectionType, _params: EntityFindParams) -> CryptoKeystoreResult<Vec<Self>> {
332        match Self::find_unique(conn).await {
333            Ok(record) => Ok(vec![record]),
334            Err(CryptoKeystoreError::NotFound(_, _)) => Ok(vec![]),
335            Err(err) => Err(err),
336        }
337    }
338
339    async fn find_one(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult<Option<Self>> {
340        match Self::find_unique(conn).await {
341            Ok(record) => Ok(Some(record)),
342            Err(CryptoKeystoreError::NotFound(_, _)) => Ok(None),
343            Err(err) => Err(err),
344        }
345    }
346
347    async fn count(conn: &mut Self::ConnectionType) -> CryptoKeystoreResult<usize> {
348        let conn = conn.conn().await;
349        conn.query_row(&format!("SELECT COUNT(*) FROM {}", Self::COLLECTION_NAME), [], |r| {
350            r.get(0)
351        })
352        .map_err(Into::into)
353    }
354
355    fn content(&self) -> &[u8];
356
357    async fn replace(&self, transaction: &TransactionWrapper<'_>) -> CryptoKeystoreResult<()> {
358        use crate::connection::DatabaseConnection;
359        Self::ConnectionType::check_buffer_size(self.content().len())?;
360        let zb_content = rusqlite::blob::ZeroBlob(self.content().len() as i32);
361
362        use rusqlite::ToSql;
363        let params: [rusqlite::types::ToSqlOutput; 2] = [Self::ID.to_sql()?, zb_content.to_sql()?];
364
365        transaction.execute(
366            &format!(
367                "INSERT OR REPLACE INTO {} (id, content) VALUES (?, ?)",
368                Self::COLLECTION_NAME
369            ),
370            params,
371        )?;
372        let row_id = transaction.last_insert_rowid();
373
374        let mut blob = transaction.blob_open(
375            rusqlite::DatabaseName::Main,
376            Self::COLLECTION_NAME,
377            "content",
378            row_id,
379            false,
380        )?;
381        use std::io::Write;
382        blob.write_all(self.content())?;
383        blob.close()?;
384
385        Ok(())
386    }
387}
388
389#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
390#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
391impl<T: UniqueEntity + Send + Sync> EntityTransactionExt for T {
392    #[cfg(not(target_family = "wasm"))]
393    async fn save(&self, tx: &TransactionWrapper<'_>) -> CryptoKeystoreResult<()> {
394        self.replace(tx).await
395    }
396
397    #[cfg(target_family = "wasm")]
398    async fn save<'a>(&'a self, tx: &TransactionWrapper<'a>) -> CryptoKeystoreResult<()> {
399        self.replace(tx).await
400    }
401
402    #[cfg(not(target_family = "wasm"))]
403    async fn delete_fail_on_missing_id(
404        _: &TransactionWrapper<'_>,
405        _id: StringEntityId<'_>,
406    ) -> CryptoKeystoreResult<()> {
407        Err(CryptoKeystoreError::NotImplemented)
408    }
409
410    #[cfg(target_family = "wasm")]
411    async fn delete_fail_on_missing_id<'a>(
412        _: &TransactionWrapper<'a>,
413        _id: StringEntityId<'a>,
414    ) -> CryptoKeystoreResult<()> {
415        Err(CryptoKeystoreError::NotImplemented)
416    }
417}
418
419/// OIDC refresh token used in E2EI
420#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
421#[zeroize(drop)]
422#[cfg_attr(
423    any(target_family = "wasm", feature = "serde"),
424    derive(serde::Serialize, serde::Deserialize)
425)]
426pub struct E2eiRefreshToken {
427    pub content: Vec<u8>,
428}
429
430#[derive(Debug, Clone, PartialEq, Eq, Zeroize)]
431#[zeroize(drop)]
432#[cfg_attr(
433    any(target_family = "wasm", feature = "serde"),
434    derive(serde::Serialize, serde::Deserialize)
435)]
436pub struct E2eiAcmeCA {
437    pub content: Vec<u8>,
438}
439
440#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
441#[zeroize(drop)]
442#[cfg_attr(
443    any(target_family = "wasm", feature = "serde"),
444    derive(serde::Serialize, serde::Deserialize)
445)]
446pub struct E2eiIntermediateCert {
447    // key to identify the CA cert; Using a combination of SKI & AKI extensions concatenated like so is suitable: `SKI[+AKI]`
448    #[id]
449    pub ski_aki_pair: String,
450    pub content: Vec<u8>,
451}
452
453#[derive(Debug, Clone, PartialEq, Eq, Zeroize, core_crypto_macros::Entity)]
454#[zeroize(drop)]
455#[cfg_attr(
456    any(target_family = "wasm", feature = "serde"),
457    derive(serde::Serialize, serde::Deserialize)
458)]
459pub struct E2eiCrl {
460    #[id]
461    pub distribution_point: String,
462    pub content: Vec<u8>,
463}