mz_pgwire_common/
codec.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Encoding/decoding of messages in pgwire. See "[Frontend/Backend Protocol:
11//! Message Formats][1]" in the PostgreSQL reference for the specification.
12//!
13//! See the [crate docs](crate) for higher level concerns.
14//!
15//! [1]: https://www.postgresql.org/docs/11/protocol-message-formats.html
16
17use std::collections::BTreeMap;
18use std::error::Error;
19use std::{fmt, str};
20
21use byteorder::{ByteOrder, NetworkEndian};
22use bytes::{BufMut, BytesMut};
23use mz_ore::cast::{CastFrom, u64_to_usize};
24use mz_ore::netio::{self};
25use tokio::io::{self, AsyncRead, AsyncReadExt};
26
27use crate::FrontendMessage;
28use crate::format::Format;
29use crate::message::{FrontendStartupMessage, VERSION_CANCEL, VERSION_GSSENC, VERSION_SSL};
30
31pub const REJECT_ENCRYPTION: u8 = b'N';
32pub const ACCEPT_SSL_ENCRYPTION: u8 = b'S';
33
34/// Maximum allowed size for a request.
35pub const MAX_REQUEST_SIZE: usize = u64_to_usize(2 * bytesize::MB);
36
37#[derive(Debug)]
38pub enum CodecError {
39    StringNoTerminator,
40}
41
42impl Error for CodecError {}
43
44impl fmt::Display for CodecError {
45    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
46        f.write_str(match self {
47            CodecError::StringNoTerminator => "The string does not have a terminator",
48        })
49    }
50}
51
52pub trait Pgbuf: BufMut {
53    fn put_string(&mut self, s: &str);
54    fn put_length_i16(&mut self, len: usize) -> Result<(), io::Error>;
55    fn put_format_i8(&mut self, format: Format);
56    fn put_format_i16(&mut self, format: Format);
57}
58
59impl<B: BufMut> Pgbuf for B {
60    fn put_string(&mut self, s: &str) {
61        self.put(s.as_bytes());
62        self.put_u8(b'\0');
63    }
64
65    fn put_length_i16(&mut self, len: usize) -> Result<(), io::Error> {
66        let len = i16::try_from(len)
67            .map_err(|_| io::Error::new(io::ErrorKind::Other, "length does not fit in an i16"))?;
68        self.put_i16(len);
69        Ok(())
70    }
71
72    fn put_format_i8(&mut self, format: Format) {
73        self.put_i8(format.into())
74    }
75
76    fn put_format_i16(&mut self, format: Format) {
77        self.put_i8(0);
78        self.put_format_i8(format);
79    }
80}
81
82pub async fn decode_startup<A>(mut conn: A) -> Result<Option<FrontendStartupMessage>, io::Error>
83where
84    A: AsyncRead + Unpin,
85{
86    let mut frame_len = [0; 4];
87    let nread = netio::read_exact_or_eof(&mut conn, &mut frame_len).await?;
88    match nread {
89        // Complete frame length. Continue.
90        4 => (),
91        // Connection closed cleanly. Indicate that the startup sequence has
92        // been terminated by the client.
93        0 => return Ok(None),
94        // Partial frame length. Likely a client bug or network glitch, so
95        // surface the unexpected EOF.
96        _ => return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "early eof")),
97    };
98    let frame_len = parse_frame_len(&frame_len)?;
99
100    let mut buf = BytesMut::new();
101    buf.resize(frame_len, b'0');
102    conn.read_exact(&mut buf).await?;
103
104    let mut buf = Cursor::new(&buf);
105    let version = buf.read_i32()?;
106    let message = match version {
107        VERSION_CANCEL => FrontendStartupMessage::CancelRequest {
108            conn_id: buf.read_u32()?,
109            secret_key: buf.read_u32()?,
110        },
111        VERSION_SSL => FrontendStartupMessage::SslRequest,
112        VERSION_GSSENC => FrontendStartupMessage::GssEncRequest,
113        _ => {
114            let mut params = BTreeMap::new();
115            while buf.peek_byte()? != 0 {
116                let name = buf.read_cstr()?.to_owned();
117                let value = buf.read_cstr()?.to_owned();
118                params.insert(name, value);
119            }
120            FrontendStartupMessage::Startup { version, params }
121        }
122    };
123    Ok(Some(message))
124}
125
126impl FrontendStartupMessage {
127    /// Encodes self into dst.
128    pub fn encode(&self, dst: &mut BytesMut) -> Result<(), io::Error> {
129        // Write message length placeholder. The true length is filled in later.
130        let base = dst.len();
131        dst.put_u32(0);
132
133        // Write message contents.
134        match self {
135            FrontendStartupMessage::Startup { version, params } => {
136                dst.put_i32(*version);
137                for (k, v) in params {
138                    dst.put_string(k);
139                    dst.put_string(v);
140                }
141                dst.put_i8(0);
142            }
143            FrontendStartupMessage::CancelRequest {
144                conn_id,
145                secret_key,
146            } => {
147                dst.put_i32(VERSION_CANCEL);
148                dst.put_u32(*conn_id);
149                dst.put_u32(*secret_key);
150            }
151            FrontendStartupMessage::SslRequest {} => dst.put_i32(VERSION_SSL),
152            _ => panic!("unsupported"),
153        }
154
155        let len = dst.len() - base;
156
157        // Overwrite length placeholder with true length.
158        let len = i32::try_from(len).map_err(|_| {
159            io::Error::new(
160                io::ErrorKind::Other,
161                "length of encoded message does not fit into an i32",
162            )
163        })?;
164        dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
165
166        Ok(())
167    }
168}
169
170impl FrontendMessage {
171    /// Encodes self into dst.
172    pub fn encode(&self, dst: &mut BytesMut) -> Result<(), io::Error> {
173        // Write type byte.
174        let byte = match self {
175            FrontendMessage::Password { .. } => b'p',
176            _ => panic!("unsupported"),
177        };
178        dst.put_u8(byte);
179
180        // Write message length placeholder. The true length is filled in later.
181        let base = dst.len();
182        dst.put_u32(0);
183
184        // Write message contents.
185        match self {
186            FrontendMessage::Password { password } => {
187                dst.put_string(password);
188            }
189            _ => panic!("unsupported"),
190        }
191
192        let len = dst.len() - base;
193
194        // Overwrite length placeholder with true length.
195        let len = i32::try_from(len).map_err(|_| {
196            io::Error::new(
197                io::ErrorKind::Other,
198                "length of encoded message does not fit into an i32",
199            )
200        })?;
201        dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
202
203        Ok(())
204    }
205}
206
207#[derive(Debug)]
208pub enum DecodeState {
209    Head,
210    Data(u8, usize),
211}
212
213pub fn parse_frame_len(src: &[u8]) -> Result<usize, io::Error> {
214    let n = usize::cast_from(NetworkEndian::read_u32(src));
215    if n > netio::MAX_FRAME_SIZE {
216        return Err(io::Error::new(
217            io::ErrorKind::InvalidData,
218            netio::FrameTooBig,
219        ));
220    } else if n < 4 {
221        return Err(io::Error::new(
222            io::ErrorKind::InvalidInput,
223            "invalid frame length",
224        ));
225    }
226    Ok(n - 4)
227}
228
229/// Decodes data within pgwire messages.
230///
231/// The API provided is very similar to [`bytes::Buf`], but operations return
232/// errors rather than panicking. This is important for safety, as we don't want
233/// to crash if the user sends us malformed pgwire messages.
234///
235/// There are also some special-purpose methods, like [`Cursor::read_cstr`],
236/// that are specific to pgwire messages.
237#[derive(Debug)]
238pub struct Cursor<'a> {
239    buf: &'a [u8],
240}
241
242impl<'a> Cursor<'a> {
243    /// Constructs a new `Cursor` from a byte slice. The cursor will begin
244    /// decoding from the beginning of the slice.
245    pub fn new(buf: &'a [u8]) -> Cursor<'a> {
246        Cursor { buf }
247    }
248
249    /// Returns the next byte without advancing the cursor.
250    pub fn peek_byte(&self) -> Result<u8, io::Error> {
251        self.buf
252            .get(0)
253            .copied()
254            .ok_or_else(|| input_err("No byte to read"))
255    }
256
257    /// Returns the next byte, advancing the cursor by one byte.
258    pub fn read_byte(&mut self) -> Result<u8, io::Error> {
259        let byte = self.peek_byte()?;
260        self.advance(1);
261        Ok(byte)
262    }
263
264    /// Returns the next null-terminated string. The null character is not
265    /// included the returned string. The cursor is advanced past the null-
266    /// terminated string.
267    ///
268    /// If there is no null byte remaining in the string, returns
269    /// `CodecError::StringNoTerminator`. If the string is not valid UTF-8,
270    /// returns an `io::Error` with an error kind of
271    /// `io::ErrorKind::InvalidInput`.
272    ///
273    /// NOTE(benesch): it is possible that returning a string here is wrong, and
274    /// we should be returning bytes, so that we can support messages that are
275    /// not UTF-8 encoded. At the moment, we've not discovered a need for this,
276    /// though, and using proper strings is convenient.
277    pub fn read_cstr(&mut self) -> Result<&'a str, io::Error> {
278        if let Some(pos) = self.buf.iter().position(|b| *b == 0) {
279            let val = std::str::from_utf8(&self.buf[..pos]).map_err(input_err)?;
280            self.advance(pos + 1);
281            Ok(val)
282        } else {
283            Err(input_err(CodecError::StringNoTerminator))
284        }
285    }
286
287    /// Reads the next 16-bit signed integer, advancing the cursor by two
288    /// bytes.
289    pub fn read_i16(&mut self) -> Result<i16, io::Error> {
290        if self.buf.len() < 2 {
291            return Err(input_err("not enough buffer for an Int16"));
292        }
293        let val = NetworkEndian::read_i16(self.buf);
294        self.advance(2);
295        Ok(val)
296    }
297
298    /// Reads the next 32-bit signed integer, advancing the cursor by four
299    /// bytes.
300    pub fn read_i32(&mut self) -> Result<i32, io::Error> {
301        if self.buf.len() < 4 {
302            return Err(input_err("not enough buffer for an Int32"));
303        }
304        let val = NetworkEndian::read_i32(self.buf);
305        self.advance(4);
306        Ok(val)
307    }
308
309    /// Reads the next 32-bit unsigned integer, advancing the cursor by four
310    /// bytes.
311    pub fn read_u32(&mut self) -> Result<u32, io::Error> {
312        if self.buf.len() < 4 {
313            return Err(input_err("not enough buffer for an Int32"));
314        }
315        let val = NetworkEndian::read_u32(self.buf);
316        self.advance(4);
317        Ok(val)
318    }
319
320    /// Reads the next 16-bit format code, advancing the cursor by two bytes.
321    pub fn read_format(&mut self) -> Result<Format, io::Error> {
322        Format::try_from(self.read_i16()?)
323    }
324
325    /// Advances the cursor by `n` bytes.
326    pub fn advance(&mut self, n: usize) {
327        self.buf = &self.buf[n..]
328    }
329}
330
331/// Constructs an error indicating that the client has violated the pgwire
332/// protocol.
333pub fn input_err(source: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
334    io::Error::new(io::ErrorKind::InvalidInput, source.into())
335}