mz_auth/
hash.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10// Clippy misreads some doc comments as HTML tags, so we disable the lint
11#![allow(rustdoc::invalid_html_tags)]
12
13use std::fmt::Display;
14use std::num::NonZeroU32;
15
16use base64::prelude::*;
17use itertools::Itertools;
18
19use crate::password::Password;
20
21/// The default iteration count as suggested by
22/// <https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html>
23const DEFAULT_ITERATIONS: NonZeroU32 = NonZeroU32::new(600_000).unwrap();
24
25/// The default salt size, which isn't currently configurable.
26const DEFAULT_SALT_SIZE: usize = 32;
27
28const SHA256_OUTPUT_LEN: usize = 32;
29
30/// The options for hashing a password
31#[derive(Debug, PartialEq)]
32pub struct HashOpts {
33    /// The number of iterations to use for PBKDF2
34    pub iterations: NonZeroU32,
35    /// The salt to use for PBKDF2. It is up to the caller to
36    /// ensure that however the salt is generated, it is cryptographically
37    /// secure.
38    pub salt: [u8; DEFAULT_SALT_SIZE],
39}
40
41pub struct PasswordHash {
42    /// The salt used for hashing
43    pub salt: [u8; DEFAULT_SALT_SIZE],
44    /// The number of iterations used for hashing
45    pub iterations: NonZeroU32,
46    /// The hash of the password.
47    /// This is the result of PBKDF2 with SHA256
48    pub hash: [u8; SHA256_OUTPUT_LEN],
49}
50
51#[derive(Debug)]
52pub enum VerifyError {
53    MalformedHash,
54    InvalidPassword,
55    Hash(HashError),
56}
57
58#[derive(Debug)]
59pub enum HashError {
60    Openssl(openssl::error::ErrorStack),
61}
62
63impl Display for HashError {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        match self {
66            HashError::Openssl(e) => write!(f, "OpenSSL error: {}", e),
67        }
68    }
69}
70
71/// Hashes a password using PBKDF2 with SHA256
72/// and a random salt.
73pub fn hash_password(password: &Password) -> Result<PasswordHash, HashError> {
74    let mut salt = [0u8; DEFAULT_SALT_SIZE];
75    openssl::rand::rand_bytes(&mut salt).map_err(HashError::Openssl)?;
76
77    let hash = hash_password_inner(
78        &HashOpts {
79            iterations: DEFAULT_ITERATIONS,
80            salt,
81        },
82        password.to_string().as_bytes(),
83    )?;
84
85    Ok(PasswordHash {
86        salt,
87        iterations: DEFAULT_ITERATIONS,
88        hash,
89    })
90}
91
92pub fn generate_nonce(client_nonce: &str) -> Result<String, HashError> {
93    let mut nonce = [0u8; 24];
94    openssl::rand::rand_bytes(&mut nonce).map_err(HashError::Openssl)?;
95    let nonce = BASE64_STANDARD.encode(&nonce);
96    let new_nonce = format!("{}{}", client_nonce, nonce);
97    Ok(new_nonce)
98}
99
100/// Hashes a password using PBKDF2 with SHA256
101/// and the given options.
102pub fn hash_password_with_opts(
103    opts: &HashOpts,
104    password: &Password,
105) -> Result<PasswordHash, HashError> {
106    let hash = hash_password_inner(opts, password.to_string().as_bytes())?;
107
108    Ok(PasswordHash {
109        salt: opts.salt,
110        iterations: opts.iterations,
111        hash,
112    })
113}
114
115/// Hashes a password using PBKDF2 with SHA256,
116/// and returns it in the SCRAM-SHA-256 format.
117/// The format is SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>
118pub fn scram256_hash(password: &Password) -> Result<String, HashError> {
119    let hashed_password = hash_password(password)?;
120    Ok(scram256_hash_inner(hashed_password).to_string())
121}
122
123fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
124    if a.len() != b.len() {
125        return false;
126    }
127    openssl::memcmp::eq(a, b)
128}
129
130/// Verifies a password against a SCRAM-SHA-256 hash.
131pub fn scram256_verify(password: &Password, hashed_password: &str) -> Result<(), VerifyError> {
132    let opts = scram256_parse_opts(hashed_password)?;
133    let hashed = hash_password_with_opts(&opts, password).map_err(VerifyError::Hash)?;
134    let scram = scram256_hash_inner(hashed);
135    if constant_time_compare(hashed_password.as_bytes(), scram.to_string().as_bytes()) {
136        Ok(())
137    } else {
138        Err(VerifyError::InvalidPassword)
139    }
140}
141
142pub fn sasl_verify(
143    hashed_password: &str,
144    proof: &str,
145    auth_message: &str,
146) -> Result<String, VerifyError> {
147    // Parse SCRAM hash: SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>
148    let parts: Vec<&str> = hashed_password.split('$').collect();
149    if parts.len() != 3 {
150        return Err(VerifyError::MalformedHash);
151    }
152    let auth_info = parts[1].split(':').collect::<Vec<&str>>();
153    if auth_info.len() != 2 {
154        return Err(VerifyError::MalformedHash);
155    }
156    let auth_value = parts[2].split(':').collect::<Vec<&str>>();
157    if auth_value.len() != 2 {
158        return Err(VerifyError::MalformedHash);
159    }
160
161    let stored_key = BASE64_STANDARD
162        .decode(auth_value[0])
163        .map_err(|_| VerifyError::MalformedHash)?;
164    let server_key = BASE64_STANDARD
165        .decode(auth_value[1])
166        .map_err(|_| VerifyError::MalformedHash)?;
167
168    // Compute client signature: HMAC(stored_key, auth_message)
169    let client_signature = generate_signature(&stored_key, auth_message)?;
170
171    // Decode provided proof
172    let provided_client_proof = BASE64_STANDARD
173        .decode(proof)
174        .map_err(|_| VerifyError::InvalidPassword)?;
175
176    // Recover client_key = proof XOR client_signature
177    let client_key: Vec<u8> = provided_client_proof
178        .iter()
179        .zip_eq(client_signature.iter())
180        .map(|(p, s)| p ^ s)
181        .collect();
182
183    if !constant_time_compare(&openssl::sha::sha256(&client_key), &stored_key) {
184        return Err(VerifyError::InvalidPassword);
185    }
186
187    // Compute server verifier: HMAC(server_key, auth_message)
188    let verifier = generate_signature(&server_key, auth_message)?;
189    Ok(BASE64_STANDARD.encode(&verifier))
190}
191
192fn generate_signature(key: &[u8], message: &str) -> Result<Vec<u8>, VerifyError> {
193    let signing_key =
194        openssl::pkey::PKey::hmac(key).map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
195    let mut signer =
196        openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key)
197            .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
198    signer
199        .update(message.as_bytes())
200        .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
201    let signature = signer
202        .sign_to_vec()
203        .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
204    Ok(signature)
205}
206
207// Generate a mock challenge based on the username and client nonce
208// We do this so that we can present a deterministic challenge even for
209// nonexistent users, to avoid user enumeration attacks.
210pub fn mock_sasl_challenge(username: &str, mock_nonce: &str) -> HashOpts {
211    let mut buf = Vec::with_capacity(username.len() + mock_nonce.len());
212    buf.extend_from_slice(username.as_bytes());
213    buf.extend_from_slice(mock_nonce.as_bytes());
214    let digest = openssl::sha::sha256(&buf);
215
216    HashOpts {
217        iterations: DEFAULT_ITERATIONS,
218        salt: digest,
219    }
220}
221
222/// Parses a SCRAM-SHA-256 hash and returns the options used to create it.
223pub fn scram256_parse_opts(hashed_password: &str) -> Result<HashOpts, VerifyError> {
224    let parts: Vec<&str> = hashed_password.split('$').collect();
225    if parts.len() != 3 {
226        return Err(VerifyError::MalformedHash);
227    }
228    let scheme = parts[0];
229    if scheme != "SCRAM-SHA-256" {
230        return Err(VerifyError::MalformedHash);
231    }
232    let auth_info = parts[1].split(':').collect::<Vec<&str>>();
233    if auth_info.len() != 2 {
234        return Err(VerifyError::MalformedHash);
235    }
236    let auth_value = parts[2].split(':').collect::<Vec<&str>>();
237    if auth_value.len() != 2 {
238        return Err(VerifyError::MalformedHash);
239    }
240
241    let iterations = auth_info[0]
242        .parse::<u32>()
243        .map_err(|_| VerifyError::MalformedHash)?;
244
245    let salt = BASE64_STANDARD
246        .decode(auth_info[1])
247        .map_err(|_| VerifyError::MalformedHash)?;
248
249    let salt = salt.try_into().map_err(|_| VerifyError::MalformedHash)?;
250
251    Ok(HashOpts {
252        iterations: NonZeroU32::new(iterations).ok_or(VerifyError::MalformedHash)?,
253        salt,
254    })
255}
256
257/// The SCRAM-SHA-256 hash
258struct ScramSha256Hash {
259    /// The number of iterations used for hashing
260    iterations: NonZeroU32,
261    /// The salt used for hashing
262    salt: [u8; 32],
263    /// The server key
264    server_key: [u8; SHA256_OUTPUT_LEN],
265    /// The stored key
266    stored_key: [u8; SHA256_OUTPUT_LEN],
267}
268
269impl Display for ScramSha256Hash {
270    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271        write!(
272            f,
273            "SCRAM-SHA-256${}:{}${}:{}",
274            self.iterations,
275            BASE64_STANDARD.encode(&self.salt),
276            BASE64_STANDARD.encode(&self.stored_key),
277            BASE64_STANDARD.encode(&self.server_key)
278        )
279    }
280}
281
282fn scram256_hash_inner(hashed_password: PasswordHash) -> ScramSha256Hash {
283    let signing_key = openssl::pkey::PKey::hmac(&hashed_password.hash).unwrap();
284    let mut signer =
285        openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key).unwrap();
286    signer.update(b"Client Key").unwrap();
287    let client_key = signer.sign_to_vec().unwrap();
288    let stored_key = openssl::sha::sha256(&client_key);
289    let mut signer =
290        openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key).unwrap();
291    signer.update(b"Server Key").unwrap();
292    let mut server_key: [u8; SHA256_OUTPUT_LEN] = [0; SHA256_OUTPUT_LEN];
293    signer.sign(server_key.as_mut()).unwrap();
294
295    ScramSha256Hash {
296        iterations: hashed_password.iterations,
297        salt: hashed_password.salt,
298        server_key,
299        stored_key,
300    }
301}
302
303fn hash_password_inner(
304    opts: &HashOpts,
305    password: &[u8],
306) -> Result<[u8; SHA256_OUTPUT_LEN], HashError> {
307    let mut salted_password = [0u8; SHA256_OUTPUT_LEN];
308    openssl::pkcs5::pbkdf2_hmac(
309        password,
310        &opts.salt,
311        opts.iterations.get().try_into().unwrap(),
312        openssl::hash::MessageDigest::sha256(),
313        &mut salted_password,
314    )
315    .map_err(HashError::Openssl)?;
316    Ok(salted_password)
317}
318
319#[cfg(test)]
320mod tests {
321    use itertools::Itertools;
322
323    use super::*;
324
325    #[mz_ore::test]
326    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux`
327    fn test_hash_password() {
328        let password = "password".to_string();
329        let hashed_password = hash_password(&password.into()).expect("Failed to hash password");
330        assert_eq!(hashed_password.iterations, DEFAULT_ITERATIONS);
331        assert_eq!(hashed_password.salt.len(), DEFAULT_SALT_SIZE);
332        assert_eq!(hashed_password.hash.len(), SHA256_OUTPUT_LEN);
333    }
334
335    #[mz_ore::test]
336    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux`
337    fn test_scram256_hash() {
338        let password = "password".into();
339        let scram_hash = scram256_hash(&password).expect("Failed to hash password");
340
341        let res = scram256_verify(&password, &scram_hash);
342        assert!(res.is_ok());
343        let res = scram256_verify(&"wrong_password".into(), &scram_hash);
344        assert!(res.is_err());
345    }
346
347    #[mz_ore::test]
348    fn test_scram256_parse_opts() {
349        let salt = "9bkIQQjQ7f1OwPsXZGC/YfIkbZsOMDXK0cxxvPBaSfM=";
350        let hashed_password = format!("SCRAM-SHA-256$600000:{}$client-key:server-key", salt);
351        let opts = scram256_parse_opts(&hashed_password);
352
353        assert!(opts.is_ok());
354        let opts = opts.unwrap();
355        assert_eq!(opts.iterations, DEFAULT_ITERATIONS);
356        assert_eq!(opts.salt.len(), DEFAULT_SALT_SIZE);
357        let decoded_salt = BASE64_STANDARD.decode(salt).expect("Failed to decode salt");
358        assert_eq!(opts.salt, decoded_salt.as_ref());
359    }
360
361    #[mz_ore::test]
362    #[cfg_attr(miri, ignore)]
363    fn test_mock_sasl_challenge() {
364        let username = "alice";
365        let mock = "cnonce";
366        let opts1 = mock_sasl_challenge(username, mock);
367        let opts2 = mock_sasl_challenge(username, mock);
368        assert_eq!(opts1, opts2);
369    }
370
371    #[mz_ore::test]
372    #[cfg_attr(miri, ignore)]
373    fn test_sasl_verify_success() {
374        let password: Password = "password".into();
375        let hashed_password = scram256_hash(&password).expect("hash password");
376        let auth_message = "n=user,r=clientnonce,s=somesalt"; // arbitrary auth message
377
378        // Parse client_key and server_key from the SCRAM hash
379        // Format: SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>
380        let parts: Vec<&str> = hashed_password.split('$').collect();
381        assert_eq!(parts.len(), 3);
382        let key_parts: Vec<&str> = parts[2].split(':').collect();
383        assert_eq!(key_parts.len(), 2);
384        let stored_key = BASE64_STANDARD
385            .decode(key_parts[0])
386            .expect("decode stored key");
387        let server_key = BASE64_STANDARD
388            .decode(key_parts[1])
389            .expect("decode server key");
390
391        // Simulate client generating a proof
392        let client_proof: Vec<u8> = {
393            // client_key = HMAC(salted_password, "Client Key")
394            let opts = scram256_parse_opts(&hashed_password).expect("parse opts");
395            let salted_password = hash_password_with_opts(&opts, &password)
396                .expect("hash password")
397                .hash;
398            let signing_key = openssl::pkey::PKey::hmac(&salted_password).expect("signing key");
399            let mut signer =
400                openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key)
401                    .expect("signer");
402            signer.update(b"Client Key").expect("update");
403            let client_key = signer.sign_to_vec().expect("client key");
404            // client_proof = client_key XOR client_signature
405            let client_signature =
406                generate_signature(&stored_key, auth_message).expect("client signature");
407            client_key
408                .iter()
409                .zip_eq(client_signature.iter())
410                .map(|(c, s)| c ^ s)
411                .collect::<Vec<u8>>()
412        };
413
414        let client_proof_b64 = BASE64_STANDARD.encode(&client_proof);
415
416        let verifier = sasl_verify(&hashed_password, &client_proof_b64, auth_message)
417            .expect("sasl_verify should succeed");
418
419        // Expected verifier: HMAC(server_key, auth_message)
420        let expected_verifier = BASE64_STANDARD
421            .encode(&generate_signature(&server_key, auth_message).expect("server verifier"));
422        assert_eq!(verifier, expected_verifier);
423    }
424
425    #[mz_ore::test]
426    #[cfg_attr(miri, ignore)]
427    fn test_sasl_verify_invalid_proof() {
428        let password: Password = "password".into();
429        let hashed_password = scram256_hash(&password).expect("hash password");
430        let auth_message = "n=user,r=clientnonce,s=somesalt";
431        // Provide an obviously invalid base64 proof (different size / random)
432        let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
433        let res = sasl_verify(&hashed_password, &bad_proof, auth_message);
434        assert!(matches!(res, Err(VerifyError::InvalidPassword)));
435    }
436
437    #[mz_ore::test]
438    fn test_sasl_verify_malformed_hash() {
439        let malformed_hash = "NOT-SCRAM$bad"; // clearly malformed (wrong parts count)
440        let auth_message = "n=user,r=clientnonce,s=somesalt";
441        let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
442        let res = sasl_verify(malformed_hash, &bad_proof, auth_message);
443        assert!(matches!(res, Err(VerifyError::MalformedHash)));
444    }
445}