Skip to main content

mz_auth/
hash.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10// Clippy misreads some doc comments as HTML tags, so we disable the lint
11#![allow(rustdoc::invalid_html_tags)]
12
13use std::fmt::Display;
14use std::num::NonZeroU32;
15
16use aws_lc_rs::constant_time::verify_slices_are_equal;
17use aws_lc_rs::digest;
18use aws_lc_rs::hmac;
19use aws_lc_rs::rand::{SecureRandom, SystemRandom};
20use base64::prelude::*;
21use itertools::Itertools;
22use mz_ore::secure::{Zeroize, Zeroizing};
23
24use crate::password::Password;
25
26/// The default salt size, which isn't currently configurable.
27const DEFAULT_SALT_SIZE: usize = 32;
28
29const SHA256_OUTPUT_LEN: usize = 32;
30
31/// The options for hashing a password
32#[derive(Debug, PartialEq)]
33pub struct HashOpts {
34    /// The number of iterations to use for PBKDF2
35    pub iterations: NonZeroU32,
36    /// The salt to use for PBKDF2. It is up to the caller to
37    /// ensure that however the salt is generated, it is cryptographically
38    /// secure.
39    pub salt: [u8; DEFAULT_SALT_SIZE],
40}
41
42impl Drop for HashOpts {
43    fn drop(&mut self) {
44        self.salt.zeroize();
45    }
46}
47
48pub struct PasswordHash {
49    /// The salt used for hashing
50    pub salt: [u8; DEFAULT_SALT_SIZE],
51    /// The number of iterations used for hashing
52    pub iterations: NonZeroU32,
53    /// The hash of the password.
54    /// This is the result of PBKDF2 with SHA256
55    pub hash: [u8; SHA256_OUTPUT_LEN],
56}
57
58impl Drop for PasswordHash {
59    fn drop(&mut self) {
60        self.salt.zeroize();
61        self.hash.zeroize();
62    }
63}
64
65#[derive(Debug)]
66pub enum VerifyError {
67    MalformedHash,
68    InvalidPassword,
69    Hash(HashError),
70}
71
72#[derive(Debug)]
73pub enum HashError {
74    Crypto(aws_lc_rs::error::Unspecified),
75}
76
77impl Display for HashError {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        match self {
80            HashError::Crypto(e) => write!(f, "crypto error: {}", e),
81        }
82    }
83}
84
85/// Hashes a password using PBKDF2 with SHA256
86/// and a random salt.
87pub fn hash_password(
88    password: &Password,
89    iterations: &NonZeroU32,
90) -> Result<PasswordHash, HashError> {
91    let rng = SystemRandom::new();
92    let mut salt = Zeroizing::new([0u8; DEFAULT_SALT_SIZE]);
93    rng.fill(&mut *salt).map_err(HashError::Crypto)?;
94
95    let hash = hash_password_inner(
96        &HashOpts {
97            iterations: iterations.to_owned(),
98            salt: *salt,
99        },
100        password.as_bytes(),
101    )?;
102
103    Ok(PasswordHash {
104        salt: *salt,
105        iterations: iterations.to_owned(),
106        hash,
107    })
108}
109
110pub fn generate_nonce(client_nonce: &str) -> Result<String, HashError> {
111    let rng = SystemRandom::new();
112    let mut nonce = Zeroizing::new([0u8; 24]);
113    rng.fill(&mut *nonce).map_err(HashError::Crypto)?;
114    let nonce = BASE64_STANDARD.encode(&*nonce);
115    let new_nonce = format!("{}{}", client_nonce, nonce);
116    Ok(new_nonce)
117}
118
119/// Hashes a password using PBKDF2 with SHA256
120/// and the given options.
121pub fn hash_password_with_opts(
122    opts: &HashOpts,
123    password: &Password,
124) -> Result<PasswordHash, HashError> {
125    let hash = hash_password_inner(opts, password.as_bytes())?;
126
127    Ok(PasswordHash {
128        salt: opts.salt,
129        iterations: opts.iterations,
130        hash,
131    })
132}
133
134/// Hashes a password using PBKDF2 with SHA256,
135/// and returns it in the SCRAM-SHA-256 format.
136/// The format is SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>
137pub fn scram256_hash(password: &Password, iterations: &NonZeroU32) -> Result<String, HashError> {
138    let hashed_password = hash_password(password, iterations)?;
139    Ok(scram256_hash_inner(hashed_password).to_string())
140}
141
142fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
143    verify_slices_are_equal(a, b).is_ok()
144}
145
146/// Verifies a password against a SCRAM-SHA-256 hash.
147pub fn scram256_verify(password: &Password, hashed_password: &str) -> Result<(), VerifyError> {
148    let opts = scram256_parse_opts(hashed_password)?;
149    let hashed = hash_password_with_opts(&opts, password).map_err(VerifyError::Hash)?;
150    let scram = scram256_hash_inner(hashed);
151    if constant_time_compare(hashed_password.as_bytes(), scram.to_string().as_bytes()) {
152        Ok(())
153    } else {
154        Err(VerifyError::InvalidPassword)
155    }
156}
157
158pub fn sasl_verify(
159    hashed_password: &str,
160    proof: &str,
161    auth_message: &str,
162) -> Result<String, VerifyError> {
163    // Parse SCRAM hash: SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>
164    let parts: Vec<&str> = hashed_password.split('$').collect();
165    if parts.len() != 3 {
166        return Err(VerifyError::MalformedHash);
167    }
168    let auth_info = parts[1].split(':').collect::<Vec<&str>>();
169    if auth_info.len() != 2 {
170        return Err(VerifyError::MalformedHash);
171    }
172    let auth_value = parts[2].split(':').collect::<Vec<&str>>();
173    if auth_value.len() != 2 {
174        return Err(VerifyError::MalformedHash);
175    }
176
177    let stored_key = Zeroizing::new(
178        BASE64_STANDARD
179            .decode(auth_value[0])
180            .map_err(|_| VerifyError::MalformedHash)?,
181    );
182    let server_key = Zeroizing::new(
183        BASE64_STANDARD
184            .decode(auth_value[1])
185            .map_err(|_| VerifyError::MalformedHash)?,
186    );
187
188    // Compute client signature: HMAC(stored_key, auth_message)
189    let client_signature = Zeroizing::new(generate_signature(&stored_key, auth_message)?);
190
191    // Decode provided proof
192    let provided_client_proof = Zeroizing::new(
193        BASE64_STANDARD
194            .decode(proof)
195            .map_err(|_| VerifyError::InvalidPassword)?,
196    );
197
198    if provided_client_proof.len() != client_signature.len() {
199        return Err(VerifyError::InvalidPassword);
200    }
201
202    // Recover client_key = proof XOR client_signature
203    let client_key: Zeroizing<Vec<u8>> = Zeroizing::new(
204        provided_client_proof
205            .iter()
206            .zip_eq(client_signature.iter())
207            .map(|(p, s)| p ^ s)
208            .collect(),
209    );
210
211    let computed_stored_key = digest::digest(&digest::SHA256, &client_key);
212    if !constant_time_compare(computed_stored_key.as_ref(), &stored_key) {
213        return Err(VerifyError::InvalidPassword);
214    }
215
216    // Compute server verifier: HMAC(server_key, auth_message)
217    let verifier = Zeroizing::new(generate_signature(&server_key, auth_message)?);
218    Ok(BASE64_STANDARD.encode(&*verifier))
219}
220
221fn generate_signature(key: &[u8], message: &str) -> Result<Zeroizing<Vec<u8>>, VerifyError> {
222    let signing_key = hmac::Key::new(hmac::HMAC_SHA256, key);
223    let tag = hmac::sign(&signing_key, message.as_bytes());
224    Ok(Zeroizing::new(tag.as_ref().to_vec()))
225}
226
227// Generate a mock challenge based on the username and client nonce
228// We do this so that we can present a deterministic challenge even for
229// nonexistent users, to avoid user enumeration attacks.
230pub fn mock_sasl_challenge(username: &str, mock_nonce: &str, iterations: &NonZeroU32) -> HashOpts {
231    let mut buf = Vec::with_capacity(username.len() + mock_nonce.len());
232    buf.extend_from_slice(username.as_bytes());
233    buf.extend_from_slice(mock_nonce.as_bytes());
234    let hash = digest::digest(&digest::SHA256, &buf);
235    let mut salt = [0u8; DEFAULT_SALT_SIZE];
236    salt.copy_from_slice(hash.as_ref());
237
238    HashOpts {
239        iterations: iterations.to_owned(),
240        salt,
241    }
242}
243
244/// Parses a SCRAM-SHA-256 hash and returns the options used to create it.
245pub fn scram256_parse_opts(hashed_password: &str) -> Result<HashOpts, VerifyError> {
246    let parts: Vec<&str> = hashed_password.split('$').collect();
247    if parts.len() != 3 {
248        return Err(VerifyError::MalformedHash);
249    }
250    let scheme = parts[0];
251    if scheme != "SCRAM-SHA-256" {
252        return Err(VerifyError::MalformedHash);
253    }
254    let auth_info = parts[1].split(':').collect::<Vec<&str>>();
255    if auth_info.len() != 2 {
256        return Err(VerifyError::MalformedHash);
257    }
258    let auth_value = parts[2].split(':').collect::<Vec<&str>>();
259    if auth_value.len() != 2 {
260        return Err(VerifyError::MalformedHash);
261    }
262
263    let iterations = auth_info[0]
264        .parse::<u32>()
265        .map_err(|_| VerifyError::MalformedHash)?;
266
267    let salt = BASE64_STANDARD
268        .decode(auth_info[1])
269        .map_err(|_| VerifyError::MalformedHash)?;
270
271    let salt = salt.try_into().map_err(|_| VerifyError::MalformedHash)?;
272
273    Ok(HashOpts {
274        iterations: NonZeroU32::new(iterations).ok_or(VerifyError::MalformedHash)?,
275        salt,
276    })
277}
278
279/// The SCRAM-SHA-256 hash
280struct ScramSha256Hash {
281    /// The number of iterations used for hashing
282    iterations: NonZeroU32,
283    /// The salt used for hashing
284    salt: [u8; 32],
285    /// The server key
286    server_key: [u8; SHA256_OUTPUT_LEN],
287    /// The stored key
288    stored_key: [u8; SHA256_OUTPUT_LEN],
289}
290
291impl Drop for ScramSha256Hash {
292    fn drop(&mut self) {
293        self.salt.zeroize();
294        self.server_key.zeroize();
295        self.stored_key.zeroize();
296    }
297}
298
299impl Display for ScramSha256Hash {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        write!(
302            f,
303            "SCRAM-SHA-256${}:{}${}:{}",
304            self.iterations,
305            BASE64_STANDARD.encode(&self.salt),
306            BASE64_STANDARD.encode(&self.stored_key),
307            BASE64_STANDARD.encode(&self.server_key)
308        )
309    }
310}
311
312fn scram256_hash_inner(hashed_password: PasswordHash) -> ScramSha256Hash {
313    let signing_key = hmac::Key::new(hmac::HMAC_SHA256, &hashed_password.hash);
314    let client_key_tag = hmac::sign(&signing_key, b"Client Key");
315    let client_key = Zeroizing::new(client_key_tag.as_ref().to_vec());
316    let stored_key_digest = digest::digest(&digest::SHA256, &client_key);
317    let mut stored_key = [0u8; SHA256_OUTPUT_LEN];
318    stored_key.copy_from_slice(stored_key_digest.as_ref());
319
320    let server_key_tag = hmac::sign(&signing_key, b"Server Key");
321    let mut server_key = Zeroizing::new([0u8; SHA256_OUTPUT_LEN]);
322    server_key.copy_from_slice(server_key_tag.as_ref());
323
324    ScramSha256Hash {
325        iterations: hashed_password.iterations,
326        salt: hashed_password.salt,
327        server_key: *server_key,
328        stored_key,
329    }
330}
331
332fn hash_password_inner(
333    opts: &HashOpts,
334    password: &[u8],
335) -> Result<[u8; SHA256_OUTPUT_LEN], HashError> {
336    let mut salted_password = Zeroizing::new([0u8; SHA256_OUTPUT_LEN]);
337    aws_lc_rs::pbkdf2::derive(
338        aws_lc_rs::pbkdf2::PBKDF2_HMAC_SHA256,
339        opts.iterations,
340        &opts.salt,
341        password,
342        &mut *salted_password,
343    );
344    Ok(*salted_password)
345}
346
347#[cfg(test)]
348mod tests {
349    use itertools::Itertools;
350
351    use super::*;
352
353    const DEFAULT_ITERATIONS: NonZeroU32 = NonZeroU32::new(60).expect("Trust me on this");
354
355    #[mz_ore::test]
356    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function on OS `linux`
357    fn test_hash_password() {
358        let password = "password".to_string();
359        let iterations = NonZeroU32::new(100).expect("Trust me on this");
360        let hashed_password =
361            hash_password(&password.into(), &iterations).expect("Failed to hash password");
362        assert_eq!(hashed_password.iterations, iterations);
363        assert_eq!(hashed_password.salt.len(), DEFAULT_SALT_SIZE);
364        assert_eq!(hashed_password.hash.len(), SHA256_OUTPUT_LEN);
365    }
366
367    #[mz_ore::test]
368    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function on OS `linux`
369    fn test_scram256_hash() {
370        let password = "password".into();
371        let scram_hash =
372            scram256_hash(&password, &DEFAULT_ITERATIONS).expect("Failed to hash password");
373
374        let res = scram256_verify(&password, &scram_hash);
375        assert!(res.is_ok());
376        let res = scram256_verify(&"wrong_password".into(), &scram_hash);
377        assert!(res.is_err());
378    }
379
380    #[mz_ore::test]
381    fn test_scram256_parse_opts() {
382        let salt = "9bkIQQjQ7f1OwPsXZGC/YfIkbZsOMDXK0cxxvPBaSfM=";
383        let hashed_password = format!("SCRAM-SHA-256$600000:{}$client-key:server-key", salt);
384        let opts = scram256_parse_opts(&hashed_password);
385
386        assert!(opts.is_ok());
387        let opts = opts.unwrap();
388        assert_eq!(
389            opts.iterations,
390            NonZeroU32::new(600_000).expect("known valid")
391        );
392        assert_eq!(opts.salt.len(), DEFAULT_SALT_SIZE);
393        let decoded_salt = BASE64_STANDARD.decode(salt).expect("Failed to decode salt");
394        assert_eq!(opts.salt, decoded_salt.as_ref());
395    }
396
397    #[mz_ore::test]
398    #[cfg_attr(miri, ignore)]
399    fn test_mock_sasl_challenge() {
400        let username = "alice";
401        let mock = "cnonce";
402        let opts1 = mock_sasl_challenge(username, mock, &DEFAULT_ITERATIONS);
403        let opts2 = mock_sasl_challenge(username, mock, &DEFAULT_ITERATIONS);
404        assert_eq!(opts1, opts2);
405    }
406
407    #[mz_ore::test]
408    #[cfg_attr(miri, ignore)]
409    fn test_sasl_verify_success() {
410        let password: Password = "password".into();
411        let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
412        let auth_message = "n=user,r=clientnonce,s=somesalt"; // arbitrary auth message
413
414        // Parse client_key and server_key from the SCRAM hash
415        // Format: SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>
416        let parts: Vec<&str> = hashed_password.split('$').collect();
417        assert_eq!(parts.len(), 3);
418        let key_parts: Vec<&str> = parts[2].split(':').collect();
419        assert_eq!(key_parts.len(), 2);
420        let stored_key = BASE64_STANDARD
421            .decode(key_parts[0])
422            .expect("decode stored key");
423        let server_key = BASE64_STANDARD
424            .decode(key_parts[1])
425            .expect("decode server key");
426
427        // Simulate client generating a proof
428        let client_proof: Vec<u8> = {
429            // client_key = HMAC(salted_password, "Client Key")
430            let opts = scram256_parse_opts(&hashed_password).expect("parse opts");
431            let salted_password = hash_password_with_opts(&opts, &password)
432                .expect("hash password")
433                .hash;
434            let signing_key = hmac::Key::new(hmac::HMAC_SHA256, &salted_password);
435            let client_key = hmac::sign(&signing_key, b"Client Key");
436            let client_key = client_key.as_ref();
437            // client_proof = client_key XOR client_signature
438            let client_signature =
439                generate_signature(&stored_key, auth_message).expect("client signature");
440            client_key
441                .iter()
442                .zip_eq(client_signature.iter())
443                .map(|(c, s)| c ^ s)
444                .collect::<Vec<u8>>()
445        };
446
447        let client_proof_b64 = BASE64_STANDARD.encode(&client_proof);
448
449        let verifier = sasl_verify(&hashed_password, &client_proof_b64, auth_message)
450            .expect("sasl_verify should succeed");
451
452        // Expected verifier: HMAC(server_key, auth_message)
453        let expected_verifier = BASE64_STANDARD
454            .encode(&generate_signature(&server_key, auth_message).expect("server verifier"));
455        assert_eq!(verifier, expected_verifier);
456    }
457
458    #[mz_ore::test]
459    #[cfg_attr(miri, ignore)]
460    fn test_sasl_verify_invalid_proof() {
461        let password: Password = "password".into();
462        let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
463        let auth_message = "n=user,r=clientnonce,s=somesalt";
464        // Provide an obviously invalid base64 proof (different size / random)
465        let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
466        let res = sasl_verify(&hashed_password, &bad_proof, auth_message);
467        assert!(matches!(res, Err(VerifyError::InvalidPassword)));
468    }
469
470    #[mz_ore::test]
471    fn test_sasl_verify_malformed_hash() {
472        let malformed_hash = "NOT-SCRAM$bad"; // clearly malformed (wrong parts count)
473        let auth_message = "n=user,r=clientnonce,s=somesalt";
474        let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
475        let res = sasl_verify(malformed_hash, &bad_proof, auth_message);
476        assert!(matches!(res, Err(VerifyError::MalformedHash)));
477    }
478
479    #[mz_ore::test]
480    #[cfg_attr(miri, ignore)]
481    fn test_sasl_verify_truncated_proof_no_panic() {
482        // A truncated client proof (not 32 bytes) should return InvalidPassword, not panic
483        let password: Password = "password".into();
484        let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
485        let auth_message = "n=user,r=clientnonce,s=somesalt";
486
487        // Truncated proof: 16 bytes instead of the expected 32 (SHA-256 output)
488        let truncated_proof = BASE64_STANDARD.encode([0u8; 16]);
489        let res = sasl_verify(&hashed_password, &truncated_proof, auth_message);
490        assert!(
491            matches!(res, Err(VerifyError::InvalidPassword)),
492            "truncated proof should return InvalidPassword, not panic"
493        );
494
495        // Oversized proof: 64 bytes instead of 32
496        let oversized_proof = BASE64_STANDARD.encode([0u8; 64]);
497        let res = sasl_verify(&hashed_password, &oversized_proof, auth_message);
498        assert!(
499            matches!(res, Err(VerifyError::InvalidPassword)),
500            "oversized proof should return InvalidPassword, not panic"
501        );
502
503        // Empty proof
504        let empty_proof = BASE64_STANDARD.encode([0u8; 0]);
505        let res = sasl_verify(&hashed_password, &empty_proof, auth_message);
506        assert!(
507            matches!(res, Err(VerifyError::InvalidPassword)),
508            "empty proof should return InvalidPassword, not panic"
509        );
510    }
511}