1use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use byteorder::{ByteOrder, NetworkEndian};
19use futures::future::{BoxFuture, FutureExt, pending};
20use itertools::Itertools;
21use mz_adapter::client::RecordFirstRowStream;
22use mz_adapter::session::{
23 EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState,
24 SessionConfig, TransactionStatus,
25};
26use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
27use mz_adapter::{
28 AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, PeekResponseUnary, metrics,
29 verify_datum_desc,
30};
31use mz_auth::password::Password;
32use mz_authenticator::Authenticator;
33use mz_ore::cast::CastFrom;
34use mz_ore::netio::AsyncReady;
35use mz_ore::now::{EpochMillis, SYSTEM_TIME};
36use mz_ore::str::StrExt;
37use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
38use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
39use mz_pgwire_common::{
40 ConnectionCounter, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3, VERSIONS,
41};
42use mz_repr::user::InternalUserMetadata;
43use mz_repr::{
44 CatalogItemId, ColumnIndex, Datum, RelationDesc, RelationType, RowArena, RowIterator, RowRef,
45 ScalarType,
46};
47use mz_server_core::TlsMode;
48use mz_server_core::listeners::AllowedRoles;
49use mz_sql::ast::display::AstDisplay;
50use mz_sql::ast::{CopyDirection, CopyStatement, FetchDirection, Ident, Raw, Statement};
51use mz_sql::parse::StatementParseResult;
52use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
53use mz_sql::session::metadata::SessionMetadata;
54use mz_sql::session::user::INTERNAL_USER_NAMES;
55use mz_sql::session::vars::{MAX_COPY_FROM_SIZE, Var, VarInput};
56use postgres::error::SqlState;
57use tokio::io::{self, AsyncRead, AsyncWrite};
58use tokio::select;
59use tokio::time::{self};
60use tokio_metrics::TaskMetrics;
61use tokio_stream::wrappers::UnboundedReceiverStream;
62use tracing::{Instrument, debug, debug_span, warn};
63use uuid::Uuid;
64
65use crate::codec::FramedConn;
66use crate::message::{self, BackendMessage};
67
68pub fn match_handshake(buf: &[u8]) -> bool {
72 if buf.len() < 8 {
82 return false;
83 }
84 let version = NetworkEndian::read_i32(&buf[4..8]);
85 VERSIONS.contains(&version)
86}
87
88pub struct RunParams<'a, A, I>
90where
91 I: Iterator<Item = TaskMetrics> + Send,
92{
93 pub tls_mode: Option<TlsMode>,
95 pub adapter_client: mz_adapter::Client,
97 pub conn: &'a mut FramedConn<A>,
99 pub conn_uuid: Uuid,
101 pub version: i32,
103 pub params: BTreeMap<String, String>,
105 pub authenticator: Authenticator,
107 pub active_connection_counter: ConnectionCounter,
109 pub helm_chart_version: Option<String>,
111 pub allowed_roles: AllowedRoles,
113 pub tokio_metrics_intervals: I,
115}
116
117#[mz_ore::instrument(level = "debug")]
127pub async fn run<'a, A, I>(
128 RunParams {
129 tls_mode,
130 adapter_client,
131 conn,
132 conn_uuid,
133 version,
134 mut params,
135 authenticator,
136 active_connection_counter,
137 helm_chart_version,
138 allowed_roles,
139 tokio_metrics_intervals,
140 }: RunParams<'a, A, I>,
141) -> Result<(), io::Error>
142where
143 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
144 I: Iterator<Item = TaskMetrics> + Send,
145{
146 if version != VERSION_3 {
147 return conn
148 .send(ErrorResponse::fatal(
149 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
150 "server does not support the client's requested protocol version",
151 ))
152 .await;
153 }
154
155 let user = params.remove("user").unwrap_or_else(String::new);
156
157 let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
159 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
161 let role_allowed = match allowed_roles {
162 AllowedRoles::Normal => !is_reserved_user,
163 AllowedRoles::Internal => is_internal_user,
164 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
165 };
166 if !role_allowed {
167 let msg = format!("unauthorized login to user '{user}'");
168 return conn
169 .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
170 .await;
171 }
172
173 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
174 return conn.send(err).await;
175 }
176
177 let (mut session, expired) = match authenticator {
178 Authenticator::Frontegg(frontegg) => {
179 conn.send(BackendMessage::AuthenticationCleartextPassword)
180 .await?;
181 conn.flush().await?;
182 let password = match conn.recv().await? {
183 Some(FrontendMessage::Password { password }) => password,
184 _ => {
185 return conn
186 .send(ErrorResponse::fatal(
187 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
188 "expected Password message",
189 ))
190 .await;
191 }
192 };
193
194 let auth_response = frontegg.authenticate(&user, &password).await;
195 match auth_response {
196 Ok(mut auth_session) => {
197 let session = adapter_client.new_session(SessionConfig {
205 conn_id: conn.conn_id().clone(),
206 uuid: conn_uuid,
207 user: auth_session.user().into(),
208 client_ip: conn.peer_addr().clone(),
209 external_metadata_rx: Some(auth_session.external_metadata_rx()),
210 internal_user_metadata: None,
211 helm_chart_version,
212 });
213 let expired = async move { auth_session.expired().await };
214 (session, expired.left_future())
215 }
216 Err(err) => {
217 warn!(?err, "pgwire connection failed authentication");
218 return conn
219 .send(ErrorResponse::fatal(
220 SqlState::INVALID_PASSWORD,
221 "invalid password",
222 ))
223 .await;
224 }
225 }
226 }
227 Authenticator::Password(adapter_client) => {
228 conn.send(BackendMessage::AuthenticationCleartextPassword)
229 .await?;
230 conn.flush().await?;
231 let password = match conn.recv().await? {
232 Some(FrontendMessage::Password { password }) => Password(password),
233 _ => {
234 return conn
235 .send(ErrorResponse::fatal(
236 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
237 "expected Password message",
238 ))
239 .await;
240 }
241 };
242 let auth_response = match adapter_client.authenticate(&user, &password).await {
243 Ok(resp) => resp,
244 Err(err) => {
245 warn!(?err, "pgwire connection failed authentication");
246 return conn
247 .send(ErrorResponse::fatal(
248 SqlState::INVALID_PASSWORD,
249 "invalid password",
250 ))
251 .await;
252 }
253 };
254 let session = adapter_client.new_session(SessionConfig {
255 conn_id: conn.conn_id().clone(),
256 uuid: conn_uuid,
257 user,
258 client_ip: conn.peer_addr().clone(),
259 external_metadata_rx: None,
260 internal_user_metadata: Some(InternalUserMetadata {
261 superuser: auth_response.superuser,
262 }),
263 helm_chart_version,
264 });
265 let auth_session = pending().right_future();
267 (session, auth_session)
268 }
269 Authenticator::None => {
270 let session = adapter_client.new_session(SessionConfig {
271 conn_id: conn.conn_id().clone(),
272 uuid: conn_uuid,
273 user,
274 client_ip: conn.peer_addr().clone(),
275 external_metadata_rx: None,
276 internal_user_metadata: None,
277 helm_chart_version,
278 });
279 let auth_session = pending().right_future();
281 (session, auth_session)
282 }
283 };
284
285 let system_vars = adapter_client.get_system_vars().await;
286 for (name, value) in params {
287 let settings = match name.as_str() {
288 "options" => match parse_options(&value) {
289 Ok(opts) => opts,
290 Err(()) => {
291 session.add_notice(AdapterNotice::BadStartupSetting {
292 name,
293 reason: "could not parse".into(),
294 });
295 continue;
296 }
297 },
298 _ => vec![(name, value)],
299 };
300 for (key, val) in settings {
301 const LOCAL: bool = false;
302 if let Err(err) =
307 session
308 .vars_mut()
309 .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
310 {
311 session.add_notice(AdapterNotice::BadStartupSetting {
312 name: key,
313 reason: err.to_string(),
314 });
315 }
316 }
317 }
318 session
319 .vars_mut()
320 .end_transaction(EndTransactionAction::Commit);
321
322 let _guard = match active_connection_counter.allocate_connection(session.user()) {
323 Ok(drop_connection) => drop_connection,
324 Err(e) => {
325 let e: AdapterError = e.into();
326 return conn.send(e.into_response(Severity::Fatal)).await;
327 }
328 };
329
330 let mut adapter_client = match adapter_client.startup(session).await {
332 Ok(adapter_client) => adapter_client,
333 Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
334 };
335
336 let mut buf = vec![BackendMessage::AuthenticationOk];
337 for var in adapter_client.session().vars().notify_set() {
338 buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
339 }
340 buf.push(BackendMessage::BackendKeyData {
341 conn_id: adapter_client.session().conn_id().unhandled(),
342 secret_key: adapter_client.session().secret_key(),
343 });
344 buf.extend(
345 adapter_client
346 .session()
347 .drain_notices()
348 .into_iter()
349 .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
350 );
351 buf.push(BackendMessage::ReadyForQuery(
352 adapter_client.session().transaction().into(),
353 ));
354 conn.send_all(buf).await?;
355 conn.flush().await?;
356
357 let machine = StateMachine {
358 conn,
359 adapter_client,
360 txn_needs_commit: false,
361 tokio_metrics_intervals,
362 };
363
364 select! {
365 r = machine.run() => {
366 if let Err(err) = &r {
371 let _ = conn
372 .send(ErrorResponse::fatal(
373 SqlState::CONNECTION_FAILURE,
374 err.to_string(),
375 ))
376 .await;
377 let _ = conn.flush().await;
378 }
379 r
380 },
381 _ = expired => {
382 conn
383 .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
384 .await?;
385 conn.flush().await
386 }
387 }
388}
389
390fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
395 let opts = split_options(value);
396 let mut pairs = Vec::with_capacity(opts.len());
397 let mut seen_prefix = false;
398 for opt in opts {
399 if !seen_prefix {
400 if opt == "-c" {
401 seen_prefix = true;
402 } else {
403 let (key, val) = parse_option(&opt)?;
404 pairs.push((key.to_owned(), val.to_owned()));
405 }
406 } else {
407 let (key, val) = opt.split_once('=').ok_or(())?;
408 pairs.push((key.to_owned(), val.to_owned()));
409 seen_prefix = false;
410 }
411 }
412 Ok(pairs)
413}
414
415fn parse_option(option: &str) -> Result<(&str, &str), ()> {
419 let (key, value) = option.split_once('=').ok_or(())?;
420 for prefix in &["-c", "--"] {
421 if let Some(key) = key.strip_prefix(prefix) {
422 return Ok((key, value));
423 }
424 }
425 Err(())
426}
427
428fn split_options(value: &str) -> Vec<String> {
430 let mut strs = Vec::new();
431 let mut current = String::new();
435 let mut was_slash = false;
436 for c in value.chars() {
437 was_slash = match c {
438 ' ' => {
439 if was_slash {
440 current.push(' ');
441 } else if !current.is_empty() {
442 strs.push(std::mem::take(&mut current));
445 }
446 false
447 }
448 '\\' => {
449 if was_slash {
450 current.push('\\');
453 false
454 } else {
455 true
456 }
457 }
458 _ => {
459 current.push(c);
460 false
461 }
462 };
463 }
464 if !current.is_empty() {
466 strs.push(current);
467 }
468 strs
469}
470
471#[derive(Debug)]
472enum State {
473 Ready,
474 Drain,
475 Done,
476}
477
478struct StateMachine<'a, A, I>
479where
480 I: Iterator<Item = TaskMetrics> + Send + 'a,
481{
482 conn: &'a mut FramedConn<A>,
483 adapter_client: mz_adapter::SessionClient,
484 txn_needs_commit: bool,
485 tokio_metrics_intervals: I,
486}
487
488enum SendRowsEndedReason {
489 Success {
490 result_size: u64,
491 rows_returned: u64,
492 },
493 Errored {
494 error: String,
495 },
496 Canceled,
497}
498
499const ABORTED_TXN_MSG: &str =
500 "current transaction is aborted, commands ignored until end of transaction block";
501
502impl<'a, A, I> StateMachine<'a, A, I>
503where
504 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
505 I: Iterator<Item = TaskMetrics> + Send + 'a,
506{
507 #[allow(clippy::manual_async_fn)]
511 #[mz_ore::instrument(level = "debug")]
512 fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
513 async move {
514 let mut state = State::Ready;
515 loop {
516 self.send_pending_notices().await?;
517 state = match state {
518 State::Ready => self.advance_ready().await?,
519 State::Drain => self.advance_drain().await?,
520 State::Done => return Ok(()),
521 };
522 self.adapter_client
523 .add_idle_in_transaction_session_timeout();
524 }
525 }
526 }
527
528 #[instrument(level = "debug")]
529 async fn advance_ready(&mut self) -> Result<State, io::Error> {
530 self.tokio_metrics_intervals
532 .next()
533 .expect("infinite iterator");
534
535 let message = select! {
537 biased;
538
539 Some(timeout) = self.adapter_client.recv_timeout() => {
541 let err: AdapterError = timeout.into();
542 let conn_id = self.adapter_client.session().conn_id();
543 tracing::warn!("session timed out, conn_id {}", conn_id);
544
545 let error_response = err.into_response(Severity::Fatal);
547 let error_state = self.error(error_response).await;
548
549 self.adapter_client.terminate().await;
551
552 let _ = self.conn.recv().await?;
556 return error_state;
557 },
558 message = self.conn.recv() => message?,
560 };
561
562 let interval = self
564 .tokio_metrics_intervals
565 .next()
566 .expect("infinite iterator");
567 let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
568
569 let received = SYSTEM_TIME();
573
574 self.adapter_client
575 .remove_idle_in_transaction_session_timeout();
576
577 let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
580
581 let start = message.as_ref().map(|_| Instant::now());
582 let next_state = match message {
583 Some(FrontendMessage::Query { sql }) => {
584 let query_root_span =
585 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
586 query_root_span.follows_from(tracing::Span::current());
587 self.query(sql, received)
588 .instrument(query_root_span)
589 .await?
590 }
591 Some(FrontendMessage::Parse {
592 name,
593 sql,
594 param_types,
595 }) => self.parse(name, sql, param_types).await?,
596 Some(FrontendMessage::Bind {
597 portal_name,
598 statement_name,
599 param_formats,
600 raw_params,
601 result_formats,
602 }) => {
603 self.bind(
604 portal_name,
605 statement_name,
606 param_formats,
607 raw_params,
608 result_formats,
609 )
610 .await?
611 }
612 Some(FrontendMessage::Execute {
613 portal_name,
614 max_rows,
615 }) => {
616 let max_rows = match usize::try_from(max_rows) {
617 Ok(0) | Err(_) => ExecuteCount::All, Ok(n) => ExecuteCount::Count(n),
619 };
620 let execute_root_span =
621 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
622 execute_root_span.follows_from(tracing::Span::current());
623 let state = self
624 .execute(
625 portal_name,
626 max_rows,
627 portal_exec_message,
628 None,
629 ExecuteTimeout::None,
630 None,
631 Some(received),
632 )
633 .instrument(execute_root_span)
634 .await?;
635 if self.adapter_client.session().transaction().is_implicit() {
650 self.txn_needs_commit = true;
651 }
652 state
653 }
654 Some(FrontendMessage::DescribeStatement { name }) => {
655 self.describe_statement(&name).await?
656 }
657 Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
658 Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
659 Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
660 Some(FrontendMessage::Flush) => self.flush().await?,
661 Some(FrontendMessage::Sync) => self.sync().await?,
662 Some(FrontendMessage::Terminate) => State::Done,
663
664 Some(FrontendMessage::CopyData(_))
665 | Some(FrontendMessage::CopyDone)
666 | Some(FrontendMessage::CopyFail(_))
667 | Some(FrontendMessage::Password { .. }) => State::Drain,
668 None => State::Done,
669 };
670
671 if let Some(start) = start {
672 self.adapter_client
673 .inner()
674 .metrics()
675 .pgwire_message_processing_seconds
676 .with_label_values(&[message_name])
677 .observe(start.elapsed().as_secs_f64());
678 }
679 self.adapter_client
680 .inner()
681 .metrics()
682 .pgwire_recv_scheduling_delay_ms
683 .with_label_values(&[message_name])
684 .observe(recv_scheduling_delay_ms);
685
686 Ok(next_state)
687 }
688
689 async fn advance_drain(&mut self) -> Result<State, io::Error> {
690 let message = self.conn.recv().await?;
691 if message.is_some() {
692 self.adapter_client
693 .remove_idle_in_transaction_session_timeout();
694 }
695 match message {
696 Some(FrontendMessage::Sync) => self.sync().await,
697 None => Ok(State::Done),
698 _ => Ok(State::Drain),
699 }
700 }
701
702 #[instrument(level = "debug")]
706 async fn one_query(
707 &mut self,
708 stmt: Statement<Raw>,
709 sql: String,
710 lifecycle_timestamps: LifecycleTimestamps,
711 ) -> Result<State, io::Error> {
712 const EMPTY_PORTAL: &str = "";
715 if let Err(e) = self
716 .adapter_client
717 .declare(EMPTY_PORTAL.to_string(), stmt, sql)
718 .await
719 {
720 return self.error(e.into_response(Severity::Error)).await;
721 }
722 let portal = self
723 .adapter_client
724 .session()
725 .get_portal_unverified_mut(EMPTY_PORTAL)
726 .expect("unnamed portal should be present");
727
728 *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
729
730 let stmt_desc = portal.desc.clone();
731 if !stmt_desc.param_types.is_empty() {
732 return self
733 .error(ErrorResponse::error(
734 SqlState::UNDEFINED_PARAMETER,
735 "there is no parameter $1",
736 ))
737 .await;
738 }
739
740 if let Some(relation_desc) = &stmt_desc.relation_desc {
742 if !stmt_desc.is_copy {
743 let formats = vec![Format::Text; stmt_desc.arity()];
744 self.send(BackendMessage::RowDescription(
745 message::encode_row_description(relation_desc, &formats),
746 ))
747 .await?;
748 }
749 }
750
751 let result = match self
752 .adapter_client
753 .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
754 .await
755 {
756 Ok((response, execute_started)) => {
757 self.send_pending_notices().await?;
758 self.send_execute_response(
759 response,
760 stmt_desc.relation_desc,
761 EMPTY_PORTAL.to_string(),
762 ExecuteCount::All,
763 portal_exec_message,
764 None,
765 ExecuteTimeout::None,
766 execute_started,
767 )
768 .await
769 }
770 Err(e) => {
771 self.send_pending_notices().await?;
772 self.error(e.into_response(Severity::Error)).await
773 }
774 };
775
776 self.adapter_client.session().remove_portal(EMPTY_PORTAL);
778
779 result
780 }
781
782 async fn ensure_transaction(
783 &mut self,
784 num_stmts: usize,
785 message_type: &str,
786 ) -> Result<(), io::Error> {
787 let start = Instant::now();
788 if self.txn_needs_commit {
789 self.commit_transaction().await?;
790 }
791 let res = self.adapter_client.start_transaction(Some(num_stmts));
794 assert_ok!(res);
795 self.adapter_client
796 .inner()
797 .metrics()
798 .pgwire_ensure_transaction_seconds
799 .with_label_values(&[message_type])
800 .observe(start.elapsed().as_secs_f64());
801 Ok(())
802 }
803
804 fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
805 let parse_start = Instant::now();
806 let result = match self.adapter_client.parse(sql) {
807 Ok(result) => result.map_err(|e| {
808 let pos = sql[..e.error.pos].chars().count() + 1;
811 ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
812 }),
813 Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
814 };
815 self.adapter_client
816 .inner()
817 .metrics()
818 .parse_seconds
819 .with_label_values(&[])
820 .observe(parse_start.elapsed().as_secs_f64());
821 result
822 }
823
824 #[instrument(level = "debug")]
829 async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
830 let stmts = match self.parse_sql(&sql) {
832 Ok(stmts) => stmts,
833 Err(err) => {
834 self.error(err).await?;
835 return self.ready().await;
836 }
837 };
838
839 let num_stmts = stmts.len();
840
841 for StatementParseResult { ast: stmt, sql } in stmts {
843 if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
845 self.aborted_txn_error().await?;
846 break;
847 }
848
849 self.ensure_transaction(num_stmts, "query").await?;
857
858 match self
859 .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
860 .await?
861 {
862 State::Ready => (),
863 State::Drain => break,
864 State::Done => return Ok(State::Done),
865 }
866 }
867
868 {
870 if self.adapter_client.session().transaction().is_implicit() {
871 self.commit_transaction().await?;
872 }
873 }
874
875 if num_stmts == 0 {
876 self.send(BackendMessage::EmptyQueryResponse).await?;
877 }
878
879 self.ready().await
880 }
881
882 #[instrument(level = "debug")]
883 async fn parse(
884 &mut self,
885 name: String,
886 sql: String,
887 param_oids: Vec<u32>,
888 ) -> Result<State, io::Error> {
889 self.ensure_transaction(1, "parse").await?;
891
892 let mut param_types = vec![];
893 for oid in param_oids {
894 match mz_pgrepr::Type::from_oid(oid) {
895 Ok(ty) => match ScalarType::try_from(&ty) {
896 Ok(ty) => param_types.push(Some(ty)),
897 Err(err) => {
898 return self
899 .error(ErrorResponse::error(
900 SqlState::INVALID_PARAMETER_VALUE,
901 err.to_string(),
902 ))
903 .await;
904 }
905 },
906 Err(_) if oid == 0 => param_types.push(None),
907 Err(e) => {
908 return self
909 .error(ErrorResponse::error(
910 SqlState::PROTOCOL_VIOLATION,
911 e.to_string(),
912 ))
913 .await;
914 }
915 }
916 }
917
918 let stmts = match self.parse_sql(&sql) {
919 Ok(stmts) => stmts,
920 Err(err) => {
921 return self.error(err).await;
922 }
923 };
924 if stmts.len() > 1 {
925 return self
926 .error(ErrorResponse::error(
927 SqlState::INTERNAL_ERROR,
928 "cannot insert multiple commands into a prepared statement",
929 ))
930 .await;
931 }
932 let (maybe_stmt, sql) = match stmts.into_iter().next() {
933 None => (None, ""),
934 Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
935 };
936 if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
937 return self.aborted_txn_error().await;
938 }
939 match self
940 .adapter_client
941 .prepare(name, maybe_stmt, sql.to_string(), param_types)
942 .await
943 {
944 Ok(()) => {
945 self.send(BackendMessage::ParseComplete).await?;
946 Ok(State::Ready)
947 }
948 Err(e) => self.error(e.into_response(Severity::Error)).await,
949 }
950 }
951
952 #[instrument(level = "debug")]
954 async fn commit_transaction(&mut self) -> Result<(), io::Error> {
955 self.end_transaction(EndTransactionAction::Commit).await
956 }
957
958 #[instrument(level = "debug")]
960 async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
961 self.end_transaction(EndTransactionAction::Rollback).await
962 }
963
964 #[instrument(level = "debug")]
966 async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
967 self.txn_needs_commit = false;
968 let resp = self.adapter_client.end_transaction(action).await;
969 if let Err(err) = resp {
970 self.send(BackendMessage::ErrorResponse(
971 err.into_response(Severity::Error),
972 ))
973 .await?;
974 }
975 Ok(())
976 }
977
978 #[instrument(level = "debug")]
979 async fn bind(
980 &mut self,
981 portal_name: String,
982 statement_name: String,
983 param_formats: Vec<Format>,
984 raw_params: Vec<Option<Vec<u8>>>,
985 result_formats: Vec<Format>,
986 ) -> Result<State, io::Error> {
987 self.ensure_transaction(1, "bind").await?;
989
990 let aborted_txn = self.is_aborted_txn();
991 let stmt = match self
992 .adapter_client
993 .get_prepared_statement(&statement_name)
994 .await
995 {
996 Ok(stmt) => stmt,
997 Err(err) => return self.error(err.into_response(Severity::Error)).await,
998 };
999
1000 let param_types = &stmt.desc().param_types;
1001 if param_types.len() != raw_params.len() {
1002 let message = format!(
1003 "bind message supplies {actual} parameters, \
1004 but prepared statement \"{name}\" requires {expected}",
1005 name = statement_name,
1006 actual = raw_params.len(),
1007 expected = param_types.len()
1008 );
1009 return self
1010 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, message))
1011 .await;
1012 }
1013 let param_formats = match pad_formats(param_formats, raw_params.len()) {
1014 Ok(param_formats) => param_formats,
1015 Err(msg) => {
1016 return self
1017 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1018 .await;
1019 }
1020 };
1021 if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1022 return self.aborted_txn_error().await;
1023 }
1024 let buf = RowArena::new();
1025 let mut params = vec![];
1026 for ((raw_param, mz_typ), format) in raw_params
1027 .into_iter()
1028 .zip_eq(param_types)
1029 .zip_eq(param_formats)
1030 {
1031 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1032 let datum = match raw_param {
1033 None => Datum::Null,
1034 Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1035 Ok(param) => param.into_datum(&buf, &pg_typ),
1036 Err(err) => {
1037 let msg = format!("unable to decode parameter: {}", err);
1038 return self
1039 .error(ErrorResponse::error(SqlState::INVALID_PARAMETER_VALUE, msg))
1040 .await;
1041 }
1042 },
1043 };
1044 params.push((datum, mz_typ.clone()))
1045 }
1046
1047 let result_formats = match pad_formats(
1048 result_formats,
1049 stmt.desc()
1050 .relation_desc
1051 .clone()
1052 .map(|desc| desc.typ().column_types.len())
1053 .unwrap_or(0),
1054 ) {
1055 Ok(result_formats) => result_formats,
1056 Err(msg) => {
1057 return self
1058 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1059 .await;
1060 }
1061 };
1062
1063 if !stmt.stmt().map_or(false, |stmt| {
1066 matches!(
1067 stmt,
1068 Statement::Copy(CopyStatement {
1069 direction: CopyDirection::To,
1070 ..
1071 })
1072 )
1073 }) {
1074 if let Some(desc) = stmt.desc().relation_desc.clone() {
1075 for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1076 match (format, &ty.scalar_type) {
1077 (Format::Binary, mz_repr::ScalarType::List { .. }) => {
1078 return self
1079 .error(ErrorResponse::error(
1080 SqlState::PROTOCOL_VIOLATION,
1081 "binary encoding of list types is not implemented",
1082 ))
1083 .await;
1084 }
1085 (Format::Binary, mz_repr::ScalarType::Map { .. }) => {
1086 return self
1087 .error(ErrorResponse::error(
1088 SqlState::PROTOCOL_VIOLATION,
1089 "binary encoding of map types is not implemented",
1090 ))
1091 .await;
1092 }
1093 (Format::Binary, mz_repr::ScalarType::AclItem) => {
1094 return self
1095 .error(ErrorResponse::error(
1096 SqlState::PROTOCOL_VIOLATION,
1097 "binary encoding of aclitem types does not exist",
1098 ))
1099 .await;
1100 }
1101 _ => (),
1102 }
1103 }
1104 }
1105 }
1106
1107 let desc = stmt.desc().clone();
1108 let logging = Arc::clone(stmt.logging());
1109 let stmt_ast = stmt.stmt().cloned();
1110 let state_revision = stmt.state_revision;
1111 if let Err(err) = self.adapter_client.session().set_portal(
1112 portal_name,
1113 desc,
1114 stmt_ast,
1115 logging,
1116 params,
1117 result_formats,
1118 state_revision,
1119 ) {
1120 return self.error(err.into_response(Severity::Error)).await;
1121 }
1122
1123 self.send(BackendMessage::BindComplete).await?;
1124 Ok(State::Ready)
1125 }
1126
1127 fn execute(
1128 &mut self,
1129 portal_name: String,
1130 max_rows: ExecuteCount,
1131 get_response: GetResponse,
1132 fetch_portal_name: Option<String>,
1133 timeout: ExecuteTimeout,
1134 outer_ctx_extra: Option<ExecuteContextExtra>,
1135 received: Option<EpochMillis>,
1136 ) -> BoxFuture<'_, Result<State, io::Error>> {
1137 async move {
1138 let aborted_txn = self.is_aborted_txn();
1139
1140 let portal = match self
1142 .adapter_client
1143 .session()
1144 .get_portal_unverified_mut(&portal_name)
1145 {
1146 Some(portal) => portal,
1147 None => {
1148 let msg = format!("portal {} does not exist", portal_name.quoted());
1149 if let Some(outer_ctx_extra) = outer_ctx_extra {
1150 self.adapter_client.retire_execute(
1151 outer_ctx_extra,
1152 StatementEndedExecutionReason::Errored { error: msg.clone() },
1153 );
1154 }
1155 return self
1156 .error(ErrorResponse::error(SqlState::INVALID_CURSOR_NAME, msg))
1157 .await;
1158 }
1159 };
1160
1161 *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1162
1163 let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1165 if aborted_txn && !txn_exit_stmt {
1166 if let Some(outer_ctx_extra) = outer_ctx_extra {
1167 self.adapter_client.retire_execute(
1168 outer_ctx_extra,
1169 StatementEndedExecutionReason::Errored {
1170 error: ABORTED_TXN_MSG.to_string(),
1171 },
1172 );
1173 }
1174 return self.aborted_txn_error().await;
1175 }
1176
1177 let row_desc = portal.desc.relation_desc.clone();
1178 match portal.state {
1179 PortalState::NotStarted => {
1180 self.ensure_transaction(1, "execute").await?;
1182 match self
1183 .adapter_client
1184 .execute(
1185 portal_name.clone(),
1186 self.conn.wait_closed(),
1187 outer_ctx_extra,
1188 )
1189 .await
1190 {
1191 Ok((response, execute_started)) => {
1192 self.send_pending_notices().await?;
1193 self.send_execute_response(
1194 response,
1195 row_desc,
1196 portal_name,
1197 max_rows,
1198 get_response,
1199 fetch_portal_name,
1200 timeout,
1201 execute_started,
1202 )
1203 .await
1204 }
1205 Err(e) => {
1206 self.send_pending_notices().await?;
1207 self.error(e.into_response(Severity::Error)).await
1208 }
1209 }
1210 }
1211 PortalState::InProgress(rows) => {
1212 let rows = rows.take().expect("InProgress rows must be populated");
1213 let (result, statement_ended_execution_reason) = match self
1214 .send_rows(
1215 row_desc.expect("portal missing row desc on resumption"),
1216 portal_name,
1217 rows,
1218 max_rows,
1219 get_response,
1220 fetch_portal_name,
1221 timeout,
1222 )
1223 .await
1224 {
1225 Err(e) => {
1226 (Err(e), StatementEndedExecutionReason::Canceled)
1229 }
1230 Ok((ok, SendRowsEndedReason::Canceled)) => {
1231 (Ok(ok), StatementEndedExecutionReason::Canceled)
1232 }
1233 Ok((
1246 ok,
1247 SendRowsEndedReason::Success {
1248 result_size: _,
1249 rows_returned: _,
1250 },
1251 )) => (
1252 Ok(ok),
1253 StatementEndedExecutionReason::Success {
1254 result_size: None,
1255 rows_returned: None,
1256 execution_strategy: None,
1257 },
1258 ),
1259 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1260 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1261 }
1262 };
1263 if let Some(outer_ctx_extra) = outer_ctx_extra {
1264 self.adapter_client
1265 .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1266 }
1267 result
1268 }
1269 PortalState::Completed(Some(tag)) => {
1276 let tag = tag.to_string();
1277 if let Some(outer_ctx_extra) = outer_ctx_extra {
1278 self.adapter_client.retire_execute(
1279 outer_ctx_extra,
1280 StatementEndedExecutionReason::Success {
1281 result_size: None,
1282 rows_returned: None,
1283 execution_strategy: None,
1284 },
1285 );
1286 }
1287 self.send(BackendMessage::CommandComplete { tag }).await?;
1288 Ok(State::Ready)
1289 }
1290 PortalState::Completed(None) => {
1291 let error = format!(
1292 "portal {} cannot be run",
1293 Ident::new_unchecked(portal_name).to_ast_string_stable()
1294 );
1295 if let Some(outer_ctx_extra) = outer_ctx_extra {
1296 self.adapter_client.retire_execute(
1297 outer_ctx_extra,
1298 StatementEndedExecutionReason::Errored {
1299 error: error.clone(),
1300 },
1301 );
1302 }
1303 self.error(ErrorResponse::error(
1304 SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1305 error,
1306 ))
1307 .await
1308 }
1309 }
1310 }
1311 .instrument(debug_span!("execute"))
1312 .boxed()
1313 }
1314
1315 #[instrument(level = "debug")]
1316 async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1317 self.ensure_transaction(1, "describe_statement").await?;
1319
1320 let stmt = match self.adapter_client.get_prepared_statement(name).await {
1321 Ok(stmt) => stmt,
1322 Err(err) => return self.error(err.into_response(Severity::Error)).await,
1323 };
1324 let parameter_desc = BackendMessage::ParameterDescription(
1326 stmt.desc()
1327 .param_types
1328 .iter()
1329 .map(mz_pgrepr::Type::from)
1330 .collect(),
1331 );
1332 let formats = vec![Format::Text; stmt.desc().arity()];
1336 let row_desc = describe_rows(stmt.desc(), &formats);
1337 self.send_all([parameter_desc, row_desc]).await?;
1338 Ok(State::Ready)
1339 }
1340
1341 #[instrument(level = "debug")]
1342 async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1343 self.ensure_transaction(1, "describe_portal").await?;
1345
1346 let session = self.adapter_client.session();
1347 let row_desc = session
1348 .get_portal_unverified(name)
1349 .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1350 match row_desc {
1351 Some(row_desc) => {
1352 self.send(row_desc).await?;
1353 Ok(State::Ready)
1354 }
1355 None => {
1356 self.error(ErrorResponse::error(
1357 SqlState::INVALID_CURSOR_NAME,
1358 format!("portal {} does not exist", name.quoted()),
1359 ))
1360 .await
1361 }
1362 }
1363 }
1364
1365 #[instrument(level = "debug")]
1366 async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1367 self.adapter_client
1368 .session()
1369 .remove_prepared_statement(&name);
1370 self.send(BackendMessage::CloseComplete).await?;
1371 Ok(State::Ready)
1372 }
1373
1374 #[instrument(level = "debug")]
1375 async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1376 self.adapter_client.session().remove_portal(&name);
1377 self.send(BackendMessage::CloseComplete).await?;
1378 Ok(State::Ready)
1379 }
1380
1381 fn complete_portal(&mut self, name: &str) {
1382 let portal = self
1383 .adapter_client
1384 .session()
1385 .get_portal_unverified_mut(name)
1386 .expect("portal should exist");
1387 *portal.state = PortalState::Completed(None);
1388 }
1389
1390 async fn fetch(
1391 &mut self,
1392 name: String,
1393 count: Option<FetchDirection>,
1394 max_rows: ExecuteCount,
1395 fetch_portal_name: Option<String>,
1396 timeout: ExecuteTimeout,
1397 ctx_extra: ExecuteContextExtra,
1398 ) -> Result<State, io::Error> {
1399 let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1402
1403 let count = match (max_rows, count) {
1416 (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1417 let count = usize::cast_from(count);
1418 if max_rows < count {
1419 let msg = "Execute with max_rows < a FETCH's count is not supported";
1420 self.adapter_client.retire_execute(
1421 ctx_extra,
1422 StatementEndedExecutionReason::Errored {
1423 error: msg.to_string(),
1424 },
1425 );
1426 return self
1427 .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1428 .await;
1429 }
1430 ExecuteCount::Count(count)
1431 }
1432 (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1433 let msg = "Execute with max_rows of a FETCH ALL is not supported";
1434 self.adapter_client.retire_execute(
1435 ctx_extra,
1436 StatementEndedExecutionReason::Errored {
1437 error: msg.to_string(),
1438 },
1439 );
1440 return self
1441 .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1442 .await;
1443 }
1444 (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1445 (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1446 ExecuteCount::Count(usize::cast_from(count))
1447 }
1448 };
1449 let cursor_name = name.to_string();
1450 self.execute(
1451 cursor_name,
1452 count,
1453 fetch_message,
1454 fetch_portal_name,
1455 timeout,
1456 Some(ctx_extra),
1457 None,
1458 )
1459 .await
1460 }
1461
1462 async fn flush(&mut self) -> Result<State, io::Error> {
1463 self.conn.flush().await?;
1464 Ok(State::Ready)
1465 }
1466
1467 #[instrument(level = "debug")]
1472 async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1473 where
1474 M: Into<BackendMessage>,
1475 {
1476 let message: BackendMessage = message.into();
1477 let is_error =
1478 matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1479
1480 self.conn.send(message).await?;
1481
1482 if is_error {
1489 self.conn.flush().await?;
1490 }
1491
1492 Ok(())
1493 }
1494
1495 #[instrument(level = "debug")]
1496 pub async fn send_all(
1497 &mut self,
1498 messages: impl IntoIterator<Item = BackendMessage>,
1499 ) -> Result<(), io::Error> {
1500 for m in messages {
1501 self.send(m).await?;
1502 }
1503 Ok(())
1504 }
1505
1506 #[instrument(level = "debug")]
1507 async fn sync(&mut self) -> Result<State, io::Error> {
1508 if self.adapter_client.session().transaction().is_implicit() {
1510 self.commit_transaction().await?;
1511 }
1512 self.ready().await
1513 }
1514
1515 #[instrument(level = "debug")]
1516 async fn ready(&mut self) -> Result<State, io::Error> {
1517 let txn_state = self.adapter_client.session().transaction().into();
1518 self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1519 self.flush().await
1520 }
1521
1522 #[allow(clippy::too_many_arguments)]
1523 #[instrument(level = "debug")]
1524 async fn send_execute_response(
1525 &mut self,
1526 response: ExecuteResponse,
1527 row_desc: Option<RelationDesc>,
1528 portal_name: String,
1529 max_rows: ExecuteCount,
1530 get_response: GetResponse,
1531 fetch_portal_name: Option<String>,
1532 timeout: ExecuteTimeout,
1533 execute_started: Instant,
1534 ) -> Result<State, io::Error> {
1535 let mut tag = response.tag();
1536
1537 macro_rules! command_complete {
1538 () => {{
1539 self.send(BackendMessage::CommandComplete {
1540 tag: tag
1541 .take()
1542 .expect("command_complete only called on tag-generating results"),
1543 })
1544 .await?;
1545 Ok(State::Ready)
1546 }};
1547 }
1548
1549 let r = match response {
1550 ExecuteResponse::ClosedCursor => {
1551 self.complete_portal(&portal_name);
1552 command_complete!()
1553 }
1554 ExecuteResponse::DeclaredCursor => {
1555 self.complete_portal(&portal_name);
1556 command_complete!()
1557 }
1558 ExecuteResponse::EmptyQuery => {
1559 self.send(BackendMessage::EmptyQueryResponse).await?;
1560 Ok(State::Ready)
1561 }
1562 ExecuteResponse::Fetch {
1563 name,
1564 count,
1565 timeout,
1566 ctx_extra,
1567 } => {
1568 self.fetch(
1569 name,
1570 count,
1571 max_rows,
1572 Some(portal_name.to_string()),
1573 timeout,
1574 ctx_extra,
1575 )
1576 .await
1577 }
1578 ExecuteResponse::SendingRowsStreaming {
1579 rows,
1580 instance_id,
1581 strategy,
1582 } => {
1583 let row_desc = row_desc
1584 .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1585
1586 let span = tracing::debug_span!("sending_rows_streaming");
1587
1588 self.send_rows(
1589 row_desc,
1590 portal_name,
1591 InProgressRows::new(RecordFirstRowStream::new(
1592 Box::new(rows),
1593 execute_started,
1594 &self.adapter_client,
1595 Some(instance_id),
1596 Some(strategy),
1597 )),
1598 max_rows,
1599 get_response,
1600 fetch_portal_name,
1601 timeout,
1602 )
1603 .instrument(span)
1604 .await
1605 .map(|(state, _)| state)
1606 }
1607 ExecuteResponse::SendingRowsImmediate { rows } => {
1608 let row_desc = row_desc
1609 .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1610
1611 let span = tracing::debug_span!("sending_rows_immediate");
1612
1613 let stream =
1614 futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1615 self.send_rows(
1616 row_desc,
1617 portal_name,
1618 InProgressRows::new(RecordFirstRowStream::new(
1619 Box::new(stream),
1620 execute_started,
1621 &self.adapter_client,
1622 None,
1623 Some(StatementExecutionStrategy::Constant),
1624 )),
1625 max_rows,
1626 get_response,
1627 fetch_portal_name,
1628 timeout,
1629 )
1630 .instrument(span)
1631 .await
1632 .map(|(state, _)| state)
1633 }
1634 ExecuteResponse::SetVariable { name, .. } => {
1635 let qn = name.to_string();
1638 let msg = if let Some(var) = self
1639 .adapter_client
1640 .session()
1641 .vars_mut()
1642 .notify_set()
1643 .find(|v| v.name() == qn)
1644 {
1645 Some(BackendMessage::ParameterStatus(var.name(), var.value()))
1646 } else {
1647 None
1648 };
1649 if let Some(msg) = msg {
1650 self.send(msg).await?;
1651 }
1652 command_complete!()
1653 }
1654 ExecuteResponse::Subscribing {
1655 rx,
1656 ctx_extra,
1657 instance_id,
1658 } => {
1659 if fetch_portal_name.is_none() {
1660 let mut msg = ErrorResponse::notice(
1661 SqlState::WARNING,
1662 "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
1663 );
1664 if self.adapter_client.session().vars().application_name() == "psql" {
1665 msg.hint = Some(
1666 "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
1667 .into(),
1668 )
1669 }
1670 self.send(msg).await?;
1671 self.conn.flush().await?;
1672 }
1673 let row_desc =
1674 row_desc.expect("missing row description for ExecuteResponse::Subscribing");
1675 let (result, statement_ended_execution_reason) = match self
1676 .send_rows(
1677 row_desc,
1678 portal_name,
1679 InProgressRows::new(RecordFirstRowStream::new(
1680 Box::new(UnboundedReceiverStream::new(rx)),
1681 execute_started,
1682 &self.adapter_client,
1683 Some(instance_id),
1684 None,
1685 )),
1686 max_rows,
1687 get_response,
1688 fetch_portal_name,
1689 timeout,
1690 )
1691 .await
1692 {
1693 Err(e) => {
1694 (Err(e), StatementEndedExecutionReason::Canceled)
1697 }
1698 Ok((ok, SendRowsEndedReason::Canceled)) => {
1699 (Ok(ok), StatementEndedExecutionReason::Canceled)
1700 }
1701 Ok((
1702 ok,
1703 SendRowsEndedReason::Success {
1704 result_size,
1705 rows_returned,
1706 },
1707 )) => (
1708 Ok(ok),
1709 StatementEndedExecutionReason::Success {
1710 result_size: Some(result_size),
1711 rows_returned: Some(rows_returned),
1712 execution_strategy: None,
1713 },
1714 ),
1715 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1716 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1717 }
1718 };
1719 self.adapter_client
1720 .retire_execute(ctx_extra, statement_ended_execution_reason);
1721 return result;
1722 }
1723 ExecuteResponse::CopyTo { format, resp } => {
1724 let row_desc =
1725 row_desc.expect("missing row description for ExecuteResponse::CopyTo");
1726 match *resp {
1727 ExecuteResponse::Subscribing {
1728 rx,
1729 ctx_extra,
1730 instance_id,
1731 } => {
1732 let (result, statement_ended_execution_reason) = match self
1733 .copy_rows(
1734 format,
1735 row_desc,
1736 RecordFirstRowStream::new(
1737 Box::new(UnboundedReceiverStream::new(rx)),
1738 execute_started,
1739 &self.adapter_client,
1740 Some(instance_id),
1741 None,
1742 ),
1743 )
1744 .await
1745 {
1746 Err(e) => {
1747 (Err(e), StatementEndedExecutionReason::Canceled)
1750 }
1751 Ok((
1752 state,
1753 SendRowsEndedReason::Success {
1754 result_size,
1755 rows_returned,
1756 },
1757 )) => (
1758 Ok(state),
1759 StatementEndedExecutionReason::Success {
1760 result_size: Some(result_size),
1761 rows_returned: Some(rows_returned),
1762 execution_strategy: None,
1763 },
1764 ),
1765 Ok((state, SendRowsEndedReason::Errored { error })) => {
1766 (Ok(state), StatementEndedExecutionReason::Errored { error })
1767 }
1768 Ok((state, SendRowsEndedReason::Canceled)) => {
1769 (Ok(state), StatementEndedExecutionReason::Canceled)
1770 }
1771 };
1772 self.adapter_client
1773 .retire_execute(ctx_extra, statement_ended_execution_reason);
1774 return result;
1775 }
1776 ExecuteResponse::SendingRowsStreaming {
1777 rows,
1778 instance_id,
1779 strategy,
1780 } => {
1781 return self
1786 .copy_rows(
1787 format,
1788 row_desc,
1789 RecordFirstRowStream::new(
1790 Box::new(rows),
1791 execute_started,
1792 &self.adapter_client,
1793 Some(instance_id),
1794 Some(strategy),
1795 ),
1796 )
1797 .await
1798 .map(|(state, _)| state);
1799 }
1800 ExecuteResponse::SendingRowsImmediate { rows } => {
1801 let span = tracing::debug_span!("sending_rows_immediate");
1802
1803 let rows = futures::stream::once(futures::future::ready(
1804 PeekResponseUnary::Rows(rows),
1805 ));
1806 return self
1811 .copy_rows(
1812 format,
1813 row_desc,
1814 RecordFirstRowStream::new(
1815 Box::new(rows),
1816 execute_started,
1817 &self.adapter_client,
1818 None,
1819 Some(StatementExecutionStrategy::Constant),
1820 ),
1821 )
1822 .instrument(span)
1823 .await
1824 .map(|(state, _)| state);
1825 }
1826 _ => {
1827 return self
1828 .error(ErrorResponse::error(
1829 SqlState::INTERNAL_ERROR,
1830 "unsupported COPY response type".to_string(),
1831 ))
1832 .await;
1833 }
1834 };
1835 }
1836 ExecuteResponse::CopyFrom {
1837 id,
1838 columns,
1839 params,
1840 ctx_extra,
1841 } => {
1842 let row_desc =
1843 row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
1844 self.copy_from(id, columns, params, row_desc, ctx_extra)
1845 .await
1846 }
1847 ExecuteResponse::TransactionCommitted { params }
1848 | ExecuteResponse::TransactionRolledBack { params } => {
1849 let notify_set: mz_ore::collections::HashSet<String> = self
1850 .adapter_client
1851 .session()
1852 .vars()
1853 .notify_set()
1854 .map(|v| v.name().to_string())
1855 .collect();
1856
1857 for (name, value) in params
1859 .into_iter()
1860 .filter(|(name, _v)| notify_set.contains(*name))
1861 {
1862 let msg = BackendMessage::ParameterStatus(name, value);
1863 self.send(msg).await?;
1864 }
1865 command_complete!()
1866 }
1867
1868 ExecuteResponse::AlteredDefaultPrivileges
1869 | ExecuteResponse::AlteredObject(..)
1870 | ExecuteResponse::AlteredRole
1871 | ExecuteResponse::AlteredSystemConfiguration
1872 | ExecuteResponse::CreatedCluster { .. }
1873 | ExecuteResponse::CreatedClusterReplica { .. }
1874 | ExecuteResponse::CreatedConnection { .. }
1875 | ExecuteResponse::CreatedDatabase { .. }
1876 | ExecuteResponse::CreatedIndex { .. }
1877 | ExecuteResponse::CreatedIntrospectionSubscribe
1878 | ExecuteResponse::CreatedMaterializedView { .. }
1879 | ExecuteResponse::CreatedContinualTask { .. }
1880 | ExecuteResponse::CreatedRole
1881 | ExecuteResponse::CreatedSchema { .. }
1882 | ExecuteResponse::CreatedSecret { .. }
1883 | ExecuteResponse::CreatedSink { .. }
1884 | ExecuteResponse::CreatedSource { .. }
1885 | ExecuteResponse::CreatedTable { .. }
1886 | ExecuteResponse::CreatedType
1887 | ExecuteResponse::CreatedView { .. }
1888 | ExecuteResponse::CreatedViews { .. }
1889 | ExecuteResponse::CreatedNetworkPolicy
1890 | ExecuteResponse::Comment
1891 | ExecuteResponse::Deallocate { .. }
1892 | ExecuteResponse::Deleted(..)
1893 | ExecuteResponse::DiscardedAll
1894 | ExecuteResponse::DiscardedTemp
1895 | ExecuteResponse::DroppedObject(_)
1896 | ExecuteResponse::DroppedOwned
1897 | ExecuteResponse::GrantedPrivilege
1898 | ExecuteResponse::GrantedRole
1899 | ExecuteResponse::Inserted(..)
1900 | ExecuteResponse::Copied(..)
1901 | ExecuteResponse::Prepare
1902 | ExecuteResponse::Raised
1903 | ExecuteResponse::ReassignOwned
1904 | ExecuteResponse::RevokedPrivilege
1905 | ExecuteResponse::RevokedRole
1906 | ExecuteResponse::StartedTransaction { .. }
1907 | ExecuteResponse::Updated(..)
1908 | ExecuteResponse::ValidatedConnection => {
1909 command_complete!()
1910 }
1911 };
1912
1913 assert_none!(tag, "tag created but not consumed: {:?}", tag);
1914 r
1915 }
1916
1917 #[allow(clippy::too_many_arguments)]
1918 #[mz_ore::instrument(level = "debug")]
1920 async fn send_rows(
1921 &mut self,
1922 row_desc: RelationDesc,
1923 portal_name: String,
1924 mut rows: InProgressRows,
1925 max_rows: ExecuteCount,
1926 get_response: GetResponse,
1927 fetch_portal_name: Option<String>,
1928 timeout: ExecuteTimeout,
1929 ) -> Result<(State, SendRowsEndedReason), io::Error> {
1930 let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
1933 name
1934 } else {
1935 &portal_name
1936 };
1937 let result_formats = self
1938 .adapter_client
1939 .session()
1940 .get_portal_unverified(result_format_portal_name)
1941 .expect("valid fetch portal name for send rows")
1942 .result_formats
1943 .clone();
1944
1945 let (mut wait_once, mut deadline) = match timeout {
1946 ExecuteTimeout::None => (false, None),
1947 ExecuteTimeout::Seconds(t) => (
1948 false,
1949 Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
1950 ),
1951 ExecuteTimeout::WaitOnce => (true, None),
1952 };
1953
1954 {
1956 let portal_name_desc = &self
1957 .adapter_client
1958 .session()
1959 .get_portal_unverified(portal_name.as_str())
1960 .expect("portal should exist")
1961 .desc
1962 .relation_desc;
1963 if let Some(portal_name_desc) = portal_name_desc {
1964 soft_assert_eq_or_log!(portal_name_desc, &row_desc);
1965 }
1966 if let Some(fetch_portal_name) = &fetch_portal_name {
1967 let fetch_portal_desc = &self
1968 .adapter_client
1969 .session()
1970 .get_portal_unverified(fetch_portal_name)
1971 .expect("portal should exist")
1972 .desc
1973 .relation_desc;
1974 if let Some(fetch_portal_desc) = fetch_portal_desc {
1975 soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
1976 }
1977 }
1978 }
1979
1980 self.conn.set_encode_state(
1981 row_desc
1982 .typ()
1983 .column_types
1984 .iter()
1985 .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
1986 .zip_eq(result_formats)
1987 .collect(),
1988 );
1989
1990 let mut total_sent_rows = 0;
1991 let mut total_sent_bytes = 0;
1992 let mut want_rows = match max_rows {
1994 ExecuteCount::All => usize::MAX,
1995 ExecuteCount::Count(count) => count,
1996 };
1997
1998 loop {
2000 let batch = if rows.current.is_some() {
2003 FetchResult::Rows(rows.current.take())
2004 } else if want_rows == 0 {
2005 FetchResult::Rows(None)
2006 } else {
2007 let notice_fut = self.adapter_client.session().recv_notice();
2008 tokio::select! {
2009 err = self.conn.wait_closed() => return Err(err),
2010 _ = time::sleep_until(deadline.unwrap_or_else(tokio::time::Instant::now)), if deadline.is_some() => FetchResult::Rows(None),
2011 notice = notice_fut => {
2012 FetchResult::Notice(notice)
2013 }
2014 batch = rows.remaining.recv() => match batch {
2015 None => FetchResult::Rows(None),
2016 Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2017 Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2018 Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2019 },
2020 }
2021 };
2022
2023 match batch {
2024 FetchResult::Rows(None) => break,
2025 FetchResult::Rows(Some(mut batch_rows)) => {
2026 if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2027 let msg = err.to_string();
2028 return self
2029 .error(err.into_response(Severity::Error))
2030 .await
2031 .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2032 }
2033
2034 if wait_once && batch_rows.peek().is_some() {
2038 deadline = Some(tokio::time::Instant::now());
2039 wait_once = false;
2040 }
2041
2042 let mut sent_rows = 0;
2044 let mut sent_bytes = 0;
2045 let messages = (&mut batch_rows)
2046 .map(|row| {
2051 let row_len = row.byte_len();
2052 let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2053 (row_len, BackendMessage::DataRow(values))
2054 })
2055 .inspect(|(row_len, _)| {
2056 sent_bytes += row_len;
2057 sent_rows += 1
2058 })
2059 .map(|(_row_len, row)| row)
2060 .take(want_rows);
2061 self.send_all(messages).await?;
2062
2063 total_sent_rows += sent_rows;
2064 total_sent_bytes += sent_bytes;
2065 want_rows -= sent_rows;
2066
2067 if want_rows == 0 {
2070 if batch_rows.peek().is_some() {
2071 rows.current = Some(batch_rows);
2072 }
2073 break;
2074 }
2075
2076 self.conn.flush().await?;
2077 }
2078 FetchResult::Notice(notice) => {
2079 self.send(notice.into_response()).await?;
2080 self.conn.flush().await?;
2081 }
2082 FetchResult::Error(text) => {
2083 return self
2084 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2085 .await
2086 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2087 }
2088 FetchResult::Canceled => {
2089 return self
2090 .error(ErrorResponse::error(
2091 SqlState::QUERY_CANCELED,
2092 "canceling statement due to user request",
2093 ))
2094 .await
2095 .map(|state| (state, SendRowsEndedReason::Canceled));
2096 }
2097 }
2098 }
2099
2100 let portal = self
2101 .adapter_client
2102 .session()
2103 .get_portal_unverified_mut(&portal_name)
2104 .expect("valid portal name for send rows");
2105
2106 let saw_rows = rows.remaining.saw_rows;
2107 let no_more_rows = rows.no_more_rows();
2108 let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2109
2110 *portal.state = PortalState::InProgress(Some(rows));
2113
2114 let fetch_portal = fetch_portal_name.map(|name| {
2115 self.adapter_client
2116 .session()
2117 .get_portal_unverified_mut(&name)
2118 .expect("valid fetch portal")
2119 });
2120 let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2121 self.send(response_message).await?;
2122
2123 if no_more_rows {
2125 let statement_type = if let Some(stmt) = &self
2126 .adapter_client
2127 .session()
2128 .get_portal_unverified(&portal_name)
2129 .expect("valid portal name for send_rows")
2130 .stmt
2131 {
2132 metrics::statement_type_label_value(stmt.deref())
2133 } else {
2134 "no-statement"
2135 };
2136 let duration = if saw_rows {
2137 recorded_first_row_instant
2138 .expect("recorded_first_row_instant because saw_rows")
2139 .elapsed()
2140 } else {
2141 Duration::ZERO
2145 };
2146 self.adapter_client
2147 .inner()
2148 .metrics()
2149 .result_rows_first_to_last_byte_seconds
2150 .with_label_values(&[statement_type])
2151 .observe(duration.as_secs_f64());
2152 }
2153
2154 Ok((
2155 State::Ready,
2156 SendRowsEndedReason::Success {
2157 result_size: u64::cast_from(total_sent_bytes),
2158 rows_returned: u64::cast_from(total_sent_rows),
2159 },
2160 ))
2161 }
2162
2163 #[mz_ore::instrument(level = "debug")]
2164 async fn copy_rows(
2165 &mut self,
2166 format: CopyFormat,
2167 row_desc: RelationDesc,
2168 mut stream: RecordFirstRowStream,
2169 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2170 let (row_format, encode_format) = match format {
2171 CopyFormat::Text => (
2172 CopyFormatParams::Text(CopyTextFormatParams::default()),
2173 Format::Text,
2174 ),
2175 CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2176 CopyFormat::Csv => (
2177 CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2178 Format::Text,
2179 ),
2180 CopyFormat::Parquet => {
2181 let text = "Parquet format is not supported".to_string();
2182 return self
2183 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2184 .await
2185 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2186 }
2187 };
2188
2189 let encode_fn = |row: &RowRef, typ: &RelationType, out: &mut Vec<u8>| {
2190 mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2191 };
2192
2193 let typ = row_desc.typ();
2194 let column_formats = iter::repeat(encode_format)
2195 .take(typ.column_types.len())
2196 .collect();
2197 self.send(BackendMessage::CopyOutResponse {
2198 overall_format: encode_format,
2199 column_formats,
2200 })
2201 .await?;
2202
2203 let mut out = Vec::new();
2208
2209 if let CopyFormat::Binary = format {
2210 out.extend(b"PGCOPY\n\xFF\r\n\0");
2212 out.extend([0, 0, 0, 0]);
2214 out.extend([0, 0, 0, 0]);
2216 }
2217
2218 let mut count = 0;
2219 let mut total_sent_bytes = 0;
2220 loop {
2221 tokio::select! {
2222 e = self.conn.wait_closed() => return Err(e),
2223 batch = stream.recv() => match batch {
2224 None => break,
2225 Some(PeekResponseUnary::Error(text)) => {
2226 return self
2227 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2228 .await
2229 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2230 }
2231 Some(PeekResponseUnary::Canceled) => {
2232 return self.error(ErrorResponse::error(
2233 SqlState::QUERY_CANCELED,
2234 "canceling statement due to user request",
2235 ))
2236 .await.map(|state| (state, SendRowsEndedReason::Canceled));
2237 }
2238 Some(PeekResponseUnary::Rows(mut rows)) => {
2239 count += rows.count();
2240 while let Some(row) = rows.next() {
2241 total_sent_bytes += row.byte_len();
2242 encode_fn(row, typ, &mut out)?;
2243 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2244 .await?;
2245 }
2246 }
2247 },
2248 notice = self.adapter_client.session().recv_notice() => {
2249 self.send(notice.into_response())
2250 .await?;
2251 self.conn.flush().await?;
2252 }
2253 }
2254
2255 self.conn.flush().await?;
2256 }
2257 if let CopyFormat::Binary = format {
2259 let trailer: i16 = -1;
2260 out.extend(trailer.to_be_bytes());
2261 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2262 .await?;
2263 }
2264
2265 let tag = format!("COPY {}", count);
2266 self.send(BackendMessage::CopyDone).await?;
2267 self.send(BackendMessage::CommandComplete { tag }).await?;
2268 Ok((
2269 State::Ready,
2270 SendRowsEndedReason::Success {
2271 result_size: u64::cast_from(total_sent_bytes),
2272 rows_returned: u64::cast_from(count),
2273 },
2274 ))
2275 }
2276
2277 #[instrument(level = "debug")]
2280 async fn copy_from(
2281 &mut self,
2282 id: CatalogItemId,
2283 columns: Vec<ColumnIndex>,
2284 params: CopyFormatParams<'_>,
2285 row_desc: RelationDesc,
2286 mut ctx_extra: ExecuteContextExtra,
2287 ) -> Result<State, io::Error> {
2288 let res = self
2289 .copy_from_inner(id, columns, params, row_desc, &mut ctx_extra)
2290 .await;
2291 match &res {
2292 Ok(State::Done) => {
2293 self.adapter_client
2297 .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2298 }
2299 Err(e) => {
2300 self.adapter_client.retire_execute(
2301 ctx_extra,
2302 StatementEndedExecutionReason::Errored {
2303 error: format!("{e}"),
2304 },
2305 );
2306 }
2307 other => {
2308 tracing::warn!(?other, "aborting COPY FROM");
2309 self.adapter_client
2310 .retire_execute(ctx_extra, StatementEndedExecutionReason::Aborted);
2311 }
2312 }
2313 res
2314 }
2315
2316 async fn copy_from_inner(
2317 &mut self,
2318 id: CatalogItemId,
2319 columns: Vec<ColumnIndex>,
2320 params: CopyFormatParams<'_>,
2321 row_desc: RelationDesc,
2322 ctx_extra: &mut ExecuteContextExtra,
2323 ) -> Result<State, io::Error> {
2324 let typ = row_desc.typ();
2325 let column_formats = vec![Format::Text; typ.column_types.len()];
2326 self.send(BackendMessage::CopyInResponse {
2327 overall_format: Format::Text,
2328 column_formats,
2329 })
2330 .await?;
2331 self.conn.flush().await?;
2332
2333 let system_vars = self.adapter_client.get_system_vars().await;
2334 let max_size = system_vars
2335 .get(MAX_COPY_FROM_SIZE.name())
2336 .ok()
2337 .and_then(|max_size| max_size.value().parse().ok())
2338 .unwrap_or(usize::MAX);
2339 tracing::debug!("COPY FROM max buffer size: {max_size} bytes");
2340
2341 let mut data = Vec::new();
2342 loop {
2343 let message = self.conn.recv().await?;
2344 match message {
2345 Some(FrontendMessage::CopyData(buf)) => {
2346 if (data.len() + buf.len()) > max_size {
2348 return self
2349 .error(ErrorResponse::error(
2350 SqlState::INSUFFICIENT_RESOURCES,
2351 "COPY FROM STDIN too large",
2352 ))
2353 .await;
2354 }
2355 data.extend(buf)
2356 }
2357 Some(FrontendMessage::CopyDone) => break,
2358 Some(FrontendMessage::CopyFail(err)) => {
2359 self.adapter_client.retire_execute(
2360 std::mem::take(ctx_extra),
2361 StatementEndedExecutionReason::Canceled,
2362 );
2363 return self
2364 .error(ErrorResponse::error(
2365 SqlState::QUERY_CANCELED,
2366 format!("COPY from stdin failed: {}", err),
2367 ))
2368 .await;
2369 }
2370 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2371 Some(_) => {
2372 let msg = "unexpected message type during COPY from stdin";
2373 self.adapter_client.retire_execute(
2374 std::mem::take(ctx_extra),
2375 StatementEndedExecutionReason::Errored {
2376 error: msg.to_string(),
2377 },
2378 );
2379 return self
2380 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
2381 .await;
2382 }
2383 None => {
2384 return Ok(State::Done);
2385 }
2386 }
2387 }
2388
2389 let column_types = typ
2390 .column_types
2391 .iter()
2392 .map(|x| &x.scalar_type)
2393 .map(mz_pgrepr::Type::from)
2394 .collect::<Vec<mz_pgrepr::Type>>();
2395
2396 let rows = match mz_pgcopy::decode_copy_format(&data, &column_types, params) {
2397 Ok(rows) => rows,
2398 Err(e) => {
2399 self.adapter_client.retire_execute(
2400 std::mem::take(ctx_extra),
2401 StatementEndedExecutionReason::Errored {
2402 error: e.to_string(),
2403 },
2404 );
2405 return self
2406 .error(ErrorResponse::error(
2407 SqlState::BAD_COPY_FILE_FORMAT,
2408 format!("{}", e),
2409 ))
2410 .await;
2411 }
2412 };
2413
2414 let count = rows.len();
2415
2416 if let Err(e) = self
2417 .adapter_client
2418 .insert_rows(id, columns, rows, std::mem::take(ctx_extra))
2419 .await
2420 {
2421 self.adapter_client.retire_execute(
2422 std::mem::take(ctx_extra),
2423 StatementEndedExecutionReason::Errored {
2424 error: e.to_string(),
2425 },
2426 );
2427 return self.error(e.into_response(Severity::Error)).await;
2428 }
2429
2430 let tag = format!("COPY {}", count);
2431 self.send(BackendMessage::CommandComplete { tag }).await?;
2432
2433 Ok(State::Ready)
2434 }
2435
2436 #[instrument(level = "debug")]
2437 async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
2438 let notices = self
2439 .adapter_client
2440 .session()
2441 .drain_notices()
2442 .into_iter()
2443 .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
2444 self.send_all(notices).await?;
2445 Ok(())
2446 }
2447
2448 #[instrument(level = "debug")]
2449 async fn error(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
2450 assert!(err.severity.is_error());
2451 debug!(
2452 "cid={} error code={}",
2453 self.adapter_client.session().conn_id(),
2454 err.code.code()
2455 );
2456 let is_fatal = err.severity.is_fatal();
2457 self.send(BackendMessage::ErrorResponse(err)).await?;
2458
2459 let txn = self.adapter_client.session().transaction();
2460 match txn {
2461 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
2464 TransactionStatus::Started(_) => {
2466 self.rollback_transaction().await?;
2467 }
2468 TransactionStatus::InTransactionImplicit(_) => {
2470 self.rollback_transaction().await?;
2471 }
2472 TransactionStatus::InTransaction(_) => {
2474 self.adapter_client.fail_transaction();
2475 }
2476 };
2477 if is_fatal {
2478 Ok(State::Done)
2479 } else {
2480 Ok(State::Drain)
2481 }
2482 }
2483
2484 #[instrument(level = "debug")]
2485 async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
2486 self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
2487 SqlState::IN_FAILED_SQL_TRANSACTION,
2488 ABORTED_TXN_MSG,
2489 )))
2490 .await?;
2491 Ok(State::Drain)
2492 }
2493
2494 fn is_aborted_txn(&mut self) -> bool {
2495 matches!(
2496 self.adapter_client.session().transaction(),
2497 TransactionStatus::Failed(_)
2498 )
2499 }
2500}
2501
2502fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
2503 match (formats.len(), n) {
2504 (0, e) => Ok(vec![Format::Text; e]),
2505 (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
2506 (a, e) if a == e => Ok(formats),
2507 (a, e) => Err(format!(
2508 "expected {} field format specifiers, but got {}",
2509 e, a
2510 )),
2511 }
2512}
2513
2514fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
2515 match &stmt_desc.relation_desc {
2516 Some(desc) if !stmt_desc.is_copy => {
2517 BackendMessage::RowDescription(message::encode_row_description(desc, formats))
2518 }
2519 _ => BackendMessage::NoData,
2520 }
2521}
2522
2523type GetResponse = fn(
2524 max_rows: ExecuteCount,
2525 total_sent_rows: usize,
2526 fetch_portal: Option<PortalRefMut>,
2527) -> BackendMessage;
2528
2529fn portal_exec_message(
2532 max_rows: ExecuteCount,
2533 total_sent_rows: usize,
2534 _fetch_portal: Option<PortalRefMut>,
2535) -> BackendMessage {
2536 match max_rows {
2543 ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
2544 BackendMessage::PortalSuspended
2545 }
2546 _ => BackendMessage::CommandComplete {
2547 tag: format!("SELECT {}", total_sent_rows),
2548 },
2549 }
2550}
2551
2552fn fetch_message(
2554 _max_rows: ExecuteCount,
2555 total_sent_rows: usize,
2556 fetch_portal: Option<PortalRefMut>,
2557) -> BackendMessage {
2558 let tag = format!("FETCH {}", total_sent_rows);
2559 if let Some(portal) = fetch_portal {
2560 *portal.state = PortalState::Completed(Some(tag.clone()));
2561 }
2562 BackendMessage::CommandComplete { tag }
2563}
2564
2565#[derive(Debug, Copy, Clone)]
2566enum ExecuteCount {
2567 All,
2568 Count(usize),
2569}
2570
2571fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
2573 match stmt {
2574 Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
2576 None => false,
2577 }
2578}
2579
2580#[derive(Debug)]
2581enum FetchResult {
2582 Rows(Option<Box<dyn RowIterator + Send + Sync>>),
2583 Canceled,
2584 Error(String),
2585 Notice(AdapterNotice),
2586}
2587
2588#[cfg(test)]
2589mod test {
2590 use super::*;
2591
2592 #[mz_ore::test]
2593 fn test_parse_options() {
2594 struct TestCase {
2595 input: &'static str,
2596 expect: Result<Vec<(&'static str, &'static str)>, ()>,
2597 }
2598 let tests = vec![
2599 TestCase {
2600 input: "",
2601 expect: Ok(vec![]),
2602 },
2603 TestCase {
2604 input: "--key",
2605 expect: Err(()),
2606 },
2607 TestCase {
2608 input: "--key=val",
2609 expect: Ok(vec![("key", "val")]),
2610 },
2611 TestCase {
2612 input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
2613 expect: Ok(vec![
2614 ("key", "val"),
2615 ("key2", "val2"),
2616 ("key3", "val3"),
2617 ("key4", "val4"),
2618 ("key5", "val5"),
2619 ]),
2620 },
2621 TestCase {
2622 input: r#"-c\ key=val"#,
2623 expect: Ok(vec![(" key", "val")]),
2624 },
2625 TestCase {
2626 input: "--key=val -ckey2 val2",
2627 expect: Err(()),
2628 },
2629 TestCase {
2631 input: "--key=",
2632 expect: Ok(vec![("key", "")]),
2633 },
2634 ];
2635 for test in tests {
2636 let got = parse_options(test.input);
2637 let expect = test.expect.map(|r| {
2638 r.into_iter()
2639 .map(|(k, v)| (k.to_owned(), v.to_owned()))
2640 .collect()
2641 });
2642 assert_eq!(got, expect, "input: {}", test.input);
2643 }
2644 }
2645
2646 #[mz_ore::test]
2647 fn test_parse_option() {
2648 struct TestCase {
2649 input: &'static str,
2650 expect: Result<(&'static str, &'static str), ()>,
2651 }
2652 let tests = vec![
2653 TestCase {
2654 input: "",
2655 expect: Err(()),
2656 },
2657 TestCase {
2658 input: "--",
2659 expect: Err(()),
2660 },
2661 TestCase {
2662 input: "--c",
2663 expect: Err(()),
2664 },
2665 TestCase {
2666 input: "a=b",
2667 expect: Err(()),
2668 },
2669 TestCase {
2670 input: "--a=b",
2671 expect: Ok(("a", "b")),
2672 },
2673 TestCase {
2674 input: "--ca=b",
2675 expect: Ok(("ca", "b")),
2676 },
2677 TestCase {
2678 input: "-ca=b",
2679 expect: Ok(("a", "b")),
2680 },
2681 TestCase {
2683 input: "--=",
2684 expect: Ok(("", "")),
2685 },
2686 ];
2687 for test in tests {
2688 let got = parse_option(test.input);
2689 assert_eq!(got, test.expect, "input: {}", test.input);
2690 }
2691 }
2692
2693 #[mz_ore::test]
2694 fn test_split_options() {
2695 struct TestCase {
2696 input: &'static str,
2697 expect: Vec<&'static str>,
2698 }
2699 let tests = vec![
2700 TestCase {
2701 input: "",
2702 expect: vec![],
2703 },
2704 TestCase {
2705 input: " ",
2706 expect: vec![],
2707 },
2708 TestCase {
2709 input: " a ",
2710 expect: vec!["a"],
2711 },
2712 TestCase {
2713 input: " ab cd ",
2714 expect: vec!["ab", "cd"],
2715 },
2716 TestCase {
2717 input: r#" ab\ cd "#,
2718 expect: vec!["ab ", "cd"],
2719 },
2720 TestCase {
2721 input: r#" ab\\ cd "#,
2722 expect: vec![r#"ab\"#, "cd"],
2723 },
2724 TestCase {
2725 input: r#" ab\\\ cd "#,
2726 expect: vec![r#"ab\ "#, "cd"],
2727 },
2728 TestCase {
2729 input: r#" ab\\\ cd "#,
2730 expect: vec![r#"ab\ cd"#],
2731 },
2732 TestCase {
2733 input: r#" ab\\\cd "#,
2734 expect: vec![r#"ab\cd"#],
2735 },
2736 TestCase {
2737 input: r#"a\"#,
2738 expect: vec!["a"],
2739 },
2740 TestCase {
2741 input: r#"a\ "#,
2742 expect: vec!["a "],
2743 },
2744 TestCase {
2745 input: r#"\"#,
2746 expect: vec![],
2747 },
2748 TestCase {
2749 input: r#"\ "#,
2750 expect: vec![r#" "#],
2751 },
2752 TestCase {
2753 input: r#" \ "#,
2754 expect: vec![r#" "#],
2755 },
2756 TestCase {
2757 input: r#"\ "#,
2758 expect: vec![r#" "#],
2759 },
2760 ];
2761 for test in tests {
2762 let got = split_options(test.input);
2763 assert_eq!(got, test.expect, "input: {}", test.input);
2764 }
2765 }
2766}