mz_pgwire_common/
codec.rs
1use std::collections::BTreeMap;
18use std::error::Error;
19use std::{fmt, str};
20
21use byteorder::{ByteOrder, NetworkEndian};
22use bytes::{BufMut, BytesMut};
23use mz_ore::cast::{CastFrom, u64_to_usize};
24use mz_ore::netio::{self};
25use tokio::io::{self, AsyncRead, AsyncReadExt};
26
27use crate::FrontendMessage;
28use crate::format::Format;
29use crate::message::{FrontendStartupMessage, VERSION_CANCEL, VERSION_GSSENC, VERSION_SSL};
30
31pub const REJECT_ENCRYPTION: u8 = b'N';
32pub const ACCEPT_SSL_ENCRYPTION: u8 = b'S';
33
34pub const MAX_REQUEST_SIZE: usize = u64_to_usize(2 * bytesize::MB);
36
37#[derive(Debug)]
38pub enum CodecError {
39 StringNoTerminator,
40}
41
42impl Error for CodecError {}
43
44impl fmt::Display for CodecError {
45 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
46 f.write_str(match self {
47 CodecError::StringNoTerminator => "The string does not have a terminator",
48 })
49 }
50}
51
52pub trait Pgbuf: BufMut {
53 fn put_string(&mut self, s: &str);
54 fn put_length_i16(&mut self, len: usize) -> Result<(), io::Error>;
55 fn put_format_i8(&mut self, format: Format);
56 fn put_format_i16(&mut self, format: Format);
57}
58
59impl<B: BufMut> Pgbuf for B {
60 fn put_string(&mut self, s: &str) {
61 self.put(s.as_bytes());
62 self.put_u8(b'\0');
63 }
64
65 fn put_length_i16(&mut self, len: usize) -> Result<(), io::Error> {
66 let len = i16::try_from(len)
67 .map_err(|_| io::Error::new(io::ErrorKind::Other, "length does not fit in an i16"))?;
68 self.put_i16(len);
69 Ok(())
70 }
71
72 fn put_format_i8(&mut self, format: Format) {
73 self.put_i8(format.into())
74 }
75
76 fn put_format_i16(&mut self, format: Format) {
77 self.put_i8(0);
78 self.put_format_i8(format);
79 }
80}
81
82pub async fn decode_startup<A>(mut conn: A) -> Result<Option<FrontendStartupMessage>, io::Error>
83where
84 A: AsyncRead + Unpin,
85{
86 let mut frame_len = [0; 4];
87 let nread = netio::read_exact_or_eof(&mut conn, &mut frame_len).await?;
88 match nread {
89 4 => (),
91 0 => return Ok(None),
94 _ => return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "early eof")),
97 };
98 let frame_len = parse_frame_len(&frame_len)?;
99
100 let mut buf = BytesMut::new();
101 buf.resize(frame_len, b'0');
102 conn.read_exact(&mut buf).await?;
103
104 let mut buf = Cursor::new(&buf);
105 let version = buf.read_i32()?;
106 let message = match version {
107 VERSION_CANCEL => FrontendStartupMessage::CancelRequest {
108 conn_id: buf.read_u32()?,
109 secret_key: buf.read_u32()?,
110 },
111 VERSION_SSL => FrontendStartupMessage::SslRequest,
112 VERSION_GSSENC => FrontendStartupMessage::GssEncRequest,
113 _ => {
114 let mut params = BTreeMap::new();
115 while buf.peek_byte()? != 0 {
116 let name = buf.read_cstr()?.to_owned();
117 let value = buf.read_cstr()?.to_owned();
118 params.insert(name, value);
119 }
120 FrontendStartupMessage::Startup { version, params }
121 }
122 };
123 Ok(Some(message))
124}
125
126impl FrontendStartupMessage {
127 pub fn encode(&self, dst: &mut BytesMut) -> Result<(), io::Error> {
129 let base = dst.len();
131 dst.put_u32(0);
132
133 match self {
135 FrontendStartupMessage::Startup { version, params } => {
136 dst.put_i32(*version);
137 for (k, v) in params {
138 dst.put_string(k);
139 dst.put_string(v);
140 }
141 dst.put_i8(0);
142 }
143 FrontendStartupMessage::CancelRequest {
144 conn_id,
145 secret_key,
146 } => {
147 dst.put_i32(VERSION_CANCEL);
148 dst.put_u32(*conn_id);
149 dst.put_u32(*secret_key);
150 }
151 FrontendStartupMessage::SslRequest {} => dst.put_i32(VERSION_SSL),
152 _ => panic!("unsupported"),
153 }
154
155 let len = dst.len() - base;
156
157 let len = i32::try_from(len).map_err(|_| {
159 io::Error::new(
160 io::ErrorKind::Other,
161 "length of encoded message does not fit into an i32",
162 )
163 })?;
164 dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
165
166 Ok(())
167 }
168}
169
170impl FrontendMessage {
171 pub fn encode(&self, dst: &mut BytesMut) -> Result<(), io::Error> {
173 let byte = match self {
175 FrontendMessage::Password { .. } => b'p',
176 _ => panic!("unsupported"),
177 };
178 dst.put_u8(byte);
179
180 let base = dst.len();
182 dst.put_u32(0);
183
184 match self {
186 FrontendMessage::Password { password } => {
187 dst.put_string(password);
188 }
189 _ => panic!("unsupported"),
190 }
191
192 let len = dst.len() - base;
193
194 let len = i32::try_from(len).map_err(|_| {
196 io::Error::new(
197 io::ErrorKind::Other,
198 "length of encoded message does not fit into an i32",
199 )
200 })?;
201 dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
202
203 Ok(())
204 }
205}
206
207#[derive(Debug)]
208pub enum DecodeState {
209 Head,
210 Data(u8, usize),
211}
212
213pub fn parse_frame_len(src: &[u8]) -> Result<usize, io::Error> {
214 let n = usize::cast_from(NetworkEndian::read_u32(src));
215 if n > netio::MAX_FRAME_SIZE {
216 return Err(io::Error::new(
217 io::ErrorKind::InvalidData,
218 netio::FrameTooBig,
219 ));
220 } else if n < 4 {
221 return Err(io::Error::new(
222 io::ErrorKind::InvalidInput,
223 "invalid frame length",
224 ));
225 }
226 Ok(n - 4)
227}
228
229#[derive(Debug)]
238pub struct Cursor<'a> {
239 buf: &'a [u8],
240}
241
242impl<'a> Cursor<'a> {
243 pub fn new(buf: &'a [u8]) -> Cursor<'a> {
246 Cursor { buf }
247 }
248
249 pub fn peek_byte(&self) -> Result<u8, io::Error> {
251 self.buf
252 .get(0)
253 .copied()
254 .ok_or_else(|| input_err("No byte to read"))
255 }
256
257 pub fn read_byte(&mut self) -> Result<u8, io::Error> {
259 let byte = self.peek_byte()?;
260 self.advance(1);
261 Ok(byte)
262 }
263
264 pub fn read_cstr(&mut self) -> Result<&'a str, io::Error> {
278 if let Some(pos) = self.buf.iter().position(|b| *b == 0) {
279 let val = std::str::from_utf8(&self.buf[..pos]).map_err(input_err)?;
280 self.advance(pos + 1);
281 Ok(val)
282 } else {
283 Err(input_err(CodecError::StringNoTerminator))
284 }
285 }
286
287 pub fn read_i16(&mut self) -> Result<i16, io::Error> {
290 if self.buf.len() < 2 {
291 return Err(input_err("not enough buffer for an Int16"));
292 }
293 let val = NetworkEndian::read_i16(self.buf);
294 self.advance(2);
295 Ok(val)
296 }
297
298 pub fn read_i32(&mut self) -> Result<i32, io::Error> {
301 if self.buf.len() < 4 {
302 return Err(input_err("not enough buffer for an Int32"));
303 }
304 let val = NetworkEndian::read_i32(self.buf);
305 self.advance(4);
306 Ok(val)
307 }
308
309 pub fn read_u32(&mut self) -> Result<u32, io::Error> {
312 if self.buf.len() < 4 {
313 return Err(input_err("not enough buffer for an Int32"));
314 }
315 let val = NetworkEndian::read_u32(self.buf);
316 self.advance(4);
317 Ok(val)
318 }
319
320 pub fn read_format(&mut self) -> Result<Format, io::Error> {
322 Format::try_from(self.read_i16()?)
323 }
324
325 pub fn advance(&mut self, n: usize) {
327 self.buf = &self.buf[n..]
328 }
329}
330
331pub fn input_err(source: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
334 io::Error::new(io::ErrorKind::InvalidInput, source.into())
335}