1use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use futures::future::{BoxFuture, FutureExt, pending};
21use itertools::Itertools;
22use mz_adapter::client::RecordFirstRowStream;
23use mz_adapter::session::{
24 EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState, Session,
25 SessionConfig, TransactionStatus,
26};
27use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
28use mz_adapter::{
29 AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics,
30 verify_datum_desc,
31};
32use mz_auth::Authenticated;
33use mz_auth::password::Password;
34use mz_authenticator::{Authenticator, GenericOidcAuthenticator};
35use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
36use mz_ore::cast::CastFrom;
37use mz_ore::netio::AsyncReady;
38use mz_ore::now::{EpochMillis, SYSTEM_TIME};
39use mz_ore::str::StrExt;
40use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
41use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
42use mz_pgwire_common::{
43 ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
44 VERSIONS,
45};
46use mz_repr::{
47 CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
48 SqlRelationType, SqlScalarType,
49};
50use mz_server_core::TlsMode;
51use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind};
52use mz_sql::ast::display::AstDisplay;
53use mz_sql::ast::{CopyDirection, CopyStatement, FetchDirection, Ident, Raw, Statement};
54use mz_sql::parse::StatementParseResult;
55use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
56use mz_sql::session::metadata::SessionMetadata;
57use mz_sql::session::user::INTERNAL_USER_NAMES;
58use mz_sql::session::vars::{MAX_COPY_FROM_SIZE, Var, VarInput};
59use postgres::error::SqlState;
60use tokio::io::{self, AsyncRead, AsyncWrite};
61use tokio::select;
62use tokio::time::{self};
63use tokio_metrics::TaskMetrics;
64use tokio_stream::wrappers::UnboundedReceiverStream;
65use tracing::{Instrument, debug, debug_span, warn};
66use uuid::Uuid;
67
68use crate::codec::{
69 FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
70};
71use crate::message::{
72 self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
73 SASLServerFirstMessage,
74};
75
76pub fn match_handshake(buf: &[u8]) -> bool {
80 if buf.len() < 8 {
90 return false;
91 }
92 let version = NetworkEndian::read_i32(&buf[4..8]);
93 VERSIONS.contains(&version)
94}
95
96pub struct RunParams<'a, A, I>
98where
99 I: Iterator<Item = TaskMetrics> + Send,
100{
101 pub tls_mode: Option<TlsMode>,
103 pub adapter_client: mz_adapter::Client,
105 pub conn: &'a mut FramedConn<A>,
107 pub conn_uuid: Uuid,
109 pub version: i32,
111 pub params: BTreeMap<String, String>,
113 pub frontegg: Option<FronteggAuthenticator>,
115 pub oidc: GenericOidcAuthenticator,
117 pub authenticator_kind: AuthenticatorKind,
120 pub active_connection_counter: ConnectionCounter,
122 pub helm_chart_version: Option<String>,
124 pub allowed_roles: AllowedRoles,
126 pub tokio_metrics_intervals: I,
128}
129
130#[mz_ore::instrument(level = "debug")]
140pub async fn run<'a, A, I>(
141 RunParams {
142 tls_mode,
143 adapter_client,
144 conn,
145 conn_uuid,
146 version,
147 mut params,
148 frontegg,
149 oidc,
150 authenticator_kind,
151 active_connection_counter,
152 helm_chart_version,
153 allowed_roles,
154 tokio_metrics_intervals,
155 }: RunParams<'a, A, I>,
156) -> Result<(), io::Error>
157where
158 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
159 I: Iterator<Item = TaskMetrics> + Send,
160{
161 if version != VERSION_3 {
162 return conn
163 .send(ErrorResponse::fatal(
164 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
165 "server does not support the client's requested protocol version",
166 ))
167 .await;
168 }
169
170 let user = params.remove("user").unwrap_or_else(String::new);
171 let options = parse_options(params.get("options").unwrap_or(&String::new()));
172
173 let (oidc_auth_enabled, options) = extract_oidc_auth_enabled_from_options(options);
176 let authenticator = get_authenticator(
177 authenticator_kind,
178 frontegg,
179 oidc,
180 adapter_client.clone(),
181 oidc_auth_enabled,
182 );
183 let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
185 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
187 let role_allowed = match allowed_roles {
188 AllowedRoles::Normal => !is_reserved_user,
189 AllowedRoles::Internal => is_internal_user,
190 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
191 };
192 if !role_allowed {
193 let msg = format!("unauthorized login to user '{user}'");
194 return conn
195 .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
196 .await;
197 }
198
199 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
200 return conn.send(err).await;
201 }
202
203 let (mut session, expired) = match authenticator {
204 Authenticator::Frontegg(frontegg) => {
205 let password = match request_cleartext_password(conn).await {
206 Ok(password) => password,
207 Err(PasswordRequestError::IoError(e)) => return Err(e),
208 Err(PasswordRequestError::InvalidPasswordError(e)) => {
209 return conn.send(e).await;
210 }
211 };
212
213 let auth_response = frontegg.authenticate(&user, &password).await;
214 match auth_response {
215 Ok((mut auth_session, authenticated)) => {
222 let session = adapter_client.new_session(
223 SessionConfig {
224 conn_id: conn.conn_id().clone(),
225 uuid: conn_uuid,
226 user: auth_session.user().into(),
227 client_ip: conn.peer_addr().clone(),
228 external_metadata_rx: Some(auth_session.external_metadata_rx()),
229 helm_chart_version,
230 },
231 authenticated,
232 );
233 let expired = async move { auth_session.expired().await };
234 (session, expired.left_future())
235 }
236 Err(err) => {
237 warn!(?err, "pgwire connection failed authentication");
238 return conn
239 .send(ErrorResponse::fatal(
240 SqlState::INVALID_PASSWORD,
241 "invalid password",
242 ))
243 .await;
244 }
245 }
246 }
247 Authenticator::Oidc(oidc) => {
248 let jwt = match request_cleartext_password(conn).await {
250 Ok(password) => password,
251 Err(PasswordRequestError::IoError(e)) => return Err(e),
252 Err(PasswordRequestError::InvalidPasswordError(e)) => {
253 return conn.send(e).await;
254 }
255 };
256 let auth_response = oidc.authenticate(&jwt, Some(&user)).await;
257 match auth_response {
258 Ok((claims, authenticated)) => {
259 let session = adapter_client.new_session(
260 SessionConfig {
261 conn_id: conn.conn_id().clone(),
262 uuid: conn_uuid,
263 user: claims.username().into(),
264 client_ip: conn.peer_addr().clone(),
265 external_metadata_rx: None,
266 helm_chart_version,
267 },
268 authenticated,
269 );
270 (session, pending().right_future())
273 }
274 Err(err) => {
275 warn!(?err, "pgwire connection failed authentication");
276 return conn
277 .send(ErrorResponse::fatal(
278 SqlState::INVALID_PASSWORD,
279 "invalid password",
280 ))
281 .await;
282 }
283 }
284 }
285 Authenticator::Password(adapter_client) => {
286 let session = match authenticate_with_password(
287 conn,
288 &adapter_client,
289 user,
290 conn_uuid,
291 helm_chart_version,
292 )
293 .await
294 {
295 Ok(session) => session,
296 Err(PasswordRequestError::IoError(e)) => return Err(e),
297 Err(PasswordRequestError::InvalidPasswordError(e)) => {
298 return conn.send(e).await;
299 }
300 };
301 (session, pending().right_future())
303 }
304 Authenticator::Sasl(adapter_client) => {
305 conn.send(BackendMessage::AuthenticationSASL).await?;
307 conn.flush().await?;
308 let (mechanism, initial_response) = match conn.recv().await? {
310 Some(FrontendMessage::RawAuthentication(data)) => {
311 match decode_sasl_initial_response(Cursor::new(&data)).ok() {
312 Some(FrontendMessage::SASLInitialResponse {
313 gs2_header,
314 mechanism,
315 initial_response,
316 }) => {
317 if gs2_header.channel_binding_enabled() {
319 return conn
320 .send(ErrorResponse::fatal(
321 SqlState::PROTOCOL_VIOLATION,
322 "channel binding not supported",
323 ))
324 .await;
325 }
326 (mechanism, initial_response)
327 }
328 _ => {
329 return conn
330 .send(ErrorResponse::fatal(
331 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
332 "expected SASLInitialResponse message",
333 ))
334 .await;
335 }
336 }
337 }
338 _ => {
339 return conn
340 .send(ErrorResponse::fatal(
341 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
342 "expected SASLInitialResponse message",
343 ))
344 .await;
345 }
346 };
347
348 if mechanism != "SCRAM-SHA-256" {
349 return conn
350 .send(ErrorResponse::fatal(
351 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
352 "unsupported SASL mechanism",
353 ))
354 .await;
355 }
356
357 if initial_response.nonce.len() > 256 {
358 return conn
359 .send(ErrorResponse::fatal(
360 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
361 "nonce too long",
362 ))
363 .await;
364 }
365
366 let (server_first_message_raw, mock_hash) = match adapter_client
367 .generate_sasl_challenge(&user, &initial_response.nonce)
368 .await
369 {
370 Ok(response) => {
371 let server_first_message_raw = format!(
372 "r={},s={},i={}",
373 response.nonce, response.salt, response.iteration_count
374 );
375
376 let client_key = [0u8; 32];
377 let server_key = [1u8; 32];
378 let mock_hash = format!(
379 "SCRAM-SHA-256${}:{}${}:{}",
380 response.iteration_count,
381 response.salt,
382 BASE64_STANDARD.encode(client_key),
383 BASE64_STANDARD.encode(server_key)
384 );
385
386 conn.send(BackendMessage::AuthenticationSASLContinue(
387 SASLServerFirstMessage {
388 iteration_count: response.iteration_count,
389 nonce: response.nonce,
390 salt: response.salt,
391 },
392 ))
393 .await?;
394 conn.flush().await?;
395 (server_first_message_raw, mock_hash)
396 }
397 Err(e) => {
398 return conn.send(e.into_response(Severity::Fatal)).await;
399 }
400 };
401
402 let authenticated = match conn.recv().await? {
403 Some(FrontendMessage::RawAuthentication(data)) => {
404 match decode_sasl_response(Cursor::new(&data)).ok() {
405 Some(FrontendMessage::SASLResponse(response)) => {
406 let auth_message = format!(
407 "{},{},{}",
408 initial_response.client_first_message_bare_raw,
409 server_first_message_raw,
410 response.client_final_message_bare_raw
411 );
412 if response.proof.len() > 1024 {
413 return conn
414 .send(ErrorResponse::fatal(
415 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
416 "proof too long",
417 ))
418 .await;
419 }
420 match adapter_client
421 .verify_sasl_proof(
422 &user,
423 &response.proof,
424 &auth_message,
425 &mock_hash,
426 )
427 .await
428 {
429 Ok((proof_response, authenticated)) => {
430 conn.send(BackendMessage::AuthenticationSASLFinal(
431 SASLServerFinalMessage {
432 kind: SASLServerFinalMessageKinds::Verifier(
433 proof_response.verifier,
434 ),
435 extensions: vec![],
436 },
437 ))
438 .await?;
439 conn.flush().await?;
440 authenticated
441 }
442 Err(_) => {
443 return conn
444 .send(ErrorResponse::fatal(
445 SqlState::INVALID_PASSWORD,
446 "invalid password",
447 ))
448 .await;
449 }
450 }
451 }
452 _ => {
453 return conn
454 .send(ErrorResponse::fatal(
455 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
456 "expected SASLResponse message",
457 ))
458 .await;
459 }
460 }
461 }
462 _ => {
463 return conn
464 .send(ErrorResponse::fatal(
465 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
466 "expected SASLResponse message",
467 ))
468 .await;
469 }
470 };
471
472 let session = adapter_client.new_session(
473 SessionConfig {
474 conn_id: conn.conn_id().clone(),
475 uuid: conn_uuid,
476 user,
477 client_ip: conn.peer_addr().clone(),
478 external_metadata_rx: None,
479 helm_chart_version,
480 },
481 authenticated,
482 );
483 let auth_session = pending().right_future();
485 (session, auth_session)
486 }
487
488 Authenticator::None => {
489 let session = adapter_client.new_session(
490 SessionConfig {
491 conn_id: conn.conn_id().clone(),
492 uuid: conn_uuid,
493 user,
494 client_ip: conn.peer_addr().clone(),
495 external_metadata_rx: None,
496 helm_chart_version,
497 },
498 Authenticated,
499 );
500 let auth_session = pending().right_future();
502 (session, auth_session)
503 }
504 };
505
506 let system_vars = adapter_client.get_system_vars().await;
507 for (name, value) in params {
508 let settings = match name.as_str() {
509 "options" => match &options {
510 Ok(opts) => opts,
511 Err(()) => {
512 session.add_notice(AdapterNotice::BadStartupSetting {
513 name,
514 reason: "could not parse".into(),
515 });
516 continue;
517 }
518 },
519 _ => &vec![(name, value)],
520 };
521 for (key, val) in settings {
522 const LOCAL: bool = false;
523 if let Err(err) = session
528 .vars_mut()
529 .set(&system_vars, key, VarInput::Flat(val), LOCAL)
530 {
531 session.add_notice(AdapterNotice::BadStartupSetting {
532 name: key.clone(),
533 reason: err.to_string(),
534 });
535 }
536 }
537 }
538 session
539 .vars_mut()
540 .end_transaction(EndTransactionAction::Commit);
541
542 let _guard = match active_connection_counter.allocate_connection(session.user()) {
543 Ok(drop_connection) => drop_connection,
544 Err(e) => {
545 let e: AdapterError = e.into();
546 return conn.send(e.into_response(Severity::Fatal)).await;
547 }
548 };
549
550 let mut adapter_client = match adapter_client.startup(session).await {
552 Ok(adapter_client) => adapter_client,
553 Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
554 };
555
556 let mut buf = vec![BackendMessage::AuthenticationOk];
557 for var in adapter_client.session().vars().notify_set() {
558 buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
559 }
560 buf.push(BackendMessage::BackendKeyData {
561 conn_id: adapter_client.session().conn_id().unhandled(),
562 secret_key: adapter_client.session().secret_key(),
563 });
564 buf.extend(
565 adapter_client
566 .session()
567 .drain_notices()
568 .into_iter()
569 .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
570 );
571 buf.push(BackendMessage::ReadyForQuery(
572 adapter_client.session().transaction().into(),
573 ));
574 conn.send_all(buf).await?;
575 conn.flush().await?;
576
577 let machine = StateMachine {
578 conn,
579 adapter_client,
580 txn_needs_commit: false,
581 tokio_metrics_intervals,
582 };
583
584 select! {
585 r = machine.run() => {
586 if let Err(err) = &r {
591 let _ = conn
592 .send(ErrorResponse::fatal(
593 SqlState::CONNECTION_FAILURE,
594 err.to_string(),
595 ))
596 .await;
597 let _ = conn.flush().await;
598 }
599 r
600 },
601 _ = expired => {
602 conn
603 .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
604 .await?;
605 conn.flush().await
606 }
607 }
608}
609
610fn extract_oidc_auth_enabled_from_options(
614 options: Result<Vec<(String, String)>, ()>,
615) -> (bool, Result<Vec<(String, String)>, ()>) {
616 let options = match options {
617 Ok(opts) => opts,
618 Err(_) => return (false, options),
619 };
620
621 let mut new_options = Vec::new();
622 let mut oidc_auth_enabled = false;
623
624 for (k, v) in options {
625 if k == "oidc_auth_enabled" {
626 oidc_auth_enabled = v.parse::<bool>().unwrap_or(false);
627 } else {
628 new_options.push((k, v));
629 }
630 }
631
632 (oidc_auth_enabled, Ok(new_options))
633}
634
635fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
640 let opts = split_options(value);
641 let mut pairs = Vec::with_capacity(opts.len());
642 let mut seen_prefix = false;
643 for opt in opts {
644 if !seen_prefix {
645 if opt == "-c" {
646 seen_prefix = true;
647 } else {
648 let (key, val) = parse_option(&opt)?;
649 pairs.push((key.to_owned(), val.to_owned()));
650 }
651 } else {
652 let (key, val) = opt.split_once('=').ok_or(())?;
653 pairs.push((key.to_owned(), val.to_owned()));
654 seen_prefix = false;
655 }
656 }
657 Ok(pairs)
658}
659
660fn parse_option(option: &str) -> Result<(&str, &str), ()> {
664 let (key, value) = option.split_once('=').ok_or(())?;
665 for prefix in &["-c", "--"] {
666 if let Some(key) = key.strip_prefix(prefix) {
667 return Ok((key, value));
668 }
669 }
670 Err(())
671}
672
673fn split_options(value: &str) -> Vec<String> {
675 let mut strs = Vec::new();
676 let mut current = String::new();
680 let mut was_slash = false;
681 for c in value.chars() {
682 was_slash = match c {
683 ' ' => {
684 if was_slash {
685 current.push(' ');
686 } else if !current.is_empty() {
687 strs.push(std::mem::take(&mut current));
690 }
691 false
692 }
693 '\\' => {
694 if was_slash {
695 current.push('\\');
698 false
699 } else {
700 true
701 }
702 }
703 _ => {
704 current.push(c);
705 false
706 }
707 };
708 }
709 if !current.is_empty() {
711 strs.push(current);
712 }
713 strs
714}
715
716enum PasswordRequestError {
717 InvalidPasswordError(ErrorResponse),
718 IoError(io::Error),
719}
720
721impl From<io::Error> for PasswordRequestError {
722 fn from(e: io::Error) -> Self {
723 PasswordRequestError::IoError(e)
724 }
725}
726
727async fn request_cleartext_password<A>(
731 conn: &mut FramedConn<A>,
732) -> Result<String, PasswordRequestError>
733where
734 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
735{
736 conn.send(BackendMessage::AuthenticationCleartextPassword)
737 .await?;
738 conn.flush().await?;
739
740 if let Some(message) = conn.recv().await? {
741 if let FrontendMessage::RawAuthentication(data) = message {
742 if let Some(FrontendMessage::Password { password }) =
743 decode_password(Cursor::new(&data)).ok()
744 {
745 return Ok(password);
746 }
747 }
748 }
749
750 Err(PasswordRequestError::InvalidPasswordError(
751 ErrorResponse::fatal(
752 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
753 "expected Password message",
754 ),
755 ))
756}
757
758async fn authenticate_with_password<A>(
761 conn: &mut FramedConn<A>,
762 adapter_client: &mz_adapter::Client,
763 user: String,
764 conn_uuid: Uuid,
765 helm_chart_version: Option<String>,
766) -> Result<Session, PasswordRequestError>
767where
768 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
769{
770 let password = match request_cleartext_password(conn).await {
771 Ok(password) => Password(password),
772 Err(e) => return Err(e),
773 };
774
775 let authenticated = match adapter_client.authenticate(&user, &password).await {
776 Ok(authenticated) => authenticated,
777 Err(err) => {
778 warn!(?err, "pgwire connection failed authentication");
779 return Err(PasswordRequestError::InvalidPasswordError(
780 ErrorResponse::fatal(SqlState::INVALID_PASSWORD, "invalid password"),
781 ));
782 }
783 };
784
785 let session = adapter_client.new_session(
786 SessionConfig {
787 conn_id: conn.conn_id().clone(),
788 uuid: conn_uuid,
789 user,
790 client_ip: conn.peer_addr().clone(),
791 external_metadata_rx: None,
792 helm_chart_version,
793 },
794 authenticated,
795 );
796
797 Ok(session)
798}
799
800#[derive(Debug)]
801enum State {
802 Ready,
803 Drain,
804 Done,
805}
806
807struct StateMachine<'a, A, I>
808where
809 I: Iterator<Item = TaskMetrics> + Send + 'a,
810{
811 conn: &'a mut FramedConn<A>,
812 adapter_client: mz_adapter::SessionClient,
813 txn_needs_commit: bool,
814 tokio_metrics_intervals: I,
815}
816
817enum SendRowsEndedReason {
818 Success {
819 result_size: u64,
820 rows_returned: u64,
821 },
822 Errored {
823 error: String,
824 },
825 Canceled,
826}
827
828const ABORTED_TXN_MSG: &str =
829 "current transaction is aborted, commands ignored until end of transaction block";
830
831impl<'a, A, I> StateMachine<'a, A, I>
832where
833 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
834 I: Iterator<Item = TaskMetrics> + Send + 'a,
835{
836 #[allow(clippy::manual_async_fn)]
840 #[mz_ore::instrument(level = "debug")]
841 fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
842 async move {
843 let mut state = State::Ready;
844 loop {
845 self.send_pending_notices().await?;
846 state = match state {
847 State::Ready => self.advance_ready().await?,
848 State::Drain => self.advance_drain().await?,
849 State::Done => return Ok(()),
850 };
851 self.adapter_client
852 .add_idle_in_transaction_session_timeout();
853 }
854 }
855 }
856
857 #[instrument(level = "debug")]
858 async fn advance_ready(&mut self) -> Result<State, io::Error> {
859 self.tokio_metrics_intervals
861 .next()
862 .expect("infinite iterator");
863
864 let message = select! {
866 biased;
867
868 Some(timeout) = self.adapter_client.recv_timeout() => {
870 let err: AdapterError = timeout.into();
871 let conn_id = self.adapter_client.session().conn_id();
872 tracing::warn!("session timed out, conn_id {}", conn_id);
873
874 let error_response = err.into_response(Severity::Fatal);
876 let error_state = self.error(error_response).await;
877
878 self.adapter_client.terminate().await;
880
881 let _ = self.conn.recv().await?;
885 return error_state;
886 },
887 message = self.conn.recv() => message?,
889 };
890
891 let interval = self
893 .tokio_metrics_intervals
894 .next()
895 .expect("infinite iterator");
896 let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
897
898 let received = SYSTEM_TIME();
902
903 self.adapter_client
904 .remove_idle_in_transaction_session_timeout();
905
906 let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
909
910 let start = message.as_ref().map(|_| Instant::now());
911 let next_state = match message {
912 Some(FrontendMessage::Query { sql }) => {
913 let query_root_span =
914 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
915 query_root_span.follows_from(tracing::Span::current());
916 self.query(sql, received)
917 .instrument(query_root_span)
918 .await?
919 }
920 Some(FrontendMessage::Parse {
921 name,
922 sql,
923 param_types,
924 }) => self.parse(name, sql, param_types).await?,
925 Some(FrontendMessage::Bind {
926 portal_name,
927 statement_name,
928 param_formats,
929 raw_params,
930 result_formats,
931 }) => {
932 self.bind(
933 portal_name,
934 statement_name,
935 param_formats,
936 raw_params,
937 result_formats,
938 )
939 .await?
940 }
941 Some(FrontendMessage::Execute {
942 portal_name,
943 max_rows,
944 }) => {
945 let max_rows = match usize::try_from(max_rows) {
946 Ok(0) | Err(_) => ExecuteCount::All, Ok(n) => ExecuteCount::Count(n),
948 };
949 let execute_root_span =
950 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
951 execute_root_span.follows_from(tracing::Span::current());
952 let state = self
953 .execute(
954 portal_name,
955 max_rows,
956 portal_exec_message,
957 None,
958 ExecuteTimeout::None,
959 None,
960 Some(received),
961 )
962 .instrument(execute_root_span)
963 .await?;
964 if self.adapter_client.session().transaction().is_implicit() {
979 self.txn_needs_commit = true;
980 }
981 state
982 }
983 Some(FrontendMessage::DescribeStatement { name }) => {
984 self.describe_statement(&name).await?
985 }
986 Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
987 Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
988 Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
989 Some(FrontendMessage::Flush) => self.flush().await?,
990 Some(FrontendMessage::Sync) => self.sync().await?,
991 Some(FrontendMessage::Terminate) => State::Done,
992
993 Some(FrontendMessage::CopyData(_))
994 | Some(FrontendMessage::CopyDone)
995 | Some(FrontendMessage::CopyFail(_))
996 | Some(FrontendMessage::Password { .. })
997 | Some(FrontendMessage::RawAuthentication(_))
998 | Some(FrontendMessage::SASLInitialResponse { .. })
999 | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
1000 None => State::Done,
1001 };
1002
1003 if let Some(start) = start {
1004 self.adapter_client
1005 .inner()
1006 .metrics()
1007 .pgwire_message_processing_seconds
1008 .with_label_values(&[message_name])
1009 .observe(start.elapsed().as_secs_f64());
1010 }
1011 self.adapter_client
1012 .inner()
1013 .metrics()
1014 .pgwire_recv_scheduling_delay_ms
1015 .with_label_values(&[message_name])
1016 .observe(recv_scheduling_delay_ms);
1017
1018 Ok(next_state)
1019 }
1020
1021 async fn advance_drain(&mut self) -> Result<State, io::Error> {
1022 let message = self.conn.recv().await?;
1023 if message.is_some() {
1024 self.adapter_client
1025 .remove_idle_in_transaction_session_timeout();
1026 }
1027 match message {
1028 Some(FrontendMessage::Sync) => self.sync().await,
1029 None => Ok(State::Done),
1030 _ => Ok(State::Drain),
1031 }
1032 }
1033
1034 #[instrument(level = "debug")]
1038 async fn one_query(
1039 &mut self,
1040 stmt: Statement<Raw>,
1041 sql: String,
1042 lifecycle_timestamps: LifecycleTimestamps,
1043 ) -> Result<State, io::Error> {
1044 const EMPTY_PORTAL: &str = "";
1047 if let Err(e) = self
1048 .adapter_client
1049 .declare(EMPTY_PORTAL.to_string(), stmt, sql)
1050 .await
1051 {
1052 return self.error(e.into_response(Severity::Error)).await;
1053 }
1054 let portal = self
1055 .adapter_client
1056 .session()
1057 .get_portal_unverified_mut(EMPTY_PORTAL)
1058 .expect("unnamed portal should be present");
1059
1060 *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
1061
1062 let stmt_desc = portal.desc.clone();
1063 if !stmt_desc.param_types.is_empty() {
1064 return self
1065 .error(ErrorResponse::error(
1066 SqlState::UNDEFINED_PARAMETER,
1067 "there is no parameter $1",
1068 ))
1069 .await;
1070 }
1071
1072 if let Some(relation_desc) = &stmt_desc.relation_desc {
1074 if !stmt_desc.is_copy {
1075 let formats = vec![Format::Text; stmt_desc.arity()];
1076 self.send(BackendMessage::RowDescription(
1077 message::encode_row_description(relation_desc, &formats),
1078 ))
1079 .await?;
1080 }
1081 }
1082
1083 let result = match self
1084 .adapter_client
1085 .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
1086 .await
1087 {
1088 Ok((response, execute_started)) => {
1089 self.send_pending_notices().await?;
1090 self.send_execute_response(
1091 response,
1092 stmt_desc.relation_desc,
1093 EMPTY_PORTAL.to_string(),
1094 ExecuteCount::All,
1095 portal_exec_message,
1096 None,
1097 ExecuteTimeout::None,
1098 execute_started,
1099 )
1100 .await
1101 }
1102 Err(e) => {
1103 self.send_pending_notices().await?;
1104 self.error(e.into_response(Severity::Error)).await
1105 }
1106 };
1107
1108 self.adapter_client.session().remove_portal(EMPTY_PORTAL);
1110
1111 result
1112 }
1113
1114 async fn ensure_transaction(
1115 &mut self,
1116 num_stmts: usize,
1117 message_type: &str,
1118 ) -> Result<(), io::Error> {
1119 let start = Instant::now();
1120 if self.txn_needs_commit {
1121 self.commit_transaction().await?;
1122 }
1123 let res = self.adapter_client.start_transaction(Some(num_stmts));
1126 assert_ok!(res);
1127 self.adapter_client
1128 .inner()
1129 .metrics()
1130 .pgwire_ensure_transaction_seconds
1131 .with_label_values(&[message_type])
1132 .observe(start.elapsed().as_secs_f64());
1133 Ok(())
1134 }
1135
1136 fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1137 let parse_start = Instant::now();
1138 let result = match self.adapter_client.parse(sql) {
1139 Ok(result) => result.map_err(|e| {
1140 let pos = sql[..e.error.pos].chars().count() + 1;
1143 ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1144 }),
1145 Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1146 };
1147 self.adapter_client
1148 .inner()
1149 .metrics()
1150 .parse_seconds
1151 .observe(parse_start.elapsed().as_secs_f64());
1152 result
1153 }
1154
1155 #[instrument(level = "debug")]
1160 async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1161 let stmts = match self.parse_sql(&sql) {
1163 Ok(stmts) => stmts,
1164 Err(err) => {
1165 self.error(err).await?;
1166 return self.ready().await;
1167 }
1168 };
1169
1170 let num_stmts = stmts.len();
1171
1172 for StatementParseResult { ast: stmt, sql } in stmts {
1174 if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1176 self.aborted_txn_error().await?;
1177 break;
1178 }
1179
1180 self.ensure_transaction(num_stmts, "query").await?;
1188
1189 match self
1190 .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1191 .await?
1192 {
1193 State::Ready => (),
1194 State::Drain => break,
1195 State::Done => return Ok(State::Done),
1196 }
1197 }
1198
1199 {
1201 if self.adapter_client.session().transaction().is_implicit() {
1202 self.commit_transaction().await?;
1203 }
1204 }
1205
1206 if num_stmts == 0 {
1207 self.send(BackendMessage::EmptyQueryResponse).await?;
1208 }
1209
1210 self.ready().await
1211 }
1212
1213 #[instrument(level = "debug")]
1214 async fn parse(
1215 &mut self,
1216 name: String,
1217 sql: String,
1218 param_oids: Vec<u32>,
1219 ) -> Result<State, io::Error> {
1220 self.ensure_transaction(1, "parse").await?;
1222
1223 let mut param_types = vec![];
1224 for oid in param_oids {
1225 match mz_pgrepr::Type::from_oid(oid) {
1226 Ok(ty) => match SqlScalarType::try_from(&ty) {
1227 Ok(ty) => param_types.push(Some(ty)),
1228 Err(err) => {
1229 return self
1230 .error(ErrorResponse::error(
1231 SqlState::INVALID_PARAMETER_VALUE,
1232 err.to_string(),
1233 ))
1234 .await;
1235 }
1236 },
1237 Err(_) if oid == 0 => param_types.push(None),
1238 Err(e) => {
1239 return self
1240 .error(ErrorResponse::error(
1241 SqlState::PROTOCOL_VIOLATION,
1242 e.to_string(),
1243 ))
1244 .await;
1245 }
1246 }
1247 }
1248
1249 let stmts = match self.parse_sql(&sql) {
1250 Ok(stmts) => stmts,
1251 Err(err) => {
1252 return self.error(err).await;
1253 }
1254 };
1255 if stmts.len() > 1 {
1256 return self
1257 .error(ErrorResponse::error(
1258 SqlState::INTERNAL_ERROR,
1259 "cannot insert multiple commands into a prepared statement",
1260 ))
1261 .await;
1262 }
1263 let (maybe_stmt, sql) = match stmts.into_iter().next() {
1264 None => (None, ""),
1265 Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1266 };
1267 if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1268 return self.aborted_txn_error().await;
1269 }
1270 match self
1271 .adapter_client
1272 .prepare(name, maybe_stmt, sql.to_string(), param_types)
1273 .await
1274 {
1275 Ok(()) => {
1276 self.send(BackendMessage::ParseComplete).await?;
1277 Ok(State::Ready)
1278 }
1279 Err(e) => self.error(e.into_response(Severity::Error)).await,
1280 }
1281 }
1282
1283 #[instrument(level = "debug")]
1285 async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1286 self.end_transaction(EndTransactionAction::Commit).await
1287 }
1288
1289 #[instrument(level = "debug")]
1291 async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1292 self.end_transaction(EndTransactionAction::Rollback).await
1293 }
1294
1295 #[instrument(level = "debug")]
1297 async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1298 self.txn_needs_commit = false;
1299 let resp = self.adapter_client.end_transaction(action).await;
1300 if let Err(err) = resp {
1301 self.send(BackendMessage::ErrorResponse(
1302 err.into_response(Severity::Error),
1303 ))
1304 .await?;
1305 }
1306 Ok(())
1307 }
1308
1309 #[instrument(level = "debug")]
1310 async fn bind(
1311 &mut self,
1312 portal_name: String,
1313 statement_name: String,
1314 param_formats: Vec<Format>,
1315 raw_params: Vec<Option<Vec<u8>>>,
1316 result_formats: Vec<Format>,
1317 ) -> Result<State, io::Error> {
1318 self.ensure_transaction(1, "bind").await?;
1320
1321 let aborted_txn = self.is_aborted_txn();
1322 let stmt = match self
1323 .adapter_client
1324 .get_prepared_statement(&statement_name)
1325 .await
1326 {
1327 Ok(stmt) => stmt,
1328 Err(err) => return self.error(err.into_response(Severity::Error)).await,
1329 };
1330
1331 let param_types = &stmt.desc().param_types;
1332 if param_types.len() != raw_params.len() {
1333 let message = format!(
1334 "bind message supplies {actual} parameters, \
1335 but prepared statement \"{name}\" requires {expected}",
1336 name = statement_name,
1337 actual = raw_params.len(),
1338 expected = param_types.len()
1339 );
1340 return self
1341 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, message))
1342 .await;
1343 }
1344 let param_formats = match pad_formats(param_formats, raw_params.len()) {
1345 Ok(param_formats) => param_formats,
1346 Err(msg) => {
1347 return self
1348 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1349 .await;
1350 }
1351 };
1352 if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1353 return self.aborted_txn_error().await;
1354 }
1355 let buf = RowArena::new();
1356 let mut params = vec![];
1357 for ((raw_param, mz_typ), format) in raw_params
1358 .into_iter()
1359 .zip_eq(param_types)
1360 .zip_eq(param_formats)
1361 {
1362 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1363 let datum = match raw_param {
1364 None => Datum::Null,
1365 Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1366 Ok(param) => param.into_datum(&buf, &pg_typ),
1367 Err(err) => {
1368 let msg = format!("unable to decode parameter: {}", err);
1369 return self
1370 .error(ErrorResponse::error(SqlState::INVALID_PARAMETER_VALUE, msg))
1371 .await;
1372 }
1373 },
1374 };
1375 params.push((datum, mz_typ.clone()))
1376 }
1377
1378 let result_formats = match pad_formats(
1379 result_formats,
1380 stmt.desc()
1381 .relation_desc
1382 .clone()
1383 .map(|desc| desc.typ().column_types.len())
1384 .unwrap_or(0),
1385 ) {
1386 Ok(result_formats) => result_formats,
1387 Err(msg) => {
1388 return self
1389 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1390 .await;
1391 }
1392 };
1393
1394 if !stmt.stmt().map_or(false, |stmt| {
1397 matches!(
1398 stmt,
1399 Statement::Copy(CopyStatement {
1400 direction: CopyDirection::To,
1401 ..
1402 })
1403 )
1404 }) {
1405 if let Some(desc) = stmt.desc().relation_desc.clone() {
1406 for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1407 match (format, &ty.scalar_type) {
1408 (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1409 return self
1410 .error(ErrorResponse::error(
1411 SqlState::PROTOCOL_VIOLATION,
1412 "binary encoding of list types is not implemented",
1413 ))
1414 .await;
1415 }
1416 (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1417 return self
1418 .error(ErrorResponse::error(
1419 SqlState::PROTOCOL_VIOLATION,
1420 "binary encoding of map types is not implemented",
1421 ))
1422 .await;
1423 }
1424 (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1425 return self
1426 .error(ErrorResponse::error(
1427 SqlState::PROTOCOL_VIOLATION,
1428 "binary encoding of aclitem types does not exist",
1429 ))
1430 .await;
1431 }
1432 _ => (),
1433 }
1434 }
1435 }
1436 }
1437
1438 let desc = stmt.desc().clone();
1439 let logging = Arc::clone(stmt.logging());
1440 let stmt_ast = stmt.stmt().cloned();
1441 let state_revision = stmt.state_revision;
1442 if let Err(err) = self.adapter_client.session().set_portal(
1443 portal_name,
1444 desc,
1445 stmt_ast,
1446 logging,
1447 params,
1448 result_formats,
1449 state_revision,
1450 ) {
1451 return self.error(err.into_response(Severity::Error)).await;
1452 }
1453
1454 self.send(BackendMessage::BindComplete).await?;
1455 Ok(State::Ready)
1456 }
1457
1458 fn execute(
1461 &mut self,
1462 portal_name: String,
1463 max_rows: ExecuteCount,
1464 get_response: GetResponse,
1465 fetch_portal_name: Option<String>,
1466 timeout: ExecuteTimeout,
1467 outer_ctx_extra: Option<ExecuteContextGuard>,
1468 received: Option<EpochMillis>,
1469 ) -> BoxFuture<'_, Result<State, io::Error>> {
1470 async move {
1471 let aborted_txn = self.is_aborted_txn();
1472
1473 let portal = match self
1475 .adapter_client
1476 .session()
1477 .get_portal_unverified_mut(&portal_name)
1478 {
1479 Some(portal) => portal,
1480 None => {
1481 let msg = format!("portal {} does not exist", portal_name.quoted());
1482 if let Some(outer_ctx_extra) = outer_ctx_extra {
1483 self.adapter_client.retire_execute(
1484 outer_ctx_extra,
1485 StatementEndedExecutionReason::Errored { error: msg.clone() },
1486 );
1487 }
1488 return self
1489 .error(ErrorResponse::error(SqlState::INVALID_CURSOR_NAME, msg))
1490 .await;
1491 }
1492 };
1493
1494 *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1495
1496 let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1498 if aborted_txn && !txn_exit_stmt {
1499 if let Some(outer_ctx_extra) = outer_ctx_extra {
1500 self.adapter_client.retire_execute(
1501 outer_ctx_extra,
1502 StatementEndedExecutionReason::Errored {
1503 error: ABORTED_TXN_MSG.to_string(),
1504 },
1505 );
1506 }
1507 return self.aborted_txn_error().await;
1508 }
1509
1510 let row_desc = portal.desc.relation_desc.clone();
1511 match portal.state {
1512 PortalState::NotStarted => {
1513 self.ensure_transaction(1, "execute").await?;
1515 match self
1516 .adapter_client
1517 .execute(
1518 portal_name.clone(),
1519 self.conn.wait_closed(),
1520 outer_ctx_extra,
1521 )
1522 .await
1523 {
1524 Ok((response, execute_started)) => {
1525 self.send_pending_notices().await?;
1526 self.send_execute_response(
1527 response,
1528 row_desc,
1529 portal_name,
1530 max_rows,
1531 get_response,
1532 fetch_portal_name,
1533 timeout,
1534 execute_started,
1535 )
1536 .await
1537 }
1538 Err(e) => {
1539 self.send_pending_notices().await?;
1540 self.error(e.into_response(Severity::Error)).await
1541 }
1542 }
1543 }
1544 PortalState::InProgress(rows) => {
1545 let rows = rows.take().expect("InProgress rows must be populated");
1546 let (result, statement_ended_execution_reason) = match self
1547 .send_rows(
1548 row_desc.expect("portal missing row desc on resumption"),
1549 portal_name,
1550 rows,
1551 max_rows,
1552 get_response,
1553 fetch_portal_name,
1554 timeout,
1555 )
1556 .await
1557 {
1558 Err(e) => {
1559 (Err(e), StatementEndedExecutionReason::Canceled)
1562 }
1563 Ok((ok, SendRowsEndedReason::Canceled)) => {
1564 (Ok(ok), StatementEndedExecutionReason::Canceled)
1565 }
1566 Ok((
1579 ok,
1580 SendRowsEndedReason::Success {
1581 result_size: _,
1582 rows_returned: _,
1583 },
1584 )) => (
1585 Ok(ok),
1586 StatementEndedExecutionReason::Success {
1587 result_size: None,
1588 rows_returned: None,
1589 execution_strategy: None,
1590 },
1591 ),
1592 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1593 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1594 }
1595 };
1596 if let Some(outer_ctx_extra) = outer_ctx_extra {
1597 self.adapter_client
1598 .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1599 }
1600 result
1601 }
1602 PortalState::Completed(Some(tag)) => {
1609 let tag = tag.to_string();
1610 if let Some(outer_ctx_extra) = outer_ctx_extra {
1611 self.adapter_client.retire_execute(
1612 outer_ctx_extra,
1613 StatementEndedExecutionReason::Success {
1614 result_size: None,
1615 rows_returned: None,
1616 execution_strategy: None,
1617 },
1618 );
1619 }
1620 self.send(BackendMessage::CommandComplete { tag }).await?;
1621 Ok(State::Ready)
1622 }
1623 PortalState::Completed(None) => {
1624 let error = format!(
1625 "portal {} cannot be run",
1626 Ident::new_unchecked(portal_name).to_ast_string_stable()
1627 );
1628 if let Some(outer_ctx_extra) = outer_ctx_extra {
1629 self.adapter_client.retire_execute(
1630 outer_ctx_extra,
1631 StatementEndedExecutionReason::Errored {
1632 error: error.clone(),
1633 },
1634 );
1635 }
1636 self.error(ErrorResponse::error(
1637 SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1638 error,
1639 ))
1640 .await
1641 }
1642 }
1643 }
1644 .instrument(debug_span!("execute"))
1645 .boxed()
1646 }
1647
1648 #[instrument(level = "debug")]
1649 async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1650 self.ensure_transaction(1, "describe_statement").await?;
1652
1653 let stmt = match self.adapter_client.get_prepared_statement(name).await {
1654 Ok(stmt) => stmt,
1655 Err(err) => return self.error(err.into_response(Severity::Error)).await,
1656 };
1657 let parameter_desc = BackendMessage::ParameterDescription(
1659 stmt.desc()
1660 .param_types
1661 .iter()
1662 .map(mz_pgrepr::Type::from)
1663 .collect(),
1664 );
1665 let formats = vec![Format::Text; stmt.desc().arity()];
1669 let row_desc = describe_rows(stmt.desc(), &formats);
1670 self.send_all([parameter_desc, row_desc]).await?;
1671 Ok(State::Ready)
1672 }
1673
1674 #[instrument(level = "debug")]
1675 async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1676 self.ensure_transaction(1, "describe_portal").await?;
1678
1679 let session = self.adapter_client.session();
1680 let row_desc = session
1681 .get_portal_unverified(name)
1682 .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1683 match row_desc {
1684 Some(row_desc) => {
1685 self.send(row_desc).await?;
1686 Ok(State::Ready)
1687 }
1688 None => {
1689 self.error(ErrorResponse::error(
1690 SqlState::INVALID_CURSOR_NAME,
1691 format!("portal {} does not exist", name.quoted()),
1692 ))
1693 .await
1694 }
1695 }
1696 }
1697
1698 #[instrument(level = "debug")]
1699 async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1700 self.adapter_client
1701 .session()
1702 .remove_prepared_statement(&name);
1703 self.send(BackendMessage::CloseComplete).await?;
1704 Ok(State::Ready)
1705 }
1706
1707 #[instrument(level = "debug")]
1708 async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1709 self.adapter_client.session().remove_portal(&name);
1710 self.send(BackendMessage::CloseComplete).await?;
1711 Ok(State::Ready)
1712 }
1713
1714 fn complete_portal(&mut self, name: &str) {
1715 let portal = self
1716 .adapter_client
1717 .session()
1718 .get_portal_unverified_mut(name)
1719 .expect("portal should exist");
1720 *portal.state = PortalState::Completed(None);
1721 }
1722
1723 async fn fetch(
1724 &mut self,
1725 name: String,
1726 count: Option<FetchDirection>,
1727 max_rows: ExecuteCount,
1728 fetch_portal_name: Option<String>,
1729 timeout: ExecuteTimeout,
1730 ctx_extra: ExecuteContextGuard,
1731 ) -> Result<State, io::Error> {
1732 let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1735
1736 let count = match (max_rows, count) {
1749 (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1750 let count = usize::cast_from(count);
1751 if max_rows < count {
1752 let msg = "Execute with max_rows < a FETCH's count is not supported";
1753 self.adapter_client.retire_execute(
1754 ctx_extra,
1755 StatementEndedExecutionReason::Errored {
1756 error: msg.to_string(),
1757 },
1758 );
1759 return self
1760 .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1761 .await;
1762 }
1763 ExecuteCount::Count(count)
1764 }
1765 (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1766 let msg = "Execute with max_rows of a FETCH ALL is not supported";
1767 self.adapter_client.retire_execute(
1768 ctx_extra,
1769 StatementEndedExecutionReason::Errored {
1770 error: msg.to_string(),
1771 },
1772 );
1773 return self
1774 .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1775 .await;
1776 }
1777 (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1778 (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1779 ExecuteCount::Count(usize::cast_from(count))
1780 }
1781 };
1782 let cursor_name = name.to_string();
1783 self.execute(
1784 cursor_name,
1785 count,
1786 fetch_message,
1787 fetch_portal_name,
1788 timeout,
1789 Some(ctx_extra),
1790 None,
1791 )
1792 .await
1793 }
1794
1795 async fn flush(&mut self) -> Result<State, io::Error> {
1796 self.conn.flush().await?;
1797 Ok(State::Ready)
1798 }
1799
1800 #[instrument(level = "debug")]
1805 async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1806 where
1807 M: Into<BackendMessage>,
1808 {
1809 let message: BackendMessage = message.into();
1810 let is_error =
1811 matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1812
1813 self.conn.send(message).await?;
1814
1815 if is_error {
1822 self.conn.flush().await?;
1823 }
1824
1825 Ok(())
1826 }
1827
1828 #[instrument(level = "debug")]
1829 pub async fn send_all(
1830 &mut self,
1831 messages: impl IntoIterator<Item = BackendMessage>,
1832 ) -> Result<(), io::Error> {
1833 for m in messages {
1834 self.send(m).await?;
1835 }
1836 Ok(())
1837 }
1838
1839 #[instrument(level = "debug")]
1840 async fn sync(&mut self) -> Result<State, io::Error> {
1841 if self.adapter_client.session().transaction().is_implicit() {
1843 self.commit_transaction().await?;
1844 }
1845 self.ready().await
1846 }
1847
1848 #[instrument(level = "debug")]
1849 async fn ready(&mut self) -> Result<State, io::Error> {
1850 let txn_state = self.adapter_client.session().transaction().into();
1851 self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1852 self.flush().await
1853 }
1854
1855 #[allow(clippy::too_many_arguments)]
1856 #[instrument(level = "debug")]
1857 async fn send_execute_response(
1858 &mut self,
1859 response: ExecuteResponse,
1860 row_desc: Option<RelationDesc>,
1861 portal_name: String,
1862 max_rows: ExecuteCount,
1863 get_response: GetResponse,
1864 fetch_portal_name: Option<String>,
1865 timeout: ExecuteTimeout,
1866 execute_started: Instant,
1867 ) -> Result<State, io::Error> {
1868 let mut tag = response.tag();
1869
1870 macro_rules! command_complete {
1871 () => {{
1872 self.send(BackendMessage::CommandComplete {
1873 tag: tag
1874 .take()
1875 .expect("command_complete only called on tag-generating results"),
1876 })
1877 .await?;
1878 Ok(State::Ready)
1879 }};
1880 }
1881
1882 let r = match response {
1883 ExecuteResponse::ClosedCursor => {
1884 self.complete_portal(&portal_name);
1885 command_complete!()
1886 }
1887 ExecuteResponse::DeclaredCursor => {
1888 self.complete_portal(&portal_name);
1889 command_complete!()
1890 }
1891 ExecuteResponse::EmptyQuery => {
1892 self.send(BackendMessage::EmptyQueryResponse).await?;
1893 Ok(State::Ready)
1894 }
1895 ExecuteResponse::Fetch {
1896 name,
1897 count,
1898 timeout,
1899 ctx_extra,
1900 } => {
1901 self.fetch(
1902 name,
1903 count,
1904 max_rows,
1905 Some(portal_name.to_string()),
1906 timeout,
1907 ctx_extra,
1908 )
1909 .await
1910 }
1911 ExecuteResponse::SendingRowsStreaming {
1912 rows,
1913 instance_id,
1914 strategy,
1915 } => {
1916 let row_desc = row_desc
1917 .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1918
1919 let span = tracing::debug_span!("sending_rows_streaming");
1920
1921 self.send_rows(
1922 row_desc,
1923 portal_name,
1924 InProgressRows::new(RecordFirstRowStream::new(
1925 Box::new(rows),
1926 execute_started,
1927 &self.adapter_client,
1928 Some(instance_id),
1929 Some(strategy),
1930 )),
1931 max_rows,
1932 get_response,
1933 fetch_portal_name,
1934 timeout,
1935 )
1936 .instrument(span)
1937 .await
1938 .map(|(state, _)| state)
1939 }
1940 ExecuteResponse::SendingRowsImmediate { rows } => {
1941 let row_desc = row_desc
1942 .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1943
1944 let span = tracing::debug_span!("sending_rows_immediate");
1945
1946 let stream =
1947 futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1948 self.send_rows(
1949 row_desc,
1950 portal_name,
1951 InProgressRows::new(RecordFirstRowStream::new(
1952 Box::new(stream),
1953 execute_started,
1954 &self.adapter_client,
1955 None,
1956 Some(StatementExecutionStrategy::Constant),
1957 )),
1958 max_rows,
1959 get_response,
1960 fetch_portal_name,
1961 timeout,
1962 )
1963 .instrument(span)
1964 .await
1965 .map(|(state, _)| state)
1966 }
1967 ExecuteResponse::SetVariable { name, .. } => {
1968 let qn = name.to_string();
1971 let msg = if let Some(var) = self
1972 .adapter_client
1973 .session()
1974 .vars_mut()
1975 .notify_set()
1976 .find(|v| v.name() == qn)
1977 {
1978 Some(BackendMessage::ParameterStatus(var.name(), var.value()))
1979 } else {
1980 None
1981 };
1982 if let Some(msg) = msg {
1983 self.send(msg).await?;
1984 }
1985 command_complete!()
1986 }
1987 ExecuteResponse::Subscribing {
1988 rx,
1989 ctx_extra,
1990 instance_id,
1991 } => {
1992 if fetch_portal_name.is_none() {
1993 let mut msg = ErrorResponse::notice(
1994 SqlState::WARNING,
1995 "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
1996 );
1997 if self.adapter_client.session().vars().application_name() == "psql" {
1998 msg.hint = Some(
1999 "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
2000 .into(),
2001 )
2002 }
2003 self.send(msg).await?;
2004 self.conn.flush().await?;
2005 }
2006 let row_desc =
2007 row_desc.expect("missing row description for ExecuteResponse::Subscribing");
2008 let (result, statement_ended_execution_reason) = match self
2009 .send_rows(
2010 row_desc,
2011 portal_name,
2012 InProgressRows::new(RecordFirstRowStream::new(
2013 Box::new(UnboundedReceiverStream::new(rx)),
2014 execute_started,
2015 &self.adapter_client,
2016 Some(instance_id),
2017 None,
2018 )),
2019 max_rows,
2020 get_response,
2021 fetch_portal_name,
2022 timeout,
2023 )
2024 .await
2025 {
2026 Err(e) => {
2027 (Err(e), StatementEndedExecutionReason::Canceled)
2030 }
2031 Ok((ok, SendRowsEndedReason::Canceled)) => {
2032 (Ok(ok), StatementEndedExecutionReason::Canceled)
2033 }
2034 Ok((
2035 ok,
2036 SendRowsEndedReason::Success {
2037 result_size,
2038 rows_returned,
2039 },
2040 )) => (
2041 Ok(ok),
2042 StatementEndedExecutionReason::Success {
2043 result_size: Some(result_size),
2044 rows_returned: Some(rows_returned),
2045 execution_strategy: None,
2046 },
2047 ),
2048 Ok((ok, SendRowsEndedReason::Errored { error })) => {
2049 (Ok(ok), StatementEndedExecutionReason::Errored { error })
2050 }
2051 };
2052 self.adapter_client
2053 .retire_execute(ctx_extra, statement_ended_execution_reason);
2054 return result;
2055 }
2056 ExecuteResponse::CopyTo { format, resp } => {
2057 let row_desc =
2058 row_desc.expect("missing row description for ExecuteResponse::CopyTo");
2059 match *resp {
2060 ExecuteResponse::Subscribing {
2061 rx,
2062 ctx_extra,
2063 instance_id,
2064 } => {
2065 let (result, statement_ended_execution_reason) = match self
2066 .copy_rows(
2067 format,
2068 row_desc,
2069 RecordFirstRowStream::new(
2070 Box::new(UnboundedReceiverStream::new(rx)),
2071 execute_started,
2072 &self.adapter_client,
2073 Some(instance_id),
2074 None,
2075 ),
2076 )
2077 .await
2078 {
2079 Err(e) => {
2080 (Err(e), StatementEndedExecutionReason::Canceled)
2083 }
2084 Ok((
2085 state,
2086 SendRowsEndedReason::Success {
2087 result_size,
2088 rows_returned,
2089 },
2090 )) => (
2091 Ok(state),
2092 StatementEndedExecutionReason::Success {
2093 result_size: Some(result_size),
2094 rows_returned: Some(rows_returned),
2095 execution_strategy: None,
2096 },
2097 ),
2098 Ok((state, SendRowsEndedReason::Errored { error })) => {
2099 (Ok(state), StatementEndedExecutionReason::Errored { error })
2100 }
2101 Ok((state, SendRowsEndedReason::Canceled)) => {
2102 (Ok(state), StatementEndedExecutionReason::Canceled)
2103 }
2104 };
2105 self.adapter_client
2106 .retire_execute(ctx_extra, statement_ended_execution_reason);
2107 return result;
2108 }
2109 ExecuteResponse::SendingRowsStreaming {
2110 rows,
2111 instance_id,
2112 strategy,
2113 } => {
2114 return self
2119 .copy_rows(
2120 format,
2121 row_desc,
2122 RecordFirstRowStream::new(
2123 Box::new(rows),
2124 execute_started,
2125 &self.adapter_client,
2126 Some(instance_id),
2127 Some(strategy),
2128 ),
2129 )
2130 .await
2131 .map(|(state, _)| state);
2132 }
2133 ExecuteResponse::SendingRowsImmediate { rows } => {
2134 let span = tracing::debug_span!("sending_rows_immediate");
2135
2136 let rows = futures::stream::once(futures::future::ready(
2137 PeekResponseUnary::Rows(rows),
2138 ));
2139 return self
2144 .copy_rows(
2145 format,
2146 row_desc,
2147 RecordFirstRowStream::new(
2148 Box::new(rows),
2149 execute_started,
2150 &self.adapter_client,
2151 None,
2152 Some(StatementExecutionStrategy::Constant),
2153 ),
2154 )
2155 .instrument(span)
2156 .await
2157 .map(|(state, _)| state);
2158 }
2159 _ => {
2160 return self
2161 .error(ErrorResponse::error(
2162 SqlState::INTERNAL_ERROR,
2163 "unsupported COPY response type".to_string(),
2164 ))
2165 .await;
2166 }
2167 };
2168 }
2169 ExecuteResponse::CopyFrom {
2170 target_id,
2171 target_name,
2172 columns,
2173 params,
2174 ctx_extra,
2175 } => {
2176 let row_desc =
2177 row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2178 self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2179 .await
2180 }
2181 ExecuteResponse::TransactionCommitted { params }
2182 | ExecuteResponse::TransactionRolledBack { params } => {
2183 let notify_set: mz_ore::collections::HashSet<String> = self
2184 .adapter_client
2185 .session()
2186 .vars()
2187 .notify_set()
2188 .map(|v| v.name().to_string())
2189 .collect();
2190
2191 for (name, value) in params
2193 .into_iter()
2194 .filter(|(name, _v)| notify_set.contains(*name))
2195 {
2196 let msg = BackendMessage::ParameterStatus(name, value);
2197 self.send(msg).await?;
2198 }
2199 command_complete!()
2200 }
2201
2202 ExecuteResponse::AlteredDefaultPrivileges
2203 | ExecuteResponse::AlteredObject(..)
2204 | ExecuteResponse::AlteredRole
2205 | ExecuteResponse::AlteredSystemConfiguration
2206 | ExecuteResponse::CreatedCluster { .. }
2207 | ExecuteResponse::CreatedClusterReplica { .. }
2208 | ExecuteResponse::CreatedConnection { .. }
2209 | ExecuteResponse::CreatedDatabase { .. }
2210 | ExecuteResponse::CreatedIndex { .. }
2211 | ExecuteResponse::CreatedIntrospectionSubscribe
2212 | ExecuteResponse::CreatedMaterializedView { .. }
2213 | ExecuteResponse::CreatedContinualTask { .. }
2214 | ExecuteResponse::CreatedRole
2215 | ExecuteResponse::CreatedSchema { .. }
2216 | ExecuteResponse::CreatedSecret { .. }
2217 | ExecuteResponse::CreatedSink { .. }
2218 | ExecuteResponse::CreatedSource { .. }
2219 | ExecuteResponse::CreatedTable { .. }
2220 | ExecuteResponse::CreatedType
2221 | ExecuteResponse::CreatedView { .. }
2222 | ExecuteResponse::CreatedViews { .. }
2223 | ExecuteResponse::CreatedNetworkPolicy
2224 | ExecuteResponse::Comment
2225 | ExecuteResponse::Deallocate { .. }
2226 | ExecuteResponse::Deleted(..)
2227 | ExecuteResponse::DiscardedAll
2228 | ExecuteResponse::DiscardedTemp
2229 | ExecuteResponse::DroppedObject(_)
2230 | ExecuteResponse::DroppedOwned
2231 | ExecuteResponse::GrantedPrivilege
2232 | ExecuteResponse::GrantedRole
2233 | ExecuteResponse::Inserted(..)
2234 | ExecuteResponse::Copied(..)
2235 | ExecuteResponse::Prepare
2236 | ExecuteResponse::Raised
2237 | ExecuteResponse::ReassignOwned
2238 | ExecuteResponse::RevokedPrivilege
2239 | ExecuteResponse::RevokedRole
2240 | ExecuteResponse::StartedTransaction { .. }
2241 | ExecuteResponse::Updated(..)
2242 | ExecuteResponse::ValidatedConnection => {
2243 command_complete!()
2244 }
2245 };
2246
2247 assert_none!(tag, "tag created but not consumed: {:?}", tag);
2248 r
2249 }
2250
2251 #[allow(clippy::too_many_arguments)]
2252 #[mz_ore::instrument(level = "debug")]
2254 async fn send_rows(
2255 &mut self,
2256 row_desc: RelationDesc,
2257 portal_name: String,
2258 mut rows: InProgressRows,
2259 max_rows: ExecuteCount,
2260 get_response: GetResponse,
2261 fetch_portal_name: Option<String>,
2262 timeout: ExecuteTimeout,
2263 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2264 let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2267 name
2268 } else {
2269 &portal_name
2270 };
2271 let result_formats = self
2272 .adapter_client
2273 .session()
2274 .get_portal_unverified(result_format_portal_name)
2275 .expect("valid fetch portal name for send rows")
2276 .result_formats
2277 .clone();
2278
2279 let (mut wait_once, mut deadline) = match timeout {
2280 ExecuteTimeout::None => (false, None),
2281 ExecuteTimeout::Seconds(t) => (
2282 false,
2283 Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2284 ),
2285 ExecuteTimeout::WaitOnce => (true, None),
2286 };
2287
2288 {
2290 let portal_name_desc = &self
2291 .adapter_client
2292 .session()
2293 .get_portal_unverified(portal_name.as_str())
2294 .expect("portal should exist")
2295 .desc
2296 .relation_desc;
2297 if let Some(portal_name_desc) = portal_name_desc {
2298 soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2299 }
2300 if let Some(fetch_portal_name) = &fetch_portal_name {
2301 let fetch_portal_desc = &self
2302 .adapter_client
2303 .session()
2304 .get_portal_unverified(fetch_portal_name)
2305 .expect("portal should exist")
2306 .desc
2307 .relation_desc;
2308 if let Some(fetch_portal_desc) = fetch_portal_desc {
2309 soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2310 }
2311 }
2312 }
2313
2314 self.conn.set_encode_state(
2315 row_desc
2316 .typ()
2317 .column_types
2318 .iter()
2319 .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2320 .zip_eq(result_formats)
2321 .collect(),
2322 );
2323
2324 let mut total_sent_rows = 0;
2325 let mut total_sent_bytes = 0;
2326 let mut want_rows = match max_rows {
2328 ExecuteCount::All => usize::MAX,
2329 ExecuteCount::Count(count) => count,
2330 };
2331
2332 loop {
2334 let batch = if rows.current.is_some() {
2337 FetchResult::Rows(rows.current.take())
2338 } else if want_rows == 0 {
2339 FetchResult::Rows(None)
2340 } else {
2341 let notice_fut = self.adapter_client.session().recv_notice();
2342 tokio::select! {
2343 err = self.conn.wait_closed() => return Err(err),
2344 _ = time::sleep_until(
2345 deadline.unwrap_or_else(tokio::time::Instant::now),
2346 ), if deadline.is_some() => FetchResult::Rows(None),
2347 notice = notice_fut => {
2348 FetchResult::Notice(notice)
2349 }
2350 batch = rows.remaining.recv() => match batch {
2351 None => FetchResult::Rows(None),
2352 Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2353 Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2354 Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2355 },
2356 }
2357 };
2358
2359 match batch {
2360 FetchResult::Rows(None) => break,
2361 FetchResult::Rows(Some(mut batch_rows)) => {
2362 if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2363 let msg = err.to_string();
2364 return self
2365 .error(err.into_response(Severity::Error))
2366 .await
2367 .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2368 }
2369
2370 if wait_once && batch_rows.peek().is_some() {
2374 deadline = Some(tokio::time::Instant::now());
2375 wait_once = false;
2376 }
2377
2378 let mut sent_rows = 0;
2380 let mut sent_bytes = 0;
2381 let messages = (&mut batch_rows)
2382 .map(|row| {
2387 let row_len = row.byte_len();
2388 let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2389 (row_len, BackendMessage::DataRow(values))
2390 })
2391 .inspect(|(row_len, _)| {
2392 sent_bytes += row_len;
2393 sent_rows += 1
2394 })
2395 .map(|(_row_len, row)| row)
2396 .take(want_rows);
2397 self.send_all(messages).await?;
2398
2399 total_sent_rows += sent_rows;
2400 total_sent_bytes += sent_bytes;
2401 want_rows -= sent_rows;
2402
2403 if want_rows == 0 {
2406 if batch_rows.peek().is_some() {
2407 rows.current = Some(batch_rows);
2408 }
2409 break;
2410 }
2411
2412 self.conn.flush().await?;
2413 }
2414 FetchResult::Notice(notice) => {
2415 self.send(notice.into_response()).await?;
2416 self.conn.flush().await?;
2417 }
2418 FetchResult::Error(text) => {
2419 return self
2420 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2421 .await
2422 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2423 }
2424 FetchResult::Canceled => {
2425 return self
2426 .error(ErrorResponse::error(
2427 SqlState::QUERY_CANCELED,
2428 "canceling statement due to user request",
2429 ))
2430 .await
2431 .map(|state| (state, SendRowsEndedReason::Canceled));
2432 }
2433 }
2434 }
2435
2436 let portal = self
2437 .adapter_client
2438 .session()
2439 .get_portal_unverified_mut(&portal_name)
2440 .expect("valid portal name for send rows");
2441
2442 let saw_rows = rows.remaining.saw_rows;
2443 let no_more_rows = rows.no_more_rows();
2444 let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2445
2446 *portal.state = PortalState::InProgress(Some(rows));
2449
2450 let fetch_portal = fetch_portal_name.map(|name| {
2451 self.adapter_client
2452 .session()
2453 .get_portal_unverified_mut(&name)
2454 .expect("valid fetch portal")
2455 });
2456 let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2457 self.send(response_message).await?;
2458
2459 if no_more_rows {
2461 let statement_type = if let Some(stmt) = &self
2462 .adapter_client
2463 .session()
2464 .get_portal_unverified(&portal_name)
2465 .expect("valid portal name for send_rows")
2466 .stmt
2467 {
2468 metrics::statement_type_label_value(stmt.deref())
2469 } else {
2470 "no-statement"
2471 };
2472 let duration = if saw_rows {
2473 recorded_first_row_instant
2474 .expect("recorded_first_row_instant because saw_rows")
2475 .elapsed()
2476 } else {
2477 Duration::ZERO
2481 };
2482 self.adapter_client
2483 .inner()
2484 .metrics()
2485 .result_rows_first_to_last_byte_seconds
2486 .with_label_values(&[statement_type])
2487 .observe(duration.as_secs_f64());
2488 }
2489
2490 Ok((
2491 State::Ready,
2492 SendRowsEndedReason::Success {
2493 result_size: u64::cast_from(total_sent_bytes),
2494 rows_returned: u64::cast_from(total_sent_rows),
2495 },
2496 ))
2497 }
2498
2499 #[mz_ore::instrument(level = "debug")]
2500 async fn copy_rows(
2501 &mut self,
2502 format: CopyFormat,
2503 row_desc: RelationDesc,
2504 mut stream: RecordFirstRowStream,
2505 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2506 let (row_format, encode_format) = match format {
2507 CopyFormat::Text => (
2508 CopyFormatParams::Text(CopyTextFormatParams::default()),
2509 Format::Text,
2510 ),
2511 CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2512 CopyFormat::Csv => (
2513 CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2514 Format::Text,
2515 ),
2516 CopyFormat::Parquet => {
2517 let text = "Parquet format is not supported".to_string();
2518 return self
2519 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2520 .await
2521 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2522 }
2523 };
2524
2525 let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2526 mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2527 };
2528
2529 let typ = row_desc.typ();
2530 let column_formats = iter::repeat(encode_format)
2531 .take(typ.column_types.len())
2532 .collect();
2533 self.send(BackendMessage::CopyOutResponse {
2534 overall_format: encode_format,
2535 column_formats,
2536 })
2537 .await?;
2538
2539 let mut out = Vec::new();
2544
2545 if let CopyFormat::Binary = format {
2546 out.extend(b"PGCOPY\n\xFF\r\n\0");
2548 out.extend([0, 0, 0, 0]);
2550 out.extend([0, 0, 0, 0]);
2552 }
2553
2554 let mut count = 0;
2555 let mut total_sent_bytes = 0;
2556 loop {
2557 tokio::select! {
2558 e = self.conn.wait_closed() => return Err(e),
2559 batch = stream.recv() => match batch {
2560 None => break,
2561 Some(PeekResponseUnary::Error(text)) => {
2562 return self
2563 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2564 .await
2565 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2566 }
2567 Some(PeekResponseUnary::Canceled) => {
2568 return self.error(ErrorResponse::error(
2569 SqlState::QUERY_CANCELED,
2570 "canceling statement due to user request",
2571 ))
2572 .await.map(|state| (state, SendRowsEndedReason::Canceled));
2573 }
2574 Some(PeekResponseUnary::Rows(mut rows)) => {
2575 count += rows.count();
2576 while let Some(row) = rows.next() {
2577 total_sent_bytes += row.byte_len();
2578 encode_fn(row, typ, &mut out)?;
2579 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2580 .await?;
2581 }
2582 }
2583 },
2584 notice = self.adapter_client.session().recv_notice() => {
2585 self.send(notice.into_response())
2586 .await?;
2587 self.conn.flush().await?;
2588 }
2589 }
2590
2591 self.conn.flush().await?;
2592 }
2593 if let CopyFormat::Binary = format {
2595 let trailer: i16 = -1;
2596 out.extend(trailer.to_be_bytes());
2597 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2598 .await?;
2599 }
2600
2601 let tag = format!("COPY {}", count);
2602 self.send(BackendMessage::CopyDone).await?;
2603 self.send(BackendMessage::CommandComplete { tag }).await?;
2604 Ok((
2605 State::Ready,
2606 SendRowsEndedReason::Success {
2607 result_size: u64::cast_from(total_sent_bytes),
2608 rows_returned: u64::cast_from(count),
2609 },
2610 ))
2611 }
2612
2613 #[instrument(level = "debug")]
2616 async fn copy_from(
2617 &mut self,
2618 target_id: CatalogItemId,
2619 target_name: String,
2620 columns: Vec<ColumnIndex>,
2621 params: CopyFormatParams<'_>,
2622 row_desc: RelationDesc,
2623 mut ctx_extra: ExecuteContextGuard,
2624 ) -> Result<State, io::Error> {
2625 let res = self
2626 .copy_from_inner(
2627 target_id,
2628 target_name,
2629 columns,
2630 params,
2631 row_desc,
2632 &mut ctx_extra,
2633 )
2634 .await;
2635 match &res {
2636 Ok(State::Done) => {
2637 self.adapter_client
2641 .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2642 }
2643 Err(e) => {
2644 self.adapter_client.retire_execute(
2645 ctx_extra,
2646 StatementEndedExecutionReason::Errored {
2647 error: format!("{e}"),
2648 },
2649 );
2650 }
2651 other => {
2652 tracing::warn!(?other, "aborting COPY FROM");
2653 self.adapter_client
2654 .retire_execute(ctx_extra, StatementEndedExecutionReason::Aborted);
2655 }
2656 }
2657 res
2658 }
2659
2660 async fn copy_from_inner(
2661 &mut self,
2662 target_id: CatalogItemId,
2663 target_name: String,
2664 columns: Vec<ColumnIndex>,
2665 params: CopyFormatParams<'_>,
2666 row_desc: RelationDesc,
2667 ctx_extra: &mut ExecuteContextGuard,
2668 ) -> Result<State, io::Error> {
2669 let typ = row_desc.typ();
2670 let column_formats = vec![Format::Text; typ.column_types.len()];
2671 self.send(BackendMessage::CopyInResponse {
2672 overall_format: Format::Text,
2673 column_formats,
2674 })
2675 .await?;
2676 self.conn.flush().await?;
2677
2678 let system_vars = self.adapter_client.get_system_vars().await;
2679 let max_size = system_vars
2680 .get(MAX_COPY_FROM_SIZE.name())
2681 .ok()
2682 .and_then(|max_size| max_size.value().parse().ok())
2683 .unwrap_or(usize::MAX);
2684 tracing::debug!("COPY FROM max buffer size: {max_size} bytes");
2685
2686 let mut data = Vec::new();
2687 loop {
2688 let message = self.conn.recv().await?;
2689 match message {
2690 Some(FrontendMessage::CopyData(buf)) => {
2691 if (data.len() + buf.len()) > max_size {
2693 return self
2694 .error(ErrorResponse::error(
2695 SqlState::INSUFFICIENT_RESOURCES,
2696 "COPY FROM STDIN too large",
2697 ))
2698 .await;
2699 }
2700 data.extend(buf)
2701 }
2702 Some(FrontendMessage::CopyDone) => break,
2703 Some(FrontendMessage::CopyFail(err)) => {
2704 self.adapter_client.retire_execute(
2705 std::mem::take(ctx_extra),
2706 StatementEndedExecutionReason::Canceled,
2707 );
2708 return self
2709 .error(ErrorResponse::error(
2710 SqlState::QUERY_CANCELED,
2711 format!("COPY from stdin failed: {}", err),
2712 ))
2713 .await;
2714 }
2715 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2716 Some(_) => {
2717 let msg = "unexpected message type during COPY from stdin";
2718 self.adapter_client.retire_execute(
2719 std::mem::take(ctx_extra),
2720 StatementEndedExecutionReason::Errored {
2721 error: msg.to_string(),
2722 },
2723 );
2724 return self
2725 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
2726 .await;
2727 }
2728 None => {
2729 return Ok(State::Done);
2730 }
2731 }
2732 }
2733
2734 let column_types = typ
2735 .column_types
2736 .iter()
2737 .map(|x| &x.scalar_type)
2738 .map(mz_pgrepr::Type::from)
2739 .collect::<Vec<mz_pgrepr::Type>>();
2740
2741 let rows = match mz_pgcopy::decode_copy_format(&data, &column_types, params) {
2742 Ok(rows) => rows,
2743 Err(e) => {
2744 self.adapter_client.retire_execute(
2745 std::mem::take(ctx_extra),
2746 StatementEndedExecutionReason::Errored {
2747 error: e.to_string(),
2748 },
2749 );
2750 return self
2751 .error(ErrorResponse::error(
2752 SqlState::BAD_COPY_FILE_FORMAT,
2753 format!("{}", e),
2754 ))
2755 .await;
2756 }
2757 };
2758
2759 let count = rows.len();
2760
2761 if let Err(e) = self
2762 .adapter_client
2763 .insert_rows(
2764 target_id,
2765 target_name,
2766 columns,
2767 rows,
2768 std::mem::take(ctx_extra),
2769 )
2770 .await
2771 {
2772 self.adapter_client.retire_execute(
2773 std::mem::take(ctx_extra),
2774 StatementEndedExecutionReason::Errored {
2775 error: e.to_string(),
2776 },
2777 );
2778 return self.error(e.into_response(Severity::Error)).await;
2779 }
2780
2781 let tag = format!("COPY {}", count);
2782 self.send(BackendMessage::CommandComplete { tag }).await?;
2783
2784 Ok(State::Ready)
2785 }
2786
2787 #[instrument(level = "debug")]
2788 async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
2789 let notices = self
2790 .adapter_client
2791 .session()
2792 .drain_notices()
2793 .into_iter()
2794 .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
2795 self.send_all(notices).await?;
2796 Ok(())
2797 }
2798
2799 #[instrument(level = "debug")]
2800 async fn error(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
2801 assert!(err.severity.is_error());
2802 debug!(
2803 "cid={} error code={}",
2804 self.adapter_client.session().conn_id(),
2805 err.code.code()
2806 );
2807 let is_fatal = err.severity.is_fatal();
2808 self.send(BackendMessage::ErrorResponse(err)).await?;
2809
2810 let txn = self.adapter_client.session().transaction();
2811 match txn {
2812 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
2815 TransactionStatus::Started(_) => {
2817 self.rollback_transaction().await?;
2818 }
2819 TransactionStatus::InTransactionImplicit(_) => {
2821 self.rollback_transaction().await?;
2822 }
2823 TransactionStatus::InTransaction(_) => {
2825 self.adapter_client.fail_transaction();
2826 }
2827 };
2828 if is_fatal {
2829 Ok(State::Done)
2830 } else {
2831 Ok(State::Drain)
2832 }
2833 }
2834
2835 #[instrument(level = "debug")]
2836 async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
2837 self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
2838 SqlState::IN_FAILED_SQL_TRANSACTION,
2839 ABORTED_TXN_MSG,
2840 )))
2841 .await?;
2842 Ok(State::Drain)
2843 }
2844
2845 fn is_aborted_txn(&mut self) -> bool {
2846 matches!(
2847 self.adapter_client.session().transaction(),
2848 TransactionStatus::Failed(_)
2849 )
2850 }
2851}
2852
2853fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
2854 match (formats.len(), n) {
2855 (0, e) => Ok(vec![Format::Text; e]),
2856 (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
2857 (a, e) if a == e => Ok(formats),
2858 (a, e) => Err(format!(
2859 "expected {} field format specifiers, but got {}",
2860 e, a
2861 )),
2862 }
2863}
2864
2865fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
2866 match &stmt_desc.relation_desc {
2867 Some(desc) if !stmt_desc.is_copy => {
2868 BackendMessage::RowDescription(message::encode_row_description(desc, formats))
2869 }
2870 _ => BackendMessage::NoData,
2871 }
2872}
2873
2874type GetResponse = fn(
2875 max_rows: ExecuteCount,
2876 total_sent_rows: usize,
2877 fetch_portal: Option<PortalRefMut>,
2878) -> BackendMessage;
2879
2880fn portal_exec_message(
2883 max_rows: ExecuteCount,
2884 total_sent_rows: usize,
2885 _fetch_portal: Option<PortalRefMut>,
2886) -> BackendMessage {
2887 match max_rows {
2894 ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
2895 BackendMessage::PortalSuspended
2896 }
2897 _ => BackendMessage::CommandComplete {
2898 tag: format!("SELECT {}", total_sent_rows),
2899 },
2900 }
2901}
2902
2903fn fetch_message(
2905 _max_rows: ExecuteCount,
2906 total_sent_rows: usize,
2907 fetch_portal: Option<PortalRefMut>,
2908) -> BackendMessage {
2909 let tag = format!("FETCH {}", total_sent_rows);
2910 if let Some(portal) = fetch_portal {
2911 *portal.state = PortalState::Completed(Some(tag.clone()));
2912 }
2913 BackendMessage::CommandComplete { tag }
2914}
2915
2916fn get_authenticator(
2917 authenticator_kind: AuthenticatorKind,
2918 frontegg: Option<FronteggAuthenticator>,
2919 oidc: GenericOidcAuthenticator,
2920 adapter_client: mz_adapter::Client,
2921 oidc_auth_option_enabled: bool,
2924) -> Authenticator {
2925 match authenticator_kind {
2926 AuthenticatorKind::Frontegg => Authenticator::Frontegg(
2927 frontegg.expect("Frontegg authenticator should exist with AuthenticatorKind::Frontegg"),
2928 ),
2929 AuthenticatorKind::Password => Authenticator::Password(adapter_client),
2930 AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client),
2931 AuthenticatorKind::Oidc => {
2932 if oidc_auth_option_enabled {
2933 Authenticator::Oidc(oidc)
2934 } else {
2935 Authenticator::Password(adapter_client)
2938 }
2939 }
2940 AuthenticatorKind::None => Authenticator::None,
2941 }
2942}
2943
2944#[derive(Debug, Copy, Clone)]
2945enum ExecuteCount {
2946 All,
2947 Count(usize),
2948}
2949
2950fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
2952 match stmt {
2953 Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
2955 None => false,
2956 }
2957}
2958
2959#[derive(Debug)]
2960enum FetchResult {
2961 Rows(Option<Box<dyn RowIterator + Send + Sync>>),
2962 Canceled,
2963 Error(String),
2964 Notice(AdapterNotice),
2965}
2966
2967#[cfg(test)]
2968mod test {
2969 use super::*;
2970
2971 #[mz_ore::test]
2972 fn test_parse_options() {
2973 struct TestCase {
2974 input: &'static str,
2975 expect: Result<Vec<(&'static str, &'static str)>, ()>,
2976 }
2977 let tests = vec![
2978 TestCase {
2979 input: "",
2980 expect: Ok(vec![]),
2981 },
2982 TestCase {
2983 input: "--key",
2984 expect: Err(()),
2985 },
2986 TestCase {
2987 input: "--key=val",
2988 expect: Ok(vec![("key", "val")]),
2989 },
2990 TestCase {
2991 input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
2992 expect: Ok(vec![
2993 ("key", "val"),
2994 ("key2", "val2"),
2995 ("key3", "val3"),
2996 ("key4", "val4"),
2997 ("key5", "val5"),
2998 ]),
2999 },
3000 TestCase {
3001 input: r#"-c\ key=val"#,
3002 expect: Ok(vec![(" key", "val")]),
3003 },
3004 TestCase {
3005 input: "--key=val -ckey2 val2",
3006 expect: Err(()),
3007 },
3008 TestCase {
3010 input: "--key=",
3011 expect: Ok(vec![("key", "")]),
3012 },
3013 ];
3014 for test in tests {
3015 let got = parse_options(test.input);
3016 let expect = test.expect.map(|r| {
3017 r.into_iter()
3018 .map(|(k, v)| (k.to_owned(), v.to_owned()))
3019 .collect()
3020 });
3021 assert_eq!(got, expect, "input: {}", test.input);
3022 }
3023 }
3024
3025 #[mz_ore::test]
3026 fn test_parse_option() {
3027 struct TestCase {
3028 input: &'static str,
3029 expect: Result<(&'static str, &'static str), ()>,
3030 }
3031 let tests = vec![
3032 TestCase {
3033 input: "",
3034 expect: Err(()),
3035 },
3036 TestCase {
3037 input: "--",
3038 expect: Err(()),
3039 },
3040 TestCase {
3041 input: "--c",
3042 expect: Err(()),
3043 },
3044 TestCase {
3045 input: "a=b",
3046 expect: Err(()),
3047 },
3048 TestCase {
3049 input: "--a=b",
3050 expect: Ok(("a", "b")),
3051 },
3052 TestCase {
3053 input: "--ca=b",
3054 expect: Ok(("ca", "b")),
3055 },
3056 TestCase {
3057 input: "-ca=b",
3058 expect: Ok(("a", "b")),
3059 },
3060 TestCase {
3062 input: "--=",
3063 expect: Ok(("", "")),
3064 },
3065 ];
3066 for test in tests {
3067 let got = parse_option(test.input);
3068 assert_eq!(got, test.expect, "input: {}", test.input);
3069 }
3070 }
3071
3072 #[mz_ore::test]
3073 fn test_split_options() {
3074 struct TestCase {
3075 input: &'static str,
3076 expect: Vec<&'static str>,
3077 }
3078 let tests = vec![
3079 TestCase {
3080 input: "",
3081 expect: vec![],
3082 },
3083 TestCase {
3084 input: " ",
3085 expect: vec![],
3086 },
3087 TestCase {
3088 input: " a ",
3089 expect: vec!["a"],
3090 },
3091 TestCase {
3092 input: " ab cd ",
3093 expect: vec!["ab", "cd"],
3094 },
3095 TestCase {
3096 input: r#" ab\ cd "#,
3097 expect: vec!["ab ", "cd"],
3098 },
3099 TestCase {
3100 input: r#" ab\\ cd "#,
3101 expect: vec![r#"ab\"#, "cd"],
3102 },
3103 TestCase {
3104 input: r#" ab\\\ cd "#,
3105 expect: vec![r#"ab\ "#, "cd"],
3106 },
3107 TestCase {
3108 input: r#" ab\\\ cd "#,
3109 expect: vec![r#"ab\ cd"#],
3110 },
3111 TestCase {
3112 input: r#" ab\\\cd "#,
3113 expect: vec![r#"ab\cd"#],
3114 },
3115 TestCase {
3116 input: r#"a\"#,
3117 expect: vec!["a"],
3118 },
3119 TestCase {
3120 input: r#"a\ "#,
3121 expect: vec!["a "],
3122 },
3123 TestCase {
3124 input: r#"\"#,
3125 expect: vec![],
3126 },
3127 TestCase {
3128 input: r#"\ "#,
3129 expect: vec![r#" "#],
3130 },
3131 TestCase {
3132 input: r#" \ "#,
3133 expect: vec![r#" "#],
3134 },
3135 TestCase {
3136 input: r#"\ "#,
3137 expect: vec![r#" "#],
3138 },
3139 ];
3140 for test in tests {
3141 let got = split_options(test.input);
3142 assert_eq!(got, test.expect, "input: {}", test.input);
3143 }
3144 }
3145}