1use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use futures::future::{BoxFuture, FutureExt, pending};
21use itertools::Itertools;
22use mz_adapter::client::RecordFirstRowStream;
23use mz_adapter::session::{
24 EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState,
25 SessionConfig, TransactionStatus,
26};
27use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
28use mz_adapter::{
29 AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, PeekResponseUnary, metrics,
30 verify_datum_desc,
31};
32use mz_auth::password::Password;
33use mz_authenticator::Authenticator;
34use mz_ore::cast::CastFrom;
35use mz_ore::netio::AsyncReady;
36use mz_ore::now::{EpochMillis, SYSTEM_TIME};
37use mz_ore::str::StrExt;
38use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
39use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
40use mz_pgwire_common::{
41 ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
42 VERSIONS,
43};
44use mz_repr::user::InternalUserMetadata;
45use mz_repr::{
46 CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
47 SqlRelationType, SqlScalarType,
48};
49use mz_server_core::TlsMode;
50use mz_server_core::listeners::AllowedRoles;
51use mz_sql::ast::display::AstDisplay;
52use mz_sql::ast::{CopyDirection, CopyStatement, FetchDirection, Ident, Raw, Statement};
53use mz_sql::parse::StatementParseResult;
54use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
55use mz_sql::session::metadata::SessionMetadata;
56use mz_sql::session::user::INTERNAL_USER_NAMES;
57use mz_sql::session::vars::{MAX_COPY_FROM_SIZE, Var, VarInput};
58use postgres::error::SqlState;
59use tokio::io::{self, AsyncRead, AsyncWrite};
60use tokio::select;
61use tokio::time::{self};
62use tokio_metrics::TaskMetrics;
63use tokio_stream::wrappers::UnboundedReceiverStream;
64use tracing::{Instrument, debug, debug_span, warn};
65use uuid::Uuid;
66
67use crate::codec::{
68 FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
69};
70use crate::message::{
71 self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
72 SASLServerFirstMessage,
73};
74
75pub fn match_handshake(buf: &[u8]) -> bool {
79 if buf.len() < 8 {
89 return false;
90 }
91 let version = NetworkEndian::read_i32(&buf[4..8]);
92 VERSIONS.contains(&version)
93}
94
95pub struct RunParams<'a, A, I>
97where
98 I: Iterator<Item = TaskMetrics> + Send,
99{
100 pub tls_mode: Option<TlsMode>,
102 pub adapter_client: mz_adapter::Client,
104 pub conn: &'a mut FramedConn<A>,
106 pub conn_uuid: Uuid,
108 pub version: i32,
110 pub params: BTreeMap<String, String>,
112 pub authenticator: Authenticator,
114 pub active_connection_counter: ConnectionCounter,
116 pub helm_chart_version: Option<String>,
118 pub allowed_roles: AllowedRoles,
120 pub tokio_metrics_intervals: I,
122}
123
124#[mz_ore::instrument(level = "debug")]
134pub async fn run<'a, A, I>(
135 RunParams {
136 tls_mode,
137 adapter_client,
138 conn,
139 conn_uuid,
140 version,
141 mut params,
142 authenticator,
143 active_connection_counter,
144 helm_chart_version,
145 allowed_roles,
146 tokio_metrics_intervals,
147 }: RunParams<'a, A, I>,
148) -> Result<(), io::Error>
149where
150 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
151 I: Iterator<Item = TaskMetrics> + Send,
152{
153 if version != VERSION_3 {
154 return conn
155 .send(ErrorResponse::fatal(
156 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
157 "server does not support the client's requested protocol version",
158 ))
159 .await;
160 }
161
162 let user = params.remove("user").unwrap_or_else(String::new);
163
164 let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
166 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
168 let role_allowed = match allowed_roles {
169 AllowedRoles::Normal => !is_reserved_user,
170 AllowedRoles::Internal => is_internal_user,
171 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
172 };
173 if !role_allowed {
174 let msg = format!("unauthorized login to user '{user}'");
175 return conn
176 .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
177 .await;
178 }
179
180 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
181 return conn.send(err).await;
182 }
183
184 let (mut session, expired) = match authenticator {
185 Authenticator::Frontegg(frontegg) => {
186 conn.send(BackendMessage::AuthenticationCleartextPassword)
187 .await?;
188 conn.flush().await?;
189 let password = match conn.recv().await? {
190 Some(FrontendMessage::RawAuthentication(data)) => {
191 match decode_password(Cursor::new(&data)).ok() {
192 Some(FrontendMessage::Password { password }) => password,
193 _ => {
194 return conn
195 .send(ErrorResponse::fatal(
196 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
197 "expected Password message",
198 ))
199 .await;
200 }
201 }
202 }
203 _ => {
204 return conn
205 .send(ErrorResponse::fatal(
206 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
207 "expected Password message",
208 ))
209 .await;
210 }
211 };
212
213 let auth_response = frontegg.authenticate(&user, &password).await;
214 match auth_response {
215 Ok(mut auth_session) => {
216 let session = adapter_client.new_session(SessionConfig {
224 conn_id: conn.conn_id().clone(),
225 uuid: conn_uuid,
226 user: auth_session.user().into(),
227 client_ip: conn.peer_addr().clone(),
228 external_metadata_rx: Some(auth_session.external_metadata_rx()),
229 internal_user_metadata: None,
230 helm_chart_version,
231 });
232 let expired = async move { auth_session.expired().await };
233 (session, expired.left_future())
234 }
235 Err(err) => {
236 warn!(?err, "pgwire connection failed authentication");
237 return conn
238 .send(ErrorResponse::fatal(
239 SqlState::INVALID_PASSWORD,
240 "invalid password",
241 ))
242 .await;
243 }
244 }
245 }
246 Authenticator::Password(adapter_client) => {
247 conn.send(BackendMessage::AuthenticationCleartextPassword)
248 .await?;
249 conn.flush().await?;
250 let password = match conn.recv().await? {
251 Some(FrontendMessage::RawAuthentication(data)) => {
252 match decode_password(Cursor::new(&data)).ok() {
253 Some(FrontendMessage::Password { password }) => Password(password),
254 _ => {
255 return conn
256 .send(ErrorResponse::fatal(
257 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
258 "expected Password message",
259 ))
260 .await;
261 }
262 }
263 }
264 _ => {
265 return conn
266 .send(ErrorResponse::fatal(
267 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
268 "expected Password message",
269 ))
270 .await;
271 }
272 };
273 let auth_response = match adapter_client.authenticate(&user, &password).await {
274 Ok(resp) => resp,
275 Err(err) => {
276 warn!(?err, "pgwire connection failed authentication");
277 return conn
278 .send(ErrorResponse::fatal(
279 SqlState::INVALID_PASSWORD,
280 "invalid password",
281 ))
282 .await;
283 }
284 };
285 let session = adapter_client.new_session(SessionConfig {
286 conn_id: conn.conn_id().clone(),
287 uuid: conn_uuid,
288 user,
289 client_ip: conn.peer_addr().clone(),
290 external_metadata_rx: None,
291 internal_user_metadata: Some(InternalUserMetadata {
292 superuser: auth_response.superuser,
293 }),
294 helm_chart_version,
295 });
296 let auth_session = pending().right_future();
298 (session, auth_session)
299 }
300 Authenticator::Sasl(adapter_client) => {
301 conn.send(BackendMessage::AuthenticationSASL).await?;
303 conn.flush().await?;
304 let (mechanism, initial_response) = match conn.recv().await? {
306 Some(FrontendMessage::RawAuthentication(data)) => {
307 match decode_sasl_initial_response(Cursor::new(&data)).ok() {
308 Some(FrontendMessage::SASLInitialResponse {
309 gs2_header,
310 mechanism,
311 initial_response,
312 }) => {
313 if gs2_header.channel_binding_enabled() {
315 return conn
316 .send(ErrorResponse::fatal(
317 SqlState::PROTOCOL_VIOLATION,
318 "channel binding not supported",
319 ))
320 .await;
321 }
322 (mechanism, initial_response)
323 }
324 _ => {
325 return conn
326 .send(ErrorResponse::fatal(
327 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
328 "expected SASLInitialResponse message",
329 ))
330 .await;
331 }
332 }
333 }
334 _ => {
335 return conn
336 .send(ErrorResponse::fatal(
337 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
338 "expected SASLInitialResponse message",
339 ))
340 .await;
341 }
342 };
343
344 if mechanism != "SCRAM-SHA-256" {
345 return conn
346 .send(ErrorResponse::fatal(
347 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
348 "unsupported SASL mechanism",
349 ))
350 .await;
351 }
352
353 if initial_response.nonce.len() > 256 {
354 return conn
355 .send(ErrorResponse::fatal(
356 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
357 "nonce too long",
358 ))
359 .await;
360 }
361
362 let (server_first_message_raw, mock_hash) = match adapter_client
363 .generate_sasl_challenge(&user, &initial_response.nonce)
364 .await
365 {
366 Ok(response) => {
367 let server_first_message_raw = format!(
368 "r={},s={},i={}",
369 response.nonce, response.salt, response.iteration_count
370 );
371
372 let client_key = [0u8; 32];
373 let server_key = [1u8; 32];
374 let mock_hash = format!(
375 "SCRAM-SHA-256${}:{}${}:{}",
376 response.iteration_count,
377 response.salt,
378 BASE64_STANDARD.encode(client_key),
379 BASE64_STANDARD.encode(server_key)
380 );
381
382 conn.send(BackendMessage::AuthenticationSASLContinue(
383 SASLServerFirstMessage {
384 iteration_count: response.iteration_count,
385 nonce: response.nonce,
386 salt: response.salt,
387 },
388 ))
389 .await?;
390 conn.flush().await?;
391 (server_first_message_raw, mock_hash)
392 }
393 Err(e) => {
394 return conn.send(e.into_response(Severity::Fatal)).await;
395 }
396 };
397
398 let auth_resp = match conn.recv().await? {
399 Some(FrontendMessage::RawAuthentication(data)) => {
400 match decode_sasl_response(Cursor::new(&data)).ok() {
401 Some(FrontendMessage::SASLResponse(response)) => {
402 let auth_message = format!(
403 "{},{},{}",
404 initial_response.client_first_message_bare_raw,
405 server_first_message_raw,
406 response.client_final_message_bare_raw
407 );
408 if response.proof.len() > 1024 {
409 return conn
410 .send(ErrorResponse::fatal(
411 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
412 "proof too long",
413 ))
414 .await;
415 }
416 match adapter_client
417 .verify_sasl_proof(
418 &user,
419 &response.proof,
420 &auth_message,
421 &mock_hash,
422 )
423 .await
424 {
425 Ok(resp) => {
426 conn.send(BackendMessage::AuthenticationSASLFinal(
427 SASLServerFinalMessage {
428 kind: SASLServerFinalMessageKinds::Verifier(
429 resp.verifier,
430 ),
431 extensions: vec![],
432 },
433 ))
434 .await?;
435 conn.flush().await?;
436 resp.auth_resp
437 }
438 Err(_) => {
439 return conn
440 .send(ErrorResponse::fatal(
441 SqlState::INVALID_PASSWORD,
442 "invalid password",
443 ))
444 .await;
445 }
446 }
447 }
448 _ => {
449 return conn
450 .send(ErrorResponse::fatal(
451 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
452 "expected SASLResponse message",
453 ))
454 .await;
455 }
456 }
457 }
458 _ => {
459 return conn
460 .send(ErrorResponse::fatal(
461 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
462 "expected SASLResponse message",
463 ))
464 .await;
465 }
466 };
467
468 let session = adapter_client.new_session(SessionConfig {
469 conn_id: conn.conn_id().clone(),
470 uuid: conn_uuid,
471 user,
472 client_ip: conn.peer_addr().clone(),
473 external_metadata_rx: None,
474 internal_user_metadata: Some(InternalUserMetadata {
475 superuser: auth_resp.superuser,
476 }),
477 helm_chart_version,
478 });
479 let auth_session = pending().right_future();
481 (session, auth_session)
482 }
483 Authenticator::None => {
484 let session = adapter_client.new_session(SessionConfig {
485 conn_id: conn.conn_id().clone(),
486 uuid: conn_uuid,
487 user,
488 client_ip: conn.peer_addr().clone(),
489 external_metadata_rx: None,
490 internal_user_metadata: None,
491 helm_chart_version,
492 });
493 let auth_session = pending().right_future();
495 (session, auth_session)
496 }
497 };
498
499 let system_vars = adapter_client.get_system_vars().await;
500 for (name, value) in params {
501 let settings = match name.as_str() {
502 "options" => match parse_options(&value) {
503 Ok(opts) => opts,
504 Err(()) => {
505 session.add_notice(AdapterNotice::BadStartupSetting {
506 name,
507 reason: "could not parse".into(),
508 });
509 continue;
510 }
511 },
512 _ => vec![(name, value)],
513 };
514 for (key, val) in settings {
515 const LOCAL: bool = false;
516 if let Err(err) =
521 session
522 .vars_mut()
523 .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
524 {
525 session.add_notice(AdapterNotice::BadStartupSetting {
526 name: key,
527 reason: err.to_string(),
528 });
529 }
530 }
531 }
532 session
533 .vars_mut()
534 .end_transaction(EndTransactionAction::Commit);
535
536 let _guard = match active_connection_counter.allocate_connection(session.user()) {
537 Ok(drop_connection) => drop_connection,
538 Err(e) => {
539 let e: AdapterError = e.into();
540 return conn.send(e.into_response(Severity::Fatal)).await;
541 }
542 };
543
544 let mut adapter_client = match adapter_client.startup(session).await {
546 Ok(adapter_client) => adapter_client,
547 Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
548 };
549
550 let mut buf = vec![BackendMessage::AuthenticationOk];
551 for var in adapter_client.session().vars().notify_set() {
552 buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
553 }
554 buf.push(BackendMessage::BackendKeyData {
555 conn_id: adapter_client.session().conn_id().unhandled(),
556 secret_key: adapter_client.session().secret_key(),
557 });
558 buf.extend(
559 adapter_client
560 .session()
561 .drain_notices()
562 .into_iter()
563 .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
564 );
565 buf.push(BackendMessage::ReadyForQuery(
566 adapter_client.session().transaction().into(),
567 ));
568 conn.send_all(buf).await?;
569 conn.flush().await?;
570
571 let machine = StateMachine {
572 conn,
573 adapter_client,
574 txn_needs_commit: false,
575 tokio_metrics_intervals,
576 };
577
578 select! {
579 r = machine.run() => {
580 if let Err(err) = &r {
585 let _ = conn
586 .send(ErrorResponse::fatal(
587 SqlState::CONNECTION_FAILURE,
588 err.to_string(),
589 ))
590 .await;
591 let _ = conn.flush().await;
592 }
593 r
594 },
595 _ = expired => {
596 conn
597 .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
598 .await?;
599 conn.flush().await
600 }
601 }
602}
603
604fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
609 let opts = split_options(value);
610 let mut pairs = Vec::with_capacity(opts.len());
611 let mut seen_prefix = false;
612 for opt in opts {
613 if !seen_prefix {
614 if opt == "-c" {
615 seen_prefix = true;
616 } else {
617 let (key, val) = parse_option(&opt)?;
618 pairs.push((key.to_owned(), val.to_owned()));
619 }
620 } else {
621 let (key, val) = opt.split_once('=').ok_or(())?;
622 pairs.push((key.to_owned(), val.to_owned()));
623 seen_prefix = false;
624 }
625 }
626 Ok(pairs)
627}
628
629fn parse_option(option: &str) -> Result<(&str, &str), ()> {
633 let (key, value) = option.split_once('=').ok_or(())?;
634 for prefix in &["-c", "--"] {
635 if let Some(key) = key.strip_prefix(prefix) {
636 return Ok((key, value));
637 }
638 }
639 Err(())
640}
641
642fn split_options(value: &str) -> Vec<String> {
644 let mut strs = Vec::new();
645 let mut current = String::new();
649 let mut was_slash = false;
650 for c in value.chars() {
651 was_slash = match c {
652 ' ' => {
653 if was_slash {
654 current.push(' ');
655 } else if !current.is_empty() {
656 strs.push(std::mem::take(&mut current));
659 }
660 false
661 }
662 '\\' => {
663 if was_slash {
664 current.push('\\');
667 false
668 } else {
669 true
670 }
671 }
672 _ => {
673 current.push(c);
674 false
675 }
676 };
677 }
678 if !current.is_empty() {
680 strs.push(current);
681 }
682 strs
683}
684
685#[derive(Debug)]
686enum State {
687 Ready,
688 Drain,
689 Done,
690}
691
692struct StateMachine<'a, A, I>
693where
694 I: Iterator<Item = TaskMetrics> + Send + 'a,
695{
696 conn: &'a mut FramedConn<A>,
697 adapter_client: mz_adapter::SessionClient,
698 txn_needs_commit: bool,
699 tokio_metrics_intervals: I,
700}
701
702enum SendRowsEndedReason {
703 Success {
704 result_size: u64,
705 rows_returned: u64,
706 },
707 Errored {
708 error: String,
709 },
710 Canceled,
711}
712
713const ABORTED_TXN_MSG: &str =
714 "current transaction is aborted, commands ignored until end of transaction block";
715
716impl<'a, A, I> StateMachine<'a, A, I>
717where
718 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
719 I: Iterator<Item = TaskMetrics> + Send + 'a,
720{
721 #[allow(clippy::manual_async_fn)]
725 #[mz_ore::instrument(level = "debug")]
726 fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
727 async move {
728 let mut state = State::Ready;
729 loop {
730 self.send_pending_notices().await?;
731 state = match state {
732 State::Ready => self.advance_ready().await?,
733 State::Drain => self.advance_drain().await?,
734 State::Done => return Ok(()),
735 };
736 self.adapter_client
737 .add_idle_in_transaction_session_timeout();
738 }
739 }
740 }
741
742 #[instrument(level = "debug")]
743 async fn advance_ready(&mut self) -> Result<State, io::Error> {
744 self.tokio_metrics_intervals
746 .next()
747 .expect("infinite iterator");
748
749 let message = select! {
751 biased;
752
753 Some(timeout) = self.adapter_client.recv_timeout() => {
755 let err: AdapterError = timeout.into();
756 let conn_id = self.adapter_client.session().conn_id();
757 tracing::warn!("session timed out, conn_id {}", conn_id);
758
759 let error_response = err.into_response(Severity::Fatal);
761 let error_state = self.error(error_response).await;
762
763 self.adapter_client.terminate().await;
765
766 let _ = self.conn.recv().await?;
770 return error_state;
771 },
772 message = self.conn.recv() => message?,
774 };
775
776 let interval = self
778 .tokio_metrics_intervals
779 .next()
780 .expect("infinite iterator");
781 let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
782
783 let received = SYSTEM_TIME();
787
788 self.adapter_client
789 .remove_idle_in_transaction_session_timeout();
790
791 let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
794
795 let start = message.as_ref().map(|_| Instant::now());
796 let next_state = match message {
797 Some(FrontendMessage::Query { sql }) => {
798 let query_root_span =
799 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
800 query_root_span.follows_from(tracing::Span::current());
801 self.query(sql, received)
802 .instrument(query_root_span)
803 .await?
804 }
805 Some(FrontendMessage::Parse {
806 name,
807 sql,
808 param_types,
809 }) => self.parse(name, sql, param_types).await?,
810 Some(FrontendMessage::Bind {
811 portal_name,
812 statement_name,
813 param_formats,
814 raw_params,
815 result_formats,
816 }) => {
817 self.bind(
818 portal_name,
819 statement_name,
820 param_formats,
821 raw_params,
822 result_formats,
823 )
824 .await?
825 }
826 Some(FrontendMessage::Execute {
827 portal_name,
828 max_rows,
829 }) => {
830 let max_rows = match usize::try_from(max_rows) {
831 Ok(0) | Err(_) => ExecuteCount::All, Ok(n) => ExecuteCount::Count(n),
833 };
834 let execute_root_span =
835 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
836 execute_root_span.follows_from(tracing::Span::current());
837 let state = self
838 .execute(
839 portal_name,
840 max_rows,
841 portal_exec_message,
842 None,
843 ExecuteTimeout::None,
844 None,
845 Some(received),
846 )
847 .instrument(execute_root_span)
848 .await?;
849 if self.adapter_client.session().transaction().is_implicit() {
864 self.txn_needs_commit = true;
865 }
866 state
867 }
868 Some(FrontendMessage::DescribeStatement { name }) => {
869 self.describe_statement(&name).await?
870 }
871 Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
872 Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
873 Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
874 Some(FrontendMessage::Flush) => self.flush().await?,
875 Some(FrontendMessage::Sync) => self.sync().await?,
876 Some(FrontendMessage::Terminate) => State::Done,
877
878 Some(FrontendMessage::CopyData(_))
879 | Some(FrontendMessage::CopyDone)
880 | Some(FrontendMessage::CopyFail(_))
881 | Some(FrontendMessage::Password { .. })
882 | Some(FrontendMessage::RawAuthentication(_))
883 | Some(FrontendMessage::SASLInitialResponse { .. })
884 | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
885 None => State::Done,
886 };
887
888 if let Some(start) = start {
889 self.adapter_client
890 .inner()
891 .metrics()
892 .pgwire_message_processing_seconds
893 .with_label_values(&[message_name])
894 .observe(start.elapsed().as_secs_f64());
895 }
896 self.adapter_client
897 .inner()
898 .metrics()
899 .pgwire_recv_scheduling_delay_ms
900 .with_label_values(&[message_name])
901 .observe(recv_scheduling_delay_ms);
902
903 Ok(next_state)
904 }
905
906 async fn advance_drain(&mut self) -> Result<State, io::Error> {
907 let message = self.conn.recv().await?;
908 if message.is_some() {
909 self.adapter_client
910 .remove_idle_in_transaction_session_timeout();
911 }
912 match message {
913 Some(FrontendMessage::Sync) => self.sync().await,
914 None => Ok(State::Done),
915 _ => Ok(State::Drain),
916 }
917 }
918
919 #[instrument(level = "debug")]
923 async fn one_query(
924 &mut self,
925 stmt: Statement<Raw>,
926 sql: String,
927 lifecycle_timestamps: LifecycleTimestamps,
928 ) -> Result<State, io::Error> {
929 const EMPTY_PORTAL: &str = "";
932 if let Err(e) = self
933 .adapter_client
934 .declare(EMPTY_PORTAL.to_string(), stmt, sql)
935 .await
936 {
937 return self.error(e.into_response(Severity::Error)).await;
938 }
939 let portal = self
940 .adapter_client
941 .session()
942 .get_portal_unverified_mut(EMPTY_PORTAL)
943 .expect("unnamed portal should be present");
944
945 *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
946
947 let stmt_desc = portal.desc.clone();
948 if !stmt_desc.param_types.is_empty() {
949 return self
950 .error(ErrorResponse::error(
951 SqlState::UNDEFINED_PARAMETER,
952 "there is no parameter $1",
953 ))
954 .await;
955 }
956
957 if let Some(relation_desc) = &stmt_desc.relation_desc {
959 if !stmt_desc.is_copy {
960 let formats = vec![Format::Text; stmt_desc.arity()];
961 self.send(BackendMessage::RowDescription(
962 message::encode_row_description(relation_desc, &formats),
963 ))
964 .await?;
965 }
966 }
967
968 let result = match self
969 .adapter_client
970 .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
971 .await
972 {
973 Ok((response, execute_started)) => {
974 self.send_pending_notices().await?;
975 self.send_execute_response(
976 response,
977 stmt_desc.relation_desc,
978 EMPTY_PORTAL.to_string(),
979 ExecuteCount::All,
980 portal_exec_message,
981 None,
982 ExecuteTimeout::None,
983 execute_started,
984 )
985 .await
986 }
987 Err(e) => {
988 self.send_pending_notices().await?;
989 self.error(e.into_response(Severity::Error)).await
990 }
991 };
992
993 self.adapter_client.session().remove_portal(EMPTY_PORTAL);
995
996 result
997 }
998
999 async fn ensure_transaction(
1000 &mut self,
1001 num_stmts: usize,
1002 message_type: &str,
1003 ) -> Result<(), io::Error> {
1004 let start = Instant::now();
1005 if self.txn_needs_commit {
1006 self.commit_transaction().await?;
1007 }
1008 let res = self.adapter_client.start_transaction(Some(num_stmts));
1011 assert_ok!(res);
1012 self.adapter_client
1013 .inner()
1014 .metrics()
1015 .pgwire_ensure_transaction_seconds
1016 .with_label_values(&[message_type])
1017 .observe(start.elapsed().as_secs_f64());
1018 Ok(())
1019 }
1020
1021 fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1022 let parse_start = Instant::now();
1023 let result = match self.adapter_client.parse(sql) {
1024 Ok(result) => result.map_err(|e| {
1025 let pos = sql[..e.error.pos].chars().count() + 1;
1028 ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1029 }),
1030 Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1031 };
1032 self.adapter_client
1033 .inner()
1034 .metrics()
1035 .parse_seconds
1036 .with_label_values(&[])
1037 .observe(parse_start.elapsed().as_secs_f64());
1038 result
1039 }
1040
1041 #[instrument(level = "debug")]
1046 async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1047 let stmts = match self.parse_sql(&sql) {
1049 Ok(stmts) => stmts,
1050 Err(err) => {
1051 self.error(err).await?;
1052 return self.ready().await;
1053 }
1054 };
1055
1056 let num_stmts = stmts.len();
1057
1058 for StatementParseResult { ast: stmt, sql } in stmts {
1060 if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1062 self.aborted_txn_error().await?;
1063 break;
1064 }
1065
1066 self.ensure_transaction(num_stmts, "query").await?;
1074
1075 match self
1076 .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1077 .await?
1078 {
1079 State::Ready => (),
1080 State::Drain => break,
1081 State::Done => return Ok(State::Done),
1082 }
1083 }
1084
1085 {
1087 if self.adapter_client.session().transaction().is_implicit() {
1088 self.commit_transaction().await?;
1089 }
1090 }
1091
1092 if num_stmts == 0 {
1093 self.send(BackendMessage::EmptyQueryResponse).await?;
1094 }
1095
1096 self.ready().await
1097 }
1098
1099 #[instrument(level = "debug")]
1100 async fn parse(
1101 &mut self,
1102 name: String,
1103 sql: String,
1104 param_oids: Vec<u32>,
1105 ) -> Result<State, io::Error> {
1106 self.ensure_transaction(1, "parse").await?;
1108
1109 let mut param_types = vec![];
1110 for oid in param_oids {
1111 match mz_pgrepr::Type::from_oid(oid) {
1112 Ok(ty) => match SqlScalarType::try_from(&ty) {
1113 Ok(ty) => param_types.push(Some(ty)),
1114 Err(err) => {
1115 return self
1116 .error(ErrorResponse::error(
1117 SqlState::INVALID_PARAMETER_VALUE,
1118 err.to_string(),
1119 ))
1120 .await;
1121 }
1122 },
1123 Err(_) if oid == 0 => param_types.push(None),
1124 Err(e) => {
1125 return self
1126 .error(ErrorResponse::error(
1127 SqlState::PROTOCOL_VIOLATION,
1128 e.to_string(),
1129 ))
1130 .await;
1131 }
1132 }
1133 }
1134
1135 let stmts = match self.parse_sql(&sql) {
1136 Ok(stmts) => stmts,
1137 Err(err) => {
1138 return self.error(err).await;
1139 }
1140 };
1141 if stmts.len() > 1 {
1142 return self
1143 .error(ErrorResponse::error(
1144 SqlState::INTERNAL_ERROR,
1145 "cannot insert multiple commands into a prepared statement",
1146 ))
1147 .await;
1148 }
1149 let (maybe_stmt, sql) = match stmts.into_iter().next() {
1150 None => (None, ""),
1151 Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1152 };
1153 if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1154 return self.aborted_txn_error().await;
1155 }
1156 match self
1157 .adapter_client
1158 .prepare(name, maybe_stmt, sql.to_string(), param_types)
1159 .await
1160 {
1161 Ok(()) => {
1162 self.send(BackendMessage::ParseComplete).await?;
1163 Ok(State::Ready)
1164 }
1165 Err(e) => self.error(e.into_response(Severity::Error)).await,
1166 }
1167 }
1168
1169 #[instrument(level = "debug")]
1171 async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1172 self.end_transaction(EndTransactionAction::Commit).await
1173 }
1174
1175 #[instrument(level = "debug")]
1177 async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1178 self.end_transaction(EndTransactionAction::Rollback).await
1179 }
1180
1181 #[instrument(level = "debug")]
1183 async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1184 self.txn_needs_commit = false;
1185 let resp = self.adapter_client.end_transaction(action).await;
1186 if let Err(err) = resp {
1187 self.send(BackendMessage::ErrorResponse(
1188 err.into_response(Severity::Error),
1189 ))
1190 .await?;
1191 }
1192 Ok(())
1193 }
1194
1195 #[instrument(level = "debug")]
1196 async fn bind(
1197 &mut self,
1198 portal_name: String,
1199 statement_name: String,
1200 param_formats: Vec<Format>,
1201 raw_params: Vec<Option<Vec<u8>>>,
1202 result_formats: Vec<Format>,
1203 ) -> Result<State, io::Error> {
1204 self.ensure_transaction(1, "bind").await?;
1206
1207 let aborted_txn = self.is_aborted_txn();
1208 let stmt = match self
1209 .adapter_client
1210 .get_prepared_statement(&statement_name)
1211 .await
1212 {
1213 Ok(stmt) => stmt,
1214 Err(err) => return self.error(err.into_response(Severity::Error)).await,
1215 };
1216
1217 let param_types = &stmt.desc().param_types;
1218 if param_types.len() != raw_params.len() {
1219 let message = format!(
1220 "bind message supplies {actual} parameters, \
1221 but prepared statement \"{name}\" requires {expected}",
1222 name = statement_name,
1223 actual = raw_params.len(),
1224 expected = param_types.len()
1225 );
1226 return self
1227 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, message))
1228 .await;
1229 }
1230 let param_formats = match pad_formats(param_formats, raw_params.len()) {
1231 Ok(param_formats) => param_formats,
1232 Err(msg) => {
1233 return self
1234 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1235 .await;
1236 }
1237 };
1238 if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1239 return self.aborted_txn_error().await;
1240 }
1241 let buf = RowArena::new();
1242 let mut params = vec![];
1243 for ((raw_param, mz_typ), format) in raw_params
1244 .into_iter()
1245 .zip_eq(param_types)
1246 .zip_eq(param_formats)
1247 {
1248 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1249 let datum = match raw_param {
1250 None => Datum::Null,
1251 Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1252 Ok(param) => param.into_datum(&buf, &pg_typ),
1253 Err(err) => {
1254 let msg = format!("unable to decode parameter: {}", err);
1255 return self
1256 .error(ErrorResponse::error(SqlState::INVALID_PARAMETER_VALUE, msg))
1257 .await;
1258 }
1259 },
1260 };
1261 params.push((datum, mz_typ.clone()))
1262 }
1263
1264 let result_formats = match pad_formats(
1265 result_formats,
1266 stmt.desc()
1267 .relation_desc
1268 .clone()
1269 .map(|desc| desc.typ().column_types.len())
1270 .unwrap_or(0),
1271 ) {
1272 Ok(result_formats) => result_formats,
1273 Err(msg) => {
1274 return self
1275 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1276 .await;
1277 }
1278 };
1279
1280 if !stmt.stmt().map_or(false, |stmt| {
1283 matches!(
1284 stmt,
1285 Statement::Copy(CopyStatement {
1286 direction: CopyDirection::To,
1287 ..
1288 })
1289 )
1290 }) {
1291 if let Some(desc) = stmt.desc().relation_desc.clone() {
1292 for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1293 match (format, &ty.scalar_type) {
1294 (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1295 return self
1296 .error(ErrorResponse::error(
1297 SqlState::PROTOCOL_VIOLATION,
1298 "binary encoding of list types is not implemented",
1299 ))
1300 .await;
1301 }
1302 (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1303 return self
1304 .error(ErrorResponse::error(
1305 SqlState::PROTOCOL_VIOLATION,
1306 "binary encoding of map types is not implemented",
1307 ))
1308 .await;
1309 }
1310 (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1311 return self
1312 .error(ErrorResponse::error(
1313 SqlState::PROTOCOL_VIOLATION,
1314 "binary encoding of aclitem types does not exist",
1315 ))
1316 .await;
1317 }
1318 _ => (),
1319 }
1320 }
1321 }
1322 }
1323
1324 let desc = stmt.desc().clone();
1325 let logging = Arc::clone(stmt.logging());
1326 let stmt_ast = stmt.stmt().cloned();
1327 let state_revision = stmt.state_revision;
1328 if let Err(err) = self.adapter_client.session().set_portal(
1329 portal_name,
1330 desc,
1331 stmt_ast,
1332 logging,
1333 params,
1334 result_formats,
1335 state_revision,
1336 ) {
1337 return self.error(err.into_response(Severity::Error)).await;
1338 }
1339
1340 self.send(BackendMessage::BindComplete).await?;
1341 Ok(State::Ready)
1342 }
1343
1344 fn execute(
1345 &mut self,
1346 portal_name: String,
1347 max_rows: ExecuteCount,
1348 get_response: GetResponse,
1349 fetch_portal_name: Option<String>,
1350 timeout: ExecuteTimeout,
1351 outer_ctx_extra: Option<ExecuteContextExtra>,
1352 received: Option<EpochMillis>,
1353 ) -> BoxFuture<'_, Result<State, io::Error>> {
1354 async move {
1355 let aborted_txn = self.is_aborted_txn();
1356
1357 let portal = match self
1359 .adapter_client
1360 .session()
1361 .get_portal_unverified_mut(&portal_name)
1362 {
1363 Some(portal) => portal,
1364 None => {
1365 let msg = format!("portal {} does not exist", portal_name.quoted());
1366 if let Some(outer_ctx_extra) = outer_ctx_extra {
1367 self.adapter_client.retire_execute(
1368 outer_ctx_extra,
1369 StatementEndedExecutionReason::Errored { error: msg.clone() },
1370 );
1371 }
1372 return self
1373 .error(ErrorResponse::error(SqlState::INVALID_CURSOR_NAME, msg))
1374 .await;
1375 }
1376 };
1377
1378 *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1379
1380 let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1382 if aborted_txn && !txn_exit_stmt {
1383 if let Some(outer_ctx_extra) = outer_ctx_extra {
1384 self.adapter_client.retire_execute(
1385 outer_ctx_extra,
1386 StatementEndedExecutionReason::Errored {
1387 error: ABORTED_TXN_MSG.to_string(),
1388 },
1389 );
1390 }
1391 return self.aborted_txn_error().await;
1392 }
1393
1394 let row_desc = portal.desc.relation_desc.clone();
1395 match portal.state {
1396 PortalState::NotStarted => {
1397 self.ensure_transaction(1, "execute").await?;
1399 match self
1400 .adapter_client
1401 .execute(
1402 portal_name.clone(),
1403 self.conn.wait_closed(),
1404 outer_ctx_extra,
1405 )
1406 .await
1407 {
1408 Ok((response, execute_started)) => {
1409 self.send_pending_notices().await?;
1410 self.send_execute_response(
1411 response,
1412 row_desc,
1413 portal_name,
1414 max_rows,
1415 get_response,
1416 fetch_portal_name,
1417 timeout,
1418 execute_started,
1419 )
1420 .await
1421 }
1422 Err(e) => {
1423 self.send_pending_notices().await?;
1424 self.error(e.into_response(Severity::Error)).await
1425 }
1426 }
1427 }
1428 PortalState::InProgress(rows) => {
1429 let rows = rows.take().expect("InProgress rows must be populated");
1430 let (result, statement_ended_execution_reason) = match self
1431 .send_rows(
1432 row_desc.expect("portal missing row desc on resumption"),
1433 portal_name,
1434 rows,
1435 max_rows,
1436 get_response,
1437 fetch_portal_name,
1438 timeout,
1439 )
1440 .await
1441 {
1442 Err(e) => {
1443 (Err(e), StatementEndedExecutionReason::Canceled)
1446 }
1447 Ok((ok, SendRowsEndedReason::Canceled)) => {
1448 (Ok(ok), StatementEndedExecutionReason::Canceled)
1449 }
1450 Ok((
1463 ok,
1464 SendRowsEndedReason::Success {
1465 result_size: _,
1466 rows_returned: _,
1467 },
1468 )) => (
1469 Ok(ok),
1470 StatementEndedExecutionReason::Success {
1471 result_size: None,
1472 rows_returned: None,
1473 execution_strategy: None,
1474 },
1475 ),
1476 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1477 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1478 }
1479 };
1480 if let Some(outer_ctx_extra) = outer_ctx_extra {
1481 self.adapter_client
1482 .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1483 }
1484 result
1485 }
1486 PortalState::Completed(Some(tag)) => {
1493 let tag = tag.to_string();
1494 if let Some(outer_ctx_extra) = outer_ctx_extra {
1495 self.adapter_client.retire_execute(
1496 outer_ctx_extra,
1497 StatementEndedExecutionReason::Success {
1498 result_size: None,
1499 rows_returned: None,
1500 execution_strategy: None,
1501 },
1502 );
1503 }
1504 self.send(BackendMessage::CommandComplete { tag }).await?;
1505 Ok(State::Ready)
1506 }
1507 PortalState::Completed(None) => {
1508 let error = format!(
1509 "portal {} cannot be run",
1510 Ident::new_unchecked(portal_name).to_ast_string_stable()
1511 );
1512 if let Some(outer_ctx_extra) = outer_ctx_extra {
1513 self.adapter_client.retire_execute(
1514 outer_ctx_extra,
1515 StatementEndedExecutionReason::Errored {
1516 error: error.clone(),
1517 },
1518 );
1519 }
1520 self.error(ErrorResponse::error(
1521 SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1522 error,
1523 ))
1524 .await
1525 }
1526 }
1527 }
1528 .instrument(debug_span!("execute"))
1529 .boxed()
1530 }
1531
1532 #[instrument(level = "debug")]
1533 async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1534 self.ensure_transaction(1, "describe_statement").await?;
1536
1537 let stmt = match self.adapter_client.get_prepared_statement(name).await {
1538 Ok(stmt) => stmt,
1539 Err(err) => return self.error(err.into_response(Severity::Error)).await,
1540 };
1541 let parameter_desc = BackendMessage::ParameterDescription(
1543 stmt.desc()
1544 .param_types
1545 .iter()
1546 .map(mz_pgrepr::Type::from)
1547 .collect(),
1548 );
1549 let formats = vec![Format::Text; stmt.desc().arity()];
1553 let row_desc = describe_rows(stmt.desc(), &formats);
1554 self.send_all([parameter_desc, row_desc]).await?;
1555 Ok(State::Ready)
1556 }
1557
1558 #[instrument(level = "debug")]
1559 async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1560 self.ensure_transaction(1, "describe_portal").await?;
1562
1563 let session = self.adapter_client.session();
1564 let row_desc = session
1565 .get_portal_unverified(name)
1566 .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1567 match row_desc {
1568 Some(row_desc) => {
1569 self.send(row_desc).await?;
1570 Ok(State::Ready)
1571 }
1572 None => {
1573 self.error(ErrorResponse::error(
1574 SqlState::INVALID_CURSOR_NAME,
1575 format!("portal {} does not exist", name.quoted()),
1576 ))
1577 .await
1578 }
1579 }
1580 }
1581
1582 #[instrument(level = "debug")]
1583 async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1584 self.adapter_client
1585 .session()
1586 .remove_prepared_statement(&name);
1587 self.send(BackendMessage::CloseComplete).await?;
1588 Ok(State::Ready)
1589 }
1590
1591 #[instrument(level = "debug")]
1592 async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1593 self.adapter_client.session().remove_portal(&name);
1594 self.send(BackendMessage::CloseComplete).await?;
1595 Ok(State::Ready)
1596 }
1597
1598 fn complete_portal(&mut self, name: &str) {
1599 let portal = self
1600 .adapter_client
1601 .session()
1602 .get_portal_unverified_mut(name)
1603 .expect("portal should exist");
1604 *portal.state = PortalState::Completed(None);
1605 }
1606
1607 async fn fetch(
1608 &mut self,
1609 name: String,
1610 count: Option<FetchDirection>,
1611 max_rows: ExecuteCount,
1612 fetch_portal_name: Option<String>,
1613 timeout: ExecuteTimeout,
1614 ctx_extra: ExecuteContextExtra,
1615 ) -> Result<State, io::Error> {
1616 let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1619
1620 let count = match (max_rows, count) {
1633 (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1634 let count = usize::cast_from(count);
1635 if max_rows < count {
1636 let msg = "Execute with max_rows < a FETCH's count is not supported";
1637 self.adapter_client.retire_execute(
1638 ctx_extra,
1639 StatementEndedExecutionReason::Errored {
1640 error: msg.to_string(),
1641 },
1642 );
1643 return self
1644 .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1645 .await;
1646 }
1647 ExecuteCount::Count(count)
1648 }
1649 (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1650 let msg = "Execute with max_rows of a FETCH ALL is not supported";
1651 self.adapter_client.retire_execute(
1652 ctx_extra,
1653 StatementEndedExecutionReason::Errored {
1654 error: msg.to_string(),
1655 },
1656 );
1657 return self
1658 .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1659 .await;
1660 }
1661 (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1662 (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1663 ExecuteCount::Count(usize::cast_from(count))
1664 }
1665 };
1666 let cursor_name = name.to_string();
1667 self.execute(
1668 cursor_name,
1669 count,
1670 fetch_message,
1671 fetch_portal_name,
1672 timeout,
1673 Some(ctx_extra),
1674 None,
1675 )
1676 .await
1677 }
1678
1679 async fn flush(&mut self) -> Result<State, io::Error> {
1680 self.conn.flush().await?;
1681 Ok(State::Ready)
1682 }
1683
1684 #[instrument(level = "debug")]
1689 async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1690 where
1691 M: Into<BackendMessage>,
1692 {
1693 let message: BackendMessage = message.into();
1694 let is_error =
1695 matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1696
1697 self.conn.send(message).await?;
1698
1699 if is_error {
1706 self.conn.flush().await?;
1707 }
1708
1709 Ok(())
1710 }
1711
1712 #[instrument(level = "debug")]
1713 pub async fn send_all(
1714 &mut self,
1715 messages: impl IntoIterator<Item = BackendMessage>,
1716 ) -> Result<(), io::Error> {
1717 for m in messages {
1718 self.send(m).await?;
1719 }
1720 Ok(())
1721 }
1722
1723 #[instrument(level = "debug")]
1724 async fn sync(&mut self) -> Result<State, io::Error> {
1725 if self.adapter_client.session().transaction().is_implicit() {
1727 self.commit_transaction().await?;
1728 }
1729 self.ready().await
1730 }
1731
1732 #[instrument(level = "debug")]
1733 async fn ready(&mut self) -> Result<State, io::Error> {
1734 let txn_state = self.adapter_client.session().transaction().into();
1735 self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1736 self.flush().await
1737 }
1738
1739 #[allow(clippy::too_many_arguments)]
1740 #[instrument(level = "debug")]
1741 async fn send_execute_response(
1742 &mut self,
1743 response: ExecuteResponse,
1744 row_desc: Option<RelationDesc>,
1745 portal_name: String,
1746 max_rows: ExecuteCount,
1747 get_response: GetResponse,
1748 fetch_portal_name: Option<String>,
1749 timeout: ExecuteTimeout,
1750 execute_started: Instant,
1751 ) -> Result<State, io::Error> {
1752 let mut tag = response.tag();
1753
1754 macro_rules! command_complete {
1755 () => {{
1756 self.send(BackendMessage::CommandComplete {
1757 tag: tag
1758 .take()
1759 .expect("command_complete only called on tag-generating results"),
1760 })
1761 .await?;
1762 Ok(State::Ready)
1763 }};
1764 }
1765
1766 let r = match response {
1767 ExecuteResponse::ClosedCursor => {
1768 self.complete_portal(&portal_name);
1769 command_complete!()
1770 }
1771 ExecuteResponse::DeclaredCursor => {
1772 self.complete_portal(&portal_name);
1773 command_complete!()
1774 }
1775 ExecuteResponse::EmptyQuery => {
1776 self.send(BackendMessage::EmptyQueryResponse).await?;
1777 Ok(State::Ready)
1778 }
1779 ExecuteResponse::Fetch {
1780 name,
1781 count,
1782 timeout,
1783 ctx_extra,
1784 } => {
1785 self.fetch(
1786 name,
1787 count,
1788 max_rows,
1789 Some(portal_name.to_string()),
1790 timeout,
1791 ctx_extra,
1792 )
1793 .await
1794 }
1795 ExecuteResponse::SendingRowsStreaming {
1796 rows,
1797 instance_id,
1798 strategy,
1799 } => {
1800 let row_desc = row_desc
1801 .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1802
1803 let span = tracing::debug_span!("sending_rows_streaming");
1804
1805 self.send_rows(
1806 row_desc,
1807 portal_name,
1808 InProgressRows::new(RecordFirstRowStream::new(
1809 Box::new(rows),
1810 execute_started,
1811 &self.adapter_client,
1812 Some(instance_id),
1813 Some(strategy),
1814 )),
1815 max_rows,
1816 get_response,
1817 fetch_portal_name,
1818 timeout,
1819 )
1820 .instrument(span)
1821 .await
1822 .map(|(state, _)| state)
1823 }
1824 ExecuteResponse::SendingRowsImmediate { rows } => {
1825 let row_desc = row_desc
1826 .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1827
1828 let span = tracing::debug_span!("sending_rows_immediate");
1829
1830 let stream =
1831 futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1832 self.send_rows(
1833 row_desc,
1834 portal_name,
1835 InProgressRows::new(RecordFirstRowStream::new(
1836 Box::new(stream),
1837 execute_started,
1838 &self.adapter_client,
1839 None,
1840 Some(StatementExecutionStrategy::Constant),
1841 )),
1842 max_rows,
1843 get_response,
1844 fetch_portal_name,
1845 timeout,
1846 )
1847 .instrument(span)
1848 .await
1849 .map(|(state, _)| state)
1850 }
1851 ExecuteResponse::SetVariable { name, .. } => {
1852 let qn = name.to_string();
1855 let msg = if let Some(var) = self
1856 .adapter_client
1857 .session()
1858 .vars_mut()
1859 .notify_set()
1860 .find(|v| v.name() == qn)
1861 {
1862 Some(BackendMessage::ParameterStatus(var.name(), var.value()))
1863 } else {
1864 None
1865 };
1866 if let Some(msg) = msg {
1867 self.send(msg).await?;
1868 }
1869 command_complete!()
1870 }
1871 ExecuteResponse::Subscribing {
1872 rx,
1873 ctx_extra,
1874 instance_id,
1875 } => {
1876 if fetch_portal_name.is_none() {
1877 let mut msg = ErrorResponse::notice(
1878 SqlState::WARNING,
1879 "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
1880 );
1881 if self.adapter_client.session().vars().application_name() == "psql" {
1882 msg.hint = Some(
1883 "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
1884 .into(),
1885 )
1886 }
1887 self.send(msg).await?;
1888 self.conn.flush().await?;
1889 }
1890 let row_desc =
1891 row_desc.expect("missing row description for ExecuteResponse::Subscribing");
1892 let (result, statement_ended_execution_reason) = match self
1893 .send_rows(
1894 row_desc,
1895 portal_name,
1896 InProgressRows::new(RecordFirstRowStream::new(
1897 Box::new(UnboundedReceiverStream::new(rx)),
1898 execute_started,
1899 &self.adapter_client,
1900 Some(instance_id),
1901 None,
1902 )),
1903 max_rows,
1904 get_response,
1905 fetch_portal_name,
1906 timeout,
1907 )
1908 .await
1909 {
1910 Err(e) => {
1911 (Err(e), StatementEndedExecutionReason::Canceled)
1914 }
1915 Ok((ok, SendRowsEndedReason::Canceled)) => {
1916 (Ok(ok), StatementEndedExecutionReason::Canceled)
1917 }
1918 Ok((
1919 ok,
1920 SendRowsEndedReason::Success {
1921 result_size,
1922 rows_returned,
1923 },
1924 )) => (
1925 Ok(ok),
1926 StatementEndedExecutionReason::Success {
1927 result_size: Some(result_size),
1928 rows_returned: Some(rows_returned),
1929 execution_strategy: None,
1930 },
1931 ),
1932 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1933 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1934 }
1935 };
1936 self.adapter_client
1937 .retire_execute(ctx_extra, statement_ended_execution_reason);
1938 return result;
1939 }
1940 ExecuteResponse::CopyTo { format, resp } => {
1941 let row_desc =
1942 row_desc.expect("missing row description for ExecuteResponse::CopyTo");
1943 match *resp {
1944 ExecuteResponse::Subscribing {
1945 rx,
1946 ctx_extra,
1947 instance_id,
1948 } => {
1949 let (result, statement_ended_execution_reason) = match self
1950 .copy_rows(
1951 format,
1952 row_desc,
1953 RecordFirstRowStream::new(
1954 Box::new(UnboundedReceiverStream::new(rx)),
1955 execute_started,
1956 &self.adapter_client,
1957 Some(instance_id),
1958 None,
1959 ),
1960 )
1961 .await
1962 {
1963 Err(e) => {
1964 (Err(e), StatementEndedExecutionReason::Canceled)
1967 }
1968 Ok((
1969 state,
1970 SendRowsEndedReason::Success {
1971 result_size,
1972 rows_returned,
1973 },
1974 )) => (
1975 Ok(state),
1976 StatementEndedExecutionReason::Success {
1977 result_size: Some(result_size),
1978 rows_returned: Some(rows_returned),
1979 execution_strategy: None,
1980 },
1981 ),
1982 Ok((state, SendRowsEndedReason::Errored { error })) => {
1983 (Ok(state), StatementEndedExecutionReason::Errored { error })
1984 }
1985 Ok((state, SendRowsEndedReason::Canceled)) => {
1986 (Ok(state), StatementEndedExecutionReason::Canceled)
1987 }
1988 };
1989 self.adapter_client
1990 .retire_execute(ctx_extra, statement_ended_execution_reason);
1991 return result;
1992 }
1993 ExecuteResponse::SendingRowsStreaming {
1994 rows,
1995 instance_id,
1996 strategy,
1997 } => {
1998 return self
2003 .copy_rows(
2004 format,
2005 row_desc,
2006 RecordFirstRowStream::new(
2007 Box::new(rows),
2008 execute_started,
2009 &self.adapter_client,
2010 Some(instance_id),
2011 Some(strategy),
2012 ),
2013 )
2014 .await
2015 .map(|(state, _)| state);
2016 }
2017 ExecuteResponse::SendingRowsImmediate { rows } => {
2018 let span = tracing::debug_span!("sending_rows_immediate");
2019
2020 let rows = futures::stream::once(futures::future::ready(
2021 PeekResponseUnary::Rows(rows),
2022 ));
2023 return self
2028 .copy_rows(
2029 format,
2030 row_desc,
2031 RecordFirstRowStream::new(
2032 Box::new(rows),
2033 execute_started,
2034 &self.adapter_client,
2035 None,
2036 Some(StatementExecutionStrategy::Constant),
2037 ),
2038 )
2039 .instrument(span)
2040 .await
2041 .map(|(state, _)| state);
2042 }
2043 _ => {
2044 return self
2045 .error(ErrorResponse::error(
2046 SqlState::INTERNAL_ERROR,
2047 "unsupported COPY response type".to_string(),
2048 ))
2049 .await;
2050 }
2051 };
2052 }
2053 ExecuteResponse::CopyFrom {
2054 target_id,
2055 target_name,
2056 columns,
2057 params,
2058 ctx_extra,
2059 } => {
2060 let row_desc =
2061 row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2062 self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2063 .await
2064 }
2065 ExecuteResponse::TransactionCommitted { params }
2066 | ExecuteResponse::TransactionRolledBack { params } => {
2067 let notify_set: mz_ore::collections::HashSet<String> = self
2068 .adapter_client
2069 .session()
2070 .vars()
2071 .notify_set()
2072 .map(|v| v.name().to_string())
2073 .collect();
2074
2075 for (name, value) in params
2077 .into_iter()
2078 .filter(|(name, _v)| notify_set.contains(*name))
2079 {
2080 let msg = BackendMessage::ParameterStatus(name, value);
2081 self.send(msg).await?;
2082 }
2083 command_complete!()
2084 }
2085
2086 ExecuteResponse::AlteredDefaultPrivileges
2087 | ExecuteResponse::AlteredObject(..)
2088 | ExecuteResponse::AlteredRole
2089 | ExecuteResponse::AlteredSystemConfiguration
2090 | ExecuteResponse::CreatedCluster { .. }
2091 | ExecuteResponse::CreatedClusterReplica { .. }
2092 | ExecuteResponse::CreatedConnection { .. }
2093 | ExecuteResponse::CreatedDatabase { .. }
2094 | ExecuteResponse::CreatedIndex { .. }
2095 | ExecuteResponse::CreatedIntrospectionSubscribe
2096 | ExecuteResponse::CreatedMaterializedView { .. }
2097 | ExecuteResponse::CreatedContinualTask { .. }
2098 | ExecuteResponse::CreatedRole
2099 | ExecuteResponse::CreatedSchema { .. }
2100 | ExecuteResponse::CreatedSecret { .. }
2101 | ExecuteResponse::CreatedSink { .. }
2102 | ExecuteResponse::CreatedSource { .. }
2103 | ExecuteResponse::CreatedTable { .. }
2104 | ExecuteResponse::CreatedType
2105 | ExecuteResponse::CreatedView { .. }
2106 | ExecuteResponse::CreatedViews { .. }
2107 | ExecuteResponse::CreatedNetworkPolicy
2108 | ExecuteResponse::Comment
2109 | ExecuteResponse::Deallocate { .. }
2110 | ExecuteResponse::Deleted(..)
2111 | ExecuteResponse::DiscardedAll
2112 | ExecuteResponse::DiscardedTemp
2113 | ExecuteResponse::DroppedObject(_)
2114 | ExecuteResponse::DroppedOwned
2115 | ExecuteResponse::GrantedPrivilege
2116 | ExecuteResponse::GrantedRole
2117 | ExecuteResponse::Inserted(..)
2118 | ExecuteResponse::Copied(..)
2119 | ExecuteResponse::Prepare
2120 | ExecuteResponse::Raised
2121 | ExecuteResponse::ReassignOwned
2122 | ExecuteResponse::RevokedPrivilege
2123 | ExecuteResponse::RevokedRole
2124 | ExecuteResponse::StartedTransaction { .. }
2125 | ExecuteResponse::Updated(..)
2126 | ExecuteResponse::ValidatedConnection => {
2127 command_complete!()
2128 }
2129 };
2130
2131 assert_none!(tag, "tag created but not consumed: {:?}", tag);
2132 r
2133 }
2134
2135 #[allow(clippy::too_many_arguments)]
2136 #[mz_ore::instrument(level = "debug")]
2138 async fn send_rows(
2139 &mut self,
2140 row_desc: RelationDesc,
2141 portal_name: String,
2142 mut rows: InProgressRows,
2143 max_rows: ExecuteCount,
2144 get_response: GetResponse,
2145 fetch_portal_name: Option<String>,
2146 timeout: ExecuteTimeout,
2147 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2148 let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2151 name
2152 } else {
2153 &portal_name
2154 };
2155 let result_formats = self
2156 .adapter_client
2157 .session()
2158 .get_portal_unverified(result_format_portal_name)
2159 .expect("valid fetch portal name for send rows")
2160 .result_formats
2161 .clone();
2162
2163 let (mut wait_once, mut deadline) = match timeout {
2164 ExecuteTimeout::None => (false, None),
2165 ExecuteTimeout::Seconds(t) => (
2166 false,
2167 Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2168 ),
2169 ExecuteTimeout::WaitOnce => (true, None),
2170 };
2171
2172 {
2174 let portal_name_desc = &self
2175 .adapter_client
2176 .session()
2177 .get_portal_unverified(portal_name.as_str())
2178 .expect("portal should exist")
2179 .desc
2180 .relation_desc;
2181 if let Some(portal_name_desc) = portal_name_desc {
2182 soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2183 }
2184 if let Some(fetch_portal_name) = &fetch_portal_name {
2185 let fetch_portal_desc = &self
2186 .adapter_client
2187 .session()
2188 .get_portal_unverified(fetch_portal_name)
2189 .expect("portal should exist")
2190 .desc
2191 .relation_desc;
2192 if let Some(fetch_portal_desc) = fetch_portal_desc {
2193 soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2194 }
2195 }
2196 }
2197
2198 self.conn.set_encode_state(
2199 row_desc
2200 .typ()
2201 .column_types
2202 .iter()
2203 .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2204 .zip_eq(result_formats)
2205 .collect(),
2206 );
2207
2208 let mut total_sent_rows = 0;
2209 let mut total_sent_bytes = 0;
2210 let mut want_rows = match max_rows {
2212 ExecuteCount::All => usize::MAX,
2213 ExecuteCount::Count(count) => count,
2214 };
2215
2216 loop {
2218 let batch = if rows.current.is_some() {
2221 FetchResult::Rows(rows.current.take())
2222 } else if want_rows == 0 {
2223 FetchResult::Rows(None)
2224 } else {
2225 let notice_fut = self.adapter_client.session().recv_notice();
2226 tokio::select! {
2227 err = self.conn.wait_closed() => return Err(err),
2228 _ = time::sleep_until(deadline.unwrap_or_else(tokio::time::Instant::now)), if deadline.is_some() => FetchResult::Rows(None),
2229 notice = notice_fut => {
2230 FetchResult::Notice(notice)
2231 }
2232 batch = rows.remaining.recv() => match batch {
2233 None => FetchResult::Rows(None),
2234 Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2235 Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2236 Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2237 },
2238 }
2239 };
2240
2241 match batch {
2242 FetchResult::Rows(None) => break,
2243 FetchResult::Rows(Some(mut batch_rows)) => {
2244 if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2245 let msg = err.to_string();
2246 return self
2247 .error(err.into_response(Severity::Error))
2248 .await
2249 .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2250 }
2251
2252 if wait_once && batch_rows.peek().is_some() {
2256 deadline = Some(tokio::time::Instant::now());
2257 wait_once = false;
2258 }
2259
2260 let mut sent_rows = 0;
2262 let mut sent_bytes = 0;
2263 let messages = (&mut batch_rows)
2264 .map(|row| {
2269 let row_len = row.byte_len();
2270 let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2271 (row_len, BackendMessage::DataRow(values))
2272 })
2273 .inspect(|(row_len, _)| {
2274 sent_bytes += row_len;
2275 sent_rows += 1
2276 })
2277 .map(|(_row_len, row)| row)
2278 .take(want_rows);
2279 self.send_all(messages).await?;
2280
2281 total_sent_rows += sent_rows;
2282 total_sent_bytes += sent_bytes;
2283 want_rows -= sent_rows;
2284
2285 if want_rows == 0 {
2288 if batch_rows.peek().is_some() {
2289 rows.current = Some(batch_rows);
2290 }
2291 break;
2292 }
2293
2294 self.conn.flush().await?;
2295 }
2296 FetchResult::Notice(notice) => {
2297 self.send(notice.into_response()).await?;
2298 self.conn.flush().await?;
2299 }
2300 FetchResult::Error(text) => {
2301 return self
2302 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2303 .await
2304 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2305 }
2306 FetchResult::Canceled => {
2307 return self
2308 .error(ErrorResponse::error(
2309 SqlState::QUERY_CANCELED,
2310 "canceling statement due to user request",
2311 ))
2312 .await
2313 .map(|state| (state, SendRowsEndedReason::Canceled));
2314 }
2315 }
2316 }
2317
2318 let portal = self
2319 .adapter_client
2320 .session()
2321 .get_portal_unverified_mut(&portal_name)
2322 .expect("valid portal name for send rows");
2323
2324 let saw_rows = rows.remaining.saw_rows;
2325 let no_more_rows = rows.no_more_rows();
2326 let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2327
2328 *portal.state = PortalState::InProgress(Some(rows));
2331
2332 let fetch_portal = fetch_portal_name.map(|name| {
2333 self.adapter_client
2334 .session()
2335 .get_portal_unverified_mut(&name)
2336 .expect("valid fetch portal")
2337 });
2338 let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2339 self.send(response_message).await?;
2340
2341 if no_more_rows {
2343 let statement_type = if let Some(stmt) = &self
2344 .adapter_client
2345 .session()
2346 .get_portal_unverified(&portal_name)
2347 .expect("valid portal name for send_rows")
2348 .stmt
2349 {
2350 metrics::statement_type_label_value(stmt.deref())
2351 } else {
2352 "no-statement"
2353 };
2354 let duration = if saw_rows {
2355 recorded_first_row_instant
2356 .expect("recorded_first_row_instant because saw_rows")
2357 .elapsed()
2358 } else {
2359 Duration::ZERO
2363 };
2364 self.adapter_client
2365 .inner()
2366 .metrics()
2367 .result_rows_first_to_last_byte_seconds
2368 .with_label_values(&[statement_type])
2369 .observe(duration.as_secs_f64());
2370 }
2371
2372 Ok((
2373 State::Ready,
2374 SendRowsEndedReason::Success {
2375 result_size: u64::cast_from(total_sent_bytes),
2376 rows_returned: u64::cast_from(total_sent_rows),
2377 },
2378 ))
2379 }
2380
2381 #[mz_ore::instrument(level = "debug")]
2382 async fn copy_rows(
2383 &mut self,
2384 format: CopyFormat,
2385 row_desc: RelationDesc,
2386 mut stream: RecordFirstRowStream,
2387 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2388 let (row_format, encode_format) = match format {
2389 CopyFormat::Text => (
2390 CopyFormatParams::Text(CopyTextFormatParams::default()),
2391 Format::Text,
2392 ),
2393 CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2394 CopyFormat::Csv => (
2395 CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2396 Format::Text,
2397 ),
2398 CopyFormat::Parquet => {
2399 let text = "Parquet format is not supported".to_string();
2400 return self
2401 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2402 .await
2403 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2404 }
2405 };
2406
2407 let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2408 mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2409 };
2410
2411 let typ = row_desc.typ();
2412 let column_formats = iter::repeat(encode_format)
2413 .take(typ.column_types.len())
2414 .collect();
2415 self.send(BackendMessage::CopyOutResponse {
2416 overall_format: encode_format,
2417 column_formats,
2418 })
2419 .await?;
2420
2421 let mut out = Vec::new();
2426
2427 if let CopyFormat::Binary = format {
2428 out.extend(b"PGCOPY\n\xFF\r\n\0");
2430 out.extend([0, 0, 0, 0]);
2432 out.extend([0, 0, 0, 0]);
2434 }
2435
2436 let mut count = 0;
2437 let mut total_sent_bytes = 0;
2438 loop {
2439 tokio::select! {
2440 e = self.conn.wait_closed() => return Err(e),
2441 batch = stream.recv() => match batch {
2442 None => break,
2443 Some(PeekResponseUnary::Error(text)) => {
2444 return self
2445 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2446 .await
2447 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2448 }
2449 Some(PeekResponseUnary::Canceled) => {
2450 return self.error(ErrorResponse::error(
2451 SqlState::QUERY_CANCELED,
2452 "canceling statement due to user request",
2453 ))
2454 .await.map(|state| (state, SendRowsEndedReason::Canceled));
2455 }
2456 Some(PeekResponseUnary::Rows(mut rows)) => {
2457 count += rows.count();
2458 while let Some(row) = rows.next() {
2459 total_sent_bytes += row.byte_len();
2460 encode_fn(row, typ, &mut out)?;
2461 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2462 .await?;
2463 }
2464 }
2465 },
2466 notice = self.adapter_client.session().recv_notice() => {
2467 self.send(notice.into_response())
2468 .await?;
2469 self.conn.flush().await?;
2470 }
2471 }
2472
2473 self.conn.flush().await?;
2474 }
2475 if let CopyFormat::Binary = format {
2477 let trailer: i16 = -1;
2478 out.extend(trailer.to_be_bytes());
2479 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2480 .await?;
2481 }
2482
2483 let tag = format!("COPY {}", count);
2484 self.send(BackendMessage::CopyDone).await?;
2485 self.send(BackendMessage::CommandComplete { tag }).await?;
2486 Ok((
2487 State::Ready,
2488 SendRowsEndedReason::Success {
2489 result_size: u64::cast_from(total_sent_bytes),
2490 rows_returned: u64::cast_from(count),
2491 },
2492 ))
2493 }
2494
2495 #[instrument(level = "debug")]
2498 async fn copy_from(
2499 &mut self,
2500 target_id: CatalogItemId,
2501 target_name: String,
2502 columns: Vec<ColumnIndex>,
2503 params: CopyFormatParams<'_>,
2504 row_desc: RelationDesc,
2505 mut ctx_extra: ExecuteContextExtra,
2506 ) -> Result<State, io::Error> {
2507 let res = self
2508 .copy_from_inner(
2509 target_id,
2510 target_name,
2511 columns,
2512 params,
2513 row_desc,
2514 &mut ctx_extra,
2515 )
2516 .await;
2517 match &res {
2518 Ok(State::Done) => {
2519 self.adapter_client
2523 .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2524 }
2525 Err(e) => {
2526 self.adapter_client.retire_execute(
2527 ctx_extra,
2528 StatementEndedExecutionReason::Errored {
2529 error: format!("{e}"),
2530 },
2531 );
2532 }
2533 other => {
2534 tracing::warn!(?other, "aborting COPY FROM");
2535 self.adapter_client
2536 .retire_execute(ctx_extra, StatementEndedExecutionReason::Aborted);
2537 }
2538 }
2539 res
2540 }
2541
2542 async fn copy_from_inner(
2543 &mut self,
2544 target_id: CatalogItemId,
2545 target_name: String,
2546 columns: Vec<ColumnIndex>,
2547 params: CopyFormatParams<'_>,
2548 row_desc: RelationDesc,
2549 ctx_extra: &mut ExecuteContextExtra,
2550 ) -> Result<State, io::Error> {
2551 let typ = row_desc.typ();
2552 let column_formats = vec![Format::Text; typ.column_types.len()];
2553 self.send(BackendMessage::CopyInResponse {
2554 overall_format: Format::Text,
2555 column_formats,
2556 })
2557 .await?;
2558 self.conn.flush().await?;
2559
2560 let system_vars = self.adapter_client.get_system_vars().await;
2561 let max_size = system_vars
2562 .get(MAX_COPY_FROM_SIZE.name())
2563 .ok()
2564 .and_then(|max_size| max_size.value().parse().ok())
2565 .unwrap_or(usize::MAX);
2566 tracing::debug!("COPY FROM max buffer size: {max_size} bytes");
2567
2568 let mut data = Vec::new();
2569 loop {
2570 let message = self.conn.recv().await?;
2571 match message {
2572 Some(FrontendMessage::CopyData(buf)) => {
2573 if (data.len() + buf.len()) > max_size {
2575 return self
2576 .error(ErrorResponse::error(
2577 SqlState::INSUFFICIENT_RESOURCES,
2578 "COPY FROM STDIN too large",
2579 ))
2580 .await;
2581 }
2582 data.extend(buf)
2583 }
2584 Some(FrontendMessage::CopyDone) => break,
2585 Some(FrontendMessage::CopyFail(err)) => {
2586 self.adapter_client.retire_execute(
2587 std::mem::take(ctx_extra),
2588 StatementEndedExecutionReason::Canceled,
2589 );
2590 return self
2591 .error(ErrorResponse::error(
2592 SqlState::QUERY_CANCELED,
2593 format!("COPY from stdin failed: {}", err),
2594 ))
2595 .await;
2596 }
2597 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2598 Some(_) => {
2599 let msg = "unexpected message type during COPY from stdin";
2600 self.adapter_client.retire_execute(
2601 std::mem::take(ctx_extra),
2602 StatementEndedExecutionReason::Errored {
2603 error: msg.to_string(),
2604 },
2605 );
2606 return self
2607 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
2608 .await;
2609 }
2610 None => {
2611 return Ok(State::Done);
2612 }
2613 }
2614 }
2615
2616 let column_types = typ
2617 .column_types
2618 .iter()
2619 .map(|x| &x.scalar_type)
2620 .map(mz_pgrepr::Type::from)
2621 .collect::<Vec<mz_pgrepr::Type>>();
2622
2623 let rows = match mz_pgcopy::decode_copy_format(&data, &column_types, params) {
2624 Ok(rows) => rows,
2625 Err(e) => {
2626 self.adapter_client.retire_execute(
2627 std::mem::take(ctx_extra),
2628 StatementEndedExecutionReason::Errored {
2629 error: e.to_string(),
2630 },
2631 );
2632 return self
2633 .error(ErrorResponse::error(
2634 SqlState::BAD_COPY_FILE_FORMAT,
2635 format!("{}", e),
2636 ))
2637 .await;
2638 }
2639 };
2640
2641 let count = rows.len();
2642
2643 if let Err(e) = self
2644 .adapter_client
2645 .insert_rows(
2646 target_id,
2647 target_name,
2648 columns,
2649 rows,
2650 std::mem::take(ctx_extra),
2651 )
2652 .await
2653 {
2654 self.adapter_client.retire_execute(
2655 std::mem::take(ctx_extra),
2656 StatementEndedExecutionReason::Errored {
2657 error: e.to_string(),
2658 },
2659 );
2660 return self.error(e.into_response(Severity::Error)).await;
2661 }
2662
2663 let tag = format!("COPY {}", count);
2664 self.send(BackendMessage::CommandComplete { tag }).await?;
2665
2666 Ok(State::Ready)
2667 }
2668
2669 #[instrument(level = "debug")]
2670 async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
2671 let notices = self
2672 .adapter_client
2673 .session()
2674 .drain_notices()
2675 .into_iter()
2676 .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
2677 self.send_all(notices).await?;
2678 Ok(())
2679 }
2680
2681 #[instrument(level = "debug")]
2682 async fn error(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
2683 assert!(err.severity.is_error());
2684 debug!(
2685 "cid={} error code={}",
2686 self.adapter_client.session().conn_id(),
2687 err.code.code()
2688 );
2689 let is_fatal = err.severity.is_fatal();
2690 self.send(BackendMessage::ErrorResponse(err)).await?;
2691
2692 let txn = self.adapter_client.session().transaction();
2693 match txn {
2694 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
2697 TransactionStatus::Started(_) => {
2699 self.rollback_transaction().await?;
2700 }
2701 TransactionStatus::InTransactionImplicit(_) => {
2703 self.rollback_transaction().await?;
2704 }
2705 TransactionStatus::InTransaction(_) => {
2707 self.adapter_client.fail_transaction();
2708 }
2709 };
2710 if is_fatal {
2711 Ok(State::Done)
2712 } else {
2713 Ok(State::Drain)
2714 }
2715 }
2716
2717 #[instrument(level = "debug")]
2718 async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
2719 self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
2720 SqlState::IN_FAILED_SQL_TRANSACTION,
2721 ABORTED_TXN_MSG,
2722 )))
2723 .await?;
2724 Ok(State::Drain)
2725 }
2726
2727 fn is_aborted_txn(&mut self) -> bool {
2728 matches!(
2729 self.adapter_client.session().transaction(),
2730 TransactionStatus::Failed(_)
2731 )
2732 }
2733}
2734
2735fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
2736 match (formats.len(), n) {
2737 (0, e) => Ok(vec![Format::Text; e]),
2738 (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
2739 (a, e) if a == e => Ok(formats),
2740 (a, e) => Err(format!(
2741 "expected {} field format specifiers, but got {}",
2742 e, a
2743 )),
2744 }
2745}
2746
2747fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
2748 match &stmt_desc.relation_desc {
2749 Some(desc) if !stmt_desc.is_copy => {
2750 BackendMessage::RowDescription(message::encode_row_description(desc, formats))
2751 }
2752 _ => BackendMessage::NoData,
2753 }
2754}
2755
2756type GetResponse = fn(
2757 max_rows: ExecuteCount,
2758 total_sent_rows: usize,
2759 fetch_portal: Option<PortalRefMut>,
2760) -> BackendMessage;
2761
2762fn portal_exec_message(
2765 max_rows: ExecuteCount,
2766 total_sent_rows: usize,
2767 _fetch_portal: Option<PortalRefMut>,
2768) -> BackendMessage {
2769 match max_rows {
2776 ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
2777 BackendMessage::PortalSuspended
2778 }
2779 _ => BackendMessage::CommandComplete {
2780 tag: format!("SELECT {}", total_sent_rows),
2781 },
2782 }
2783}
2784
2785fn fetch_message(
2787 _max_rows: ExecuteCount,
2788 total_sent_rows: usize,
2789 fetch_portal: Option<PortalRefMut>,
2790) -> BackendMessage {
2791 let tag = format!("FETCH {}", total_sent_rows);
2792 if let Some(portal) = fetch_portal {
2793 *portal.state = PortalState::Completed(Some(tag.clone()));
2794 }
2795 BackendMessage::CommandComplete { tag }
2796}
2797
2798#[derive(Debug, Copy, Clone)]
2799enum ExecuteCount {
2800 All,
2801 Count(usize),
2802}
2803
2804fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
2806 match stmt {
2807 Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
2809 None => false,
2810 }
2811}
2812
2813#[derive(Debug)]
2814enum FetchResult {
2815 Rows(Option<Box<dyn RowIterator + Send + Sync>>),
2816 Canceled,
2817 Error(String),
2818 Notice(AdapterNotice),
2819}
2820
2821#[cfg(test)]
2822mod test {
2823 use super::*;
2824
2825 #[mz_ore::test]
2826 fn test_parse_options() {
2827 struct TestCase {
2828 input: &'static str,
2829 expect: Result<Vec<(&'static str, &'static str)>, ()>,
2830 }
2831 let tests = vec![
2832 TestCase {
2833 input: "",
2834 expect: Ok(vec![]),
2835 },
2836 TestCase {
2837 input: "--key",
2838 expect: Err(()),
2839 },
2840 TestCase {
2841 input: "--key=val",
2842 expect: Ok(vec![("key", "val")]),
2843 },
2844 TestCase {
2845 input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
2846 expect: Ok(vec![
2847 ("key", "val"),
2848 ("key2", "val2"),
2849 ("key3", "val3"),
2850 ("key4", "val4"),
2851 ("key5", "val5"),
2852 ]),
2853 },
2854 TestCase {
2855 input: r#"-c\ key=val"#,
2856 expect: Ok(vec![(" key", "val")]),
2857 },
2858 TestCase {
2859 input: "--key=val -ckey2 val2",
2860 expect: Err(()),
2861 },
2862 TestCase {
2864 input: "--key=",
2865 expect: Ok(vec![("key", "")]),
2866 },
2867 ];
2868 for test in tests {
2869 let got = parse_options(test.input);
2870 let expect = test.expect.map(|r| {
2871 r.into_iter()
2872 .map(|(k, v)| (k.to_owned(), v.to_owned()))
2873 .collect()
2874 });
2875 assert_eq!(got, expect, "input: {}", test.input);
2876 }
2877 }
2878
2879 #[mz_ore::test]
2880 fn test_parse_option() {
2881 struct TestCase {
2882 input: &'static str,
2883 expect: Result<(&'static str, &'static str), ()>,
2884 }
2885 let tests = vec![
2886 TestCase {
2887 input: "",
2888 expect: Err(()),
2889 },
2890 TestCase {
2891 input: "--",
2892 expect: Err(()),
2893 },
2894 TestCase {
2895 input: "--c",
2896 expect: Err(()),
2897 },
2898 TestCase {
2899 input: "a=b",
2900 expect: Err(()),
2901 },
2902 TestCase {
2903 input: "--a=b",
2904 expect: Ok(("a", "b")),
2905 },
2906 TestCase {
2907 input: "--ca=b",
2908 expect: Ok(("ca", "b")),
2909 },
2910 TestCase {
2911 input: "-ca=b",
2912 expect: Ok(("a", "b")),
2913 },
2914 TestCase {
2916 input: "--=",
2917 expect: Ok(("", "")),
2918 },
2919 ];
2920 for test in tests {
2921 let got = parse_option(test.input);
2922 assert_eq!(got, test.expect, "input: {}", test.input);
2923 }
2924 }
2925
2926 #[mz_ore::test]
2927 fn test_split_options() {
2928 struct TestCase {
2929 input: &'static str,
2930 expect: Vec<&'static str>,
2931 }
2932 let tests = vec![
2933 TestCase {
2934 input: "",
2935 expect: vec![],
2936 },
2937 TestCase {
2938 input: " ",
2939 expect: vec![],
2940 },
2941 TestCase {
2942 input: " a ",
2943 expect: vec!["a"],
2944 },
2945 TestCase {
2946 input: " ab cd ",
2947 expect: vec!["ab", "cd"],
2948 },
2949 TestCase {
2950 input: r#" ab\ cd "#,
2951 expect: vec!["ab ", "cd"],
2952 },
2953 TestCase {
2954 input: r#" ab\\ cd "#,
2955 expect: vec![r#"ab\"#, "cd"],
2956 },
2957 TestCase {
2958 input: r#" ab\\\ cd "#,
2959 expect: vec![r#"ab\ "#, "cd"],
2960 },
2961 TestCase {
2962 input: r#" ab\\\ cd "#,
2963 expect: vec![r#"ab\ cd"#],
2964 },
2965 TestCase {
2966 input: r#" ab\\\cd "#,
2967 expect: vec![r#"ab\cd"#],
2968 },
2969 TestCase {
2970 input: r#"a\"#,
2971 expect: vec!["a"],
2972 },
2973 TestCase {
2974 input: r#"a\ "#,
2975 expect: vec!["a "],
2976 },
2977 TestCase {
2978 input: r#"\"#,
2979 expect: vec![],
2980 },
2981 TestCase {
2982 input: r#"\ "#,
2983 expect: vec![r#" "#],
2984 },
2985 TestCase {
2986 input: r#" \ "#,
2987 expect: vec![r#" "#],
2988 },
2989 TestCase {
2990 input: r#"\ "#,
2991 expect: vec![r#" "#],
2992 },
2993 ];
2994 for test in tests {
2995 let got = split_options(test.input);
2996 assert_eq!(got, test.expect, "input: {}", test.input);
2997 }
2998 }
2999}