1use std::net::IpAddr;
18
19use async_trait::async_trait;
20use bytes::{Buf, BufMut, BytesMut};
21use bytesize::ByteSize;
22use futures::{SinkExt, TryStreamExt, sink};
23use itertools::Itertools;
24use mz_adapter_types::connection::ConnectionId;
25use mz_ore::cast::CastFrom;
26use mz_ore::future::OreSinkExt;
27use mz_ore::netio::AsyncReady;
28use mz_pgwire_common::{
29 ChannelBinding, Conn, Cursor, DecodeState, ErrorResponse, FrontendMessage, GS2Header,
30 MAX_REQUEST_SIZE, Pgbuf, SASLClientFinalResponse, SASLInitialResponse, input_err,
31 parse_frame_len,
32};
33use tokio::io::{self, AsyncRead, AsyncWrite, Interest, Ready};
34use tokio::time::{self, Duration};
35use tokio_util::codec::{Decoder, Encoder, Framed};
36use tracing::trace;
37
38use crate::message::{BackendMessage, BackendMessageKind, SASLServerFinalMessageKinds};
39
40pub struct FramedConn<A> {
42 conn_id: ConnectionId,
43 peer_addr: Option<IpAddr>,
44 inner: sink::Buffer<Framed<Conn<A>, Codec>, BackendMessage>,
45}
46
47impl<A> FramedConn<A>
48where
49 A: AsyncRead + AsyncWrite + Unpin,
50{
51 pub fn new(conn_id: ConnectionId, peer_addr: Option<IpAddr>, inner: Conn<A>) -> FramedConn<A> {
60 FramedConn {
61 conn_id,
62 peer_addr,
63 inner: Framed::new(inner, Codec::new()).buffer(32),
64 }
65 }
66
67 pub async fn recv(&mut self) -> Result<Option<FrontendMessage>, io::Error> {
81 let message = self.inner.try_next().await?;
82 match &message {
83 Some(message) => trace!("cid={} recv_name={}", self.conn_id, message.name()),
84 None => trace!("cid={} recv=<eof>", self.conn_id),
85 }
86 Ok(message)
87 }
88
89 pub async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
98 where
99 M: Into<BackendMessage>,
100 {
101 let message = message.into();
102 trace!(
103 "cid={} send={:?}",
104 self.conn_id,
105 BackendMessageKind::from(&message)
106 );
107 self.inner.enqueue(message).await
108 }
109
110 pub async fn send_all(
117 &mut self,
118 messages: impl IntoIterator<Item = BackendMessage>,
119 ) -> Result<(), io::Error> {
120 for m in messages {
123 self.send(m).await?;
124 }
125 Ok(())
126 }
127
128 pub async fn flush(&mut self) -> Result<(), io::Error> {
130 self.inner.flush().await
131 }
132
133 pub fn set_encode_state(
142 &mut self,
143 encode_state: Vec<(mz_pgrepr::Type, mz_pgwire_common::Format)>,
144 ) {
145 self.inner.get_mut().codec_mut().encode_state = encode_state;
146 }
147
148 pub async fn wait_closed(&self) -> io::Error
162 where
163 A: AsyncReady + Send + Sync,
164 {
165 loop {
166 time::sleep(Duration::from_secs(1)).await;
167
168 match self.ready(Interest::READABLE | Interest::WRITABLE).await {
169 Ok(ready) if ready.is_read_closed() || ready.is_write_closed() => {
170 return io::Error::new(io::ErrorKind::Other, "connection closed");
171 }
172 Ok(_) => (),
173 Err(err) => return err,
174 }
175 }
176 }
177
178 pub fn conn_id(&self) -> &ConnectionId {
180 &self.conn_id
181 }
182
183 pub fn peer_addr(&self) -> &Option<IpAddr> {
185 &self.peer_addr
186 }
187}
188
189impl<A> FramedConn<A>
190where
191 A: AsyncRead + AsyncWrite + Unpin,
192{
193 pub fn inner(&self) -> &Conn<A> {
194 self.inner.get_ref().get_ref()
195 }
196}
197
198#[async_trait]
199impl<A> AsyncReady for FramedConn<A>
200where
201 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
202{
203 async fn ready(&self, interest: Interest) -> io::Result<Ready> {
204 self.inner.get_ref().get_ref().ready(interest).await
205 }
206}
207
208struct Codec {
209 decode_state: DecodeState,
210 encode_state: Vec<(mz_pgrepr::Type, mz_pgwire_common::Format)>,
211}
212
213impl Codec {
214 pub fn new() -> Codec {
216 Codec {
217 decode_state: DecodeState::Head,
218 encode_state: vec![],
219 }
220 }
221}
222
223impl Default for Codec {
224 fn default() -> Codec {
225 Codec::new()
226 }
227}
228
229impl Encoder<BackendMessage> for Codec {
230 type Error = io::Error;
231
232 fn encode(&mut self, msg: BackendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
233 let byte = match &msg {
235 BackendMessage::AuthenticationOk => b'R',
236 BackendMessage::AuthenticationCleartextPassword
237 | BackendMessage::AuthenticationSASL
238 | BackendMessage::AuthenticationSASLContinue(_)
239 | BackendMessage::AuthenticationSASLFinal(_) => b'R',
240 BackendMessage::RowDescription(_) => b'T',
241 BackendMessage::DataRow(_) => b'D',
242 BackendMessage::CommandComplete { .. } => b'C',
243 BackendMessage::EmptyQueryResponse => b'I',
244 BackendMessage::ReadyForQuery(_) => b'Z',
245 BackendMessage::NoData => b'n',
246 BackendMessage::ParameterStatus(_, _) => b'S',
247 BackendMessage::PortalSuspended => b's',
248 BackendMessage::BackendKeyData { .. } => b'K',
249 BackendMessage::ParameterDescription(_) => b't',
250 BackendMessage::ParseComplete => b'1',
251 BackendMessage::BindComplete => b'2',
252 BackendMessage::CloseComplete => b'3',
253 BackendMessage::ErrorResponse(r) => {
254 if r.severity.is_error() {
255 b'E'
256 } else {
257 b'N'
258 }
259 }
260 BackendMessage::CopyInResponse { .. } => b'G',
261 BackendMessage::CopyOutResponse { .. } => b'H',
262 BackendMessage::CopyData(_) => b'd',
263 BackendMessage::CopyDone => b'c',
264 };
265 dst.put_u8(byte);
266
267 let base = dst.len();
269 dst.put_u32(0);
270
271 match msg {
273 BackendMessage::CopyInResponse {
274 overall_format,
275 column_formats,
276 }
277 | BackendMessage::CopyOutResponse {
278 overall_format,
279 column_formats,
280 } => {
281 dst.put_format_i8(overall_format);
282 dst.put_length_i16(column_formats.len())?;
283 for format in column_formats {
284 dst.put_format_i16(format);
285 }
286 }
287 BackendMessage::CopyData(data) => {
288 dst.put_slice(&data);
289 }
290 BackendMessage::CopyDone => (),
291 BackendMessage::AuthenticationOk => {
292 dst.put_u32(0);
293 }
294 BackendMessage::AuthenticationCleartextPassword => {
295 dst.put_u32(3);
296 }
297 BackendMessage::AuthenticationSASL => {
298 dst.put_u32(10);
299 dst.put_string("SCRAM-SHA-256");
300 dst.put_u8(b'\0');
301 }
302 BackendMessage::AuthenticationSASLContinue(data) => {
303 dst.put_u32(11);
304 let data = format!(
305 "r={},s={},i={}",
306 data.nonce, data.salt, data.iteration_count
307 );
308 dst.put_slice(data.as_bytes());
309 }
310 BackendMessage::AuthenticationSASLFinal(data) => {
311 dst.put_u32(12);
312 let res = match data.kind {
313 SASLServerFinalMessageKinds::Verifier(verifier) => {
314 format!("v={}", verifier)
315 }
316 };
317 dst.put_slice(res.as_bytes());
318 if !data.extensions.is_empty() {
319 dst.put_slice(b",");
320 dst.put_slice(data.extensions.join(",").as_bytes());
321 }
322 }
323 BackendMessage::RowDescription(fields) => {
324 dst.put_length_i16(fields.len())?;
325 for f in &fields {
326 dst.put_string(&f.name.to_string());
327 dst.put_u32(f.table_id);
328 dst.put_u16(f.column_id);
329 dst.put_u32(f.type_oid);
330 dst.put_i16(f.type_len);
331 dst.put_i32(f.type_mod);
332 dst.put_format_i16(f.format);
334 }
335 }
336 BackendMessage::DataRow(fields) => {
337 dst.put_length_i16(fields.len())?;
338 for (f, (ty, format)) in fields.iter().zip_eq(&self.encode_state) {
339 if let Some(f) = f {
340 let base = dst.len();
341 dst.put_u32(0);
342 f.encode(ty, *format, dst)?;
343 let len = dst.len() - base - 4;
344 let len = i32::try_from(len).map_err(|_| {
345 io::Error::new(
346 io::ErrorKind::Other,
347 "length of encoded data row field does not fit into an i32",
348 )
349 })?;
350 dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
351 } else {
352 dst.put_i32(-1);
353 }
354 }
355 }
356 BackendMessage::CommandComplete { tag } => {
357 dst.put_string(&tag);
358 }
359 BackendMessage::ParseComplete => (),
360 BackendMessage::BindComplete => (),
361 BackendMessage::CloseComplete => (),
362 BackendMessage::EmptyQueryResponse => (),
363 BackendMessage::ReadyForQuery(status) => {
364 dst.put_u8(status.into());
365 }
366 BackendMessage::ParameterStatus(name, value) => {
367 dst.put_string(name);
368 dst.put_string(&value);
369 }
370 BackendMessage::PortalSuspended => (),
371 BackendMessage::NoData => (),
372 BackendMessage::BackendKeyData {
373 conn_id,
374 secret_key,
375 } => {
376 dst.put_u32(conn_id);
377 dst.put_u32(secret_key);
378 }
379 BackendMessage::ParameterDescription(params) => {
380 dst.put_length_i16(params.len())?;
381 for param in params {
382 dst.put_u32(param.oid());
383 }
384 }
385 BackendMessage::ErrorResponse(ErrorResponse {
386 severity,
387 code,
388 message,
389 detail,
390 hint,
391 position,
392 }) => {
393 dst.put_u8(b'S');
394 dst.put_string(severity.as_str());
395 dst.put_u8(b'C');
396 dst.put_string(code.code());
397 dst.put_u8(b'M');
398 dst.put_string(&message);
399 if let Some(detail) = &detail {
400 dst.put_u8(b'D');
401 dst.put_string(detail);
402 }
403 if let Some(hint) = &hint {
404 dst.put_u8(b'H');
405 dst.put_string(hint);
406 }
407 if let Some(position) = &position {
408 dst.put_u8(b'P');
409 dst.put_string(&position.to_string());
410 }
411 dst.put_u8(b'\0');
412 }
413 }
414
415 let len = dst.len() - base;
416
417 let len = i32::try_from(len).map_err(|_| {
419 io::Error::new(
420 io::ErrorKind::Other,
421 "length of encoded message does not fit into an i32",
422 )
423 })?;
424 dst[base..base + 4].copy_from_slice(&len.to_be_bytes());
425
426 Ok(())
427 }
428}
429
430impl Decoder for Codec {
431 type Item = FrontendMessage;
432 type Error = io::Error;
433
434 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FrontendMessage>, io::Error> {
435 if src.len() > MAX_REQUEST_SIZE {
436 return Err(io::Error::new(
437 io::ErrorKind::InvalidData,
438 format!(
439 "request larger than {}",
440 ByteSize::b(u64::cast_from(MAX_REQUEST_SIZE))
441 ),
442 ));
443 }
444 loop {
445 match self.decode_state {
446 DecodeState::Head => {
447 if src.len() < 5 {
448 return Ok(None);
449 }
450 let msg_type = src[0];
451 let frame_len = parse_frame_len(&src[1..])?;
452 src.advance(5);
453 src.reserve(frame_len);
454 self.decode_state = DecodeState::Data(msg_type, frame_len);
455 }
456
457 DecodeState::Data(msg_type, frame_len) => {
458 if src.len() < frame_len {
459 return Ok(None);
460 }
461 let buf = src.split_to(frame_len).freeze();
462 let buf = Cursor::new(&buf);
463 let msg = match msg_type {
464 b'Q' => decode_query(buf)?,
466
467 b'P' => decode_parse(buf)?,
469 b'D' => decode_describe(buf)?,
470 b'B' => decode_bind(buf)?,
471 b'E' => decode_execute(buf)?,
472 b'H' => decode_flush(buf)?,
473 b'S' => decode_sync(buf)?,
474 b'C' => decode_close(buf)?,
475
476 b'X' => decode_terminate(buf)?,
478
479 b'p' => decode_auth(buf)?,
481
482 b'f' => decode_copy_fail(buf)?,
484 b'd' => decode_copy_data(buf, frame_len)?,
485 b'c' => decode_copy_done(buf)?,
486
487 _ => {
489 return Err(io::Error::new(
490 io::ErrorKind::InvalidData,
491 format!("unknown message type {}", msg_type),
492 ));
493 }
494 };
495 src.reserve(5);
496 self.decode_state = DecodeState::Head;
497 return Ok(Some(msg));
498 }
499 }
500 }
501 }
502}
503
504fn decode_terminate(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
505 Ok(FrontendMessage::Terminate)
507}
508
509fn decode_auth(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
510 let mut value = Vec::new();
511 while let Ok(b) = buf.read_byte() {
512 value.push(b);
513 }
514 Ok(FrontendMessage::RawAuthentication(value))
515}
516
517fn expect(buf: &mut Cursor, expected: &[u8]) -> Result<(), io::Error> {
518 for i in 0..expected.len() {
519 if buf.read_byte()? != expected[i] {
520 return Err(input_err(format!(
521 "Invalid SASL initial response: expected '{}'",
522 std::str::from_utf8(expected).unwrap_or("invalid UTF-8")
523 )));
524 }
525 }
526 Ok(())
527}
528
529fn read_until_comma(buf: &mut Cursor) -> Result<Vec<u8>, io::Error> {
530 let mut v = Vec::new();
531 while let Ok(b) = buf.peek_byte() {
532 if b == b',' {
533 break;
534 }
535 v.push(buf.read_byte()?);
536 }
537 Ok(v)
538}
539
540pub fn decode_sasl_client_first_message(mut buf: Cursor) -> Result<SASLInitialResponse, io::Error> {
571 let cbind_flag = match buf.read_byte()? {
573 b'n' => ChannelBinding::None,
574 b'y' => ChannelBinding::ClientSupported,
575 b'p' => {
576 expect(&mut buf, b"=")?;
578 let cbname = String::from_utf8(read_until_comma(&mut buf)?)
579 .map_err(|_| input_err("invalid cbname utf8"))?;
580 ChannelBinding::Required(cbname)
581 }
582 other => {
583 return Err(input_err(format!(
584 "Invalid channel binding flag: {}",
585 other
586 )));
587 }
588 };
589 expect(&mut buf, b",")?;
590
591 let mut authzid = None;
593 if buf.peek_byte()? == b'a' {
594 expect(&mut buf, b"a=")?;
595 let a = String::from_utf8(read_until_comma(&mut buf)?)
596 .map_err(|_| input_err("invalid authzid utf8"))?;
597 authzid = Some(a);
598 }
599 expect(&mut buf, b",")?;
600
601 let mut client_first_message_bare_raw = String::new();
602
603 let mut reserved_mext = None;
605 if buf.peek_byte()? == b'm' {
606 expect(&mut buf, b"m=")?;
607 let mext_val = String::from_utf8(read_until_comma(&mut buf)?)
608 .map_err(|_| input_err("invalid m ext utf8"))?;
609 client_first_message_bare_raw.push_str(&format!("m={},", mext_val));
610 reserved_mext = Some(mext_val);
611 expect(&mut buf, b",")?;
612 }
613
614 expect(&mut buf, b"n=")?;
616 let username = String::from_utf8(read_until_comma(&mut buf)?)
618 .map_err(|_| input_err("invalid username utf8"))?;
619 expect(&mut buf, b",")?;
620 client_first_message_bare_raw.push_str(&format!("n={},", username));
621
622 expect(&mut buf, b"r=")?;
624 let nonce = String::from_utf8(read_until_comma(&mut buf)?)
625 .map_err(|_| input_err("invalid nonce utf8"))?;
626 client_first_message_bare_raw.push_str(&format!("r={}", nonce));
627
628 let mut extensions = Vec::new();
630 while let Ok(b',') = buf.peek_byte().map(|b| b) {
631 expect(&mut buf, b",")?;
632 let ext = String::from_utf8(read_until_comma(&mut buf)?)
633 .map_err(|_| input_err("invalid ext utf8"))?;
634 if !ext.is_empty() {
635 client_first_message_bare_raw.push_str(&format!(",{}", ext));
636 extensions.push(ext);
637 }
638 }
639
640 Ok(SASLInitialResponse {
641 gs2_header: GS2Header {
642 cbind_flag,
643 authzid,
644 },
645 nonce,
646 extensions,
647 reserved_mext,
648 client_first_message_bare_raw,
649 })
650}
651
652pub fn decode_sasl_initial_response(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
653 let mechanism = buf.read_cstr()?;
654 let initial_resp_len = buf.read_i32()?;
655 if initial_resp_len < 0 {
656 return Err(input_err("No initial response"));
658 }
659
660 let initial_response = decode_sasl_client_first_message(buf)?;
661 Ok(FrontendMessage::SASLInitialResponse {
662 gs2_header: initial_response.gs2_header.clone(),
663 mechanism: mechanism.to_owned(),
664 initial_response,
665 })
666}
667
668pub fn decode_sasl_response(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
679 let mut client_final_message_bare_raw = String::new();
681 expect(&mut buf, b"c=")?;
683 let channel_binding = String::from_utf8(read_until_comma(&mut buf)?)
684 .map_err(|_| input_err("invalid channel-binding utf8"))?;
685 expect(&mut buf, b",")?;
686 client_final_message_bare_raw.push_str(&format!("c={},", channel_binding));
687
688 expect(&mut buf, b"r=")?;
690 let nonce = String::from_utf8(read_until_comma(&mut buf)?)
691 .map_err(|_| input_err("invalid nonce utf8"))?;
692 client_final_message_bare_raw.push_str(&format!("r={}", nonce));
693
694 let mut extensions = Vec::new();
696
697 while buf.peek_byte()? == b',' {
699 expect(&mut buf, b",")?;
700 if buf.peek_byte()? == b'p' {
701 break;
702 }
703 let ext = String::from_utf8(read_until_comma(&mut buf)?)
704 .map_err(|_| input_err("invalid extension utf8"))?;
705 if !ext.is_empty() {
706 client_final_message_bare_raw.push_str(&format!(",{}", ext));
707 extensions.push(ext);
708 }
709 }
710
711 expect(&mut buf, b"p=")?;
713 let proof = String::from_utf8(read_until_comma(&mut buf)?)
714 .map_err(|_| input_err("invalid proof utf8"))?;
715
716 Ok(FrontendMessage::SASLResponse(SASLClientFinalResponse {
717 channel_binding,
718 nonce,
719 extensions,
720 proof,
721 client_final_message_bare_raw,
722 }))
723}
724
725pub fn decode_password(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
726 Ok(FrontendMessage::Password {
727 password: buf.read_cstr()?.to_owned(),
728 })
729}
730
731fn decode_query(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
732 Ok(FrontendMessage::Query {
733 sql: buf.read_cstr()?.to_string(),
734 })
735}
736
737fn decode_parse(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
738 let name = buf.read_cstr()?;
739 let sql = buf.read_cstr()?;
740
741 let mut param_types = vec![];
742 for _ in 0..buf.read_i16()? {
743 param_types.push(buf.read_u32()?);
744 }
745
746 Ok(FrontendMessage::Parse {
747 name: name.into(),
748 sql: sql.into(),
749 param_types,
750 })
751}
752
753fn decode_close(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
754 match buf.read_byte()? {
755 b'S' => Ok(FrontendMessage::CloseStatement {
756 name: buf.read_cstr()?.to_owned(),
757 }),
758 b'P' => Ok(FrontendMessage::ClosePortal {
759 name: buf.read_cstr()?.to_owned(),
760 }),
761 b => Err(input_err(format!(
762 "invalid type byte in close message: {}",
763 b
764 ))),
765 }
766}
767
768fn decode_describe(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
769 let first_char = buf.read_byte()?;
770 let name = buf.read_cstr()?.to_string();
771 match first_char {
772 b'S' => Ok(FrontendMessage::DescribeStatement { name }),
773 b'P' => Ok(FrontendMessage::DescribePortal { name }),
774 other => Err(input_err(format!("Invalid describe type: {:#x?}", other))),
775 }
776}
777
778fn decode_bind(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
779 let portal_name = buf.read_cstr()?.to_string();
780 let statement_name = buf.read_cstr()?.to_string();
781
782 let mut param_formats = Vec::new();
783 for _ in 0..buf.read_i16()? {
784 param_formats.push(buf.read_format()?);
785 }
786
787 let mut raw_params = Vec::new();
788 for _ in 0..buf.read_i16()? {
789 let len = buf.read_i32()?;
790 if len == -1 {
791 raw_params.push(None); } else {
793 let mut value = Vec::new();
795 for _ in 0..len {
796 value.push(buf.read_byte()?);
797 }
798 raw_params.push(Some(value));
799 }
800 }
801
802 let mut result_formats = Vec::new();
803 for _ in 0..buf.read_i16()? {
804 result_formats.push(buf.read_format()?);
805 }
806
807 Ok(FrontendMessage::Bind {
808 portal_name,
809 statement_name,
810 param_formats,
811 raw_params,
812 result_formats,
813 })
814}
815
816fn decode_execute(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
817 let portal_name = buf.read_cstr()?.to_string();
818 let max_rows = buf.read_i32()?;
819 Ok(FrontendMessage::Execute {
820 portal_name,
821 max_rows,
822 })
823}
824
825fn decode_flush(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
826 Ok(FrontendMessage::Flush)
828}
829
830fn decode_sync(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
831 Ok(FrontendMessage::Sync)
833}
834
835fn decode_copy_data(mut buf: Cursor, frame_len: usize) -> Result<FrontendMessage, io::Error> {
836 let mut data = Vec::with_capacity(frame_len);
837 for _ in 0..frame_len {
838 data.push(buf.read_byte()?);
839 }
840 Ok(FrontendMessage::CopyData(data))
841}
842
843fn decode_copy_done(mut _buf: Cursor) -> Result<FrontendMessage, io::Error> {
844 Ok(FrontendMessage::CopyDone)
846}
847
848fn decode_copy_fail(mut buf: Cursor) -> Result<FrontendMessage, io::Error> {
849 Ok(FrontendMessage::CopyFail(buf.read_cstr()?.to_string()))
850}