1use crate::prelude::{ClientId, MlsConversation, Session};
2use mls_crypto_provider::MlsCryptoProvider;
3
4pub(crate) mod ciphersuite;
5pub mod conversation;
6pub(crate) mod credential;
7mod error;
8pub(crate) mod proposal;
9pub(crate) mod session;
10
11pub use error::{Error, Result};
12pub use session::EpochObserver;
13pub use session::HistoryObserver;
14
15#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
16#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
17pub(crate) trait HasSessionAndCrypto: Send {
18 async fn session(&self) -> Result<Session>;
19 async fn crypto_provider(&self) -> Result<MlsCryptoProvider>;
20}
21
22#[cfg(test)]
23mod tests {
24 use crate::transaction_context::Error as TransactionError;
25
26 use crate::prelude::{
27 CertificateBundle, ClientIdentifier, INITIAL_KEYING_MATERIAL_COUNT, MlsCredentialType, SessionConfig,
28 };
29 use crate::{
30 CoreCrypto,
31 mls::Session,
32 test_utils::{x509::X509TestChain, *},
33 };
34
35 use core_crypto_keystore::DatabaseKey;
36
37 mod conversation_epoch {
38 use super::*;
39 use crate::mls::conversation::Conversation as _;
40
41 #[apply(all_cred_cipher)]
42 async fn can_get_newly_created_conversation_epoch(case: TestContext) {
43 let [session] = case.sessions().await;
44 let conversation = case.create_conversation([&session]).await;
45 let epoch = conversation.guard().await.epoch().await;
46 assert_eq!(epoch, 0);
47 }
48
49 #[apply(all_cred_cipher)]
50 async fn can_get_conversation_epoch(case: TestContext) {
51 let [alice, bob] = case.sessions().await;
52 Box::pin(async move {
53 let conversation = case.create_conversation([&alice, &bob]).await;
54 let epoch = conversation.guard().await.epoch().await;
55 assert_eq!(epoch, 1);
56 })
57 .await;
58 }
59
60 #[apply(all_cred_cipher)]
61 async fn conversation_not_found(case: TestContext) {
62 use crate::LeafError;
63 let [session] = case.sessions().await;
64 let id = conversation_id();
65 let err = session.transaction.conversation(&id).await.unwrap_err();
66 assert!(matches!(
67 err,
68 TransactionError::Leaf(LeafError::ConversationNotFound(i)) if i == id
69 ));
70 }
71 }
72
73 mod invariants {
74 use crate::{mls, prelude::MlsCiphersuite};
75
76 use super::*;
77
78 #[apply(all_cred_cipher)]
79 async fn can_create_from_valid_configuration(mut case: TestContext) {
80 let tmp_dir = case.tmp_dir().await;
81 Box::pin(async move {
82 let configuration = SessionConfig::builder()
83 .persistent(&tmp_dir)
84 .database_key(DatabaseKey::generate())
85 .client_id("alice".into())
86 .ciphersuites([case.ciphersuite()])
87 .build()
88 .validate()
89 .unwrap();
90
91 let new_client_result = Session::try_new(configuration).await;
92 assert!(new_client_result.is_ok())
93 })
94 .await
95 }
96
97 #[test]
98 fn store_path_should_not_be_empty_nor_blank() {
99 let config_err = SessionConfig::builder()
100 .persistent(" ")
101 .database_key(DatabaseKey::generate())
102 .ciphersuites([MlsCiphersuite::default()])
103 .build()
104 .validate()
105 .unwrap_err();
106
107 assert!(matches!(config_err, mls::Error::MalformedIdentifier(msg) if msg.contains("path")));
108 }
109
110 #[async_std::test]
111 async fn client_id_should_not_be_empty() {
112 let mut case = TestContext::default();
113 let tmp_dir = case.tmp_dir().await;
114 Box::pin(async move {
115 let config_err = SessionConfig::builder()
116 .persistent(&tmp_dir)
117 .database_key(DatabaseKey::generate())
118 .client_id("".into())
119 .ciphersuites([MlsCiphersuite::default()])
120 .build()
121 .validate()
122 .unwrap_err();
123
124 assert!(matches!(config_err, mls::Error::MalformedIdentifier("client_id")));
125 })
126 .await
127 }
128 }
129
130 #[apply(all_cred_cipher)]
131 async fn create_conversation_should_fail_when_already_exists(case: TestContext) {
132 use crate::LeafError;
133
134 let [alice] = case.sessions().await;
135 Box::pin(async move {
136 let conversation = case.create_conversation([&alice]).await;
137 let id = conversation.id().clone();
138
139 let repeat_create = alice
141 .transaction
142 .new_conversation(&id, case.credential_type, case.cfg.clone())
143 .await;
144 assert!(matches!(repeat_create.unwrap_err(), TransactionError::Leaf(LeafError::ConversationAlreadyExists(i)) if i == id));
145 })
146 .await;
147 }
148
149 #[apply(all_cred_cipher)]
150 async fn can_fetch_client_public_key(mut case: TestContext) {
151 let tmp_dir = case.tmp_dir().await;
152 Box::pin(async move {
153 let configuration = SessionConfig::builder()
154 .persistent(&tmp_dir)
155 .database_key(DatabaseKey::generate())
156 .client_id("potato".into())
157 .ciphersuites([case.ciphersuite()])
158 .build()
159 .validate()
160 .unwrap();
161
162 let result = Session::try_new(configuration).await;
163 println!("{result:?}");
164 assert!(result.is_ok());
165 })
166 .await
167 }
168
169 #[apply(all_cred_cipher)]
170 async fn can_2_phase_init_central(mut case: TestContext) {
171 let tmp_dir = case.tmp_dir().await;
172 Box::pin(async move {
173 let x509_test_chain = X509TestChain::init_empty(case.signature_scheme());
174 let configuration = SessionConfig::builder()
175 .persistent(&tmp_dir)
176 .database_key(DatabaseKey::generate())
177 .ciphersuites([case.ciphersuite()])
178 .build()
179 .validate()
180 .unwrap();
181
182 let client = Session::try_new(configuration).await.unwrap();
184 let cc = CoreCrypto::from(client);
185 let context = cc.new_transaction().await.unwrap();
186 x509_test_chain.register_with_central(&context).await;
187
188 assert!(!context.session().await.unwrap().is_ready().await);
189 let client_id = "alice";
191 let identifier = match case.credential_type {
192 MlsCredentialType::Basic => ClientIdentifier::Basic(client_id.into()),
193 MlsCredentialType::X509 => {
194 CertificateBundle::rand_identifier(client_id, &[x509_test_chain.find_local_intermediate_ca()])
195 }
196 };
197 context
198 .mls_init(
199 identifier,
200 vec![case.ciphersuite()],
201 Some(INITIAL_KEYING_MATERIAL_COUNT),
202 )
203 .await
204 .unwrap();
205 assert!(context.session().await.unwrap().is_ready().await);
206 assert_eq!(
208 context
209 .get_or_create_client_keypackages(case.ciphersuite(), case.credential_type, 2)
210 .await
211 .unwrap()
212 .len(),
213 2
214 );
215 })
216 .await
217 }
218}