1#![allow(rustdoc::invalid_html_tags)]
12
13use std::fmt::Display;
14use std::num::NonZeroU32;
15
16use base64::prelude::*;
17use itertools::Itertools;
18use mz_ore::secure::{Zeroize, Zeroizing};
19
20use crate::password::Password;
21
22const DEFAULT_SALT_SIZE: usize = 32;
24
25const SHA256_OUTPUT_LEN: usize = 32;
26
27#[derive(Debug, PartialEq)]
29pub struct HashOpts {
30 pub iterations: NonZeroU32,
32 pub salt: [u8; DEFAULT_SALT_SIZE],
36}
37
38impl Drop for HashOpts {
39 fn drop(&mut self) {
40 self.salt.zeroize();
41 }
42}
43
44pub struct PasswordHash {
45 pub salt: [u8; DEFAULT_SALT_SIZE],
47 pub iterations: NonZeroU32,
49 pub hash: [u8; SHA256_OUTPUT_LEN],
52}
53
54impl Drop for PasswordHash {
55 fn drop(&mut self) {
56 self.salt.zeroize();
57 self.hash.zeroize();
58 }
59}
60
61#[derive(Debug)]
62pub enum VerifyError {
63 MalformedHash,
64 InvalidPassword,
65 Hash(HashError),
66}
67
68#[derive(Debug)]
69pub enum HashError {
70 Openssl(openssl::error::ErrorStack),
71}
72
73impl Display for HashError {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 match self {
76 HashError::Openssl(e) => write!(f, "OpenSSL error: {}", e),
77 }
78 }
79}
80
81pub fn hash_password(
84 password: &Password,
85 iterations: &NonZeroU32,
86) -> Result<PasswordHash, HashError> {
87 let mut salt = Zeroizing::new([0u8; DEFAULT_SALT_SIZE]);
88 openssl::rand::rand_bytes(&mut *salt).map_err(HashError::Openssl)?;
89
90 let hash = hash_password_inner(
91 &HashOpts {
92 iterations: iterations.to_owned(),
93 salt: *salt,
94 },
95 password.as_bytes(),
96 )?;
97
98 Ok(PasswordHash {
99 salt: *salt,
100 iterations: iterations.to_owned(),
101 hash,
102 })
103}
104
105pub fn generate_nonce(client_nonce: &str) -> Result<String, HashError> {
106 let mut nonce = Zeroizing::new([0u8; 24]);
107 openssl::rand::rand_bytes(&mut *nonce).map_err(HashError::Openssl)?;
108 let nonce = BASE64_STANDARD.encode(&*nonce);
109 let new_nonce = format!("{}{}", client_nonce, nonce);
110 Ok(new_nonce)
111}
112
113pub fn hash_password_with_opts(
116 opts: &HashOpts,
117 password: &Password,
118) -> Result<PasswordHash, HashError> {
119 let hash = hash_password_inner(opts, password.as_bytes())?;
120
121 Ok(PasswordHash {
122 salt: opts.salt,
123 iterations: opts.iterations,
124 hash,
125 })
126}
127
128pub fn scram256_hash(password: &Password, iterations: &NonZeroU32) -> Result<String, HashError> {
132 let hashed_password = hash_password(password, iterations)?;
133 Ok(scram256_hash_inner(hashed_password).to_string())
134}
135
136fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
137 if a.len() != b.len() {
138 return false;
139 }
140 openssl::memcmp::eq(a, b)
141}
142
143pub fn scram256_verify(password: &Password, hashed_password: &str) -> Result<(), VerifyError> {
145 let opts = scram256_parse_opts(hashed_password)?;
146 let hashed = hash_password_with_opts(&opts, password).map_err(VerifyError::Hash)?;
147 let scram = scram256_hash_inner(hashed);
148 if constant_time_compare(hashed_password.as_bytes(), scram.to_string().as_bytes()) {
149 Ok(())
150 } else {
151 Err(VerifyError::InvalidPassword)
152 }
153}
154
155pub fn sasl_verify(
156 hashed_password: &str,
157 proof: &str,
158 auth_message: &str,
159) -> Result<String, VerifyError> {
160 let parts: Vec<&str> = hashed_password.split('$').collect();
162 if parts.len() != 3 {
163 return Err(VerifyError::MalformedHash);
164 }
165 let auth_info = parts[1].split(':').collect::<Vec<&str>>();
166 if auth_info.len() != 2 {
167 return Err(VerifyError::MalformedHash);
168 }
169 let auth_value = parts[2].split(':').collect::<Vec<&str>>();
170 if auth_value.len() != 2 {
171 return Err(VerifyError::MalformedHash);
172 }
173
174 let stored_key = Zeroizing::new(
175 BASE64_STANDARD
176 .decode(auth_value[0])
177 .map_err(|_| VerifyError::MalformedHash)?,
178 );
179 let server_key = Zeroizing::new(
180 BASE64_STANDARD
181 .decode(auth_value[1])
182 .map_err(|_| VerifyError::MalformedHash)?,
183 );
184
185 let client_signature = Zeroizing::new(generate_signature(&stored_key, auth_message)?);
187
188 let provided_client_proof = Zeroizing::new(
190 BASE64_STANDARD
191 .decode(proof)
192 .map_err(|_| VerifyError::InvalidPassword)?,
193 );
194
195 if provided_client_proof.len() != client_signature.len() {
196 return Err(VerifyError::InvalidPassword);
197 }
198
199 let client_key: Zeroizing<Vec<u8>> = Zeroizing::new(
201 provided_client_proof
202 .iter()
203 .zip_eq(client_signature.iter())
204 .map(|(p, s)| p ^ s)
205 .collect(),
206 );
207
208 if !constant_time_compare(&openssl::sha::sha256(&client_key), &stored_key) {
209 return Err(VerifyError::InvalidPassword);
210 }
211
212 let verifier = Zeroizing::new(generate_signature(&server_key, auth_message)?);
214 Ok(BASE64_STANDARD.encode(&*verifier))
215}
216
217fn generate_signature(key: &[u8], message: &str) -> Result<Zeroizing<Vec<u8>>, VerifyError> {
218 let signing_key =
219 openssl::pkey::PKey::hmac(key).map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
220 let mut signer =
221 openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key)
222 .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
223 signer
224 .update(message.as_bytes())
225 .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
226 let signature = signer
227 .sign_to_vec()
228 .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
229 Ok(Zeroizing::new(signature))
230}
231
232pub fn mock_sasl_challenge(username: &str, mock_nonce: &str, iterations: &NonZeroU32) -> HashOpts {
236 let mut buf = Vec::with_capacity(username.len() + mock_nonce.len());
237 buf.extend_from_slice(username.as_bytes());
238 buf.extend_from_slice(mock_nonce.as_bytes());
239 let digest = openssl::sha::sha256(&buf);
240
241 HashOpts {
242 iterations: iterations.to_owned(),
243 salt: digest,
244 }
245}
246
247pub fn scram256_parse_opts(hashed_password: &str) -> Result<HashOpts, VerifyError> {
249 let parts: Vec<&str> = hashed_password.split('$').collect();
250 if parts.len() != 3 {
251 return Err(VerifyError::MalformedHash);
252 }
253 let scheme = parts[0];
254 if scheme != "SCRAM-SHA-256" {
255 return Err(VerifyError::MalformedHash);
256 }
257 let auth_info = parts[1].split(':').collect::<Vec<&str>>();
258 if auth_info.len() != 2 {
259 return Err(VerifyError::MalformedHash);
260 }
261 let auth_value = parts[2].split(':').collect::<Vec<&str>>();
262 if auth_value.len() != 2 {
263 return Err(VerifyError::MalformedHash);
264 }
265
266 let iterations = auth_info[0]
267 .parse::<u32>()
268 .map_err(|_| VerifyError::MalformedHash)?;
269
270 let salt = BASE64_STANDARD
271 .decode(auth_info[1])
272 .map_err(|_| VerifyError::MalformedHash)?;
273
274 let salt = salt.try_into().map_err(|_| VerifyError::MalformedHash)?;
275
276 Ok(HashOpts {
277 iterations: NonZeroU32::new(iterations).ok_or(VerifyError::MalformedHash)?,
278 salt,
279 })
280}
281
282struct ScramSha256Hash {
284 iterations: NonZeroU32,
286 salt: [u8; 32],
288 server_key: [u8; SHA256_OUTPUT_LEN],
290 stored_key: [u8; SHA256_OUTPUT_LEN],
292}
293
294impl Drop for ScramSha256Hash {
295 fn drop(&mut self) {
296 self.salt.zeroize();
297 self.server_key.zeroize();
298 self.stored_key.zeroize();
299 }
300}
301
302impl Display for ScramSha256Hash {
303 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304 write!(
305 f,
306 "SCRAM-SHA-256${}:{}${}:{}",
307 self.iterations,
308 BASE64_STANDARD.encode(&self.salt),
309 BASE64_STANDARD.encode(&self.stored_key),
310 BASE64_STANDARD.encode(&self.server_key)
311 )
312 }
313}
314
315fn scram256_hash_inner(hashed_password: PasswordHash) -> ScramSha256Hash {
316 let signing_key = openssl::pkey::PKey::hmac(&hashed_password.hash).unwrap();
317 let mut signer =
318 openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key).unwrap();
319 signer.update(b"Client Key").unwrap();
320 let client_key = Zeroizing::new(signer.sign_to_vec().unwrap());
321 let stored_key = openssl::sha::sha256(&client_key);
322 let mut signer =
323 openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key).unwrap();
324 signer.update(b"Server Key").unwrap();
325 let mut server_key = Zeroizing::new([0u8; SHA256_OUTPUT_LEN]);
326 signer.sign(server_key.as_mut()).unwrap();
327
328 ScramSha256Hash {
329 iterations: hashed_password.iterations,
330 salt: hashed_password.salt,
331 server_key: *server_key,
332 stored_key,
333 }
334}
335
336fn hash_password_inner(
337 opts: &HashOpts,
338 password: &[u8],
339) -> Result<[u8; SHA256_OUTPUT_LEN], HashError> {
340 let mut salted_password = Zeroizing::new([0u8; SHA256_OUTPUT_LEN]);
341 openssl::pkcs5::pbkdf2_hmac(
342 password,
343 &opts.salt,
344 opts.iterations.get().try_into().unwrap(),
345 openssl::hash::MessageDigest::sha256(),
346 &mut *salted_password,
347 )
348 .map_err(HashError::Openssl)?;
349 Ok(*salted_password)
350}
351
352#[cfg(test)]
353mod tests {
354 use itertools::Itertools;
355
356 use super::*;
357
358 const DEFAULT_ITERATIONS: NonZeroU32 = NonZeroU32::new(60).expect("Trust me on this");
359
360 #[mz_ore::test]
361 #[cfg_attr(miri, ignore)] fn test_hash_password() {
363 let password = "password".to_string();
364 let iterations = NonZeroU32::new(100).expect("Trust me on this");
365 let hashed_password =
366 hash_password(&password.into(), &iterations).expect("Failed to hash password");
367 assert_eq!(hashed_password.iterations, iterations);
368 assert_eq!(hashed_password.salt.len(), DEFAULT_SALT_SIZE);
369 assert_eq!(hashed_password.hash.len(), SHA256_OUTPUT_LEN);
370 }
371
372 #[mz_ore::test]
373 #[cfg_attr(miri, ignore)] fn test_scram256_hash() {
375 let password = "password".into();
376 let scram_hash =
377 scram256_hash(&password, &DEFAULT_ITERATIONS).expect("Failed to hash password");
378
379 let res = scram256_verify(&password, &scram_hash);
380 assert!(res.is_ok());
381 let res = scram256_verify(&"wrong_password".into(), &scram_hash);
382 assert!(res.is_err());
383 }
384
385 #[mz_ore::test]
386 fn test_scram256_parse_opts() {
387 let salt = "9bkIQQjQ7f1OwPsXZGC/YfIkbZsOMDXK0cxxvPBaSfM=";
388 let hashed_password = format!("SCRAM-SHA-256$600000:{}$client-key:server-key", salt);
389 let opts = scram256_parse_opts(&hashed_password);
390
391 assert!(opts.is_ok());
392 let opts = opts.unwrap();
393 assert_eq!(
394 opts.iterations,
395 NonZeroU32::new(600_000).expect("known valid")
396 );
397 assert_eq!(opts.salt.len(), DEFAULT_SALT_SIZE);
398 let decoded_salt = BASE64_STANDARD.decode(salt).expect("Failed to decode salt");
399 assert_eq!(opts.salt, decoded_salt.as_ref());
400 }
401
402 #[mz_ore::test]
403 #[cfg_attr(miri, ignore)]
404 fn test_mock_sasl_challenge() {
405 let username = "alice";
406 let mock = "cnonce";
407 let opts1 = mock_sasl_challenge(username, mock, &DEFAULT_ITERATIONS);
408 let opts2 = mock_sasl_challenge(username, mock, &DEFAULT_ITERATIONS);
409 assert_eq!(opts1, opts2);
410 }
411
412 #[mz_ore::test]
413 #[cfg_attr(miri, ignore)]
414 fn test_sasl_verify_success() {
415 let password: Password = "password".into();
416 let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
417 let auth_message = "n=user,r=clientnonce,s=somesalt"; let parts: Vec<&str> = hashed_password.split('$').collect();
422 assert_eq!(parts.len(), 3);
423 let key_parts: Vec<&str> = parts[2].split(':').collect();
424 assert_eq!(key_parts.len(), 2);
425 let stored_key = BASE64_STANDARD
426 .decode(key_parts[0])
427 .expect("decode stored key");
428 let server_key = BASE64_STANDARD
429 .decode(key_parts[1])
430 .expect("decode server key");
431
432 let client_proof: Vec<u8> = {
434 let opts = scram256_parse_opts(&hashed_password).expect("parse opts");
436 let salted_password = hash_password_with_opts(&opts, &password)
437 .expect("hash password")
438 .hash;
439 let signing_key = openssl::pkey::PKey::hmac(&salted_password).expect("signing key");
440 let mut signer =
441 openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key)
442 .expect("signer");
443 signer.update(b"Client Key").expect("update");
444 let client_key = signer.sign_to_vec().expect("client key");
445 let client_signature =
447 generate_signature(&stored_key, auth_message).expect("client signature");
448 client_key
449 .iter()
450 .zip_eq(client_signature.iter())
451 .map(|(c, s)| c ^ s)
452 .collect::<Vec<u8>>()
453 };
454
455 let client_proof_b64 = BASE64_STANDARD.encode(&client_proof);
456
457 let verifier = sasl_verify(&hashed_password, &client_proof_b64, auth_message)
458 .expect("sasl_verify should succeed");
459
460 let expected_verifier = BASE64_STANDARD
462 .encode(&generate_signature(&server_key, auth_message).expect("server verifier"));
463 assert_eq!(verifier, expected_verifier);
464 }
465
466 #[mz_ore::test]
467 #[cfg_attr(miri, ignore)]
468 fn test_sasl_verify_invalid_proof() {
469 let password: Password = "password".into();
470 let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
471 let auth_message = "n=user,r=clientnonce,s=somesalt";
472 let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
474 let res = sasl_verify(&hashed_password, &bad_proof, auth_message);
475 assert!(matches!(res, Err(VerifyError::InvalidPassword)));
476 }
477
478 #[mz_ore::test]
479 fn test_sasl_verify_malformed_hash() {
480 let malformed_hash = "NOT-SCRAM$bad"; let auth_message = "n=user,r=clientnonce,s=somesalt";
482 let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
483 let res = sasl_verify(malformed_hash, &bad_proof, auth_message);
484 assert!(matches!(res, Err(VerifyError::MalformedHash)));
485 }
486
487 #[mz_ore::test]
488 #[cfg_attr(miri, ignore)]
489 fn test_sasl_verify_truncated_proof_no_panic() {
490 let password: Password = "password".into();
492 let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
493 let auth_message = "n=user,r=clientnonce,s=somesalt";
494
495 let truncated_proof = BASE64_STANDARD.encode([0u8; 16]);
497 let res = sasl_verify(&hashed_password, &truncated_proof, auth_message);
498 assert!(
499 matches!(res, Err(VerifyError::InvalidPassword)),
500 "truncated proof should return InvalidPassword, not panic"
501 );
502
503 let oversized_proof = BASE64_STANDARD.encode([0u8; 64]);
505 let res = sasl_verify(&hashed_password, &oversized_proof, auth_message);
506 assert!(
507 matches!(res, Err(VerifyError::InvalidPassword)),
508 "oversized proof should return InvalidPassword, not panic"
509 );
510
511 let empty_proof = BASE64_STANDARD.encode([0u8; 0]);
513 let res = sasl_verify(&hashed_password, &empty_proof, auth_message);
514 assert!(
515 matches!(res, Err(VerifyError::InvalidPassword)),
516 "empty proof should return InvalidPassword, not panic"
517 );
518 }
519}