1#![allow(rustdoc::invalid_html_tags)]
12
13use std::fmt::Display;
14use std::num::NonZeroU32;
15
16use aws_lc_rs::constant_time::verify_slices_are_equal;
17use aws_lc_rs::digest;
18use aws_lc_rs::hmac;
19use aws_lc_rs::rand::{SecureRandom, SystemRandom};
20use base64::prelude::*;
21use itertools::Itertools;
22use mz_ore::secure::{Zeroize, Zeroizing};
23
24use crate::password::Password;
25
26const DEFAULT_SALT_SIZE: usize = 32;
28
29const SHA256_OUTPUT_LEN: usize = 32;
30
31#[derive(Debug, PartialEq)]
33pub struct HashOpts {
34 pub iterations: NonZeroU32,
36 pub salt: [u8; DEFAULT_SALT_SIZE],
40}
41
42impl Drop for HashOpts {
43 fn drop(&mut self) {
44 self.salt.zeroize();
45 }
46}
47
48pub struct PasswordHash {
49 pub salt: [u8; DEFAULT_SALT_SIZE],
51 pub iterations: NonZeroU32,
53 pub hash: [u8; SHA256_OUTPUT_LEN],
56}
57
58impl Drop for PasswordHash {
59 fn drop(&mut self) {
60 self.salt.zeroize();
61 self.hash.zeroize();
62 }
63}
64
65#[derive(Debug)]
66pub enum VerifyError {
67 MalformedHash,
68 InvalidPassword,
69 Hash(HashError),
70}
71
72#[derive(Debug)]
73pub enum HashError {
74 Crypto(aws_lc_rs::error::Unspecified),
75}
76
77impl Display for HashError {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 match self {
80 HashError::Crypto(e) => write!(f, "crypto error: {}", e),
81 }
82 }
83}
84
85pub fn hash_password(
88 password: &Password,
89 iterations: &NonZeroU32,
90) -> Result<PasswordHash, HashError> {
91 let rng = SystemRandom::new();
92 let mut salt = Zeroizing::new([0u8; DEFAULT_SALT_SIZE]);
93 rng.fill(&mut *salt).map_err(HashError::Crypto)?;
94
95 let hash = hash_password_inner(
96 &HashOpts {
97 iterations: iterations.to_owned(),
98 salt: *salt,
99 },
100 password.as_bytes(),
101 )?;
102
103 Ok(PasswordHash {
104 salt: *salt,
105 iterations: iterations.to_owned(),
106 hash,
107 })
108}
109
110pub fn generate_nonce(client_nonce: &str) -> Result<String, HashError> {
111 let rng = SystemRandom::new();
112 let mut nonce = Zeroizing::new([0u8; 24]);
113 rng.fill(&mut *nonce).map_err(HashError::Crypto)?;
114 let nonce = BASE64_STANDARD.encode(&*nonce);
115 let new_nonce = format!("{}{}", client_nonce, nonce);
116 Ok(new_nonce)
117}
118
119pub fn hash_password_with_opts(
122 opts: &HashOpts,
123 password: &Password,
124) -> Result<PasswordHash, HashError> {
125 let hash = hash_password_inner(opts, password.as_bytes())?;
126
127 Ok(PasswordHash {
128 salt: opts.salt,
129 iterations: opts.iterations,
130 hash,
131 })
132}
133
134pub fn scram256_hash(password: &Password, iterations: &NonZeroU32) -> Result<String, HashError> {
138 let hashed_password = hash_password(password, iterations)?;
139 Ok(scram256_hash_inner(hashed_password).to_string())
140}
141
142fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
143 verify_slices_are_equal(a, b).is_ok()
144}
145
146pub fn scram256_verify(password: &Password, hashed_password: &str) -> Result<(), VerifyError> {
148 let opts = scram256_parse_opts(hashed_password)?;
149 let hashed = hash_password_with_opts(&opts, password).map_err(VerifyError::Hash)?;
150 let scram = scram256_hash_inner(hashed);
151 if constant_time_compare(hashed_password.as_bytes(), scram.to_string().as_bytes()) {
152 Ok(())
153 } else {
154 Err(VerifyError::InvalidPassword)
155 }
156}
157
158pub fn sasl_verify(
159 hashed_password: &str,
160 proof: &str,
161 auth_message: &str,
162) -> Result<String, VerifyError> {
163 let parts: Vec<&str> = hashed_password.split('$').collect();
165 if parts.len() != 3 {
166 return Err(VerifyError::MalformedHash);
167 }
168 let auth_info = parts[1].split(':').collect::<Vec<&str>>();
169 if auth_info.len() != 2 {
170 return Err(VerifyError::MalformedHash);
171 }
172 let auth_value = parts[2].split(':').collect::<Vec<&str>>();
173 if auth_value.len() != 2 {
174 return Err(VerifyError::MalformedHash);
175 }
176
177 let stored_key = Zeroizing::new(
178 BASE64_STANDARD
179 .decode(auth_value[0])
180 .map_err(|_| VerifyError::MalformedHash)?,
181 );
182 let server_key = Zeroizing::new(
183 BASE64_STANDARD
184 .decode(auth_value[1])
185 .map_err(|_| VerifyError::MalformedHash)?,
186 );
187
188 let client_signature = Zeroizing::new(generate_signature(&stored_key, auth_message)?);
190
191 let provided_client_proof = Zeroizing::new(
193 BASE64_STANDARD
194 .decode(proof)
195 .map_err(|_| VerifyError::InvalidPassword)?,
196 );
197
198 if provided_client_proof.len() != client_signature.len() {
199 return Err(VerifyError::InvalidPassword);
200 }
201
202 let client_key: Zeroizing<Vec<u8>> = Zeroizing::new(
204 provided_client_proof
205 .iter()
206 .zip_eq(client_signature.iter())
207 .map(|(p, s)| p ^ s)
208 .collect(),
209 );
210
211 let computed_stored_key = digest::digest(&digest::SHA256, &client_key);
212 if !constant_time_compare(computed_stored_key.as_ref(), &stored_key) {
213 return Err(VerifyError::InvalidPassword);
214 }
215
216 let verifier = Zeroizing::new(generate_signature(&server_key, auth_message)?);
218 Ok(BASE64_STANDARD.encode(&*verifier))
219}
220
221fn generate_signature(key: &[u8], message: &str) -> Result<Zeroizing<Vec<u8>>, VerifyError> {
222 let signing_key = hmac::Key::new(hmac::HMAC_SHA256, key);
223 let tag = hmac::sign(&signing_key, message.as_bytes());
224 Ok(Zeroizing::new(tag.as_ref().to_vec()))
225}
226
227pub fn mock_sasl_challenge(username: &str, mock_nonce: &str, iterations: &NonZeroU32) -> HashOpts {
231 let mut buf = Vec::with_capacity(username.len() + mock_nonce.len());
232 buf.extend_from_slice(username.as_bytes());
233 buf.extend_from_slice(mock_nonce.as_bytes());
234 let hash = digest::digest(&digest::SHA256, &buf);
235 let mut salt = [0u8; DEFAULT_SALT_SIZE];
236 salt.copy_from_slice(hash.as_ref());
237
238 HashOpts {
239 iterations: iterations.to_owned(),
240 salt,
241 }
242}
243
244pub fn scram256_parse_opts(hashed_password: &str) -> Result<HashOpts, VerifyError> {
246 let parts: Vec<&str> = hashed_password.split('$').collect();
247 if parts.len() != 3 {
248 return Err(VerifyError::MalformedHash);
249 }
250 let scheme = parts[0];
251 if scheme != "SCRAM-SHA-256" {
252 return Err(VerifyError::MalformedHash);
253 }
254 let auth_info = parts[1].split(':').collect::<Vec<&str>>();
255 if auth_info.len() != 2 {
256 return Err(VerifyError::MalformedHash);
257 }
258 let auth_value = parts[2].split(':').collect::<Vec<&str>>();
259 if auth_value.len() != 2 {
260 return Err(VerifyError::MalformedHash);
261 }
262
263 let iterations = auth_info[0]
264 .parse::<u32>()
265 .map_err(|_| VerifyError::MalformedHash)?;
266
267 let salt = BASE64_STANDARD
268 .decode(auth_info[1])
269 .map_err(|_| VerifyError::MalformedHash)?;
270
271 let salt = salt.try_into().map_err(|_| VerifyError::MalformedHash)?;
272
273 Ok(HashOpts {
274 iterations: NonZeroU32::new(iterations).ok_or(VerifyError::MalformedHash)?,
275 salt,
276 })
277}
278
279struct ScramSha256Hash {
281 iterations: NonZeroU32,
283 salt: [u8; 32],
285 server_key: [u8; SHA256_OUTPUT_LEN],
287 stored_key: [u8; SHA256_OUTPUT_LEN],
289}
290
291impl Drop for ScramSha256Hash {
292 fn drop(&mut self) {
293 self.salt.zeroize();
294 self.server_key.zeroize();
295 self.stored_key.zeroize();
296 }
297}
298
299impl Display for ScramSha256Hash {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 write!(
302 f,
303 "SCRAM-SHA-256${}:{}${}:{}",
304 self.iterations,
305 BASE64_STANDARD.encode(&self.salt),
306 BASE64_STANDARD.encode(&self.stored_key),
307 BASE64_STANDARD.encode(&self.server_key)
308 )
309 }
310}
311
312fn scram256_hash_inner(hashed_password: PasswordHash) -> ScramSha256Hash {
313 let signing_key = hmac::Key::new(hmac::HMAC_SHA256, &hashed_password.hash);
314 let client_key_tag = hmac::sign(&signing_key, b"Client Key");
315 let client_key = Zeroizing::new(client_key_tag.as_ref().to_vec());
316 let stored_key_digest = digest::digest(&digest::SHA256, &client_key);
317 let mut stored_key = [0u8; SHA256_OUTPUT_LEN];
318 stored_key.copy_from_slice(stored_key_digest.as_ref());
319
320 let server_key_tag = hmac::sign(&signing_key, b"Server Key");
321 let mut server_key = Zeroizing::new([0u8; SHA256_OUTPUT_LEN]);
322 server_key.copy_from_slice(server_key_tag.as_ref());
323
324 ScramSha256Hash {
325 iterations: hashed_password.iterations,
326 salt: hashed_password.salt,
327 server_key: *server_key,
328 stored_key,
329 }
330}
331
332fn hash_password_inner(
333 opts: &HashOpts,
334 password: &[u8],
335) -> Result<[u8; SHA256_OUTPUT_LEN], HashError> {
336 let mut salted_password = Zeroizing::new([0u8; SHA256_OUTPUT_LEN]);
337 aws_lc_rs::pbkdf2::derive(
338 aws_lc_rs::pbkdf2::PBKDF2_HMAC_SHA256,
339 opts.iterations,
340 &opts.salt,
341 password,
342 &mut *salted_password,
343 );
344 Ok(*salted_password)
345}
346
347#[cfg(test)]
348mod tests {
349 use itertools::Itertools;
350
351 use super::*;
352
353 const DEFAULT_ITERATIONS: NonZeroU32 = NonZeroU32::new(60).expect("Trust me on this");
354
355 #[mz_ore::test]
356 #[cfg_attr(miri, ignore)] fn test_hash_password() {
358 let password = "password".to_string();
359 let iterations = NonZeroU32::new(100).expect("Trust me on this");
360 let hashed_password =
361 hash_password(&password.into(), &iterations).expect("Failed to hash password");
362 assert_eq!(hashed_password.iterations, iterations);
363 assert_eq!(hashed_password.salt.len(), DEFAULT_SALT_SIZE);
364 assert_eq!(hashed_password.hash.len(), SHA256_OUTPUT_LEN);
365 }
366
367 #[mz_ore::test]
368 #[cfg_attr(miri, ignore)] fn test_scram256_hash() {
370 let password = "password".into();
371 let scram_hash =
372 scram256_hash(&password, &DEFAULT_ITERATIONS).expect("Failed to hash password");
373
374 let res = scram256_verify(&password, &scram_hash);
375 assert!(res.is_ok());
376 let res = scram256_verify(&"wrong_password".into(), &scram_hash);
377 assert!(res.is_err());
378 }
379
380 #[mz_ore::test]
381 fn test_scram256_parse_opts() {
382 let salt = "9bkIQQjQ7f1OwPsXZGC/YfIkbZsOMDXK0cxxvPBaSfM=";
383 let hashed_password = format!("SCRAM-SHA-256$600000:{}$client-key:server-key", salt);
384 let opts = scram256_parse_opts(&hashed_password);
385
386 assert!(opts.is_ok());
387 let opts = opts.unwrap();
388 assert_eq!(
389 opts.iterations,
390 NonZeroU32::new(600_000).expect("known valid")
391 );
392 assert_eq!(opts.salt.len(), DEFAULT_SALT_SIZE);
393 let decoded_salt = BASE64_STANDARD.decode(salt).expect("Failed to decode salt");
394 assert_eq!(opts.salt, decoded_salt.as_ref());
395 }
396
397 #[mz_ore::test]
398 #[cfg_attr(miri, ignore)]
399 fn test_mock_sasl_challenge() {
400 let username = "alice";
401 let mock = "cnonce";
402 let opts1 = mock_sasl_challenge(username, mock, &DEFAULT_ITERATIONS);
403 let opts2 = mock_sasl_challenge(username, mock, &DEFAULT_ITERATIONS);
404 assert_eq!(opts1, opts2);
405 }
406
407 #[mz_ore::test]
408 #[cfg_attr(miri, ignore)]
409 fn test_sasl_verify_success() {
410 let password: Password = "password".into();
411 let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
412 let auth_message = "n=user,r=clientnonce,s=somesalt"; let parts: Vec<&str> = hashed_password.split('$').collect();
417 assert_eq!(parts.len(), 3);
418 let key_parts: Vec<&str> = parts[2].split(':').collect();
419 assert_eq!(key_parts.len(), 2);
420 let stored_key = BASE64_STANDARD
421 .decode(key_parts[0])
422 .expect("decode stored key");
423 let server_key = BASE64_STANDARD
424 .decode(key_parts[1])
425 .expect("decode server key");
426
427 let client_proof: Vec<u8> = {
429 let opts = scram256_parse_opts(&hashed_password).expect("parse opts");
431 let salted_password = hash_password_with_opts(&opts, &password)
432 .expect("hash password")
433 .hash;
434 let signing_key = hmac::Key::new(hmac::HMAC_SHA256, &salted_password);
435 let client_key = hmac::sign(&signing_key, b"Client Key");
436 let client_key = client_key.as_ref();
437 let client_signature =
439 generate_signature(&stored_key, auth_message).expect("client signature");
440 client_key
441 .iter()
442 .zip_eq(client_signature.iter())
443 .map(|(c, s)| c ^ s)
444 .collect::<Vec<u8>>()
445 };
446
447 let client_proof_b64 = BASE64_STANDARD.encode(&client_proof);
448
449 let verifier = sasl_verify(&hashed_password, &client_proof_b64, auth_message)
450 .expect("sasl_verify should succeed");
451
452 let expected_verifier = BASE64_STANDARD
454 .encode(&generate_signature(&server_key, auth_message).expect("server verifier"));
455 assert_eq!(verifier, expected_verifier);
456 }
457
458 #[mz_ore::test]
459 #[cfg_attr(miri, ignore)]
460 fn test_sasl_verify_invalid_proof() {
461 let password: Password = "password".into();
462 let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
463 let auth_message = "n=user,r=clientnonce,s=somesalt";
464 let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
466 let res = sasl_verify(&hashed_password, &bad_proof, auth_message);
467 assert!(matches!(res, Err(VerifyError::InvalidPassword)));
468 }
469
470 #[mz_ore::test]
471 fn test_sasl_verify_malformed_hash() {
472 let malformed_hash = "NOT-SCRAM$bad"; let auth_message = "n=user,r=clientnonce,s=somesalt";
474 let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
475 let res = sasl_verify(malformed_hash, &bad_proof, auth_message);
476 assert!(matches!(res, Err(VerifyError::MalformedHash)));
477 }
478
479 #[mz_ore::test]
480 #[cfg_attr(miri, ignore)]
481 fn test_sasl_verify_truncated_proof_no_panic() {
482 let password: Password = "password".into();
484 let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
485 let auth_message = "n=user,r=clientnonce,s=somesalt";
486
487 let truncated_proof = BASE64_STANDARD.encode([0u8; 16]);
489 let res = sasl_verify(&hashed_password, &truncated_proof, auth_message);
490 assert!(
491 matches!(res, Err(VerifyError::InvalidPassword)),
492 "truncated proof should return InvalidPassword, not panic"
493 );
494
495 let oversized_proof = BASE64_STANDARD.encode([0u8; 64]);
497 let res = sasl_verify(&hashed_password, &oversized_proof, auth_message);
498 assert!(
499 matches!(res, Err(VerifyError::InvalidPassword)),
500 "oversized proof should return InvalidPassword, not panic"
501 );
502
503 let empty_proof = BASE64_STANDARD.encode([0u8; 0]);
505 let res = sasl_verify(&hashed_password, &empty_proof, auth_message);
506 assert!(
507 matches!(res, Err(VerifyError::InvalidPassword)),
508 "empty proof should return InvalidPassword, not panic"
509 );
510 }
511}