Skip to main content

mz_auth/
hash.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10// Clippy misreads some doc comments as HTML tags, so we disable the lint
11#![allow(rustdoc::invalid_html_tags)]
12
13use std::fmt::Display;
14use std::num::NonZeroU32;
15
16use base64::prelude::*;
17use itertools::Itertools;
18use mz_ore::secure::{Zeroize, Zeroizing};
19
20use crate::password::Password;
21
22/// The default salt size, which isn't currently configurable.
23const DEFAULT_SALT_SIZE: usize = 32;
24
25const SHA256_OUTPUT_LEN: usize = 32;
26
27/// The options for hashing a password
28#[derive(Debug, PartialEq)]
29pub struct HashOpts {
30    /// The number of iterations to use for PBKDF2
31    pub iterations: NonZeroU32,
32    /// The salt to use for PBKDF2. It is up to the caller to
33    /// ensure that however the salt is generated, it is cryptographically
34    /// secure.
35    pub salt: [u8; DEFAULT_SALT_SIZE],
36}
37
38impl Drop for HashOpts {
39    fn drop(&mut self) {
40        self.salt.zeroize();
41    }
42}
43
44pub struct PasswordHash {
45    /// The salt used for hashing
46    pub salt: [u8; DEFAULT_SALT_SIZE],
47    /// The number of iterations used for hashing
48    pub iterations: NonZeroU32,
49    /// The hash of the password.
50    /// This is the result of PBKDF2 with SHA256
51    pub hash: [u8; SHA256_OUTPUT_LEN],
52}
53
54impl Drop for PasswordHash {
55    fn drop(&mut self) {
56        self.salt.zeroize();
57        self.hash.zeroize();
58    }
59}
60
61#[derive(Debug)]
62pub enum VerifyError {
63    MalformedHash,
64    InvalidPassword,
65    Hash(HashError),
66}
67
68#[derive(Debug)]
69pub enum HashError {
70    Openssl(openssl::error::ErrorStack),
71}
72
73impl Display for HashError {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        match self {
76            HashError::Openssl(e) => write!(f, "OpenSSL error: {}", e),
77        }
78    }
79}
80
81/// Hashes a password using PBKDF2 with SHA256
82/// and a random salt.
83pub fn hash_password(
84    password: &Password,
85    iterations: &NonZeroU32,
86) -> Result<PasswordHash, HashError> {
87    let mut salt = Zeroizing::new([0u8; DEFAULT_SALT_SIZE]);
88    openssl::rand::rand_bytes(&mut *salt).map_err(HashError::Openssl)?;
89
90    let hash = hash_password_inner(
91        &HashOpts {
92            iterations: iterations.to_owned(),
93            salt: *salt,
94        },
95        password.as_bytes(),
96    )?;
97
98    Ok(PasswordHash {
99        salt: *salt,
100        iterations: iterations.to_owned(),
101        hash,
102    })
103}
104
105pub fn generate_nonce(client_nonce: &str) -> Result<String, HashError> {
106    let mut nonce = Zeroizing::new([0u8; 24]);
107    openssl::rand::rand_bytes(&mut *nonce).map_err(HashError::Openssl)?;
108    let nonce = BASE64_STANDARD.encode(&*nonce);
109    let new_nonce = format!("{}{}", client_nonce, nonce);
110    Ok(new_nonce)
111}
112
113/// Hashes a password using PBKDF2 with SHA256
114/// and the given options.
115pub fn hash_password_with_opts(
116    opts: &HashOpts,
117    password: &Password,
118) -> Result<PasswordHash, HashError> {
119    let hash = hash_password_inner(opts, password.as_bytes())?;
120
121    Ok(PasswordHash {
122        salt: opts.salt,
123        iterations: opts.iterations,
124        hash,
125    })
126}
127
128/// Hashes a password using PBKDF2 with SHA256,
129/// and returns it in the SCRAM-SHA-256 format.
130/// The format is SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>
131pub fn scram256_hash(password: &Password, iterations: &NonZeroU32) -> Result<String, HashError> {
132    let hashed_password = hash_password(password, iterations)?;
133    Ok(scram256_hash_inner(hashed_password).to_string())
134}
135
136fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
137    if a.len() != b.len() {
138        return false;
139    }
140    openssl::memcmp::eq(a, b)
141}
142
143/// Verifies a password against a SCRAM-SHA-256 hash.
144pub fn scram256_verify(password: &Password, hashed_password: &str) -> Result<(), VerifyError> {
145    let opts = scram256_parse_opts(hashed_password)?;
146    let hashed = hash_password_with_opts(&opts, password).map_err(VerifyError::Hash)?;
147    let scram = scram256_hash_inner(hashed);
148    if constant_time_compare(hashed_password.as_bytes(), scram.to_string().as_bytes()) {
149        Ok(())
150    } else {
151        Err(VerifyError::InvalidPassword)
152    }
153}
154
155pub fn sasl_verify(
156    hashed_password: &str,
157    proof: &str,
158    auth_message: &str,
159) -> Result<String, VerifyError> {
160    // Parse SCRAM hash: SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>
161    let parts: Vec<&str> = hashed_password.split('$').collect();
162    if parts.len() != 3 {
163        return Err(VerifyError::MalformedHash);
164    }
165    let auth_info = parts[1].split(':').collect::<Vec<&str>>();
166    if auth_info.len() != 2 {
167        return Err(VerifyError::MalformedHash);
168    }
169    let auth_value = parts[2].split(':').collect::<Vec<&str>>();
170    if auth_value.len() != 2 {
171        return Err(VerifyError::MalformedHash);
172    }
173
174    let stored_key = Zeroizing::new(
175        BASE64_STANDARD
176            .decode(auth_value[0])
177            .map_err(|_| VerifyError::MalformedHash)?,
178    );
179    let server_key = Zeroizing::new(
180        BASE64_STANDARD
181            .decode(auth_value[1])
182            .map_err(|_| VerifyError::MalformedHash)?,
183    );
184
185    // Compute client signature: HMAC(stored_key, auth_message)
186    let client_signature = Zeroizing::new(generate_signature(&stored_key, auth_message)?);
187
188    // Decode provided proof
189    let provided_client_proof = Zeroizing::new(
190        BASE64_STANDARD
191            .decode(proof)
192            .map_err(|_| VerifyError::InvalidPassword)?,
193    );
194
195    if provided_client_proof.len() != client_signature.len() {
196        return Err(VerifyError::InvalidPassword);
197    }
198
199    // Recover client_key = proof XOR client_signature
200    let client_key: Zeroizing<Vec<u8>> = Zeroizing::new(
201        provided_client_proof
202            .iter()
203            .zip_eq(client_signature.iter())
204            .map(|(p, s)| p ^ s)
205            .collect(),
206    );
207
208    if !constant_time_compare(&openssl::sha::sha256(&client_key), &stored_key) {
209        return Err(VerifyError::InvalidPassword);
210    }
211
212    // Compute server verifier: HMAC(server_key, auth_message)
213    let verifier = Zeroizing::new(generate_signature(&server_key, auth_message)?);
214    Ok(BASE64_STANDARD.encode(&*verifier))
215}
216
217fn generate_signature(key: &[u8], message: &str) -> Result<Zeroizing<Vec<u8>>, VerifyError> {
218    let signing_key =
219        openssl::pkey::PKey::hmac(key).map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
220    let mut signer =
221        openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key)
222            .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
223    signer
224        .update(message.as_bytes())
225        .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
226    let signature = signer
227        .sign_to_vec()
228        .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
229    Ok(Zeroizing::new(signature))
230}
231
232// Generate a mock challenge based on the username and client nonce
233// We do this so that we can present a deterministic challenge even for
234// nonexistent users, to avoid user enumeration attacks.
235pub fn mock_sasl_challenge(username: &str, mock_nonce: &str, iterations: &NonZeroU32) -> HashOpts {
236    let mut buf = Vec::with_capacity(username.len() + mock_nonce.len());
237    buf.extend_from_slice(username.as_bytes());
238    buf.extend_from_slice(mock_nonce.as_bytes());
239    let digest = openssl::sha::sha256(&buf);
240
241    HashOpts {
242        iterations: iterations.to_owned(),
243        salt: digest,
244    }
245}
246
247/// Parses a SCRAM-SHA-256 hash and returns the options used to create it.
248pub fn scram256_parse_opts(hashed_password: &str) -> Result<HashOpts, VerifyError> {
249    let parts: Vec<&str> = hashed_password.split('$').collect();
250    if parts.len() != 3 {
251        return Err(VerifyError::MalformedHash);
252    }
253    let scheme = parts[0];
254    if scheme != "SCRAM-SHA-256" {
255        return Err(VerifyError::MalformedHash);
256    }
257    let auth_info = parts[1].split(':').collect::<Vec<&str>>();
258    if auth_info.len() != 2 {
259        return Err(VerifyError::MalformedHash);
260    }
261    let auth_value = parts[2].split(':').collect::<Vec<&str>>();
262    if auth_value.len() != 2 {
263        return Err(VerifyError::MalformedHash);
264    }
265
266    let iterations = auth_info[0]
267        .parse::<u32>()
268        .map_err(|_| VerifyError::MalformedHash)?;
269
270    let salt = BASE64_STANDARD
271        .decode(auth_info[1])
272        .map_err(|_| VerifyError::MalformedHash)?;
273
274    let salt = salt.try_into().map_err(|_| VerifyError::MalformedHash)?;
275
276    Ok(HashOpts {
277        iterations: NonZeroU32::new(iterations).ok_or(VerifyError::MalformedHash)?,
278        salt,
279    })
280}
281
282/// The SCRAM-SHA-256 hash
283struct ScramSha256Hash {
284    /// The number of iterations used for hashing
285    iterations: NonZeroU32,
286    /// The salt used for hashing
287    salt: [u8; 32],
288    /// The server key
289    server_key: [u8; SHA256_OUTPUT_LEN],
290    /// The stored key
291    stored_key: [u8; SHA256_OUTPUT_LEN],
292}
293
294impl Drop for ScramSha256Hash {
295    fn drop(&mut self) {
296        self.salt.zeroize();
297        self.server_key.zeroize();
298        self.stored_key.zeroize();
299    }
300}
301
302impl Display for ScramSha256Hash {
303    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304        write!(
305            f,
306            "SCRAM-SHA-256${}:{}${}:{}",
307            self.iterations,
308            BASE64_STANDARD.encode(&self.salt),
309            BASE64_STANDARD.encode(&self.stored_key),
310            BASE64_STANDARD.encode(&self.server_key)
311        )
312    }
313}
314
315fn scram256_hash_inner(hashed_password: PasswordHash) -> ScramSha256Hash {
316    let signing_key = openssl::pkey::PKey::hmac(&hashed_password.hash).unwrap();
317    let mut signer =
318        openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key).unwrap();
319    signer.update(b"Client Key").unwrap();
320    let client_key = Zeroizing::new(signer.sign_to_vec().unwrap());
321    let stored_key = openssl::sha::sha256(&client_key);
322    let mut signer =
323        openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key).unwrap();
324    signer.update(b"Server Key").unwrap();
325    let mut server_key = Zeroizing::new([0u8; SHA256_OUTPUT_LEN]);
326    signer.sign(server_key.as_mut()).unwrap();
327
328    ScramSha256Hash {
329        iterations: hashed_password.iterations,
330        salt: hashed_password.salt,
331        server_key: *server_key,
332        stored_key,
333    }
334}
335
336fn hash_password_inner(
337    opts: &HashOpts,
338    password: &[u8],
339) -> Result<[u8; SHA256_OUTPUT_LEN], HashError> {
340    let mut salted_password = Zeroizing::new([0u8; SHA256_OUTPUT_LEN]);
341    openssl::pkcs5::pbkdf2_hmac(
342        password,
343        &opts.salt,
344        opts.iterations.get().try_into().unwrap(),
345        openssl::hash::MessageDigest::sha256(),
346        &mut *salted_password,
347    )
348    .map_err(HashError::Openssl)?;
349    Ok(*salted_password)
350}
351
352#[cfg(test)]
353mod tests {
354    use itertools::Itertools;
355
356    use super::*;
357
358    const DEFAULT_ITERATIONS: NonZeroU32 = NonZeroU32::new(60).expect("Trust me on this");
359
360    #[mz_ore::test]
361    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux`
362    fn test_hash_password() {
363        let password = "password".to_string();
364        let iterations = NonZeroU32::new(100).expect("Trust me on this");
365        let hashed_password =
366            hash_password(&password.into(), &iterations).expect("Failed to hash password");
367        assert_eq!(hashed_password.iterations, iterations);
368        assert_eq!(hashed_password.salt.len(), DEFAULT_SALT_SIZE);
369        assert_eq!(hashed_password.hash.len(), SHA256_OUTPUT_LEN);
370    }
371
372    #[mz_ore::test]
373    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux`
374    fn test_scram256_hash() {
375        let password = "password".into();
376        let scram_hash =
377            scram256_hash(&password, &DEFAULT_ITERATIONS).expect("Failed to hash password");
378
379        let res = scram256_verify(&password, &scram_hash);
380        assert!(res.is_ok());
381        let res = scram256_verify(&"wrong_password".into(), &scram_hash);
382        assert!(res.is_err());
383    }
384
385    #[mz_ore::test]
386    fn test_scram256_parse_opts() {
387        let salt = "9bkIQQjQ7f1OwPsXZGC/YfIkbZsOMDXK0cxxvPBaSfM=";
388        let hashed_password = format!("SCRAM-SHA-256$600000:{}$client-key:server-key", salt);
389        let opts = scram256_parse_opts(&hashed_password);
390
391        assert!(opts.is_ok());
392        let opts = opts.unwrap();
393        assert_eq!(
394            opts.iterations,
395            NonZeroU32::new(600_000).expect("known valid")
396        );
397        assert_eq!(opts.salt.len(), DEFAULT_SALT_SIZE);
398        let decoded_salt = BASE64_STANDARD.decode(salt).expect("Failed to decode salt");
399        assert_eq!(opts.salt, decoded_salt.as_ref());
400    }
401
402    #[mz_ore::test]
403    #[cfg_attr(miri, ignore)]
404    fn test_mock_sasl_challenge() {
405        let username = "alice";
406        let mock = "cnonce";
407        let opts1 = mock_sasl_challenge(username, mock, &DEFAULT_ITERATIONS);
408        let opts2 = mock_sasl_challenge(username, mock, &DEFAULT_ITERATIONS);
409        assert_eq!(opts1, opts2);
410    }
411
412    #[mz_ore::test]
413    #[cfg_attr(miri, ignore)]
414    fn test_sasl_verify_success() {
415        let password: Password = "password".into();
416        let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
417        let auth_message = "n=user,r=clientnonce,s=somesalt"; // arbitrary auth message
418
419        // Parse client_key and server_key from the SCRAM hash
420        // Format: SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>
421        let parts: Vec<&str> = hashed_password.split('$').collect();
422        assert_eq!(parts.len(), 3);
423        let key_parts: Vec<&str> = parts[2].split(':').collect();
424        assert_eq!(key_parts.len(), 2);
425        let stored_key = BASE64_STANDARD
426            .decode(key_parts[0])
427            .expect("decode stored key");
428        let server_key = BASE64_STANDARD
429            .decode(key_parts[1])
430            .expect("decode server key");
431
432        // Simulate client generating a proof
433        let client_proof: Vec<u8> = {
434            // client_key = HMAC(salted_password, "Client Key")
435            let opts = scram256_parse_opts(&hashed_password).expect("parse opts");
436            let salted_password = hash_password_with_opts(&opts, &password)
437                .expect("hash password")
438                .hash;
439            let signing_key = openssl::pkey::PKey::hmac(&salted_password).expect("signing key");
440            let mut signer =
441                openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key)
442                    .expect("signer");
443            signer.update(b"Client Key").expect("update");
444            let client_key = signer.sign_to_vec().expect("client key");
445            // client_proof = client_key XOR client_signature
446            let client_signature =
447                generate_signature(&stored_key, auth_message).expect("client signature");
448            client_key
449                .iter()
450                .zip_eq(client_signature.iter())
451                .map(|(c, s)| c ^ s)
452                .collect::<Vec<u8>>()
453        };
454
455        let client_proof_b64 = BASE64_STANDARD.encode(&client_proof);
456
457        let verifier = sasl_verify(&hashed_password, &client_proof_b64, auth_message)
458            .expect("sasl_verify should succeed");
459
460        // Expected verifier: HMAC(server_key, auth_message)
461        let expected_verifier = BASE64_STANDARD
462            .encode(&generate_signature(&server_key, auth_message).expect("server verifier"));
463        assert_eq!(verifier, expected_verifier);
464    }
465
466    #[mz_ore::test]
467    #[cfg_attr(miri, ignore)]
468    fn test_sasl_verify_invalid_proof() {
469        let password: Password = "password".into();
470        let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
471        let auth_message = "n=user,r=clientnonce,s=somesalt";
472        // Provide an obviously invalid base64 proof (different size / random)
473        let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
474        let res = sasl_verify(&hashed_password, &bad_proof, auth_message);
475        assert!(matches!(res, Err(VerifyError::InvalidPassword)));
476    }
477
478    #[mz_ore::test]
479    fn test_sasl_verify_malformed_hash() {
480        let malformed_hash = "NOT-SCRAM$bad"; // clearly malformed (wrong parts count)
481        let auth_message = "n=user,r=clientnonce,s=somesalt";
482        let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
483        let res = sasl_verify(malformed_hash, &bad_proof, auth_message);
484        assert!(matches!(res, Err(VerifyError::MalformedHash)));
485    }
486
487    #[mz_ore::test]
488    #[cfg_attr(miri, ignore)]
489    fn test_sasl_verify_truncated_proof_no_panic() {
490        // A truncated client proof (not 32 bytes) should return InvalidPassword, not panic
491        let password: Password = "password".into();
492        let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
493        let auth_message = "n=user,r=clientnonce,s=somesalt";
494
495        // Truncated proof: 16 bytes instead of the expected 32 (SHA-256 output)
496        let truncated_proof = BASE64_STANDARD.encode([0u8; 16]);
497        let res = sasl_verify(&hashed_password, &truncated_proof, auth_message);
498        assert!(
499            matches!(res, Err(VerifyError::InvalidPassword)),
500            "truncated proof should return InvalidPassword, not panic"
501        );
502
503        // Oversized proof: 64 bytes instead of 32
504        let oversized_proof = BASE64_STANDARD.encode([0u8; 64]);
505        let res = sasl_verify(&hashed_password, &oversized_proof, auth_message);
506        assert!(
507            matches!(res, Err(VerifyError::InvalidPassword)),
508            "oversized proof should return InvalidPassword, not panic"
509        );
510
511        // Empty proof
512        let empty_proof = BASE64_STANDARD.encode([0u8; 0]);
513        let res = sasl_verify(&hashed_password, &empty_proof, auth_message);
514        assert!(
515            matches!(res, Err(VerifyError::InvalidPassword)),
516            "empty proof should return InvalidPassword, not panic"
517        );
518    }
519}