1use log::trace;
2
3use crate::{
4 MlsError, RecursiveError,
5 prelude::{
6 ClientId, MlsCiphersuite, MlsConversation, MlsCredentialType, Session, identifier::ClientIdentifier,
7 key_package::INITIAL_KEYING_MATERIAL_COUNT,
8 },
9};
10use core_crypto_keystore::DatabaseKey;
11use mls_crypto_provider::MlsCryptoProvider;
12use openmls_traits::OpenMlsCryptoProvider;
13
14use crate::transaction_context::TransactionContext;
15
16pub(crate) mod ciphersuite;
17pub mod conversation;
18pub(crate) mod credential;
19mod error;
20pub(crate) mod proposal;
21pub(crate) mod session;
22
23pub use error::{Error, Result};
24pub use session::EpochObserver;
25
26pub(crate) mod config {
28 use ciphersuite::MlsCiphersuite;
29 use mls_crypto_provider::EntropySeed;
30
31 use super::*;
32
33 #[derive(Debug, Clone)]
35 #[non_exhaustive]
36 pub struct MlsClientConfiguration {
37 pub store_path: String,
39 pub database_key: DatabaseKey,
41 pub client_id: Option<ClientId>,
43 pub external_entropy: Option<EntropySeed>,
45 pub ciphersuites: Vec<ciphersuite::MlsCiphersuite>,
47 pub nb_init_key_packages: Option<usize>,
49 }
50
51 impl MlsClientConfiguration {
52 pub fn try_new(
81 store_path: String,
82 database_key: DatabaseKey,
83 client_id: Option<ClientId>,
84 ciphersuites: Vec<MlsCiphersuite>,
85 entropy: Option<Vec<u8>>,
86 nb_init_key_packages: Option<usize>,
87 ) -> Result<Self> {
88 if store_path.trim().is_empty() {
90 return Err(Error::MalformedIdentifier("store_path"));
91 }
92 if let Some(client_id) = client_id.as_ref() {
94 if client_id.is_empty() {
95 return Err(Error::MalformedIdentifier("client_id"));
96 }
97 }
98 let external_entropy = entropy
99 .as_deref()
100 .map(|seed| &seed[..EntropySeed::EXPECTED_LEN])
101 .map(EntropySeed::try_from_slice)
102 .transpose()
103 .map_err(MlsError::wrap("gathering external entropy"))?;
104 Ok(Self {
105 store_path,
106 database_key,
107 client_id,
108 ciphersuites,
109 external_entropy,
110 nb_init_key_packages,
111 })
112 }
113
114 pub fn set_entropy(&mut self, entropy: EntropySeed) {
116 self.external_entropy = Some(entropy);
117 }
118
119 #[cfg(test)]
120 #[allow(dead_code)]
121 pub(crate) fn tmp_store_path(tmp_dir: &tempfile::TempDir) -> String {
124 let path = tmp_dir.path().join("store.edb");
125 std::fs::File::create(&path).unwrap();
126 path.to_str().unwrap().to_string()
127 }
128 }
129}
130
131#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
132#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
133pub(crate) trait HasSessionAndCrypto: Send {
134 async fn session(&self) -> Result<Session>;
135 async fn crypto_provider(&self) -> Result<MlsCryptoProvider>;
136}
137
138impl TransactionContext {
139 pub async fn mls_init(
143 &self,
144 identifier: ClientIdentifier,
145 ciphersuites: Vec<MlsCiphersuite>,
146 nb_init_key_packages: Option<usize>,
147 ) -> Result<()> {
148 let nb_key_package = nb_init_key_packages.unwrap_or(INITIAL_KEYING_MATERIAL_COUNT);
149 let mls_client = self
150 .session()
151 .await
152 .map_err(RecursiveError::transaction("getting mls client"))?;
153 mls_client
154 .init(
155 identifier,
156 &ciphersuites,
157 &self
158 .mls_provider()
159 .await
160 .map_err(RecursiveError::transaction("getting mls provider"))?,
161 nb_key_package,
162 )
163 .await
164 .map_err(RecursiveError::mls_client("initializing mls client"))?;
165
166 if mls_client.is_e2ei_capable().await {
167 let client_id = mls_client
168 .id()
169 .await
170 .map_err(RecursiveError::mls_client("getting client id"))?;
171 trace!(client_id:% = client_id; "Initializing PKI environment");
172 self.init_pki_env()
173 .await
174 .map_err(RecursiveError::transaction("initializing pki env"))?;
175 }
176
177 Ok(())
178 }
179
180 #[cfg_attr(test, crate::dispotent)]
185 pub async fn mls_generate_keypairs(&self, ciphersuites: Vec<MlsCiphersuite>) -> Result<Vec<ClientId>> {
186 self.session()
187 .await
188 .map_err(RecursiveError::transaction("getting mls client"))?
189 .generate_raw_keypairs(
190 &ciphersuites,
191 &self
192 .mls_provider()
193 .await
194 .map_err(RecursiveError::transaction("getting mls provider"))?,
195 )
196 .await
197 .map_err(RecursiveError::mls_client("generating raw keypairs"))
198 .map_err(Into::into)
199 }
200
201 #[cfg_attr(test, crate::dispotent)]
205 pub async fn mls_init_with_client_id(
206 &self,
207 client_id: ClientId,
208 tmp_client_ids: Vec<ClientId>,
209 ciphersuites: Vec<MlsCiphersuite>,
210 ) -> Result<()> {
211 self.session()
212 .await
213 .map_err(RecursiveError::transaction("getting mls client"))?
214 .init_with_external_client_id(
215 client_id,
216 tmp_client_ids,
217 &ciphersuites,
218 &self
219 .mls_provider()
220 .await
221 .map_err(RecursiveError::transaction("getting mls provider"))?,
222 )
223 .await
224 .map_err(RecursiveError::mls_client(
225 "initializing mls client with external client id",
226 ))
227 .map_err(Into::into)
228 }
229
230 pub async fn client_public_key(
232 &self,
233 ciphersuite: MlsCiphersuite,
234 credential_type: MlsCredentialType,
235 ) -> Result<Vec<u8>> {
236 let cb = self
237 .session()
238 .await
239 .map_err(RecursiveError::transaction("getting mls client"))?
240 .find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), credential_type)
241 .await
242 .map_err(RecursiveError::mls_client("finding most recent credential bundle"))?;
243 Ok(cb.signature_key.to_public_vec())
244 }
245
246 pub async fn client_id(&self) -> Result<ClientId> {
248 self.session()
249 .await
250 .map_err(RecursiveError::transaction("getting mls client"))?
251 .id()
252 .await
253 .map_err(RecursiveError::mls_client("getting client id"))
254 .map_err(Into::into)
255 }
256
257 pub async fn random_bytes(&self, len: usize) -> Result<Vec<u8>> {
259 use openmls_traits::random::OpenMlsRand as _;
260 self.mls_provider()
261 .await
262 .map_err(RecursiveError::transaction("getting mls provider"))?
263 .rand()
264 .random_vec(len)
265 .map_err(MlsError::wrap("generating random vector"))
266 .map_err(Into::into)
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use crate::transaction_context::Error as TransactionError;
273 use wasm_bindgen_test::*;
274
275 use crate::prelude::{
276 CertificateBundle, ClientIdentifier, INITIAL_KEYING_MATERIAL_COUNT, MlsClientConfiguration, MlsCredentialType,
277 };
278 use crate::{
279 CoreCrypto,
280 mls::Session,
281 test_utils::{x509::X509TestChain, *},
282 };
283
284 wasm_bindgen_test_configure!(run_in_browser);
285
286 use core_crypto_keystore::DatabaseKey;
287
288 mod conversation_epoch {
289 use super::*;
290 use crate::mls::conversation::Conversation as _;
291
292 #[apply(all_cred_cipher)]
293 #[wasm_bindgen_test]
294 async fn can_get_newly_created_conversation_epoch(case: TestCase) {
295 run_test_with_central(case.clone(), move |[central]| {
296 Box::pin(async move {
297 let id = conversation_id();
298 central
299 .context
300 .new_conversation(&id, case.credential_type, case.cfg.clone())
301 .await
302 .unwrap();
303 let epoch = central.context.conversation(&id).await.unwrap().epoch().await;
304 assert_eq!(epoch, 0);
305 })
306 })
307 .await;
308 }
309
310 #[apply(all_cred_cipher)]
311 #[wasm_bindgen_test]
312 async fn can_get_conversation_epoch(case: TestCase) {
313 run_test_with_client_ids(case.clone(), ["alice", "bob"], move |[alice_central, bob_central]| {
314 Box::pin(async move {
315 let id = conversation_id();
316 alice_central
317 .context
318 .new_conversation(&id, case.credential_type, case.cfg.clone())
319 .await
320 .unwrap();
321 alice_central.invite_all(&case, &id, [&bob_central]).await.unwrap();
322 let epoch = alice_central.context.conversation(&id).await.unwrap().epoch().await;
323 assert_eq!(epoch, 1);
324 })
325 })
326 .await;
327 }
328
329 #[apply(all_cred_cipher)]
330 #[wasm_bindgen_test]
331 async fn conversation_not_found(case: TestCase) {
332 use crate::LeafError;
333
334 run_test_with_central(case.clone(), move |[central]| {
335 Box::pin(async move {
336 let id = conversation_id();
337 let err = central.context.conversation(&id).await.unwrap_err();
338 assert!(matches!(
339 err,
340 TransactionError::Leaf(LeafError::ConversationNotFound(i)) if i == id
341 ));
342 })
343 })
344 .await;
345 }
346 }
347
348 mod invariants {
349 use crate::{mls, prelude::MlsCiphersuite};
350
351 use super::*;
352
353 #[apply(all_cred_cipher)]
354 #[wasm_bindgen_test]
355 async fn can_create_from_valid_configuration(case: TestCase) {
356 run_tests(move |[tmp_dir_argument]| {
357 Box::pin(async move {
358 let configuration = MlsClientConfiguration::try_new(
359 tmp_dir_argument,
360 DatabaseKey::generate(),
361 Some("alice".into()),
362 vec![case.ciphersuite()],
363 None,
364 Some(INITIAL_KEYING_MATERIAL_COUNT),
365 )
366 .unwrap();
367
368 let new_client_result = Session::try_new(configuration).await;
369 assert!(new_client_result.is_ok())
370 })
371 })
372 .await
373 }
374
375 #[test]
376 #[wasm_bindgen_test]
377 fn store_path_should_not_be_empty_nor_blank() {
378 let ciphersuites = vec![MlsCiphersuite::default()];
379 let configuration = MlsClientConfiguration::try_new(
380 " ".to_string(),
381 DatabaseKey::generate(),
382 Some("alice".into()),
383 ciphersuites,
384 None,
385 Some(INITIAL_KEYING_MATERIAL_COUNT),
386 );
387 assert!(matches!(
388 configuration.unwrap_err(),
389 mls::Error::MalformedIdentifier("store_path")
390 ));
391 }
392
393 #[cfg_attr(not(target_family = "wasm"), async_std::test)]
394 #[wasm_bindgen_test]
395 async fn client_id_should_not_be_empty() {
396 run_tests(|[tmp_dir_argument]| {
397 Box::pin(async move {
398 let ciphersuites = vec![MlsCiphersuite::default()];
399 let configuration = MlsClientConfiguration::try_new(
400 tmp_dir_argument,
401 DatabaseKey::generate(),
402 Some("".into()),
403 ciphersuites,
404 None,
405 Some(INITIAL_KEYING_MATERIAL_COUNT),
406 );
407 assert!(matches!(
408 configuration.unwrap_err(),
409 mls::Error::MalformedIdentifier("client_id")
410 ));
411 })
412 })
413 .await
414 }
415 }
416
417 #[apply(all_cred_cipher)]
418 #[wasm_bindgen_test]
419 async fn create_conversation_should_fail_when_already_exists(case: TestCase) {
420 use crate::LeafError;
421
422 run_test_with_client_ids(case.clone(), ["alice"], move |[alice_central]| {
423 Box::pin(async move {
424 let id = conversation_id();
425
426 let create = alice_central
427 .context
428 .new_conversation(&id, case.credential_type, case.cfg.clone())
429 .await;
430 assert!(create.is_ok());
431
432 let repeat_create = alice_central
434 .context
435 .new_conversation(&id, case.credential_type, case.cfg.clone())
436 .await;
437 assert!(matches!(repeat_create.unwrap_err(), TransactionError::Leaf(LeafError::ConversationAlreadyExists(i)) if i == id));
438 })
439 })
440 .await;
441 }
442
443 #[apply(all_cred_cipher)]
444 #[wasm_bindgen_test]
445 async fn can_fetch_client_public_key(case: TestCase) {
446 run_tests(move |[tmp_dir_argument]| {
447 Box::pin(async move {
448 let configuration = MlsClientConfiguration::try_new(
449 tmp_dir_argument,
450 DatabaseKey::generate(),
451 Some("potato".into()),
452 vec![case.ciphersuite()],
453 None,
454 Some(INITIAL_KEYING_MATERIAL_COUNT),
455 )
456 .unwrap();
457
458 let result = Session::try_new(configuration.clone()).await;
459 println!("{:?}", result);
460 assert!(result.is_ok());
461 })
462 })
463 .await
464 }
465
466 #[apply(all_cred_cipher)]
467 #[wasm_bindgen_test]
468 async fn can_2_phase_init_central(case: TestCase) {
469 run_tests(move |[tmp_dir_argument]| {
470 Box::pin(async move {
471 let x509_test_chain = X509TestChain::init_empty(case.signature_scheme());
472 let configuration = MlsClientConfiguration::try_new(
473 tmp_dir_argument,
474 DatabaseKey::generate(),
475 None,
476 vec![case.ciphersuite()],
477 None,
478 Some(INITIAL_KEYING_MATERIAL_COUNT),
479 )
480 .unwrap();
481 let client = Session::try_new(configuration).await.unwrap();
483 let cc = CoreCrypto::from(client);
484 let context = cc.new_transaction().await.unwrap();
485 x509_test_chain.register_with_central(&context).await;
486
487 assert!(!context.session().await.unwrap().is_ready().await);
488 let client_id = "alice";
490 let identifier = match case.credential_type {
491 MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.into()),
492 MlsCredentialType::X509 => {
493 CertificateBundle::rand_identifier(client_id, &[x509_test_chain.find_local_intermediate_ca()])
494 }
495 };
496 context
497 .mls_init(
498 identifier,
499 vec![case.ciphersuite()],
500 Some(INITIAL_KEYING_MATERIAL_COUNT),
501 )
502 .await
503 .unwrap();
504 assert!(context.session().await.unwrap().is_ready().await);
505 assert_eq!(
507 context
508 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 2)
509 .await
510 .unwrap()
511 .len(),
512 2
513 );
514 })
515 })
516 .await
517 }
518}