Skip to main content

mz_auth/
hash.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10// Clippy misreads some doc comments as HTML tags, so we disable the lint
11#![allow(rustdoc::invalid_html_tags)]
12
13use std::fmt::Display;
14use std::num::NonZeroU32;
15
16use base64::prelude::*;
17use itertools::Itertools;
18
19use crate::password::Password;
20
21/// The default salt size, which isn't currently configurable.
22const DEFAULT_SALT_SIZE: usize = 32;
23
24const SHA256_OUTPUT_LEN: usize = 32;
25
26/// The options for hashing a password
27#[derive(Debug, PartialEq)]
28pub struct HashOpts {
29    /// The number of iterations to use for PBKDF2
30    pub iterations: NonZeroU32,
31    /// The salt to use for PBKDF2. It is up to the caller to
32    /// ensure that however the salt is generated, it is cryptographically
33    /// secure.
34    pub salt: [u8; DEFAULT_SALT_SIZE],
35}
36
37pub struct PasswordHash {
38    /// The salt used for hashing
39    pub salt: [u8; DEFAULT_SALT_SIZE],
40    /// The number of iterations used for hashing
41    pub iterations: NonZeroU32,
42    /// The hash of the password.
43    /// This is the result of PBKDF2 with SHA256
44    pub hash: [u8; SHA256_OUTPUT_LEN],
45}
46
47#[derive(Debug)]
48pub enum VerifyError {
49    MalformedHash,
50    InvalidPassword,
51    Hash(HashError),
52}
53
54#[derive(Debug)]
55pub enum HashError {
56    Openssl(openssl::error::ErrorStack),
57}
58
59impl Display for HashError {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            HashError::Openssl(e) => write!(f, "OpenSSL error: {}", e),
63        }
64    }
65}
66
67/// Hashes a password using PBKDF2 with SHA256
68/// and a random salt.
69pub fn hash_password(
70    password: &Password,
71    iterations: &NonZeroU32,
72) -> Result<PasswordHash, HashError> {
73    let mut salt = [0u8; DEFAULT_SALT_SIZE];
74    openssl::rand::rand_bytes(&mut salt).map_err(HashError::Openssl)?;
75
76    let hash = hash_password_inner(
77        &HashOpts {
78            iterations: iterations.to_owned(),
79            salt,
80        },
81        password.to_string().as_bytes(),
82    )?;
83
84    Ok(PasswordHash {
85        salt,
86        iterations: iterations.to_owned(),
87        hash,
88    })
89}
90
91pub fn generate_nonce(client_nonce: &str) -> Result<String, HashError> {
92    let mut nonce = [0u8; 24];
93    openssl::rand::rand_bytes(&mut nonce).map_err(HashError::Openssl)?;
94    let nonce = BASE64_STANDARD.encode(&nonce);
95    let new_nonce = format!("{}{}", client_nonce, nonce);
96    Ok(new_nonce)
97}
98
99/// Hashes a password using PBKDF2 with SHA256
100/// and the given options.
101pub fn hash_password_with_opts(
102    opts: &HashOpts,
103    password: &Password,
104) -> Result<PasswordHash, HashError> {
105    let hash = hash_password_inner(opts, password.to_string().as_bytes())?;
106
107    Ok(PasswordHash {
108        salt: opts.salt,
109        iterations: opts.iterations,
110        hash,
111    })
112}
113
114/// Hashes a password using PBKDF2 with SHA256,
115/// and returns it in the SCRAM-SHA-256 format.
116/// The format is SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>
117pub fn scram256_hash(password: &Password, iterations: &NonZeroU32) -> Result<String, HashError> {
118    let hashed_password = hash_password(password, iterations)?;
119    Ok(scram256_hash_inner(hashed_password).to_string())
120}
121
122fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
123    if a.len() != b.len() {
124        return false;
125    }
126    openssl::memcmp::eq(a, b)
127}
128
129/// Verifies a password against a SCRAM-SHA-256 hash.
130pub fn scram256_verify(password: &Password, hashed_password: &str) -> Result<(), VerifyError> {
131    let opts = scram256_parse_opts(hashed_password)?;
132    let hashed = hash_password_with_opts(&opts, password).map_err(VerifyError::Hash)?;
133    let scram = scram256_hash_inner(hashed);
134    if constant_time_compare(hashed_password.as_bytes(), scram.to_string().as_bytes()) {
135        Ok(())
136    } else {
137        Err(VerifyError::InvalidPassword)
138    }
139}
140
141pub fn sasl_verify(
142    hashed_password: &str,
143    proof: &str,
144    auth_message: &str,
145) -> Result<String, VerifyError> {
146    // Parse SCRAM hash: SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>
147    let parts: Vec<&str> = hashed_password.split('$').collect();
148    if parts.len() != 3 {
149        return Err(VerifyError::MalformedHash);
150    }
151    let auth_info = parts[1].split(':').collect::<Vec<&str>>();
152    if auth_info.len() != 2 {
153        return Err(VerifyError::MalformedHash);
154    }
155    let auth_value = parts[2].split(':').collect::<Vec<&str>>();
156    if auth_value.len() != 2 {
157        return Err(VerifyError::MalformedHash);
158    }
159
160    let stored_key = BASE64_STANDARD
161        .decode(auth_value[0])
162        .map_err(|_| VerifyError::MalformedHash)?;
163    let server_key = BASE64_STANDARD
164        .decode(auth_value[1])
165        .map_err(|_| VerifyError::MalformedHash)?;
166
167    // Compute client signature: HMAC(stored_key, auth_message)
168    let client_signature = generate_signature(&stored_key, auth_message)?;
169
170    // Decode provided proof
171    let provided_client_proof = BASE64_STANDARD
172        .decode(proof)
173        .map_err(|_| VerifyError::InvalidPassword)?;
174
175    if provided_client_proof.len() != client_signature.len() {
176        return Err(VerifyError::InvalidPassword);
177    }
178
179    // Recover client_key = proof XOR client_signature
180    let client_key: Vec<u8> = provided_client_proof
181        .iter()
182        .zip_eq(client_signature.iter())
183        .map(|(p, s)| p ^ s)
184        .collect();
185
186    if !constant_time_compare(&openssl::sha::sha256(&client_key), &stored_key) {
187        return Err(VerifyError::InvalidPassword);
188    }
189
190    // Compute server verifier: HMAC(server_key, auth_message)
191    let verifier = generate_signature(&server_key, auth_message)?;
192    Ok(BASE64_STANDARD.encode(&verifier))
193}
194
195fn generate_signature(key: &[u8], message: &str) -> Result<Vec<u8>, VerifyError> {
196    let signing_key =
197        openssl::pkey::PKey::hmac(key).map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
198    let mut signer =
199        openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key)
200            .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
201    signer
202        .update(message.as_bytes())
203        .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
204    let signature = signer
205        .sign_to_vec()
206        .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
207    Ok(signature)
208}
209
210// Generate a mock challenge based on the username and client nonce
211// We do this so that we can present a deterministic challenge even for
212// nonexistent users, to avoid user enumeration attacks.
213pub fn mock_sasl_challenge(username: &str, mock_nonce: &str, iterations: &NonZeroU32) -> HashOpts {
214    let mut buf = Vec::with_capacity(username.len() + mock_nonce.len());
215    buf.extend_from_slice(username.as_bytes());
216    buf.extend_from_slice(mock_nonce.as_bytes());
217    let digest = openssl::sha::sha256(&buf);
218
219    HashOpts {
220        iterations: iterations.to_owned(),
221        salt: digest,
222    }
223}
224
225/// Parses a SCRAM-SHA-256 hash and returns the options used to create it.
226pub fn scram256_parse_opts(hashed_password: &str) -> Result<HashOpts, VerifyError> {
227    let parts: Vec<&str> = hashed_password.split('$').collect();
228    if parts.len() != 3 {
229        return Err(VerifyError::MalformedHash);
230    }
231    let scheme = parts[0];
232    if scheme != "SCRAM-SHA-256" {
233        return Err(VerifyError::MalformedHash);
234    }
235    let auth_info = parts[1].split(':').collect::<Vec<&str>>();
236    if auth_info.len() != 2 {
237        return Err(VerifyError::MalformedHash);
238    }
239    let auth_value = parts[2].split(':').collect::<Vec<&str>>();
240    if auth_value.len() != 2 {
241        return Err(VerifyError::MalformedHash);
242    }
243
244    let iterations = auth_info[0]
245        .parse::<u32>()
246        .map_err(|_| VerifyError::MalformedHash)?;
247
248    let salt = BASE64_STANDARD
249        .decode(auth_info[1])
250        .map_err(|_| VerifyError::MalformedHash)?;
251
252    let salt = salt.try_into().map_err(|_| VerifyError::MalformedHash)?;
253
254    Ok(HashOpts {
255        iterations: NonZeroU32::new(iterations).ok_or(VerifyError::MalformedHash)?,
256        salt,
257    })
258}
259
260/// The SCRAM-SHA-256 hash
261struct ScramSha256Hash {
262    /// The number of iterations used for hashing
263    iterations: NonZeroU32,
264    /// The salt used for hashing
265    salt: [u8; 32],
266    /// The server key
267    server_key: [u8; SHA256_OUTPUT_LEN],
268    /// The stored key
269    stored_key: [u8; SHA256_OUTPUT_LEN],
270}
271
272impl Display for ScramSha256Hash {
273    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274        write!(
275            f,
276            "SCRAM-SHA-256${}:{}${}:{}",
277            self.iterations,
278            BASE64_STANDARD.encode(&self.salt),
279            BASE64_STANDARD.encode(&self.stored_key),
280            BASE64_STANDARD.encode(&self.server_key)
281        )
282    }
283}
284
285fn scram256_hash_inner(hashed_password: PasswordHash) -> ScramSha256Hash {
286    let signing_key = openssl::pkey::PKey::hmac(&hashed_password.hash).unwrap();
287    let mut signer =
288        openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key).unwrap();
289    signer.update(b"Client Key").unwrap();
290    let client_key = signer.sign_to_vec().unwrap();
291    let stored_key = openssl::sha::sha256(&client_key);
292    let mut signer =
293        openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key).unwrap();
294    signer.update(b"Server Key").unwrap();
295    let mut server_key: [u8; SHA256_OUTPUT_LEN] = [0; SHA256_OUTPUT_LEN];
296    signer.sign(server_key.as_mut()).unwrap();
297
298    ScramSha256Hash {
299        iterations: hashed_password.iterations,
300        salt: hashed_password.salt,
301        server_key,
302        stored_key,
303    }
304}
305
306fn hash_password_inner(
307    opts: &HashOpts,
308    password: &[u8],
309) -> Result<[u8; SHA256_OUTPUT_LEN], HashError> {
310    let mut salted_password = [0u8; SHA256_OUTPUT_LEN];
311    openssl::pkcs5::pbkdf2_hmac(
312        password,
313        &opts.salt,
314        opts.iterations.get().try_into().unwrap(),
315        openssl::hash::MessageDigest::sha256(),
316        &mut salted_password,
317    )
318    .map_err(HashError::Openssl)?;
319    Ok(salted_password)
320}
321
322#[cfg(test)]
323mod tests {
324    use itertools::Itertools;
325
326    use super::*;
327
328    const DEFAULT_ITERATIONS: NonZeroU32 = NonZeroU32::new(60).expect("Trust me on this");
329
330    #[mz_ore::test]
331    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux`
332    fn test_hash_password() {
333        let password = "password".to_string();
334        let iterations = NonZeroU32::new(100).expect("Trust me on this");
335        let hashed_password =
336            hash_password(&password.into(), &iterations).expect("Failed to hash password");
337        assert_eq!(hashed_password.iterations, iterations);
338        assert_eq!(hashed_password.salt.len(), DEFAULT_SALT_SIZE);
339        assert_eq!(hashed_password.hash.len(), SHA256_OUTPUT_LEN);
340    }
341
342    #[mz_ore::test]
343    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux`
344    fn test_scram256_hash() {
345        let password = "password".into();
346        let scram_hash =
347            scram256_hash(&password, &DEFAULT_ITERATIONS).expect("Failed to hash password");
348
349        let res = scram256_verify(&password, &scram_hash);
350        assert!(res.is_ok());
351        let res = scram256_verify(&"wrong_password".into(), &scram_hash);
352        assert!(res.is_err());
353    }
354
355    #[mz_ore::test]
356    fn test_scram256_parse_opts() {
357        let salt = "9bkIQQjQ7f1OwPsXZGC/YfIkbZsOMDXK0cxxvPBaSfM=";
358        let hashed_password = format!("SCRAM-SHA-256$600000:{}$client-key:server-key", salt);
359        let opts = scram256_parse_opts(&hashed_password);
360
361        assert!(opts.is_ok());
362        let opts = opts.unwrap();
363        assert_eq!(
364            opts.iterations,
365            NonZeroU32::new(600_000).expect("known valid")
366        );
367        assert_eq!(opts.salt.len(), DEFAULT_SALT_SIZE);
368        let decoded_salt = BASE64_STANDARD.decode(salt).expect("Failed to decode salt");
369        assert_eq!(opts.salt, decoded_salt.as_ref());
370    }
371
372    #[mz_ore::test]
373    #[cfg_attr(miri, ignore)]
374    fn test_mock_sasl_challenge() {
375        let username = "alice";
376        let mock = "cnonce";
377        let opts1 = mock_sasl_challenge(username, mock, &DEFAULT_ITERATIONS);
378        let opts2 = mock_sasl_challenge(username, mock, &DEFAULT_ITERATIONS);
379        assert_eq!(opts1, opts2);
380    }
381
382    #[mz_ore::test]
383    #[cfg_attr(miri, ignore)]
384    fn test_sasl_verify_success() {
385        let password: Password = "password".into();
386        let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
387        let auth_message = "n=user,r=clientnonce,s=somesalt"; // arbitrary auth message
388
389        // Parse client_key and server_key from the SCRAM hash
390        // Format: SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>
391        let parts: Vec<&str> = hashed_password.split('$').collect();
392        assert_eq!(parts.len(), 3);
393        let key_parts: Vec<&str> = parts[2].split(':').collect();
394        assert_eq!(key_parts.len(), 2);
395        let stored_key = BASE64_STANDARD
396            .decode(key_parts[0])
397            .expect("decode stored key");
398        let server_key = BASE64_STANDARD
399            .decode(key_parts[1])
400            .expect("decode server key");
401
402        // Simulate client generating a proof
403        let client_proof: Vec<u8> = {
404            // client_key = HMAC(salted_password, "Client Key")
405            let opts = scram256_parse_opts(&hashed_password).expect("parse opts");
406            let salted_password = hash_password_with_opts(&opts, &password)
407                .expect("hash password")
408                .hash;
409            let signing_key = openssl::pkey::PKey::hmac(&salted_password).expect("signing key");
410            let mut signer =
411                openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key)
412                    .expect("signer");
413            signer.update(b"Client Key").expect("update");
414            let client_key = signer.sign_to_vec().expect("client key");
415            // client_proof = client_key XOR client_signature
416            let client_signature =
417                generate_signature(&stored_key, auth_message).expect("client signature");
418            client_key
419                .iter()
420                .zip_eq(client_signature.iter())
421                .map(|(c, s)| c ^ s)
422                .collect::<Vec<u8>>()
423        };
424
425        let client_proof_b64 = BASE64_STANDARD.encode(&client_proof);
426
427        let verifier = sasl_verify(&hashed_password, &client_proof_b64, auth_message)
428            .expect("sasl_verify should succeed");
429
430        // Expected verifier: HMAC(server_key, auth_message)
431        let expected_verifier = BASE64_STANDARD
432            .encode(&generate_signature(&server_key, auth_message).expect("server verifier"));
433        assert_eq!(verifier, expected_verifier);
434    }
435
436    #[mz_ore::test]
437    #[cfg_attr(miri, ignore)]
438    fn test_sasl_verify_invalid_proof() {
439        let password: Password = "password".into();
440        let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
441        let auth_message = "n=user,r=clientnonce,s=somesalt";
442        // Provide an obviously invalid base64 proof (different size / random)
443        let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
444        let res = sasl_verify(&hashed_password, &bad_proof, auth_message);
445        assert!(matches!(res, Err(VerifyError::InvalidPassword)));
446    }
447
448    #[mz_ore::test]
449    fn test_sasl_verify_malformed_hash() {
450        let malformed_hash = "NOT-SCRAM$bad"; // clearly malformed (wrong parts count)
451        let auth_message = "n=user,r=clientnonce,s=somesalt";
452        let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
453        let res = sasl_verify(malformed_hash, &bad_proof, auth_message);
454        assert!(matches!(res, Err(VerifyError::MalformedHash)));
455    }
456
457    #[mz_ore::test]
458    #[cfg_attr(miri, ignore)]
459    fn test_sasl_verify_truncated_proof_no_panic() {
460        // A truncated client proof (not 32 bytes) should return InvalidPassword, not panic
461        let password: Password = "password".into();
462        let hashed_password = scram256_hash(&password, &DEFAULT_ITERATIONS).expect("hash password");
463        let auth_message = "n=user,r=clientnonce,s=somesalt";
464
465        // Truncated proof: 16 bytes instead of the expected 32 (SHA-256 output)
466        let truncated_proof = BASE64_STANDARD.encode([0u8; 16]);
467        let res = sasl_verify(&hashed_password, &truncated_proof, auth_message);
468        assert!(
469            matches!(res, Err(VerifyError::InvalidPassword)),
470            "truncated proof should return InvalidPassword, not panic"
471        );
472
473        // Oversized proof: 64 bytes instead of 32
474        let oversized_proof = BASE64_STANDARD.encode([0u8; 64]);
475        let res = sasl_verify(&hashed_password, &oversized_proof, auth_message);
476        assert!(
477            matches!(res, Err(VerifyError::InvalidPassword)),
478            "oversized proof should return InvalidPassword, not panic"
479        );
480
481        // Empty proof
482        let empty_proof = BASE64_STANDARD.encode([0u8; 0]);
483        let res = sasl_verify(&hashed_password, &empty_proof, auth_message);
484        assert!(
485            matches!(res, Err(VerifyError::InvalidPassword)),
486            "empty proof should return InvalidPassword, not panic"
487        );
488    }
489}