Skip to main content

mz_pgwire_common/
codec.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Encoding/decoding of messages in pgwire. See "[Frontend/Backend Protocol:
11//! Message Formats][1]" in the PostgreSQL reference for the specification.
12//!
13//! See the [crate docs](crate) for higher level concerns.
14//!
15//! [1]: https://www.postgresql.org/docs/11/protocol-message-formats.html
16
17use std::collections::BTreeMap;
18use std::error::Error;
19use std::{fmt, str};
20
21use byteorder::{ByteOrder, NetworkEndian};
22use bytes::{BufMut, BytesMut};
23use mz_ore::cast::{CastFrom, u64_to_usize};
24use mz_ore::netio::{self};
25use tokio::io::{self, AsyncRead, AsyncReadExt};
26
27use crate::FrontendMessage;
28use crate::format::Format;
29use crate::message::{FrontendStartupMessage, VERSION_CANCEL, VERSION_GSSENC, VERSION_SSL};
30
31pub const REJECT_ENCRYPTION: u8 = b'N';
32pub const ACCEPT_SSL_ENCRYPTION: u8 = b'S';
33
34/// Maximum allowed size for a request.
35pub const MAX_REQUEST_SIZE: usize = u64_to_usize(2 * bytesize::MB);
36
37#[derive(Debug)]
38pub enum CodecError {
39    StringNoTerminator,
40}
41
42impl Error for CodecError {}
43
44impl fmt::Display for CodecError {
45    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
46        f.write_str(match self {
47            CodecError::StringNoTerminator => "The string does not have a terminator",
48        })
49    }
50}
51
52pub trait Pgbuf: BufMut {
53    fn put_string(&mut self, s: &str);
54    fn put_length_i16(&mut self, len: usize) -> Result<(), io::Error>;
55    fn put_format_i8(&mut self, format: Format);
56    fn put_format_i16(&mut self, format: Format);
57}
58
59impl<B: BufMut> Pgbuf for B {
60    fn put_string(&mut self, s: &str) {
61        self.put(s.as_bytes());
62        self.put_u8(b'\0');
63    }
64
65    fn put_length_i16(&mut self, len: usize) -> Result<(), io::Error> {
66        let len = i16::try_from(len).map_err(|_| {
67            io::Error::new(io::ErrorKind::InvalidData, "length does not fit in an i16")
68        })?;
69        self.put_i16(len);
70        Ok(())
71    }
72
73    fn put_format_i8(&mut self, format: Format) {
74        self.put_i8(format.into())
75    }
76
77    fn put_format_i16(&mut self, format: Format) {
78        self.put_i8(0);
79        self.put_format_i8(format);
80    }
81}
82
83pub async fn decode_startup<A>(mut conn: A) -> Result<Option<FrontendStartupMessage>, io::Error>
84where
85    A: AsyncRead + Unpin,
86{
87    let mut frame_len = [0; 4];
88    let nread = netio::read_exact_or_eof(&mut conn, &mut frame_len).await?;
89    match nread {
90        // Complete frame length. Continue.
91        4 => (),
92        // Connection closed cleanly. Indicate that the startup sequence has
93        // been terminated by the client.
94        0 => return Ok(None),
95        // Partial frame length. Likely a client bug or network glitch, so
96        // surface the unexpected EOF.
97        _ => return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "early eof")),
98    };
99    let frame_len = parse_frame_len(&frame_len)?;
100
101    let mut buf = BytesMut::new();
102    buf.resize(frame_len, b'0');
103    conn.read_exact(&mut buf).await?;
104
105    let mut buf = Cursor::new(&buf);
106    let version = buf.read_i32()?;
107    let message = match version {
108        VERSION_CANCEL => FrontendStartupMessage::CancelRequest {
109            conn_id: buf.read_u32()?,
110            secret_key: buf.read_u32()?,
111        },
112        VERSION_SSL => FrontendStartupMessage::SslRequest,
113        VERSION_GSSENC => FrontendStartupMessage::GssEncRequest,
114        _ => {
115            let mut params = BTreeMap::new();
116            while buf.peek_byte()? != 0 {
117                let name = buf.read_cstr()?.to_owned();
118                let value = buf.read_cstr()?.to_owned();
119                params.insert(name, value);
120            }
121            FrontendStartupMessage::Startup { version, params }
122        }
123    };
124    Ok(Some(message))
125}
126
127impl FrontendStartupMessage {
128    /// Encodes self into dst.
129    pub fn encode(&self, dst: &mut BytesMut) -> Result<(), io::Error> {
130        // Write message length placeholder. The true length is filled in later.
131        let base = dst.len();
132        dst.put_u32(0);
133
134        // Write message contents.
135        match self {
136            FrontendStartupMessage::Startup { version, params } => {
137                dst.put_i32(*version);
138                for (k, v) in params {
139                    dst.put_string(k);
140                    dst.put_string(v);
141                }
142                dst.put_i8(0);
143            }
144            FrontendStartupMessage::CancelRequest {
145                conn_id,
146                secret_key,
147            } => {
148                dst.put_i32(VERSION_CANCEL);
149                dst.put_u32(*conn_id);
150                dst.put_u32(*secret_key);
151            }
152            FrontendStartupMessage::SslRequest {} => dst.put_i32(VERSION_SSL),
153            _ => panic!("unsupported"),
154        }
155
156        let len = dst.len() - base;
157
158        // Overwrite length placeholder with true length.
159        let len = i32::try_from(len).map_err(|_| {
160            io::Error::new(
161                io::ErrorKind::InvalidData,
162                "length of encoded message does not fit into an i32",
163            )
164        })?;
165        dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
166
167        Ok(())
168    }
169}
170
171impl FrontendMessage {
172    /// Encodes self into dst.
173    pub fn encode(&self, dst: &mut BytesMut) -> Result<(), io::Error> {
174        // Write type byte.
175        let byte = match self {
176            FrontendMessage::Password { .. } => b'p',
177            _ => panic!("unsupported"),
178        };
179        dst.put_u8(byte);
180
181        // Write message length placeholder. The true length is filled in later.
182        let base = dst.len();
183        dst.put_u32(0);
184
185        // Write message contents.
186        match self {
187            FrontendMessage::Password { password } => {
188                dst.put_string(password);
189            }
190            _ => panic!("unsupported"),
191        }
192
193        let len = dst.len() - base;
194
195        // Overwrite length placeholder with true length.
196        let len = i32::try_from(len).map_err(|_| {
197            io::Error::new(
198                io::ErrorKind::InvalidData,
199                "length of encoded message does not fit into an i32",
200            )
201        })?;
202        dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
203
204        Ok(())
205    }
206}
207
208#[derive(Debug)]
209pub enum DecodeState {
210    Head,
211    Data(u8, usize),
212}
213
214pub fn parse_frame_len(src: &[u8]) -> Result<usize, io::Error> {
215    let n = usize::cast_from(NetworkEndian::read_u32(src));
216    if n > netio::MAX_FRAME_SIZE {
217        return Err(io::Error::new(
218            io::ErrorKind::InvalidData,
219            netio::FrameTooBig,
220        ));
221    } else if n < 4 {
222        return Err(io::Error::new(
223            io::ErrorKind::InvalidInput,
224            "invalid frame length",
225        ));
226    }
227    Ok(n - 4)
228}
229
230/// Decodes data within pgwire messages.
231///
232/// The API provided is very similar to [`bytes::Buf`], but operations return
233/// errors rather than panicking. This is important for safety, as we don't want
234/// to crash if the user sends us malformed pgwire messages.
235///
236/// There are also some special-purpose methods, like [`Cursor::read_cstr`],
237/// that are specific to pgwire messages.
238#[derive(Debug)]
239pub struct Cursor<'a> {
240    buf: &'a [u8],
241}
242
243impl<'a> Cursor<'a> {
244    /// Constructs a new `Cursor` from a byte slice. The cursor will begin
245    /// decoding from the beginning of the slice.
246    pub fn new(buf: &'a [u8]) -> Cursor<'a> {
247        Cursor { buf }
248    }
249
250    /// Returns the next byte without advancing the cursor.
251    pub fn peek_byte(&self) -> Result<u8, io::Error> {
252        self.buf
253            .get(0)
254            .copied()
255            .ok_or_else(|| input_err("No byte to read"))
256    }
257
258    /// Returns the next byte, advancing the cursor by one byte.
259    pub fn read_byte(&mut self) -> Result<u8, io::Error> {
260        let byte = self.peek_byte()?;
261        self.advance(1);
262        Ok(byte)
263    }
264
265    /// Returns the next null-terminated string. The null character is not
266    /// included the returned string. The cursor is advanced past the null-
267    /// terminated string.
268    ///
269    /// If there is no null byte remaining in the string, returns
270    /// `CodecError::StringNoTerminator`. If the string is not valid UTF-8,
271    /// returns an `io::Error` with an error kind of
272    /// `io::ErrorKind::InvalidInput`.
273    ///
274    /// NOTE(benesch): it is possible that returning a string here is wrong, and
275    /// we should be returning bytes, so that we can support messages that are
276    /// not UTF-8 encoded. At the moment, we've not discovered a need for this,
277    /// though, and using proper strings is convenient.
278    pub fn read_cstr(&mut self) -> Result<&'a str, io::Error> {
279        if let Some(pos) = self.buf.iter().position(|b| *b == 0) {
280            let val = std::str::from_utf8(&self.buf[..pos]).map_err(input_err)?;
281            self.advance(pos + 1);
282            Ok(val)
283        } else {
284            Err(input_err(CodecError::StringNoTerminator))
285        }
286    }
287
288    /// Reads the next 16-bit signed integer, advancing the cursor by two
289    /// bytes.
290    pub fn read_i16(&mut self) -> Result<i16, io::Error> {
291        if self.buf.len() < 2 {
292            return Err(input_err("not enough buffer for an Int16"));
293        }
294        let val = NetworkEndian::read_i16(self.buf);
295        self.advance(2);
296        Ok(val)
297    }
298
299    /// Reads the next 32-bit signed integer, advancing the cursor by four
300    /// bytes.
301    pub fn read_i32(&mut self) -> Result<i32, io::Error> {
302        if self.buf.len() < 4 {
303            return Err(input_err("not enough buffer for an Int32"));
304        }
305        let val = NetworkEndian::read_i32(self.buf);
306        self.advance(4);
307        Ok(val)
308    }
309
310    /// Reads the next 32-bit unsigned integer, advancing the cursor by four
311    /// bytes.
312    pub fn read_u32(&mut self) -> Result<u32, io::Error> {
313        if self.buf.len() < 4 {
314            return Err(input_err("not enough buffer for an Int32"));
315        }
316        let val = NetworkEndian::read_u32(self.buf);
317        self.advance(4);
318        Ok(val)
319    }
320
321    /// Reads the next 16-bit format code, advancing the cursor by two bytes.
322    pub fn read_format(&mut self) -> Result<Format, io::Error> {
323        Format::try_from(self.read_i16()?)
324    }
325
326    /// Advances the cursor by `n` bytes.
327    pub fn advance(&mut self, n: usize) {
328        self.buf = &self.buf[n..]
329    }
330}
331
332/// Constructs an error indicating that the client has violated the pgwire
333/// protocol.
334pub fn input_err(source: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
335    io::Error::new(io::ErrorKind::InvalidInput, source.into())
336}