core_crypto/e2e_identity/
stash.rs

1use openmls_traits::{OpenMlsCryptoProvider, random::OpenMlsRand};
2
3use super::Result;
4use crate::context::CentralContext;
5use crate::prelude::E2eiEnrollment;
6use crate::{KeystoreError, MlsError, RecursiveError};
7use core_crypto_keystore::CryptoKeystoreMls;
8use mls_crypto_provider::MlsCryptoProvider;
9
10/// A unique identifier for an enrollment a consumer can use to fetch it from the keystore when he
11/// wants to resume the process
12pub(crate) type EnrollmentHandle = Vec<u8>;
13
14impl E2eiEnrollment {
15    pub(crate) async fn stash(self, backend: &MlsCryptoProvider) -> Result<EnrollmentHandle> {
16        // should be enough to prevent collisions
17        const HANDLE_SIZE: usize = 32;
18
19        let content = serde_json::to_vec(&self)?;
20        let handle = backend
21            .crypto()
22            .random_vec(HANDLE_SIZE)
23            .map_err(MlsError::wrap("generating random vector of bytes"))?;
24        backend
25            .key_store()
26            .save_e2ei_enrollment(&handle, &content)
27            .await
28            .map_err(KeystoreError::wrap("saving e2ei enrollment"))?;
29        Ok(handle)
30    }
31
32    pub(crate) async fn stash_pop(backend: &MlsCryptoProvider, handle: EnrollmentHandle) -> Result<Self> {
33        let content = backend
34            .key_store()
35            .pop_e2ei_enrollment(&handle)
36            .await
37            .map_err(KeystoreError::wrap("popping e2ei enrollment"))?;
38        Ok(serde_json::from_slice(&content)?)
39    }
40}
41
42impl CentralContext {
43    /// Allows persisting an active enrollment (for example while redirecting the user during OAuth)
44    /// in order to resume it later with [CentralContext::e2ei_enrollment_stash_pop]
45    ///
46    /// # Arguments
47    /// * `enrollment` - the enrollment instance to persist
48    ///
49    /// # Returns
50    /// A handle for retrieving the enrollment later on
51    pub async fn e2ei_enrollment_stash(&self, enrollment: E2eiEnrollment) -> Result<EnrollmentHandle> {
52        enrollment
53            .stash(
54                &self
55                    .mls_provider()
56                    .await
57                    .map_err(RecursiveError::root("getting mls provider"))?,
58            )
59            .await
60    }
61
62    /// Fetches the persisted enrollment and deletes it from the keystore
63    ///
64    /// # Arguments
65    /// * `handle` - returned by [CentralContext::e2ei_enrollment_stash]
66    pub async fn e2ei_enrollment_stash_pop(&self, handle: EnrollmentHandle) -> Result<E2eiEnrollment> {
67        E2eiEnrollment::stash_pop(
68            &self
69                .mls_provider()
70                .await
71                .map_err(RecursiveError::root("getting mls provider"))?,
72            handle,
73        )
74        .await
75    }
76}
77
78#[cfg(test)]
79mod tests {
80
81    use mls_crypto_provider::MlsCryptoProvider;
82    use wasm_bindgen_test::*;
83
84    use crate::{
85        e2e_identity::id::WireQualifiedClientId,
86        e2e_identity::tests::*,
87        prelude::{E2eiEnrollment, INITIAL_KEYING_MATERIAL_COUNT},
88        test_utils::{x509::X509TestChain, *},
89    };
90
91    wasm_bindgen_test_configure!(run_in_browser);
92
93    #[apply(all_cred_cipher)]
94    #[wasm_bindgen_test]
95    async fn stash_and_pop_should_not_abort_enrollment(case: TestCase) {
96        run_test_wo_clients(case.clone(), move |mut cc| {
97            Box::pin(async move {
98                let x509_test_chain = X509TestChain::init_empty(case.signature_scheme());
99
100                let is_renewal = false;
101                let (mut enrollment, cert) = e2ei_enrollment(
102                    &mut cc,
103                    &case,
104                    &x509_test_chain,
105                    Some(E2EI_CLIENT_ID_URI),
106                    is_renewal,
107                    init_enrollment,
108                    |e, cc| {
109                        Box::pin(async move {
110                            let handle = cc.e2ei_enrollment_stash(e).await.unwrap();
111                            cc.e2ei_enrollment_stash_pop(handle).await.unwrap()
112                        })
113                    },
114                )
115                .await
116                .unwrap();
117
118                assert!(
119                    cc.context
120                        .e2ei_mls_init_only(&mut enrollment, cert, Some(INITIAL_KEYING_MATERIAL_COUNT))
121                        .await
122                        .is_ok()
123                );
124            })
125        })
126        .await
127    }
128
129    // this ensures the nominal test does its job
130    #[apply(all_cred_cipher)]
131    #[wasm_bindgen_test]
132    async fn should_fail_when_restoring_invalid(case: TestCase) {
133        run_test_wo_clients(case.clone(), move |mut cc| {
134            Box::pin(async move {
135                let x509_test_chain = X509TestChain::init_empty(case.signature_scheme());
136
137                let is_renewal = false;
138                let result = e2ei_enrollment(
139                    &mut cc,
140                    &case,
141                    &x509_test_chain,
142                    Some(E2EI_CLIENT_ID_URI),
143                    is_renewal,
144                    init_enrollment,
145                    move |e, _cc| {
146                        Box::pin(async move {
147                            // this restore recreates a partial enrollment
148                            let backend = MlsCryptoProvider::try_new_in_memory("new").await.unwrap();
149                            backend.new_transaction().await.unwrap();
150                            let client_id = e.client_id.parse::<WireQualifiedClientId>().unwrap();
151                            E2eiEnrollment::try_new(
152                                client_id.into(),
153                                e.display_name,
154                                e.handle,
155                                e.team,
156                                1,
157                                &backend,
158                                e.ciphersuite,
159                                None,
160                                #[cfg(not(target_family = "wasm"))]
161                                None,
162                            )
163                            .unwrap()
164                        })
165                    },
166                )
167                .await;
168                assert!(result.is_err());
169            })
170        })
171        .await
172    }
173}