mz_pgwire/
codec.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Encoding/decoding of messages in pgwire. See "[Frontend/Backend Protocol:
11//! Message Formats][1]" in the PostgreSQL reference for the specification.
12//!
13//! See the [crate docs](crate) for higher level concerns.
14//!
15//! [1]: https://www.postgresql.org/docs/11/protocol-message-formats.html
16
17use std::net::IpAddr;
18
19use async_trait::async_trait;
20use bytes::{Buf, BufMut, BytesMut};
21use bytesize::ByteSize;
22use futures::{SinkExt, TryStreamExt, sink};
23use itertools::Itertools;
24use mz_adapter_types::connection::ConnectionId;
25use mz_ore::cast::CastFrom;
26use mz_ore::future::OreSinkExt;
27use mz_ore::netio::AsyncReady;
28use mz_pgwire_common::{
29    ChannelBinding, Conn, Cursor, DecodeState, ErrorResponse, FrontendMessage, GS2Header,
30    MAX_REQUEST_SIZE, Pgbuf, SASLClientFinalResponse, SASLInitialResponse, input_err,
31    parse_frame_len,
32};
33use tokio::io::{self, AsyncRead, AsyncWrite, Interest, Ready};
34use tokio::time::{self, Duration};
35use tokio_util::codec::{Decoder, Encoder, Framed};
36use tracing::trace;
37
38use crate::message::{BackendMessage, BackendMessageKind, SASLServerFinalMessageKinds};
39
40/// A connection that manages the encoding and decoding of pgwire frames.
41pub struct FramedConn<A> {
42    conn_id: ConnectionId,
43    peer_addr: Option<IpAddr>,
44    inner: sink::Buffer<Framed<Conn<A>, Codec>, BackendMessage>,
45}
46
47impl<A> FramedConn<A>
48where
49    A: AsyncRead + AsyncWrite + Unpin,
50{
51    /// Constructs a new framed connection.
52    ///
53    /// The underlying connection, `inner`, is expected to be something like a
54    /// TCP stream. Anything that implements [`AsyncRead`] and [`AsyncWrite`]
55    /// will do.
56    ///
57    /// The supplied `conn_id` is used to identify the connection in logging
58    /// messages.
59    pub fn new(conn_id: ConnectionId, peer_addr: Option<IpAddr>, inner: Conn<A>) -> FramedConn<A> {
60        FramedConn {
61            conn_id,
62            peer_addr,
63            inner: Framed::new(inner, Codec::new()).buffer(32),
64        }
65    }
66
67    /// Reads and decodes one frontend message from the client.
68    ///
69    /// Blocks until the client sends a complete message. If the client
70    /// terminates the stream, returns `None`. Returns an error if the client
71    /// sends a malformed message or if the connection underlying is broken.
72    ///
73    /// # Cancel safety
74    ///
75    /// This method is cancel safe. The returned future only holds onto a
76    /// reference to thea underlying stream, so dropping it will never lose a
77    /// value.
78    ///
79    /// <https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html#cancel-safety-1>
80    pub async fn recv(&mut self) -> Result<Option<FrontendMessage>, io::Error> {
81        let message = self.inner.try_next().await?;
82        match &message {
83            Some(message) => trace!("cid={} recv_name={}", self.conn_id, message.name()),
84            None => trace!("cid={} recv=<eof>", self.conn_id),
85        }
86        Ok(message)
87    }
88
89    /// Encodes and sends one backend message to the client.
90    ///
91    /// Note that the connection is not flushed after calling this method. You
92    /// must call [`FramedConn::flush`] explicitly. Returns an error if the
93    /// underlying connection is broken.
94    ///
95    /// Please use `StateMachine::send` instead if calling from `StateMachine`,
96    /// as it applies session-based filters before calling this method.
97    pub async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
98    where
99        M: Into<BackendMessage>,
100    {
101        let message = message.into();
102        trace!(
103            "cid={} send={:?}",
104            self.conn_id,
105            BackendMessageKind::from(&message)
106        );
107        self.inner.enqueue(message).await
108    }
109
110    /// Encodes and sends the backend messages in the `messages` iterator to the
111    /// client.
112    ///
113    /// As with [`FramedConn::send`], the connection is not flushed after
114    /// calling this method. You must call [`FramedConn::flush`] explicitly.
115    /// Returns an error if the underlying connection is broken.
116    pub async fn send_all(
117        &mut self,
118        messages: impl IntoIterator<Item = BackendMessage>,
119    ) -> Result<(), io::Error> {
120        // N.B. we intentionally don't use `self.conn.send_all` here to avoid
121        // flushing the sink unnecessarily.
122        for m in messages {
123            self.send(m).await?;
124        }
125        Ok(())
126    }
127
128    /// Flushes all outstanding messages.
129    pub async fn flush(&mut self) -> Result<(), io::Error> {
130        self.inner.flush().await
131    }
132
133    /// Injects state that affects how certain backend messages are encoded.
134    ///
135    /// Specifically, the encoding of `BackendMessage::DataRow` depends upon the
136    /// types of the datums in the row. To avoid including the same type
137    /// information in each message, we use this side channel to install the
138    /// type information in the codec before sending any data row messages. This
139    /// violates the abstraction boundary a bit but results in much better
140    /// performance.
141    pub fn set_encode_state(
142        &mut self,
143        encode_state: Vec<(mz_pgrepr::Type, mz_pgwire_common::Format)>,
144    ) {
145        self.inner.get_mut().codec_mut().encode_state = encode_state;
146    }
147
148    /// Waits for the connection to be closed.
149    ///
150    /// Returns a "connection closed" error when the connection is closed. If
151    /// another error occurs before the connection is closed, that error is
152    /// returned instead.
153    ///
154    /// Use this method when you have an unbounded stream of data to forward to
155    /// the connection and the protocol does not require the client to
156    /// periodically acknowledge receipt. If you don't call this method to
157    /// periodically check if the connection has closed, you may not notice that
158    /// the client has gone away for an unboundedly long amount of time; usually
159    /// not until the stream of data produces its next message and you attempt
160    /// to write the data to the connection.
161    pub async fn wait_closed(&self) -> io::Error
162    where
163        A: AsyncReady + Send + Sync,
164    {
165        loop {
166            time::sleep(Duration::from_secs(1)).await;
167
168            match self.ready(Interest::READABLE | Interest::WRITABLE).await {
169                Ok(ready) if ready.is_read_closed() || ready.is_write_closed() => {
170                    return io::Error::new(io::ErrorKind::Other, "connection closed");
171                }
172                Ok(_) => (),
173                Err(err) => return err,
174            }
175        }
176    }
177
178    /// Returns the ID associated with this connection.
179    pub fn conn_id(&self) -> &ConnectionId {
180        &self.conn_id
181    }
182
183    /// Returns the peer address of the connection.
184    pub fn peer_addr(&self) -> &Option<IpAddr> {
185        &self.peer_addr
186    }
187}
188
189impl<A> FramedConn<A>
190where
191    A: AsyncRead + AsyncWrite + Unpin,
192{
193    pub fn inner(&self) -> &Conn<A> {
194        self.inner.get_ref().get_ref()
195    }
196}
197
198#[async_trait]
199impl<A> AsyncReady for FramedConn<A>
200where
201    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
202{
203    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
204        self.inner.get_ref().get_ref().ready(interest).await
205    }
206}
207
208struct Codec {
209    decode_state: DecodeState,
210    encode_state: Vec<(mz_pgrepr::Type, mz_pgwire_common::Format)>,
211}
212
213impl Codec {
214    /// Creates a new `Codec`.
215    pub fn new() -> Codec {
216        Codec {
217            decode_state: DecodeState::Head,
218            encode_state: vec![],
219        }
220    }
221}
222
223impl Default for Codec {
224    fn default() -> Codec {
225        Codec::new()
226    }
227}
228
229impl Encoder<BackendMessage> for Codec {
230    type Error = io::Error;
231
232    fn encode(&mut self, msg: BackendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
233        // Write type byte.
234        let byte = match &msg {
235            BackendMessage::AuthenticationOk => b'R',
236            BackendMessage::AuthenticationCleartextPassword
237            | BackendMessage::AuthenticationSASL
238            | BackendMessage::AuthenticationSASLContinue(_)
239            | BackendMessage::AuthenticationSASLFinal(_) => b'R',
240            BackendMessage::RowDescription(_) => b'T',
241            BackendMessage::DataRow(_) => b'D',
242            BackendMessage::CommandComplete { .. } => b'C',
243            BackendMessage::EmptyQueryResponse => b'I',
244            BackendMessage::ReadyForQuery(_) => b'Z',
245            BackendMessage::NoData => b'n',
246            BackendMessage::ParameterStatus(_, _) => b'S',
247            BackendMessage::PortalSuspended => b's',
248            BackendMessage::BackendKeyData { .. } => b'K',
249            BackendMessage::ParameterDescription(_) => b't',
250            BackendMessage::ParseComplete => b'1',
251            BackendMessage::BindComplete => b'2',
252            BackendMessage::CloseComplete => b'3',
253            BackendMessage::ErrorResponse(r) => {
254                if r.severity.is_error() {
255                    b'E'
256                } else {
257                    b'N'
258                }
259            }
260            BackendMessage::CopyInResponse { .. } => b'G',
261            BackendMessage::CopyOutResponse { .. } => b'H',
262            BackendMessage::CopyData(_) => b'd',
263            BackendMessage::CopyDone => b'c',
264        };
265        dst.put_u8(byte);
266
267        // Write message length placeholder. The true length is filled in later.
268        let base = dst.len();
269        dst.put_u32(0);
270
271        // Write message contents.
272        match msg {
273            BackendMessage::CopyInResponse {
274                overall_format,
275                column_formats,
276            }
277            | BackendMessage::CopyOutResponse {
278                overall_format,
279                column_formats,
280            } => {
281                dst.put_format_i8(overall_format);
282                dst.put_length_i16(column_formats.len())?;
283                for format in column_formats {
284                    dst.put_format_i16(format);
285                }
286            }
287            BackendMessage::CopyData(data) => {
288                dst.put_slice(&data);
289            }
290            BackendMessage::CopyDone => (),
291            BackendMessage::AuthenticationOk => {
292                dst.put_u32(0);
293            }
294            BackendMessage::AuthenticationCleartextPassword => {
295                dst.put_u32(3);
296            }
297            BackendMessage::AuthenticationSASL => {
298                dst.put_u32(10);
299                dst.put_string("SCRAM-SHA-256");
300                dst.put_u8(b'\0');
301            }
302            BackendMessage::AuthenticationSASLContinue(data) => {
303                dst.put_u32(11);
304                let data = format!(
305                    "r={},s={},i={}",
306                    data.nonce, data.salt, data.iteration_count
307                );
308                dst.put_slice(data.as_bytes());
309            }
310            BackendMessage::AuthenticationSASLFinal(data) => {
311                dst.put_u32(12);
312                let res = match data.kind {
313                    SASLServerFinalMessageKinds::Verifier(verifier) => {
314                        format!("v={}", verifier)
315                    }
316                };
317                dst.put_slice(res.as_bytes());
318                if !data.extensions.is_empty() {
319                    dst.put_slice(b",");
320                    dst.put_slice(data.extensions.join(",").as_bytes());
321                }
322            }
323            BackendMessage::RowDescription(fields) => {
324                dst.put_length_i16(fields.len())?;
325                for f in &fields {
326                    dst.put_string(&f.name.to_string());
327                    dst.put_u32(f.table_id);
328                    dst.put_u16(f.column_id);
329                    dst.put_u32(f.type_oid);
330                    dst.put_i16(f.type_len);
331                    dst.put_i32(f.type_mod);
332                    // TODO: make the format correct
333                    dst.put_format_i16(f.format);
334                }
335            }
336            BackendMessage::DataRow(fields) => {
337                dst.put_length_i16(fields.len())?;
338                for (f, (ty, format)) in fields.iter().zip_eq(&self.encode_state) {
339                    if let Some(f) = f {
340                        let base = dst.len();
341                        dst.put_u32(0);
342                        f.encode(ty, *format, dst)?;
343                        let len = dst.len() - base - 4;
344                        let len = i32::try_from(len).map_err(|_| {
345                            io::Error::new(
346                                io::ErrorKind::Other,
347                                "length of encoded data row field does not fit into an i32",
348                            )
349                        })?;
350                        dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
351                    } else {
352                        dst.put_i32(-1);
353                    }
354                }
355            }
356            BackendMessage::CommandComplete { tag } => {
357                dst.put_string(&tag);
358            }
359            BackendMessage::ParseComplete => (),
360            BackendMessage::BindComplete => (),
361            BackendMessage::CloseComplete => (),
362            BackendMessage::EmptyQueryResponse => (),
363            BackendMessage::ReadyForQuery(status) => {
364                dst.put_u8(status.into());
365            }
366            BackendMessage::ParameterStatus(name, value) => {
367                dst.put_string(name);
368                dst.put_string(&value);
369            }
370            BackendMessage::PortalSuspended => (),
371            BackendMessage::NoData => (),
372            BackendMessage::BackendKeyData {
373                conn_id,
374                secret_key,
375            } => {
376                dst.put_u32(conn_id);
377                dst.put_u32(secret_key);
378            }
379            BackendMessage::ParameterDescription(params) => {
380                dst.put_length_i16(params.len())?;
381                for param in params {
382                    dst.put_u32(param.oid());
383                }
384            }
385            BackendMessage::ErrorResponse(ErrorResponse {
386                severity,
387                code,
388                message,
389                detail,
390                hint,
391                position,
392            }) => {
393                dst.put_u8(b'S');
394                dst.put_string(severity.as_str());
395                dst.put_u8(b'C');
396                dst.put_string(code.code());
397                dst.put_u8(b'M');
398                dst.put_string(&message);
399                if let Some(detail) = &detail {
400                    dst.put_u8(b'D');
401                    dst.put_string(detail);
402                }
403                if let Some(hint) = &hint {
404                    dst.put_u8(b'H');
405                    dst.put_string(hint);
406                }
407                if let Some(position) = &position {
408                    dst.put_u8(b'P');
409                    dst.put_string(&position.to_string());
410                }
411                dst.put_u8(b'\0');
412            }
413        }
414
415        let len = dst.len() - base;
416
417        // Overwrite length placeholder with true length.
418        let len = i32::try_from(len).map_err(|_| {
419            io::Error::new(
420                io::ErrorKind::Other,
421                "length of encoded message does not fit into an i32",
422            )
423        })?;
424        dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
425
426        Ok(())
427    }
428}
429
430impl Decoder for Codec {
431    type Item = FrontendMessage;
432    type Error = io::Error;
433
434    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FrontendMessage>, io::Error> {
435        if src.len() > MAX_REQUEST_SIZE {
436            return Err(io::Error::new(
437                io::ErrorKind::InvalidData,
438                format!(
439                    "request larger than {}",
440                    ByteSize::b(u64::cast_from(MAX_REQUEST_SIZE))
441                ),
442            ));
443        }
444        loop {
445            match self.decode_state {
446                DecodeState::Head => {
447                    if src.len() < 5 {
448                        return Ok(None);
449                    }
450                    let msg_type = src[0];
451                    let frame_len = parse_frame_len(&src[1..])?;
452                    src.advance(5);
453                    src.reserve(frame_len);
454                    self.decode_state = DecodeState::Data(msg_type, frame_len);
455                }
456
457                DecodeState::Data(msg_type, frame_len) => {
458                    if src.len() < frame_len {
459                        return Ok(None);
460                    }
461                    let buf = src.split_to(frame_len).freeze();
462                    let buf = Cursor::new(&buf);
463                    let msg = match msg_type {
464                        // Simple query flow.
465                        b'Q' => decode_query(buf)?,
466
467                        // Extended query flow.
468                        b'P' => decode_parse(buf)?,
469                        b'D' => decode_describe(buf)?,
470                        b'B' => decode_bind(buf)?,
471                        b'E' => decode_execute(buf)?,
472                        b'H' => decode_flush(buf)?,
473                        b'S' => decode_sync(buf)?,
474                        b'C' => decode_close(buf)?,
475
476                        // Termination.
477                        b'X' => decode_terminate(buf)?,
478
479                        // Authentication.
480                        b'p' => decode_auth(buf)?,
481
482                        // Copy from flow.
483                        b'f' => decode_copy_fail(buf)?,
484                        b'd' => decode_copy_data(buf, frame_len)?,
485                        b'c' => decode_copy_done(buf)?,
486
487                        // Invalid.
488                        _ => {
489                            return Err(io::Error::new(
490                                io::ErrorKind::InvalidData,
491                                format!("unknown message type {}", msg_type),
492                            ));
493                        }
494                    };
495                    src.reserve(5);
496                    self.decode_state = DecodeState::Head;
497                    return Ok(Some(msg));
498                }
499            }
500        }
501    }
502}
503
504fn decode_terminate(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
505    // Nothing more to decode.
506    Ok(FrontendMessage::Terminate)
507}
508
509fn decode_auth(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
510    let mut value = Vec::new();
511    while let Ok(b) = buf.read_byte() {
512        value.push(b);
513    }
514    Ok(FrontendMessage::RawAuthentication(value))
515}
516
517fn expect(buf: &mut Cursor, expected: &[u8]) -> Result<(), io::Error> {
518    for i in 0..expected.len() {
519        if buf.read_byte()? != expected[i] {
520            return Err(input_err(format!(
521                "Invalid SASL initial response: expected '{}'",
522                std::str::from_utf8(expected).unwrap_or("invalid UTF-8")
523            )));
524        }
525    }
526    Ok(())
527}
528
529fn read_until_comma(buf: &mut Cursor) -> Result<Vec<u8>, io::Error> {
530    let mut v = Vec::new();
531    while let Ok(b) = buf.peek_byte() {
532        if b == b',' {
533            break;
534        }
535        v.push(buf.read_byte()?);
536    }
537    Ok(v)
538}
539
540// All SASL parsing is based on RFC 5802, [section 7](https://datatracker.ietf.org/doc/html/rfc5802#section-7)
541
542//   extensions = attr-val *("," attr-val)
543//                     ;; All extensions are optional,
544//                     ;; i.e., unrecognized attributes
545//                     ;; not defined in this document
546//                     ;; MUST be ignored.
547//   reserved-mext  = "m=" 1*(value-char)
548//                     ;; Reserved for signaling mandatory extensions.
549//                     ;; The exact syntax will be defined in
550//                     ;; the future.
551//   gs2-cbind-flag  = ("p=" cb-name) / "n" / "y"
552//                     ;; "n" -> client doesn't support channel binding.
553//                     ;; "y" -> client does support channel binding
554//                     ;;        but thinks the server does not.
555//                     ;; "p" -> client requires channel binding.
556//                     ;; The selected channel binding follows "p=".
557//
558//   gs2-header      = gs2-cbind-flag "," [ authzid ] ","
559//                     ;; GS2 header for SCRAM
560//                     ;; (the actual GS2 header includes an optional
561//                     ;; flag to indicate that the GSS mechanism is not
562//                     ;; "standard", but since SCRAM is "standard", we
563//                     ;; don't include that flag).
564//   client-first-message-bare =
565//                     [reserved-mext ","]
566//                     username "," nonce ["," extensions]
567//
568//   client-first-message =
569//                     gs2-header client-first-message-bare
570pub fn decode_sasl_client_first_message(mut buf: Cursor) -> Result<SASLInitialResponse, io::Error> {
571    // 1) GS2 cbind flag
572    let cbind_flag = match buf.read_byte()? {
573        b'n' => ChannelBinding::None,
574        b'y' => ChannelBinding::ClientSupported,
575        b'p' => {
576            // must be "p=" then cbname up to next comma
577            expect(&mut buf, b"=")?;
578            let cbname = String::from_utf8(read_until_comma(&mut buf)?)
579                .map_err(|_| input_err("invalid cbname utf8"))?;
580            ChannelBinding::Required(cbname)
581        }
582        other => {
583            return Err(input_err(format!(
584                "Invalid channel binding flag: {}",
585                other
586            )));
587        }
588    };
589    expect(&mut buf, b",")?;
590
591    // 2) Optional authzid: either empty, or "a=" up to next comma
592    let mut authzid = None;
593    if buf.peek_byte()? == b'a' {
594        expect(&mut buf, b"a=")?;
595        let a = String::from_utf8(read_until_comma(&mut buf)?)
596            .map_err(|_| input_err("invalid authzid utf8"))?;
597        authzid = Some(a);
598    }
599    expect(&mut buf, b",")?;
600
601    let mut client_first_message_bare_raw = String::new();
602
603    // 3) Optional reserved "m=" extension before n=
604    let mut reserved_mext = None;
605    if buf.peek_byte()? == b'm' {
606        expect(&mut buf, b"m=")?;
607        let mext_val = String::from_utf8(read_until_comma(&mut buf)?)
608            .map_err(|_| input_err("invalid m ext utf8"))?;
609        client_first_message_bare_raw.push_str(&format!("m={},", mext_val));
610        reserved_mext = Some(mext_val);
611        expect(&mut buf, b",")?;
612    }
613
614    // 4) Username: must be "n=" then saslname
615    expect(&mut buf, b"n=")?;
616    // Postgres doesn't use the username here, so we just consume
617    let username = String::from_utf8(read_until_comma(&mut buf)?)
618        .map_err(|_| input_err("invalid username utf8"))?;
619    expect(&mut buf, b",")?;
620    client_first_message_bare_raw.push_str(&format!("n={},", username));
621
622    // 5) Nonce: must be "r=" then value up to next comma or end
623    expect(&mut buf, b"r=")?;
624    let nonce = String::from_utf8(read_until_comma(&mut buf)?)
625        .map_err(|_| input_err("invalid nonce utf8"))?;
626    client_first_message_bare_raw.push_str(&format!("r={}", nonce));
627
628    // 6) Optional extensions: "," key=value chunks
629    let mut extensions = Vec::new();
630    while let Ok(b',') = buf.peek_byte().map(|b| b) {
631        expect(&mut buf, b",")?;
632        let ext = String::from_utf8(read_until_comma(&mut buf)?)
633            .map_err(|_| input_err("invalid ext utf8"))?;
634        if !ext.is_empty() {
635            client_first_message_bare_raw.push_str(&format!(",{}", ext));
636            extensions.push(ext);
637        }
638    }
639
640    Ok(SASLInitialResponse {
641        gs2_header: GS2Header {
642            cbind_flag,
643            authzid,
644        },
645        nonce,
646        extensions,
647        reserved_mext,
648        client_first_message_bare_raw,
649    })
650}
651
652pub fn decode_sasl_initial_response(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
653    let mechanism = buf.read_cstr()?;
654    let initial_resp_len = buf.read_i32()?;
655    if initial_resp_len < 0 {
656        // -1 means no response? We bail here
657        return Err(input_err("No initial response"));
658    }
659
660    let initial_response = decode_sasl_client_first_message(buf)?;
661    Ok(FrontendMessage::SASLInitialResponse {
662        gs2_header: initial_response.gs2_header.clone(),
663        mechanism: mechanism.to_owned(),
664        initial_response,
665    })
666}
667
668//   proof           = "p=" base64
669//
670//   channel-binding = "c=" base64
671//                     ;; base64 encoding of cbind-input.
672//   client-final-message-without-proof =
673//                     channel-binding "," nonce [","
674//                     extensions]
675//
676//   client-final-message =
677//                     client-final-message-without-proof "," proof
678pub fn decode_sasl_response(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
679    // --- client-final-message-without-proof ---
680    let mut client_final_message_bare_raw = String::new();
681    // channel-binding: "c=" <base64>, up to the next comma
682    expect(&mut buf, b"c=")?;
683    let channel_binding = String::from_utf8(read_until_comma(&mut buf)?)
684        .map_err(|_| input_err("invalid channel-binding utf8"))?;
685    expect(&mut buf, b",")?;
686    client_final_message_bare_raw.push_str(&format!("c={},", channel_binding));
687
688    // nonce: "r=" <printable>, up to the next comma
689    expect(&mut buf, b"r=")?;
690    let nonce = String::from_utf8(read_until_comma(&mut buf)?)
691        .map_err(|_| input_err("invalid nonce utf8"))?;
692    client_final_message_bare_raw.push_str(&format!("r={}", nonce));
693
694    // after reading channel-binding and nonce
695    let mut extensions = Vec::new();
696
697    // Keep reading ",<token>" until we see ",p="
698    while buf.peek_byte()? == b',' {
699        expect(&mut buf, b",")?;
700        if buf.peek_byte()? == b'p' {
701            break;
702        }
703        let ext = String::from_utf8(read_until_comma(&mut buf)?)
704            .map_err(|_| input_err("invalid extension utf8"))?;
705        if !ext.is_empty() {
706            client_final_message_bare_raw.push_str(&format!(",{}", ext));
707            extensions.push(ext);
708        }
709    }
710
711    // Proof is mandatory and last
712    expect(&mut buf, b"p=")?;
713    let proof = String::from_utf8(read_until_comma(&mut buf)?)
714        .map_err(|_| input_err("invalid proof utf8"))?;
715
716    Ok(FrontendMessage::SASLResponse(SASLClientFinalResponse {
717        channel_binding,
718        nonce,
719        extensions,
720        proof,
721        client_final_message_bare_raw,
722    }))
723}
724
725pub fn decode_password(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
726    Ok(FrontendMessage::Password {
727        password: buf.read_cstr()?.to_owned(),
728    })
729}
730
731fn decode_query(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
732    Ok(FrontendMessage::Query {
733        sql: buf.read_cstr()?.to_string(),
734    })
735}
736
737fn decode_parse(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
738    let name = buf.read_cstr()?;
739    let sql = buf.read_cstr()?;
740
741    let mut param_types = vec![];
742    for _ in 0..buf.read_i16()? {
743        param_types.push(buf.read_u32()?);
744    }
745
746    Ok(FrontendMessage::Parse {
747        name: name.into(),
748        sql: sql.into(),
749        param_types,
750    })
751}
752
753fn decode_close(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
754    match buf.read_byte()? {
755        b'S' => Ok(FrontendMessage::CloseStatement {
756            name: buf.read_cstr()?.to_owned(),
757        }),
758        b'P' => Ok(FrontendMessage::ClosePortal {
759            name: buf.read_cstr()?.to_owned(),
760        }),
761        b => Err(input_err(format!(
762            "invalid type byte in close message: {}",
763            b
764        ))),
765    }
766}
767
768fn decode_describe(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
769    let first_char = buf.read_byte()?;
770    let name = buf.read_cstr()?.to_string();
771    match first_char {
772        b'S' => Ok(FrontendMessage::DescribeStatement { name }),
773        b'P' => Ok(FrontendMessage::DescribePortal { name }),
774        other => Err(input_err(format!("Invalid describe type: {:#x?}", other))),
775    }
776}
777
778fn decode_bind(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
779    let portal_name = buf.read_cstr()?.to_string();
780    let statement_name = buf.read_cstr()?.to_string();
781
782    let mut param_formats = Vec::new();
783    for _ in 0..buf.read_i16()? {
784        param_formats.push(buf.read_format()?);
785    }
786
787    let mut raw_params = Vec::new();
788    for _ in 0..buf.read_i16()? {
789        let len = buf.read_i32()?;
790        if len == -1 {
791            raw_params.push(None); // NULL
792        } else {
793            // TODO(benesch): this should use bytes::Bytes to avoid the copy.
794            let mut value = Vec::new();
795            for _ in 0..len {
796                value.push(buf.read_byte()?);
797            }
798            raw_params.push(Some(value));
799        }
800    }
801
802    let mut result_formats = Vec::new();
803    for _ in 0..buf.read_i16()? {
804        result_formats.push(buf.read_format()?);
805    }
806
807    Ok(FrontendMessage::Bind {
808        portal_name,
809        statement_name,
810        param_formats,
811        raw_params,
812        result_formats,
813    })
814}
815
816fn decode_execute(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
817    let portal_name = buf.read_cstr()?.to_string();
818    let max_rows = buf.read_i32()?;
819    Ok(FrontendMessage::Execute {
820        portal_name,
821        max_rows,
822    })
823}
824
825fn decode_flush(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
826    // Nothing more to decode.
827    Ok(FrontendMessage::Flush)
828}
829
830fn decode_sync(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
831    // Nothing more to decode.
832    Ok(FrontendMessage::Sync)
833}
834
835fn decode_copy_data(mut buf: Cursor, frame_len: usize) -> Result<FrontendMessage, io::Error> {
836    let mut data = Vec::with_capacity(frame_len);
837    for _ in 0..frame_len {
838        data.push(buf.read_byte()?);
839    }
840    Ok(FrontendMessage::CopyData(data))
841}
842
843fn decode_copy_done(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
844    // Nothing more to decode.
845    Ok(FrontendMessage::CopyDone)
846}
847
848fn decode_copy_fail(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
849    Ok(FrontendMessage::CopyFail(buf.read_cstr()?.to_string()))
850}