1use std::net::IpAddr;
18
19use async_trait::async_trait;
20use bytes::{Buf, BufMut, BytesMut};
21use bytesize::ByteSize;
22use futures::{SinkExt, TryStreamExt, sink};
23use itertools::Itertools;
24use mz_adapter_types::connection::ConnectionId;
25use mz_ore::cast::CastFrom;
26use mz_ore::future::OreSinkExt;
27use mz_ore::netio::AsyncReady;
28use mz_pgwire_common::{
29 Conn, Cursor, DecodeState, ErrorResponse, FrontendMessage, MAX_REQUEST_SIZE, Pgbuf, input_err,
30 parse_frame_len,
31};
32use tokio::io::{self, AsyncRead, AsyncWrite, Interest, Ready};
33use tokio::time::{self, Duration};
34use tokio_util::codec::{Decoder, Encoder, Framed};
35use tracing::trace;
36
37use crate::message::{BackendMessage, BackendMessageKind};
38
39pub struct FramedConn<A> {
41 conn_id: ConnectionId,
42 peer_addr: Option<IpAddr>,
43 inner: sink::Buffer<Framed<Conn<A>, Codec>, BackendMessage>,
44}
45
46impl<A> FramedConn<A>
47where
48 A: AsyncRead + AsyncWrite + Unpin,
49{
50 pub fn new(conn_id: ConnectionId, peer_addr: Option<IpAddr>, inner: Conn<A>) -> FramedConn<A> {
59 FramedConn {
60 conn_id,
61 peer_addr,
62 inner: Framed::new(inner, Codec::new()).buffer(32),
63 }
64 }
65
66 pub async fn recv(&mut self) -> Result<Option<FrontendMessage>, io::Error> {
80 let message = self.inner.try_next().await?;
81 match &message {
82 Some(message) => trace!("cid={} recv_name={}", self.conn_id, message.name()),
83 None => trace!("cid={} recv=<eof>", self.conn_id),
84 }
85 Ok(message)
86 }
87
88 pub async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
97 where
98 M: Into<BackendMessage>,
99 {
100 let message = message.into();
101 trace!(
102 "cid={} send={:?}",
103 self.conn_id,
104 BackendMessageKind::from(&message)
105 );
106 self.inner.enqueue(message).await
107 }
108
109 pub async fn send_all(
116 &mut self,
117 messages: impl IntoIterator<Item = BackendMessage>,
118 ) -> Result<(), io::Error> {
119 for m in messages {
122 self.send(m).await?;
123 }
124 Ok(())
125 }
126
127 pub async fn flush(&mut self) -> Result<(), io::Error> {
129 self.inner.flush().await
130 }
131
132 pub fn set_encode_state(
141 &mut self,
142 encode_state: Vec<(mz_pgrepr::Type, mz_pgwire_common::Format)>,
143 ) {
144 self.inner.get_mut().codec_mut().encode_state = encode_state;
145 }
146
147 pub async fn wait_closed(&self) -> io::Error
161 where
162 A: AsyncReady + Send + Sync,
163 {
164 loop {
165 time::sleep(Duration::from_secs(1)).await;
166
167 match self.ready(Interest::READABLE | Interest::WRITABLE).await {
168 Ok(ready) if ready.is_read_closed() || ready.is_write_closed() => {
169 return io::Error::new(io::ErrorKind::Other, "connection closed");
170 }
171 Ok(_) => (),
172 Err(err) => return err,
173 }
174 }
175 }
176
177 pub fn conn_id(&self) -> &ConnectionId {
179 &self.conn_id
180 }
181
182 pub fn peer_addr(&self) -> &Option<IpAddr> {
184 &self.peer_addr
185 }
186}
187
188impl<A> FramedConn<A>
189where
190 A: AsyncRead + AsyncWrite + Unpin,
191{
192 pub fn inner(&self) -> &Conn<A> {
193 self.inner.get_ref().get_ref()
194 }
195}
196
197#[async_trait]
198impl<A> AsyncReady for FramedConn<A>
199where
200 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
201{
202 async fn ready(&self, interest: Interest) -> io::Result<Ready> {
203 self.inner.get_ref().get_ref().ready(interest).await
204 }
205}
206
207struct Codec {
208 decode_state: DecodeState,
209 encode_state: Vec<(mz_pgrepr::Type, mz_pgwire_common::Format)>,
210}
211
212impl Codec {
213 pub fn new() -> Codec {
215 Codec {
216 decode_state: DecodeState::Head,
217 encode_state: vec![],
218 }
219 }
220}
221
222impl Default for Codec {
223 fn default() -> Codec {
224 Codec::new()
225 }
226}
227
228impl Encoder<BackendMessage> for Codec {
229 type Error = io::Error;
230
231 fn encode(&mut self, msg: BackendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
232 let byte = match &msg {
234 BackendMessage::AuthenticationOk => b'R',
235 BackendMessage::AuthenticationCleartextPassword => b'R',
236 BackendMessage::RowDescription(_) => b'T',
237 BackendMessage::DataRow(_) => b'D',
238 BackendMessage::CommandComplete { .. } => b'C',
239 BackendMessage::EmptyQueryResponse => b'I',
240 BackendMessage::ReadyForQuery(_) => b'Z',
241 BackendMessage::NoData => b'n',
242 BackendMessage::ParameterStatus(_, _) => b'S',
243 BackendMessage::PortalSuspended => b's',
244 BackendMessage::BackendKeyData { .. } => b'K',
245 BackendMessage::ParameterDescription(_) => b't',
246 BackendMessage::ParseComplete => b'1',
247 BackendMessage::BindComplete => b'2',
248 BackendMessage::CloseComplete => b'3',
249 BackendMessage::ErrorResponse(r) => {
250 if r.severity.is_error() {
251 b'E'
252 } else {
253 b'N'
254 }
255 }
256 BackendMessage::CopyInResponse { .. } => b'G',
257 BackendMessage::CopyOutResponse { .. } => b'H',
258 BackendMessage::CopyData(_) => b'd',
259 BackendMessage::CopyDone => b'c',
260 };
261 dst.put_u8(byte);
262
263 let base = dst.len();
265 dst.put_u32(0);
266
267 match msg {
269 BackendMessage::CopyInResponse {
270 overall_format,
271 column_formats,
272 }
273 | BackendMessage::CopyOutResponse {
274 overall_format,
275 column_formats,
276 } => {
277 dst.put_format_i8(overall_format);
278 dst.put_length_i16(column_formats.len())?;
279 for format in column_formats {
280 dst.put_format_i16(format);
281 }
282 }
283 BackendMessage::CopyData(data) => {
284 dst.put_slice(&data);
285 }
286 BackendMessage::CopyDone => (),
287 BackendMessage::AuthenticationOk => {
288 dst.put_u32(0);
289 }
290 BackendMessage::AuthenticationCleartextPassword => {
291 dst.put_u32(3);
292 }
293 BackendMessage::RowDescription(fields) => {
294 dst.put_length_i16(fields.len())?;
295 for f in &fields {
296 dst.put_string(&f.name.to_string());
297 dst.put_u32(f.table_id);
298 dst.put_u16(f.column_id);
299 dst.put_u32(f.type_oid);
300 dst.put_i16(f.type_len);
301 dst.put_i32(f.type_mod);
302 dst.put_format_i16(f.format);
304 }
305 }
306 BackendMessage::DataRow(fields) => {
307 dst.put_length_i16(fields.len())?;
308 for (f, (ty, format)) in fields.iter().zip_eq(&self.encode_state) {
309 if let Some(f) = f {
310 let base = dst.len();
311 dst.put_u32(0);
312 f.encode(ty, *format, dst)?;
313 let len = dst.len() - base - 4;
314 let len = i32::try_from(len).map_err(|_| {
315 io::Error::new(
316 io::ErrorKind::Other,
317 "length of encoded data row field does not fit into an i32",
318 )
319 })?;
320 dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
321 } else {
322 dst.put_i32(-1);
323 }
324 }
325 }
326 BackendMessage::CommandComplete { tag } => {
327 dst.put_string(&tag);
328 }
329 BackendMessage::ParseComplete => (),
330 BackendMessage::BindComplete => (),
331 BackendMessage::CloseComplete => (),
332 BackendMessage::EmptyQueryResponse => (),
333 BackendMessage::ReadyForQuery(status) => {
334 dst.put_u8(status.into());
335 }
336 BackendMessage::ParameterStatus(name, value) => {
337 dst.put_string(name);
338 dst.put_string(&value);
339 }
340 BackendMessage::PortalSuspended => (),
341 BackendMessage::NoData => (),
342 BackendMessage::BackendKeyData {
343 conn_id,
344 secret_key,
345 } => {
346 dst.put_u32(conn_id);
347 dst.put_u32(secret_key);
348 }
349 BackendMessage::ParameterDescription(params) => {
350 dst.put_length_i16(params.len())?;
351 for param in params {
352 dst.put_u32(param.oid());
353 }
354 }
355 BackendMessage::ErrorResponse(ErrorResponse {
356 severity,
357 code,
358 message,
359 detail,
360 hint,
361 position,
362 }) => {
363 dst.put_u8(b'S');
364 dst.put_string(severity.as_str());
365 dst.put_u8(b'C');
366 dst.put_string(code.code());
367 dst.put_u8(b'M');
368 dst.put_string(&message);
369 if let Some(detail) = &detail {
370 dst.put_u8(b'D');
371 dst.put_string(detail);
372 }
373 if let Some(hint) = &hint {
374 dst.put_u8(b'H');
375 dst.put_string(hint);
376 }
377 if let Some(position) = &position {
378 dst.put_u8(b'P');
379 dst.put_string(&position.to_string());
380 }
381 dst.put_u8(b'\0');
382 }
383 }
384
385 let len = dst.len() - base;
386
387 let len = i32::try_from(len).map_err(|_| {
389 io::Error::new(
390 io::ErrorKind::Other,
391 "length of encoded message does not fit into an i32",
392 )
393 })?;
394 dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
395
396 Ok(())
397 }
398}
399
400impl Decoder for Codec {
401 type Item = FrontendMessage;
402 type Error = io::Error;
403
404 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FrontendMessage>, io::Error> {
405 if src.len() > MAX_REQUEST_SIZE {
406 return Err(io::Error::new(
407 io::ErrorKind::InvalidData,
408 format!(
409 "request larger than {}",
410 ByteSize::b(u64::cast_from(MAX_REQUEST_SIZE))
411 ),
412 ));
413 }
414 loop {
415 match self.decode_state {
416 DecodeState::Head => {
417 if src.len() < 5 {
418 return Ok(None);
419 }
420 let msg_type = src[0];
421 let frame_len = parse_frame_len(&src[1..])?;
422 src.advance(5);
423 src.reserve(frame_len);
424 self.decode_state = DecodeState::Data(msg_type, frame_len);
425 }
426
427 DecodeState::Data(msg_type, frame_len) => {
428 if src.len() < frame_len {
429 return Ok(None);
430 }
431 let buf = src.split_to(frame_len).freeze();
432 let buf = Cursor::new(&buf);
433 let msg = match msg_type {
434 b'Q' => decode_query(buf)?,
436
437 b'P' => decode_parse(buf)?,
439 b'D' => decode_describe(buf)?,
440 b'B' => decode_bind(buf)?,
441 b'E' => decode_execute(buf)?,
442 b'H' => decode_flush(buf)?,
443 b'S' => decode_sync(buf)?,
444 b'C' => decode_close(buf)?,
445
446 b'X' => decode_terminate(buf)?,
448
449 b'p' => decode_password(buf)?,
451
452 b'f' => decode_copy_fail(buf)?,
454 b'd' => decode_copy_data(buf, frame_len)?,
455 b'c' => decode_copy_done(buf)?,
456
457 _ => {
459 return Err(io::Error::new(
460 io::ErrorKind::InvalidData,
461 format!("unknown message type {}", msg_type),
462 ));
463 }
464 };
465 src.reserve(5);
466 self.decode_state = DecodeState::Head;
467 return Ok(Some(msg));
468 }
469 }
470 }
471 }
472}
473
474fn decode_terminate(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
475 Ok(FrontendMessage::Terminate)
477}
478
479fn decode_password(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
480 Ok(FrontendMessage::Password {
481 password: buf.read_cstr()?.to_owned(),
482 })
483}
484
485fn decode_query(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
486 Ok(FrontendMessage::Query {
487 sql: buf.read_cstr()?.to_string(),
488 })
489}
490
491fn decode_parse(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
492 let name = buf.read_cstr()?;
493 let sql = buf.read_cstr()?;
494
495 let mut param_types = vec![];
496 for _ in 0..buf.read_i16()? {
497 param_types.push(buf.read_u32()?);
498 }
499
500 Ok(FrontendMessage::Parse {
501 name: name.into(),
502 sql: sql.into(),
503 param_types,
504 })
505}
506
507fn decode_close(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
508 match buf.read_byte()? {
509 b'S' => Ok(FrontendMessage::CloseStatement {
510 name: buf.read_cstr()?.to_owned(),
511 }),
512 b'P' => Ok(FrontendMessage::ClosePortal {
513 name: buf.read_cstr()?.to_owned(),
514 }),
515 b => Err(input_err(format!(
516 "invalid type byte in close message: {}",
517 b
518 ))),
519 }
520}
521
522fn decode_describe(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
523 let first_char = buf.read_byte()?;
524 let name = buf.read_cstr()?.to_string();
525 match first_char {
526 b'S' => Ok(FrontendMessage::DescribeStatement { name }),
527 b'P' => Ok(FrontendMessage::DescribePortal { name }),
528 other => Err(input_err(format!("Invalid describe type: {:#x?}", other))),
529 }
530}
531
532fn decode_bind(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
533 let portal_name = buf.read_cstr()?.to_string();
534 let statement_name = buf.read_cstr()?.to_string();
535
536 let mut param_formats = Vec::new();
537 for _ in 0..buf.read_i16()? {
538 param_formats.push(buf.read_format()?);
539 }
540
541 let mut raw_params = Vec::new();
542 for _ in 0..buf.read_i16()? {
543 let len = buf.read_i32()?;
544 if len == -1 {
545 raw_params.push(None); } else {
547 let mut value = Vec::new();
549 for _ in 0..len {
550 value.push(buf.read_byte()?);
551 }
552 raw_params.push(Some(value));
553 }
554 }
555
556 let mut result_formats = Vec::new();
557 for _ in 0..buf.read_i16()? {
558 result_formats.push(buf.read_format()?);
559 }
560
561 Ok(FrontendMessage::Bind {
562 portal_name,
563 statement_name,
564 param_formats,
565 raw_params,
566 result_formats,
567 })
568}
569
570fn decode_execute(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
571 let portal_name = buf.read_cstr()?.to_string();
572 let max_rows = buf.read_i32()?;
573 Ok(FrontendMessage::Execute {
574 portal_name,
575 max_rows,
576 })
577}
578
579fn decode_flush(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
580 Ok(FrontendMessage::Flush)
582}
583
584fn decode_sync(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
585 Ok(FrontendMessage::Sync)
587}
588
589fn decode_copy_data(mut buf: Cursor, frame_len: usize) -> Result<FrontendMessage, io::Error> {
590 let mut data = Vec::with_capacity(frame_len);
591 for _ in 0..frame_len {
592 data.push(buf.read_byte()?);
593 }
594 Ok(FrontendMessage::CopyData(data))
595}
596
597fn decode_copy_done(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
598 Ok(FrontendMessage::CopyDone)
600}
601
602fn decode_copy_fail(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
603 Ok(FrontendMessage::CopyFail(buf.read_cstr()?.to_string()))
604}