Skip to main content

mz_pgwire/
codec.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Encoding/decoding of messages in pgwire. See "[Frontend/Backend Protocol:
11//! Message Formats][1]" in the PostgreSQL reference for the specification.
12//!
13//! See the [crate docs](crate) for higher level concerns.
14//!
15//! [1]: https://www.postgresql.org/docs/11/protocol-message-formats.html
16
17use std::net::IpAddr;
18
19use async_trait::async_trait;
20use bytes::{Buf, BufMut, BytesMut};
21use bytesize::ByteSize;
22use futures::{SinkExt, TryStreamExt, sink};
23use itertools::Itertools;
24use mz_adapter_types::connection::ConnectionId;
25use mz_ore::cast::CastFrom;
26use mz_ore::future::OreSinkExt;
27use mz_ore::netio::AsyncReady;
28use mz_pgwire_common::{
29    ChannelBinding, Conn, Cursor, DecodeState, ErrorResponse, FrontendMessage, GS2Header,
30    MAX_REQUEST_SIZE, Pgbuf, SASLClientFinalResponse, SASLInitialResponse, input_err,
31    parse_frame_len,
32};
33use tokio::io::{self, AsyncRead, AsyncWrite, Interest, Ready};
34use tokio::time::{self, Duration};
35use tokio_util::codec::{Decoder, Encoder, Framed};
36use tracing::trace;
37
38use crate::message::{BackendMessage, BackendMessageKind, SASLServerFinalMessageKinds};
39
40/// A connection that manages the encoding and decoding of pgwire frames.
41pub struct FramedConn<A> {
42    conn_id: ConnectionId,
43    peer_addr: Option<IpAddr>,
44    inner: sink::Buffer<Framed<Conn<A>, Codec>, BackendMessage>,
45}
46
47impl<A> FramedConn<A>
48where
49    A: AsyncRead + AsyncWrite + Unpin,
50{
51    /// Constructs a new framed connection.
52    ///
53    /// The underlying connection, `inner`, is expected to be something like a
54    /// TCP stream. Anything that implements [`AsyncRead`] and [`AsyncWrite`]
55    /// will do.
56    ///
57    /// The supplied `conn_id` is used to identify the connection in logging
58    /// messages.
59    pub fn new(conn_id: ConnectionId, peer_addr: Option<IpAddr>, inner: Conn<A>) -> FramedConn<A> {
60        FramedConn {
61            conn_id,
62            peer_addr,
63            inner: Framed::new(inner, Codec::new()).buffer(32),
64        }
65    }
66
67    /// Reads and decodes one frontend message from the client.
68    ///
69    /// Blocks until the client sends a complete message. If the client
70    /// terminates the stream, returns `None`. Returns an error if the client
71    /// sends a malformed message or if the connection underlying is broken.
72    ///
73    /// # Cancel safety
74    ///
75    /// This method is cancel safe. The returned future only holds onto a
76    /// reference to thea underlying stream, so dropping it will never lose a
77    /// value.
78    ///
79    /// <https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html#cancel-safety-1>
80    pub async fn recv(&mut self) -> Result<Option<FrontendMessage>, io::Error> {
81        let message = self.inner.try_next().await?;
82        match &message {
83            Some(message) => trace!("cid={} recv_name={}", self.conn_id, message.name()),
84            None => trace!("cid={} recv=<eof>", self.conn_id),
85        }
86        Ok(message)
87    }
88
89    /// Encodes and sends one backend message to the client.
90    ///
91    /// Note that the connection is not flushed after calling this method. You
92    /// must call [`FramedConn::flush`] explicitly. Returns an error if the
93    /// underlying connection is broken.
94    ///
95    /// Please use `StateMachine::send` instead if calling from `StateMachine`,
96    /// as it applies session-based filters before calling this method.
97    pub async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
98    where
99        M: Into<BackendMessage>,
100    {
101        let message = message.into();
102        trace!(
103            "cid={} send={:?}",
104            self.conn_id,
105            BackendMessageKind::from(&message)
106        );
107        self.inner.enqueue(message).await
108    }
109
110    /// Encodes and sends the backend messages in the `messages` iterator to the
111    /// client.
112    ///
113    /// As with [`FramedConn::send`], the connection is not flushed after
114    /// calling this method. You must call [`FramedConn::flush`] explicitly.
115    /// Returns an error if the underlying connection is broken.
116    pub async fn send_all(
117        &mut self,
118        messages: impl IntoIterator<Item = BackendMessage>,
119    ) -> Result<(), io::Error> {
120        // N.B. we intentionally don't use `self.conn.send_all` here to avoid
121        // flushing the sink unnecessarily.
122        for m in messages {
123            self.send(m).await?;
124        }
125        Ok(())
126    }
127
128    /// Flushes all outstanding messages.
129    pub async fn flush(&mut self) -> Result<(), io::Error> {
130        self.inner.flush().await
131    }
132
133    /// Injects state that affects how certain backend messages are encoded.
134    ///
135    /// Specifically, the encoding of `BackendMessage::DataRow` depends upon the
136    /// types of the datums in the row. To avoid including the same type
137    /// information in each message, we use this side channel to install the
138    /// type information in the codec before sending any data row messages. This
139    /// violates the abstraction boundary a bit but results in much better
140    /// performance.
141    pub fn set_encode_state(
142        &mut self,
143        encode_state: Vec<(mz_pgrepr::Type, mz_pgwire_common::Format)>,
144    ) {
145        self.inner.get_mut().codec_mut().encode_state = encode_state;
146    }
147
148    /// Enables or disables copy mode on the codec.
149    ///
150    /// When copy mode is enabled, the aggregate buffer size check in the
151    /// decoder is skipped. This is needed during COPY FROM STDIN because
152    /// many small CopyData frames can accumulate in the TCP read buffer.
153    pub fn set_copy_mode(&mut self, enabled: bool) {
154        self.inner.get_mut().codec_mut().in_copy_mode = enabled;
155    }
156
157    /// Waits for the connection to be closed.
158    ///
159    /// Returns a "connection closed" error when the connection is closed. If
160    /// another error occurs before the connection is closed, that error is
161    /// returned instead.
162    ///
163    /// Use this method when you have an unbounded stream of data to forward to
164    /// the connection and the protocol does not require the client to
165    /// periodically acknowledge receipt. If you don't call this method to
166    /// periodically check if the connection has closed, you may not notice that
167    /// the client has gone away for an unboundedly long amount of time; usually
168    /// not until the stream of data produces its next message and you attempt
169    /// to write the data to the connection.
170    pub async fn wait_closed(&self) -> io::Error
171    where
172        A: AsyncReady + Send + Sync,
173    {
174        loop {
175            time::sleep(Duration::from_secs(1)).await;
176
177            match self.ready(Interest::READABLE | Interest::WRITABLE).await {
178                Ok(ready) if ready.is_read_closed() || ready.is_write_closed() => {
179                    return io::Error::new(io::ErrorKind::Other, "connection closed");
180                }
181                Ok(_) => (),
182                Err(err) => return err,
183            }
184        }
185    }
186
187    /// Returns the ID associated with this connection.
188    pub fn conn_id(&self) -> &ConnectionId {
189        &self.conn_id
190    }
191
192    /// Returns the peer address of the connection.
193    pub fn peer_addr(&self) -> &Option<IpAddr> {
194        &self.peer_addr
195    }
196}
197
198impl<A> FramedConn<A>
199where
200    A: AsyncRead + AsyncWrite + Unpin,
201{
202    pub fn inner(&self) -> &Conn<A> {
203        self.inner.get_ref().get_ref()
204    }
205}
206
207#[async_trait]
208impl<A> AsyncReady for FramedConn<A>
209where
210    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
211{
212    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
213        self.inner.get_ref().get_ref().ready(interest).await
214    }
215}
216
217struct Codec {
218    decode_state: DecodeState,
219    encode_state: Vec<(mz_pgrepr::Type, mz_pgwire_common::Format)>,
220    /// When true, skip the aggregate buffer size check in `decode()`.
221    /// During COPY FROM STDIN, many small CopyData frames accumulate in the
222    /// TCP read buffer and can exceed MAX_REQUEST_SIZE even though individual
223    /// frames are small. Individual frame lengths are still validated by
224    /// `parse_frame_len()`.
225    in_copy_mode: bool,
226}
227
228impl Codec {
229    /// Creates a new `Codec`.
230    pub fn new() -> Codec {
231        Codec {
232            decode_state: DecodeState::Head,
233            encode_state: vec![],
234            in_copy_mode: false,
235        }
236    }
237}
238
239impl Default for Codec {
240    fn default() -> Codec {
241        Codec::new()
242    }
243}
244
245impl Encoder<BackendMessage> for Codec {
246    type Error = io::Error;
247
248    /// Encode a backend message into `dst`.
249    /// If this function returns an error result, `dst` is left unmodified.
250    fn encode(&mut self, msg: BackendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
251        // Record the starting position so we can truncate on error.
252        // This prevents partial messages from being left in the buffer,
253        // which could be sent to the client and cause "lost synchronization" errors.
254        let start = dst.len();
255        match self.encode_inner(msg, dst) {
256            Ok(()) => Ok(()),
257            Err(e) => {
258                dst.truncate(start);
259                Err(e)
260            }
261        }
262    }
263}
264
265impl Codec {
266    /// This is the meat of the encoding logic. It's a separate function so that errors returned by
267    /// `?` can be handled in the outer `encode` function.
268    fn encode_inner(&self, msg: BackendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
269        // Write type byte.
270        let byte = match &msg {
271            BackendMessage::AuthenticationOk => b'R',
272            BackendMessage::AuthenticationCleartextPassword
273            | BackendMessage::AuthenticationSASL
274            | BackendMessage::AuthenticationSASLContinue(_)
275            | BackendMessage::AuthenticationSASLFinal(_) => b'R',
276            BackendMessage::RowDescription(_) => b'T',
277            BackendMessage::DataRow(_) => b'D',
278            BackendMessage::CommandComplete { .. } => b'C',
279            BackendMessage::EmptyQueryResponse => b'I',
280            BackendMessage::ReadyForQuery(_) => b'Z',
281            BackendMessage::NoData => b'n',
282            BackendMessage::ParameterStatus(_, _) => b'S',
283            BackendMessage::PortalSuspended => b's',
284            BackendMessage::BackendKeyData { .. } => b'K',
285            BackendMessage::ParameterDescription(_) => b't',
286            BackendMessage::ParseComplete => b'1',
287            BackendMessage::BindComplete => b'2',
288            BackendMessage::CloseComplete => b'3',
289            BackendMessage::ErrorResponse(r) => {
290                if r.severity.is_error() {
291                    b'E'
292                } else {
293                    b'N'
294                }
295            }
296            BackendMessage::CopyInResponse { .. } => b'G',
297            BackendMessage::CopyOutResponse { .. } => b'H',
298            BackendMessage::CopyData(_) => b'd',
299            BackendMessage::CopyDone => b'c',
300        };
301        dst.put_u8(byte);
302
303        // Write message length placeholder. The true length is filled in later.
304        let base = dst.len();
305        dst.put_u32(0);
306
307        // Write message contents.
308        match msg {
309            BackendMessage::CopyInResponse {
310                overall_format,
311                column_formats,
312            }
313            | BackendMessage::CopyOutResponse {
314                overall_format,
315                column_formats,
316            } => {
317                dst.put_format_i8(overall_format);
318                if column_formats.len() > usize::try_from(i16::MAX).expect("i16::MAX is positive") {
319                    return Err(io::Error::new(
320                        io::ErrorKind::InvalidData,
321                        format!(
322                            "{} columns in COPY response, which exceeds {}",
323                            column_formats.len(),
324                            i16::MAX
325                        ),
326                    ));
327                }
328                dst.put_length_i16(column_formats.len())?;
329                for format in column_formats {
330                    dst.put_format_i16(format);
331                }
332            }
333            BackendMessage::CopyData(data) => {
334                dst.put_slice(&data);
335            }
336            BackendMessage::CopyDone => (),
337            BackendMessage::AuthenticationOk => {
338                dst.put_u32(0);
339            }
340            BackendMessage::AuthenticationCleartextPassword => {
341                dst.put_u32(3);
342            }
343            BackendMessage::AuthenticationSASL => {
344                dst.put_u32(10);
345                dst.put_string("SCRAM-SHA-256");
346                dst.put_u8(b'\0');
347            }
348            BackendMessage::AuthenticationSASLContinue(data) => {
349                dst.put_u32(11);
350                let data = format!(
351                    "r={},s={},i={}",
352                    data.nonce, data.salt, data.iteration_count
353                );
354                dst.put_slice(data.as_bytes());
355            }
356            BackendMessage::AuthenticationSASLFinal(data) => {
357                dst.put_u32(12);
358                let res = match data.kind {
359                    SASLServerFinalMessageKinds::Verifier(verifier) => {
360                        format!("v={}", verifier)
361                    }
362                };
363                dst.put_slice(res.as_bytes());
364                if !data.extensions.is_empty() {
365                    dst.put_slice(b",");
366                    dst.put_slice(data.extensions.join(",").as_bytes());
367                }
368            }
369            BackendMessage::RowDescription(fields) => {
370                if fields.len() > usize::try_from(i16::MAX).expect("i16::MAX is positive") {
371                    return Err(io::Error::new(
372                        io::ErrorKind::InvalidData,
373                        format!(
374                            "{} fields in row description, which exceeds {}",
375                            fields.len(),
376                            i16::MAX
377                        ),
378                    ));
379                }
380                dst.put_length_i16(fields.len())?;
381                for f in &fields {
382                    dst.put_string(&f.name.to_string());
383                    dst.put_u32(f.table_id);
384                    dst.put_u16(f.column_id);
385                    dst.put_u32(f.type_oid);
386                    dst.put_i16(f.type_len);
387                    dst.put_i32(f.type_mod);
388                    // TODO: make the format correct
389                    dst.put_format_i16(f.format);
390                }
391            }
392            BackendMessage::DataRow(fields) => {
393                if fields.len() > usize::try_from(i16::MAX).expect("i16::MAX is positive") {
394                    return Err(io::Error::new(
395                        io::ErrorKind::InvalidData,
396                        format!(
397                            "{} fields in data row, which exceeds {}",
398                            fields.len(),
399                            i16::MAX
400                        ),
401                    ));
402                }
403                dst.put_length_i16(fields.len())?;
404                for (f, (ty, format)) in fields.iter().zip_eq(&self.encode_state) {
405                    if let Some(f) = f {
406                        let base = dst.len();
407                        dst.put_u32(0);
408                        f.encode(ty, *format, dst)?;
409                        let len = dst.len() - base - 4;
410                        let len = i32::try_from(len).map_err(|_| {
411                            io::Error::new(
412                                io::ErrorKind::InvalidData,
413                                "length of encoded data row field does not fit into an i32",
414                            )
415                        })?;
416                        dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
417                    } else {
418                        dst.put_i32(-1);
419                    }
420                }
421            }
422            BackendMessage::CommandComplete { tag } => {
423                dst.put_string(&tag);
424            }
425            BackendMessage::ParseComplete => (),
426            BackendMessage::BindComplete => (),
427            BackendMessage::CloseComplete => (),
428            BackendMessage::EmptyQueryResponse => (),
429            BackendMessage::ReadyForQuery(status) => {
430                dst.put_u8(status.into());
431            }
432            BackendMessage::ParameterStatus(name, value) => {
433                dst.put_string(name);
434                dst.put_string(&value);
435            }
436            BackendMessage::PortalSuspended => (),
437            BackendMessage::NoData => (),
438            BackendMessage::BackendKeyData {
439                conn_id,
440                secret_key,
441            } => {
442                dst.put_u32(conn_id);
443                dst.put_u32(secret_key);
444            }
445            BackendMessage::ParameterDescription(params) => {
446                if params.len() > usize::try_from(i16::MAX).expect("i16::MAX is positive") {
447                    return Err(io::Error::new(
448                        io::ErrorKind::InvalidData,
449                        format!(
450                            "{} params in parameter description, which exceeds {}",
451                            params.len(),
452                            i16::MAX
453                        ),
454                    ));
455                }
456                dst.put_length_i16(params.len())?;
457                for param in params {
458                    dst.put_u32(param.oid());
459                }
460            }
461            BackendMessage::ErrorResponse(ErrorResponse {
462                severity,
463                code,
464                message,
465                detail,
466                hint,
467                position,
468            }) => {
469                dst.put_u8(b'S');
470                dst.put_string(severity.as_str());
471                dst.put_u8(b'C');
472                dst.put_string(code.code());
473                dst.put_u8(b'M');
474                dst.put_string(&message);
475                if let Some(detail) = &detail {
476                    dst.put_u8(b'D');
477                    dst.put_string(detail);
478                }
479                if let Some(hint) = &hint {
480                    dst.put_u8(b'H');
481                    dst.put_string(hint);
482                }
483                if let Some(position) = &position {
484                    dst.put_u8(b'P');
485                    dst.put_string(&position.to_string());
486                }
487                dst.put_u8(b'\0');
488            }
489        }
490
491        let len = dst.len() - base;
492
493        // Overwrite length placeholder with true length.
494        let len = i32::try_from(len).map_err(|_| {
495            io::Error::new(
496                io::ErrorKind::InvalidData,
497                "length of encoded message does not fit into an i32",
498            )
499        })?;
500        dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
501
502        Ok(())
503    }
504}
505
506impl Decoder for Codec {
507    type Item = FrontendMessage;
508    type Error = io::Error;
509
510    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FrontendMessage>, io::Error> {
511        if !self.in_copy_mode && src.len() > MAX_REQUEST_SIZE {
512            return Err(io::Error::new(
513                io::ErrorKind::InvalidData,
514                format!(
515                    "request larger than {}",
516                    ByteSize::b(u64::cast_from(MAX_REQUEST_SIZE))
517                ),
518            ));
519        }
520        loop {
521            match self.decode_state {
522                DecodeState::Head => {
523                    if src.len() < 5 {
524                        return Ok(None);
525                    }
526                    let msg_type = src[0];
527                    let frame_len = parse_frame_len(&src[1..])?;
528                    src.advance(5);
529                    src.reserve(frame_len);
530                    self.decode_state = DecodeState::Data(msg_type, frame_len);
531                }
532
533                DecodeState::Data(msg_type, frame_len) => {
534                    if src.len() < frame_len {
535                        return Ok(None);
536                    }
537                    let buf = src.split_to(frame_len).freeze();
538                    let buf = Cursor::new(&buf);
539                    let msg = match msg_type {
540                        // Simple query flow.
541                        b'Q' => decode_query(buf)?,
542
543                        // Extended query flow.
544                        b'P' => decode_parse(buf)?,
545                        b'D' => decode_describe(buf)?,
546                        b'B' => decode_bind(buf)?,
547                        b'E' => decode_execute(buf)?,
548                        b'H' => decode_flush(buf)?,
549                        b'S' => decode_sync(buf)?,
550                        b'C' => decode_close(buf)?,
551
552                        // Termination.
553                        b'X' => decode_terminate(buf)?,
554
555                        // Authentication.
556                        b'p' => decode_auth(buf)?,
557
558                        // Copy from flow.
559                        b'f' => decode_copy_fail(buf)?,
560                        b'd' => decode_copy_data(buf, frame_len)?,
561                        b'c' => decode_copy_done(buf)?,
562
563                        // Invalid.
564                        _ => {
565                            return Err(io::Error::new(
566                                io::ErrorKind::InvalidData,
567                                format!("unknown message type {}", msg_type),
568                            ));
569                        }
570                    };
571                    src.reserve(5);
572                    self.decode_state = DecodeState::Head;
573                    return Ok(Some(msg));
574                }
575            }
576        }
577    }
578}
579
580fn decode_terminate(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
581    // Nothing more to decode.
582    Ok(FrontendMessage::Terminate)
583}
584
585fn decode_auth(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
586    let mut value = Vec::new();
587    while let Ok(b) = buf.read_byte() {
588        value.push(b);
589    }
590    Ok(FrontendMessage::RawAuthentication(value))
591}
592
593fn expect(buf: &mut Cursor, expected: &[u8]) -> Result<(), io::Error> {
594    for i in 0..expected.len() {
595        if buf.read_byte()? != expected[i] {
596            return Err(input_err(format!(
597                "Invalid SASL initial response: expected '{}'",
598                std::str::from_utf8(expected).unwrap_or("invalid UTF-8")
599            )));
600        }
601    }
602    Ok(())
603}
604
605fn read_until_comma(buf: &mut Cursor) -> Result<Vec<u8>, io::Error> {
606    let mut v = Vec::new();
607    while let Ok(b) = buf.peek_byte() {
608        if b == b',' {
609            break;
610        }
611        v.push(buf.read_byte()?);
612    }
613    Ok(v)
614}
615
616// All SASL parsing is based on RFC 5802, [section 7](https://datatracker.ietf.org/doc/html/rfc5802#section-7)
617
618//   extensions = attr-val *("," attr-val)
619//                     ;; All extensions are optional,
620//                     ;; i.e., unrecognized attributes
621//                     ;; not defined in this document
622//                     ;; MUST be ignored.
623//   reserved-mext  = "m=" 1*(value-char)
624//                     ;; Reserved for signaling mandatory extensions.
625//                     ;; The exact syntax will be defined in
626//                     ;; the future.
627//   gs2-cbind-flag  = ("p=" cb-name) / "n" / "y"
628//                     ;; "n" -> client doesn't support channel binding.
629//                     ;; "y" -> client does support channel binding
630//                     ;;        but thinks the server does not.
631//                     ;; "p" -> client requires channel binding.
632//                     ;; The selected channel binding follows "p=".
633//
634//   gs2-header      = gs2-cbind-flag "," [ authzid ] ","
635//                     ;; GS2 header for SCRAM
636//                     ;; (the actual GS2 header includes an optional
637//                     ;; flag to indicate that the GSS mechanism is not
638//                     ;; "standard", but since SCRAM is "standard", we
639//                     ;; don't include that flag).
640//   client-first-message-bare =
641//                     [reserved-mext ","]
642//                     username "," nonce ["," extensions]
643//
644//   client-first-message =
645//                     gs2-header client-first-message-bare
646pub fn decode_sasl_client_first_message(mut buf: Cursor) -> Result<SASLInitialResponse, io::Error> {
647    // 1) GS2 cbind flag
648    let cbind_flag = match buf.read_byte()? {
649        b'n' => ChannelBinding::None,
650        b'y' => ChannelBinding::ClientSupported,
651        b'p' => {
652            // must be "p=" then cbname up to next comma
653            expect(&mut buf, b"=")?;
654            let cbname = String::from_utf8(read_until_comma(&mut buf)?)
655                .map_err(|_| input_err("invalid cbname utf8"))?;
656            ChannelBinding::Required(cbname)
657        }
658        other => {
659            return Err(input_err(format!(
660                "Invalid channel binding flag: {}",
661                other
662            )));
663        }
664    };
665    expect(&mut buf, b",")?;
666
667    // 2) Optional authzid: either empty, or "a=" up to next comma
668    let mut authzid = None;
669    if buf.peek_byte()? == b'a' {
670        expect(&mut buf, b"a=")?;
671        let a = String::from_utf8(read_until_comma(&mut buf)?)
672            .map_err(|_| input_err("invalid authzid utf8"))?;
673        authzid = Some(a);
674    }
675    expect(&mut buf, b",")?;
676
677    let mut client_first_message_bare_raw = String::new();
678
679    // 3) Optional reserved "m=" extension before n=
680    let mut reserved_mext = None;
681    if buf.peek_byte()? == b'm' {
682        expect(&mut buf, b"m=")?;
683        let mext_val = String::from_utf8(read_until_comma(&mut buf)?)
684            .map_err(|_| input_err("invalid m ext utf8"))?;
685        client_first_message_bare_raw.push_str(&format!("m={},", mext_val));
686        reserved_mext = Some(mext_val);
687        expect(&mut buf, b",")?;
688    }
689
690    // 4) Username: must be "n=" then saslname
691    expect(&mut buf, b"n=")?;
692    // Postgres doesn't use the username here, so we just consume
693    let username = String::from_utf8(read_until_comma(&mut buf)?)
694        .map_err(|_| input_err("invalid username utf8"))?;
695    expect(&mut buf, b",")?;
696    client_first_message_bare_raw.push_str(&format!("n={},", username));
697
698    // 5) Nonce: must be "r=" then value up to next comma or end
699    expect(&mut buf, b"r=")?;
700    let nonce = String::from_utf8(read_until_comma(&mut buf)?)
701        .map_err(|_| input_err("invalid nonce utf8"))?;
702    client_first_message_bare_raw.push_str(&format!("r={}", nonce));
703
704    // 6) Optional extensions: "," key=value chunks
705    let mut extensions = Vec::new();
706    while let Ok(b',') = buf.peek_byte().map(|b| b) {
707        expect(&mut buf, b",")?;
708        let ext = String::from_utf8(read_until_comma(&mut buf)?)
709            .map_err(|_| input_err("invalid ext utf8"))?;
710        if !ext.is_empty() {
711            client_first_message_bare_raw.push_str(&format!(",{}", ext));
712            extensions.push(ext);
713        }
714    }
715
716    Ok(SASLInitialResponse {
717        gs2_header: GS2Header {
718            cbind_flag,
719            authzid,
720        },
721        nonce,
722        extensions,
723        reserved_mext,
724        client_first_message_bare_raw,
725    })
726}
727
728pub fn decode_sasl_initial_response(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
729    let mechanism = buf.read_cstr()?;
730    let initial_resp_len = buf.read_i32()?;
731    if initial_resp_len < 0 {
732        // -1 means no response? We bail here
733        return Err(input_err("No initial response"));
734    }
735
736    let initial_response = decode_sasl_client_first_message(buf)?;
737    Ok(FrontendMessage::SASLInitialResponse {
738        gs2_header: initial_response.gs2_header.clone(),
739        mechanism: mechanism.to_owned(),
740        initial_response,
741    })
742}
743
744//   proof           = "p=" base64
745//
746//   channel-binding = "c=" base64
747//                     ;; base64 encoding of cbind-input.
748//   client-final-message-without-proof =
749//                     channel-binding "," nonce [","
750//                     extensions]
751//
752//   client-final-message =
753//                     client-final-message-without-proof "," proof
754pub fn decode_sasl_response(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
755    // --- client-final-message-without-proof ---
756    let mut client_final_message_bare_raw = String::new();
757    // channel-binding: "c=" <base64>, up to the next comma
758    expect(&mut buf, b"c=")?;
759    let channel_binding = String::from_utf8(read_until_comma(&mut buf)?)
760        .map_err(|_| input_err("invalid channel-binding utf8"))?;
761    expect(&mut buf, b",")?;
762    client_final_message_bare_raw.push_str(&format!("c={},", channel_binding));
763
764    // nonce: "r=" <printable>, up to the next comma
765    expect(&mut buf, b"r=")?;
766    let nonce = String::from_utf8(read_until_comma(&mut buf)?)
767        .map_err(|_| input_err("invalid nonce utf8"))?;
768    client_final_message_bare_raw.push_str(&format!("r={}", nonce));
769
770    // after reading channel-binding and nonce
771    let mut extensions = Vec::new();
772
773    // Keep reading ",<token>" until we see ",p="
774    while buf.peek_byte()? == b',' {
775        expect(&mut buf, b",")?;
776        if buf.peek_byte()? == b'p' {
777            break;
778        }
779        let ext = String::from_utf8(read_until_comma(&mut buf)?)
780            .map_err(|_| input_err("invalid extension utf8"))?;
781        if !ext.is_empty() {
782            client_final_message_bare_raw.push_str(&format!(",{}", ext));
783            extensions.push(ext);
784        }
785    }
786
787    // Proof is mandatory and last
788    expect(&mut buf, b"p=")?;
789    let proof = String::from_utf8(read_until_comma(&mut buf)?)
790        .map_err(|_| input_err("invalid proof utf8"))?;
791
792    Ok(FrontendMessage::SASLResponse(SASLClientFinalResponse {
793        channel_binding,
794        nonce,
795        extensions,
796        proof,
797        client_final_message_bare_raw,
798    }))
799}
800
801pub fn decode_password(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
802    Ok(FrontendMessage::Password {
803        password: buf.read_cstr()?.to_owned(),
804    })
805}
806
807fn decode_query(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
808    Ok(FrontendMessage::Query {
809        sql: buf.read_cstr()?.to_string(),
810    })
811}
812
813fn decode_parse(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
814    let name = buf.read_cstr()?;
815    let sql = buf.read_cstr()?;
816
817    let mut param_types = vec![];
818    for _ in 0..buf.read_i16()? {
819        param_types.push(buf.read_u32()?);
820    }
821
822    Ok(FrontendMessage::Parse {
823        name: name.into(),
824        sql: sql.into(),
825        param_types,
826    })
827}
828
829fn decode_close(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
830    match buf.read_byte()? {
831        b'S' => Ok(FrontendMessage::CloseStatement {
832            name: buf.read_cstr()?.to_owned(),
833        }),
834        b'P' => Ok(FrontendMessage::ClosePortal {
835            name: buf.read_cstr()?.to_owned(),
836        }),
837        b => Err(input_err(format!(
838            "invalid type byte in close message: {}",
839            b
840        ))),
841    }
842}
843
844fn decode_describe(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
845    let first_char = buf.read_byte()?;
846    let name = buf.read_cstr()?.to_string();
847    match first_char {
848        b'S' => Ok(FrontendMessage::DescribeStatement { name }),
849        b'P' => Ok(FrontendMessage::DescribePortal { name }),
850        other => Err(input_err(format!("Invalid describe type: {:#x?}", other))),
851    }
852}
853
854fn decode_bind(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
855    let portal_name = buf.read_cstr()?.to_string();
856    let statement_name = buf.read_cstr()?.to_string();
857
858    let mut param_formats = Vec::new();
859    for _ in 0..buf.read_i16()? {
860        param_formats.push(buf.read_format()?);
861    }
862
863    let mut raw_params = Vec::new();
864    for _ in 0..buf.read_i16()? {
865        let len = buf.read_i32()?;
866        if len == -1 {
867            raw_params.push(None); // NULL
868        } else {
869            // TODO(benesch): this should use bytes::Bytes to avoid the copy.
870            let mut value = Vec::new();
871            for _ in 0..len {
872                value.push(buf.read_byte()?);
873            }
874            raw_params.push(Some(value));
875        }
876    }
877
878    let mut result_formats = Vec::new();
879    for _ in 0..buf.read_i16()? {
880        result_formats.push(buf.read_format()?);
881    }
882
883    Ok(FrontendMessage::Bind {
884        portal_name,
885        statement_name,
886        param_formats,
887        raw_params,
888        result_formats,
889    })
890}
891
892fn decode_execute(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
893    let portal_name = buf.read_cstr()?.to_string();
894    let max_rows = buf.read_i32()?;
895    Ok(FrontendMessage::Execute {
896        portal_name,
897        max_rows,
898    })
899}
900
901fn decode_flush(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
902    // Nothing more to decode.
903    Ok(FrontendMessage::Flush)
904}
905
906fn decode_sync(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
907    // Nothing more to decode.
908    Ok(FrontendMessage::Sync)
909}
910
911fn decode_copy_data(mut buf: Cursor, frame_len: usize) -> Result<FrontendMessage, io::Error> {
912    let mut data = Vec::with_capacity(frame_len);
913    for _ in 0..frame_len {
914        data.push(buf.read_byte()?);
915    }
916    Ok(FrontendMessage::CopyData(data))
917}
918
919fn decode_copy_done(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
920    // Nothing more to decode.
921    Ok(FrontendMessage::CopyDone)
922}
923
924fn decode_copy_fail(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
925    Ok(FrontendMessage::CopyFail(buf.read_cstr()?.to_string()))
926}