mz_pgwire/
codec.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Encoding/decoding of messages in pgwire. See "[Frontend/Backend Protocol:
11//! Message Formats][1]" in the PostgreSQL reference for the specification.
12//!
13//! See the [crate docs](crate) for higher level concerns.
14//!
15//! [1]: https://www.postgresql.org/docs/11/protocol-message-formats.html
16
17use std::net::IpAddr;
18
19use async_trait::async_trait;
20use bytes::{Buf, BufMut, BytesMut};
21use bytesize::ByteSize;
22use futures::{SinkExt, TryStreamExt, sink};
23use itertools::Itertools;
24use mz_adapter_types::connection::ConnectionId;
25use mz_ore::cast::CastFrom;
26use mz_ore::future::OreSinkExt;
27use mz_ore::netio::AsyncReady;
28use mz_pgwire_common::{
29    Conn, Cursor, DecodeState, ErrorResponse, FrontendMessage, MAX_REQUEST_SIZE, Pgbuf, input_err,
30    parse_frame_len,
31};
32use tokio::io::{self, AsyncRead, AsyncWrite, Interest, Ready};
33use tokio::time::{self, Duration};
34use tokio_util::codec::{Decoder, Encoder, Framed};
35use tracing::trace;
36
37use crate::message::{BackendMessage, BackendMessageKind};
38
39/// A connection that manages the encoding and decoding of pgwire frames.
40pub struct FramedConn<A> {
41    conn_id: ConnectionId,
42    peer_addr: Option<IpAddr>,
43    inner: sink::Buffer<Framed<Conn<A>, Codec>, BackendMessage>,
44}
45
46impl<A> FramedConn<A>
47where
48    A: AsyncRead + AsyncWrite + Unpin,
49{
50    /// Constructs a new framed connection.
51    ///
52    /// The underlying connection, `inner`, is expected to be something like a
53    /// TCP stream. Anything that implements [`AsyncRead`] and [`AsyncWrite`]
54    /// will do.
55    ///
56    /// The supplied `conn_id` is used to identify the connection in logging
57    /// messages.
58    pub fn new(conn_id: ConnectionId, peer_addr: Option<IpAddr>, inner: Conn<A>) -> FramedConn<A> {
59        FramedConn {
60            conn_id,
61            peer_addr,
62            inner: Framed::new(inner, Codec::new()).buffer(32),
63        }
64    }
65
66    /// Reads and decodes one frontend message from the client.
67    ///
68    /// Blocks until the client sends a complete message. If the client
69    /// terminates the stream, returns `None`. Returns an error if the client
70    /// sends a malformed message or if the connection underlying is broken.
71    ///
72    /// # Cancel safety
73    ///
74    /// This method is cancel safe. The returned future only holds onto a
75    /// reference to thea underlying stream, so dropping it will never lose a
76    /// value.
77    ///
78    /// <https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html#cancel-safety-1>
79    pub async fn recv(&mut self) -> Result<Option<FrontendMessage>, io::Error> {
80        let message = self.inner.try_next().await?;
81        match &message {
82            Some(message) => trace!("cid={} recv_name={}", self.conn_id, message.name()),
83            None => trace!("cid={} recv=<eof>", self.conn_id),
84        }
85        Ok(message)
86    }
87
88    /// Encodes and sends one backend message to the client.
89    ///
90    /// Note that the connection is not flushed after calling this method. You
91    /// must call [`FramedConn::flush`] explicitly. Returns an error if the
92    /// underlying connection is broken.
93    ///
94    /// Please use `StateMachine::send` instead if calling from `StateMachine`,
95    /// as it applies session-based filters before calling this method.
96    pub async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
97    where
98        M: Into<BackendMessage>,
99    {
100        let message = message.into();
101        trace!(
102            "cid={} send={:?}",
103            self.conn_id,
104            BackendMessageKind::from(&message)
105        );
106        self.inner.enqueue(message).await
107    }
108
109    /// Encodes and sends the backend messages in the `messages` iterator to the
110    /// client.
111    ///
112    /// As with [`FramedConn::send`], the connection is not flushed after
113    /// calling this method. You must call [`FramedConn::flush`] explicitly.
114    /// Returns an error if the underlying connection is broken.
115    pub async fn send_all(
116        &mut self,
117        messages: impl IntoIterator<Item = BackendMessage>,
118    ) -> Result<(), io::Error> {
119        // N.B. we intentionally don't use `self.conn.send_all` here to avoid
120        // flushing the sink unnecessarily.
121        for m in messages {
122            self.send(m).await?;
123        }
124        Ok(())
125    }
126
127    /// Flushes all outstanding messages.
128    pub async fn flush(&mut self) -> Result<(), io::Error> {
129        self.inner.flush().await
130    }
131
132    /// Injects state that affects how certain backend messages are encoded.
133    ///
134    /// Specifically, the encoding of `BackendMessage::DataRow` depends upon the
135    /// types of the datums in the row. To avoid including the same type
136    /// information in each message, we use this side channel to install the
137    /// type information in the codec before sending any data row messages. This
138    /// violates the abstraction boundary a bit but results in much better
139    /// performance.
140    pub fn set_encode_state(
141        &mut self,
142        encode_state: Vec<(mz_pgrepr::Type, mz_pgwire_common::Format)>,
143    ) {
144        self.inner.get_mut().codec_mut().encode_state = encode_state;
145    }
146
147    /// Waits for the connection to be closed.
148    ///
149    /// Returns a "connection closed" error when the connection is closed. If
150    /// another error occurs before the connection is closed, that error is
151    /// returned instead.
152    ///
153    /// Use this method when you have an unbounded stream of data to forward to
154    /// the connection and the protocol does not require the client to
155    /// periodically acknowledge receipt. If you don't call this method to
156    /// periodically check if the connection has closed, you may not notice that
157    /// the client has gone away for an unboundedly long amount of time; usually
158    /// not until the stream of data produces its next message and you attempt
159    /// to write the data to the connection.
160    pub async fn wait_closed(&self) -> io::Error
161    where
162        A: AsyncReady + Send + Sync,
163    {
164        loop {
165            time::sleep(Duration::from_secs(1)).await;
166
167            match self.ready(Interest::READABLE | Interest::WRITABLE).await {
168                Ok(ready) if ready.is_read_closed() || ready.is_write_closed() => {
169                    return io::Error::new(io::ErrorKind::Other, "connection closed");
170                }
171                Ok(_) => (),
172                Err(err) => return err,
173            }
174        }
175    }
176
177    /// Returns the ID associated with this connection.
178    pub fn conn_id(&self) -> &ConnectionId {
179        &self.conn_id
180    }
181
182    /// Returns the peer address of the connection.
183    pub fn peer_addr(&self) -> &Option<IpAddr> {
184        &self.peer_addr
185    }
186}
187
188impl<A> FramedConn<A>
189where
190    A: AsyncRead + AsyncWrite + Unpin,
191{
192    pub fn inner(&self) -> &Conn<A> {
193        self.inner.get_ref().get_ref()
194    }
195}
196
197#[async_trait]
198impl<A> AsyncReady for FramedConn<A>
199where
200    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
201{
202    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
203        self.inner.get_ref().get_ref().ready(interest).await
204    }
205}
206
207struct Codec {
208    decode_state: DecodeState,
209    encode_state: Vec<(mz_pgrepr::Type, mz_pgwire_common::Format)>,
210}
211
212impl Codec {
213    /// Creates a new `Codec`.
214    pub fn new() -> Codec {
215        Codec {
216            decode_state: DecodeState::Head,
217            encode_state: vec![],
218        }
219    }
220}
221
222impl Default for Codec {
223    fn default() -> Codec {
224        Codec::new()
225    }
226}
227
228impl Encoder<BackendMessage> for Codec {
229    type Error = io::Error;
230
231    fn encode(&mut self, msg: BackendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
232        // Write type byte.
233        let byte = match &msg {
234            BackendMessage::AuthenticationOk => b'R',
235            BackendMessage::AuthenticationCleartextPassword => b'R',
236            BackendMessage::RowDescription(_) => b'T',
237            BackendMessage::DataRow(_) => b'D',
238            BackendMessage::CommandComplete { .. } => b'C',
239            BackendMessage::EmptyQueryResponse => b'I',
240            BackendMessage::ReadyForQuery(_) => b'Z',
241            BackendMessage::NoData => b'n',
242            BackendMessage::ParameterStatus(_, _) => b'S',
243            BackendMessage::PortalSuspended => b's',
244            BackendMessage::BackendKeyData { .. } => b'K',
245            BackendMessage::ParameterDescription(_) => b't',
246            BackendMessage::ParseComplete => b'1',
247            BackendMessage::BindComplete => b'2',
248            BackendMessage::CloseComplete => b'3',
249            BackendMessage::ErrorResponse(r) => {
250                if r.severity.is_error() {
251                    b'E'
252                } else {
253                    b'N'
254                }
255            }
256            BackendMessage::CopyInResponse { .. } => b'G',
257            BackendMessage::CopyOutResponse { .. } => b'H',
258            BackendMessage::CopyData(_) => b'd',
259            BackendMessage::CopyDone => b'c',
260        };
261        dst.put_u8(byte);
262
263        // Write message length placeholder. The true length is filled in later.
264        let base = dst.len();
265        dst.put_u32(0);
266
267        // Write message contents.
268        match msg {
269            BackendMessage::CopyInResponse {
270                overall_format,
271                column_formats,
272            }
273            | BackendMessage::CopyOutResponse {
274                overall_format,
275                column_formats,
276            } => {
277                dst.put_format_i8(overall_format);
278                dst.put_length_i16(column_formats.len())?;
279                for format in column_formats {
280                    dst.put_format_i16(format);
281                }
282            }
283            BackendMessage::CopyData(data) => {
284                dst.put_slice(&data);
285            }
286            BackendMessage::CopyDone => (),
287            BackendMessage::AuthenticationOk => {
288                dst.put_u32(0);
289            }
290            BackendMessage::AuthenticationCleartextPassword => {
291                dst.put_u32(3);
292            }
293            BackendMessage::RowDescription(fields) => {
294                dst.put_length_i16(fields.len())?;
295                for f in &fields {
296                    dst.put_string(&f.name.to_string());
297                    dst.put_u32(f.table_id);
298                    dst.put_u16(f.column_id);
299                    dst.put_u32(f.type_oid);
300                    dst.put_i16(f.type_len);
301                    dst.put_i32(f.type_mod);
302                    // TODO: make the format correct
303                    dst.put_format_i16(f.format);
304                }
305            }
306            BackendMessage::DataRow(fields) => {
307                dst.put_length_i16(fields.len())?;
308                for (f, (ty, format)) in fields.iter().zip_eq(&self.encode_state) {
309                    if let Some(f) = f {
310                        let base = dst.len();
311                        dst.put_u32(0);
312                        f.encode(ty, *format, dst)?;
313                        let len = dst.len() - base - 4;
314                        let len = i32::try_from(len).map_err(|_| {
315                            io::Error::new(
316                                io::ErrorKind::Other,
317                                "length of encoded data row field does not fit into an i32",
318                            )
319                        })?;
320                        dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
321                    } else {
322                        dst.put_i32(-1);
323                    }
324                }
325            }
326            BackendMessage::CommandComplete { tag } => {
327                dst.put_string(&tag);
328            }
329            BackendMessage::ParseComplete => (),
330            BackendMessage::BindComplete => (),
331            BackendMessage::CloseComplete => (),
332            BackendMessage::EmptyQueryResponse => (),
333            BackendMessage::ReadyForQuery(status) => {
334                dst.put_u8(status.into());
335            }
336            BackendMessage::ParameterStatus(name, value) => {
337                dst.put_string(name);
338                dst.put_string(&value);
339            }
340            BackendMessage::PortalSuspended => (),
341            BackendMessage::NoData => (),
342            BackendMessage::BackendKeyData {
343                conn_id,
344                secret_key,
345            } => {
346                dst.put_u32(conn_id);
347                dst.put_u32(secret_key);
348            }
349            BackendMessage::ParameterDescription(params) => {
350                dst.put_length_i16(params.len())?;
351                for param in params {
352                    dst.put_u32(param.oid());
353                }
354            }
355            BackendMessage::ErrorResponse(ErrorResponse {
356                severity,
357                code,
358                message,
359                detail,
360                hint,
361                position,
362            }) => {
363                dst.put_u8(b'S');
364                dst.put_string(severity.as_str());
365                dst.put_u8(b'C');
366                dst.put_string(code.code());
367                dst.put_u8(b'M');
368                dst.put_string(&message);
369                if let Some(detail) = &detail {
370                    dst.put_u8(b'D');
371                    dst.put_string(detail);
372                }
373                if let Some(hint) = &hint {
374                    dst.put_u8(b'H');
375                    dst.put_string(hint);
376                }
377                if let Some(position) = &position {
378                    dst.put_u8(b'P');
379                    dst.put_string(&position.to_string());
380                }
381                dst.put_u8(b'\0');
382            }
383        }
384
385        let len = dst.len() - base;
386
387        // Overwrite length placeholder with true length.
388        let len = i32::try_from(len).map_err(|_| {
389            io::Error::new(
390                io::ErrorKind::Other,
391                "length of encoded message does not fit into an i32",
392            )
393        })?;
394        dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
395
396        Ok(())
397    }
398}
399
400impl Decoder for Codec {
401    type Item = FrontendMessage;
402    type Error = io::Error;
403
404    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FrontendMessage>, io::Error> {
405        if src.len() > MAX_REQUEST_SIZE {
406            return Err(io::Error::new(
407                io::ErrorKind::InvalidData,
408                format!(
409                    "request larger than {}",
410                    ByteSize::b(u64::cast_from(MAX_REQUEST_SIZE))
411                ),
412            ));
413        }
414        loop {
415            match self.decode_state {
416                DecodeState::Head => {
417                    if src.len() < 5 {
418                        return Ok(None);
419                    }
420                    let msg_type = src[0];
421                    let frame_len = parse_frame_len(&src[1..])?;
422                    src.advance(5);
423                    src.reserve(frame_len);
424                    self.decode_state = DecodeState::Data(msg_type, frame_len);
425                }
426
427                DecodeState::Data(msg_type, frame_len) => {
428                    if src.len() < frame_len {
429                        return Ok(None);
430                    }
431                    let buf = src.split_to(frame_len).freeze();
432                    let buf = Cursor::new(&buf);
433                    let msg = match msg_type {
434                        // Simple query flow.
435                        b'Q' => decode_query(buf)?,
436
437                        // Extended query flow.
438                        b'P' => decode_parse(buf)?,
439                        b'D' => decode_describe(buf)?,
440                        b'B' => decode_bind(buf)?,
441                        b'E' => decode_execute(buf)?,
442                        b'H' => decode_flush(buf)?,
443                        b'S' => decode_sync(buf)?,
444                        b'C' => decode_close(buf)?,
445
446                        // Termination.
447                        b'X' => decode_terminate(buf)?,
448
449                        // Authentication.
450                        b'p' => decode_password(buf)?,
451
452                        // Copy from flow.
453                        b'f' => decode_copy_fail(buf)?,
454                        b'd' => decode_copy_data(buf, frame_len)?,
455                        b'c' => decode_copy_done(buf)?,
456
457                        // Invalid.
458                        _ => {
459                            return Err(io::Error::new(
460                                io::ErrorKind::InvalidData,
461                                format!("unknown message type {}", msg_type),
462                            ));
463                        }
464                    };
465                    src.reserve(5);
466                    self.decode_state = DecodeState::Head;
467                    return Ok(Some(msg));
468                }
469            }
470        }
471    }
472}
473
474fn decode_terminate(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
475    // Nothing more to decode.
476    Ok(FrontendMessage::Terminate)
477}
478
479fn decode_password(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
480    Ok(FrontendMessage::Password {
481        password: buf.read_cstr()?.to_owned(),
482    })
483}
484
485fn decode_query(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
486    Ok(FrontendMessage::Query {
487        sql: buf.read_cstr()?.to_string(),
488    })
489}
490
491fn decode_parse(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
492    let name = buf.read_cstr()?;
493    let sql = buf.read_cstr()?;
494
495    let mut param_types = vec![];
496    for _ in 0..buf.read_i16()? {
497        param_types.push(buf.read_u32()?);
498    }
499
500    Ok(FrontendMessage::Parse {
501        name: name.into(),
502        sql: sql.into(),
503        param_types,
504    })
505}
506
507fn decode_close(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
508    match buf.read_byte()? {
509        b'S' => Ok(FrontendMessage::CloseStatement {
510            name: buf.read_cstr()?.to_owned(),
511        }),
512        b'P' => Ok(FrontendMessage::ClosePortal {
513            name: buf.read_cstr()?.to_owned(),
514        }),
515        b => Err(input_err(format!(
516            "invalid type byte in close message: {}",
517            b
518        ))),
519    }
520}
521
522fn decode_describe(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
523    let first_char = buf.read_byte()?;
524    let name = buf.read_cstr()?.to_string();
525    match first_char {
526        b'S' => Ok(FrontendMessage::DescribeStatement { name }),
527        b'P' => Ok(FrontendMessage::DescribePortal { name }),
528        other => Err(input_err(format!("Invalid describe type: {:#x?}", other))),
529    }
530}
531
532fn decode_bind(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
533    let portal_name = buf.read_cstr()?.to_string();
534    let statement_name = buf.read_cstr()?.to_string();
535
536    let mut param_formats = Vec::new();
537    for _ in 0..buf.read_i16()? {
538        param_formats.push(buf.read_format()?);
539    }
540
541    let mut raw_params = Vec::new();
542    for _ in 0..buf.read_i16()? {
543        let len = buf.read_i32()?;
544        if len == -1 {
545            raw_params.push(None); // NULL
546        } else {
547            // TODO(benesch): this should use bytes::Bytes to avoid the copy.
548            let mut value = Vec::new();
549            for _ in 0..len {
550                value.push(buf.read_byte()?);
551            }
552            raw_params.push(Some(value));
553        }
554    }
555
556    let mut result_formats = Vec::new();
557    for _ in 0..buf.read_i16()? {
558        result_formats.push(buf.read_format()?);
559    }
560
561    Ok(FrontendMessage::Bind {
562        portal_name,
563        statement_name,
564        param_formats,
565        raw_params,
566        result_formats,
567    })
568}
569
570fn decode_execute(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
571    let portal_name = buf.read_cstr()?.to_string();
572    let max_rows = buf.read_i32()?;
573    Ok(FrontendMessage::Execute {
574        portal_name,
575        max_rows,
576    })
577}
578
579fn decode_flush(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
580    // Nothing more to decode.
581    Ok(FrontendMessage::Flush)
582}
583
584fn decode_sync(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
585    // Nothing more to decode.
586    Ok(FrontendMessage::Sync)
587}
588
589fn decode_copy_data(mut buf: Cursor, frame_len: usize) -> Result<FrontendMessage, io::Error> {
590    let mut data = Vec::with_capacity(frame_len);
591    for _ in 0..frame_len {
592        data.push(buf.read_byte()?);
593    }
594    Ok(FrontendMessage::CopyData(data))
595}
596
597fn decode_copy_done(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
598    // Nothing more to decode.
599    Ok(FrontendMessage::CopyDone)
600}
601
602fn decode_copy_fail(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
603    Ok(FrontendMessage::CopyFail(buf.read_cstr()?.to_string()))
604}