core_crypto/
proteus.rs

1use std::{collections::HashMap, sync::Arc};
2
3use core_crypto_keystore::{
4    Database as CryptoKeystore,
5    connection::FetchFromDatabase,
6    entities::{ProteusIdentity, ProteusSession},
7};
8use proteus_wasm::{
9    keys::{IdentityKeyPair, PreKeyBundle},
10    message::Envelope,
11    session::Session,
12};
13
14use crate::{
15    CoreCrypto, Error, KeystoreError, LeafError, ProteusError, Result,
16    group_store::{GroupStore, GroupStoreEntity, GroupStoreValue},
17};
18
19/// Proteus session IDs, it seems it's basically a string
20pub type SessionIdentifier = String;
21
22/// Proteus Session wrapper, that contains the identifier and the associated proteus Session
23#[derive(Debug)]
24pub struct ProteusConversationSession {
25    pub(crate) identifier: SessionIdentifier,
26    pub(crate) session: Session<Arc<IdentityKeyPair>>,
27}
28
29impl ProteusConversationSession {
30    /// Encrypts a message for this Proteus session
31    pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
32        self.session
33            .encrypt(plaintext)
34            .and_then(|e| e.serialise())
35            .map_err(ProteusError::wrap("encrypting message for proteus session"))
36            .map_err(Into::into)
37    }
38
39    /// Decrypts a message for this Proteus session
40    pub async fn decrypt(&mut self, store: &mut core_crypto_keystore::Database, ciphertext: &[u8]) -> Result<Vec<u8>> {
41        let envelope = Envelope::deserialise(ciphertext).map_err(ProteusError::wrap("deserializing envelope"))?;
42        self.session
43            .decrypt(store, &envelope)
44            .await
45            .map_err(ProteusError::wrap("decrypting message for proteus session"))
46            .map_err(Into::into)
47    }
48
49    /// Returns the session identifier
50    pub fn identifier(&self) -> &str {
51        &self.identifier
52    }
53
54    /// Returns the public key fingerprint of the local identity (= self identity)
55    pub fn fingerprint_local(&self) -> String {
56        self.session.local_identity().fingerprint()
57    }
58
59    /// Returns the public key fingerprint of the remote identity (= client you're communicating with)
60    pub fn fingerprint_remote(&self) -> String {
61        self.session.remote_identity().fingerprint()
62    }
63}
64
65#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
66#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
67impl GroupStoreEntity for ProteusConversationSession {
68    type RawStoreValue = core_crypto_keystore::entities::ProteusSession;
69    type IdentityType = Arc<proteus_wasm::keys::IdentityKeyPair>;
70
71    async fn fetch_from_id(
72        id: impl AsRef<[u8]> + Send,
73        identity: Option<Self::IdentityType>,
74        keystore: &impl FetchFromDatabase,
75    ) -> crate::Result<Option<Self>> {
76        let result = keystore
77            .find::<Self::RawStoreValue>(id)
78            .await
79            .map_err(KeystoreError::wrap("finding raw group store entity by id"))?;
80        let Some(store_value) = result else {
81            return Ok(None);
82        };
83
84        let Some(identity) = identity else {
85            return Err(crate::Error::ProteusNotInitialized);
86        };
87
88        let session = proteus_wasm::session::Session::deserialise(identity, &store_value.session)
89            .map_err(ProteusError::wrap("deserializing session"))?;
90
91        Ok(Some(Self {
92            identifier: store_value.id.clone(),
93            session,
94        }))
95    }
96}
97
98impl CoreCrypto {
99    /// Proteus session accessor
100    ///
101    /// Warning: The Proteus client **MUST** be initialized with
102    /// [crate::transaction_context::TransactionContext::proteus_init] first or an error will be
103    /// returned
104    pub async fn proteus_session(
105        &self,
106        session_id: &str,
107    ) -> Result<Option<GroupStoreValue<ProteusConversationSession>>> {
108        let mut mutex = self.proteus.lock().await;
109        let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?;
110        let keystore = self.mls.crypto_provider.keystore();
111        proteus.session(session_id, &keystore).await
112    }
113
114    /// Proteus session exists
115    ///
116    /// Warning: The Proteus client **MUST** be initialized with
117    /// [crate::transaction_context::TransactionContext::proteus_init] first or an error will be
118    /// returned
119    pub async fn proteus_session_exists(&self, session_id: &str) -> Result<bool> {
120        let mut mutex = self.proteus.lock().await;
121        let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?;
122        let keystore = self.mls.crypto_provider.keystore();
123        Ok(proteus.session_exists(session_id, &keystore).await)
124    }
125
126    /// Returns the proteus last resort prekey id (u16::MAX = 65535)
127    pub fn proteus_last_resort_prekey_id() -> u16 {
128        ProteusCentral::last_resort_prekey_id()
129    }
130
131    /// Returns the proteus identity's public key fingerprint
132    ///
133    /// Warning: The Proteus client **MUST** be initialized with
134    /// [crate::transaction_context::TransactionContext::proteus_init] first or an error will be
135    /// returned
136    pub async fn proteus_fingerprint(&self) -> Result<String> {
137        let mutex = self.proteus.lock().await;
138        let proteus = mutex.as_ref().ok_or(Error::ProteusNotInitialized)?;
139        Ok(proteus.fingerprint())
140    }
141
142    /// Returns the proteus identity's public key fingerprint
143    ///
144    /// Warning: The Proteus client **MUST** be initialized with
145    /// [crate::transaction_context::TransactionContext::proteus_init] first or an error will be
146    /// returned
147    pub async fn proteus_fingerprint_local(&self, session_id: &str) -> Result<String> {
148        let mut mutex = self.proteus.lock().await;
149        let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?;
150        let keystore = self.mls.crypto_provider.keystore();
151        proteus.fingerprint_local(session_id, &keystore).await
152    }
153
154    /// Returns the proteus identity's public key fingerprint
155    ///
156    /// Warning: The Proteus client **MUST** be initialized with
157    /// [crate::transaction_context::TransactionContext::proteus_init] first or an error will be
158    /// returned
159    pub async fn proteus_fingerprint_remote(&self, session_id: &str) -> Result<String> {
160        let mut mutex = self.proteus.lock().await;
161        let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?;
162        let keystore = self.mls.crypto_provider.keystore();
163        proteus.fingerprint_remote(session_id, &keystore).await
164    }
165}
166
167/// Proteus counterpart of [crate::mls::session::Session]
168///
169/// The big difference is that [ProteusCentral] doesn't *own* its own keystore but must borrow it from the outside.
170/// Whether it's exclusively for this struct's purposes or it's shared with our main struct,
171/// [crate::mls::session::Session]
172#[derive(Debug)]
173pub struct ProteusCentral {
174    proteus_identity: Arc<IdentityKeyPair>,
175    proteus_sessions: GroupStore<ProteusConversationSession>,
176}
177
178impl ProteusCentral {
179    /// Initializes the [ProteusCentral]
180    pub async fn try_new(keystore: &CryptoKeystore) -> Result<Self> {
181        let proteus_identity: Arc<IdentityKeyPair> = Arc::new(Self::load_or_create_identity(keystore).await?);
182        let proteus_sessions = Self::restore_sessions(keystore, &proteus_identity).await?;
183
184        Ok(Self {
185            proteus_identity,
186            proteus_sessions,
187        })
188    }
189
190    /// Restore proteus sessions from disk
191    pub(crate) async fn reload_sessions(&mut self, keystore: &CryptoKeystore) -> Result<()> {
192        self.proteus_sessions = Self::restore_sessions(keystore, &self.proteus_identity).await?;
193        Ok(())
194    }
195
196    /// This function will try to load a proteus Identity from our keystore; If it cannot, it will create a new one
197    /// This means this function doesn't fail except in cases of deeper errors (such as in the Keystore and other crypto
198    /// errors)
199    async fn load_or_create_identity(keystore: &CryptoKeystore) -> Result<IdentityKeyPair> {
200        let Some(identity) = keystore
201            .find::<ProteusIdentity>(ProteusIdentity::ID)
202            .await
203            .map_err(KeystoreError::wrap("finding proteus identity"))?
204        else {
205            return Self::create_identity(keystore).await;
206        };
207
208        let sk = identity.sk_raw();
209        let pk = identity.pk_raw();
210
211        // SAFETY: Byte lengths are ensured at the keystore level so this function is safe to call, despite being cursed
212        IdentityKeyPair::from_raw_key_pair(*sk, *pk)
213            .map_err(ProteusError::wrap("constructing identity keypair"))
214            .map_err(Into::into)
215    }
216
217    /// Internal function to create and save a new Proteus Identity
218    async fn create_identity(keystore: &CryptoKeystore) -> Result<IdentityKeyPair> {
219        let kp = IdentityKeyPair::new();
220        let pk = kp.public_key.public_key.as_slice().to_vec();
221
222        let ks_identity = ProteusIdentity {
223            sk: kp.secret_key.to_keypair_bytes().into(),
224            pk,
225        };
226        keystore
227            .save(ks_identity)
228            .await
229            .map_err(KeystoreError::wrap("saving new proteus identity"))?;
230
231        Ok(kp)
232    }
233
234    /// Restores the saved sessions in memory. This is performed automatically on init
235    async fn restore_sessions(
236        keystore: &core_crypto_keystore::Database,
237        identity: &Arc<IdentityKeyPair>,
238    ) -> Result<GroupStore<ProteusConversationSession>> {
239        let mut proteus_sessions = GroupStore::new_with_limit(crate::group_store::ITEM_LIMIT * 2);
240        for session in keystore
241            .find_all::<ProteusSession>(Default::default())
242            .await
243            .map_err(KeystoreError::wrap("finding all proteus sessions"))?
244            .into_iter()
245        {
246            let proteus_session = Session::deserialise(identity.clone(), &session.session)
247                .map_err(ProteusError::wrap("deserializing session"))?;
248
249            let identifier = session.id.clone();
250
251            let proteus_conversation = ProteusConversationSession {
252                identifier: identifier.clone(),
253                session: proteus_session,
254            };
255
256            if proteus_sessions
257                .try_insert(identifier.into_bytes(), proteus_conversation)
258                .is_err()
259            {
260                break;
261            }
262        }
263
264        Ok(proteus_sessions)
265    }
266
267    /// Creates a new session from a prekey
268    pub async fn session_from_prekey(
269        &mut self,
270        session_id: &str,
271        key: &[u8],
272    ) -> Result<GroupStoreValue<ProteusConversationSession>> {
273        let prekey = PreKeyBundle::deserialise(key).map_err(ProteusError::wrap("deserializing prekey bundle"))?;
274        // Note on the `::<>` turbofish below:
275        //
276        // `init_from_prekey` returns an error type which is parametric over some wrapped `E`,
277        // because one variant (not relevant to this particular operation) wraps an error type based
278        // on a parameter of a different function entirely.
279        //
280        // Rust complains here, because it can't figure out what type that `E` should be. After all, it's
281        // not inferrable from this function call! It is also entirely irrelevant in this case.
282        //
283        // We can derive two general rules about error-handling in Rust from this example:
284        //
285        // 1. It's better to make smaller error types where possible, encapsulating fallible operations with their own
286        //    error variants, and then wrapping those errors where required, as opposed to creating giant catch-all
287        //    errors. Doing so also has knock-on benefits with regard to tracing the precise origin of the error.
288        // 2. One should never make an error wrapper parametric. If you need to wrap an unknown error, it's always
289        //    better to wrap a `Box<dyn std::error::Error>` than to make your error type parametric. The allocation cost
290        //    of creating the `Box` is utterly trivial in an error-handling path, and it avoids parametric virality.
291        //    (`init_from_prekey` is itself only generic because it returns this error type with a type-parametric
292        //    variant, which the function never returns.)
293        //
294        // In this case, we have the out of band knowledge that `ProteusErrorKind` has a `#[from]` implementation
295        // for `proteus_wasm::session::Error<core_crypto_keystore::CryptoKeystoreError>` and for no other kinds
296        // of session error. So we can safely say that the type of error we are meant to catch here, and
297        // therefore pass in that otherwise-irrelevant type, to ensure that error handling works properly.
298        //
299        // Some people say that if it's stupid but it works, it's not stupid. I disagree. If it's stupid but
300        // it works, that's our cue to seek out even better, non-stupid ways to get things done. I reiterate:
301        // the actual type referred to in this turbofish is nothing but a magic incantation to make error
302        // handling work; it has no bearing on the error retured from this function. How much better would it
303        // have been if `session::Error` were not parametric and we could have avoided the turbofish entirely?
304        let proteus_session = Session::init_from_prekey::<core_crypto_keystore::CryptoKeystoreError>(
305            self.proteus_identity.clone(),
306            prekey,
307        )
308        .map_err(ProteusError::wrap("initializing session from prekey"))?;
309
310        let proteus_conversation = ProteusConversationSession {
311            identifier: session_id.into(),
312            session: proteus_session,
313        };
314
315        self.proteus_sessions
316            .insert(session_id.as_bytes(), proteus_conversation);
317
318        Ok(self.proteus_sessions.get(session_id.as_bytes()).unwrap().clone())
319    }
320
321    /// Creates a new proteus Session from a received message
322    pub(crate) async fn session_from_message(
323        &mut self,
324        keystore: &mut CryptoKeystore,
325        session_id: &str,
326        envelope: &[u8],
327    ) -> Result<(GroupStoreValue<ProteusConversationSession>, Vec<u8>)> {
328        let message = Envelope::deserialise(envelope).map_err(ProteusError::wrap("deserialising envelope"))?;
329        let (session, payload) = Session::init_from_message(self.proteus_identity.clone(), keystore, &message)
330            .await
331            .map_err(ProteusError::wrap("initializing session from message"))?;
332
333        let proteus_conversation = ProteusConversationSession {
334            identifier: session_id.into(),
335            session,
336        };
337
338        self.proteus_sessions
339            .insert(session_id.as_bytes(), proteus_conversation);
340
341        Ok((
342            self.proteus_sessions.get(session_id.as_bytes()).unwrap().clone(),
343            payload,
344        ))
345    }
346
347    /// Persists a session in store
348    ///
349    /// **Note**: This isn't usually needed as persisting sessions happens automatically when decrypting/encrypting
350    /// messages and initializing Sessions
351    pub(crate) async fn session_save(&mut self, keystore: &CryptoKeystore, session_id: &str) -> Result<()> {
352        if let Some(session) = self
353            .proteus_sessions
354            .get_fetch(session_id.as_bytes(), keystore, Some(self.proteus_identity.clone()))
355            .await?
356        {
357            Self::session_save_by_ref(keystore, session).await?;
358        }
359
360        Ok(())
361    }
362
363    pub(crate) async fn session_save_by_ref(
364        keystore: &CryptoKeystore,
365        session: GroupStoreValue<ProteusConversationSession>,
366    ) -> Result<()> {
367        let session = session.read().await;
368        let db_session = ProteusSession {
369            id: session.identifier().to_string(),
370            session: session
371                .session
372                .serialise()
373                .map_err(ProteusError::wrap("serializing session"))?,
374        };
375        keystore
376            .save(db_session)
377            .await
378            .map_err(KeystoreError::wrap("saving proteus session"))?;
379        Ok(())
380    }
381
382    /// Deletes a session in the store
383    pub(crate) async fn session_delete(&mut self, keystore: &CryptoKeystore, session_id: &str) -> Result<()> {
384        if keystore.remove::<ProteusSession, _>(session_id).await.is_ok() {
385            let _ = self.proteus_sessions.remove(session_id.as_bytes());
386        }
387        Ok(())
388    }
389
390    /// Session accessor
391    pub(crate) async fn session(
392        &mut self,
393        session_id: &str,
394        keystore: &CryptoKeystore,
395    ) -> Result<Option<GroupStoreValue<ProteusConversationSession>>> {
396        self.proteus_sessions
397            .get_fetch(session_id.as_bytes(), keystore, Some(self.proteus_identity.clone()))
398            .await
399    }
400
401    /// Session exists
402    pub(crate) async fn session_exists(&mut self, session_id: &str, keystore: &CryptoKeystore) -> bool {
403        self.session(session_id, keystore).await.ok().flatten().is_some()
404    }
405
406    /// Decrypt a proteus message for an already existing session
407    /// Note: This cannot be used for handshake messages, see [ProteusCentral::session_from_message]
408    pub(crate) async fn decrypt(
409        &mut self,
410        keystore: &mut CryptoKeystore,
411        session_id: &str,
412        ciphertext: &[u8],
413    ) -> Result<Vec<u8>> {
414        let session = self
415            .proteus_sessions
416            .get_fetch(session_id.as_bytes(), keystore, Some(self.proteus_identity.clone()))
417            .await?
418            .ok_or(LeafError::ConversationNotFound(session_id.as_bytes().into()))
419            .map_err(ProteusError::wrap("getting session"))?;
420
421        let plaintext = session.write().await.decrypt(keystore, ciphertext).await?;
422        ProteusCentral::session_save_by_ref(keystore, session).await?;
423
424        Ok(plaintext)
425    }
426
427    /// Encrypt a message for a session
428    pub(crate) async fn encrypt(
429        &mut self,
430        keystore: &mut CryptoKeystore,
431        session_id: &str,
432        plaintext: &[u8],
433    ) -> Result<Vec<u8>> {
434        let session = self
435            .session(session_id, keystore)
436            .await?
437            .ok_or(LeafError::ConversationNotFound(session_id.as_bytes().into()))
438            .map_err(ProteusError::wrap("getting session"))?;
439
440        let ciphertext = session.write().await.encrypt(plaintext)?;
441        ProteusCentral::session_save_by_ref(keystore, session).await?;
442
443        Ok(ciphertext)
444    }
445
446    /// Encrypts a message for a list of sessions
447    /// This is mainly used for conversations with multiple clients, this allows to minimize FFI roundtrips
448    pub(crate) async fn encrypt_batched(
449        &mut self,
450        keystore: &mut CryptoKeystore,
451        sessions: &[impl AsRef<str>],
452        plaintext: &[u8],
453    ) -> Result<HashMap<String, Vec<u8>>> {
454        let mut acc = HashMap::new();
455        for session_id in sessions {
456            if let Some(session) = self.session(session_id.as_ref(), keystore).await? {
457                let mut session_w = session.write().await;
458                acc.insert(session_w.identifier.clone(), session_w.encrypt(plaintext)?);
459                drop(session_w);
460
461                ProteusCentral::session_save_by_ref(keystore, session).await?;
462            }
463        }
464        Ok(acc)
465    }
466
467    /// Generates a new Proteus PreKey, stores it in the keystore and returns a serialized PreKeyBundle to be consumed
468    /// externally
469    pub(crate) async fn new_prekey(&self, id: u16, keystore: &CryptoKeystore) -> Result<Vec<u8>> {
470        use proteus_wasm::keys::{PreKey, PreKeyId};
471
472        let prekey_id = PreKeyId::new(id);
473        let prekey = PreKey::new(prekey_id);
474        let keystore_prekey = core_crypto_keystore::entities::ProteusPrekey::from_raw(
475            id,
476            prekey.serialise().map_err(ProteusError::wrap("serialising prekey"))?,
477        );
478        let bundle = PreKeyBundle::new(self.proteus_identity.as_ref().public_key.clone(), &prekey);
479        let bundle = bundle
480            .serialise()
481            .map_err(ProteusError::wrap("serialising prekey bundle"))?;
482        keystore
483            .save(keystore_prekey)
484            .await
485            .map_err(KeystoreError::wrap("saving keystore prekey"))?;
486        Ok(bundle)
487    }
488
489    /// Generates a new Proteus Prekey, with an automatically auto-incremented ID.
490    ///
491    /// See [ProteusCentral::new_prekey]
492    pub(crate) async fn new_prekey_auto(&self, keystore: &CryptoKeystore) -> Result<(u16, Vec<u8>)> {
493        let id = core_crypto_keystore::entities::ProteusPrekey::get_free_id(keystore)
494            .await
495            .map_err(KeystoreError::wrap("getting proteus prekey by id"))?;
496        Ok((id, self.new_prekey(id, keystore).await?))
497    }
498
499    /// Returns the Proteus last resort prekey ID (u16::MAX = 65535 = 0xFFFF)
500    pub fn last_resort_prekey_id() -> u16 {
501        proteus_wasm::keys::MAX_PREKEY_ID.value()
502    }
503
504    /// Returns the Proteus last resort prekey
505    /// If it cannot be found, one will be created.
506    pub(crate) async fn last_resort_prekey(&self, keystore: &CryptoKeystore) -> Result<Vec<u8>> {
507        let last_resort = if let Some(last_resort) = keystore
508            .find::<core_crypto_keystore::entities::ProteusPrekey>(
509                Self::last_resort_prekey_id().to_le_bytes().as_slice(),
510            )
511            .await
512            .map_err(KeystoreError::wrap("finding proteus prekey"))?
513        {
514            proteus_wasm::keys::PreKey::deserialise(&last_resort.prekey)
515                .map_err(ProteusError::wrap("deserialising proteus prekey"))?
516        } else {
517            let last_resort = proteus_wasm::keys::PreKey::last_resort();
518
519            use core_crypto_keystore::CryptoKeystoreProteus as _;
520            keystore
521                .proteus_store_prekey(
522                    Self::last_resort_prekey_id(),
523                    &last_resort
524                        .serialise()
525                        .map_err(ProteusError::wrap("serialising last resort prekey"))?,
526                )
527                .await
528                .map_err(KeystoreError::wrap("storing proteus prekey"))?;
529
530            last_resort
531        };
532
533        let bundle = PreKeyBundle::new(self.proteus_identity.as_ref().public_key.clone(), &last_resort);
534        let bundle = bundle
535            .serialise()
536            .map_err(ProteusError::wrap("serialising prekey bundle"))?;
537
538        Ok(bundle)
539    }
540
541    /// Proteus identity keypair
542    pub fn identity(&self) -> &IdentityKeyPair {
543        self.proteus_identity.as_ref()
544    }
545
546    /// Proteus Public key hex-encoded fingerprint
547    pub fn fingerprint(&self) -> String {
548        self.proteus_identity.as_ref().public_key.fingerprint()
549    }
550
551    /// Proteus Session local hex-encoded fingerprint
552    ///
553    /// # Errors
554    /// When the session is not found
555    pub(crate) async fn fingerprint_local(&mut self, session_id: &str, keystore: &CryptoKeystore) -> Result<String> {
556        let session = self
557            .session(session_id, keystore)
558            .await?
559            .ok_or(LeafError::ConversationNotFound(session_id.as_bytes().into()))
560            .map_err(ProteusError::wrap("getting session"))?;
561        let fingerprint = session.read().await.fingerprint_local();
562        Ok(fingerprint)
563    }
564
565    /// Proteus Session remote hex-encoded fingerprint
566    ///
567    /// # Errors
568    /// When the session is not found
569    pub(crate) async fn fingerprint_remote(&mut self, session_id: &str, keystore: &CryptoKeystore) -> Result<String> {
570        let session = self
571            .session(session_id, keystore)
572            .await?
573            .ok_or(LeafError::ConversationNotFound(session_id.as_bytes().into()))
574            .map_err(ProteusError::wrap("getting session"))?;
575        let fingerprint = session.read().await.fingerprint_remote();
576        Ok(fingerprint)
577    }
578
579    /// Hex-encoded fingerprint of the given prekey
580    ///
581    /// # Errors
582    /// If the prekey cannot be deserialized
583    pub fn fingerprint_prekeybundle(prekey: &[u8]) -> Result<String> {
584        let prekey = PreKeyBundle::deserialise(prekey).map_err(ProteusError::wrap("deserialising prekey bundle"))?;
585        Ok(prekey.identity_key.fingerprint())
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use core_crypto_keystore::{ConnectionType, Database, DatabaseKey};
592
593    use super::*;
594    use crate::{
595        CertificateBundle, ClientId, ClientIdentifier, CredentialType, Session,
596        test_utils::{proteus_utils::*, x509::X509TestChain, *},
597    };
598
599    #[apply(all_cred_cipher)]
600    async fn cc_can_init(case: TestContext) {
601        #[cfg(not(target_family = "wasm"))]
602        let (path, db_file) = tmp_db_file();
603        #[cfg(target_family = "wasm")]
604        let (path, _) = tmp_db_file();
605        let client_id = ClientId::from("alice").into();
606        let db = Database::open(ConnectionType::Persistent(&path), &DatabaseKey::generate())
607            .await
608            .unwrap();
609
610        let cc: CoreCrypto = Session::try_new(&db).await.unwrap().into();
611        cc.init(client_id, &[case.ciphersuite().signature_algorithm()])
612            .await
613            .unwrap();
614        let context = cc.new_transaction().await.unwrap();
615        assert!(context.proteus_init().await.is_ok());
616        assert!(context.proteus_new_prekey(1).await.is_ok());
617        context.finish().await.unwrap();
618        #[cfg(not(target_family = "wasm"))]
619        drop(db_file);
620    }
621
622    #[apply(all_cred_cipher)]
623    async fn cc_can_2_phase_init(case: TestContext) {
624        use crate::{ClientId, Credential};
625
626        #[cfg(not(target_family = "wasm"))]
627        let (path, db_file) = tmp_db_file();
628        #[cfg(target_family = "wasm")]
629        let (path, _) = tmp_db_file();
630        let db = Database::open(ConnectionType::Persistent(&path), &DatabaseKey::generate())
631            .await
632            .unwrap();
633
634        let cc: CoreCrypto = Session::try_new(&db).await.unwrap().into();
635        let transaction = cc.new_transaction().await.unwrap();
636        let x509_test_chain = X509TestChain::init_empty(case.signature_scheme());
637        x509_test_chain.register_with_central(&transaction).await;
638        assert!(transaction.proteus_init().await.is_ok());
639        // proteus is initialized, prekeys can be generated
640        assert!(transaction.proteus_new_prekey(1).await.is_ok());
641        // 👇 and so a unique 'client_id' can be fetched from wire-server
642        let client_id = ClientId::from("alice");
643        let identifier = match case.credential_type {
644            CredentialType::Basic => ClientIdentifier::Basic(client_id),
645            CredentialType::X509 => {
646                CertificateBundle::rand_identifier(&client_id, &[x509_test_chain.find_local_intermediate_ca()])
647            }
648        };
649        transaction
650            .mls_init(identifier.clone(), &[case.ciphersuite()])
651            .await
652            .unwrap();
653
654        let credential =
655            Credential::from_identifier(&identifier, case.signature_scheme(), &cc.mls.crypto_provider).unwrap();
656        cc.add_credential(credential).await.unwrap();
657
658        // expect MLS to work
659        assert_eq!(
660            transaction
661                .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 2)
662                .await
663                .unwrap()
664                .len(),
665            2
666        );
667        #[cfg(not(target_family = "wasm"))]
668        drop(db_file);
669    }
670
671    #[macro_rules_attribute::apply(smol_macros::test)]
672    async fn can_init() {
673        #[cfg(not(target_family = "wasm"))]
674        let (path, db_file) = tmp_db_file();
675        #[cfg(target_family = "wasm")]
676        let (path, _) = tmp_db_file();
677        let key = DatabaseKey::generate();
678        let keystore = core_crypto_keystore::Database::open(ConnectionType::Persistent(&path), &key)
679            .await
680            .unwrap();
681        keystore.new_transaction().await.unwrap();
682        let central = ProteusCentral::try_new(&keystore).await.unwrap();
683        let identity = (*central.proteus_identity).clone();
684        keystore.commit_transaction().await.unwrap();
685
686        let keystore = core_crypto_keystore::Database::open(ConnectionType::Persistent(&path), &key)
687            .await
688            .unwrap();
689        keystore.new_transaction().await.unwrap();
690        let central = ProteusCentral::try_new(&keystore).await.unwrap();
691        keystore.commit_transaction().await.unwrap();
692        assert_eq!(identity, *central.proteus_identity);
693
694        keystore.wipe().await.unwrap();
695        #[cfg(not(target_family = "wasm"))]
696        drop(db_file);
697    }
698
699    #[macro_rules_attribute::apply(smol_macros::test)]
700    async fn can_talk_with_proteus() {
701        #[cfg(not(target_family = "wasm"))]
702        let (path, db_file) = tmp_db_file();
703        #[cfg(target_family = "wasm")]
704        let (path, _) = tmp_db_file();
705
706        let session_id = uuid::Uuid::new_v4().hyphenated().to_string();
707
708        let key = DatabaseKey::generate();
709        let mut keystore = core_crypto_keystore::Database::open(ConnectionType::Persistent(&path), &key)
710            .await
711            .unwrap();
712        keystore.new_transaction().await.unwrap();
713
714        let mut alice = ProteusCentral::try_new(&keystore).await.unwrap();
715
716        let mut bob = CryptoboxLike::init();
717        let bob_pk_bundle = bob.new_prekey();
718
719        alice
720            .session_from_prekey(&session_id, &bob_pk_bundle.serialise().unwrap())
721            .await
722            .unwrap();
723
724        let message = b"Hello world";
725
726        let encrypted = alice.encrypt(&mut keystore, &session_id, message).await.unwrap();
727        let decrypted = bob.decrypt(&session_id, &encrypted).await;
728        assert_eq!(decrypted, message);
729
730        let encrypted = bob.encrypt(&session_id, message);
731        let decrypted = alice.decrypt(&mut keystore, &session_id, &encrypted).await.unwrap();
732        assert_eq!(decrypted, message);
733
734        keystore.commit_transaction().await.unwrap();
735        keystore.wipe().await.unwrap();
736        #[cfg(not(target_family = "wasm"))]
737        drop(db_file);
738    }
739
740    #[macro_rules_attribute::apply(smol_macros::test)]
741    async fn can_produce_proteus_consumed_prekeys() {
742        #[cfg(not(target_family = "wasm"))]
743        let (path, db_file) = tmp_db_file();
744        #[cfg(target_family = "wasm")]
745        let (path, _) = tmp_db_file();
746
747        let session_id = uuid::Uuid::new_v4().hyphenated().to_string();
748
749        let key = DatabaseKey::generate();
750        let mut keystore = core_crypto_keystore::Database::open(ConnectionType::Persistent(&path), &key)
751            .await
752            .unwrap();
753        keystore.new_transaction().await.unwrap();
754        let mut alice = ProteusCentral::try_new(&keystore).await.unwrap();
755
756        let mut bob = CryptoboxLike::init();
757
758        let alice_prekey_bundle_ser = alice.new_prekey(1, &keystore).await.unwrap();
759
760        bob.init_session_from_prekey_bundle(&session_id, &alice_prekey_bundle_ser);
761        let message = b"Hello world!";
762        let encrypted = bob.encrypt(&session_id, message);
763
764        let (_, decrypted) = alice
765            .session_from_message(&mut keystore, &session_id, &encrypted)
766            .await
767            .unwrap();
768
769        assert_eq!(message, decrypted.as_slice());
770
771        let encrypted = alice.encrypt(&mut keystore, &session_id, message).await.unwrap();
772        let decrypted = bob.decrypt(&session_id, &encrypted).await;
773
774        assert_eq!(message, decrypted.as_slice());
775        keystore.commit_transaction().await.unwrap();
776        keystore.wipe().await.unwrap();
777        #[cfg(not(target_family = "wasm"))]
778        drop(db_file);
779    }
780
781    #[macro_rules_attribute::apply(smol_macros::test)]
782    async fn auto_prekeys_are_sequential() {
783        use core_crypto_keystore::entities::ProteusPrekey;
784        const GAP_AMOUNT: u16 = 5;
785        const ID_TEST_RANGE: std::ops::RangeInclusive<u16> = 1..=30;
786
787        #[cfg(not(target_family = "wasm"))]
788        let (path, db_file) = tmp_db_file();
789        #[cfg(target_family = "wasm")]
790        let (path, _) = tmp_db_file();
791
792        let key = DatabaseKey::generate();
793        let keystore = core_crypto_keystore::Database::open(ConnectionType::Persistent(&path), &key)
794            .await
795            .unwrap();
796        keystore.new_transaction().await.unwrap();
797        let alice = ProteusCentral::try_new(&keystore).await.unwrap();
798
799        for i in ID_TEST_RANGE {
800            let (pk_id, pkb) = alice.new_prekey_auto(&keystore).await.unwrap();
801            assert_eq!(i, pk_id);
802            let prekey = proteus_wasm::keys::PreKeyBundle::deserialise(&pkb).unwrap();
803            assert_eq!(prekey.prekey_id.value(), pk_id);
804        }
805
806        use rand::Rng as _;
807        let mut rng = rand::thread_rng();
808        let mut gap_ids: Vec<u16> = (0..GAP_AMOUNT).map(|_| rng.gen_range(ID_TEST_RANGE)).collect();
809        gap_ids.sort();
810        gap_ids.dedup();
811        while gap_ids.len() < GAP_AMOUNT as usize {
812            gap_ids.push(rng.gen_range(ID_TEST_RANGE));
813            gap_ids.sort();
814            gap_ids.dedup();
815        }
816        for gap_id in gap_ids.iter() {
817            keystore.remove::<ProteusPrekey, _>(gap_id.to_le_bytes()).await.unwrap();
818        }
819
820        gap_ids.sort();
821
822        for gap_id in gap_ids.iter() {
823            let (pk_id, pkb) = alice.new_prekey_auto(&keystore).await.unwrap();
824            assert_eq!(pk_id, *gap_id);
825            let prekey = proteus_wasm::keys::PreKeyBundle::deserialise(&pkb).unwrap();
826            assert_eq!(prekey.prekey_id.value(), *gap_id);
827        }
828
829        let mut gap_ids: Vec<u16> = (0..GAP_AMOUNT).map(|_| rng.gen_range(ID_TEST_RANGE)).collect();
830        gap_ids.sort();
831        gap_ids.dedup();
832        while gap_ids.len() < GAP_AMOUNT as usize {
833            gap_ids.push(rng.gen_range(ID_TEST_RANGE));
834            gap_ids.sort();
835            gap_ids.dedup();
836        }
837        for gap_id in gap_ids.iter() {
838            keystore.remove::<ProteusPrekey, _>(gap_id.to_le_bytes()).await.unwrap();
839        }
840
841        let potential_range = *ID_TEST_RANGE.end()..=(*ID_TEST_RANGE.end() * 2);
842        let potential_range_check = potential_range.clone();
843        for _ in potential_range {
844            let (pk_id, pkb) = alice.new_prekey_auto(&keystore).await.unwrap();
845            assert!(gap_ids.contains(&pk_id) || potential_range_check.contains(&pk_id));
846            let prekey = proteus_wasm::keys::PreKeyBundle::deserialise(&pkb).unwrap();
847            assert_eq!(prekey.prekey_id.value(), pk_id);
848        }
849        keystore.commit_transaction().await.unwrap();
850        keystore.wipe().await.unwrap();
851        #[cfg(not(target_family = "wasm"))]
852        drop(db_file);
853    }
854}