1use base64::display::Base64Display;
4use base64::engine::general_purpose::STANDARD;
5use base64::Engine;
6use hmac::{Hmac, Mac};
7use rand::{self, Rng};
8use sha2::digest::FixedOutput;
9use sha2::{Digest, Sha256};
10use std::fmt::Write;
11use std::io;
12use std::iter;
13use std::mem;
14use std::str;
15
16const NONCE_LENGTH: usize = 24;
17
18pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
20pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
22
23fn normalize(pass: &[u8]) -> Vec<u8> {
27 let pass = match str::from_utf8(pass) {
28 Ok(pass) => pass,
29 Err(_) => return pass.to_vec(),
30 };
31
32 match stringprep::saslprep(pass) {
33 Ok(pass) => pass.into_owned().into_bytes(),
34 Err(_) => pass.as_bytes().to_vec(),
35 }
36}
37
38pub(crate) fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] {
39 let mut hmac =
40 Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
41 hmac.update(salt);
42 hmac.update(&[0, 0, 0, 1]);
43 let mut prev = hmac.finalize().into_bytes();
44
45 let mut hi = prev;
46
47 for _ in 1..i {
48 let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
49 hmac.update(&prev);
50 prev = hmac.finalize().into_bytes();
51
52 for (hi, prev) in hi.iter_mut().zip(prev) {
53 *hi ^= prev;
54 }
55 }
56
57 hi.into()
58}
59
60enum ChannelBindingInner {
61 Unrequested,
62 Unsupported,
63 TlsServerEndPoint(Vec<u8>),
64}
65
66pub struct ChannelBinding(ChannelBindingInner);
68
69impl ChannelBinding {
70 pub fn unrequested() -> ChannelBinding {
72 ChannelBinding(ChannelBindingInner::Unrequested)
73 }
74
75 pub fn unsupported() -> ChannelBinding {
77 ChannelBinding(ChannelBindingInner::Unsupported)
78 }
79
80 pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
83 ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
84 }
85
86 fn gs2_header(&self) -> &'static str {
87 match self.0 {
88 ChannelBindingInner::Unrequested => "y,,",
89 ChannelBindingInner::Unsupported => "n,,",
90 ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
91 }
92 }
93
94 fn cbind_data(&self) -> &[u8] {
95 match self.0 {
96 ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
97 ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
98 }
99 }
100}
101
102enum State {
103 Update {
104 nonce: String,
105 password: Vec<u8>,
106 channel_binding: ChannelBinding,
107 },
108 Finish {
109 salted_password: [u8; 32],
110 auth_message: String,
111 },
112 Done,
113}
114
115pub struct ScramSha256 {
131 message: String,
132 state: State,
133}
134
135impl ScramSha256 {
136 pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
138 let mut rng = rand::thread_rng();
140 let nonce = (0..NONCE_LENGTH)
141 .map(|_| {
142 let mut v = rng.gen_range(0x21u8..0x7e);
143 if v == 0x2c {
144 v = 0x7e
145 }
146 v as char
147 })
148 .collect::<String>();
149
150 ScramSha256::new_inner(password, channel_binding, nonce)
151 }
152
153 fn new_inner(password: &[u8], channel_binding: ChannelBinding, nonce: String) -> ScramSha256 {
154 ScramSha256 {
155 message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
156 state: State::Update {
157 nonce,
158 password: normalize(password),
159 channel_binding,
160 },
161 }
162 }
163
164 pub fn message(&self) -> &[u8] {
166 if let State::Done = self.state {
167 panic!("invalid SCRAM state");
168 }
169 self.message.as_bytes()
170 }
171
172 pub fn update(&mut self, message: &[u8]) -> io::Result<()> {
176 let (client_nonce, password, channel_binding) =
177 match mem::replace(&mut self.state, State::Done) {
178 State::Update {
179 nonce,
180 password,
181 channel_binding,
182 } => (nonce, password, channel_binding),
183 _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
184 };
185
186 let message =
187 str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
188
189 let parsed = Parser::new(message).server_first_message()?;
190
191 if !parsed.nonce.starts_with(&client_nonce) {
192 return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
193 }
194
195 let salt = match STANDARD.decode(parsed.salt) {
196 Ok(salt) => salt,
197 Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
198 };
199
200 let salted_password = hi(&password, &salt, parsed.iteration_count);
201
202 let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
203 .expect("HMAC is able to accept all key sizes");
204 hmac.update(b"Client Key");
205 let client_key = hmac.finalize().into_bytes();
206
207 let mut hash = Sha256::default();
208 hash.update(client_key.as_slice());
209 let stored_key = hash.finalize_fixed();
210
211 let mut cbind_input = vec![];
212 cbind_input.extend(channel_binding.gs2_header().as_bytes());
213 cbind_input.extend(channel_binding.cbind_data());
214 let cbind_input = STANDARD.encode(&cbind_input);
215
216 self.message.clear();
217 write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
218
219 let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
220
221 let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
222 .expect("HMAC is able to accept all key sizes");
223 hmac.update(auth_message.as_bytes());
224 let client_signature = hmac.finalize().into_bytes();
225
226 let mut client_proof = client_key;
227 for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
228 *proof ^= signature;
229 }
230
231 write!(
232 &mut self.message,
233 ",p={}",
234 Base64Display::new(&client_proof, &STANDARD)
235 )
236 .unwrap();
237
238 self.state = State::Finish {
239 salted_password,
240 auth_message,
241 };
242 Ok(())
243 }
244
245 pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
250 let (salted_password, auth_message) = match mem::replace(&mut self.state, State::Done) {
251 State::Finish {
252 salted_password,
253 auth_message,
254 } => (salted_password, auth_message),
255 _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
256 };
257
258 let message =
259 str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
260
261 let parsed = Parser::new(message).server_final_message()?;
262
263 let verifier = match parsed {
264 ServerFinalMessage::Error(e) => {
265 return Err(io::Error::new(
266 io::ErrorKind::Other,
267 format!("SCRAM error: {}", e),
268 ));
269 }
270 ServerFinalMessage::Verifier(verifier) => verifier,
271 };
272
273 let verifier = match STANDARD.decode(verifier) {
274 Ok(verifier) => verifier,
275 Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
276 };
277
278 let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
279 .expect("HMAC is able to accept all key sizes");
280 hmac.update(b"Server Key");
281 let server_key = hmac.finalize().into_bytes();
282
283 let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
284 .expect("HMAC is able to accept all key sizes");
285 hmac.update(auth_message.as_bytes());
286 hmac.verify_slice(&verifier)
287 .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
288 }
289}
290
291struct Parser<'a> {
292 s: &'a str,
293 it: iter::Peekable<str::CharIndices<'a>>,
294}
295
296impl<'a> Parser<'a> {
297 fn new(s: &'a str) -> Parser<'a> {
298 Parser {
299 s,
300 it: s.char_indices().peekable(),
301 }
302 }
303
304 fn eat(&mut self, target: char) -> io::Result<()> {
305 match self.it.next() {
306 Some((_, c)) if c == target => Ok(()),
307 Some((i, c)) => {
308 let m = format!(
309 "unexpected character at byte {}: expected `{}` but got `{}",
310 i, target, c
311 );
312 Err(io::Error::new(io::ErrorKind::InvalidInput, m))
313 }
314 None => Err(io::Error::new(
315 io::ErrorKind::UnexpectedEof,
316 "unexpected EOF",
317 )),
318 }
319 }
320
321 fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
322 where
323 F: Fn(char) -> bool,
324 {
325 let start = match self.it.peek() {
326 Some(&(i, _)) => i,
327 None => return Ok(""),
328 };
329
330 loop {
331 match self.it.peek() {
332 Some(&(_, c)) if f(c) => {
333 self.it.next();
334 }
335 Some(&(i, _)) => return Ok(&self.s[start..i]),
336 None => return Ok(&self.s[start..]),
337 }
338 }
339 }
340
341 fn printable(&mut self) -> io::Result<&'a str> {
342 self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
343 }
344
345 fn nonce(&mut self) -> io::Result<&'a str> {
346 self.eat('r')?;
347 self.eat('=')?;
348 self.printable()
349 }
350
351 fn base64(&mut self) -> io::Result<&'a str> {
352 self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
353 }
354
355 fn salt(&mut self) -> io::Result<&'a str> {
356 self.eat('s')?;
357 self.eat('=')?;
358 self.base64()
359 }
360
361 fn posit_number(&mut self) -> io::Result<u32> {
362 let n = self.take_while(|c| c.is_ascii_digit())?;
363 n.parse()
364 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
365 }
366
367 fn iteration_count(&mut self) -> io::Result<u32> {
368 self.eat('i')?;
369 self.eat('=')?;
370 self.posit_number()
371 }
372
373 fn eof(&mut self) -> io::Result<()> {
374 match self.it.peek() {
375 Some(&(i, _)) => Err(io::Error::new(
376 io::ErrorKind::InvalidInput,
377 format!("unexpected trailing data at byte {}", i),
378 )),
379 None => Ok(()),
380 }
381 }
382
383 fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
384 let nonce = self.nonce()?;
385 self.eat(',')?;
386 let salt = self.salt()?;
387 self.eat(',')?;
388 let iteration_count = self.iteration_count()?;
389 self.eof()?;
390
391 Ok(ServerFirstMessage {
392 nonce,
393 salt,
394 iteration_count,
395 })
396 }
397
398 fn value(&mut self) -> io::Result<&'a str> {
399 self.take_while(|c| matches!(c, '\0' | '=' | ','))
400 }
401
402 fn server_error(&mut self) -> io::Result<Option<&'a str>> {
403 match self.it.peek() {
404 Some(&(_, 'e')) => {}
405 _ => return Ok(None),
406 }
407
408 self.eat('e')?;
409 self.eat('=')?;
410 self.value().map(Some)
411 }
412
413 fn verifier(&mut self) -> io::Result<&'a str> {
414 self.eat('v')?;
415 self.eat('=')?;
416 self.base64()
417 }
418
419 fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
420 let message = match self.server_error()? {
421 Some(error) => ServerFinalMessage::Error(error),
422 None => ServerFinalMessage::Verifier(self.verifier()?),
423 };
424 self.eof()?;
425 Ok(message)
426 }
427}
428
429struct ServerFirstMessage<'a> {
430 nonce: &'a str,
431 salt: &'a str,
432 iteration_count: u32,
433}
434
435enum ServerFinalMessage<'a> {
436 Error(&'a str),
437 Verifier(&'a str),
438}
439
440#[cfg(test)]
441mod test {
442 use super::*;
443
444 #[test]
445 fn parse_server_first_message() {
446 let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
447 let message = Parser::new(message).server_first_message().unwrap();
448 assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
449 assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
450 assert_eq!(message.iteration_count, 4096);
451 }
452
453 #[test]
455 fn exchange() {
456 let password = "foobar";
457 let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
458
459 let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
460 let server_first =
461 "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
462 =4096";
463 let client_final =
464 "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
465 1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
466 let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
467
468 let mut scram = ScramSha256::new_inner(
469 password.as_bytes(),
470 ChannelBinding::unsupported(),
471 nonce.to_string(),
472 );
473 assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
474
475 scram.update(server_first.as_bytes()).unwrap();
476 assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
477
478 scram.finish(server_final.as_bytes()).unwrap();
479 }
480}