1#![allow(rustdoc::invalid_html_tags)]
12
13use std::fmt::Display;
14use std::num::NonZeroU32;
15
16use base64::prelude::*;
17use itertools::Itertools;
18
19use crate::password::Password;
20
21const DEFAULT_ITERATIONS: NonZeroU32 = NonZeroU32::new(600_000).unwrap();
24
25const DEFAULT_SALT_SIZE: usize = 32;
27
28const SHA256_OUTPUT_LEN: usize = 32;
29
30#[derive(Debug, PartialEq)]
32pub struct HashOpts {
33 pub iterations: NonZeroU32,
35 pub salt: [u8; DEFAULT_SALT_SIZE],
39}
40
41pub struct PasswordHash {
42 pub salt: [u8; DEFAULT_SALT_SIZE],
44 pub iterations: NonZeroU32,
46 pub hash: [u8; SHA256_OUTPUT_LEN],
49}
50
51#[derive(Debug)]
52pub enum VerifyError {
53 MalformedHash,
54 InvalidPassword,
55 Hash(HashError),
56}
57
58#[derive(Debug)]
59pub enum HashError {
60 Openssl(openssl::error::ErrorStack),
61}
62
63impl Display for HashError {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 match self {
66 HashError::Openssl(e) => write!(f, "OpenSSL error: {}", e),
67 }
68 }
69}
70
71pub fn hash_password(password: &Password) -> Result<PasswordHash, HashError> {
74 let mut salt = [0u8; DEFAULT_SALT_SIZE];
75 openssl::rand::rand_bytes(&mut salt).map_err(HashError::Openssl)?;
76
77 let hash = hash_password_inner(
78 &HashOpts {
79 iterations: DEFAULT_ITERATIONS,
80 salt,
81 },
82 password.to_string().as_bytes(),
83 )?;
84
85 Ok(PasswordHash {
86 salt,
87 iterations: DEFAULT_ITERATIONS,
88 hash,
89 })
90}
91
92pub fn generate_nonce(client_nonce: &str) -> Result<String, HashError> {
93 let mut nonce = [0u8; 24];
94 openssl::rand::rand_bytes(&mut nonce).map_err(HashError::Openssl)?;
95 let nonce = BASE64_STANDARD.encode(&nonce);
96 let new_nonce = format!("{}{}", client_nonce, nonce);
97 Ok(new_nonce)
98}
99
100pub fn hash_password_with_opts(
103 opts: &HashOpts,
104 password: &Password,
105) -> Result<PasswordHash, HashError> {
106 let hash = hash_password_inner(opts, password.to_string().as_bytes())?;
107
108 Ok(PasswordHash {
109 salt: opts.salt,
110 iterations: opts.iterations,
111 hash,
112 })
113}
114
115pub fn scram256_hash(password: &Password) -> Result<String, HashError> {
119 let hashed_password = hash_password(password)?;
120 Ok(scram256_hash_inner(hashed_password).to_string())
121}
122
123fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
124 if a.len() != b.len() {
125 return false;
126 }
127 openssl::memcmp::eq(a, b)
128}
129
130pub fn scram256_verify(password: &Password, hashed_password: &str) -> Result<(), VerifyError> {
132 let opts = scram256_parse_opts(hashed_password)?;
133 let hashed = hash_password_with_opts(&opts, password).map_err(VerifyError::Hash)?;
134 let scram = scram256_hash_inner(hashed);
135 if constant_time_compare(hashed_password.as_bytes(), scram.to_string().as_bytes()) {
136 Ok(())
137 } else {
138 Err(VerifyError::InvalidPassword)
139 }
140}
141
142pub fn sasl_verify(
143 hashed_password: &str,
144 proof: &str,
145 auth_message: &str,
146) -> Result<String, VerifyError> {
147 let parts: Vec<&str> = hashed_password.split('$').collect();
149 if parts.len() != 3 {
150 return Err(VerifyError::MalformedHash);
151 }
152 let auth_info = parts[1].split(':').collect::<Vec<&str>>();
153 if auth_info.len() != 2 {
154 return Err(VerifyError::MalformedHash);
155 }
156 let auth_value = parts[2].split(':').collect::<Vec<&str>>();
157 if auth_value.len() != 2 {
158 return Err(VerifyError::MalformedHash);
159 }
160
161 let stored_key = BASE64_STANDARD
162 .decode(auth_value[0])
163 .map_err(|_| VerifyError::MalformedHash)?;
164 let server_key = BASE64_STANDARD
165 .decode(auth_value[1])
166 .map_err(|_| VerifyError::MalformedHash)?;
167
168 let client_signature = generate_signature(&stored_key, auth_message)?;
170
171 let provided_client_proof = BASE64_STANDARD
173 .decode(proof)
174 .map_err(|_| VerifyError::InvalidPassword)?;
175
176 let client_key: Vec<u8> = provided_client_proof
178 .iter()
179 .zip_eq(client_signature.iter())
180 .map(|(p, s)| p ^ s)
181 .collect();
182
183 if !constant_time_compare(&openssl::sha::sha256(&client_key), &stored_key) {
184 return Err(VerifyError::InvalidPassword);
185 }
186
187 let verifier = generate_signature(&server_key, auth_message)?;
189 Ok(BASE64_STANDARD.encode(&verifier))
190}
191
192fn generate_signature(key: &[u8], message: &str) -> Result<Vec<u8>, VerifyError> {
193 let signing_key =
194 openssl::pkey::PKey::hmac(key).map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
195 let mut signer =
196 openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key)
197 .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
198 signer
199 .update(message.as_bytes())
200 .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
201 let signature = signer
202 .sign_to_vec()
203 .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?;
204 Ok(signature)
205}
206
207pub fn mock_sasl_challenge(username: &str, mock_nonce: &str) -> HashOpts {
211 let mut buf = Vec::with_capacity(username.len() + mock_nonce.len());
212 buf.extend_from_slice(username.as_bytes());
213 buf.extend_from_slice(mock_nonce.as_bytes());
214 let digest = openssl::sha::sha256(&buf);
215
216 HashOpts {
217 iterations: DEFAULT_ITERATIONS,
218 salt: digest,
219 }
220}
221
222pub fn scram256_parse_opts(hashed_password: &str) -> Result<HashOpts, VerifyError> {
224 let parts: Vec<&str> = hashed_password.split('$').collect();
225 if parts.len() != 3 {
226 return Err(VerifyError::MalformedHash);
227 }
228 let scheme = parts[0];
229 if scheme != "SCRAM-SHA-256" {
230 return Err(VerifyError::MalformedHash);
231 }
232 let auth_info = parts[1].split(':').collect::<Vec<&str>>();
233 if auth_info.len() != 2 {
234 return Err(VerifyError::MalformedHash);
235 }
236 let auth_value = parts[2].split(':').collect::<Vec<&str>>();
237 if auth_value.len() != 2 {
238 return Err(VerifyError::MalformedHash);
239 }
240
241 let iterations = auth_info[0]
242 .parse::<u32>()
243 .map_err(|_| VerifyError::MalformedHash)?;
244
245 let salt = BASE64_STANDARD
246 .decode(auth_info[1])
247 .map_err(|_| VerifyError::MalformedHash)?;
248
249 let salt = salt.try_into().map_err(|_| VerifyError::MalformedHash)?;
250
251 Ok(HashOpts {
252 iterations: NonZeroU32::new(iterations).ok_or(VerifyError::MalformedHash)?,
253 salt,
254 })
255}
256
257struct ScramSha256Hash {
259 iterations: NonZeroU32,
261 salt: [u8; 32],
263 server_key: [u8; SHA256_OUTPUT_LEN],
265 stored_key: [u8; SHA256_OUTPUT_LEN],
267}
268
269impl Display for ScramSha256Hash {
270 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271 write!(
272 f,
273 "SCRAM-SHA-256${}:{}${}:{}",
274 self.iterations,
275 BASE64_STANDARD.encode(&self.salt),
276 BASE64_STANDARD.encode(&self.stored_key),
277 BASE64_STANDARD.encode(&self.server_key)
278 )
279 }
280}
281
282fn scram256_hash_inner(hashed_password: PasswordHash) -> ScramSha256Hash {
283 let signing_key = openssl::pkey::PKey::hmac(&hashed_password.hash).unwrap();
284 let mut signer =
285 openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key).unwrap();
286 signer.update(b"Client Key").unwrap();
287 let client_key = signer.sign_to_vec().unwrap();
288 let stored_key = openssl::sha::sha256(&client_key);
289 let mut signer =
290 openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key).unwrap();
291 signer.update(b"Server Key").unwrap();
292 let mut server_key: [u8; SHA256_OUTPUT_LEN] = [0; SHA256_OUTPUT_LEN];
293 signer.sign(server_key.as_mut()).unwrap();
294
295 ScramSha256Hash {
296 iterations: hashed_password.iterations,
297 salt: hashed_password.salt,
298 server_key,
299 stored_key,
300 }
301}
302
303fn hash_password_inner(
304 opts: &HashOpts,
305 password: &[u8],
306) -> Result<[u8; SHA256_OUTPUT_LEN], HashError> {
307 let mut salted_password = [0u8; SHA256_OUTPUT_LEN];
308 openssl::pkcs5::pbkdf2_hmac(
309 password,
310 &opts.salt,
311 opts.iterations.get().try_into().unwrap(),
312 openssl::hash::MessageDigest::sha256(),
313 &mut salted_password,
314 )
315 .map_err(HashError::Openssl)?;
316 Ok(salted_password)
317}
318
319#[cfg(test)]
320mod tests {
321 use itertools::Itertools;
322
323 use super::*;
324
325 #[mz_ore::test]
326 #[cfg_attr(miri, ignore)] fn test_hash_password() {
328 let password = "password".to_string();
329 let hashed_password = hash_password(&password.into()).expect("Failed to hash password");
330 assert_eq!(hashed_password.iterations, DEFAULT_ITERATIONS);
331 assert_eq!(hashed_password.salt.len(), DEFAULT_SALT_SIZE);
332 assert_eq!(hashed_password.hash.len(), SHA256_OUTPUT_LEN);
333 }
334
335 #[mz_ore::test]
336 #[cfg_attr(miri, ignore)] fn test_scram256_hash() {
338 let password = "password".into();
339 let scram_hash = scram256_hash(&password).expect("Failed to hash password");
340
341 let res = scram256_verify(&password, &scram_hash);
342 assert!(res.is_ok());
343 let res = scram256_verify(&"wrong_password".into(), &scram_hash);
344 assert!(res.is_err());
345 }
346
347 #[mz_ore::test]
348 fn test_scram256_parse_opts() {
349 let salt = "9bkIQQjQ7f1OwPsXZGC/YfIkbZsOMDXK0cxxvPBaSfM=";
350 let hashed_password = format!("SCRAM-SHA-256$600000:{}$client-key:server-key", salt);
351 let opts = scram256_parse_opts(&hashed_password);
352
353 assert!(opts.is_ok());
354 let opts = opts.unwrap();
355 assert_eq!(opts.iterations, DEFAULT_ITERATIONS);
356 assert_eq!(opts.salt.len(), DEFAULT_SALT_SIZE);
357 let decoded_salt = BASE64_STANDARD.decode(salt).expect("Failed to decode salt");
358 assert_eq!(opts.salt, decoded_salt.as_ref());
359 }
360
361 #[mz_ore::test]
362 #[cfg_attr(miri, ignore)]
363 fn test_mock_sasl_challenge() {
364 let username = "alice";
365 let mock = "cnonce";
366 let opts1 = mock_sasl_challenge(username, mock);
367 let opts2 = mock_sasl_challenge(username, mock);
368 assert_eq!(opts1, opts2);
369 }
370
371 #[mz_ore::test]
372 #[cfg_attr(miri, ignore)]
373 fn test_sasl_verify_success() {
374 let password: Password = "password".into();
375 let hashed_password = scram256_hash(&password).expect("hash password");
376 let auth_message = "n=user,r=clientnonce,s=somesalt"; let parts: Vec<&str> = hashed_password.split('$').collect();
381 assert_eq!(parts.len(), 3);
382 let key_parts: Vec<&str> = parts[2].split(':').collect();
383 assert_eq!(key_parts.len(), 2);
384 let stored_key = BASE64_STANDARD
385 .decode(key_parts[0])
386 .expect("decode stored key");
387 let server_key = BASE64_STANDARD
388 .decode(key_parts[1])
389 .expect("decode server key");
390
391 let client_proof: Vec<u8> = {
393 let opts = scram256_parse_opts(&hashed_password).expect("parse opts");
395 let salted_password = hash_password_with_opts(&opts, &password)
396 .expect("hash password")
397 .hash;
398 let signing_key = openssl::pkey::PKey::hmac(&salted_password).expect("signing key");
399 let mut signer =
400 openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key)
401 .expect("signer");
402 signer.update(b"Client Key").expect("update");
403 let client_key = signer.sign_to_vec().expect("client key");
404 let client_signature =
406 generate_signature(&stored_key, auth_message).expect("client signature");
407 client_key
408 .iter()
409 .zip_eq(client_signature.iter())
410 .map(|(c, s)| c ^ s)
411 .collect::<Vec<u8>>()
412 };
413
414 let client_proof_b64 = BASE64_STANDARD.encode(&client_proof);
415
416 let verifier = sasl_verify(&hashed_password, &client_proof_b64, auth_message)
417 .expect("sasl_verify should succeed");
418
419 let expected_verifier = BASE64_STANDARD
421 .encode(&generate_signature(&server_key, auth_message).expect("server verifier"));
422 assert_eq!(verifier, expected_verifier);
423 }
424
425 #[mz_ore::test]
426 #[cfg_attr(miri, ignore)]
427 fn test_sasl_verify_invalid_proof() {
428 let password: Password = "password".into();
429 let hashed_password = scram256_hash(&password).expect("hash password");
430 let auth_message = "n=user,r=clientnonce,s=somesalt";
431 let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
433 let res = sasl_verify(&hashed_password, &bad_proof, auth_message);
434 assert!(matches!(res, Err(VerifyError::InvalidPassword)));
435 }
436
437 #[mz_ore::test]
438 fn test_sasl_verify_malformed_hash() {
439 let malformed_hash = "NOT-SCRAM$bad"; let auth_message = "n=user,r=clientnonce,s=somesalt";
441 let bad_proof = BASE64_STANDARD.encode([0u8; 32]);
442 let res = sasl_verify(malformed_hash, &bad_proof, auth_message);
443 assert!(matches!(res, Err(VerifyError::MalformedHash)));
444 }
445}