postgres_protocol/authentication/
sasl.rs

1//! SASL-based authentication support.
2
3use base64::display::Base64Display;
4use base64::engine::general_purpose::STANDARD;
5use base64::Engine;
6use hmac::{Hmac, Mac};
7use rand::{self, Rng};
8use sha2::digest::FixedOutput;
9use sha2::{Digest, Sha256};
10use std::fmt::Write;
11use std::io;
12use std::iter;
13use std::mem;
14use std::str;
15
16const NONCE_LENGTH: usize = 24;
17
18/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
19pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
20/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
21pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
22
23// since postgres passwords are not required to exclude saslprep-prohibited
24// characters or even be valid UTF8, we run saslprep if possible and otherwise
25// return the raw password.
26fn normalize(pass: &[u8]) -> Vec<u8> {
27    let pass = match str::from_utf8(pass) {
28        Ok(pass) => pass,
29        Err(_) => return pass.to_vec(),
30    };
31
32    match stringprep::saslprep(pass) {
33        Ok(pass) => pass.into_owned().into_bytes(),
34        Err(_) => pass.as_bytes().to_vec(),
35    }
36}
37
38pub(crate) fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] {
39    let mut hmac =
40        Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
41    hmac.update(salt);
42    hmac.update(&[0, 0, 0, 1]);
43    let mut prev = hmac.finalize().into_bytes();
44
45    let mut hi = prev;
46
47    for _ in 1..i {
48        let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
49        hmac.update(&prev);
50        prev = hmac.finalize().into_bytes();
51
52        for (hi, prev) in hi.iter_mut().zip(prev) {
53            *hi ^= prev;
54        }
55    }
56
57    hi.into()
58}
59
60enum ChannelBindingInner {
61    Unrequested,
62    Unsupported,
63    TlsServerEndPoint(Vec<u8>),
64}
65
66/// The channel binding configuration for a SCRAM authentication exchange.
67pub struct ChannelBinding(ChannelBindingInner);
68
69impl ChannelBinding {
70    /// The server did not request channel binding.
71    pub fn unrequested() -> ChannelBinding {
72        ChannelBinding(ChannelBindingInner::Unrequested)
73    }
74
75    /// The server requested channel binding but the client is unable to provide it.
76    pub fn unsupported() -> ChannelBinding {
77        ChannelBinding(ChannelBindingInner::Unsupported)
78    }
79
80    /// The server requested channel binding and the client will use the `tls-server-end-point`
81    /// method.
82    pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
83        ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
84    }
85
86    fn gs2_header(&self) -> &'static str {
87        match self.0 {
88            ChannelBindingInner::Unrequested => "y,,",
89            ChannelBindingInner::Unsupported => "n,,",
90            ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
91        }
92    }
93
94    fn cbind_data(&self) -> &[u8] {
95        match self.0 {
96            ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
97            ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
98        }
99    }
100}
101
102enum State {
103    Update {
104        nonce: String,
105        password: Vec<u8>,
106        channel_binding: ChannelBinding,
107    },
108    Finish {
109        salted_password: [u8; 32],
110        auth_message: String,
111    },
112    Done,
113}
114
115/// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
116/// process.
117///
118/// During the authentication process, if the backend sends an `AuthenticationSASL` message which
119/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
120///
121/// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be
122/// sent to the backend in a `SASLInitialResponse` message along with the mechanism name.
123///
124/// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be
125/// passed to the `update()` method, after which the buffer returned by the `message()` method
126/// should be sent to the backend in a `SASLResponse` message.
127///
128/// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed
129/// to the `finish()` method, after which the authentication process is complete.
130pub struct ScramSha256 {
131    message: String,
132    state: State,
133}
134
135impl ScramSha256 {
136    /// Constructs a new instance which will use the provided password for authentication.
137    pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
138        // rand 0.5's ThreadRng is cryptographically secure
139        let mut rng = rand::thread_rng();
140        let nonce = (0..NONCE_LENGTH)
141            .map(|_| {
142                let mut v = rng.gen_range(0x21u8..0x7e);
143                if v == 0x2c {
144                    v = 0x7e
145                }
146                v as char
147            })
148            .collect::<String>();
149
150        ScramSha256::new_inner(password, channel_binding, nonce)
151    }
152
153    fn new_inner(password: &[u8], channel_binding: ChannelBinding, nonce: String) -> ScramSha256 {
154        ScramSha256 {
155            message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
156            state: State::Update {
157                nonce,
158                password: normalize(password),
159                channel_binding,
160            },
161        }
162    }
163
164    /// Returns the message which should be sent to the backend in an `SASLResponse` message.
165    pub fn message(&self) -> &[u8] {
166        if let State::Done = self.state {
167            panic!("invalid SCRAM state");
168        }
169        self.message.as_bytes()
170    }
171
172    /// Updates the state machine with the response from the backend.
173    ///
174    /// This should be called when an `AuthenticationSASLContinue` message is received.
175    pub fn update(&mut self, message: &[u8]) -> io::Result<()> {
176        let (client_nonce, password, channel_binding) =
177            match mem::replace(&mut self.state, State::Done) {
178                State::Update {
179                    nonce,
180                    password,
181                    channel_binding,
182                } => (nonce, password, channel_binding),
183                _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
184            };
185
186        let message =
187            str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
188
189        let parsed = Parser::new(message).server_first_message()?;
190
191        if !parsed.nonce.starts_with(&client_nonce) {
192            return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
193        }
194
195        let salt = match STANDARD.decode(parsed.salt) {
196            Ok(salt) => salt,
197            Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
198        };
199
200        let salted_password = hi(&password, &salt, parsed.iteration_count);
201
202        let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
203            .expect("HMAC is able to accept all key sizes");
204        hmac.update(b"Client Key");
205        let client_key = hmac.finalize().into_bytes();
206
207        let mut hash = Sha256::default();
208        hash.update(client_key.as_slice());
209        let stored_key = hash.finalize_fixed();
210
211        let mut cbind_input = vec![];
212        cbind_input.extend(channel_binding.gs2_header().as_bytes());
213        cbind_input.extend(channel_binding.cbind_data());
214        let cbind_input = STANDARD.encode(&cbind_input);
215
216        self.message.clear();
217        write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
218
219        let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
220
221        let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
222            .expect("HMAC is able to accept all key sizes");
223        hmac.update(auth_message.as_bytes());
224        let client_signature = hmac.finalize().into_bytes();
225
226        let mut client_proof = client_key;
227        for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
228            *proof ^= signature;
229        }
230
231        write!(
232            &mut self.message,
233            ",p={}",
234            Base64Display::new(&client_proof, &STANDARD)
235        )
236        .unwrap();
237
238        self.state = State::Finish {
239            salted_password,
240            auth_message,
241        };
242        Ok(())
243    }
244
245    /// Finalizes the authentication process.
246    ///
247    /// This should be called when the backend sends an `AuthenticationSASLFinal` message.
248    /// Authentication has only succeeded if this method returns `Ok(())`.
249    pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
250        let (salted_password, auth_message) = match mem::replace(&mut self.state, State::Done) {
251            State::Finish {
252                salted_password,
253                auth_message,
254            } => (salted_password, auth_message),
255            _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
256        };
257
258        let message =
259            str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
260
261        let parsed = Parser::new(message).server_final_message()?;
262
263        let verifier = match parsed {
264            ServerFinalMessage::Error(e) => {
265                return Err(io::Error::new(
266                    io::ErrorKind::Other,
267                    format!("SCRAM error: {}", e),
268                ));
269            }
270            ServerFinalMessage::Verifier(verifier) => verifier,
271        };
272
273        let verifier = match STANDARD.decode(verifier) {
274            Ok(verifier) => verifier,
275            Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
276        };
277
278        let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
279            .expect("HMAC is able to accept all key sizes");
280        hmac.update(b"Server Key");
281        let server_key = hmac.finalize().into_bytes();
282
283        let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
284            .expect("HMAC is able to accept all key sizes");
285        hmac.update(auth_message.as_bytes());
286        hmac.verify_slice(&verifier)
287            .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
288    }
289}
290
291struct Parser<'a> {
292    s: &'a str,
293    it: iter::Peekable<str::CharIndices<'a>>,
294}
295
296impl<'a> Parser<'a> {
297    fn new(s: &'a str) -> Parser<'a> {
298        Parser {
299            s,
300            it: s.char_indices().peekable(),
301        }
302    }
303
304    fn eat(&mut self, target: char) -> io::Result<()> {
305        match self.it.next() {
306            Some((_, c)) if c == target => Ok(()),
307            Some((i, c)) => {
308                let m = format!(
309                    "unexpected character at byte {}: expected `{}` but got `{}",
310                    i, target, c
311                );
312                Err(io::Error::new(io::ErrorKind::InvalidInput, m))
313            }
314            None => Err(io::Error::new(
315                io::ErrorKind::UnexpectedEof,
316                "unexpected EOF",
317            )),
318        }
319    }
320
321    fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
322    where
323        F: Fn(char) -> bool,
324    {
325        let start = match self.it.peek() {
326            Some(&(i, _)) => i,
327            None => return Ok(""),
328        };
329
330        loop {
331            match self.it.peek() {
332                Some(&(_, c)) if f(c) => {
333                    self.it.next();
334                }
335                Some(&(i, _)) => return Ok(&self.s[start..i]),
336                None => return Ok(&self.s[start..]),
337            }
338        }
339    }
340
341    fn printable(&mut self) -> io::Result<&'a str> {
342        self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
343    }
344
345    fn nonce(&mut self) -> io::Result<&'a str> {
346        self.eat('r')?;
347        self.eat('=')?;
348        self.printable()
349    }
350
351    fn base64(&mut self) -> io::Result<&'a str> {
352        self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
353    }
354
355    fn salt(&mut self) -> io::Result<&'a str> {
356        self.eat('s')?;
357        self.eat('=')?;
358        self.base64()
359    }
360
361    fn posit_number(&mut self) -> io::Result<u32> {
362        let n = self.take_while(|c| c.is_ascii_digit())?;
363        n.parse()
364            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
365    }
366
367    fn iteration_count(&mut self) -> io::Result<u32> {
368        self.eat('i')?;
369        self.eat('=')?;
370        self.posit_number()
371    }
372
373    fn eof(&mut self) -> io::Result<()> {
374        match self.it.peek() {
375            Some(&(i, _)) => Err(io::Error::new(
376                io::ErrorKind::InvalidInput,
377                format!("unexpected trailing data at byte {}", i),
378            )),
379            None => Ok(()),
380        }
381    }
382
383    fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
384        let nonce = self.nonce()?;
385        self.eat(',')?;
386        let salt = self.salt()?;
387        self.eat(',')?;
388        let iteration_count = self.iteration_count()?;
389        self.eof()?;
390
391        Ok(ServerFirstMessage {
392            nonce,
393            salt,
394            iteration_count,
395        })
396    }
397
398    fn value(&mut self) -> io::Result<&'a str> {
399        self.take_while(|c| matches!(c, '\0' | '=' | ','))
400    }
401
402    fn server_error(&mut self) -> io::Result<Option<&'a str>> {
403        match self.it.peek() {
404            Some(&(_, 'e')) => {}
405            _ => return Ok(None),
406        }
407
408        self.eat('e')?;
409        self.eat('=')?;
410        self.value().map(Some)
411    }
412
413    fn verifier(&mut self) -> io::Result<&'a str> {
414        self.eat('v')?;
415        self.eat('=')?;
416        self.base64()
417    }
418
419    fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
420        let message = match self.server_error()? {
421            Some(error) => ServerFinalMessage::Error(error),
422            None => ServerFinalMessage::Verifier(self.verifier()?),
423        };
424        self.eof()?;
425        Ok(message)
426    }
427}
428
429struct ServerFirstMessage<'a> {
430    nonce: &'a str,
431    salt: &'a str,
432    iteration_count: u32,
433}
434
435enum ServerFinalMessage<'a> {
436    Error(&'a str),
437    Verifier(&'a str),
438}
439
440#[cfg(test)]
441mod test {
442    use super::*;
443
444    #[test]
445    fn parse_server_first_message() {
446        let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
447        let message = Parser::new(message).server_first_message().unwrap();
448        assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
449        assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
450        assert_eq!(message.iteration_count, 4096);
451    }
452
453    // recorded auth exchange from psql
454    #[test]
455    fn exchange() {
456        let password = "foobar";
457        let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
458
459        let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
460        let server_first =
461            "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
462             =4096";
463        let client_final =
464            "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
465             1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
466        let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
467
468        let mut scram = ScramSha256::new_inner(
469            password.as_bytes(),
470            ChannelBinding::unsupported(),
471            nonce.to_string(),
472        );
473        assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
474
475        scram.update(server_first.as_bytes()).unwrap();
476        assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
477
478        scram.finish(server_final.as_bytes()).unwrap();
479    }
480}