1#![allow(rustdoc::invalid_html_tags)]
12
13use std::fmt::Display;
14use std::num::NonZeroU32;
15
16use base64::prelude::*;
17
18use crate::password::Password;
19
20const DEFAULT_ITERATIONS: NonZeroU32 = NonZeroU32::new(600_000).unwrap();
23
24const DEFAULT_SALT_SIZE: usize = 32;
26
27const SHA256_OUTPUT_LEN: usize = 32;
28
29pub struct HashOpts {
31 pub iterations: NonZeroU32,
33 pub salt: [u8; DEFAULT_SALT_SIZE],
37}
38
39pub struct PasswordHash {
40 pub salt: [u8; DEFAULT_SALT_SIZE],
42 pub iterations: NonZeroU32,
44 pub hash: [u8; SHA256_OUTPUT_LEN],
47}
48
49#[derive(Debug)]
50pub enum VerifyError {
51 MalformedHash,
52 InvalidPassword,
53 Hash(HashError),
54}
55
56#[derive(Debug)]
57pub enum HashError {
58 Openssl(openssl::error::ErrorStack),
59}
60
61pub fn hash_password(password: &Password) -> Result<PasswordHash, HashError> {
64 let mut salt = [0u8; DEFAULT_SALT_SIZE];
65 openssl::rand::rand_bytes(&mut salt).map_err(HashError::Openssl)?;
66
67 let hash = hash_password_inner(
68 &HashOpts {
69 iterations: DEFAULT_ITERATIONS,
70 salt,
71 },
72 password.to_string().as_bytes(),
73 )?;
74
75 Ok(PasswordHash {
76 salt,
77 iterations: DEFAULT_ITERATIONS,
78 hash,
79 })
80}
81
82pub fn hash_password_with_opts(
85 opts: &HashOpts,
86 password: &Password,
87) -> Result<PasswordHash, HashError> {
88 let hash = hash_password_inner(opts, password.to_string().as_bytes())?;
89
90 Ok(PasswordHash {
91 salt: opts.salt,
92 iterations: opts.iterations,
93 hash,
94 })
95}
96
97pub fn scram256_hash(password: &Password) -> Result<String, HashError> {
101 let hashed_password = hash_password(password)?;
102 Ok(scram256_hash_inner(hashed_password).to_string())
103}
104
105pub fn scram256_verify(password: &Password, hashed_password: &str) -> Result<(), VerifyError> {
107 let opts = scram256_parse_opts(hashed_password)?;
108 let hashed = hash_password_with_opts(&opts, password).map_err(VerifyError::Hash)?;
109 let scram = scram256_hash_inner(hashed);
110 if *hashed_password == scram.to_string() {
111 Ok(())
112 } else {
113 Err(VerifyError::InvalidPassword)
114 }
115}
116
117fn scram256_parse_opts(hashed_password: &str) -> Result<HashOpts, VerifyError> {
119 let parts: Vec<&str> = hashed_password.split('$').collect();
120 if parts.len() != 3 {
121 return Err(VerifyError::MalformedHash);
122 }
123 let scheme = parts[0];
124 if scheme != "SCRAM-SHA-256" {
125 return Err(VerifyError::MalformedHash);
126 }
127 let auth_info = parts[1].split(':').collect::<Vec<&str>>();
128 if auth_info.len() != 2 {
129 return Err(VerifyError::MalformedHash);
130 }
131 let auth_value = parts[2].split(':').collect::<Vec<&str>>();
132 if auth_value.len() != 2 {
133 return Err(VerifyError::MalformedHash);
134 }
135
136 let iterations = auth_info[0]
137 .parse::<u32>()
138 .map_err(|_| VerifyError::MalformedHash)?;
139
140 let salt = BASE64_STANDARD
141 .decode(auth_info[1])
142 .map_err(|_| VerifyError::MalformedHash)?;
143
144 let salt = salt.try_into().map_err(|_| VerifyError::MalformedHash)?;
145
146 Ok(HashOpts {
147 iterations: NonZeroU32::new(iterations).ok_or(VerifyError::MalformedHash)?,
148 salt,
149 })
150}
151
152struct ScramSha256Hash {
154 iterations: NonZeroU32,
156 salt: [u8; 32],
158 server_key: [u8; SHA256_OUTPUT_LEN],
160 client_key: [u8; SHA256_OUTPUT_LEN],
162}
163
164impl Display for ScramSha256Hash {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 write!(
167 f,
168 "SCRAM-SHA-256${}:{}${}:{}",
169 self.iterations,
170 BASE64_STANDARD.encode(&self.salt),
171 BASE64_STANDARD.encode(&self.client_key),
172 BASE64_STANDARD.encode(&self.server_key)
173 )
174 }
175}
176
177fn scram256_hash_inner(hashed_password: PasswordHash) -> ScramSha256Hash {
178 let signing_key = openssl::pkey::PKey::hmac(&hashed_password.hash).unwrap();
179 let mut signer =
180 openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key).unwrap();
181 signer.update(b"Client Key").unwrap();
182 let client_key = signer.sign_to_vec().unwrap();
183 let mut signer =
184 openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key).unwrap();
185 signer.update(b"Server Key").unwrap();
186 let server_key = signer.sign_to_vec().unwrap();
187
188 ScramSha256Hash {
189 iterations: hashed_password.iterations,
190 salt: hashed_password.salt,
191 server_key: server_key.try_into().unwrap(),
192 client_key: client_key.try_into().unwrap(),
193 }
194}
195
196fn hash_password_inner(
197 opts: &HashOpts,
198 password: &[u8],
199) -> Result<[u8; SHA256_OUTPUT_LEN], HashError> {
200 let mut salted_password = [0u8; SHA256_OUTPUT_LEN];
201 openssl::pkcs5::pbkdf2_hmac(
202 password,
203 &opts.salt,
204 opts.iterations.get().try_into().unwrap(),
205 openssl::hash::MessageDigest::sha256(),
206 &mut salted_password,
207 )
208 .map_err(HashError::Openssl)?;
209 Ok(salted_password)
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[mz_ore::test]
217 #[cfg_attr(miri, ignore)] fn test_hash_password() {
219 let password = "password".to_string();
220 let hashed_password = hash_password(&password.into()).expect("Failed to hash password");
221 assert_eq!(hashed_password.iterations, DEFAULT_ITERATIONS);
222 assert_eq!(hashed_password.salt.len(), DEFAULT_SALT_SIZE);
223 assert_eq!(hashed_password.hash.len(), SHA256_OUTPUT_LEN);
224 }
225
226 #[mz_ore::test]
227 #[cfg_attr(miri, ignore)] fn test_scram256_hash() {
229 let password = "password".into();
230 let scram_hash = scram256_hash(&password).expect("Failed to hash password");
231
232 let res = scram256_verify(&password, &scram_hash);
233 assert!(res.is_ok());
234 let res = scram256_verify(&"wrong_password".into(), &scram_hash);
235 assert!(res.is_err());
236 }
237
238 #[mz_ore::test]
239 fn test_scram256_parse_opts() {
240 let salt = "9bkIQQjQ7f1OwPsXZGC/YfIkbZsOMDXK0cxxvPBaSfM=";
241 let hashed_password = format!("SCRAM-SHA-256$600000:{}$client-key:server-key", salt);
242 let opts = scram256_parse_opts(&hashed_password);
243
244 assert!(opts.is_ok());
245 let opts = opts.unwrap();
246 assert_eq!(opts.iterations, DEFAULT_ITERATIONS);
247 assert_eq!(opts.salt.len(), DEFAULT_SALT_SIZE);
248 let decoded_salt = BASE64_STANDARD.decode(salt).expect("Failed to decode salt");
249 assert_eq!(opts.salt, decoded_salt.as_ref());
250 }
251}