1use std::net::IpAddr;
18
19use async_trait::async_trait;
20use bytes::{Buf, BufMut, BytesMut};
21use bytesize::ByteSize;
22use futures::{SinkExt, TryStreamExt, sink};
23use mz_adapter_types::connection::ConnectionId;
24use mz_ore::cast::CastFrom;
25use mz_ore::future::OreSinkExt;
26use mz_ore::netio::AsyncReady;
27use mz_pgwire_common::{
28 Conn, Cursor, DecodeState, ErrorResponse, FrontendMessage, MAX_REQUEST_SIZE, Pgbuf, input_err,
29 parse_frame_len,
30};
31use tokio::io::{self, AsyncRead, AsyncWrite, Interest, Ready};
32use tokio::time::{self, Duration};
33use tokio_util::codec::{Decoder, Encoder, Framed};
34use tracing::trace;
35
36use crate::message::{BackendMessage, BackendMessageKind};
37
38pub struct FramedConn<A> {
40 conn_id: ConnectionId,
41 peer_addr: Option<IpAddr>,
42 inner: sink::Buffer<Framed<Conn<A>, Codec>, BackendMessage>,
43}
44
45impl<A> FramedConn<A>
46where
47 A: AsyncRead + AsyncWrite + Unpin,
48{
49 pub fn new(conn_id: ConnectionId, peer_addr: Option<IpAddr>, inner: Conn<A>) -> FramedConn<A> {
58 FramedConn {
59 conn_id,
60 peer_addr,
61 inner: Framed::new(inner, Codec::new()).buffer(32),
62 }
63 }
64
65 pub async fn recv(&mut self) -> Result<Option<FrontendMessage>, io::Error> {
79 let message = self.inner.try_next().await?;
80 match &message {
81 Some(message) => trace!("cid={} recv_name={}", self.conn_id, message.name()),
82 None => trace!("cid={} recv=<eof>", self.conn_id),
83 }
84 Ok(message)
85 }
86
87 pub async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
96 where
97 M: Into<BackendMessage>,
98 {
99 let message = message.into();
100 trace!(
101 "cid={} send={:?}",
102 self.conn_id,
103 BackendMessageKind::from(&message)
104 );
105 self.inner.enqueue(message).await
106 }
107
108 pub async fn send_all(
115 &mut self,
116 messages: impl IntoIterator<Item = BackendMessage>,
117 ) -> Result<(), io::Error> {
118 for m in messages {
121 self.send(m).await?;
122 }
123 Ok(())
124 }
125
126 pub async fn flush(&mut self) -> Result<(), io::Error> {
128 self.inner.flush().await
129 }
130
131 pub fn set_encode_state(
140 &mut self,
141 encode_state: Vec<(mz_pgrepr::Type, mz_pgwire_common::Format)>,
142 ) {
143 self.inner.get_mut().codec_mut().encode_state = encode_state;
144 }
145
146 pub async fn wait_closed(&self) -> io::Error
160 where
161 A: AsyncReady + Send + Sync,
162 {
163 loop {
164 time::sleep(Duration::from_secs(1)).await;
165
166 match self.ready(Interest::READABLE | Interest::WRITABLE).await {
167 Ok(ready) if ready.is_read_closed() || ready.is_write_closed() => {
168 return io::Error::new(io::ErrorKind::Other, "connection closed");
169 }
170 Ok(_) => (),
171 Err(err) => return err,
172 }
173 }
174 }
175
176 pub fn conn_id(&self) -> &ConnectionId {
178 &self.conn_id
179 }
180
181 pub fn peer_addr(&self) -> &Option<IpAddr> {
183 &self.peer_addr
184 }
185}
186
187impl<A> FramedConn<A>
188where
189 A: AsyncRead + AsyncWrite + Unpin,
190{
191 pub fn inner(&self) -> &Conn<A> {
192 self.inner.get_ref().get_ref()
193 }
194}
195
196#[async_trait]
197impl<A> AsyncReady for FramedConn<A>
198where
199 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
200{
201 async fn ready(&self, interest: Interest) -> io::Result<Ready> {
202 self.inner.get_ref().get_ref().ready(interest).await
203 }
204}
205
206struct Codec {
207 decode_state: DecodeState,
208 encode_state: Vec<(mz_pgrepr::Type, mz_pgwire_common::Format)>,
209}
210
211impl Codec {
212 pub fn new() -> Codec {
214 Codec {
215 decode_state: DecodeState::Head,
216 encode_state: vec![],
217 }
218 }
219}
220
221impl Default for Codec {
222 fn default() -> Codec {
223 Codec::new()
224 }
225}
226
227impl Encoder<BackendMessage> for Codec {
228 type Error = io::Error;
229
230 fn encode(&mut self, msg: BackendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
231 let byte = match &msg {
233 BackendMessage::AuthenticationOk => b'R',
234 BackendMessage::AuthenticationCleartextPassword => b'R',
235 BackendMessage::RowDescription(_) => b'T',
236 BackendMessage::DataRow(_) => b'D',
237 BackendMessage::CommandComplete { .. } => b'C',
238 BackendMessage::EmptyQueryResponse => b'I',
239 BackendMessage::ReadyForQuery(_) => b'Z',
240 BackendMessage::NoData => b'n',
241 BackendMessage::ParameterStatus(_, _) => b'S',
242 BackendMessage::PortalSuspended => b's',
243 BackendMessage::BackendKeyData { .. } => b'K',
244 BackendMessage::ParameterDescription(_) => b't',
245 BackendMessage::ParseComplete => b'1',
246 BackendMessage::BindComplete => b'2',
247 BackendMessage::CloseComplete => b'3',
248 BackendMessage::ErrorResponse(r) => {
249 if r.severity.is_error() {
250 b'E'
251 } else {
252 b'N'
253 }
254 }
255 BackendMessage::CopyInResponse { .. } => b'G',
256 BackendMessage::CopyOutResponse { .. } => b'H',
257 BackendMessage::CopyData(_) => b'd',
258 BackendMessage::CopyDone => b'c',
259 };
260 dst.put_u8(byte);
261
262 let base = dst.len();
264 dst.put_u32(0);
265
266 match msg {
268 BackendMessage::CopyInResponse {
269 overall_format,
270 column_formats,
271 }
272 | BackendMessage::CopyOutResponse {
273 overall_format,
274 column_formats,
275 } => {
276 dst.put_format_i8(overall_format);
277 dst.put_length_i16(column_formats.len())?;
278 for format in column_formats {
279 dst.put_format_i16(format);
280 }
281 }
282 BackendMessage::CopyData(data) => {
283 dst.put_slice(&data);
284 }
285 BackendMessage::CopyDone => (),
286 BackendMessage::AuthenticationOk => {
287 dst.put_u32(0);
288 }
289 BackendMessage::AuthenticationCleartextPassword => {
290 dst.put_u32(3);
291 }
292 BackendMessage::RowDescription(fields) => {
293 dst.put_length_i16(fields.len())?;
294 for f in &fields {
295 dst.put_string(&f.name.to_string());
296 dst.put_u32(f.table_id);
297 dst.put_u16(f.column_id);
298 dst.put_u32(f.type_oid);
299 dst.put_i16(f.type_len);
300 dst.put_i32(f.type_mod);
301 dst.put_format_i16(f.format);
303 }
304 }
305 BackendMessage::DataRow(fields) => {
306 dst.put_length_i16(fields.len())?;
307 for (f, (ty, format)) in fields.iter().zip(&self.encode_state) {
308 if let Some(f) = f {
309 let base = dst.len();
310 dst.put_u32(0);
311 f.encode(ty, *format, dst)?;
312 let len = dst.len() - base - 4;
313 let len = i32::try_from(len).map_err(|_| {
314 io::Error::new(
315 io::ErrorKind::Other,
316 "length of encoded data row field does not fit into an i32",
317 )
318 })?;
319 dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
320 } else {
321 dst.put_i32(-1);
322 }
323 }
324 }
325 BackendMessage::CommandComplete { tag } => {
326 dst.put_string(&tag);
327 }
328 BackendMessage::ParseComplete => (),
329 BackendMessage::BindComplete => (),
330 BackendMessage::CloseComplete => (),
331 BackendMessage::EmptyQueryResponse => (),
332 BackendMessage::ReadyForQuery(status) => {
333 dst.put_u8(status.into());
334 }
335 BackendMessage::ParameterStatus(name, value) => {
336 dst.put_string(name);
337 dst.put_string(&value);
338 }
339 BackendMessage::PortalSuspended => (),
340 BackendMessage::NoData => (),
341 BackendMessage::BackendKeyData {
342 conn_id,
343 secret_key,
344 } => {
345 dst.put_u32(conn_id);
346 dst.put_u32(secret_key);
347 }
348 BackendMessage::ParameterDescription(params) => {
349 dst.put_length_i16(params.len())?;
350 for param in params {
351 dst.put_u32(param.oid());
352 }
353 }
354 BackendMessage::ErrorResponse(ErrorResponse {
355 severity,
356 code,
357 message,
358 detail,
359 hint,
360 position,
361 }) => {
362 dst.put_u8(b'S');
363 dst.put_string(severity.as_str());
364 dst.put_u8(b'C');
365 dst.put_string(code.code());
366 dst.put_u8(b'M');
367 dst.put_string(&message);
368 if let Some(detail) = &detail {
369 dst.put_u8(b'D');
370 dst.put_string(detail);
371 }
372 if let Some(hint) = &hint {
373 dst.put_u8(b'H');
374 dst.put_string(hint);
375 }
376 if let Some(position) = &position {
377 dst.put_u8(b'P');
378 dst.put_string(&position.to_string());
379 }
380 dst.put_u8(b'\0');
381 }
382 }
383
384 let len = dst.len() - base;
385
386 let len = i32::try_from(len).map_err(|_| {
388 io::Error::new(
389 io::ErrorKind::Other,
390 "length of encoded message does not fit into an i32",
391 )
392 })?;
393 dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
394
395 Ok(())
396 }
397}
398
399impl Decoder for Codec {
400 type Item = FrontendMessage;
401 type Error = io::Error;
402
403 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FrontendMessage>, io::Error> {
404 if src.len() > MAX_REQUEST_SIZE {
405 return Err(io::Error::new(
406 io::ErrorKind::InvalidData,
407 format!(
408 "request larger than {}",
409 ByteSize::b(u64::cast_from(MAX_REQUEST_SIZE))
410 ),
411 ));
412 }
413 loop {
414 match self.decode_state {
415 DecodeState::Head => {
416 if src.len() < 5 {
417 return Ok(None);
418 }
419 let msg_type = src[0];
420 let frame_len = parse_frame_len(&src[1..])?;
421 src.advance(5);
422 src.reserve(frame_len);
423 self.decode_state = DecodeState::Data(msg_type, frame_len);
424 }
425
426 DecodeState::Data(msg_type, frame_len) => {
427 if src.len() < frame_len {
428 return Ok(None);
429 }
430 let buf = src.split_to(frame_len).freeze();
431 let buf = Cursor::new(&buf);
432 let msg = match msg_type {
433 b'Q' => decode_query(buf)?,
435
436 b'P' => decode_parse(buf)?,
438 b'D' => decode_describe(buf)?,
439 b'B' => decode_bind(buf)?,
440 b'E' => decode_execute(buf)?,
441 b'H' => decode_flush(buf)?,
442 b'S' => decode_sync(buf)?,
443 b'C' => decode_close(buf)?,
444
445 b'X' => decode_terminate(buf)?,
447
448 b'p' => decode_password(buf)?,
450
451 b'f' => decode_copy_fail(buf)?,
453 b'd' => decode_copy_data(buf, frame_len)?,
454 b'c' => decode_copy_done(buf)?,
455
456 _ => {
458 return Err(io::Error::new(
459 io::ErrorKind::InvalidData,
460 format!("unknown message type {}", msg_type),
461 ));
462 }
463 };
464 src.reserve(5);
465 self.decode_state = DecodeState::Head;
466 return Ok(Some(msg));
467 }
468 }
469 }
470 }
471}
472
473fn decode_terminate(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
474 Ok(FrontendMessage::Terminate)
476}
477
478fn decode_password(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
479 Ok(FrontendMessage::Password {
480 password: buf.read_cstr()?.to_owned(),
481 })
482}
483
484fn decode_query(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
485 Ok(FrontendMessage::Query {
486 sql: buf.read_cstr()?.to_string(),
487 })
488}
489
490fn decode_parse(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
491 let name = buf.read_cstr()?;
492 let sql = buf.read_cstr()?;
493
494 let mut param_types = vec![];
495 for _ in 0..buf.read_i16()? {
496 param_types.push(buf.read_u32()?);
497 }
498
499 Ok(FrontendMessage::Parse {
500 name: name.into(),
501 sql: sql.into(),
502 param_types,
503 })
504}
505
506fn decode_close(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
507 match buf.read_byte()? {
508 b'S' => Ok(FrontendMessage::CloseStatement {
509 name: buf.read_cstr()?.to_owned(),
510 }),
511 b'P' => Ok(FrontendMessage::ClosePortal {
512 name: buf.read_cstr()?.to_owned(),
513 }),
514 b => Err(input_err(format!(
515 "invalid type byte in close message: {}",
516 b
517 ))),
518 }
519}
520
521fn decode_describe(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
522 let first_char = buf.read_byte()?;
523 let name = buf.read_cstr()?.to_string();
524 match first_char {
525 b'S' => Ok(FrontendMessage::DescribeStatement { name }),
526 b'P' => Ok(FrontendMessage::DescribePortal { name }),
527 other => Err(input_err(format!("Invalid describe type: {:#x?}", other))),
528 }
529}
530
531fn decode_bind(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
532 let portal_name = buf.read_cstr()?.to_string();
533 let statement_name = buf.read_cstr()?.to_string();
534
535 let mut param_formats = Vec::new();
536 for _ in 0..buf.read_i16()? {
537 param_formats.push(buf.read_format()?);
538 }
539
540 let mut raw_params = Vec::new();
541 for _ in 0..buf.read_i16()? {
542 let len = buf.read_i32()?;
543 if len == -1 {
544 raw_params.push(None); } else {
546 let mut value = Vec::new();
548 for _ in 0..len {
549 value.push(buf.read_byte()?);
550 }
551 raw_params.push(Some(value));
552 }
553 }
554
555 let mut result_formats = Vec::new();
556 for _ in 0..buf.read_i16()? {
557 result_formats.push(buf.read_format()?);
558 }
559
560 Ok(FrontendMessage::Bind {
561 portal_name,
562 statement_name,
563 param_formats,
564 raw_params,
565 result_formats,
566 })
567}
568
569fn decode_execute(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
570 let portal_name = buf.read_cstr()?.to_string();
571 let max_rows = buf.read_i32()?;
572 Ok(FrontendMessage::Execute {
573 portal_name,
574 max_rows,
575 })
576}
577
578fn decode_flush(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
579 Ok(FrontendMessage::Flush)
581}
582
583fn decode_sync(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
584 Ok(FrontendMessage::Sync)
586}
587
588fn decode_copy_data(mut buf: Cursor, frame_len: usize) -> Result<FrontendMessage, io::Error> {
589 let mut data = Vec::with_capacity(frame_len);
590 for _ in 0..frame_len {
591 data.push(buf.read_byte()?);
592 }
593 Ok(FrontendMessage::CopyData(data))
594}
595
596fn decode_copy_done(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
597 Ok(FrontendMessage::CopyDone)
599}
600
601fn decode_copy_fail(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
602 Ok(FrontendMessage::CopyFail(buf.read_cstr()?.to_string()))
603}