core_crypto/
proteus.rs

1use std::{collections::HashMap, sync::Arc};
2
3use core_crypto_keystore::{
4    Database as CryptoKeystore,
5    connection::FetchFromDatabase,
6    entities::{ProteusIdentity, ProteusSession},
7};
8use proteus_wasm::{
9    keys::{IdentityKeyPair, PreKeyBundle},
10    message::Envelope,
11    session::Session,
12};
13
14use crate::{
15    CoreCrypto, Error, KeystoreError, LeafError, ProteusError, Result,
16    group_store::{GroupStore, GroupStoreEntity, GroupStoreValue},
17};
18
19/// Proteus session IDs, it seems it's basically a string
20pub type SessionIdentifier = String;
21
22/// Proteus Session wrapper, that contains the identifier and the associated proteus Session
23#[derive(Debug)]
24pub struct ProteusConversationSession {
25    pub(crate) identifier: SessionIdentifier,
26    pub(crate) session: Session<Arc<IdentityKeyPair>>,
27}
28
29impl ProteusConversationSession {
30    /// Encrypts a message for this Proteus session
31    pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
32        self.session
33            .encrypt(plaintext)
34            .and_then(|e| e.serialise())
35            .map_err(ProteusError::wrap("encrypting message for proteus session"))
36            .map_err(Into::into)
37    }
38
39    /// Decrypts a message for this Proteus session
40    pub async fn decrypt(&mut self, store: &mut core_crypto_keystore::Database, ciphertext: &[u8]) -> Result<Vec<u8>> {
41        let envelope = Envelope::deserialise(ciphertext).map_err(ProteusError::wrap("deserializing envelope"))?;
42        self.session
43            .decrypt(store, &envelope)
44            .await
45            .map_err(ProteusError::wrap("decrypting message for proteus session"))
46            .map_err(Into::into)
47    }
48
49    /// Returns the session identifier
50    pub fn identifier(&self) -> &str {
51        &self.identifier
52    }
53
54    /// Returns the public key fingerprint of the local identity (= self identity)
55    pub fn fingerprint_local(&self) -> String {
56        self.session.local_identity().fingerprint()
57    }
58
59    /// Returns the public key fingerprint of the remote identity (= client you're communicating with)
60    pub fn fingerprint_remote(&self) -> String {
61        self.session.remote_identity().fingerprint()
62    }
63}
64
65#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
66#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
67impl GroupStoreEntity for ProteusConversationSession {
68    type RawStoreValue = core_crypto_keystore::entities::ProteusSession;
69    type IdentityType = Arc<proteus_wasm::keys::IdentityKeyPair>;
70
71    async fn fetch_from_id(
72        id: impl AsRef<[u8]> + Send,
73        identity: Option<Self::IdentityType>,
74        keystore: &impl FetchFromDatabase,
75    ) -> crate::Result<Option<Self>> {
76        let result = keystore
77            .find::<Self::RawStoreValue>(id)
78            .await
79            .map_err(KeystoreError::wrap("finding raw group store entity by id"))?;
80        let Some(store_value) = result else {
81            return Ok(None);
82        };
83
84        let Some(identity) = identity else {
85            return Err(crate::Error::ProteusNotInitialized);
86        };
87
88        let session = proteus_wasm::session::Session::deserialise(identity, &store_value.session)
89            .map_err(ProteusError::wrap("deserializing session"))?;
90
91        Ok(Some(Self {
92            identifier: store_value.id.clone(),
93            session,
94        }))
95    }
96}
97
98impl CoreCrypto {
99    /// Proteus session accessor
100    ///
101    /// Warning: The Proteus client **MUST** be initialized with
102    /// [crate::transaction_context::TransactionContext::proteus_init] first or an error will be
103    /// returned
104    pub async fn proteus_session(
105        &self,
106        session_id: &str,
107    ) -> Result<Option<GroupStoreValue<ProteusConversationSession>>> {
108        let mut mutex = self.proteus.lock().await;
109        let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?;
110        let keystore = self.mls.crypto_provider.keystore();
111        proteus.session(session_id, &keystore).await
112    }
113
114    /// Proteus session exists
115    ///
116    /// Warning: The Proteus client **MUST** be initialized with
117    /// [crate::transaction_context::TransactionContext::proteus_init] first or an error will be
118    /// returned
119    pub async fn proteus_session_exists(&self, session_id: &str) -> Result<bool> {
120        let mut mutex = self.proteus.lock().await;
121        let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?;
122        let keystore = self.mls.crypto_provider.keystore();
123        Ok(proteus.session_exists(session_id, &keystore).await)
124    }
125
126    /// Returns the proteus last resort prekey id (u16::MAX = 65535)
127    pub fn proteus_last_resort_prekey_id() -> u16 {
128        ProteusCentral::last_resort_prekey_id()
129    }
130
131    /// Returns the proteus identity's public key fingerprint
132    ///
133    /// Warning: The Proteus client **MUST** be initialized with
134    /// [crate::transaction_context::TransactionContext::proteus_init] first or an error will be
135    /// returned
136    pub async fn proteus_fingerprint(&self) -> Result<String> {
137        let mutex = self.proteus.lock().await;
138        let proteus = mutex.as_ref().ok_or(Error::ProteusNotInitialized)?;
139        Ok(proteus.fingerprint())
140    }
141
142    /// Returns the proteus identity's public key fingerprint
143    ///
144    /// Warning: The Proteus client **MUST** be initialized with
145    /// [crate::transaction_context::TransactionContext::proteus_init] first or an error will be
146    /// returned
147    pub async fn proteus_fingerprint_local(&self, session_id: &str) -> Result<String> {
148        let mut mutex = self.proteus.lock().await;
149        let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?;
150        let keystore = self.mls.crypto_provider.keystore();
151        proteus.fingerprint_local(session_id, &keystore).await
152    }
153
154    /// Returns the proteus identity's public key fingerprint
155    ///
156    /// Warning: The Proteus client **MUST** be initialized with
157    /// [crate::transaction_context::TransactionContext::proteus_init] first or an error will be
158    /// returned
159    pub async fn proteus_fingerprint_remote(&self, session_id: &str) -> Result<String> {
160        let mut mutex = self.proteus.lock().await;
161        let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?;
162        let keystore = self.mls.crypto_provider.keystore();
163        proteus.fingerprint_remote(session_id, &keystore).await
164    }
165}
166
167/// Proteus counterpart of [crate::mls::session::Session]
168///
169/// The big difference is that [ProteusCentral] doesn't *own* its own keystore but must borrow it from the outside.
170/// Whether it's exclusively for this struct's purposes or it's shared with our main struct, [crate::mls::session::Session]
171#[derive(Debug)]
172pub struct ProteusCentral {
173    proteus_identity: Arc<IdentityKeyPair>,
174    proteus_sessions: GroupStore<ProteusConversationSession>,
175}
176
177impl ProteusCentral {
178    /// Initializes the [ProteusCentral]
179    pub async fn try_new(keystore: &CryptoKeystore) -> Result<Self> {
180        let proteus_identity: Arc<IdentityKeyPair> = Arc::new(Self::load_or_create_identity(keystore).await?);
181        let proteus_sessions = Self::restore_sessions(keystore, &proteus_identity).await?;
182
183        Ok(Self {
184            proteus_identity,
185            proteus_sessions,
186        })
187    }
188
189    /// Restore proteus sessions from disk
190    pub(crate) async fn reload_sessions(&mut self, keystore: &CryptoKeystore) -> Result<()> {
191        self.proteus_sessions = Self::restore_sessions(keystore, &self.proteus_identity).await?;
192        Ok(())
193    }
194
195    /// This function will try to load a proteus Identity from our keystore; If it cannot, it will create a new one
196    /// This means this function doesn't fail except in cases of deeper errors (such as in the Keystore and other crypto errors)
197    async fn load_or_create_identity(keystore: &CryptoKeystore) -> Result<IdentityKeyPair> {
198        let Some(identity) = keystore
199            .find::<ProteusIdentity>(ProteusIdentity::ID)
200            .await
201            .map_err(KeystoreError::wrap("finding proteus identity"))?
202        else {
203            return Self::create_identity(keystore).await;
204        };
205
206        let sk = identity.sk_raw();
207        let pk = identity.pk_raw();
208
209        // SAFETY: Byte lengths are ensured at the keystore level so this function is safe to call, despite being cursed
210        IdentityKeyPair::from_raw_key_pair(*sk, *pk)
211            .map_err(ProteusError::wrap("constructing identity keypair"))
212            .map_err(Into::into)
213    }
214
215    /// Internal function to create and save a new Proteus Identity
216    async fn create_identity(keystore: &CryptoKeystore) -> Result<IdentityKeyPair> {
217        let kp = IdentityKeyPair::new();
218        let pk = kp.public_key.public_key.as_slice().to_vec();
219
220        let ks_identity = ProteusIdentity {
221            sk: kp.secret_key.to_keypair_bytes().into(),
222            pk,
223        };
224        keystore
225            .save(ks_identity)
226            .await
227            .map_err(KeystoreError::wrap("saving new proteus identity"))?;
228
229        Ok(kp)
230    }
231
232    /// Restores the saved sessions in memory. This is performed automatically on init
233    async fn restore_sessions(
234        keystore: &core_crypto_keystore::Database,
235        identity: &Arc<IdentityKeyPair>,
236    ) -> Result<GroupStore<ProteusConversationSession>> {
237        let mut proteus_sessions = GroupStore::new_with_limit(crate::group_store::ITEM_LIMIT * 2);
238        for session in keystore
239            .find_all::<ProteusSession>(Default::default())
240            .await
241            .map_err(KeystoreError::wrap("finding all proteus sessions"))?
242            .into_iter()
243        {
244            let proteus_session = Session::deserialise(identity.clone(), &session.session)
245                .map_err(ProteusError::wrap("deserializing session"))?;
246
247            let identifier = session.id.clone();
248
249            let proteus_conversation = ProteusConversationSession {
250                identifier: identifier.clone(),
251                session: proteus_session,
252            };
253
254            if proteus_sessions
255                .try_insert(identifier.into_bytes(), proteus_conversation)
256                .is_err()
257            {
258                break;
259            }
260        }
261
262        Ok(proteus_sessions)
263    }
264
265    /// Creates a new session from a prekey
266    pub async fn session_from_prekey(
267        &mut self,
268        session_id: &str,
269        key: &[u8],
270    ) -> Result<GroupStoreValue<ProteusConversationSession>> {
271        let prekey = PreKeyBundle::deserialise(key).map_err(ProteusError::wrap("deserializing prekey bundle"))?;
272        // Note on the `::<>` turbofish below:
273        //
274        // `init_from_prekey` returns an error type which is parametric over some wrapped `E`,
275        // because one variant (not relevant to this particular operation) wraps an error type based
276        // on a parameter of a different function entirely.
277        //
278        // Rust complains here, because it can't figure out what type that `E` should be. After all, it's
279        // not inferrable from this function call! It is also entirely irrelevant in this case.
280        //
281        // We can derive two general rules about error-handling in Rust from this example:
282        //
283        // 1. It's better to make smaller error types where possible, encapsulating fallible operations
284        //    with their own error variants, and then wrapping those errors where required, as opposed to
285        //    creating giant catch-all errors. Doing so also has knock-on benefits with regard to tracing
286        //    the precise origin of the error.
287        // 2. One should never make an error wrapper parametric. If you need to wrap an unknown error,
288        //    it's always better to wrap a `Box<dyn std::error::Error>` than to make your error type parametric.
289        //    The allocation cost of creating the `Box` is utterly trivial in an error-handling path, and
290        //    it avoids parametric virality. (`init_from_prekey` is itself only generic because it returns
291        //    this error type with a type-parametric variant, which the function never returns.)
292        //
293        // In this case, we have the out of band knowledge that `ProteusErrorKind` has a `#[from]` implementation
294        // for `proteus_wasm::session::Error<core_crypto_keystore::CryptoKeystoreError>` and for no other kinds
295        // of session error. So we can safely say that the type of error we are meant to catch here, and
296        // therefore pass in that otherwise-irrelevant type, to ensure that error handling works properly.
297        //
298        // Some people say that if it's stupid but it works, it's not stupid. I disagree. If it's stupid but
299        // it works, that's our cue to seek out even better, non-stupid ways to get things done. I reiterate:
300        // the actual type referred to in this turbofish is nothing but a magic incantation to make error
301        // handling work; it has no bearing on the error retured from this function. How much better would it
302        // have been if `session::Error` were not parametric and we could have avoided the turbofish entirely?
303        let proteus_session = Session::init_from_prekey::<core_crypto_keystore::CryptoKeystoreError>(
304            self.proteus_identity.clone(),
305            prekey,
306        )
307        .map_err(ProteusError::wrap("initializing session from prekey"))?;
308
309        let proteus_conversation = ProteusConversationSession {
310            identifier: session_id.into(),
311            session: proteus_session,
312        };
313
314        self.proteus_sessions
315            .insert(session_id.as_bytes(), proteus_conversation);
316
317        Ok(self.proteus_sessions.get(session_id.as_bytes()).unwrap().clone())
318    }
319
320    /// Creates a new proteus Session from a received message
321    pub(crate) async fn session_from_message(
322        &mut self,
323        keystore: &mut CryptoKeystore,
324        session_id: &str,
325        envelope: &[u8],
326    ) -> Result<(GroupStoreValue<ProteusConversationSession>, Vec<u8>)> {
327        let message = Envelope::deserialise(envelope).map_err(ProteusError::wrap("deserialising envelope"))?;
328        let (session, payload) = Session::init_from_message(self.proteus_identity.clone(), keystore, &message)
329            .await
330            .map_err(ProteusError::wrap("initializing session from message"))?;
331
332        let proteus_conversation = ProteusConversationSession {
333            identifier: session_id.into(),
334            session,
335        };
336
337        self.proteus_sessions
338            .insert(session_id.as_bytes(), proteus_conversation);
339
340        Ok((
341            self.proteus_sessions.get(session_id.as_bytes()).unwrap().clone(),
342            payload,
343        ))
344    }
345
346    /// Persists a session in store
347    ///
348    /// **Note**: This isn't usually needed as persisting sessions happens automatically when decrypting/encrypting messages and initializing Sessions
349    pub(crate) async fn session_save(&mut self, keystore: &CryptoKeystore, session_id: &str) -> Result<()> {
350        if let Some(session) = self
351            .proteus_sessions
352            .get_fetch(session_id.as_bytes(), keystore, Some(self.proteus_identity.clone()))
353            .await?
354        {
355            Self::session_save_by_ref(keystore, session).await?;
356        }
357
358        Ok(())
359    }
360
361    pub(crate) async fn session_save_by_ref(
362        keystore: &CryptoKeystore,
363        session: GroupStoreValue<ProteusConversationSession>,
364    ) -> Result<()> {
365        let session = session.read().await;
366        let db_session = ProteusSession {
367            id: session.identifier().to_string(),
368            session: session
369                .session
370                .serialise()
371                .map_err(ProteusError::wrap("serializing session"))?,
372        };
373        keystore
374            .save(db_session)
375            .await
376            .map_err(KeystoreError::wrap("saving proteus session"))?;
377        Ok(())
378    }
379
380    /// Deletes a session in the store
381    pub(crate) async fn session_delete(&mut self, keystore: &CryptoKeystore, session_id: &str) -> Result<()> {
382        if keystore.remove::<ProteusSession, _>(session_id).await.is_ok() {
383            let _ = self.proteus_sessions.remove(session_id.as_bytes());
384        }
385        Ok(())
386    }
387
388    /// Session accessor
389    pub(crate) async fn session(
390        &mut self,
391        session_id: &str,
392        keystore: &CryptoKeystore,
393    ) -> Result<Option<GroupStoreValue<ProteusConversationSession>>> {
394        self.proteus_sessions
395            .get_fetch(session_id.as_bytes(), keystore, Some(self.proteus_identity.clone()))
396            .await
397    }
398
399    /// Session exists
400    pub(crate) async fn session_exists(&mut self, session_id: &str, keystore: &CryptoKeystore) -> bool {
401        self.session(session_id, keystore).await.ok().flatten().is_some()
402    }
403
404    /// Decrypt a proteus message for an already existing session
405    /// Note: This cannot be used for handshake messages, see [ProteusCentral::session_from_message]
406    pub(crate) async fn decrypt(
407        &mut self,
408        keystore: &mut CryptoKeystore,
409        session_id: &str,
410        ciphertext: &[u8],
411    ) -> Result<Vec<u8>> {
412        let session = self
413            .proteus_sessions
414            .get_fetch(session_id.as_bytes(), keystore, Some(self.proteus_identity.clone()))
415            .await?
416            .ok_or(LeafError::ConversationNotFound(session_id.as_bytes().into()))
417            .map_err(ProteusError::wrap("getting session"))?;
418
419        let plaintext = session.write().await.decrypt(keystore, ciphertext).await?;
420        ProteusCentral::session_save_by_ref(keystore, session).await?;
421
422        Ok(plaintext)
423    }
424
425    /// Encrypt a message for a session
426    pub(crate) async fn encrypt(
427        &mut self,
428        keystore: &mut CryptoKeystore,
429        session_id: &str,
430        plaintext: &[u8],
431    ) -> Result<Vec<u8>> {
432        let session = self
433            .session(session_id, keystore)
434            .await?
435            .ok_or(LeafError::ConversationNotFound(session_id.as_bytes().into()))
436            .map_err(ProteusError::wrap("getting session"))?;
437
438        let ciphertext = session.write().await.encrypt(plaintext)?;
439        ProteusCentral::session_save_by_ref(keystore, session).await?;
440
441        Ok(ciphertext)
442    }
443
444    /// Encrypts a message for a list of sessions
445    /// This is mainly used for conversations with multiple clients, this allows to minimize FFI roundtrips
446    pub(crate) async fn encrypt_batched(
447        &mut self,
448        keystore: &mut CryptoKeystore,
449        sessions: &[impl AsRef<str>],
450        plaintext: &[u8],
451    ) -> Result<HashMap<String, Vec<u8>>> {
452        let mut acc = HashMap::new();
453        for session_id in sessions {
454            if let Some(session) = self.session(session_id.as_ref(), keystore).await? {
455                let mut session_w = session.write().await;
456                acc.insert(session_w.identifier.clone(), session_w.encrypt(plaintext)?);
457                drop(session_w);
458
459                ProteusCentral::session_save_by_ref(keystore, session).await?;
460            }
461        }
462        Ok(acc)
463    }
464
465    /// Generates a new Proteus PreKey, stores it in the keystore and returns a serialized PreKeyBundle to be consumed externally
466    pub(crate) async fn new_prekey(&self, id: u16, keystore: &CryptoKeystore) -> Result<Vec<u8>> {
467        use proteus_wasm::keys::{PreKey, PreKeyId};
468
469        let prekey_id = PreKeyId::new(id);
470        let prekey = PreKey::new(prekey_id);
471        let keystore_prekey = core_crypto_keystore::entities::ProteusPrekey::from_raw(
472            id,
473            prekey.serialise().map_err(ProteusError::wrap("serialising prekey"))?,
474        );
475        let bundle = PreKeyBundle::new(self.proteus_identity.as_ref().public_key.clone(), &prekey);
476        let bundle = bundle
477            .serialise()
478            .map_err(ProteusError::wrap("serialising prekey bundle"))?;
479        keystore
480            .save(keystore_prekey)
481            .await
482            .map_err(KeystoreError::wrap("saving keystore prekey"))?;
483        Ok(bundle)
484    }
485
486    /// Generates a new Proteus Prekey, with an automatically auto-incremented ID.
487    ///
488    /// See [ProteusCentral::new_prekey]
489    pub(crate) async fn new_prekey_auto(&self, keystore: &CryptoKeystore) -> Result<(u16, Vec<u8>)> {
490        let id = core_crypto_keystore::entities::ProteusPrekey::get_free_id(keystore)
491            .await
492            .map_err(KeystoreError::wrap("getting proteus prekey by id"))?;
493        Ok((id, self.new_prekey(id, keystore).await?))
494    }
495
496    /// Returns the Proteus last resort prekey ID (u16::MAX = 65535 = 0xFFFF)
497    pub fn last_resort_prekey_id() -> u16 {
498        proteus_wasm::keys::MAX_PREKEY_ID.value()
499    }
500
501    /// Returns the Proteus last resort prekey
502    /// If it cannot be found, one will be created.
503    pub(crate) async fn last_resort_prekey(&self, keystore: &CryptoKeystore) -> Result<Vec<u8>> {
504        let last_resort = if let Some(last_resort) = keystore
505            .find::<core_crypto_keystore::entities::ProteusPrekey>(
506                Self::last_resort_prekey_id().to_le_bytes().as_slice(),
507            )
508            .await
509            .map_err(KeystoreError::wrap("finding proteus prekey"))?
510        {
511            proteus_wasm::keys::PreKey::deserialise(&last_resort.prekey)
512                .map_err(ProteusError::wrap("deserialising proteus prekey"))?
513        } else {
514            let last_resort = proteus_wasm::keys::PreKey::last_resort();
515
516            use core_crypto_keystore::CryptoKeystoreProteus as _;
517            keystore
518                .proteus_store_prekey(
519                    Self::last_resort_prekey_id(),
520                    &last_resort
521                        .serialise()
522                        .map_err(ProteusError::wrap("serialising last resort prekey"))?,
523                )
524                .await
525                .map_err(KeystoreError::wrap("storing proteus prekey"))?;
526
527            last_resort
528        };
529
530        let bundle = PreKeyBundle::new(self.proteus_identity.as_ref().public_key.clone(), &last_resort);
531        let bundle = bundle
532            .serialise()
533            .map_err(ProteusError::wrap("serialising prekey bundle"))?;
534
535        Ok(bundle)
536    }
537
538    /// Proteus identity keypair
539    pub fn identity(&self) -> &IdentityKeyPair {
540        self.proteus_identity.as_ref()
541    }
542
543    /// Proteus Public key hex-encoded fingerprint
544    pub fn fingerprint(&self) -> String {
545        self.proteus_identity.as_ref().public_key.fingerprint()
546    }
547
548    /// Proteus Session local hex-encoded fingerprint
549    ///
550    /// # Errors
551    /// When the session is not found
552    pub(crate) async fn fingerprint_local(&mut self, session_id: &str, keystore: &CryptoKeystore) -> Result<String> {
553        let session = self
554            .session(session_id, keystore)
555            .await?
556            .ok_or(LeafError::ConversationNotFound(session_id.as_bytes().into()))
557            .map_err(ProteusError::wrap("getting session"))?;
558        let fingerprint = session.read().await.fingerprint_local();
559        Ok(fingerprint)
560    }
561
562    /// Proteus Session remote hex-encoded fingerprint
563    ///
564    /// # Errors
565    /// When the session is not found
566    pub(crate) async fn fingerprint_remote(&mut self, session_id: &str, keystore: &CryptoKeystore) -> Result<String> {
567        let session = self
568            .session(session_id, keystore)
569            .await?
570            .ok_or(LeafError::ConversationNotFound(session_id.as_bytes().into()))
571            .map_err(ProteusError::wrap("getting session"))?;
572        let fingerprint = session.read().await.fingerprint_remote();
573        Ok(fingerprint)
574    }
575
576    /// Hex-encoded fingerprint of the given prekey
577    ///
578    /// # Errors
579    /// If the prekey cannot be deserialized
580    pub fn fingerprint_prekeybundle(prekey: &[u8]) -> Result<String> {
581        let prekey = PreKeyBundle::deserialise(prekey).map_err(ProteusError::wrap("deserialising prekey bundle"))?;
582        Ok(prekey.identity_key.fingerprint())
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use core_crypto_keystore::{ConnectionType, Database, DatabaseKey};
589
590    use super::*;
591    use crate::{
592        CertificateBundle, ClientIdentifier, MlsCredentialType, Session, SessionConfig,
593        test_utils::{proteus_utils::*, x509::X509TestChain, *},
594    };
595
596    #[apply(all_cred_cipher)]
597    async fn cc_can_init(case: TestContext) {
598        #[cfg(not(target_family = "wasm"))]
599        let (path, db_file) = tmp_db_file();
600        #[cfg(target_family = "wasm")]
601        let (path, _) = tmp_db_file();
602        let client_id = "alice".into();
603        let db = Database::open(ConnectionType::Persistent(&path), &DatabaseKey::generate())
604            .await
605            .unwrap();
606        let cfg = SessionConfig::builder()
607            .database(db)
608            .client_id(client_id)
609            .ciphersuites([case.ciphersuite()])
610            .build()
611            .validate()
612            .unwrap();
613
614        let cc: CoreCrypto = Session::try_new(cfg).await.unwrap().into();
615        let context = cc.new_transaction().await.unwrap();
616        assert!(context.proteus_init().await.is_ok());
617        assert!(context.proteus_new_prekey(1).await.is_ok());
618        context.finish().await.unwrap();
619        #[cfg(not(target_family = "wasm"))]
620        drop(db_file);
621    }
622
623    #[apply(all_cred_cipher)]
624    async fn cc_can_2_phase_init(case: TestContext) {
625        #[cfg(not(target_family = "wasm"))]
626        let (path, db_file) = tmp_db_file();
627        #[cfg(target_family = "wasm")]
628        let (path, _) = tmp_db_file();
629        let db = Database::open(ConnectionType::Persistent(&path), &DatabaseKey::generate())
630            .await
631            .unwrap();
632        // we are deferring MLS initialization here, not passing a MLS 'client_id' yet
633        let cfg = SessionConfig::builder()
634            .database(db)
635            .ciphersuites([case.ciphersuite()])
636            .build()
637            .validate()
638            .unwrap();
639
640        let cc: CoreCrypto = Session::try_new(cfg).await.unwrap().into();
641        let transaction = cc.new_transaction().await.unwrap();
642        let x509_test_chain = X509TestChain::init_empty(case.signature_scheme());
643        x509_test_chain.register_with_central(&transaction).await;
644        assert!(transaction.proteus_init().await.is_ok());
645        // proteus is initialized, prekeys can be generated
646        assert!(transaction.proteus_new_prekey(1).await.is_ok());
647        // 👇 and so a unique 'client_id' can be fetched from wire-server
648        let client_id = "alice";
649        let identifier = match case.credential_type {
650            MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.into()),
651            MlsCredentialType::X509 => {
652                CertificateBundle::rand_identifier(client_id, &[x509_test_chain.find_local_intermediate_ca()])
653            }
654        };
655        transaction.mls_init(identifier, &[case.ciphersuite()]).await.unwrap();
656        // expect MLS to work
657        assert_eq!(
658            transaction
659                .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 2)
660                .await
661                .unwrap()
662                .len(),
663            2
664        );
665        #[cfg(not(target_family = "wasm"))]
666        drop(db_file);
667    }
668
669    #[macro_rules_attribute::apply(smol_macros::test)]
670    async fn can_init() {
671        #[cfg(not(target_family = "wasm"))]
672        let (path, db_file) = tmp_db_file();
673        #[cfg(target_family = "wasm")]
674        let (path, _) = tmp_db_file();
675        let key = DatabaseKey::generate();
676        let keystore = core_crypto_keystore::Database::open(ConnectionType::Persistent(&path), &key)
677            .await
678            .unwrap();
679        keystore.new_transaction().await.unwrap();
680        let central = ProteusCentral::try_new(&keystore).await.unwrap();
681        let identity = (*central.proteus_identity).clone();
682        keystore.commit_transaction().await.unwrap();
683
684        let keystore = core_crypto_keystore::Database::open(ConnectionType::Persistent(&path), &key)
685            .await
686            .unwrap();
687        keystore.new_transaction().await.unwrap();
688        let central = ProteusCentral::try_new(&keystore).await.unwrap();
689        keystore.commit_transaction().await.unwrap();
690        assert_eq!(identity, *central.proteus_identity);
691
692        keystore.wipe().await.unwrap();
693        #[cfg(not(target_family = "wasm"))]
694        drop(db_file);
695    }
696
697    #[macro_rules_attribute::apply(smol_macros::test)]
698    async fn can_talk_with_proteus() {
699        #[cfg(not(target_family = "wasm"))]
700        let (path, db_file) = tmp_db_file();
701        #[cfg(target_family = "wasm")]
702        let (path, _) = tmp_db_file();
703
704        let session_id = uuid::Uuid::new_v4().hyphenated().to_string();
705
706        let key = DatabaseKey::generate();
707        let mut keystore = core_crypto_keystore::Database::open(ConnectionType::Persistent(&path), &key)
708            .await
709            .unwrap();
710        keystore.new_transaction().await.unwrap();
711
712        let mut alice = ProteusCentral::try_new(&keystore).await.unwrap();
713
714        let mut bob = CryptoboxLike::init();
715        let bob_pk_bundle = bob.new_prekey();
716
717        alice
718            .session_from_prekey(&session_id, &bob_pk_bundle.serialise().unwrap())
719            .await
720            .unwrap();
721
722        let message = b"Hello world";
723
724        let encrypted = alice.encrypt(&mut keystore, &session_id, message).await.unwrap();
725        let decrypted = bob.decrypt(&session_id, &encrypted).await;
726        assert_eq!(decrypted, message);
727
728        let encrypted = bob.encrypt(&session_id, message);
729        let decrypted = alice.decrypt(&mut keystore, &session_id, &encrypted).await.unwrap();
730        assert_eq!(decrypted, message);
731
732        keystore.commit_transaction().await.unwrap();
733        keystore.wipe().await.unwrap();
734        #[cfg(not(target_family = "wasm"))]
735        drop(db_file);
736    }
737
738    #[macro_rules_attribute::apply(smol_macros::test)]
739    async fn can_produce_proteus_consumed_prekeys() {
740        #[cfg(not(target_family = "wasm"))]
741        let (path, db_file) = tmp_db_file();
742        #[cfg(target_family = "wasm")]
743        let (path, _) = tmp_db_file();
744
745        let session_id = uuid::Uuid::new_v4().hyphenated().to_string();
746
747        let key = DatabaseKey::generate();
748        let mut keystore = core_crypto_keystore::Database::open(ConnectionType::Persistent(&path), &key)
749            .await
750            .unwrap();
751        keystore.new_transaction().await.unwrap();
752        let mut alice = ProteusCentral::try_new(&keystore).await.unwrap();
753
754        let mut bob = CryptoboxLike::init();
755
756        let alice_prekey_bundle_ser = alice.new_prekey(1, &keystore).await.unwrap();
757
758        bob.init_session_from_prekey_bundle(&session_id, &alice_prekey_bundle_ser);
759        let message = b"Hello world!";
760        let encrypted = bob.encrypt(&session_id, message);
761
762        let (_, decrypted) = alice
763            .session_from_message(&mut keystore, &session_id, &encrypted)
764            .await
765            .unwrap();
766
767        assert_eq!(message, decrypted.as_slice());
768
769        let encrypted = alice.encrypt(&mut keystore, &session_id, message).await.unwrap();
770        let decrypted = bob.decrypt(&session_id, &encrypted).await;
771
772        assert_eq!(message, decrypted.as_slice());
773        keystore.commit_transaction().await.unwrap();
774        keystore.wipe().await.unwrap();
775        #[cfg(not(target_family = "wasm"))]
776        drop(db_file);
777    }
778
779    #[macro_rules_attribute::apply(smol_macros::test)]
780    async fn auto_prekeys_are_sequential() {
781        use core_crypto_keystore::entities::ProteusPrekey;
782        const GAP_AMOUNT: u16 = 5;
783        const ID_TEST_RANGE: std::ops::RangeInclusive<u16> = 1..=30;
784
785        #[cfg(not(target_family = "wasm"))]
786        let (path, db_file) = tmp_db_file();
787        #[cfg(target_family = "wasm")]
788        let (path, _) = tmp_db_file();
789
790        let key = DatabaseKey::generate();
791        let keystore = core_crypto_keystore::Database::open(ConnectionType::Persistent(&path), &key)
792            .await
793            .unwrap();
794        keystore.new_transaction().await.unwrap();
795        let alice = ProteusCentral::try_new(&keystore).await.unwrap();
796
797        for i in ID_TEST_RANGE {
798            let (pk_id, pkb) = alice.new_prekey_auto(&keystore).await.unwrap();
799            assert_eq!(i, pk_id);
800            let prekey = proteus_wasm::keys::PreKeyBundle::deserialise(&pkb).unwrap();
801            assert_eq!(prekey.prekey_id.value(), pk_id);
802        }
803
804        use rand::Rng as _;
805        let mut rng = rand::thread_rng();
806        let mut gap_ids: Vec<u16> = (0..GAP_AMOUNT).map(|_| rng.gen_range(ID_TEST_RANGE)).collect();
807        gap_ids.sort();
808        gap_ids.dedup();
809        while gap_ids.len() < GAP_AMOUNT as usize {
810            gap_ids.push(rng.gen_range(ID_TEST_RANGE));
811            gap_ids.sort();
812            gap_ids.dedup();
813        }
814        for gap_id in gap_ids.iter() {
815            keystore.remove::<ProteusPrekey, _>(gap_id.to_le_bytes()).await.unwrap();
816        }
817
818        gap_ids.sort();
819
820        for gap_id in gap_ids.iter() {
821            let (pk_id, pkb) = alice.new_prekey_auto(&keystore).await.unwrap();
822            assert_eq!(pk_id, *gap_id);
823            let prekey = proteus_wasm::keys::PreKeyBundle::deserialise(&pkb).unwrap();
824            assert_eq!(prekey.prekey_id.value(), *gap_id);
825        }
826
827        let mut gap_ids: Vec<u16> = (0..GAP_AMOUNT).map(|_| rng.gen_range(ID_TEST_RANGE)).collect();
828        gap_ids.sort();
829        gap_ids.dedup();
830        while gap_ids.len() < GAP_AMOUNT as usize {
831            gap_ids.push(rng.gen_range(ID_TEST_RANGE));
832            gap_ids.sort();
833            gap_ids.dedup();
834        }
835        for gap_id in gap_ids.iter() {
836            keystore.remove::<ProteusPrekey, _>(gap_id.to_le_bytes()).await.unwrap();
837        }
838
839        let potential_range = *ID_TEST_RANGE.end()..=(*ID_TEST_RANGE.end() * 2);
840        let potential_range_check = potential_range.clone();
841        for _ in potential_range {
842            let (pk_id, pkb) = alice.new_prekey_auto(&keystore).await.unwrap();
843            assert!(gap_ids.contains(&pk_id) || potential_range_check.contains(&pk_id));
844            let prekey = proteus_wasm::keys::PreKeyBundle::deserialise(&pkb).unwrap();
845            assert_eq!(prekey.prekey_id.value(), pk_id);
846        }
847        keystore.commit_transaction().await.unwrap();
848        keystore.wipe().await.unwrap();
849        #[cfg(not(target_family = "wasm"))]
850        drop(db_file);
851    }
852}