mz_pgwire/
codec.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Encoding/decoding of messages in pgwire. See "[Frontend/Backend Protocol:
11//! Message Formats][1]" in the PostgreSQL reference for the specification.
12//!
13//! See the [crate docs](crate) for higher level concerns.
14//!
15//! [1]: https://www.postgresql.org/docs/11/protocol-message-formats.html
16
17use std::net::IpAddr;
18
19use async_trait::async_trait;
20use bytes::{Buf, BufMut, BytesMut};
21use bytesize::ByteSize;
22use futures::{SinkExt, TryStreamExt, sink};
23use mz_adapter_types::connection::ConnectionId;
24use mz_ore::cast::CastFrom;
25use mz_ore::future::OreSinkExt;
26use mz_ore::netio::AsyncReady;
27use mz_pgwire_common::{
28    Conn, Cursor, DecodeState, ErrorResponse, FrontendMessage, MAX_REQUEST_SIZE, Pgbuf, input_err,
29    parse_frame_len,
30};
31use tokio::io::{self, AsyncRead, AsyncWrite, Interest, Ready};
32use tokio::time::{self, Duration};
33use tokio_util::codec::{Decoder, Encoder, Framed};
34use tracing::trace;
35
36use crate::message::{BackendMessage, BackendMessageKind};
37
38/// A connection that manages the encoding and decoding of pgwire frames.
39pub struct FramedConn<A> {
40    conn_id: ConnectionId,
41    peer_addr: Option<IpAddr>,
42    inner: sink::Buffer<Framed<Conn<A>, Codec>, BackendMessage>,
43}
44
45impl<A> FramedConn<A>
46where
47    A: AsyncRead + AsyncWrite + Unpin,
48{
49    /// Constructs a new framed connection.
50    ///
51    /// The underlying connection, `inner`, is expected to be something like a
52    /// TCP stream. Anything that implements [`AsyncRead`] and [`AsyncWrite`]
53    /// will do.
54    ///
55    /// The supplied `conn_id` is used to identify the connection in logging
56    /// messages.
57    pub fn new(conn_id: ConnectionId, peer_addr: Option<IpAddr>, inner: Conn<A>) -> FramedConn<A> {
58        FramedConn {
59            conn_id,
60            peer_addr,
61            inner: Framed::new(inner, Codec::new()).buffer(32),
62        }
63    }
64
65    /// Reads and decodes one frontend message from the client.
66    ///
67    /// Blocks until the client sends a complete message. If the client
68    /// terminates the stream, returns `None`. Returns an error if the client
69    /// sends a malformed message or if the connection underlying is broken.
70    ///
71    /// # Cancel safety
72    ///
73    /// This method is cancel safe. The returned future only holds onto a
74    /// reference to thea underlying stream, so dropping it will never lose a
75    /// value.
76    ///
77    /// <https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html#cancel-safety-1>
78    pub async fn recv(&mut self) -> Result<Option<FrontendMessage>, io::Error> {
79        let message = self.inner.try_next().await?;
80        match &message {
81            Some(message) => trace!("cid={} recv_name={}", self.conn_id, message.name()),
82            None => trace!("cid={} recv=<eof>", self.conn_id),
83        }
84        Ok(message)
85    }
86
87    /// Encodes and sends one backend message to the client.
88    ///
89    /// Note that the connection is not flushed after calling this method. You
90    /// must call [`FramedConn::flush`] explicitly. Returns an error if the
91    /// underlying connection is broken.
92    ///
93    /// Please use `StateMachine::send` instead if calling from `StateMachine`,
94    /// as it applies session-based filters before calling this method.
95    pub async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
96    where
97        M: Into<BackendMessage>,
98    {
99        let message = message.into();
100        trace!(
101            "cid={} send={:?}",
102            self.conn_id,
103            BackendMessageKind::from(&message)
104        );
105        self.inner.enqueue(message).await
106    }
107
108    /// Encodes and sends the backend messages in the `messages` iterator to the
109    /// client.
110    ///
111    /// As with [`FramedConn::send`], the connection is not flushed after
112    /// calling this method. You must call [`FramedConn::flush`] explicitly.
113    /// Returns an error if the underlying connection is broken.
114    pub async fn send_all(
115        &mut self,
116        messages: impl IntoIterator<Item = BackendMessage>,
117    ) -> Result<(), io::Error> {
118        // N.B. we intentionally don't use `self.conn.send_all` here to avoid
119        // flushing the sink unnecessarily.
120        for m in messages {
121            self.send(m).await?;
122        }
123        Ok(())
124    }
125
126    /// Flushes all outstanding messages.
127    pub async fn flush(&mut self) -> Result<(), io::Error> {
128        self.inner.flush().await
129    }
130
131    /// Injects state that affects how certain backend messages are encoded.
132    ///
133    /// Specifically, the encoding of `BackendMessage::DataRow` depends upon the
134    /// types of the datums in the row. To avoid including the same type
135    /// information in each message, we use this side channel to install the
136    /// type information in the codec before sending any data row messages. This
137    /// violates the abstraction boundary a bit but results in much better
138    /// performance.
139    pub fn set_encode_state(
140        &mut self,
141        encode_state: Vec<(mz_pgrepr::Type, mz_pgwire_common::Format)>,
142    ) {
143        self.inner.get_mut().codec_mut().encode_state = encode_state;
144    }
145
146    /// Waits for the connection to be closed.
147    ///
148    /// Returns a "connection closed" error when the connection is closed. If
149    /// another error occurs before the connection is closed, that error is
150    /// returned instead.
151    ///
152    /// Use this method when you have an unbounded stream of data to forward to
153    /// the connection and the protocol does not require the client to
154    /// periodically acknowledge receipt. If you don't call this method to
155    /// periodically check if the connection has closed, you may not notice that
156    /// the client has gone away for an unboundedly long amount of time; usually
157    /// not until the stream of data produces its next message and you attempt
158    /// to write the data to the connection.
159    pub async fn wait_closed(&self) -> io::Error
160    where
161        A: AsyncReady + Send + Sync,
162    {
163        loop {
164            time::sleep(Duration::from_secs(1)).await;
165
166            match self.ready(Interest::READABLE | Interest::WRITABLE).await {
167                Ok(ready) if ready.is_read_closed() || ready.is_write_closed() => {
168                    return io::Error::new(io::ErrorKind::Other, "connection closed");
169                }
170                Ok(_) => (),
171                Err(err) => return err,
172            }
173        }
174    }
175
176    /// Returns the ID associated with this connection.
177    pub fn conn_id(&self) -> &ConnectionId {
178        &self.conn_id
179    }
180
181    /// Returns the peer address of the connection.
182    pub fn peer_addr(&self) -> &Option<IpAddr> {
183        &self.peer_addr
184    }
185}
186
187impl<A> FramedConn<A>
188where
189    A: AsyncRead + AsyncWrite + Unpin,
190{
191    pub fn inner(&self) -> &Conn<A> {
192        self.inner.get_ref().get_ref()
193    }
194}
195
196#[async_trait]
197impl<A> AsyncReady for FramedConn<A>
198where
199    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
200{
201    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
202        self.inner.get_ref().get_ref().ready(interest).await
203    }
204}
205
206struct Codec {
207    decode_state: DecodeState,
208    encode_state: Vec<(mz_pgrepr::Type, mz_pgwire_common::Format)>,
209}
210
211impl Codec {
212    /// Creates a new `Codec`.
213    pub fn new() -> Codec {
214        Codec {
215            decode_state: DecodeState::Head,
216            encode_state: vec![],
217        }
218    }
219}
220
221impl Default for Codec {
222    fn default() -> Codec {
223        Codec::new()
224    }
225}
226
227impl Encoder<BackendMessage> for Codec {
228    type Error = io::Error;
229
230    fn encode(&mut self, msg: BackendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
231        // Write type byte.
232        let byte = match &msg {
233            BackendMessage::AuthenticationOk => b'R',
234            BackendMessage::AuthenticationCleartextPassword => b'R',
235            BackendMessage::RowDescription(_) => b'T',
236            BackendMessage::DataRow(_) => b'D',
237            BackendMessage::CommandComplete { .. } => b'C',
238            BackendMessage::EmptyQueryResponse => b'I',
239            BackendMessage::ReadyForQuery(_) => b'Z',
240            BackendMessage::NoData => b'n',
241            BackendMessage::ParameterStatus(_, _) => b'S',
242            BackendMessage::PortalSuspended => b's',
243            BackendMessage::BackendKeyData { .. } => b'K',
244            BackendMessage::ParameterDescription(_) => b't',
245            BackendMessage::ParseComplete => b'1',
246            BackendMessage::BindComplete => b'2',
247            BackendMessage::CloseComplete => b'3',
248            BackendMessage::ErrorResponse(r) => {
249                if r.severity.is_error() {
250                    b'E'
251                } else {
252                    b'N'
253                }
254            }
255            BackendMessage::CopyInResponse { .. } => b'G',
256            BackendMessage::CopyOutResponse { .. } => b'H',
257            BackendMessage::CopyData(_) => b'd',
258            BackendMessage::CopyDone => b'c',
259        };
260        dst.put_u8(byte);
261
262        // Write message length placeholder. The true length is filled in later.
263        let base = dst.len();
264        dst.put_u32(0);
265
266        // Write message contents.
267        match msg {
268            BackendMessage::CopyInResponse {
269                overall_format,
270                column_formats,
271            }
272            | BackendMessage::CopyOutResponse {
273                overall_format,
274                column_formats,
275            } => {
276                dst.put_format_i8(overall_format);
277                dst.put_length_i16(column_formats.len())?;
278                for format in column_formats {
279                    dst.put_format_i16(format);
280                }
281            }
282            BackendMessage::CopyData(data) => {
283                dst.put_slice(&data);
284            }
285            BackendMessage::CopyDone => (),
286            BackendMessage::AuthenticationOk => {
287                dst.put_u32(0);
288            }
289            BackendMessage::AuthenticationCleartextPassword => {
290                dst.put_u32(3);
291            }
292            BackendMessage::RowDescription(fields) => {
293                dst.put_length_i16(fields.len())?;
294                for f in &fields {
295                    dst.put_string(&f.name.to_string());
296                    dst.put_u32(f.table_id);
297                    dst.put_u16(f.column_id);
298                    dst.put_u32(f.type_oid);
299                    dst.put_i16(f.type_len);
300                    dst.put_i32(f.type_mod);
301                    // TODO: make the format correct
302                    dst.put_format_i16(f.format);
303                }
304            }
305            BackendMessage::DataRow(fields) => {
306                dst.put_length_i16(fields.len())?;
307                for (f, (ty, format)) in fields.iter().zip(&self.encode_state) {
308                    if let Some(f) = f {
309                        let base = dst.len();
310                        dst.put_u32(0);
311                        f.encode(ty, *format, dst)?;
312                        let len = dst.len() - base - 4;
313                        let len = i32::try_from(len).map_err(|_| {
314                            io::Error::new(
315                                io::ErrorKind::Other,
316                                "length of encoded data row field does not fit into an i32",
317                            )
318                        })?;
319                        dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
320                    } else {
321                        dst.put_i32(-1);
322                    }
323                }
324            }
325            BackendMessage::CommandComplete { tag } => {
326                dst.put_string(&tag);
327            }
328            BackendMessage::ParseComplete => (),
329            BackendMessage::BindComplete => (),
330            BackendMessage::CloseComplete => (),
331            BackendMessage::EmptyQueryResponse => (),
332            BackendMessage::ReadyForQuery(status) => {
333                dst.put_u8(status.into());
334            }
335            BackendMessage::ParameterStatus(name, value) => {
336                dst.put_string(name);
337                dst.put_string(&value);
338            }
339            BackendMessage::PortalSuspended => (),
340            BackendMessage::NoData => (),
341            BackendMessage::BackendKeyData {
342                conn_id,
343                secret_key,
344            } => {
345                dst.put_u32(conn_id);
346                dst.put_u32(secret_key);
347            }
348            BackendMessage::ParameterDescription(params) => {
349                dst.put_length_i16(params.len())?;
350                for param in params {
351                    dst.put_u32(param.oid());
352                }
353            }
354            BackendMessage::ErrorResponse(ErrorResponse {
355                severity,
356                code,
357                message,
358                detail,
359                hint,
360                position,
361            }) => {
362                dst.put_u8(b'S');
363                dst.put_string(severity.as_str());
364                dst.put_u8(b'C');
365                dst.put_string(code.code());
366                dst.put_u8(b'M');
367                dst.put_string(&message);
368                if let Some(detail) = &detail {
369                    dst.put_u8(b'D');
370                    dst.put_string(detail);
371                }
372                if let Some(hint) = &hint {
373                    dst.put_u8(b'H');
374                    dst.put_string(hint);
375                }
376                if let Some(position) = &position {
377                    dst.put_u8(b'P');
378                    dst.put_string(&position.to_string());
379                }
380                dst.put_u8(b'\0');
381            }
382        }
383
384        let len = dst.len() - base;
385
386        // Overwrite length placeholder with true length.
387        let len = i32::try_from(len).map_err(|_| {
388            io::Error::new(
389                io::ErrorKind::Other,
390                "length of encoded message does not fit into an i32",
391            )
392        })?;
393        dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
394
395        Ok(())
396    }
397}
398
399impl Decoder for Codec {
400    type Item = FrontendMessage;
401    type Error = io::Error;
402
403    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FrontendMessage>, io::Error> {
404        if src.len() > MAX_REQUEST_SIZE {
405            return Err(io::Error::new(
406                io::ErrorKind::InvalidData,
407                format!(
408                    "request larger than {}",
409                    ByteSize::b(u64::cast_from(MAX_REQUEST_SIZE))
410                ),
411            ));
412        }
413        loop {
414            match self.decode_state {
415                DecodeState::Head => {
416                    if src.len() < 5 {
417                        return Ok(None);
418                    }
419                    let msg_type = src[0];
420                    let frame_len = parse_frame_len(&src[1..])?;
421                    src.advance(5);
422                    src.reserve(frame_len);
423                    self.decode_state = DecodeState::Data(msg_type, frame_len);
424                }
425
426                DecodeState::Data(msg_type, frame_len) => {
427                    if src.len() < frame_len {
428                        return Ok(None);
429                    }
430                    let buf = src.split_to(frame_len).freeze();
431                    let buf = Cursor::new(&buf);
432                    let msg = match msg_type {
433                        // Simple query flow.
434                        b'Q' => decode_query(buf)?,
435
436                        // Extended query flow.
437                        b'P' => decode_parse(buf)?,
438                        b'D' => decode_describe(buf)?,
439                        b'B' => decode_bind(buf)?,
440                        b'E' => decode_execute(buf)?,
441                        b'H' => decode_flush(buf)?,
442                        b'S' => decode_sync(buf)?,
443                        b'C' => decode_close(buf)?,
444
445                        // Termination.
446                        b'X' => decode_terminate(buf)?,
447
448                        // Authentication.
449                        b'p' => decode_password(buf)?,
450
451                        // Copy from flow.
452                        b'f' => decode_copy_fail(buf)?,
453                        b'd' => decode_copy_data(buf, frame_len)?,
454                        b'c' => decode_copy_done(buf)?,
455
456                        // Invalid.
457                        _ => {
458                            return Err(io::Error::new(
459                                io::ErrorKind::InvalidData,
460                                format!("unknown message type {}", msg_type),
461                            ));
462                        }
463                    };
464                    src.reserve(5);
465                    self.decode_state = DecodeState::Head;
466                    return Ok(Some(msg));
467                }
468            }
469        }
470    }
471}
472
473fn decode_terminate(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
474    // Nothing more to decode.
475    Ok(FrontendMessage::Terminate)
476}
477
478fn decode_password(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
479    Ok(FrontendMessage::Password {
480        password: buf.read_cstr()?.to_owned(),
481    })
482}
483
484fn decode_query(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
485    Ok(FrontendMessage::Query {
486        sql: buf.read_cstr()?.to_string(),
487    })
488}
489
490fn decode_parse(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
491    let name = buf.read_cstr()?;
492    let sql = buf.read_cstr()?;
493
494    let mut param_types = vec![];
495    for _ in 0..buf.read_i16()? {
496        param_types.push(buf.read_u32()?);
497    }
498
499    Ok(FrontendMessage::Parse {
500        name: name.into(),
501        sql: sql.into(),
502        param_types,
503    })
504}
505
506fn decode_close(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
507    match buf.read_byte()? {
508        b'S' => Ok(FrontendMessage::CloseStatement {
509            name: buf.read_cstr()?.to_owned(),
510        }),
511        b'P' => Ok(FrontendMessage::ClosePortal {
512            name: buf.read_cstr()?.to_owned(),
513        }),
514        b => Err(input_err(format!(
515            "invalid type byte in close message: {}",
516            b
517        ))),
518    }
519}
520
521fn decode_describe(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
522    let first_char = buf.read_byte()?;
523    let name = buf.read_cstr()?.to_string();
524    match first_char {
525        b'S' => Ok(FrontendMessage::DescribeStatement { name }),
526        b'P' => Ok(FrontendMessage::DescribePortal { name }),
527        other => Err(input_err(format!("Invalid describe type: {:#x?}", other))),
528    }
529}
530
531fn decode_bind(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
532    let portal_name = buf.read_cstr()?.to_string();
533    let statement_name = buf.read_cstr()?.to_string();
534
535    let mut param_formats = Vec::new();
536    for _ in 0..buf.read_i16()? {
537        param_formats.push(buf.read_format()?);
538    }
539
540    let mut raw_params = Vec::new();
541    for _ in 0..buf.read_i16()? {
542        let len = buf.read_i32()?;
543        if len == -1 {
544            raw_params.push(None); // NULL
545        } else {
546            // TODO(benesch): this should use bytes::Bytes to avoid the copy.
547            let mut value = Vec::new();
548            for _ in 0..len {
549                value.push(buf.read_byte()?);
550            }
551            raw_params.push(Some(value));
552        }
553    }
554
555    let mut result_formats = Vec::new();
556    for _ in 0..buf.read_i16()? {
557        result_formats.push(buf.read_format()?);
558    }
559
560    Ok(FrontendMessage::Bind {
561        portal_name,
562        statement_name,
563        param_formats,
564        raw_params,
565        result_formats,
566    })
567}
568
569fn decode_execute(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
570    let portal_name = buf.read_cstr()?.to_string();
571    let max_rows = buf.read_i32()?;
572    Ok(FrontendMessage::Execute {
573        portal_name,
574        max_rows,
575    })
576}
577
578fn decode_flush(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
579    // Nothing more to decode.
580    Ok(FrontendMessage::Flush)
581}
582
583fn decode_sync(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
584    // Nothing more to decode.
585    Ok(FrontendMessage::Sync)
586}
587
588fn decode_copy_data(mut buf: Cursor, frame_len: usize) -> Result<FrontendMessage, io::Error> {
589    let mut data = Vec::with_capacity(frame_len);
590    for _ in 0..frame_len {
591        data.push(buf.read_byte()?);
592    }
593    Ok(FrontendMessage::CopyData(data))
594}
595
596fn decode_copy_done(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
597    // Nothing more to decode.
598    Ok(FrontendMessage::CopyDone)
599}
600
601fn decode_copy_fail(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
602    Ok(FrontendMessage::CopyFail(buf.read_cstr()?.to_string()))
603}