mz_pgwire_common/
codec.rs1use std::collections::BTreeMap;
18use std::error::Error;
19use std::{fmt, str};
20
21use byteorder::{ByteOrder, NetworkEndian};
22use bytes::{BufMut, BytesMut};
23use mz_ore::cast::{CastFrom, u64_to_usize};
24use mz_ore::netio::{self};
25use tokio::io::{self, AsyncRead, AsyncReadExt};
26
27use crate::FrontendMessage;
28use crate::format::Format;
29use crate::message::{FrontendStartupMessage, VERSION_CANCEL, VERSION_GSSENC, VERSION_SSL};
30
31pub const REJECT_ENCRYPTION: u8 = b'N';
32pub const ACCEPT_SSL_ENCRYPTION: u8 = b'S';
33
34pub const MAX_REQUEST_SIZE: usize = u64_to_usize(2 * bytesize::MB);
36
37#[derive(Debug)]
38pub enum CodecError {
39 StringNoTerminator,
40}
41
42impl Error for CodecError {}
43
44impl fmt::Display for CodecError {
45 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
46 f.write_str(match self {
47 CodecError::StringNoTerminator => "The string does not have a terminator",
48 })
49 }
50}
51
52pub trait Pgbuf: BufMut {
53 fn put_string(&mut self, s: &str);
54 fn put_length_i16(&mut self, len: usize) -> Result<(), io::Error>;
55 fn put_format_i8(&mut self, format: Format);
56 fn put_format_i16(&mut self, format: Format);
57}
58
59impl<B: BufMut> Pgbuf for B {
60 fn put_string(&mut self, s: &str) {
61 self.put(s.as_bytes());
62 self.put_u8(b'\0');
63 }
64
65 fn put_length_i16(&mut self, len: usize) -> Result<(), io::Error> {
66 let len = i16::try_from(len).map_err(|_| {
67 io::Error::new(io::ErrorKind::InvalidData, "length does not fit in an i16")
68 })?;
69 self.put_i16(len);
70 Ok(())
71 }
72
73 fn put_format_i8(&mut self, format: Format) {
74 self.put_i8(format.into())
75 }
76
77 fn put_format_i16(&mut self, format: Format) {
78 self.put_i8(0);
79 self.put_format_i8(format);
80 }
81}
82
83pub async fn decode_startup<A>(mut conn: A) -> Result<Option<FrontendStartupMessage>, io::Error>
84where
85 A: AsyncRead + Unpin,
86{
87 let mut frame_len = [0; 4];
88 let nread = netio::read_exact_or_eof(&mut conn, &mut frame_len).await?;
89 match nread {
90 4 => (),
92 0 => return Ok(None),
95 _ => return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "early eof")),
98 };
99 let frame_len = parse_frame_len(&frame_len)?;
100
101 let mut buf = BytesMut::new();
102 buf.resize(frame_len, b'0');
103 conn.read_exact(&mut buf).await?;
104
105 let mut buf = Cursor::new(&buf);
106 let version = buf.read_i32()?;
107 let message = match version {
108 VERSION_CANCEL => FrontendStartupMessage::CancelRequest {
109 conn_id: buf.read_u32()?,
110 secret_key: buf.read_u32()?,
111 },
112 VERSION_SSL => FrontendStartupMessage::SslRequest,
113 VERSION_GSSENC => FrontendStartupMessage::GssEncRequest,
114 _ => {
115 let mut params = BTreeMap::new();
116 while buf.peek_byte()? != 0 {
117 let name = buf.read_cstr()?.to_owned();
118 let value = buf.read_cstr()?.to_owned();
119 params.insert(name, value);
120 }
121 FrontendStartupMessage::Startup { version, params }
122 }
123 };
124 Ok(Some(message))
125}
126
127impl FrontendStartupMessage {
128 pub fn encode(&self, dst: &mut BytesMut) -> Result<(), io::Error> {
130 let base = dst.len();
132 dst.put_u32(0);
133
134 match self {
136 FrontendStartupMessage::Startup { version, params } => {
137 dst.put_i32(*version);
138 for (k, v) in params {
139 dst.put_string(k);
140 dst.put_string(v);
141 }
142 dst.put_i8(0);
143 }
144 FrontendStartupMessage::CancelRequest {
145 conn_id,
146 secret_key,
147 } => {
148 dst.put_i32(VERSION_CANCEL);
149 dst.put_u32(*conn_id);
150 dst.put_u32(*secret_key);
151 }
152 FrontendStartupMessage::SslRequest {} => dst.put_i32(VERSION_SSL),
153 _ => panic!("unsupported"),
154 }
155
156 let len = dst.len() - base;
157
158 let len = i32::try_from(len).map_err(|_| {
160 io::Error::new(
161 io::ErrorKind::InvalidData,
162 "length of encoded message does not fit into an i32",
163 )
164 })?;
165 dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
166
167 Ok(())
168 }
169}
170
171impl FrontendMessage {
172 pub fn encode(&self, dst: &mut BytesMut) -> Result<(), io::Error> {
174 let byte = match self {
176 FrontendMessage::Password { .. } => b'p',
177 _ => panic!("unsupported"),
178 };
179 dst.put_u8(byte);
180
181 let base = dst.len();
183 dst.put_u32(0);
184
185 match self {
187 FrontendMessage::Password { password } => {
188 dst.put_string(password);
189 }
190 _ => panic!("unsupported"),
191 }
192
193 let len = dst.len() - base;
194
195 let len = i32::try_from(len).map_err(|_| {
197 io::Error::new(
198 io::ErrorKind::InvalidData,
199 "length of encoded message does not fit into an i32",
200 )
201 })?;
202 dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
203
204 Ok(())
205 }
206}
207
208#[derive(Debug)]
209pub enum DecodeState {
210 Head,
211 Data(u8, usize),
212}
213
214pub fn parse_frame_len(src: &[u8]) -> Result<usize, io::Error> {
215 let n = usize::cast_from(NetworkEndian::read_u32(src));
216 if n > netio::MAX_FRAME_SIZE {
217 return Err(io::Error::new(
218 io::ErrorKind::InvalidData,
219 netio::FrameTooBig,
220 ));
221 } else if n < 4 {
222 return Err(io::Error::new(
223 io::ErrorKind::InvalidInput,
224 "invalid frame length",
225 ));
226 }
227 Ok(n - 4)
228}
229
230#[derive(Debug)]
239pub struct Cursor<'a> {
240 buf: &'a [u8],
241}
242
243impl<'a> Cursor<'a> {
244 pub fn new(buf: &'a [u8]) -> Cursor<'a> {
247 Cursor { buf }
248 }
249
250 pub fn peek_byte(&self) -> Result<u8, io::Error> {
252 self.buf
253 .get(0)
254 .copied()
255 .ok_or_else(|| input_err("No byte to read"))
256 }
257
258 pub fn read_byte(&mut self) -> Result<u8, io::Error> {
260 let byte = self.peek_byte()?;
261 self.advance(1);
262 Ok(byte)
263 }
264
265 pub fn read_cstr(&mut self) -> Result<&'a str, io::Error> {
279 if let Some(pos) = self.buf.iter().position(|b| *b == 0) {
280 let val = std::str::from_utf8(&self.buf[..pos]).map_err(input_err)?;
281 self.advance(pos + 1);
282 Ok(val)
283 } else {
284 Err(input_err(CodecError::StringNoTerminator))
285 }
286 }
287
288 pub fn read_i16(&mut self) -> Result<i16, io::Error> {
291 if self.buf.len() < 2 {
292 return Err(input_err("not enough buffer for an Int16"));
293 }
294 let val = NetworkEndian::read_i16(self.buf);
295 self.advance(2);
296 Ok(val)
297 }
298
299 pub fn read_i32(&mut self) -> Result<i32, io::Error> {
302 if self.buf.len() < 4 {
303 return Err(input_err("not enough buffer for an Int32"));
304 }
305 let val = NetworkEndian::read_i32(self.buf);
306 self.advance(4);
307 Ok(val)
308 }
309
310 pub fn read_u32(&mut self) -> Result<u32, io::Error> {
313 if self.buf.len() < 4 {
314 return Err(input_err("not enough buffer for an Int32"));
315 }
316 let val = NetworkEndian::read_u32(self.buf);
317 self.advance(4);
318 Ok(val)
319 }
320
321 pub fn read_format(&mut self) -> Result<Format, io::Error> {
323 Format::try_from(self.read_i16()?)
324 }
325
326 pub fn advance(&mut self, n: usize) {
328 self.buf = &self.buf[n..]
329 }
330}
331
332pub fn input_err(source: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
335 io::Error::new(io::ErrorKind::InvalidInput, source.into())
336}