1#![allow(rustdoc::invalid_html_tags)]
12
13use std::fmt::Display;
14use std::num::NonZeroU32;
15
16use base64::prelude::*;
17use itertools::Itertools;
18
19use crate::password::Password;
20
21const DEFAULT_SALT_SIZE: usize = 32;
23
24const SHA256_OUTPUT_LEN: usize = 32;
25
26#[derive(Debug, PartialEq)]
28pub struct HashOpts {
29 pub iterations: NonZeroU32,
31 pub salt: [u8; DEFAULT_SALT_SIZE],
35}
36
37pub struct PasswordHash {
38 pub salt: [u8; DEFAULT_SALT_SIZE],
40 pub iterations: NonZeroU32,
42 pub hash: [u8; SHA256_OUTPUT_LEN],
45}
46
47#[derive(Debug)]
48pub enum VerifyError {
49 MalformedHash,
50 InvalidPassword,
51 Hash(HashError),
52}
53
54#[derive(Debug)]
55pub enum HashError {
56 Openssl(openssl::error::ErrorStack),
57}
58
59impl Display for HashError {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 HashError::Openssl(e) => write!(f, "OpenSSL error: {}", e),
63 }
64 }
65}
66
67pub fn hash_password(
70 password: &Password,
71 iterations: &NonZeroU32,
72) -> Result<PasswordHash, HashError> {
73 let mut salt = [0u8; DEFAULT_SALT_SIZE];
74 openssl::rand::rand_bytes(&mut salt).map_err(HashError::Openssl)?;
75
76 let hash = hash_password_inner(
77 &HashOpts {
78 iterations: iterations.to_owned(),
79 salt,
80 },
81 password.to_string().as_bytes(),
82 )?;
83
84 Ok(PasswordHash {
85 salt,
86 iterations: iterations.to_owned(),
87 hash,
88 })
89}
90
91pub fn generate_nonce(client_nonce: &str) -> Result<String, HashError> {
92 let mut nonce = [0u8; 24];
93 openssl::rand::rand_bytes(&mut nonce).map_err(HashError::Openssl)?;
94 let nonce = BASE64_STANDARD.encode(&nonce);
95 let new_nonce = format!("{}{}", client_nonce, nonce);
96 Ok(new_nonce)
97}
98
99pub fn hash_password_with_opts(
102 opts: &HashOpts,
103 password: &Password,
104) -> Result<PasswordHash, HashError> {
105 let hash = hash_password_inner(opts, password.to_string().as_bytes())?;
106
107 Ok(PasswordHash {
108 salt: opts.salt,
109 iterations: opts.iterations,
110 hash,
111 })
112}
113
114pub fn scram256_hash(password: &Password, iterations: &NonZeroU32) -> Result<String, HashError> {
118 let hashed_password = hash_password(password, iterations)?;
119 Ok(scram256_hash_inner(hashed_password).to_string())
120}
121
122fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
123 if a.len() != b.len() {
124 return false;
125 }
126 openssl::memcmp::eq(a, b)
127}
128
129pub fn scram256_verify(password: &Password, hashed_password: &str) -> Result<(), VerifyError> {
131 let opts = scram256_parse_opts(hashed_password)?;
132 let hashed = hash_password_with_opts(&opts, password).map_err(VerifyError::Hash)?;
133 let scram = scram256_hash_inner(hashed);
134 if constant_time_compare(hashed_password.as_bytes(), scram.to_string().as_bytes()) {
135 Ok(())
136 } else {
137 Err(VerifyError::InvalidPassword)
138 }
139}
140
141pub fn sasl_verify(
142 hashed_password: &str,
143 proof: &str,
144 auth_message: &str,
145) -> Result<String, VerifyError> {
146 let parts: Vec<&str> = hashed_password.split('$').collect();
148 if parts.len() != 3 {
149 return Err(VerifyError::MalformedHash);
150 }
151 let auth_info = parts[1].split(':').collect::<Vec<&str>>();
152 if auth_info.len() != 2 {
153 return Err(VerifyError::MalformedHash);
154 }
155 let auth_value = parts[2].split(':').collect::<Vec<&str>>();
156 if auth_value.len() != 2 {
157 return Err(VerifyError::MalformedHash);
158 }
159
160 let stored_key = BASE64_STANDARD
161 .decode(auth_value[0])
162 .map_err(|_| VerifyError::MalformedHash)?;
163 let server_key = BASE64_STANDARD
164 .decode(auth_value[1])
165 .map_err(|_| VerifyError::MalformedHash)?;
166
167 let client_signature = generate_signature(&stored_key, auth_message)?;
169
170 let provided_client_proof = BASE64_STANDARD
172 .decode(proof)
173 .map_err(|_| VerifyError::InvalidPassword)?;
174
175 if provided_client_proof.len() != client_signature.len() {
176 return Err(VerifyError::InvalidPassword);
177 }
178
179 let client_key: Vec<u8> = provided_client_proof
181 .iter()
182 .zip_eq(client_signature.iter())
183 .map(|(p, s)| p ^ s)
184 .collect();
185
186 if !constant_time_compare(&openssl::sha::sha256(&client_key), &stored_key) {
187 return Err(VerifyError::InvalidPassword);
188 }
189
190 let verifier = generate_signature(&server_key, auth_message)?;
192 Ok(BASE64_STANDARD.encode(&verifier))
193}
194
195fn generate_signature(key: &[u8], message: &str) -> Result<Vec<u8>, VerifyError> {
196 let signing_key =
197 openssl::pkey::PKey::hmac(key).map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
198 let mut signer =
199 openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key)
200 .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
201 signer
202 .update(message.as_bytes())
203 .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
204 let signature = signer
205 .sign_to_vec()
206 .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
207 Ok(signature)
208}
209
210pub fn mock_sasl_challenge(username: &str, mock_nonce: &str, iterations: &NonZeroU32) -> HashOpts {
214 let mut buf = Vec::with_capacity(username.len() + mock_nonce.len());
215 buf.extend_from_slice(username.as_bytes());
216 buf.extend_from_slice(mock_nonce.as_bytes());
217 let digest = openssl::sha::sha256(&buf);
218
219 HashOpts {
220 iterations: iterations.to_owned(),
221 salt: digest,
222 }
223}
224
225pub fn scram256_parse_opts(hashed_password: &str) -> Result<HashOpts, VerifyError> {
227 let parts: Vec<&str> = hashed_password.split('$').collect();
228 if parts.len() != 3 {
229 return Err(VerifyError::MalformedHash);
230 }
231 let scheme = parts[0];
232 if scheme != "SCRAM-SHA-256" {
233 return Err(VerifyError::MalformedHash);
234 }
235 let auth_info = parts[1].split(':').collect::<Vec<&str>>();
236 if auth_info.len() != 2 {
237 return Err(VerifyError::MalformedHash);
238 }
239 let auth_value = parts[2].split(':').collect::<Vec<&str>>();
240 if auth_value.len() != 2 {
241 return Err(VerifyError::MalformedHash);
242 }
243
244 let iterations = auth_info[0]
245 .parse::<u32>()
246 .map_err(|_| VerifyError::MalformedHash)?;
247
248 let salt = BASE64_STANDARD
249 .decode(auth_info[1])
250 .map_err(|_| VerifyError::MalformedHash)?;
251
252 let salt = salt.try_into().map_err(|_| VerifyError::MalformedHash)?;
253
254 Ok(HashOpts {
255 iterations: NonZeroU32::new(iterations).ok_or(VerifyError::MalformedHash)?,
256 salt,
257 })
258}
259
260struct ScramSha256Hash {
262 iterations: NonZeroU32,
264 salt: [u8; 32],
266 server_key: [u8; SHA256_OUTPUT_LEN],
268 stored_key: [u8; SHA256_OUTPUT_LEN],
270}
271
272impl Display for ScramSha256Hash {
273 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274 write!(
275 f,
276 "SCRAM-SHA-256${}:{}${}:{}",
277 self.iterations,
278 BASE64_STANDARD.encode(&self.salt),
279 BASE64_STANDARD.encode(&self.stored_key),
280 BASE64_STANDARD.encode(&self.server_key)
281 )
282 }
283}
284
285fn scram256_hash_inner(hashed_password: PasswordHash) -> ScramSha256Hash {
286 let signing_key = openssl::pkey::PKey::hmac(&hashed_password.hash).unwrap();
287 let mut signer =
288 openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key).unwrap();
289 signer.update(b"Client Key").unwrap();
290 let client_key = signer.sign_to_vec().unwrap();
291 let stored_key = openssl::sha::sha256(&client_key);
292 let mut signer =
293 openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key).unwrap();
294 signer.update(b"Server Key").unwrap();
295 let mut server_key: [u8; SHA256_OUTPUT_LEN] = [0; SHA256_OUTPUT_LEN];
296 signer.sign(server_key.as_mut()).unwrap();
297
298 ScramSha256Hash {
299 iterations: hashed_password.iterations,
300 salt: hashed_password.salt,
301 server_key,
302 stored_key,
303 }
304}
305
306fn hash_password_inner(
307 opts: &HashOpts,
308 password: &[u8],
309) -> Result<[u8; SHA256_OUTPUT_LEN], HashError> {
310 let mut salted_password = [0u8; SHA256_OUTPUT_LEN];
311 openssl::pkcs5::pbkdf2_hmac(
312 password,
313 &opts.salt,
314 opts.iterations.get().try_into().unwrap(),
315 openssl::hash::MessageDigest::sha256(),
316 &mut salted_password,
317 )
318 .map_err(HashError::Openssl)?;
319 Ok(salted_password)
320}
321
322#[cfg(test)]
323mod tests {
324 use itertools::Itertools;
325
326 use super::*;
327
328 const DEFAULT_ITERATIONS: NonZeroU32 = NonZeroU32::new(60).expect("Trust me on this");
329
330 #[mz_ore::test]
331 #[cfg_attr(miri, ignore)] fn test_hash_password() {
333 let password = "password".to_string();
334 let iterations = NonZeroU32::new(100).expect("Trust me on this");
335 let hashed_password =
336 hash_password(&password.into(), &iterations).expect("Failed to hash password");
337 assert_eq!(hashed_password.iterations, iterations);
338 assert_eq!(hashed_password.salt.len(), DEFAULT_SALT_SIZE);
339 assert_eq!(hashed_password.hash.len(), SHA256_OUTPUT_LEN);
340 }
341
342 #[mz_ore::test]
343 #[cfg_attr(miri, ignore)] fn test_scram256_hash() {
345 let password = "password".into();
346 let scram_hash =
347 scram256_hash(&password, &DEFAULT_ITERATIONS).expect("Failed to hash password");
348
349 let res = scram256_verify(&password, &scram_hash);
350 assert!(res.is_ok());
351 let res = scram256_verify(&"wrong_password".into(), &scram_hash);
352 assert!(res.is_err());
353 }
354
355 #[mz_ore::test]
356 fn test_scram256_parse_opts() {
357 let salt = "9bkIQQjQ7f1OwPsXZGC/YfIkbZsOMDXK0cxxvPBaSfM=";
358 let hashed_password = format!("SCRAM-SHA-256$600000:{}$client-key:server-key", salt);
359 let opts = scram256_parse_opts(&hashed_password);
360
361 assert!(opts.is_ok());
362 let opts = opts.unwrap();
363 assert_eq!(
364 opts.iterations,
365 NonZeroU32::new(600_000).expect("known valid")
366 );
367 assert_eq!(opts.salt.len(), DEFAULT_SALT_SIZE);
368 let decoded_salt = BASE64_STANDARD.decode(salt).expect("Failed to decode salt");
369 assert_eq!(opts.salt, decoded_salt.as_ref());
370 }
371
372 #[mz_ore::test]
373 #[cfg_attr(miri, ignore)]
374 fn test_mock_sasl_challenge() {
375 let username = "alice";
376 let mock = "cnonce";
377 let opts1 = mock_sasl_challenge(username, mock, &DEFAULT_ITERATIONS);
378 let opts2 = mock_sasl_challenge(username, mock, &DEFAULT_ITERATIONS);
379 assert_eq!(opts1, opts2);
380 }
381
382 #[mz_ore::test]
383 #[cfg_attr(miri, ignore)]
384 fn test_sasl_verify_success() {
385 let password: Password = "password".into();
386 let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
387 let auth_message = "n=user,r=clientnonce,s=somesalt"; let parts: Vec<&str> = hashed_password.split('$').collect();
392 assert_eq!(parts.len(), 3);
393 let key_parts: Vec<&str> = parts[2].split(':').collect();
394 assert_eq!(key_parts.len(), 2);
395 let stored_key = BASE64_STANDARD
396 .decode(key_parts[0])
397 .expect("decode stored key");
398 let server_key = BASE64_STANDARD
399 .decode(key_parts[1])
400 .expect("decode server key");
401
402 let client_proof: Vec<u8> = {
404 let opts = scram256_parse_opts(&hashed_password).expect("parse opts");
406 let salted_password = hash_password_with_opts(&opts, &password)
407 .expect("hash password")
408 .hash;
409 let signing_key = openssl::pkey::PKey::hmac(&salted_password).expect("signing key");
410 let mut signer =
411 openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key)
412 .expect("signer");
413 signer.update(b"Client Key").expect("update");
414 let client_key = signer.sign_to_vec().expect("client key");
415 let client_signature =
417 generate_signature(&stored_key, auth_message).expect("client signature");
418 client_key
419 .iter()
420 .zip_eq(client_signature.iter())
421 .map(|(c, s)| c ^ s)
422 .collect::<Vec<u8>>()
423 };
424
425 let client_proof_b64 = BASE64_STANDARD.encode(&client_proof);
426
427 let verifier = sasl_verify(&hashed_password, &client_proof_b64, auth_message)
428 .expect("sasl_verify should succeed");
429
430 let expected_verifier = BASE64_STANDARD
432 .encode(&generate_signature(&server_key, auth_message).expect("server verifier"));
433 assert_eq!(verifier, expected_verifier);
434 }
435
436 #[mz_ore::test]
437 #[cfg_attr(miri, ignore)]
438 fn test_sasl_verify_invalid_proof() {
439 let password: Password = "password".into();
440 let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
441 let auth_message = "n=user,r=clientnonce,s=somesalt";
442 let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
444 let res = sasl_verify(&hashed_password, &bad_proof, auth_message);
445 assert!(matches!(res, Err(VerifyError::InvalidPassword)));
446 }
447
448 #[mz_ore::test]
449 fn test_sasl_verify_malformed_hash() {
450 let malformed_hash = "NOT-SCRAM$bad"; let auth_message = "n=user,r=clientnonce,s=somesalt";
452 let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
453 let res = sasl_verify(malformed_hash, &bad_proof, auth_message);
454 assert!(matches!(res, Err(VerifyError::MalformedHash)));
455 }
456
457 #[mz_ore::test]
458 #[cfg_attr(miri, ignore)]
459 fn test_sasl_verify_truncated_proof_no_panic() {
460 let password: Password = "password".into();
462 let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
463 let auth_message = "n=user,r=clientnonce,s=somesalt";
464
465 let truncated_proof = BASE64_STANDARD.encode([0u8; 16]);
467 let res = sasl_verify(&hashed_password, &truncated_proof, auth_message);
468 assert!(
469 matches!(res, Err(VerifyError::InvalidPassword)),
470 "truncated proof should return InvalidPassword, not panic"
471 );
472
473 let oversized_proof = BASE64_STANDARD.encode([0u8; 64]);
475 let res = sasl_verify(&hashed_password, &oversized_proof, auth_message);
476 assert!(
477 matches!(res, Err(VerifyError::InvalidPassword)),
478 "oversized proof should return InvalidPassword, not panic"
479 );
480
481 let empty_proof = BASE64_STANDARD.encode([0u8; 0]);
483 let res = sasl_verify(&hashed_password, &empty_proof, auth_message);
484 assert!(
485 matches!(res, Err(VerifyError::InvalidPassword)),
486 "empty proof should return InvalidPassword, not panic"
487 );
488 }
489}