1use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use futures::future::{BoxFuture, FutureExt, pending};
21use itertools::Itertools;
22use mz_adapter::client::RecordFirstRowStream;
23use mz_adapter::session::{
24 EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState,
25 SessionConfig, TransactionStatus,
26};
27use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
28use mz_adapter::{
29 AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, PeekResponseUnary, metrics,
30 verify_datum_desc,
31};
32use mz_auth::password::Password;
33use mz_authenticator::Authenticator;
34use mz_ore::cast::CastFrom;
35use mz_ore::netio::AsyncReady;
36use mz_ore::now::{EpochMillis, SYSTEM_TIME};
37use mz_ore::str::StrExt;
38use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
39use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
40use mz_pgwire_common::{
41 ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
42 VERSIONS,
43};
44use mz_repr::user::InternalUserMetadata;
45use mz_repr::{
46 CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
47 SqlRelationType, SqlScalarType,
48};
49use mz_server_core::TlsMode;
50use mz_server_core::listeners::AllowedRoles;
51use mz_sql::ast::display::AstDisplay;
52use mz_sql::ast::{CopyDirection, CopyStatement, FetchDirection, Ident, Raw, Statement};
53use mz_sql::parse::StatementParseResult;
54use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
55use mz_sql::session::metadata::SessionMetadata;
56use mz_sql::session::user::INTERNAL_USER_NAMES;
57use mz_sql::session::vars::{MAX_COPY_FROM_SIZE, Var, VarInput};
58use postgres::error::SqlState;
59use tokio::io::{self, AsyncRead, AsyncWrite};
60use tokio::select;
61use tokio::time::{self};
62use tokio_metrics::TaskMetrics;
63use tokio_stream::wrappers::UnboundedReceiverStream;
64use tracing::{Instrument, debug, debug_span, warn};
65use uuid::Uuid;
66
67use crate::codec::{
68 FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
69};
70use crate::message::{
71 self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
72 SASLServerFirstMessage,
73};
74
75pub fn match_handshake(buf: &[u8]) -> bool {
79 if buf.len() < 8 {
89 return false;
90 }
91 let version = NetworkEndian::read_i32(&buf[4..8]);
92 VERSIONS.contains(&version)
93}
94
95pub struct RunParams<'a, A, I>
97where
98 I: Iterator<Item = TaskMetrics> + Send,
99{
100 pub tls_mode: Option<TlsMode>,
102 pub adapter_client: mz_adapter::Client,
104 pub conn: &'a mut FramedConn<A>,
106 pub conn_uuid: Uuid,
108 pub version: i32,
110 pub params: BTreeMap<String, String>,
112 pub authenticator: Authenticator,
114 pub active_connection_counter: ConnectionCounter,
116 pub helm_chart_version: Option<String>,
118 pub allowed_roles: AllowedRoles,
120 pub tokio_metrics_intervals: I,
122}
123
124#[mz_ore::instrument(level = "debug")]
134pub async fn run<'a, A, I>(
135 RunParams {
136 tls_mode,
137 adapter_client,
138 conn,
139 conn_uuid,
140 version,
141 mut params,
142 authenticator,
143 active_connection_counter,
144 helm_chart_version,
145 allowed_roles,
146 tokio_metrics_intervals,
147 }: RunParams<'a, A, I>,
148) -> Result<(), io::Error>
149where
150 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
151 I: Iterator<Item = TaskMetrics> + Send,
152{
153 if version != VERSION_3 {
154 return conn
155 .send(ErrorResponse::fatal(
156 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
157 "server does not support the client's requested protocol version",
158 ))
159 .await;
160 }
161
162 let user = params.remove("user").unwrap_or_else(String::new);
163
164 let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
166 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
168 let role_allowed = match allowed_roles {
169 AllowedRoles::Normal => !is_reserved_user,
170 AllowedRoles::Internal => is_internal_user,
171 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
172 };
173 if !role_allowed {
174 let msg = format!("unauthorized login to user '{user}'");
175 return conn
176 .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
177 .await;
178 }
179
180 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
181 return conn.send(err).await;
182 }
183
184 let (mut session, expired) = match authenticator {
185 Authenticator::Frontegg(frontegg) => {
186 conn.send(BackendMessage::AuthenticationCleartextPassword)
187 .await?;
188 conn.flush().await?;
189 let password = match conn.recv().await? {
190 Some(FrontendMessage::RawAuthentication(data)) => {
191 match decode_password(Cursor::new(&data)).ok() {
192 Some(FrontendMessage::Password { password }) => password,
193 _ => {
194 return conn
195 .send(ErrorResponse::fatal(
196 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
197 "expected Password message",
198 ))
199 .await;
200 }
201 }
202 }
203 _ => {
204 return conn
205 .send(ErrorResponse::fatal(
206 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
207 "expected Password message",
208 ))
209 .await;
210 }
211 };
212
213 let auth_response = frontegg.authenticate(&user, &password).await;
214 match auth_response {
215 Ok(mut auth_session) => {
216 let session = adapter_client.new_session(SessionConfig {
224 conn_id: conn.conn_id().clone(),
225 uuid: conn_uuid,
226 user: auth_session.user().into(),
227 client_ip: conn.peer_addr().clone(),
228 external_metadata_rx: Some(auth_session.external_metadata_rx()),
229 internal_user_metadata: None,
230 helm_chart_version,
231 });
232 let expired = async move { auth_session.expired().await };
233 (session, expired.left_future())
234 }
235 Err(err) => {
236 warn!(?err, "pgwire connection failed authentication");
237 return conn
238 .send(ErrorResponse::fatal(
239 SqlState::INVALID_PASSWORD,
240 "invalid password",
241 ))
242 .await;
243 }
244 }
245 }
246 Authenticator::Password(adapter_client) => {
247 conn.send(BackendMessage::AuthenticationCleartextPassword)
248 .await?;
249 conn.flush().await?;
250 let password = match conn.recv().await? {
251 Some(FrontendMessage::RawAuthentication(data)) => {
252 match decode_password(Cursor::new(&data)).ok() {
253 Some(FrontendMessage::Password { password }) => Password(password),
254 _ => {
255 return conn
256 .send(ErrorResponse::fatal(
257 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
258 "expected Password message",
259 ))
260 .await;
261 }
262 }
263 }
264 _ => {
265 return conn
266 .send(ErrorResponse::fatal(
267 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
268 "expected Password message",
269 ))
270 .await;
271 }
272 };
273 let auth_response = match adapter_client.authenticate(&user, &password).await {
274 Ok(resp) => resp,
275 Err(err) => {
276 warn!(?err, "pgwire connection failed authentication");
277 return conn
278 .send(ErrorResponse::fatal(
279 SqlState::INVALID_PASSWORD,
280 "invalid password",
281 ))
282 .await;
283 }
284 };
285 let session = adapter_client.new_session(SessionConfig {
286 conn_id: conn.conn_id().clone(),
287 uuid: conn_uuid,
288 user,
289 client_ip: conn.peer_addr().clone(),
290 external_metadata_rx: None,
291 internal_user_metadata: Some(InternalUserMetadata {
292 superuser: auth_response.superuser,
293 }),
294 helm_chart_version,
295 });
296 let auth_session = pending().right_future();
298 (session, auth_session)
299 }
300 Authenticator::Sasl(adapter_client) => {
301 conn.send(BackendMessage::AuthenticationSASL).await?;
303 conn.flush().await?;
304 let (mechanism, initial_response) = match conn.recv().await? {
306 Some(FrontendMessage::RawAuthentication(data)) => {
307 match decode_sasl_initial_response(Cursor::new(&data)).ok() {
308 Some(FrontendMessage::SASLInitialResponse {
309 gs2_header,
310 mechanism,
311 initial_response,
312 }) => {
313 if gs2_header.channel_binding_enabled() {
315 return conn
316 .send(ErrorResponse::fatal(
317 SqlState::PROTOCOL_VIOLATION,
318 "channel binding not supported",
319 ))
320 .await;
321 }
322 (mechanism, initial_response)
323 }
324 _ => {
325 return conn
326 .send(ErrorResponse::fatal(
327 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
328 "expected SASLInitialResponse message",
329 ))
330 .await;
331 }
332 }
333 }
334 _ => {
335 return conn
336 .send(ErrorResponse::fatal(
337 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
338 "expected SASLInitialResponse message",
339 ))
340 .await;
341 }
342 };
343
344 if mechanism != "SCRAM-SHA-256" {
345 return conn
346 .send(ErrorResponse::fatal(
347 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
348 "unsupported SASL mechanism",
349 ))
350 .await;
351 }
352
353 if initial_response.nonce.len() > 256 {
354 return conn
355 .send(ErrorResponse::fatal(
356 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
357 "nonce too long",
358 ))
359 .await;
360 }
361
362 let (server_first_message_raw, mock_hash) = match adapter_client
363 .generate_sasl_challenge(&user, &initial_response.nonce)
364 .await
365 {
366 Ok(response) => {
367 let server_first_message_raw = format!(
368 "r={},s={},i={}",
369 response.nonce, response.salt, response.iteration_count
370 );
371
372 let client_key = [0u8; 32];
373 let server_key = [1u8; 32];
374 let mock_hash = format!(
375 "SCRAM-SHA-256${}:{}${}:{}",
376 response.iteration_count,
377 response.salt,
378 BASE64_STANDARD.encode(client_key),
379 BASE64_STANDARD.encode(server_key)
380 );
381
382 conn.send(BackendMessage::AuthenticationSASLContinue(
383 SASLServerFirstMessage {
384 iteration_count: response.iteration_count,
385 nonce: response.nonce,
386 salt: response.salt,
387 },
388 ))
389 .await?;
390 conn.flush().await?;
391 (server_first_message_raw, mock_hash)
392 }
393 Err(e) => {
394 return conn.send(e.into_response(Severity::Fatal)).await;
395 }
396 };
397
398 let auth_resp = match conn.recv().await? {
399 Some(FrontendMessage::RawAuthentication(data)) => {
400 match decode_sasl_response(Cursor::new(&data)).ok() {
401 Some(FrontendMessage::SASLResponse(response)) => {
402 let auth_message = format!(
403 "{},{},{}",
404 initial_response.client_first_message_bare_raw,
405 server_first_message_raw,
406 response.client_final_message_bare_raw
407 );
408 if response.proof.len() > 1024 {
409 return conn
410 .send(ErrorResponse::fatal(
411 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
412 "proof too long",
413 ))
414 .await;
415 }
416 match adapter_client
417 .verify_sasl_proof(
418 &user,
419 &response.proof,
420 &auth_message,
421 &mock_hash,
422 )
423 .await
424 {
425 Ok(resp) => {
426 conn.send(BackendMessage::AuthenticationSASLFinal(
427 SASLServerFinalMessage {
428 kind: SASLServerFinalMessageKinds::Verifier(
429 resp.verifier,
430 ),
431 extensions: vec![],
432 },
433 ))
434 .await?;
435 conn.flush().await?;
436 resp.auth_resp
437 }
438 Err(_) => {
439 return conn
440 .send(ErrorResponse::fatal(
441 SqlState::INVALID_PASSWORD,
442 "invalid password",
443 ))
444 .await;
445 }
446 }
447 }
448 _ => {
449 return conn
450 .send(ErrorResponse::fatal(
451 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
452 "expected SASLResponse message",
453 ))
454 .await;
455 }
456 }
457 }
458 _ => {
459 return conn
460 .send(ErrorResponse::fatal(
461 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
462 "expected SASLResponse message",
463 ))
464 .await;
465 }
466 };
467
468 let session = adapter_client.new_session(SessionConfig {
469 conn_id: conn.conn_id().clone(),
470 uuid: conn_uuid,
471 user,
472 client_ip: conn.peer_addr().clone(),
473 external_metadata_rx: None,
474 internal_user_metadata: Some(InternalUserMetadata {
475 superuser: auth_resp.superuser,
476 }),
477 helm_chart_version,
478 });
479 let auth_session = pending().right_future();
481 (session, auth_session)
482 }
483 Authenticator::None => {
484 let session = adapter_client.new_session(SessionConfig {
485 conn_id: conn.conn_id().clone(),
486 uuid: conn_uuid,
487 user,
488 client_ip: conn.peer_addr().clone(),
489 external_metadata_rx: None,
490 internal_user_metadata: None,
491 helm_chart_version,
492 });
493 let auth_session = pending().right_future();
495 (session, auth_session)
496 }
497 };
498
499 let system_vars = adapter_client.get_system_vars().await;
500 for (name, value) in params {
501 let settings = match name.as_str() {
502 "options" => match parse_options(&value) {
503 Ok(opts) => opts,
504 Err(()) => {
505 session.add_notice(AdapterNotice::BadStartupSetting {
506 name,
507 reason: "could not parse".into(),
508 });
509 continue;
510 }
511 },
512 _ => vec![(name, value)],
513 };
514 for (key, val) in settings {
515 const LOCAL: bool = false;
516 if let Err(err) =
521 session
522 .vars_mut()
523 .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
524 {
525 session.add_notice(AdapterNotice::BadStartupSetting {
526 name: key,
527 reason: err.to_string(),
528 });
529 }
530 }
531 }
532 session
533 .vars_mut()
534 .end_transaction(EndTransactionAction::Commit);
535
536 let _guard = match active_connection_counter.allocate_connection(session.user()) {
537 Ok(drop_connection) => drop_connection,
538 Err(e) => {
539 let e: AdapterError = e.into();
540 return conn.send(e.into_response(Severity::Fatal)).await;
541 }
542 };
543
544 let mut adapter_client = match adapter_client.startup(session).await {
546 Ok(adapter_client) => adapter_client,
547 Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
548 };
549
550 let mut buf = vec![BackendMessage::AuthenticationOk];
551 for var in adapter_client.session().vars().notify_set() {
552 buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
553 }
554 buf.push(BackendMessage::BackendKeyData {
555 conn_id: adapter_client.session().conn_id().unhandled(),
556 secret_key: adapter_client.session().secret_key(),
557 });
558 buf.extend(
559 adapter_client
560 .session()
561 .drain_notices()
562 .into_iter()
563 .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
564 );
565 buf.push(BackendMessage::ReadyForQuery(
566 adapter_client.session().transaction().into(),
567 ));
568 conn.send_all(buf).await?;
569 conn.flush().await?;
570
571 let machine = StateMachine {
572 conn,
573 adapter_client,
574 txn_needs_commit: false,
575 tokio_metrics_intervals,
576 };
577
578 select! {
579 r = machine.run() => {
580 if let Err(err) = &r {
585 let _ = conn
586 .send(ErrorResponse::fatal(
587 SqlState::CONNECTION_FAILURE,
588 err.to_string(),
589 ))
590 .await;
591 let _ = conn.flush().await;
592 }
593 r
594 },
595 _ = expired => {
596 conn
597 .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
598 .await?;
599 conn.flush().await
600 }
601 }
602}
603
604fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
609 let opts = split_options(value);
610 let mut pairs = Vec::with_capacity(opts.len());
611 let mut seen_prefix = false;
612 for opt in opts {
613 if !seen_prefix {
614 if opt == "-c" {
615 seen_prefix = true;
616 } else {
617 let (key, val) = parse_option(&opt)?;
618 pairs.push((key.to_owned(), val.to_owned()));
619 }
620 } else {
621 let (key, val) = opt.split_once('=').ok_or(())?;
622 pairs.push((key.to_owned(), val.to_owned()));
623 seen_prefix = false;
624 }
625 }
626 Ok(pairs)
627}
628
629fn parse_option(option: &str) -> Result<(&str, &str), ()> {
633 let (key, value) = option.split_once('=').ok_or(())?;
634 for prefix in &["-c", "--"] {
635 if let Some(key) = key.strip_prefix(prefix) {
636 return Ok((key, value));
637 }
638 }
639 Err(())
640}
641
642fn split_options(value: &str) -> Vec<String> {
644 let mut strs = Vec::new();
645 let mut current = String::new();
649 let mut was_slash = false;
650 for c in value.chars() {
651 was_slash = match c {
652 ' ' => {
653 if was_slash {
654 current.push(' ');
655 } else if !current.is_empty() {
656 strs.push(std::mem::take(&mut current));
659 }
660 false
661 }
662 '\\' => {
663 if was_slash {
664 current.push('\\');
667 false
668 } else {
669 true
670 }
671 }
672 _ => {
673 current.push(c);
674 false
675 }
676 };
677 }
678 if !current.is_empty() {
680 strs.push(current);
681 }
682 strs
683}
684
685#[derive(Debug)]
686enum State {
687 Ready,
688 Drain,
689 Done,
690}
691
692struct StateMachine<'a, A, I>
693where
694 I: Iterator<Item = TaskMetrics> + Send + 'a,
695{
696 conn: &'a mut FramedConn<A>,
697 adapter_client: mz_adapter::SessionClient,
698 txn_needs_commit: bool,
699 tokio_metrics_intervals: I,
700}
701
702enum SendRowsEndedReason {
703 Success {
704 result_size: u64,
705 rows_returned: u64,
706 },
707 Errored {
708 error: String,
709 },
710 Canceled,
711}
712
713const ABORTED_TXN_MSG: &str =
714 "current transaction is aborted, commands ignored until end of transaction block";
715
716impl<'a, A, I> StateMachine<'a, A, I>
717where
718 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
719 I: Iterator<Item = TaskMetrics> + Send + 'a,
720{
721 #[allow(clippy::manual_async_fn)]
725 #[mz_ore::instrument(level = "debug")]
726 fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
727 async move {
728 let mut state = State::Ready;
729 loop {
730 self.send_pending_notices().await?;
731 state = match state {
732 State::Ready => self.advance_ready().await?,
733 State::Drain => self.advance_drain().await?,
734 State::Done => return Ok(()),
735 };
736 self.adapter_client
737 .add_idle_in_transaction_session_timeout();
738 }
739 }
740 }
741
742 #[instrument(level = "debug")]
743 async fn advance_ready(&mut self) -> Result<State, io::Error> {
744 self.tokio_metrics_intervals
746 .next()
747 .expect("infinite iterator");
748
749 let message = select! {
751 biased;
752
753 Some(timeout) = self.adapter_client.recv_timeout() => {
755 let err: AdapterError = timeout.into();
756 let conn_id = self.adapter_client.session().conn_id();
757 tracing::warn!("session timed out, conn_id {}", conn_id);
758
759 let error_response = err.into_response(Severity::Fatal);
761 let error_state = self.error(error_response).await;
762
763 self.adapter_client.terminate().await;
765
766 let _ = self.conn.recv().await?;
770 return error_state;
771 },
772 message = self.conn.recv() => message?,
774 };
775
776 let interval = self
778 .tokio_metrics_intervals
779 .next()
780 .expect("infinite iterator");
781 let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
782
783 let received = SYSTEM_TIME();
787
788 self.adapter_client
789 .remove_idle_in_transaction_session_timeout();
790
791 let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
794
795 let start = message.as_ref().map(|_| Instant::now());
796 let next_state = match message {
797 Some(FrontendMessage::Query { sql }) => {
798 let query_root_span =
799 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
800 query_root_span.follows_from(tracing::Span::current());
801 self.query(sql, received)
802 .instrument(query_root_span)
803 .await?
804 }
805 Some(FrontendMessage::Parse {
806 name,
807 sql,
808 param_types,
809 }) => self.parse(name, sql, param_types).await?,
810 Some(FrontendMessage::Bind {
811 portal_name,
812 statement_name,
813 param_formats,
814 raw_params,
815 result_formats,
816 }) => {
817 self.bind(
818 portal_name,
819 statement_name,
820 param_formats,
821 raw_params,
822 result_formats,
823 )
824 .await?
825 }
826 Some(FrontendMessage::Execute {
827 portal_name,
828 max_rows,
829 }) => {
830 let max_rows = match usize::try_from(max_rows) {
831 Ok(0) | Err(_) => ExecuteCount::All, Ok(n) => ExecuteCount::Count(n),
833 };
834 let execute_root_span =
835 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
836 execute_root_span.follows_from(tracing::Span::current());
837 let state = self
838 .execute(
839 portal_name,
840 max_rows,
841 portal_exec_message,
842 None,
843 ExecuteTimeout::None,
844 None,
845 Some(received),
846 )
847 .instrument(execute_root_span)
848 .await?;
849 if self.adapter_client.session().transaction().is_implicit() {
864 self.txn_needs_commit = true;
865 }
866 state
867 }
868 Some(FrontendMessage::DescribeStatement { name }) => {
869 self.describe_statement(&name).await?
870 }
871 Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
872 Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
873 Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
874 Some(FrontendMessage::Flush) => self.flush().await?,
875 Some(FrontendMessage::Sync) => self.sync().await?,
876 Some(FrontendMessage::Terminate) => State::Done,
877
878 Some(FrontendMessage::CopyData(_))
879 | Some(FrontendMessage::CopyDone)
880 | Some(FrontendMessage::CopyFail(_))
881 | Some(FrontendMessage::Password { .. })
882 | Some(FrontendMessage::RawAuthentication(_))
883 | Some(FrontendMessage::SASLInitialResponse { .. })
884 | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
885 None => State::Done,
886 };
887
888 if let Some(start) = start {
889 self.adapter_client
890 .inner()
891 .metrics()
892 .pgwire_message_processing_seconds
893 .with_label_values(&[message_name])
894 .observe(start.elapsed().as_secs_f64());
895 }
896 self.adapter_client
897 .inner()
898 .metrics()
899 .pgwire_recv_scheduling_delay_ms
900 .with_label_values(&[message_name])
901 .observe(recv_scheduling_delay_ms);
902
903 Ok(next_state)
904 }
905
906 async fn advance_drain(&mut self) -> Result<State, io::Error> {
907 let message = self.conn.recv().await?;
908 if message.is_some() {
909 self.adapter_client
910 .remove_idle_in_transaction_session_timeout();
911 }
912 match message {
913 Some(FrontendMessage::Sync) => self.sync().await,
914 None => Ok(State::Done),
915 _ => Ok(State::Drain),
916 }
917 }
918
919 #[instrument(level = "debug")]
923 async fn one_query(
924 &mut self,
925 stmt: Statement<Raw>,
926 sql: String,
927 lifecycle_timestamps: LifecycleTimestamps,
928 ) -> Result<State, io::Error> {
929 const EMPTY_PORTAL: &str = "";
932 if let Err(e) = self
933 .adapter_client
934 .declare(EMPTY_PORTAL.to_string(), stmt, sql)
935 .await
936 {
937 return self.error(e.into_response(Severity::Error)).await;
938 }
939 let portal = self
940 .adapter_client
941 .session()
942 .get_portal_unverified_mut(EMPTY_PORTAL)
943 .expect("unnamed portal should be present");
944
945 *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
946
947 let stmt_desc = portal.desc.clone();
948 if !stmt_desc.param_types.is_empty() {
949 return self
950 .error(ErrorResponse::error(
951 SqlState::UNDEFINED_PARAMETER,
952 "there is no parameter $1",
953 ))
954 .await;
955 }
956
957 if let Some(relation_desc) = &stmt_desc.relation_desc {
959 if !stmt_desc.is_copy {
960 let formats = vec![Format::Text; stmt_desc.arity()];
961 self.send(BackendMessage::RowDescription(
962 message::encode_row_description(relation_desc, &formats),
963 ))
964 .await?;
965 }
966 }
967
968 let result = match self
969 .adapter_client
970 .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
971 .await
972 {
973 Ok((response, execute_started)) => {
974 self.send_pending_notices().await?;
975 self.send_execute_response(
976 response,
977 stmt_desc.relation_desc,
978 EMPTY_PORTAL.to_string(),
979 ExecuteCount::All,
980 portal_exec_message,
981 None,
982 ExecuteTimeout::None,
983 execute_started,
984 )
985 .await
986 }
987 Err(e) => {
988 self.send_pending_notices().await?;
989 self.error(e.into_response(Severity::Error)).await
990 }
991 };
992
993 self.adapter_client.session().remove_portal(EMPTY_PORTAL);
995
996 result
997 }
998
999 async fn ensure_transaction(
1000 &mut self,
1001 num_stmts: usize,
1002 message_type: &str,
1003 ) -> Result<(), io::Error> {
1004 let start = Instant::now();
1005 if self.txn_needs_commit {
1006 self.commit_transaction().await?;
1007 }
1008 let res = self.adapter_client.start_transaction(Some(num_stmts));
1011 assert_ok!(res);
1012 self.adapter_client
1013 .inner()
1014 .metrics()
1015 .pgwire_ensure_transaction_seconds
1016 .with_label_values(&[message_type])
1017 .observe(start.elapsed().as_secs_f64());
1018 Ok(())
1019 }
1020
1021 fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1022 let parse_start = Instant::now();
1023 let result = match self.adapter_client.parse(sql) {
1024 Ok(result) => result.map_err(|e| {
1025 let pos = sql[..e.error.pos].chars().count() + 1;
1028 ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1029 }),
1030 Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1031 };
1032 self.adapter_client
1033 .inner()
1034 .metrics()
1035 .parse_seconds
1036 .observe(parse_start.elapsed().as_secs_f64());
1037 result
1038 }
1039
1040 #[instrument(level = "debug")]
1045 async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1046 let stmts = match self.parse_sql(&sql) {
1048 Ok(stmts) => stmts,
1049 Err(err) => {
1050 self.error(err).await?;
1051 return self.ready().await;
1052 }
1053 };
1054
1055 let num_stmts = stmts.len();
1056
1057 for StatementParseResult { ast: stmt, sql } in stmts {
1059 if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1061 self.aborted_txn_error().await?;
1062 break;
1063 }
1064
1065 self.ensure_transaction(num_stmts, "query").await?;
1073
1074 match self
1075 .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1076 .await?
1077 {
1078 State::Ready => (),
1079 State::Drain => break,
1080 State::Done => return Ok(State::Done),
1081 }
1082 }
1083
1084 {
1086 if self.adapter_client.session().transaction().is_implicit() {
1087 self.commit_transaction().await?;
1088 }
1089 }
1090
1091 if num_stmts == 0 {
1092 self.send(BackendMessage::EmptyQueryResponse).await?;
1093 }
1094
1095 self.ready().await
1096 }
1097
1098 #[instrument(level = "debug")]
1099 async fn parse(
1100 &mut self,
1101 name: String,
1102 sql: String,
1103 param_oids: Vec<u32>,
1104 ) -> Result<State, io::Error> {
1105 self.ensure_transaction(1, "parse").await?;
1107
1108 let mut param_types = vec![];
1109 for oid in param_oids {
1110 match mz_pgrepr::Type::from_oid(oid) {
1111 Ok(ty) => match SqlScalarType::try_from(&ty) {
1112 Ok(ty) => param_types.push(Some(ty)),
1113 Err(err) => {
1114 return self
1115 .error(ErrorResponse::error(
1116 SqlState::INVALID_PARAMETER_VALUE,
1117 err.to_string(),
1118 ))
1119 .await;
1120 }
1121 },
1122 Err(_) if oid == 0 => param_types.push(None),
1123 Err(e) => {
1124 return self
1125 .error(ErrorResponse::error(
1126 SqlState::PROTOCOL_VIOLATION,
1127 e.to_string(),
1128 ))
1129 .await;
1130 }
1131 }
1132 }
1133
1134 let stmts = match self.parse_sql(&sql) {
1135 Ok(stmts) => stmts,
1136 Err(err) => {
1137 return self.error(err).await;
1138 }
1139 };
1140 if stmts.len() > 1 {
1141 return self
1142 .error(ErrorResponse::error(
1143 SqlState::INTERNAL_ERROR,
1144 "cannot insert multiple commands into a prepared statement",
1145 ))
1146 .await;
1147 }
1148 let (maybe_stmt, sql) = match stmts.into_iter().next() {
1149 None => (None, ""),
1150 Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1151 };
1152 if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1153 return self.aborted_txn_error().await;
1154 }
1155 match self
1156 .adapter_client
1157 .prepare(name, maybe_stmt, sql.to_string(), param_types)
1158 .await
1159 {
1160 Ok(()) => {
1161 self.send(BackendMessage::ParseComplete).await?;
1162 Ok(State::Ready)
1163 }
1164 Err(e) => self.error(e.into_response(Severity::Error)).await,
1165 }
1166 }
1167
1168 #[instrument(level = "debug")]
1170 async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1171 self.end_transaction(EndTransactionAction::Commit).await
1172 }
1173
1174 #[instrument(level = "debug")]
1176 async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1177 self.end_transaction(EndTransactionAction::Rollback).await
1178 }
1179
1180 #[instrument(level = "debug")]
1182 async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1183 self.txn_needs_commit = false;
1184 let resp = self.adapter_client.end_transaction(action).await;
1185 if let Err(err) = resp {
1186 self.send(BackendMessage::ErrorResponse(
1187 err.into_response(Severity::Error),
1188 ))
1189 .await?;
1190 }
1191 Ok(())
1192 }
1193
1194 #[instrument(level = "debug")]
1195 async fn bind(
1196 &mut self,
1197 portal_name: String,
1198 statement_name: String,
1199 param_formats: Vec<Format>,
1200 raw_params: Vec<Option<Vec<u8>>>,
1201 result_formats: Vec<Format>,
1202 ) -> Result<State, io::Error> {
1203 self.ensure_transaction(1, "bind").await?;
1205
1206 let aborted_txn = self.is_aborted_txn();
1207 let stmt = match self
1208 .adapter_client
1209 .get_prepared_statement(&statement_name)
1210 .await
1211 {
1212 Ok(stmt) => stmt,
1213 Err(err) => return self.error(err.into_response(Severity::Error)).await,
1214 };
1215
1216 let param_types = &stmt.desc().param_types;
1217 if param_types.len() != raw_params.len() {
1218 let message = format!(
1219 "bind message supplies {actual} parameters, \
1220 but prepared statement \"{name}\" requires {expected}",
1221 name = statement_name,
1222 actual = raw_params.len(),
1223 expected = param_types.len()
1224 );
1225 return self
1226 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, message))
1227 .await;
1228 }
1229 let param_formats = match pad_formats(param_formats, raw_params.len()) {
1230 Ok(param_formats) => param_formats,
1231 Err(msg) => {
1232 return self
1233 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1234 .await;
1235 }
1236 };
1237 if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1238 return self.aborted_txn_error().await;
1239 }
1240 let buf = RowArena::new();
1241 let mut params = vec![];
1242 for ((raw_param, mz_typ), format) in raw_params
1243 .into_iter()
1244 .zip_eq(param_types)
1245 .zip_eq(param_formats)
1246 {
1247 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1248 let datum = match raw_param {
1249 None => Datum::Null,
1250 Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1251 Ok(param) => param.into_datum(&buf, &pg_typ),
1252 Err(err) => {
1253 let msg = format!("unable to decode parameter: {}", err);
1254 return self
1255 .error(ErrorResponse::error(SqlState::INVALID_PARAMETER_VALUE, msg))
1256 .await;
1257 }
1258 },
1259 };
1260 params.push((datum, mz_typ.clone()))
1261 }
1262
1263 let result_formats = match pad_formats(
1264 result_formats,
1265 stmt.desc()
1266 .relation_desc
1267 .clone()
1268 .map(|desc| desc.typ().column_types.len())
1269 .unwrap_or(0),
1270 ) {
1271 Ok(result_formats) => result_formats,
1272 Err(msg) => {
1273 return self
1274 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1275 .await;
1276 }
1277 };
1278
1279 if !stmt.stmt().map_or(false, |stmt| {
1282 matches!(
1283 stmt,
1284 Statement::Copy(CopyStatement {
1285 direction: CopyDirection::To,
1286 ..
1287 })
1288 )
1289 }) {
1290 if let Some(desc) = stmt.desc().relation_desc.clone() {
1291 for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1292 match (format, &ty.scalar_type) {
1293 (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1294 return self
1295 .error(ErrorResponse::error(
1296 SqlState::PROTOCOL_VIOLATION,
1297 "binary encoding of list types is not implemented",
1298 ))
1299 .await;
1300 }
1301 (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1302 return self
1303 .error(ErrorResponse::error(
1304 SqlState::PROTOCOL_VIOLATION,
1305 "binary encoding of map types is not implemented",
1306 ))
1307 .await;
1308 }
1309 (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1310 return self
1311 .error(ErrorResponse::error(
1312 SqlState::PROTOCOL_VIOLATION,
1313 "binary encoding of aclitem types does not exist",
1314 ))
1315 .await;
1316 }
1317 _ => (),
1318 }
1319 }
1320 }
1321 }
1322
1323 let desc = stmt.desc().clone();
1324 let logging = Arc::clone(stmt.logging());
1325 let stmt_ast = stmt.stmt().cloned();
1326 let state_revision = stmt.state_revision;
1327 if let Err(err) = self.adapter_client.session().set_portal(
1328 portal_name,
1329 desc,
1330 stmt_ast,
1331 logging,
1332 params,
1333 result_formats,
1334 state_revision,
1335 ) {
1336 return self.error(err.into_response(Severity::Error)).await;
1337 }
1338
1339 self.send(BackendMessage::BindComplete).await?;
1340 Ok(State::Ready)
1341 }
1342
1343 fn execute(
1344 &mut self,
1345 portal_name: String,
1346 max_rows: ExecuteCount,
1347 get_response: GetResponse,
1348 fetch_portal_name: Option<String>,
1349 timeout: ExecuteTimeout,
1350 outer_ctx_extra: Option<ExecuteContextExtra>,
1351 received: Option<EpochMillis>,
1352 ) -> BoxFuture<'_, Result<State, io::Error>> {
1353 async move {
1354 let aborted_txn = self.is_aborted_txn();
1355
1356 let portal = match self
1358 .adapter_client
1359 .session()
1360 .get_portal_unverified_mut(&portal_name)
1361 {
1362 Some(portal) => portal,
1363 None => {
1364 let msg = format!("portal {} does not exist", portal_name.quoted());
1365 if let Some(outer_ctx_extra) = outer_ctx_extra {
1366 self.adapter_client.retire_execute(
1367 outer_ctx_extra,
1368 StatementEndedExecutionReason::Errored { error: msg.clone() },
1369 );
1370 }
1371 return self
1372 .error(ErrorResponse::error(SqlState::INVALID_CURSOR_NAME, msg))
1373 .await;
1374 }
1375 };
1376
1377 *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1378
1379 let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1381 if aborted_txn && !txn_exit_stmt {
1382 if let Some(outer_ctx_extra) = outer_ctx_extra {
1383 self.adapter_client.retire_execute(
1384 outer_ctx_extra,
1385 StatementEndedExecutionReason::Errored {
1386 error: ABORTED_TXN_MSG.to_string(),
1387 },
1388 );
1389 }
1390 return self.aborted_txn_error().await;
1391 }
1392
1393 let row_desc = portal.desc.relation_desc.clone();
1394 match portal.state {
1395 PortalState::NotStarted => {
1396 self.ensure_transaction(1, "execute").await?;
1398 match self
1399 .adapter_client
1400 .execute(
1401 portal_name.clone(),
1402 self.conn.wait_closed(),
1403 outer_ctx_extra,
1404 )
1405 .await
1406 {
1407 Ok((response, execute_started)) => {
1408 self.send_pending_notices().await?;
1409 self.send_execute_response(
1410 response,
1411 row_desc,
1412 portal_name,
1413 max_rows,
1414 get_response,
1415 fetch_portal_name,
1416 timeout,
1417 execute_started,
1418 )
1419 .await
1420 }
1421 Err(e) => {
1422 self.send_pending_notices().await?;
1423 self.error(e.into_response(Severity::Error)).await
1424 }
1425 }
1426 }
1427 PortalState::InProgress(rows) => {
1428 let rows = rows.take().expect("InProgress rows must be populated");
1429 let (result, statement_ended_execution_reason) = match self
1430 .send_rows(
1431 row_desc.expect("portal missing row desc on resumption"),
1432 portal_name,
1433 rows,
1434 max_rows,
1435 get_response,
1436 fetch_portal_name,
1437 timeout,
1438 )
1439 .await
1440 {
1441 Err(e) => {
1442 (Err(e), StatementEndedExecutionReason::Canceled)
1445 }
1446 Ok((ok, SendRowsEndedReason::Canceled)) => {
1447 (Ok(ok), StatementEndedExecutionReason::Canceled)
1448 }
1449 Ok((
1462 ok,
1463 SendRowsEndedReason::Success {
1464 result_size: _,
1465 rows_returned: _,
1466 },
1467 )) => (
1468 Ok(ok),
1469 StatementEndedExecutionReason::Success {
1470 result_size: None,
1471 rows_returned: None,
1472 execution_strategy: None,
1473 },
1474 ),
1475 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1476 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1477 }
1478 };
1479 if let Some(outer_ctx_extra) = outer_ctx_extra {
1480 self.adapter_client
1481 .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1482 }
1483 result
1484 }
1485 PortalState::Completed(Some(tag)) => {
1492 let tag = tag.to_string();
1493 if let Some(outer_ctx_extra) = outer_ctx_extra {
1494 self.adapter_client.retire_execute(
1495 outer_ctx_extra,
1496 StatementEndedExecutionReason::Success {
1497 result_size: None,
1498 rows_returned: None,
1499 execution_strategy: None,
1500 },
1501 );
1502 }
1503 self.send(BackendMessage::CommandComplete { tag }).await?;
1504 Ok(State::Ready)
1505 }
1506 PortalState::Completed(None) => {
1507 let error = format!(
1508 "portal {} cannot be run",
1509 Ident::new_unchecked(portal_name).to_ast_string_stable()
1510 );
1511 if let Some(outer_ctx_extra) = outer_ctx_extra {
1512 self.adapter_client.retire_execute(
1513 outer_ctx_extra,
1514 StatementEndedExecutionReason::Errored {
1515 error: error.clone(),
1516 },
1517 );
1518 }
1519 self.error(ErrorResponse::error(
1520 SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1521 error,
1522 ))
1523 .await
1524 }
1525 }
1526 }
1527 .instrument(debug_span!("execute"))
1528 .boxed()
1529 }
1530
1531 #[instrument(level = "debug")]
1532 async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1533 self.ensure_transaction(1, "describe_statement").await?;
1535
1536 let stmt = match self.adapter_client.get_prepared_statement(name).await {
1537 Ok(stmt) => stmt,
1538 Err(err) => return self.error(err.into_response(Severity::Error)).await,
1539 };
1540 let parameter_desc = BackendMessage::ParameterDescription(
1542 stmt.desc()
1543 .param_types
1544 .iter()
1545 .map(mz_pgrepr::Type::from)
1546 .collect(),
1547 );
1548 let formats = vec![Format::Text; stmt.desc().arity()];
1552 let row_desc = describe_rows(stmt.desc(), &formats);
1553 self.send_all([parameter_desc, row_desc]).await?;
1554 Ok(State::Ready)
1555 }
1556
1557 #[instrument(level = "debug")]
1558 async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1559 self.ensure_transaction(1, "describe_portal").await?;
1561
1562 let session = self.adapter_client.session();
1563 let row_desc = session
1564 .get_portal_unverified(name)
1565 .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1566 match row_desc {
1567 Some(row_desc) => {
1568 self.send(row_desc).await?;
1569 Ok(State::Ready)
1570 }
1571 None => {
1572 self.error(ErrorResponse::error(
1573 SqlState::INVALID_CURSOR_NAME,
1574 format!("portal {} does not exist", name.quoted()),
1575 ))
1576 .await
1577 }
1578 }
1579 }
1580
1581 #[instrument(level = "debug")]
1582 async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1583 self.adapter_client
1584 .session()
1585 .remove_prepared_statement(&name);
1586 self.send(BackendMessage::CloseComplete).await?;
1587 Ok(State::Ready)
1588 }
1589
1590 #[instrument(level = "debug")]
1591 async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1592 self.adapter_client.session().remove_portal(&name);
1593 self.send(BackendMessage::CloseComplete).await?;
1594 Ok(State::Ready)
1595 }
1596
1597 fn complete_portal(&mut self, name: &str) {
1598 let portal = self
1599 .adapter_client
1600 .session()
1601 .get_portal_unverified_mut(name)
1602 .expect("portal should exist");
1603 *portal.state = PortalState::Completed(None);
1604 }
1605
1606 async fn fetch(
1607 &mut self,
1608 name: String,
1609 count: Option<FetchDirection>,
1610 max_rows: ExecuteCount,
1611 fetch_portal_name: Option<String>,
1612 timeout: ExecuteTimeout,
1613 ctx_extra: ExecuteContextExtra,
1614 ) -> Result<State, io::Error> {
1615 let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1618
1619 let count = match (max_rows, count) {
1632 (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1633 let count = usize::cast_from(count);
1634 if max_rows < count {
1635 let msg = "Execute with max_rows < a FETCH's count is not supported";
1636 self.adapter_client.retire_execute(
1637 ctx_extra,
1638 StatementEndedExecutionReason::Errored {
1639 error: msg.to_string(),
1640 },
1641 );
1642 return self
1643 .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1644 .await;
1645 }
1646 ExecuteCount::Count(count)
1647 }
1648 (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1649 let msg = "Execute with max_rows of a FETCH ALL is not supported";
1650 self.adapter_client.retire_execute(
1651 ctx_extra,
1652 StatementEndedExecutionReason::Errored {
1653 error: msg.to_string(),
1654 },
1655 );
1656 return self
1657 .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1658 .await;
1659 }
1660 (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1661 (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1662 ExecuteCount::Count(usize::cast_from(count))
1663 }
1664 };
1665 let cursor_name = name.to_string();
1666 self.execute(
1667 cursor_name,
1668 count,
1669 fetch_message,
1670 fetch_portal_name,
1671 timeout,
1672 Some(ctx_extra),
1673 None,
1674 )
1675 .await
1676 }
1677
1678 async fn flush(&mut self) -> Result<State, io::Error> {
1679 self.conn.flush().await?;
1680 Ok(State::Ready)
1681 }
1682
1683 #[instrument(level = "debug")]
1688 async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1689 where
1690 M: Into<BackendMessage>,
1691 {
1692 let message: BackendMessage = message.into();
1693 let is_error =
1694 matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1695
1696 self.conn.send(message).await?;
1697
1698 if is_error {
1705 self.conn.flush().await?;
1706 }
1707
1708 Ok(())
1709 }
1710
1711 #[instrument(level = "debug")]
1712 pub async fn send_all(
1713 &mut self,
1714 messages: impl IntoIterator<Item = BackendMessage>,
1715 ) -> Result<(), io::Error> {
1716 for m in messages {
1717 self.send(m).await?;
1718 }
1719 Ok(())
1720 }
1721
1722 #[instrument(level = "debug")]
1723 async fn sync(&mut self) -> Result<State, io::Error> {
1724 if self.adapter_client.session().transaction().is_implicit() {
1726 self.commit_transaction().await?;
1727 }
1728 self.ready().await
1729 }
1730
1731 #[instrument(level = "debug")]
1732 async fn ready(&mut self) -> Result<State, io::Error> {
1733 let txn_state = self.adapter_client.session().transaction().into();
1734 self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1735 self.flush().await
1736 }
1737
1738 #[allow(clippy::too_many_arguments)]
1739 #[instrument(level = "debug")]
1740 async fn send_execute_response(
1741 &mut self,
1742 response: ExecuteResponse,
1743 row_desc: Option<RelationDesc>,
1744 portal_name: String,
1745 max_rows: ExecuteCount,
1746 get_response: GetResponse,
1747 fetch_portal_name: Option<String>,
1748 timeout: ExecuteTimeout,
1749 execute_started: Instant,
1750 ) -> Result<State, io::Error> {
1751 let mut tag = response.tag();
1752
1753 macro_rules! command_complete {
1754 () => {{
1755 self.send(BackendMessage::CommandComplete {
1756 tag: tag
1757 .take()
1758 .expect("command_complete only called on tag-generating results"),
1759 })
1760 .await?;
1761 Ok(State::Ready)
1762 }};
1763 }
1764
1765 let r = match response {
1766 ExecuteResponse::ClosedCursor => {
1767 self.complete_portal(&portal_name);
1768 command_complete!()
1769 }
1770 ExecuteResponse::DeclaredCursor => {
1771 self.complete_portal(&portal_name);
1772 command_complete!()
1773 }
1774 ExecuteResponse::EmptyQuery => {
1775 self.send(BackendMessage::EmptyQueryResponse).await?;
1776 Ok(State::Ready)
1777 }
1778 ExecuteResponse::Fetch {
1779 name,
1780 count,
1781 timeout,
1782 ctx_extra,
1783 } => {
1784 self.fetch(
1785 name,
1786 count,
1787 max_rows,
1788 Some(portal_name.to_string()),
1789 timeout,
1790 ctx_extra,
1791 )
1792 .await
1793 }
1794 ExecuteResponse::SendingRowsStreaming {
1795 rows,
1796 instance_id,
1797 strategy,
1798 } => {
1799 let row_desc = row_desc
1800 .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1801
1802 let span = tracing::debug_span!("sending_rows_streaming");
1803
1804 self.send_rows(
1805 row_desc,
1806 portal_name,
1807 InProgressRows::new(RecordFirstRowStream::new(
1808 Box::new(rows),
1809 execute_started,
1810 &self.adapter_client,
1811 Some(instance_id),
1812 Some(strategy),
1813 )),
1814 max_rows,
1815 get_response,
1816 fetch_portal_name,
1817 timeout,
1818 )
1819 .instrument(span)
1820 .await
1821 .map(|(state, _)| state)
1822 }
1823 ExecuteResponse::SendingRowsImmediate { rows } => {
1824 let row_desc = row_desc
1825 .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1826
1827 let span = tracing::debug_span!("sending_rows_immediate");
1828
1829 let stream =
1830 futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1831 self.send_rows(
1832 row_desc,
1833 portal_name,
1834 InProgressRows::new(RecordFirstRowStream::new(
1835 Box::new(stream),
1836 execute_started,
1837 &self.adapter_client,
1838 None,
1839 Some(StatementExecutionStrategy::Constant),
1840 )),
1841 max_rows,
1842 get_response,
1843 fetch_portal_name,
1844 timeout,
1845 )
1846 .instrument(span)
1847 .await
1848 .map(|(state, _)| state)
1849 }
1850 ExecuteResponse::SetVariable { name, .. } => {
1851 let qn = name.to_string();
1854 let msg = if let Some(var) = self
1855 .adapter_client
1856 .session()
1857 .vars_mut()
1858 .notify_set()
1859 .find(|v| v.name() == qn)
1860 {
1861 Some(BackendMessage::ParameterStatus(var.name(), var.value()))
1862 } else {
1863 None
1864 };
1865 if let Some(msg) = msg {
1866 self.send(msg).await?;
1867 }
1868 command_complete!()
1869 }
1870 ExecuteResponse::Subscribing {
1871 rx,
1872 ctx_extra,
1873 instance_id,
1874 } => {
1875 if fetch_portal_name.is_none() {
1876 let mut msg = ErrorResponse::notice(
1877 SqlState::WARNING,
1878 "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
1879 );
1880 if self.adapter_client.session().vars().application_name() == "psql" {
1881 msg.hint = Some(
1882 "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
1883 .into(),
1884 )
1885 }
1886 self.send(msg).await?;
1887 self.conn.flush().await?;
1888 }
1889 let row_desc =
1890 row_desc.expect("missing row description for ExecuteResponse::Subscribing");
1891 let (result, statement_ended_execution_reason) = match self
1892 .send_rows(
1893 row_desc,
1894 portal_name,
1895 InProgressRows::new(RecordFirstRowStream::new(
1896 Box::new(UnboundedReceiverStream::new(rx)),
1897 execute_started,
1898 &self.adapter_client,
1899 Some(instance_id),
1900 None,
1901 )),
1902 max_rows,
1903 get_response,
1904 fetch_portal_name,
1905 timeout,
1906 )
1907 .await
1908 {
1909 Err(e) => {
1910 (Err(e), StatementEndedExecutionReason::Canceled)
1913 }
1914 Ok((ok, SendRowsEndedReason::Canceled)) => {
1915 (Ok(ok), StatementEndedExecutionReason::Canceled)
1916 }
1917 Ok((
1918 ok,
1919 SendRowsEndedReason::Success {
1920 result_size,
1921 rows_returned,
1922 },
1923 )) => (
1924 Ok(ok),
1925 StatementEndedExecutionReason::Success {
1926 result_size: Some(result_size),
1927 rows_returned: Some(rows_returned),
1928 execution_strategy: None,
1929 },
1930 ),
1931 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1932 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1933 }
1934 };
1935 self.adapter_client
1936 .retire_execute(ctx_extra, statement_ended_execution_reason);
1937 return result;
1938 }
1939 ExecuteResponse::CopyTo { format, resp } => {
1940 let row_desc =
1941 row_desc.expect("missing row description for ExecuteResponse::CopyTo");
1942 match *resp {
1943 ExecuteResponse::Subscribing {
1944 rx,
1945 ctx_extra,
1946 instance_id,
1947 } => {
1948 let (result, statement_ended_execution_reason) = match self
1949 .copy_rows(
1950 format,
1951 row_desc,
1952 RecordFirstRowStream::new(
1953 Box::new(UnboundedReceiverStream::new(rx)),
1954 execute_started,
1955 &self.adapter_client,
1956 Some(instance_id),
1957 None,
1958 ),
1959 )
1960 .await
1961 {
1962 Err(e) => {
1963 (Err(e), StatementEndedExecutionReason::Canceled)
1966 }
1967 Ok((
1968 state,
1969 SendRowsEndedReason::Success {
1970 result_size,
1971 rows_returned,
1972 },
1973 )) => (
1974 Ok(state),
1975 StatementEndedExecutionReason::Success {
1976 result_size: Some(result_size),
1977 rows_returned: Some(rows_returned),
1978 execution_strategy: None,
1979 },
1980 ),
1981 Ok((state, SendRowsEndedReason::Errored { error })) => {
1982 (Ok(state), StatementEndedExecutionReason::Errored { error })
1983 }
1984 Ok((state, SendRowsEndedReason::Canceled)) => {
1985 (Ok(state), StatementEndedExecutionReason::Canceled)
1986 }
1987 };
1988 self.adapter_client
1989 .retire_execute(ctx_extra, statement_ended_execution_reason);
1990 return result;
1991 }
1992 ExecuteResponse::SendingRowsStreaming {
1993 rows,
1994 instance_id,
1995 strategy,
1996 } => {
1997 return self
2002 .copy_rows(
2003 format,
2004 row_desc,
2005 RecordFirstRowStream::new(
2006 Box::new(rows),
2007 execute_started,
2008 &self.adapter_client,
2009 Some(instance_id),
2010 Some(strategy),
2011 ),
2012 )
2013 .await
2014 .map(|(state, _)| state);
2015 }
2016 ExecuteResponse::SendingRowsImmediate { rows } => {
2017 let span = tracing::debug_span!("sending_rows_immediate");
2018
2019 let rows = futures::stream::once(futures::future::ready(
2020 PeekResponseUnary::Rows(rows),
2021 ));
2022 return self
2027 .copy_rows(
2028 format,
2029 row_desc,
2030 RecordFirstRowStream::new(
2031 Box::new(rows),
2032 execute_started,
2033 &self.adapter_client,
2034 None,
2035 Some(StatementExecutionStrategy::Constant),
2036 ),
2037 )
2038 .instrument(span)
2039 .await
2040 .map(|(state, _)| state);
2041 }
2042 _ => {
2043 return self
2044 .error(ErrorResponse::error(
2045 SqlState::INTERNAL_ERROR,
2046 "unsupported COPY response type".to_string(),
2047 ))
2048 .await;
2049 }
2050 };
2051 }
2052 ExecuteResponse::CopyFrom {
2053 target_id,
2054 target_name,
2055 columns,
2056 params,
2057 ctx_extra,
2058 } => {
2059 let row_desc =
2060 row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2061 self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2062 .await
2063 }
2064 ExecuteResponse::TransactionCommitted { params }
2065 | ExecuteResponse::TransactionRolledBack { params } => {
2066 let notify_set: mz_ore::collections::HashSet<String> = self
2067 .adapter_client
2068 .session()
2069 .vars()
2070 .notify_set()
2071 .map(|v| v.name().to_string())
2072 .collect();
2073
2074 for (name, value) in params
2076 .into_iter()
2077 .filter(|(name, _v)| notify_set.contains(*name))
2078 {
2079 let msg = BackendMessage::ParameterStatus(name, value);
2080 self.send(msg).await?;
2081 }
2082 command_complete!()
2083 }
2084
2085 ExecuteResponse::AlteredDefaultPrivileges
2086 | ExecuteResponse::AlteredObject(..)
2087 | ExecuteResponse::AlteredRole
2088 | ExecuteResponse::AlteredSystemConfiguration
2089 | ExecuteResponse::CreatedCluster { .. }
2090 | ExecuteResponse::CreatedClusterReplica { .. }
2091 | ExecuteResponse::CreatedConnection { .. }
2092 | ExecuteResponse::CreatedDatabase { .. }
2093 | ExecuteResponse::CreatedIndex { .. }
2094 | ExecuteResponse::CreatedIntrospectionSubscribe
2095 | ExecuteResponse::CreatedMaterializedView { .. }
2096 | ExecuteResponse::CreatedContinualTask { .. }
2097 | ExecuteResponse::CreatedRole
2098 | ExecuteResponse::CreatedSchema { .. }
2099 | ExecuteResponse::CreatedSecret { .. }
2100 | ExecuteResponse::CreatedSink { .. }
2101 | ExecuteResponse::CreatedSource { .. }
2102 | ExecuteResponse::CreatedTable { .. }
2103 | ExecuteResponse::CreatedType
2104 | ExecuteResponse::CreatedView { .. }
2105 | ExecuteResponse::CreatedViews { .. }
2106 | ExecuteResponse::CreatedNetworkPolicy
2107 | ExecuteResponse::Comment
2108 | ExecuteResponse::Deallocate { .. }
2109 | ExecuteResponse::Deleted(..)
2110 | ExecuteResponse::DiscardedAll
2111 | ExecuteResponse::DiscardedTemp
2112 | ExecuteResponse::DroppedObject(_)
2113 | ExecuteResponse::DroppedOwned
2114 | ExecuteResponse::GrantedPrivilege
2115 | ExecuteResponse::GrantedRole
2116 | ExecuteResponse::Inserted(..)
2117 | ExecuteResponse::Copied(..)
2118 | ExecuteResponse::Prepare
2119 | ExecuteResponse::Raised
2120 | ExecuteResponse::ReassignOwned
2121 | ExecuteResponse::RevokedPrivilege
2122 | ExecuteResponse::RevokedRole
2123 | ExecuteResponse::StartedTransaction { .. }
2124 | ExecuteResponse::Updated(..)
2125 | ExecuteResponse::ValidatedConnection => {
2126 command_complete!()
2127 }
2128 };
2129
2130 assert_none!(tag, "tag created but not consumed: {:?}", tag);
2131 r
2132 }
2133
2134 #[allow(clippy::too_many_arguments)]
2135 #[mz_ore::instrument(level = "debug")]
2137 async fn send_rows(
2138 &mut self,
2139 row_desc: RelationDesc,
2140 portal_name: String,
2141 mut rows: InProgressRows,
2142 max_rows: ExecuteCount,
2143 get_response: GetResponse,
2144 fetch_portal_name: Option<String>,
2145 timeout: ExecuteTimeout,
2146 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2147 let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2150 name
2151 } else {
2152 &portal_name
2153 };
2154 let result_formats = self
2155 .adapter_client
2156 .session()
2157 .get_portal_unverified(result_format_portal_name)
2158 .expect("valid fetch portal name for send rows")
2159 .result_formats
2160 .clone();
2161
2162 let (mut wait_once, mut deadline) = match timeout {
2163 ExecuteTimeout::None => (false, None),
2164 ExecuteTimeout::Seconds(t) => (
2165 false,
2166 Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2167 ),
2168 ExecuteTimeout::WaitOnce => (true, None),
2169 };
2170
2171 {
2173 let portal_name_desc = &self
2174 .adapter_client
2175 .session()
2176 .get_portal_unverified(portal_name.as_str())
2177 .expect("portal should exist")
2178 .desc
2179 .relation_desc;
2180 if let Some(portal_name_desc) = portal_name_desc {
2181 soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2182 }
2183 if let Some(fetch_portal_name) = &fetch_portal_name {
2184 let fetch_portal_desc = &self
2185 .adapter_client
2186 .session()
2187 .get_portal_unverified(fetch_portal_name)
2188 .expect("portal should exist")
2189 .desc
2190 .relation_desc;
2191 if let Some(fetch_portal_desc) = fetch_portal_desc {
2192 soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2193 }
2194 }
2195 }
2196
2197 self.conn.set_encode_state(
2198 row_desc
2199 .typ()
2200 .column_types
2201 .iter()
2202 .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2203 .zip_eq(result_formats)
2204 .collect(),
2205 );
2206
2207 let mut total_sent_rows = 0;
2208 let mut total_sent_bytes = 0;
2209 let mut want_rows = match max_rows {
2211 ExecuteCount::All => usize::MAX,
2212 ExecuteCount::Count(count) => count,
2213 };
2214
2215 loop {
2217 let batch = if rows.current.is_some() {
2220 FetchResult::Rows(rows.current.take())
2221 } else if want_rows == 0 {
2222 FetchResult::Rows(None)
2223 } else {
2224 let notice_fut = self.adapter_client.session().recv_notice();
2225 tokio::select! {
2226 err = self.conn.wait_closed() => return Err(err),
2227 _ = time::sleep_until(deadline.unwrap_or_else(tokio::time::Instant::now)), if deadline.is_some() => FetchResult::Rows(None),
2228 notice = notice_fut => {
2229 FetchResult::Notice(notice)
2230 }
2231 batch = rows.remaining.recv() => match batch {
2232 None => FetchResult::Rows(None),
2233 Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2234 Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2235 Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2236 },
2237 }
2238 };
2239
2240 match batch {
2241 FetchResult::Rows(None) => break,
2242 FetchResult::Rows(Some(mut batch_rows)) => {
2243 if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2244 let msg = err.to_string();
2245 return self
2246 .error(err.into_response(Severity::Error))
2247 .await
2248 .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2249 }
2250
2251 if wait_once && batch_rows.peek().is_some() {
2255 deadline = Some(tokio::time::Instant::now());
2256 wait_once = false;
2257 }
2258
2259 let mut sent_rows = 0;
2261 let mut sent_bytes = 0;
2262 let messages = (&mut batch_rows)
2263 .map(|row| {
2268 let row_len = row.byte_len();
2269 let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2270 (row_len, BackendMessage::DataRow(values))
2271 })
2272 .inspect(|(row_len, _)| {
2273 sent_bytes += row_len;
2274 sent_rows += 1
2275 })
2276 .map(|(_row_len, row)| row)
2277 .take(want_rows);
2278 self.send_all(messages).await?;
2279
2280 total_sent_rows += sent_rows;
2281 total_sent_bytes += sent_bytes;
2282 want_rows -= sent_rows;
2283
2284 if want_rows == 0 {
2287 if batch_rows.peek().is_some() {
2288 rows.current = Some(batch_rows);
2289 }
2290 break;
2291 }
2292
2293 self.conn.flush().await?;
2294 }
2295 FetchResult::Notice(notice) => {
2296 self.send(notice.into_response()).await?;
2297 self.conn.flush().await?;
2298 }
2299 FetchResult::Error(text) => {
2300 return self
2301 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2302 .await
2303 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2304 }
2305 FetchResult::Canceled => {
2306 return self
2307 .error(ErrorResponse::error(
2308 SqlState::QUERY_CANCELED,
2309 "canceling statement due to user request",
2310 ))
2311 .await
2312 .map(|state| (state, SendRowsEndedReason::Canceled));
2313 }
2314 }
2315 }
2316
2317 let portal = self
2318 .adapter_client
2319 .session()
2320 .get_portal_unverified_mut(&portal_name)
2321 .expect("valid portal name for send rows");
2322
2323 let saw_rows = rows.remaining.saw_rows;
2324 let no_more_rows = rows.no_more_rows();
2325 let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2326
2327 *portal.state = PortalState::InProgress(Some(rows));
2330
2331 let fetch_portal = fetch_portal_name.map(|name| {
2332 self.adapter_client
2333 .session()
2334 .get_portal_unverified_mut(&name)
2335 .expect("valid fetch portal")
2336 });
2337 let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2338 self.send(response_message).await?;
2339
2340 if no_more_rows {
2342 let statement_type = if let Some(stmt) = &self
2343 .adapter_client
2344 .session()
2345 .get_portal_unverified(&portal_name)
2346 .expect("valid portal name for send_rows")
2347 .stmt
2348 {
2349 metrics::statement_type_label_value(stmt.deref())
2350 } else {
2351 "no-statement"
2352 };
2353 let duration = if saw_rows {
2354 recorded_first_row_instant
2355 .expect("recorded_first_row_instant because saw_rows")
2356 .elapsed()
2357 } else {
2358 Duration::ZERO
2362 };
2363 self.adapter_client
2364 .inner()
2365 .metrics()
2366 .result_rows_first_to_last_byte_seconds
2367 .with_label_values(&[statement_type])
2368 .observe(duration.as_secs_f64());
2369 }
2370
2371 Ok((
2372 State::Ready,
2373 SendRowsEndedReason::Success {
2374 result_size: u64::cast_from(total_sent_bytes),
2375 rows_returned: u64::cast_from(total_sent_rows),
2376 },
2377 ))
2378 }
2379
2380 #[mz_ore::instrument(level = "debug")]
2381 async fn copy_rows(
2382 &mut self,
2383 format: CopyFormat,
2384 row_desc: RelationDesc,
2385 mut stream: RecordFirstRowStream,
2386 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2387 let (row_format, encode_format) = match format {
2388 CopyFormat::Text => (
2389 CopyFormatParams::Text(CopyTextFormatParams::default()),
2390 Format::Text,
2391 ),
2392 CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2393 CopyFormat::Csv => (
2394 CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2395 Format::Text,
2396 ),
2397 CopyFormat::Parquet => {
2398 let text = "Parquet format is not supported".to_string();
2399 return self
2400 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2401 .await
2402 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2403 }
2404 };
2405
2406 let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2407 mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2408 };
2409
2410 let typ = row_desc.typ();
2411 let column_formats = iter::repeat(encode_format)
2412 .take(typ.column_types.len())
2413 .collect();
2414 self.send(BackendMessage::CopyOutResponse {
2415 overall_format: encode_format,
2416 column_formats,
2417 })
2418 .await?;
2419
2420 let mut out = Vec::new();
2425
2426 if let CopyFormat::Binary = format {
2427 out.extend(b"PGCOPY\n\xFF\r\n\0");
2429 out.extend([0, 0, 0, 0]);
2431 out.extend([0, 0, 0, 0]);
2433 }
2434
2435 let mut count = 0;
2436 let mut total_sent_bytes = 0;
2437 loop {
2438 tokio::select! {
2439 e = self.conn.wait_closed() => return Err(e),
2440 batch = stream.recv() => match batch {
2441 None => break,
2442 Some(PeekResponseUnary::Error(text)) => {
2443 return self
2444 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2445 .await
2446 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2447 }
2448 Some(PeekResponseUnary::Canceled) => {
2449 return self.error(ErrorResponse::error(
2450 SqlState::QUERY_CANCELED,
2451 "canceling statement due to user request",
2452 ))
2453 .await.map(|state| (state, SendRowsEndedReason::Canceled));
2454 }
2455 Some(PeekResponseUnary::Rows(mut rows)) => {
2456 count += rows.count();
2457 while let Some(row) = rows.next() {
2458 total_sent_bytes += row.byte_len();
2459 encode_fn(row, typ, &mut out)?;
2460 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2461 .await?;
2462 }
2463 }
2464 },
2465 notice = self.adapter_client.session().recv_notice() => {
2466 self.send(notice.into_response())
2467 .await?;
2468 self.conn.flush().await?;
2469 }
2470 }
2471
2472 self.conn.flush().await?;
2473 }
2474 if let CopyFormat::Binary = format {
2476 let trailer: i16 = -1;
2477 out.extend(trailer.to_be_bytes());
2478 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2479 .await?;
2480 }
2481
2482 let tag = format!("COPY {}", count);
2483 self.send(BackendMessage::CopyDone).await?;
2484 self.send(BackendMessage::CommandComplete { tag }).await?;
2485 Ok((
2486 State::Ready,
2487 SendRowsEndedReason::Success {
2488 result_size: u64::cast_from(total_sent_bytes),
2489 rows_returned: u64::cast_from(count),
2490 },
2491 ))
2492 }
2493
2494 #[instrument(level = "debug")]
2497 async fn copy_from(
2498 &mut self,
2499 target_id: CatalogItemId,
2500 target_name: String,
2501 columns: Vec<ColumnIndex>,
2502 params: CopyFormatParams<'_>,
2503 row_desc: RelationDesc,
2504 mut ctx_extra: ExecuteContextExtra,
2505 ) -> Result<State, io::Error> {
2506 let res = self
2507 .copy_from_inner(
2508 target_id,
2509 target_name,
2510 columns,
2511 params,
2512 row_desc,
2513 &mut ctx_extra,
2514 )
2515 .await;
2516 match &res {
2517 Ok(State::Done) => {
2518 self.adapter_client
2522 .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2523 }
2524 Err(e) => {
2525 self.adapter_client.retire_execute(
2526 ctx_extra,
2527 StatementEndedExecutionReason::Errored {
2528 error: format!("{e}"),
2529 },
2530 );
2531 }
2532 other => {
2533 tracing::warn!(?other, "aborting COPY FROM");
2534 self.adapter_client
2535 .retire_execute(ctx_extra, StatementEndedExecutionReason::Aborted);
2536 }
2537 }
2538 res
2539 }
2540
2541 async fn copy_from_inner(
2542 &mut self,
2543 target_id: CatalogItemId,
2544 target_name: String,
2545 columns: Vec<ColumnIndex>,
2546 params: CopyFormatParams<'_>,
2547 row_desc: RelationDesc,
2548 ctx_extra: &mut ExecuteContextExtra,
2549 ) -> Result<State, io::Error> {
2550 let typ = row_desc.typ();
2551 let column_formats = vec![Format::Text; typ.column_types.len()];
2552 self.send(BackendMessage::CopyInResponse {
2553 overall_format: Format::Text,
2554 column_formats,
2555 })
2556 .await?;
2557 self.conn.flush().await?;
2558
2559 let system_vars = self.adapter_client.get_system_vars().await;
2560 let max_size = system_vars
2561 .get(MAX_COPY_FROM_SIZE.name())
2562 .ok()
2563 .and_then(|max_size| max_size.value().parse().ok())
2564 .unwrap_or(usize::MAX);
2565 tracing::debug!("COPY FROM max buffer size: {max_size} bytes");
2566
2567 let mut data = Vec::new();
2568 loop {
2569 let message = self.conn.recv().await?;
2570 match message {
2571 Some(FrontendMessage::CopyData(buf)) => {
2572 if (data.len() + buf.len()) > max_size {
2574 return self
2575 .error(ErrorResponse::error(
2576 SqlState::INSUFFICIENT_RESOURCES,
2577 "COPY FROM STDIN too large",
2578 ))
2579 .await;
2580 }
2581 data.extend(buf)
2582 }
2583 Some(FrontendMessage::CopyDone) => break,
2584 Some(FrontendMessage::CopyFail(err)) => {
2585 self.adapter_client.retire_execute(
2586 std::mem::take(ctx_extra),
2587 StatementEndedExecutionReason::Canceled,
2588 );
2589 return self
2590 .error(ErrorResponse::error(
2591 SqlState::QUERY_CANCELED,
2592 format!("COPY from stdin failed: {}", err),
2593 ))
2594 .await;
2595 }
2596 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2597 Some(_) => {
2598 let msg = "unexpected message type during COPY from stdin";
2599 self.adapter_client.retire_execute(
2600 std::mem::take(ctx_extra),
2601 StatementEndedExecutionReason::Errored {
2602 error: msg.to_string(),
2603 },
2604 );
2605 return self
2606 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
2607 .await;
2608 }
2609 None => {
2610 return Ok(State::Done);
2611 }
2612 }
2613 }
2614
2615 let column_types = typ
2616 .column_types
2617 .iter()
2618 .map(|x| &x.scalar_type)
2619 .map(mz_pgrepr::Type::from)
2620 .collect::<Vec<mz_pgrepr::Type>>();
2621
2622 let rows = match mz_pgcopy::decode_copy_format(&data, &column_types, params) {
2623 Ok(rows) => rows,
2624 Err(e) => {
2625 self.adapter_client.retire_execute(
2626 std::mem::take(ctx_extra),
2627 StatementEndedExecutionReason::Errored {
2628 error: e.to_string(),
2629 },
2630 );
2631 return self
2632 .error(ErrorResponse::error(
2633 SqlState::BAD_COPY_FILE_FORMAT,
2634 format!("{}", e),
2635 ))
2636 .await;
2637 }
2638 };
2639
2640 let count = rows.len();
2641
2642 if let Err(e) = self
2643 .adapter_client
2644 .insert_rows(
2645 target_id,
2646 target_name,
2647 columns,
2648 rows,
2649 std::mem::take(ctx_extra),
2650 )
2651 .await
2652 {
2653 self.adapter_client.retire_execute(
2654 std::mem::take(ctx_extra),
2655 StatementEndedExecutionReason::Errored {
2656 error: e.to_string(),
2657 },
2658 );
2659 return self.error(e.into_response(Severity::Error)).await;
2660 }
2661
2662 let tag = format!("COPY {}", count);
2663 self.send(BackendMessage::CommandComplete { tag }).await?;
2664
2665 Ok(State::Ready)
2666 }
2667
2668 #[instrument(level = "debug")]
2669 async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
2670 let notices = self
2671 .adapter_client
2672 .session()
2673 .drain_notices()
2674 .into_iter()
2675 .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
2676 self.send_all(notices).await?;
2677 Ok(())
2678 }
2679
2680 #[instrument(level = "debug")]
2681 async fn error(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
2682 assert!(err.severity.is_error());
2683 debug!(
2684 "cid={} error code={}",
2685 self.adapter_client.session().conn_id(),
2686 err.code.code()
2687 );
2688 let is_fatal = err.severity.is_fatal();
2689 self.send(BackendMessage::ErrorResponse(err)).await?;
2690
2691 let txn = self.adapter_client.session().transaction();
2692 match txn {
2693 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
2696 TransactionStatus::Started(_) => {
2698 self.rollback_transaction().await?;
2699 }
2700 TransactionStatus::InTransactionImplicit(_) => {
2702 self.rollback_transaction().await?;
2703 }
2704 TransactionStatus::InTransaction(_) => {
2706 self.adapter_client.fail_transaction();
2707 }
2708 };
2709 if is_fatal {
2710 Ok(State::Done)
2711 } else {
2712 Ok(State::Drain)
2713 }
2714 }
2715
2716 #[instrument(level = "debug")]
2717 async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
2718 self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
2719 SqlState::IN_FAILED_SQL_TRANSACTION,
2720 ABORTED_TXN_MSG,
2721 )))
2722 .await?;
2723 Ok(State::Drain)
2724 }
2725
2726 fn is_aborted_txn(&mut self) -> bool {
2727 matches!(
2728 self.adapter_client.session().transaction(),
2729 TransactionStatus::Failed(_)
2730 )
2731 }
2732}
2733
2734fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
2735 match (formats.len(), n) {
2736 (0, e) => Ok(vec![Format::Text; e]),
2737 (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
2738 (a, e) if a == e => Ok(formats),
2739 (a, e) => Err(format!(
2740 "expected {} field format specifiers, but got {}",
2741 e, a
2742 )),
2743 }
2744}
2745
2746fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
2747 match &stmt_desc.relation_desc {
2748 Some(desc) if !stmt_desc.is_copy => {
2749 BackendMessage::RowDescription(message::encode_row_description(desc, formats))
2750 }
2751 _ => BackendMessage::NoData,
2752 }
2753}
2754
2755type GetResponse = fn(
2756 max_rows: ExecuteCount,
2757 total_sent_rows: usize,
2758 fetch_portal: Option<PortalRefMut>,
2759) -> BackendMessage;
2760
2761fn portal_exec_message(
2764 max_rows: ExecuteCount,
2765 total_sent_rows: usize,
2766 _fetch_portal: Option<PortalRefMut>,
2767) -> BackendMessage {
2768 match max_rows {
2775 ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
2776 BackendMessage::PortalSuspended
2777 }
2778 _ => BackendMessage::CommandComplete {
2779 tag: format!("SELECT {}", total_sent_rows),
2780 },
2781 }
2782}
2783
2784fn fetch_message(
2786 _max_rows: ExecuteCount,
2787 total_sent_rows: usize,
2788 fetch_portal: Option<PortalRefMut>,
2789) -> BackendMessage {
2790 let tag = format!("FETCH {}", total_sent_rows);
2791 if let Some(portal) = fetch_portal {
2792 *portal.state = PortalState::Completed(Some(tag.clone()));
2793 }
2794 BackendMessage::CommandComplete { tag }
2795}
2796
2797#[derive(Debug, Copy, Clone)]
2798enum ExecuteCount {
2799 All,
2800 Count(usize),
2801}
2802
2803fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
2805 match stmt {
2806 Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
2808 None => false,
2809 }
2810}
2811
2812#[derive(Debug)]
2813enum FetchResult {
2814 Rows(Option<Box<dyn RowIterator + Send + Sync>>),
2815 Canceled,
2816 Error(String),
2817 Notice(AdapterNotice),
2818}
2819
2820#[cfg(test)]
2821mod test {
2822 use super::*;
2823
2824 #[mz_ore::test]
2825 fn test_parse_options() {
2826 struct TestCase {
2827 input: &'static str,
2828 expect: Result<Vec<(&'static str, &'static str)>, ()>,
2829 }
2830 let tests = vec![
2831 TestCase {
2832 input: "",
2833 expect: Ok(vec![]),
2834 },
2835 TestCase {
2836 input: "--key",
2837 expect: Err(()),
2838 },
2839 TestCase {
2840 input: "--key=val",
2841 expect: Ok(vec![("key", "val")]),
2842 },
2843 TestCase {
2844 input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
2845 expect: Ok(vec![
2846 ("key", "val"),
2847 ("key2", "val2"),
2848 ("key3", "val3"),
2849 ("key4", "val4"),
2850 ("key5", "val5"),
2851 ]),
2852 },
2853 TestCase {
2854 input: r#"-c\ key=val"#,
2855 expect: Ok(vec![(" key", "val")]),
2856 },
2857 TestCase {
2858 input: "--key=val -ckey2 val2",
2859 expect: Err(()),
2860 },
2861 TestCase {
2863 input: "--key=",
2864 expect: Ok(vec![("key", "")]),
2865 },
2866 ];
2867 for test in tests {
2868 let got = parse_options(test.input);
2869 let expect = test.expect.map(|r| {
2870 r.into_iter()
2871 .map(|(k, v)| (k.to_owned(), v.to_owned()))
2872 .collect()
2873 });
2874 assert_eq!(got, expect, "input: {}", test.input);
2875 }
2876 }
2877
2878 #[mz_ore::test]
2879 fn test_parse_option() {
2880 struct TestCase {
2881 input: &'static str,
2882 expect: Result<(&'static str, &'static str), ()>,
2883 }
2884 let tests = vec![
2885 TestCase {
2886 input: "",
2887 expect: Err(()),
2888 },
2889 TestCase {
2890 input: "--",
2891 expect: Err(()),
2892 },
2893 TestCase {
2894 input: "--c",
2895 expect: Err(()),
2896 },
2897 TestCase {
2898 input: "a=b",
2899 expect: Err(()),
2900 },
2901 TestCase {
2902 input: "--a=b",
2903 expect: Ok(("a", "b")),
2904 },
2905 TestCase {
2906 input: "--ca=b",
2907 expect: Ok(("ca", "b")),
2908 },
2909 TestCase {
2910 input: "-ca=b",
2911 expect: Ok(("a", "b")),
2912 },
2913 TestCase {
2915 input: "--=",
2916 expect: Ok(("", "")),
2917 },
2918 ];
2919 for test in tests {
2920 let got = parse_option(test.input);
2921 assert_eq!(got, test.expect, "input: {}", test.input);
2922 }
2923 }
2924
2925 #[mz_ore::test]
2926 fn test_split_options() {
2927 struct TestCase {
2928 input: &'static str,
2929 expect: Vec<&'static str>,
2930 }
2931 let tests = vec![
2932 TestCase {
2933 input: "",
2934 expect: vec![],
2935 },
2936 TestCase {
2937 input: " ",
2938 expect: vec![],
2939 },
2940 TestCase {
2941 input: " a ",
2942 expect: vec!["a"],
2943 },
2944 TestCase {
2945 input: " ab cd ",
2946 expect: vec!["ab", "cd"],
2947 },
2948 TestCase {
2949 input: r#" ab\ cd "#,
2950 expect: vec!["ab ", "cd"],
2951 },
2952 TestCase {
2953 input: r#" ab\\ cd "#,
2954 expect: vec![r#"ab\"#, "cd"],
2955 },
2956 TestCase {
2957 input: r#" ab\\\ cd "#,
2958 expect: vec![r#"ab\ "#, "cd"],
2959 },
2960 TestCase {
2961 input: r#" ab\\\ cd "#,
2962 expect: vec![r#"ab\ cd"#],
2963 },
2964 TestCase {
2965 input: r#" ab\\\cd "#,
2966 expect: vec![r#"ab\cd"#],
2967 },
2968 TestCase {
2969 input: r#"a\"#,
2970 expect: vec!["a"],
2971 },
2972 TestCase {
2973 input: r#"a\ "#,
2974 expect: vec!["a "],
2975 },
2976 TestCase {
2977 input: r#"\"#,
2978 expect: vec![],
2979 },
2980 TestCase {
2981 input: r#"\ "#,
2982 expect: vec![r#" "#],
2983 },
2984 TestCase {
2985 input: r#" \ "#,
2986 expect: vec![r#" "#],
2987 },
2988 TestCase {
2989 input: r#"\ "#,
2990 expect: vec![r#" "#],
2991 },
2992 ];
2993 for test in tests {
2994 let got = split_options(test.input);
2995 assert_eq!(got, test.expect, "input: {}", test.input);
2996 }
2997 }
2998}