1use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use futures::future::{BoxFuture, FutureExt, pending};
21use itertools::Itertools;
22use mz_adapter::client::RecordFirstRowStream;
23use mz_adapter::session::{
24 EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState,
25 SessionConfig, TransactionStatus,
26};
27use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
28use mz_adapter::{
29 AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, PeekResponseUnary, metrics,
30 verify_datum_desc,
31};
32use mz_auth::password::Password;
33use mz_authenticator::Authenticator;
34use mz_ore::cast::CastFrom;
35use mz_ore::netio::AsyncReady;
36use mz_ore::now::{EpochMillis, SYSTEM_TIME};
37use mz_ore::str::StrExt;
38use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
39use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
40use mz_pgwire_common::{
41 ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
42 VERSIONS,
43};
44use mz_repr::user::InternalUserMetadata;
45use mz_repr::{
46 CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
47 SqlRelationType, SqlScalarType,
48};
49use mz_server_core::TlsMode;
50use mz_server_core::listeners::AllowedRoles;
51use mz_sql::ast::display::AstDisplay;
52use mz_sql::ast::{CopyDirection, CopyStatement, FetchDirection, Ident, Raw, Statement};
53use mz_sql::parse::StatementParseResult;
54use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
55use mz_sql::session::metadata::SessionMetadata;
56use mz_sql::session::user::INTERNAL_USER_NAMES;
57use mz_sql::session::vars::{MAX_COPY_FROM_SIZE, Var, VarInput};
58use postgres::error::SqlState;
59use tokio::io::{self, AsyncRead, AsyncWrite};
60use tokio::select;
61use tokio::time::{self};
62use tokio_metrics::TaskMetrics;
63use tokio_stream::wrappers::UnboundedReceiverStream;
64use tracing::{Instrument, debug, debug_span, warn};
65use uuid::Uuid;
66
67use crate::codec::{
68 FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
69};
70use crate::message::{
71 self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
72 SASLServerFirstMessage,
73};
74
75pub fn match_handshake(buf: &[u8]) -> bool {
79 if buf.len() < 8 {
89 return false;
90 }
91 let version = NetworkEndian::read_i32(&buf[4..8]);
92 VERSIONS.contains(&version)
93}
94
95pub struct RunParams<'a, A, I>
97where
98 I: Iterator<Item = TaskMetrics> + Send,
99{
100 pub tls_mode: Option<TlsMode>,
102 pub adapter_client: mz_adapter::Client,
104 pub conn: &'a mut FramedConn<A>,
106 pub conn_uuid: Uuid,
108 pub version: i32,
110 pub params: BTreeMap<String, String>,
112 pub authenticator: Authenticator,
114 pub active_connection_counter: ConnectionCounter,
116 pub helm_chart_version: Option<String>,
118 pub allowed_roles: AllowedRoles,
120 pub tokio_metrics_intervals: I,
122}
123
124#[mz_ore::instrument(level = "debug")]
134pub async fn run<'a, A, I>(
135 RunParams {
136 tls_mode,
137 adapter_client,
138 conn,
139 conn_uuid,
140 version,
141 mut params,
142 authenticator,
143 active_connection_counter,
144 helm_chart_version,
145 allowed_roles,
146 tokio_metrics_intervals,
147 }: RunParams<'a, A, I>,
148) -> Result<(), io::Error>
149where
150 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
151 I: Iterator<Item = TaskMetrics> + Send,
152{
153 if version != VERSION_3 {
154 return conn
155 .send(ErrorResponse::fatal(
156 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
157 "server does not support the client's requested protocol version",
158 ))
159 .await;
160 }
161
162 let user = params.remove("user").unwrap_or_else(String::new);
163
164 let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
166 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
168 let role_allowed = match allowed_roles {
169 AllowedRoles::Normal => !is_reserved_user,
170 AllowedRoles::Internal => is_internal_user,
171 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
172 };
173 if !role_allowed {
174 let msg = format!("unauthorized login to user '{user}'");
175 return conn
176 .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
177 .await;
178 }
179
180 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
181 return conn.send(err).await;
182 }
183
184 let (mut session, expired) = match authenticator {
185 Authenticator::Frontegg(frontegg) => {
186 conn.send(BackendMessage::AuthenticationCleartextPassword)
187 .await?;
188 conn.flush().await?;
189 let password = match conn.recv().await? {
190 Some(FrontendMessage::RawAuthentication(data)) => {
191 match decode_password(Cursor::new(&data)).ok() {
192 Some(FrontendMessage::Password { password }) => password,
193 _ => {
194 return conn
195 .send(ErrorResponse::fatal(
196 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
197 "expected Password message",
198 ))
199 .await;
200 }
201 }
202 }
203 _ => {
204 return conn
205 .send(ErrorResponse::fatal(
206 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
207 "expected Password message",
208 ))
209 .await;
210 }
211 };
212
213 let auth_response = frontegg.authenticate(&user, &password).await;
214 match auth_response {
215 Ok(mut auth_session) => {
216 let session = adapter_client.new_session(SessionConfig {
224 conn_id: conn.conn_id().clone(),
225 uuid: conn_uuid,
226 user: auth_session.user().into(),
227 client_ip: conn.peer_addr().clone(),
228 external_metadata_rx: Some(auth_session.external_metadata_rx()),
229 internal_user_metadata: None,
230 helm_chart_version,
231 });
232 let expired = async move { auth_session.expired().await };
233 (session, expired.left_future())
234 }
235 Err(err) => {
236 warn!(?err, "pgwire connection failed authentication");
237 return conn
238 .send(ErrorResponse::fatal(
239 SqlState::INVALID_PASSWORD,
240 "invalid password",
241 ))
242 .await;
243 }
244 }
245 }
246 Authenticator::Password(adapter_client) => {
247 conn.send(BackendMessage::AuthenticationCleartextPassword)
248 .await?;
249 conn.flush().await?;
250 let password = match conn.recv().await? {
251 Some(FrontendMessage::RawAuthentication(data)) => {
252 match decode_password(Cursor::new(&data)).ok() {
253 Some(FrontendMessage::Password { password }) => Password(password),
254 _ => {
255 return conn
256 .send(ErrorResponse::fatal(
257 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
258 "expected Password message",
259 ))
260 .await;
261 }
262 }
263 }
264 _ => {
265 return conn
266 .send(ErrorResponse::fatal(
267 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
268 "expected Password message",
269 ))
270 .await;
271 }
272 };
273 let auth_response = match adapter_client.authenticate(&user, &password).await {
274 Ok(resp) => resp,
275 Err(err) => {
276 warn!(?err, "pgwire connection failed authentication");
277 return conn
278 .send(ErrorResponse::fatal(
279 SqlState::INVALID_PASSWORD,
280 "invalid password",
281 ))
282 .await;
283 }
284 };
285 let session = adapter_client.new_session(SessionConfig {
286 conn_id: conn.conn_id().clone(),
287 uuid: conn_uuid,
288 user,
289 client_ip: conn.peer_addr().clone(),
290 external_metadata_rx: None,
291 internal_user_metadata: Some(InternalUserMetadata {
292 superuser: auth_response.superuser,
293 }),
294 helm_chart_version,
295 });
296 let auth_session = pending().right_future();
298 (session, auth_session)
299 }
300 Authenticator::Sasl(adapter_client) => {
301 conn.send(BackendMessage::AuthenticationSASL).await?;
303 conn.flush().await?;
304 let (mechanism, initial_response) = match conn.recv().await? {
306 Some(FrontendMessage::RawAuthentication(data)) => {
307 match decode_sasl_initial_response(Cursor::new(&data)).ok() {
308 Some(FrontendMessage::SASLInitialResponse {
309 gs2_header,
310 mechanism,
311 initial_response,
312 }) => {
313 if gs2_header.channel_binding_enabled() {
315 return conn
316 .send(ErrorResponse::fatal(
317 SqlState::PROTOCOL_VIOLATION,
318 "channel binding not supported",
319 ))
320 .await;
321 }
322 (mechanism, initial_response)
323 }
324 _ => {
325 return conn
326 .send(ErrorResponse::fatal(
327 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
328 "expected SASLInitialResponse message",
329 ))
330 .await;
331 }
332 }
333 }
334 _ => {
335 return conn
336 .send(ErrorResponse::fatal(
337 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
338 "expected SASLInitialResponse message",
339 ))
340 .await;
341 }
342 };
343
344 if mechanism != "SCRAM-SHA-256" {
345 return conn
346 .send(ErrorResponse::fatal(
347 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
348 "unsupported SASL mechanism",
349 ))
350 .await;
351 }
352
353 if initial_response.nonce.len() > 256 {
354 return conn
355 .send(ErrorResponse::fatal(
356 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
357 "nonce too long",
358 ))
359 .await;
360 }
361
362 let (server_first_message_raw, mock_hash) = match adapter_client
363 .generate_sasl_challenge(&user, &initial_response.nonce)
364 .await
365 {
366 Ok(response) => {
367 let server_first_message_raw = format!(
368 "r={},s={},i={}",
369 response.nonce, response.salt, response.iteration_count
370 );
371
372 let client_key = [0u8; 32];
373 let server_key = [1u8; 32];
374 let mock_hash = format!(
375 "SCRAM-SHA-256${}:{}${}:{}",
376 response.iteration_count,
377 response.salt,
378 BASE64_STANDARD.encode(client_key),
379 BASE64_STANDARD.encode(server_key)
380 );
381
382 conn.send(BackendMessage::AuthenticationSASLContinue(
383 SASLServerFirstMessage {
384 iteration_count: response.iteration_count,
385 nonce: response.nonce,
386 salt: response.salt,
387 },
388 ))
389 .await?;
390 conn.flush().await?;
391 (server_first_message_raw, mock_hash)
392 }
393 Err(e) => {
394 return conn.send(e.into_response(Severity::Fatal)).await;
395 }
396 };
397
398 let auth_resp = match conn.recv().await? {
399 Some(FrontendMessage::RawAuthentication(data)) => {
400 match decode_sasl_response(Cursor::new(&data)).ok() {
401 Some(FrontendMessage::SASLResponse(response)) => {
402 let auth_message = format!(
403 "{},{},{}",
404 initial_response.client_first_message_bare_raw,
405 server_first_message_raw,
406 response.client_final_message_bare_raw
407 );
408 if response.proof.len() > 1024 {
409 return conn
410 .send(ErrorResponse::fatal(
411 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
412 "proof too long",
413 ))
414 .await;
415 }
416 match adapter_client
417 .verify_sasl_proof(
418 &user,
419 &response.proof,
420 &auth_message,
421 &mock_hash,
422 )
423 .await
424 {
425 Ok(resp) => {
426 conn.send(BackendMessage::AuthenticationSASLFinal(
427 SASLServerFinalMessage {
428 kind: SASLServerFinalMessageKinds::Verifier(
429 resp.verifier,
430 ),
431 extensions: vec![],
432 },
433 ))
434 .await?;
435 conn.flush().await?;
436 resp.auth_resp
437 }
438 Err(_) => {
439 return conn
440 .send(ErrorResponse::fatal(
441 SqlState::INVALID_PASSWORD,
442 "invalid password",
443 ))
444 .await;
445 }
446 }
447 }
448 _ => {
449 return conn
450 .send(ErrorResponse::fatal(
451 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
452 "expected SASLResponse message",
453 ))
454 .await;
455 }
456 }
457 }
458 _ => {
459 return conn
460 .send(ErrorResponse::fatal(
461 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
462 "expected SASLResponse message",
463 ))
464 .await;
465 }
466 };
467
468 let session = adapter_client.new_session(SessionConfig {
469 conn_id: conn.conn_id().clone(),
470 uuid: conn_uuid,
471 user,
472 client_ip: conn.peer_addr().clone(),
473 external_metadata_rx: None,
474 internal_user_metadata: Some(InternalUserMetadata {
475 superuser: auth_resp.superuser,
476 }),
477 helm_chart_version,
478 });
479 let auth_session = pending().right_future();
481 (session, auth_session)
482 }
483 Authenticator::None => {
484 let session = adapter_client.new_session(SessionConfig {
485 conn_id: conn.conn_id().clone(),
486 uuid: conn_uuid,
487 user,
488 client_ip: conn.peer_addr().clone(),
489 external_metadata_rx: None,
490 internal_user_metadata: None,
491 helm_chart_version,
492 });
493 let auth_session = pending().right_future();
495 (session, auth_session)
496 }
497 };
498
499 let system_vars = adapter_client.get_system_vars().await;
500 for (name, value) in params {
501 let settings = match name.as_str() {
502 "options" => match parse_options(&value) {
503 Ok(opts) => opts,
504 Err(()) => {
505 session.add_notice(AdapterNotice::BadStartupSetting {
506 name,
507 reason: "could not parse".into(),
508 });
509 continue;
510 }
511 },
512 _ => vec![(name, value)],
513 };
514 for (key, val) in settings {
515 const LOCAL: bool = false;
516 if let Err(err) =
521 session
522 .vars_mut()
523 .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
524 {
525 session.add_notice(AdapterNotice::BadStartupSetting {
526 name: key,
527 reason: err.to_string(),
528 });
529 }
530 }
531 }
532 session
533 .vars_mut()
534 .end_transaction(EndTransactionAction::Commit);
535
536 let _guard = match active_connection_counter.allocate_connection(session.user()) {
537 Ok(drop_connection) => drop_connection,
538 Err(e) => {
539 let e: AdapterError = e.into();
540 return conn.send(e.into_response(Severity::Fatal)).await;
541 }
542 };
543
544 let mut adapter_client = match adapter_client.startup(session).await {
546 Ok(adapter_client) => adapter_client,
547 Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
548 };
549
550 let mut buf = vec![BackendMessage::AuthenticationOk];
551 for var in adapter_client.session().vars().notify_set() {
552 buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
553 }
554 buf.push(BackendMessage::BackendKeyData {
555 conn_id: adapter_client.session().conn_id().unhandled(),
556 secret_key: adapter_client.session().secret_key(),
557 });
558 buf.extend(
559 adapter_client
560 .session()
561 .drain_notices()
562 .into_iter()
563 .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
564 );
565 buf.push(BackendMessage::ReadyForQuery(
566 adapter_client.session().transaction().into(),
567 ));
568 conn.send_all(buf).await?;
569 conn.flush().await?;
570
571 let machine = StateMachine {
572 conn,
573 adapter_client,
574 txn_needs_commit: false,
575 tokio_metrics_intervals,
576 };
577
578 select! {
579 r = machine.run() => {
580 if let Err(err) = &r {
585 let _ = conn
586 .send(ErrorResponse::fatal(
587 SqlState::CONNECTION_FAILURE,
588 err.to_string(),
589 ))
590 .await;
591 let _ = conn.flush().await;
592 }
593 r
594 },
595 _ = expired => {
596 conn
597 .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
598 .await?;
599 conn.flush().await
600 }
601 }
602}
603
604fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
609 let opts = split_options(value);
610 let mut pairs = Vec::with_capacity(opts.len());
611 let mut seen_prefix = false;
612 for opt in opts {
613 if !seen_prefix {
614 if opt == "-c" {
615 seen_prefix = true;
616 } else {
617 let (key, val) = parse_option(&opt)?;
618 pairs.push((key.to_owned(), val.to_owned()));
619 }
620 } else {
621 let (key, val) = opt.split_once('=').ok_or(())?;
622 pairs.push((key.to_owned(), val.to_owned()));
623 seen_prefix = false;
624 }
625 }
626 Ok(pairs)
627}
628
629fn parse_option(option: &str) -> Result<(&str, &str), ()> {
633 let (key, value) = option.split_once('=').ok_or(())?;
634 for prefix in &["-c", "--"] {
635 if let Some(key) = key.strip_prefix(prefix) {
636 return Ok((key, value));
637 }
638 }
639 Err(())
640}
641
642fn split_options(value: &str) -> Vec<String> {
644 let mut strs = Vec::new();
645 let mut current = String::new();
649 let mut was_slash = false;
650 for c in value.chars() {
651 was_slash = match c {
652 ' ' => {
653 if was_slash {
654 current.push(' ');
655 } else if !current.is_empty() {
656 strs.push(std::mem::take(&mut current));
659 }
660 false
661 }
662 '\\' => {
663 if was_slash {
664 current.push('\\');
667 false
668 } else {
669 true
670 }
671 }
672 _ => {
673 current.push(c);
674 false
675 }
676 };
677 }
678 if !current.is_empty() {
680 strs.push(current);
681 }
682 strs
683}
684
685#[derive(Debug)]
686enum State {
687 Ready,
688 Drain,
689 Done,
690}
691
692struct StateMachine<'a, A, I>
693where
694 I: Iterator<Item = TaskMetrics> + Send + 'a,
695{
696 conn: &'a mut FramedConn<A>,
697 adapter_client: mz_adapter::SessionClient,
698 txn_needs_commit: bool,
699 tokio_metrics_intervals: I,
700}
701
702enum SendRowsEndedReason {
703 Success {
704 result_size: u64,
705 rows_returned: u64,
706 },
707 Errored {
708 error: String,
709 },
710 Canceled,
711}
712
713const ABORTED_TXN_MSG: &str =
714 "current transaction is aborted, commands ignored until end of transaction block";
715
716impl<'a, A, I> StateMachine<'a, A, I>
717where
718 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
719 I: Iterator<Item = TaskMetrics> + Send + 'a,
720{
721 #[allow(clippy::manual_async_fn)]
725 #[mz_ore::instrument(level = "debug")]
726 fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
727 async move {
728 let mut state = State::Ready;
729 loop {
730 self.send_pending_notices().await?;
731 state = match state {
732 State::Ready => self.advance_ready().await?,
733 State::Drain => self.advance_drain().await?,
734 State::Done => return Ok(()),
735 };
736 self.adapter_client
737 .add_idle_in_transaction_session_timeout();
738 }
739 }
740 }
741
742 #[instrument(level = "debug")]
743 async fn advance_ready(&mut self) -> Result<State, io::Error> {
744 self.tokio_metrics_intervals
746 .next()
747 .expect("infinite iterator");
748
749 let message = select! {
751 biased;
752
753 Some(timeout) = self.adapter_client.recv_timeout() => {
755 let err: AdapterError = timeout.into();
756 let conn_id = self.adapter_client.session().conn_id();
757 tracing::warn!("session timed out, conn_id {}", conn_id);
758
759 let error_response = err.into_response(Severity::Fatal);
761 let error_state = self.error(error_response).await;
762
763 self.adapter_client.terminate().await;
765
766 let _ = self.conn.recv().await?;
770 return error_state;
771 },
772 message = self.conn.recv() => message?,
774 };
775
776 let interval = self
778 .tokio_metrics_intervals
779 .next()
780 .expect("infinite iterator");
781 let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
782
783 let received = SYSTEM_TIME();
787
788 self.adapter_client
789 .remove_idle_in_transaction_session_timeout();
790
791 let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
794
795 let start = message.as_ref().map(|_| Instant::now());
796 let next_state = match message {
797 Some(FrontendMessage::Query { sql }) => {
798 let query_root_span =
799 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
800 query_root_span.follows_from(tracing::Span::current());
801 self.query(sql, received)
802 .instrument(query_root_span)
803 .await?
804 }
805 Some(FrontendMessage::Parse {
806 name,
807 sql,
808 param_types,
809 }) => self.parse(name, sql, param_types).await?,
810 Some(FrontendMessage::Bind {
811 portal_name,
812 statement_name,
813 param_formats,
814 raw_params,
815 result_formats,
816 }) => {
817 self.bind(
818 portal_name,
819 statement_name,
820 param_formats,
821 raw_params,
822 result_formats,
823 )
824 .await?
825 }
826 Some(FrontendMessage::Execute {
827 portal_name,
828 max_rows,
829 }) => {
830 let max_rows = match usize::try_from(max_rows) {
831 Ok(0) | Err(_) => ExecuteCount::All, Ok(n) => ExecuteCount::Count(n),
833 };
834 let execute_root_span =
835 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
836 execute_root_span.follows_from(tracing::Span::current());
837 let state = self
838 .execute(
839 portal_name,
840 max_rows,
841 portal_exec_message,
842 None,
843 ExecuteTimeout::None,
844 None,
845 Some(received),
846 )
847 .instrument(execute_root_span)
848 .await?;
849 if self.adapter_client.session().transaction().is_implicit() {
864 self.txn_needs_commit = true;
865 }
866 state
867 }
868 Some(FrontendMessage::DescribeStatement { name }) => {
869 self.describe_statement(&name).await?
870 }
871 Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
872 Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
873 Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
874 Some(FrontendMessage::Flush) => self.flush().await?,
875 Some(FrontendMessage::Sync) => self.sync().await?,
876 Some(FrontendMessage::Terminate) => State::Done,
877
878 Some(FrontendMessage::CopyData(_))
879 | Some(FrontendMessage::CopyDone)
880 | Some(FrontendMessage::CopyFail(_))
881 | Some(FrontendMessage::Password { .. })
882 | Some(FrontendMessage::RawAuthentication(_))
883 | Some(FrontendMessage::SASLInitialResponse { .. })
884 | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
885 None => State::Done,
886 };
887
888 if let Some(start) = start {
889 self.adapter_client
890 .inner()
891 .metrics()
892 .pgwire_message_processing_seconds
893 .with_label_values(&[message_name])
894 .observe(start.elapsed().as_secs_f64());
895 }
896 self.adapter_client
897 .inner()
898 .metrics()
899 .pgwire_recv_scheduling_delay_ms
900 .with_label_values(&[message_name])
901 .observe(recv_scheduling_delay_ms);
902
903 Ok(next_state)
904 }
905
906 async fn advance_drain(&mut self) -> Result<State, io::Error> {
907 let message = self.conn.recv().await?;
908 if message.is_some() {
909 self.adapter_client
910 .remove_idle_in_transaction_session_timeout();
911 }
912 match message {
913 Some(FrontendMessage::Sync) => self.sync().await,
914 None => Ok(State::Done),
915 _ => Ok(State::Drain),
916 }
917 }
918
919 #[instrument(level = "debug")]
923 async fn one_query(
924 &mut self,
925 stmt: Statement<Raw>,
926 sql: String,
927 lifecycle_timestamps: LifecycleTimestamps,
928 ) -> Result<State, io::Error> {
929 const EMPTY_PORTAL: &str = "";
932 if let Err(e) = self
933 .adapter_client
934 .declare(EMPTY_PORTAL.to_string(), stmt, sql)
935 .await
936 {
937 return self.error(e.into_response(Severity::Error)).await;
938 }
939 let portal = self
940 .adapter_client
941 .session()
942 .get_portal_unverified_mut(EMPTY_PORTAL)
943 .expect("unnamed portal should be present");
944
945 *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
946
947 let stmt_desc = portal.desc.clone();
948 if !stmt_desc.param_types.is_empty() {
949 return self
950 .error(ErrorResponse::error(
951 SqlState::UNDEFINED_PARAMETER,
952 "there is no parameter $1",
953 ))
954 .await;
955 }
956
957 if let Some(relation_desc) = &stmt_desc.relation_desc {
959 if !stmt_desc.is_copy {
960 let formats = vec![Format::Text; stmt_desc.arity()];
961 self.send(BackendMessage::RowDescription(
962 message::encode_row_description(relation_desc, &formats),
963 ))
964 .await?;
965 }
966 }
967
968 let result = match self
969 .adapter_client
970 .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
971 .await
972 {
973 Ok((response, execute_started)) => {
974 self.send_pending_notices().await?;
975 self.send_execute_response(
976 response,
977 stmt_desc.relation_desc,
978 EMPTY_PORTAL.to_string(),
979 ExecuteCount::All,
980 portal_exec_message,
981 None,
982 ExecuteTimeout::None,
983 execute_started,
984 )
985 .await
986 }
987 Err(e) => {
988 self.send_pending_notices().await?;
989 self.error(e.into_response(Severity::Error)).await
990 }
991 };
992
993 self.adapter_client.session().remove_portal(EMPTY_PORTAL);
995
996 result
997 }
998
999 async fn ensure_transaction(
1000 &mut self,
1001 num_stmts: usize,
1002 message_type: &str,
1003 ) -> Result<(), io::Error> {
1004 let start = Instant::now();
1005 if self.txn_needs_commit {
1006 self.commit_transaction().await?;
1007 }
1008 let res = self.adapter_client.start_transaction(Some(num_stmts));
1011 assert_ok!(res);
1012 self.adapter_client
1013 .inner()
1014 .metrics()
1015 .pgwire_ensure_transaction_seconds
1016 .with_label_values(&[message_type])
1017 .observe(start.elapsed().as_secs_f64());
1018 Ok(())
1019 }
1020
1021 fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1022 let parse_start = Instant::now();
1023 let result = match self.adapter_client.parse(sql) {
1024 Ok(result) => result.map_err(|e| {
1025 let pos = sql[..e.error.pos].chars().count() + 1;
1028 ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1029 }),
1030 Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1031 };
1032 self.adapter_client
1033 .inner()
1034 .metrics()
1035 .parse_seconds
1036 .observe(parse_start.elapsed().as_secs_f64());
1037 result
1038 }
1039
1040 #[instrument(level = "debug")]
1045 async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1046 let stmts = match self.parse_sql(&sql) {
1048 Ok(stmts) => stmts,
1049 Err(err) => {
1050 self.error(err).await?;
1051 return self.ready().await;
1052 }
1053 };
1054
1055 let num_stmts = stmts.len();
1056
1057 for StatementParseResult { ast: stmt, sql } in stmts {
1059 if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1061 self.aborted_txn_error().await?;
1062 break;
1063 }
1064
1065 self.ensure_transaction(num_stmts, "query").await?;
1073
1074 match self
1075 .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1076 .await?
1077 {
1078 State::Ready => (),
1079 State::Drain => break,
1080 State::Done => return Ok(State::Done),
1081 }
1082 }
1083
1084 {
1086 if self.adapter_client.session().transaction().is_implicit() {
1087 self.commit_transaction().await?;
1088 }
1089 }
1090
1091 if num_stmts == 0 {
1092 self.send(BackendMessage::EmptyQueryResponse).await?;
1093 }
1094
1095 self.ready().await
1096 }
1097
1098 #[instrument(level = "debug")]
1099 async fn parse(
1100 &mut self,
1101 name: String,
1102 sql: String,
1103 param_oids: Vec<u32>,
1104 ) -> Result<State, io::Error> {
1105 self.ensure_transaction(1, "parse").await?;
1107
1108 let mut param_types = vec![];
1109 for oid in param_oids {
1110 match mz_pgrepr::Type::from_oid(oid) {
1111 Ok(ty) => match SqlScalarType::try_from(&ty) {
1112 Ok(ty) => param_types.push(Some(ty)),
1113 Err(err) => {
1114 return self
1115 .error(ErrorResponse::error(
1116 SqlState::INVALID_PARAMETER_VALUE,
1117 err.to_string(),
1118 ))
1119 .await;
1120 }
1121 },
1122 Err(_) if oid == 0 => param_types.push(None),
1123 Err(e) => {
1124 return self
1125 .error(ErrorResponse::error(
1126 SqlState::PROTOCOL_VIOLATION,
1127 e.to_string(),
1128 ))
1129 .await;
1130 }
1131 }
1132 }
1133
1134 let stmts = match self.parse_sql(&sql) {
1135 Ok(stmts) => stmts,
1136 Err(err) => {
1137 return self.error(err).await;
1138 }
1139 };
1140 if stmts.len() > 1 {
1141 return self
1142 .error(ErrorResponse::error(
1143 SqlState::INTERNAL_ERROR,
1144 "cannot insert multiple commands into a prepared statement",
1145 ))
1146 .await;
1147 }
1148 let (maybe_stmt, sql) = match stmts.into_iter().next() {
1149 None => (None, ""),
1150 Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1151 };
1152 if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1153 return self.aborted_txn_error().await;
1154 }
1155 match self
1156 .adapter_client
1157 .prepare(name, maybe_stmt, sql.to_string(), param_types)
1158 .await
1159 {
1160 Ok(()) => {
1161 self.send(BackendMessage::ParseComplete).await?;
1162 Ok(State::Ready)
1163 }
1164 Err(e) => self.error(e.into_response(Severity::Error)).await,
1165 }
1166 }
1167
1168 #[instrument(level = "debug")]
1170 async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1171 self.end_transaction(EndTransactionAction::Commit).await
1172 }
1173
1174 #[instrument(level = "debug")]
1176 async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1177 self.end_transaction(EndTransactionAction::Rollback).await
1178 }
1179
1180 #[instrument(level = "debug")]
1182 async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1183 self.txn_needs_commit = false;
1184 let resp = self.adapter_client.end_transaction(action).await;
1185 if let Err(err) = resp {
1186 self.send(BackendMessage::ErrorResponse(
1187 err.into_response(Severity::Error),
1188 ))
1189 .await?;
1190 }
1191 Ok(())
1192 }
1193
1194 #[instrument(level = "debug")]
1195 async fn bind(
1196 &mut self,
1197 portal_name: String,
1198 statement_name: String,
1199 param_formats: Vec<Format>,
1200 raw_params: Vec<Option<Vec<u8>>>,
1201 result_formats: Vec<Format>,
1202 ) -> Result<State, io::Error> {
1203 self.ensure_transaction(1, "bind").await?;
1205
1206 let aborted_txn = self.is_aborted_txn();
1207 let stmt = match self
1208 .adapter_client
1209 .get_prepared_statement(&statement_name)
1210 .await
1211 {
1212 Ok(stmt) => stmt,
1213 Err(err) => return self.error(err.into_response(Severity::Error)).await,
1214 };
1215
1216 let param_types = &stmt.desc().param_types;
1217 if param_types.len() != raw_params.len() {
1218 let message = format!(
1219 "bind message supplies {actual} parameters, \
1220 but prepared statement \"{name}\" requires {expected}",
1221 name = statement_name,
1222 actual = raw_params.len(),
1223 expected = param_types.len()
1224 );
1225 return self
1226 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, message))
1227 .await;
1228 }
1229 let param_formats = match pad_formats(param_formats, raw_params.len()) {
1230 Ok(param_formats) => param_formats,
1231 Err(msg) => {
1232 return self
1233 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1234 .await;
1235 }
1236 };
1237 if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1238 return self.aborted_txn_error().await;
1239 }
1240 let buf = RowArena::new();
1241 let mut params = vec![];
1242 for ((raw_param, mz_typ), format) in raw_params
1243 .into_iter()
1244 .zip_eq(param_types)
1245 .zip_eq(param_formats)
1246 {
1247 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1248 let datum = match raw_param {
1249 None => Datum::Null,
1250 Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1251 Ok(param) => param.into_datum(&buf, &pg_typ),
1252 Err(err) => {
1253 let msg = format!("unable to decode parameter: {}", err);
1254 return self
1255 .error(ErrorResponse::error(SqlState::INVALID_PARAMETER_VALUE, msg))
1256 .await;
1257 }
1258 },
1259 };
1260 params.push((datum, mz_typ.clone()))
1261 }
1262
1263 let result_formats = match pad_formats(
1264 result_formats,
1265 stmt.desc()
1266 .relation_desc
1267 .clone()
1268 .map(|desc| desc.typ().column_types.len())
1269 .unwrap_or(0),
1270 ) {
1271 Ok(result_formats) => result_formats,
1272 Err(msg) => {
1273 return self
1274 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1275 .await;
1276 }
1277 };
1278
1279 if !stmt.stmt().map_or(false, |stmt| {
1282 matches!(
1283 stmt,
1284 Statement::Copy(CopyStatement {
1285 direction: CopyDirection::To,
1286 ..
1287 })
1288 )
1289 }) {
1290 if let Some(desc) = stmt.desc().relation_desc.clone() {
1291 for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1292 match (format, &ty.scalar_type) {
1293 (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1294 return self
1295 .error(ErrorResponse::error(
1296 SqlState::PROTOCOL_VIOLATION,
1297 "binary encoding of list types is not implemented",
1298 ))
1299 .await;
1300 }
1301 (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1302 return self
1303 .error(ErrorResponse::error(
1304 SqlState::PROTOCOL_VIOLATION,
1305 "binary encoding of map types is not implemented",
1306 ))
1307 .await;
1308 }
1309 (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1310 return self
1311 .error(ErrorResponse::error(
1312 SqlState::PROTOCOL_VIOLATION,
1313 "binary encoding of aclitem types does not exist",
1314 ))
1315 .await;
1316 }
1317 _ => (),
1318 }
1319 }
1320 }
1321 }
1322
1323 let desc = stmt.desc().clone();
1324 let logging = Arc::clone(stmt.logging());
1325 let stmt_ast = stmt.stmt().cloned();
1326 let state_revision = stmt.state_revision;
1327 if let Err(err) = self.adapter_client.session().set_portal(
1328 portal_name,
1329 desc,
1330 stmt_ast,
1331 logging,
1332 params,
1333 result_formats,
1334 state_revision,
1335 ) {
1336 return self.error(err.into_response(Severity::Error)).await;
1337 }
1338
1339 self.send(BackendMessage::BindComplete).await?;
1340 Ok(State::Ready)
1341 }
1342
1343 fn execute(
1346 &mut self,
1347 portal_name: String,
1348 max_rows: ExecuteCount,
1349 get_response: GetResponse,
1350 fetch_portal_name: Option<String>,
1351 timeout: ExecuteTimeout,
1352 outer_ctx_extra: Option<ExecuteContextExtra>,
1353 received: Option<EpochMillis>,
1354 ) -> BoxFuture<'_, Result<State, io::Error>> {
1355 async move {
1356 let aborted_txn = self.is_aborted_txn();
1357
1358 let portal = match self
1360 .adapter_client
1361 .session()
1362 .get_portal_unverified_mut(&portal_name)
1363 {
1364 Some(portal) => portal,
1365 None => {
1366 let msg = format!("portal {} does not exist", portal_name.quoted());
1367 if let Some(outer_ctx_extra) = outer_ctx_extra {
1368 self.adapter_client.retire_execute(
1369 outer_ctx_extra,
1370 StatementEndedExecutionReason::Errored { error: msg.clone() },
1371 );
1372 }
1373 return self
1374 .error(ErrorResponse::error(SqlState::INVALID_CURSOR_NAME, msg))
1375 .await;
1376 }
1377 };
1378
1379 *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1380
1381 let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1383 if aborted_txn && !txn_exit_stmt {
1384 if let Some(outer_ctx_extra) = outer_ctx_extra {
1385 self.adapter_client.retire_execute(
1386 outer_ctx_extra,
1387 StatementEndedExecutionReason::Errored {
1388 error: ABORTED_TXN_MSG.to_string(),
1389 },
1390 );
1391 }
1392 return self.aborted_txn_error().await;
1393 }
1394
1395 let row_desc = portal.desc.relation_desc.clone();
1396 match portal.state {
1397 PortalState::NotStarted => {
1398 self.ensure_transaction(1, "execute").await?;
1400 match self
1401 .adapter_client
1402 .execute(
1403 portal_name.clone(),
1404 self.conn.wait_closed(),
1405 outer_ctx_extra,
1406 )
1407 .await
1408 {
1409 Ok((response, execute_started)) => {
1410 self.send_pending_notices().await?;
1411 self.send_execute_response(
1412 response,
1413 row_desc,
1414 portal_name,
1415 max_rows,
1416 get_response,
1417 fetch_portal_name,
1418 timeout,
1419 execute_started,
1420 )
1421 .await
1422 }
1423 Err(e) => {
1424 self.send_pending_notices().await?;
1425 self.error(e.into_response(Severity::Error)).await
1426 }
1427 }
1428 }
1429 PortalState::InProgress(rows) => {
1430 let rows = rows.take().expect("InProgress rows must be populated");
1431 let (result, statement_ended_execution_reason) = match self
1432 .send_rows(
1433 row_desc.expect("portal missing row desc on resumption"),
1434 portal_name,
1435 rows,
1436 max_rows,
1437 get_response,
1438 fetch_portal_name,
1439 timeout,
1440 )
1441 .await
1442 {
1443 Err(e) => {
1444 (Err(e), StatementEndedExecutionReason::Canceled)
1447 }
1448 Ok((ok, SendRowsEndedReason::Canceled)) => {
1449 (Ok(ok), StatementEndedExecutionReason::Canceled)
1450 }
1451 Ok((
1464 ok,
1465 SendRowsEndedReason::Success {
1466 result_size: _,
1467 rows_returned: _,
1468 },
1469 )) => (
1470 Ok(ok),
1471 StatementEndedExecutionReason::Success {
1472 result_size: None,
1473 rows_returned: None,
1474 execution_strategy: None,
1475 },
1476 ),
1477 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1478 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1479 }
1480 };
1481 if let Some(outer_ctx_extra) = outer_ctx_extra {
1482 self.adapter_client
1483 .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1484 }
1485 result
1486 }
1487 PortalState::Completed(Some(tag)) => {
1494 let tag = tag.to_string();
1495 if let Some(outer_ctx_extra) = outer_ctx_extra {
1496 self.adapter_client.retire_execute(
1497 outer_ctx_extra,
1498 StatementEndedExecutionReason::Success {
1499 result_size: None,
1500 rows_returned: None,
1501 execution_strategy: None,
1502 },
1503 );
1504 }
1505 self.send(BackendMessage::CommandComplete { tag }).await?;
1506 Ok(State::Ready)
1507 }
1508 PortalState::Completed(None) => {
1509 let error = format!(
1510 "portal {} cannot be run",
1511 Ident::new_unchecked(portal_name).to_ast_string_stable()
1512 );
1513 if let Some(outer_ctx_extra) = outer_ctx_extra {
1514 self.adapter_client.retire_execute(
1515 outer_ctx_extra,
1516 StatementEndedExecutionReason::Errored {
1517 error: error.clone(),
1518 },
1519 );
1520 }
1521 self.error(ErrorResponse::error(
1522 SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1523 error,
1524 ))
1525 .await
1526 }
1527 }
1528 }
1529 .instrument(debug_span!("execute"))
1530 .boxed()
1531 }
1532
1533 #[instrument(level = "debug")]
1534 async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1535 self.ensure_transaction(1, "describe_statement").await?;
1537
1538 let stmt = match self.adapter_client.get_prepared_statement(name).await {
1539 Ok(stmt) => stmt,
1540 Err(err) => return self.error(err.into_response(Severity::Error)).await,
1541 };
1542 let parameter_desc = BackendMessage::ParameterDescription(
1544 stmt.desc()
1545 .param_types
1546 .iter()
1547 .map(mz_pgrepr::Type::from)
1548 .collect(),
1549 );
1550 let formats = vec![Format::Text; stmt.desc().arity()];
1554 let row_desc = describe_rows(stmt.desc(), &formats);
1555 self.send_all([parameter_desc, row_desc]).await?;
1556 Ok(State::Ready)
1557 }
1558
1559 #[instrument(level = "debug")]
1560 async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1561 self.ensure_transaction(1, "describe_portal").await?;
1563
1564 let session = self.adapter_client.session();
1565 let row_desc = session
1566 .get_portal_unverified(name)
1567 .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1568 match row_desc {
1569 Some(row_desc) => {
1570 self.send(row_desc).await?;
1571 Ok(State::Ready)
1572 }
1573 None => {
1574 self.error(ErrorResponse::error(
1575 SqlState::INVALID_CURSOR_NAME,
1576 format!("portal {} does not exist", name.quoted()),
1577 ))
1578 .await
1579 }
1580 }
1581 }
1582
1583 #[instrument(level = "debug")]
1584 async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1585 self.adapter_client
1586 .session()
1587 .remove_prepared_statement(&name);
1588 self.send(BackendMessage::CloseComplete).await?;
1589 Ok(State::Ready)
1590 }
1591
1592 #[instrument(level = "debug")]
1593 async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1594 self.adapter_client.session().remove_portal(&name);
1595 self.send(BackendMessage::CloseComplete).await?;
1596 Ok(State::Ready)
1597 }
1598
1599 fn complete_portal(&mut self, name: &str) {
1600 let portal = self
1601 .adapter_client
1602 .session()
1603 .get_portal_unverified_mut(name)
1604 .expect("portal should exist");
1605 *portal.state = PortalState::Completed(None);
1606 }
1607
1608 async fn fetch(
1609 &mut self,
1610 name: String,
1611 count: Option<FetchDirection>,
1612 max_rows: ExecuteCount,
1613 fetch_portal_name: Option<String>,
1614 timeout: ExecuteTimeout,
1615 ctx_extra: ExecuteContextExtra,
1616 ) -> Result<State, io::Error> {
1617 let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1620
1621 let count = match (max_rows, count) {
1634 (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1635 let count = usize::cast_from(count);
1636 if max_rows < count {
1637 let msg = "Execute with max_rows < a FETCH's count is not supported";
1638 self.adapter_client.retire_execute(
1639 ctx_extra,
1640 StatementEndedExecutionReason::Errored {
1641 error: msg.to_string(),
1642 },
1643 );
1644 return self
1645 .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1646 .await;
1647 }
1648 ExecuteCount::Count(count)
1649 }
1650 (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1651 let msg = "Execute with max_rows of a FETCH ALL is not supported";
1652 self.adapter_client.retire_execute(
1653 ctx_extra,
1654 StatementEndedExecutionReason::Errored {
1655 error: msg.to_string(),
1656 },
1657 );
1658 return self
1659 .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1660 .await;
1661 }
1662 (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1663 (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1664 ExecuteCount::Count(usize::cast_from(count))
1665 }
1666 };
1667 let cursor_name = name.to_string();
1668 self.execute(
1669 cursor_name,
1670 count,
1671 fetch_message,
1672 fetch_portal_name,
1673 timeout,
1674 Some(ctx_extra),
1675 None,
1676 )
1677 .await
1678 }
1679
1680 async fn flush(&mut self) -> Result<State, io::Error> {
1681 self.conn.flush().await?;
1682 Ok(State::Ready)
1683 }
1684
1685 #[instrument(level = "debug")]
1690 async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1691 where
1692 M: Into<BackendMessage>,
1693 {
1694 let message: BackendMessage = message.into();
1695 let is_error =
1696 matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1697
1698 self.conn.send(message).await?;
1699
1700 if is_error {
1707 self.conn.flush().await?;
1708 }
1709
1710 Ok(())
1711 }
1712
1713 #[instrument(level = "debug")]
1714 pub async fn send_all(
1715 &mut self,
1716 messages: impl IntoIterator<Item = BackendMessage>,
1717 ) -> Result<(), io::Error> {
1718 for m in messages {
1719 self.send(m).await?;
1720 }
1721 Ok(())
1722 }
1723
1724 #[instrument(level = "debug")]
1725 async fn sync(&mut self) -> Result<State, io::Error> {
1726 if self.adapter_client.session().transaction().is_implicit() {
1728 self.commit_transaction().await?;
1729 }
1730 self.ready().await
1731 }
1732
1733 #[instrument(level = "debug")]
1734 async fn ready(&mut self) -> Result<State, io::Error> {
1735 let txn_state = self.adapter_client.session().transaction().into();
1736 self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1737 self.flush().await
1738 }
1739
1740 #[allow(clippy::too_many_arguments)]
1741 #[instrument(level = "debug")]
1742 async fn send_execute_response(
1743 &mut self,
1744 response: ExecuteResponse,
1745 row_desc: Option<RelationDesc>,
1746 portal_name: String,
1747 max_rows: ExecuteCount,
1748 get_response: GetResponse,
1749 fetch_portal_name: Option<String>,
1750 timeout: ExecuteTimeout,
1751 execute_started: Instant,
1752 ) -> Result<State, io::Error> {
1753 let mut tag = response.tag();
1754
1755 macro_rules! command_complete {
1756 () => {{
1757 self.send(BackendMessage::CommandComplete {
1758 tag: tag
1759 .take()
1760 .expect("command_complete only called on tag-generating results"),
1761 })
1762 .await?;
1763 Ok(State::Ready)
1764 }};
1765 }
1766
1767 let r = match response {
1768 ExecuteResponse::ClosedCursor => {
1769 self.complete_portal(&portal_name);
1770 command_complete!()
1771 }
1772 ExecuteResponse::DeclaredCursor => {
1773 self.complete_portal(&portal_name);
1774 command_complete!()
1775 }
1776 ExecuteResponse::EmptyQuery => {
1777 self.send(BackendMessage::EmptyQueryResponse).await?;
1778 Ok(State::Ready)
1779 }
1780 ExecuteResponse::Fetch {
1781 name,
1782 count,
1783 timeout,
1784 ctx_extra,
1785 } => {
1786 self.fetch(
1787 name,
1788 count,
1789 max_rows,
1790 Some(portal_name.to_string()),
1791 timeout,
1792 ctx_extra,
1793 )
1794 .await
1795 }
1796 ExecuteResponse::SendingRowsStreaming {
1797 rows,
1798 instance_id,
1799 strategy,
1800 } => {
1801 let row_desc = row_desc
1802 .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1803
1804 let span = tracing::debug_span!("sending_rows_streaming");
1805
1806 self.send_rows(
1807 row_desc,
1808 portal_name,
1809 InProgressRows::new(RecordFirstRowStream::new(
1810 Box::new(rows),
1811 execute_started,
1812 &self.adapter_client,
1813 Some(instance_id),
1814 Some(strategy),
1815 )),
1816 max_rows,
1817 get_response,
1818 fetch_portal_name,
1819 timeout,
1820 )
1821 .instrument(span)
1822 .await
1823 .map(|(state, _)| state)
1824 }
1825 ExecuteResponse::SendingRowsImmediate { rows } => {
1826 let row_desc = row_desc
1827 .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1828
1829 let span = tracing::debug_span!("sending_rows_immediate");
1830
1831 let stream =
1832 futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1833 self.send_rows(
1834 row_desc,
1835 portal_name,
1836 InProgressRows::new(RecordFirstRowStream::new(
1837 Box::new(stream),
1838 execute_started,
1839 &self.adapter_client,
1840 None,
1841 Some(StatementExecutionStrategy::Constant),
1842 )),
1843 max_rows,
1844 get_response,
1845 fetch_portal_name,
1846 timeout,
1847 )
1848 .instrument(span)
1849 .await
1850 .map(|(state, _)| state)
1851 }
1852 ExecuteResponse::SetVariable { name, .. } => {
1853 let qn = name.to_string();
1856 let msg = if let Some(var) = self
1857 .adapter_client
1858 .session()
1859 .vars_mut()
1860 .notify_set()
1861 .find(|v| v.name() == qn)
1862 {
1863 Some(BackendMessage::ParameterStatus(var.name(), var.value()))
1864 } else {
1865 None
1866 };
1867 if let Some(msg) = msg {
1868 self.send(msg).await?;
1869 }
1870 command_complete!()
1871 }
1872 ExecuteResponse::Subscribing {
1873 rx,
1874 ctx_extra,
1875 instance_id,
1876 } => {
1877 if fetch_portal_name.is_none() {
1878 let mut msg = ErrorResponse::notice(
1879 SqlState::WARNING,
1880 "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
1881 );
1882 if self.adapter_client.session().vars().application_name() == "psql" {
1883 msg.hint = Some(
1884 "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
1885 .into(),
1886 )
1887 }
1888 self.send(msg).await?;
1889 self.conn.flush().await?;
1890 }
1891 let row_desc =
1892 row_desc.expect("missing row description for ExecuteResponse::Subscribing");
1893 let (result, statement_ended_execution_reason) = match self
1894 .send_rows(
1895 row_desc,
1896 portal_name,
1897 InProgressRows::new(RecordFirstRowStream::new(
1898 Box::new(UnboundedReceiverStream::new(rx)),
1899 execute_started,
1900 &self.adapter_client,
1901 Some(instance_id),
1902 None,
1903 )),
1904 max_rows,
1905 get_response,
1906 fetch_portal_name,
1907 timeout,
1908 )
1909 .await
1910 {
1911 Err(e) => {
1912 (Err(e), StatementEndedExecutionReason::Canceled)
1915 }
1916 Ok((ok, SendRowsEndedReason::Canceled)) => {
1917 (Ok(ok), StatementEndedExecutionReason::Canceled)
1918 }
1919 Ok((
1920 ok,
1921 SendRowsEndedReason::Success {
1922 result_size,
1923 rows_returned,
1924 },
1925 )) => (
1926 Ok(ok),
1927 StatementEndedExecutionReason::Success {
1928 result_size: Some(result_size),
1929 rows_returned: Some(rows_returned),
1930 execution_strategy: None,
1931 },
1932 ),
1933 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1934 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1935 }
1936 };
1937 self.adapter_client
1938 .retire_execute(ctx_extra, statement_ended_execution_reason);
1939 return result;
1940 }
1941 ExecuteResponse::CopyTo { format, resp } => {
1942 let row_desc =
1943 row_desc.expect("missing row description for ExecuteResponse::CopyTo");
1944 match *resp {
1945 ExecuteResponse::Subscribing {
1946 rx,
1947 ctx_extra,
1948 instance_id,
1949 } => {
1950 let (result, statement_ended_execution_reason) = match self
1951 .copy_rows(
1952 format,
1953 row_desc,
1954 RecordFirstRowStream::new(
1955 Box::new(UnboundedReceiverStream::new(rx)),
1956 execute_started,
1957 &self.adapter_client,
1958 Some(instance_id),
1959 None,
1960 ),
1961 )
1962 .await
1963 {
1964 Err(e) => {
1965 (Err(e), StatementEndedExecutionReason::Canceled)
1968 }
1969 Ok((
1970 state,
1971 SendRowsEndedReason::Success {
1972 result_size,
1973 rows_returned,
1974 },
1975 )) => (
1976 Ok(state),
1977 StatementEndedExecutionReason::Success {
1978 result_size: Some(result_size),
1979 rows_returned: Some(rows_returned),
1980 execution_strategy: None,
1981 },
1982 ),
1983 Ok((state, SendRowsEndedReason::Errored { error })) => {
1984 (Ok(state), StatementEndedExecutionReason::Errored { error })
1985 }
1986 Ok((state, SendRowsEndedReason::Canceled)) => {
1987 (Ok(state), StatementEndedExecutionReason::Canceled)
1988 }
1989 };
1990 self.adapter_client
1991 .retire_execute(ctx_extra, statement_ended_execution_reason);
1992 return result;
1993 }
1994 ExecuteResponse::SendingRowsStreaming {
1995 rows,
1996 instance_id,
1997 strategy,
1998 } => {
1999 return self
2004 .copy_rows(
2005 format,
2006 row_desc,
2007 RecordFirstRowStream::new(
2008 Box::new(rows),
2009 execute_started,
2010 &self.adapter_client,
2011 Some(instance_id),
2012 Some(strategy),
2013 ),
2014 )
2015 .await
2016 .map(|(state, _)| state);
2017 }
2018 ExecuteResponse::SendingRowsImmediate { rows } => {
2019 let span = tracing::debug_span!("sending_rows_immediate");
2020
2021 let rows = futures::stream::once(futures::future::ready(
2022 PeekResponseUnary::Rows(rows),
2023 ));
2024 return self
2029 .copy_rows(
2030 format,
2031 row_desc,
2032 RecordFirstRowStream::new(
2033 Box::new(rows),
2034 execute_started,
2035 &self.adapter_client,
2036 None,
2037 Some(StatementExecutionStrategy::Constant),
2038 ),
2039 )
2040 .instrument(span)
2041 .await
2042 .map(|(state, _)| state);
2043 }
2044 _ => {
2045 return self
2046 .error(ErrorResponse::error(
2047 SqlState::INTERNAL_ERROR,
2048 "unsupported COPY response type".to_string(),
2049 ))
2050 .await;
2051 }
2052 };
2053 }
2054 ExecuteResponse::CopyFrom {
2055 target_id,
2056 target_name,
2057 columns,
2058 params,
2059 ctx_extra,
2060 } => {
2061 let row_desc =
2062 row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2063 self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2064 .await
2065 }
2066 ExecuteResponse::TransactionCommitted { params }
2067 | ExecuteResponse::TransactionRolledBack { params } => {
2068 let notify_set: mz_ore::collections::HashSet<String> = self
2069 .adapter_client
2070 .session()
2071 .vars()
2072 .notify_set()
2073 .map(|v| v.name().to_string())
2074 .collect();
2075
2076 for (name, value) in params
2078 .into_iter()
2079 .filter(|(name, _v)| notify_set.contains(*name))
2080 {
2081 let msg = BackendMessage::ParameterStatus(name, value);
2082 self.send(msg).await?;
2083 }
2084 command_complete!()
2085 }
2086
2087 ExecuteResponse::AlteredDefaultPrivileges
2088 | ExecuteResponse::AlteredObject(..)
2089 | ExecuteResponse::AlteredRole
2090 | ExecuteResponse::AlteredSystemConfiguration
2091 | ExecuteResponse::CreatedCluster { .. }
2092 | ExecuteResponse::CreatedClusterReplica { .. }
2093 | ExecuteResponse::CreatedConnection { .. }
2094 | ExecuteResponse::CreatedDatabase { .. }
2095 | ExecuteResponse::CreatedIndex { .. }
2096 | ExecuteResponse::CreatedIntrospectionSubscribe
2097 | ExecuteResponse::CreatedMaterializedView { .. }
2098 | ExecuteResponse::CreatedContinualTask { .. }
2099 | ExecuteResponse::CreatedRole
2100 | ExecuteResponse::CreatedSchema { .. }
2101 | ExecuteResponse::CreatedSecret { .. }
2102 | ExecuteResponse::CreatedSink { .. }
2103 | ExecuteResponse::CreatedSource { .. }
2104 | ExecuteResponse::CreatedTable { .. }
2105 | ExecuteResponse::CreatedType
2106 | ExecuteResponse::CreatedView { .. }
2107 | ExecuteResponse::CreatedViews { .. }
2108 | ExecuteResponse::CreatedNetworkPolicy
2109 | ExecuteResponse::Comment
2110 | ExecuteResponse::Deallocate { .. }
2111 | ExecuteResponse::Deleted(..)
2112 | ExecuteResponse::DiscardedAll
2113 | ExecuteResponse::DiscardedTemp
2114 | ExecuteResponse::DroppedObject(_)
2115 | ExecuteResponse::DroppedOwned
2116 | ExecuteResponse::GrantedPrivilege
2117 | ExecuteResponse::GrantedRole
2118 | ExecuteResponse::Inserted(..)
2119 | ExecuteResponse::Copied(..)
2120 | ExecuteResponse::Prepare
2121 | ExecuteResponse::Raised
2122 | ExecuteResponse::ReassignOwned
2123 | ExecuteResponse::RevokedPrivilege
2124 | ExecuteResponse::RevokedRole
2125 | ExecuteResponse::StartedTransaction { .. }
2126 | ExecuteResponse::Updated(..)
2127 | ExecuteResponse::ValidatedConnection => {
2128 command_complete!()
2129 }
2130 };
2131
2132 assert_none!(tag, "tag created but not consumed: {:?}", tag);
2133 r
2134 }
2135
2136 #[allow(clippy::too_many_arguments)]
2137 #[mz_ore::instrument(level = "debug")]
2139 async fn send_rows(
2140 &mut self,
2141 row_desc: RelationDesc,
2142 portal_name: String,
2143 mut rows: InProgressRows,
2144 max_rows: ExecuteCount,
2145 get_response: GetResponse,
2146 fetch_portal_name: Option<String>,
2147 timeout: ExecuteTimeout,
2148 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2149 let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2152 name
2153 } else {
2154 &portal_name
2155 };
2156 let result_formats = self
2157 .adapter_client
2158 .session()
2159 .get_portal_unverified(result_format_portal_name)
2160 .expect("valid fetch portal name for send rows")
2161 .result_formats
2162 .clone();
2163
2164 let (mut wait_once, mut deadline) = match timeout {
2165 ExecuteTimeout::None => (false, None),
2166 ExecuteTimeout::Seconds(t) => (
2167 false,
2168 Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2169 ),
2170 ExecuteTimeout::WaitOnce => (true, None),
2171 };
2172
2173 {
2175 let portal_name_desc = &self
2176 .adapter_client
2177 .session()
2178 .get_portal_unverified(portal_name.as_str())
2179 .expect("portal should exist")
2180 .desc
2181 .relation_desc;
2182 if let Some(portal_name_desc) = portal_name_desc {
2183 soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2184 }
2185 if let Some(fetch_portal_name) = &fetch_portal_name {
2186 let fetch_portal_desc = &self
2187 .adapter_client
2188 .session()
2189 .get_portal_unverified(fetch_portal_name)
2190 .expect("portal should exist")
2191 .desc
2192 .relation_desc;
2193 if let Some(fetch_portal_desc) = fetch_portal_desc {
2194 soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2195 }
2196 }
2197 }
2198
2199 self.conn.set_encode_state(
2200 row_desc
2201 .typ()
2202 .column_types
2203 .iter()
2204 .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2205 .zip_eq(result_formats)
2206 .collect(),
2207 );
2208
2209 let mut total_sent_rows = 0;
2210 let mut total_sent_bytes = 0;
2211 let mut want_rows = match max_rows {
2213 ExecuteCount::All => usize::MAX,
2214 ExecuteCount::Count(count) => count,
2215 };
2216
2217 loop {
2219 let batch = if rows.current.is_some() {
2222 FetchResult::Rows(rows.current.take())
2223 } else if want_rows == 0 {
2224 FetchResult::Rows(None)
2225 } else {
2226 let notice_fut = self.adapter_client.session().recv_notice();
2227 tokio::select! {
2228 err = self.conn.wait_closed() => return Err(err),
2229 _ = time::sleep_until(deadline.unwrap_or_else(tokio::time::Instant::now)), if deadline.is_some() => FetchResult::Rows(None),
2230 notice = notice_fut => {
2231 FetchResult::Notice(notice)
2232 }
2233 batch = rows.remaining.recv() => match batch {
2234 None => FetchResult::Rows(None),
2235 Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2236 Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2237 Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2238 },
2239 }
2240 };
2241
2242 match batch {
2243 FetchResult::Rows(None) => break,
2244 FetchResult::Rows(Some(mut batch_rows)) => {
2245 if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2246 let msg = err.to_string();
2247 return self
2248 .error(err.into_response(Severity::Error))
2249 .await
2250 .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2251 }
2252
2253 if wait_once && batch_rows.peek().is_some() {
2257 deadline = Some(tokio::time::Instant::now());
2258 wait_once = false;
2259 }
2260
2261 let mut sent_rows = 0;
2263 let mut sent_bytes = 0;
2264 let messages = (&mut batch_rows)
2265 .map(|row| {
2270 let row_len = row.byte_len();
2271 let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2272 (row_len, BackendMessage::DataRow(values))
2273 })
2274 .inspect(|(row_len, _)| {
2275 sent_bytes += row_len;
2276 sent_rows += 1
2277 })
2278 .map(|(_row_len, row)| row)
2279 .take(want_rows);
2280 self.send_all(messages).await?;
2281
2282 total_sent_rows += sent_rows;
2283 total_sent_bytes += sent_bytes;
2284 want_rows -= sent_rows;
2285
2286 if want_rows == 0 {
2289 if batch_rows.peek().is_some() {
2290 rows.current = Some(batch_rows);
2291 }
2292 break;
2293 }
2294
2295 self.conn.flush().await?;
2296 }
2297 FetchResult::Notice(notice) => {
2298 self.send(notice.into_response()).await?;
2299 self.conn.flush().await?;
2300 }
2301 FetchResult::Error(text) => {
2302 return self
2303 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2304 .await
2305 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2306 }
2307 FetchResult::Canceled => {
2308 return self
2309 .error(ErrorResponse::error(
2310 SqlState::QUERY_CANCELED,
2311 "canceling statement due to user request",
2312 ))
2313 .await
2314 .map(|state| (state, SendRowsEndedReason::Canceled));
2315 }
2316 }
2317 }
2318
2319 let portal = self
2320 .adapter_client
2321 .session()
2322 .get_portal_unverified_mut(&portal_name)
2323 .expect("valid portal name for send rows");
2324
2325 let saw_rows = rows.remaining.saw_rows;
2326 let no_more_rows = rows.no_more_rows();
2327 let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2328
2329 *portal.state = PortalState::InProgress(Some(rows));
2332
2333 let fetch_portal = fetch_portal_name.map(|name| {
2334 self.adapter_client
2335 .session()
2336 .get_portal_unverified_mut(&name)
2337 .expect("valid fetch portal")
2338 });
2339 let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2340 self.send(response_message).await?;
2341
2342 if no_more_rows {
2344 let statement_type = if let Some(stmt) = &self
2345 .adapter_client
2346 .session()
2347 .get_portal_unverified(&portal_name)
2348 .expect("valid portal name for send_rows")
2349 .stmt
2350 {
2351 metrics::statement_type_label_value(stmt.deref())
2352 } else {
2353 "no-statement"
2354 };
2355 let duration = if saw_rows {
2356 recorded_first_row_instant
2357 .expect("recorded_first_row_instant because saw_rows")
2358 .elapsed()
2359 } else {
2360 Duration::ZERO
2364 };
2365 self.adapter_client
2366 .inner()
2367 .metrics()
2368 .result_rows_first_to_last_byte_seconds
2369 .with_label_values(&[statement_type])
2370 .observe(duration.as_secs_f64());
2371 }
2372
2373 Ok((
2374 State::Ready,
2375 SendRowsEndedReason::Success {
2376 result_size: u64::cast_from(total_sent_bytes),
2377 rows_returned: u64::cast_from(total_sent_rows),
2378 },
2379 ))
2380 }
2381
2382 #[mz_ore::instrument(level = "debug")]
2383 async fn copy_rows(
2384 &mut self,
2385 format: CopyFormat,
2386 row_desc: RelationDesc,
2387 mut stream: RecordFirstRowStream,
2388 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2389 let (row_format, encode_format) = match format {
2390 CopyFormat::Text => (
2391 CopyFormatParams::Text(CopyTextFormatParams::default()),
2392 Format::Text,
2393 ),
2394 CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2395 CopyFormat::Csv => (
2396 CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2397 Format::Text,
2398 ),
2399 CopyFormat::Parquet => {
2400 let text = "Parquet format is not supported".to_string();
2401 return self
2402 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2403 .await
2404 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2405 }
2406 };
2407
2408 let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2409 mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2410 };
2411
2412 let typ = row_desc.typ();
2413 let column_formats = iter::repeat(encode_format)
2414 .take(typ.column_types.len())
2415 .collect();
2416 self.send(BackendMessage::CopyOutResponse {
2417 overall_format: encode_format,
2418 column_formats,
2419 })
2420 .await?;
2421
2422 let mut out = Vec::new();
2427
2428 if let CopyFormat::Binary = format {
2429 out.extend(b"PGCOPY\n\xFF\r\n\0");
2431 out.extend([0, 0, 0, 0]);
2433 out.extend([0, 0, 0, 0]);
2435 }
2436
2437 let mut count = 0;
2438 let mut total_sent_bytes = 0;
2439 loop {
2440 tokio::select! {
2441 e = self.conn.wait_closed() => return Err(e),
2442 batch = stream.recv() => match batch {
2443 None => break,
2444 Some(PeekResponseUnary::Error(text)) => {
2445 return self
2446 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2447 .await
2448 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2449 }
2450 Some(PeekResponseUnary::Canceled) => {
2451 return self.error(ErrorResponse::error(
2452 SqlState::QUERY_CANCELED,
2453 "canceling statement due to user request",
2454 ))
2455 .await.map(|state| (state, SendRowsEndedReason::Canceled));
2456 }
2457 Some(PeekResponseUnary::Rows(mut rows)) => {
2458 count += rows.count();
2459 while let Some(row) = rows.next() {
2460 total_sent_bytes += row.byte_len();
2461 encode_fn(row, typ, &mut out)?;
2462 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2463 .await?;
2464 }
2465 }
2466 },
2467 notice = self.adapter_client.session().recv_notice() => {
2468 self.send(notice.into_response())
2469 .await?;
2470 self.conn.flush().await?;
2471 }
2472 }
2473
2474 self.conn.flush().await?;
2475 }
2476 if let CopyFormat::Binary = format {
2478 let trailer: i16 = -1;
2479 out.extend(trailer.to_be_bytes());
2480 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2481 .await?;
2482 }
2483
2484 let tag = format!("COPY {}", count);
2485 self.send(BackendMessage::CopyDone).await?;
2486 self.send(BackendMessage::CommandComplete { tag }).await?;
2487 Ok((
2488 State::Ready,
2489 SendRowsEndedReason::Success {
2490 result_size: u64::cast_from(total_sent_bytes),
2491 rows_returned: u64::cast_from(count),
2492 },
2493 ))
2494 }
2495
2496 #[instrument(level = "debug")]
2499 async fn copy_from(
2500 &mut self,
2501 target_id: CatalogItemId,
2502 target_name: String,
2503 columns: Vec<ColumnIndex>,
2504 params: CopyFormatParams<'_>,
2505 row_desc: RelationDesc,
2506 mut ctx_extra: ExecuteContextExtra,
2507 ) -> Result<State, io::Error> {
2508 let res = self
2509 .copy_from_inner(
2510 target_id,
2511 target_name,
2512 columns,
2513 params,
2514 row_desc,
2515 &mut ctx_extra,
2516 )
2517 .await;
2518 match &res {
2519 Ok(State::Done) => {
2520 self.adapter_client
2524 .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2525 }
2526 Err(e) => {
2527 self.adapter_client.retire_execute(
2528 ctx_extra,
2529 StatementEndedExecutionReason::Errored {
2530 error: format!("{e}"),
2531 },
2532 );
2533 }
2534 other => {
2535 tracing::warn!(?other, "aborting COPY FROM");
2536 self.adapter_client
2537 .retire_execute(ctx_extra, StatementEndedExecutionReason::Aborted);
2538 }
2539 }
2540 res
2541 }
2542
2543 async fn copy_from_inner(
2544 &mut self,
2545 target_id: CatalogItemId,
2546 target_name: String,
2547 columns: Vec<ColumnIndex>,
2548 params: CopyFormatParams<'_>,
2549 row_desc: RelationDesc,
2550 ctx_extra: &mut ExecuteContextExtra,
2551 ) -> Result<State, io::Error> {
2552 let typ = row_desc.typ();
2553 let column_formats = vec![Format::Text; typ.column_types.len()];
2554 self.send(BackendMessage::CopyInResponse {
2555 overall_format: Format::Text,
2556 column_formats,
2557 })
2558 .await?;
2559 self.conn.flush().await?;
2560
2561 let system_vars = self.adapter_client.get_system_vars().await;
2562 let max_size = system_vars
2563 .get(MAX_COPY_FROM_SIZE.name())
2564 .ok()
2565 .and_then(|max_size| max_size.value().parse().ok())
2566 .unwrap_or(usize::MAX);
2567 tracing::debug!("COPY FROM max buffer size: {max_size} bytes");
2568
2569 let mut data = Vec::new();
2570 loop {
2571 let message = self.conn.recv().await?;
2572 match message {
2573 Some(FrontendMessage::CopyData(buf)) => {
2574 if (data.len() + buf.len()) > max_size {
2576 return self
2577 .error(ErrorResponse::error(
2578 SqlState::INSUFFICIENT_RESOURCES,
2579 "COPY FROM STDIN too large",
2580 ))
2581 .await;
2582 }
2583 data.extend(buf)
2584 }
2585 Some(FrontendMessage::CopyDone) => break,
2586 Some(FrontendMessage::CopyFail(err)) => {
2587 self.adapter_client.retire_execute(
2588 std::mem::take(ctx_extra),
2589 StatementEndedExecutionReason::Canceled,
2590 );
2591 return self
2592 .error(ErrorResponse::error(
2593 SqlState::QUERY_CANCELED,
2594 format!("COPY from stdin failed: {}", err),
2595 ))
2596 .await;
2597 }
2598 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2599 Some(_) => {
2600 let msg = "unexpected message type during COPY from stdin";
2601 self.adapter_client.retire_execute(
2602 std::mem::take(ctx_extra),
2603 StatementEndedExecutionReason::Errored {
2604 error: msg.to_string(),
2605 },
2606 );
2607 return self
2608 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
2609 .await;
2610 }
2611 None => {
2612 return Ok(State::Done);
2613 }
2614 }
2615 }
2616
2617 let column_types = typ
2618 .column_types
2619 .iter()
2620 .map(|x| &x.scalar_type)
2621 .map(mz_pgrepr::Type::from)
2622 .collect::<Vec<mz_pgrepr::Type>>();
2623
2624 let rows = match mz_pgcopy::decode_copy_format(&data, &column_types, params) {
2625 Ok(rows) => rows,
2626 Err(e) => {
2627 self.adapter_client.retire_execute(
2628 std::mem::take(ctx_extra),
2629 StatementEndedExecutionReason::Errored {
2630 error: e.to_string(),
2631 },
2632 );
2633 return self
2634 .error(ErrorResponse::error(
2635 SqlState::BAD_COPY_FILE_FORMAT,
2636 format!("{}", e),
2637 ))
2638 .await;
2639 }
2640 };
2641
2642 let count = rows.len();
2643
2644 if let Err(e) = self
2645 .adapter_client
2646 .insert_rows(
2647 target_id,
2648 target_name,
2649 columns,
2650 rows,
2651 std::mem::take(ctx_extra),
2652 )
2653 .await
2654 {
2655 self.adapter_client.retire_execute(
2656 std::mem::take(ctx_extra),
2657 StatementEndedExecutionReason::Errored {
2658 error: e.to_string(),
2659 },
2660 );
2661 return self.error(e.into_response(Severity::Error)).await;
2662 }
2663
2664 let tag = format!("COPY {}", count);
2665 self.send(BackendMessage::CommandComplete { tag }).await?;
2666
2667 Ok(State::Ready)
2668 }
2669
2670 #[instrument(level = "debug")]
2671 async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
2672 let notices = self
2673 .adapter_client
2674 .session()
2675 .drain_notices()
2676 .into_iter()
2677 .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
2678 self.send_all(notices).await?;
2679 Ok(())
2680 }
2681
2682 #[instrument(level = "debug")]
2683 async fn error(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
2684 assert!(err.severity.is_error());
2685 debug!(
2686 "cid={} error code={}",
2687 self.adapter_client.session().conn_id(),
2688 err.code.code()
2689 );
2690 let is_fatal = err.severity.is_fatal();
2691 self.send(BackendMessage::ErrorResponse(err)).await?;
2692
2693 let txn = self.adapter_client.session().transaction();
2694 match txn {
2695 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
2698 TransactionStatus::Started(_) => {
2700 self.rollback_transaction().await?;
2701 }
2702 TransactionStatus::InTransactionImplicit(_) => {
2704 self.rollback_transaction().await?;
2705 }
2706 TransactionStatus::InTransaction(_) => {
2708 self.adapter_client.fail_transaction();
2709 }
2710 };
2711 if is_fatal {
2712 Ok(State::Done)
2713 } else {
2714 Ok(State::Drain)
2715 }
2716 }
2717
2718 #[instrument(level = "debug")]
2719 async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
2720 self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
2721 SqlState::IN_FAILED_SQL_TRANSACTION,
2722 ABORTED_TXN_MSG,
2723 )))
2724 .await?;
2725 Ok(State::Drain)
2726 }
2727
2728 fn is_aborted_txn(&mut self) -> bool {
2729 matches!(
2730 self.adapter_client.session().transaction(),
2731 TransactionStatus::Failed(_)
2732 )
2733 }
2734}
2735
2736fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
2737 match (formats.len(), n) {
2738 (0, e) => Ok(vec![Format::Text; e]),
2739 (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
2740 (a, e) if a == e => Ok(formats),
2741 (a, e) => Err(format!(
2742 "expected {} field format specifiers, but got {}",
2743 e, a
2744 )),
2745 }
2746}
2747
2748fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
2749 match &stmt_desc.relation_desc {
2750 Some(desc) if !stmt_desc.is_copy => {
2751 BackendMessage::RowDescription(message::encode_row_description(desc, formats))
2752 }
2753 _ => BackendMessage::NoData,
2754 }
2755}
2756
2757type GetResponse = fn(
2758 max_rows: ExecuteCount,
2759 total_sent_rows: usize,
2760 fetch_portal: Option<PortalRefMut>,
2761) -> BackendMessage;
2762
2763fn portal_exec_message(
2766 max_rows: ExecuteCount,
2767 total_sent_rows: usize,
2768 _fetch_portal: Option<PortalRefMut>,
2769) -> BackendMessage {
2770 match max_rows {
2777 ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
2778 BackendMessage::PortalSuspended
2779 }
2780 _ => BackendMessage::CommandComplete {
2781 tag: format!("SELECT {}", total_sent_rows),
2782 },
2783 }
2784}
2785
2786fn fetch_message(
2788 _max_rows: ExecuteCount,
2789 total_sent_rows: usize,
2790 fetch_portal: Option<PortalRefMut>,
2791) -> BackendMessage {
2792 let tag = format!("FETCH {}", total_sent_rows);
2793 if let Some(portal) = fetch_portal {
2794 *portal.state = PortalState::Completed(Some(tag.clone()));
2795 }
2796 BackendMessage::CommandComplete { tag }
2797}
2798
2799#[derive(Debug, Copy, Clone)]
2800enum ExecuteCount {
2801 All,
2802 Count(usize),
2803}
2804
2805fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
2807 match stmt {
2808 Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
2810 None => false,
2811 }
2812}
2813
2814#[derive(Debug)]
2815enum FetchResult {
2816 Rows(Option<Box<dyn RowIterator + Send + Sync>>),
2817 Canceled,
2818 Error(String),
2819 Notice(AdapterNotice),
2820}
2821
2822#[cfg(test)]
2823mod test {
2824 use super::*;
2825
2826 #[mz_ore::test]
2827 fn test_parse_options() {
2828 struct TestCase {
2829 input: &'static str,
2830 expect: Result<Vec<(&'static str, &'static str)>, ()>,
2831 }
2832 let tests = vec![
2833 TestCase {
2834 input: "",
2835 expect: Ok(vec![]),
2836 },
2837 TestCase {
2838 input: "--key",
2839 expect: Err(()),
2840 },
2841 TestCase {
2842 input: "--key=val",
2843 expect: Ok(vec![("key", "val")]),
2844 },
2845 TestCase {
2846 input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
2847 expect: Ok(vec![
2848 ("key", "val"),
2849 ("key2", "val2"),
2850 ("key3", "val3"),
2851 ("key4", "val4"),
2852 ("key5", "val5"),
2853 ]),
2854 },
2855 TestCase {
2856 input: r#"-c\ key=val"#,
2857 expect: Ok(vec![(" key", "val")]),
2858 },
2859 TestCase {
2860 input: "--key=val -ckey2 val2",
2861 expect: Err(()),
2862 },
2863 TestCase {
2865 input: "--key=",
2866 expect: Ok(vec![("key", "")]),
2867 },
2868 ];
2869 for test in tests {
2870 let got = parse_options(test.input);
2871 let expect = test.expect.map(|r| {
2872 r.into_iter()
2873 .map(|(k, v)| (k.to_owned(), v.to_owned()))
2874 .collect()
2875 });
2876 assert_eq!(got, expect, "input: {}", test.input);
2877 }
2878 }
2879
2880 #[mz_ore::test]
2881 fn test_parse_option() {
2882 struct TestCase {
2883 input: &'static str,
2884 expect: Result<(&'static str, &'static str), ()>,
2885 }
2886 let tests = vec![
2887 TestCase {
2888 input: "",
2889 expect: Err(()),
2890 },
2891 TestCase {
2892 input: "--",
2893 expect: Err(()),
2894 },
2895 TestCase {
2896 input: "--c",
2897 expect: Err(()),
2898 },
2899 TestCase {
2900 input: "a=b",
2901 expect: Err(()),
2902 },
2903 TestCase {
2904 input: "--a=b",
2905 expect: Ok(("a", "b")),
2906 },
2907 TestCase {
2908 input: "--ca=b",
2909 expect: Ok(("ca", "b")),
2910 },
2911 TestCase {
2912 input: "-ca=b",
2913 expect: Ok(("a", "b")),
2914 },
2915 TestCase {
2917 input: "--=",
2918 expect: Ok(("", "")),
2919 },
2920 ];
2921 for test in tests {
2922 let got = parse_option(test.input);
2923 assert_eq!(got, test.expect, "input: {}", test.input);
2924 }
2925 }
2926
2927 #[mz_ore::test]
2928 fn test_split_options() {
2929 struct TestCase {
2930 input: &'static str,
2931 expect: Vec<&'static str>,
2932 }
2933 let tests = vec![
2934 TestCase {
2935 input: "",
2936 expect: vec![],
2937 },
2938 TestCase {
2939 input: " ",
2940 expect: vec![],
2941 },
2942 TestCase {
2943 input: " a ",
2944 expect: vec!["a"],
2945 },
2946 TestCase {
2947 input: " ab cd ",
2948 expect: vec!["ab", "cd"],
2949 },
2950 TestCase {
2951 input: r#" ab\ cd "#,
2952 expect: vec!["ab ", "cd"],
2953 },
2954 TestCase {
2955 input: r#" ab\\ cd "#,
2956 expect: vec![r#"ab\"#, "cd"],
2957 },
2958 TestCase {
2959 input: r#" ab\\\ cd "#,
2960 expect: vec![r#"ab\ "#, "cd"],
2961 },
2962 TestCase {
2963 input: r#" ab\\\ cd "#,
2964 expect: vec![r#"ab\ cd"#],
2965 },
2966 TestCase {
2967 input: r#" ab\\\cd "#,
2968 expect: vec![r#"ab\cd"#],
2969 },
2970 TestCase {
2971 input: r#"a\"#,
2972 expect: vec!["a"],
2973 },
2974 TestCase {
2975 input: r#"a\ "#,
2976 expect: vec!["a "],
2977 },
2978 TestCase {
2979 input: r#"\"#,
2980 expect: vec![],
2981 },
2982 TestCase {
2983 input: r#"\ "#,
2984 expect: vec![r#" "#],
2985 },
2986 TestCase {
2987 input: r#" \ "#,
2988 expect: vec![r#" "#],
2989 },
2990 TestCase {
2991 input: r#"\ "#,
2992 expect: vec![r#" "#],
2993 },
2994 ];
2995 for test in tests {
2996 let got = split_options(test.input);
2997 assert_eq!(got, test.expect, "input: {}", test.input);
2998 }
2999 }
3000}