1use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use csv_core::ReadRecordResult;
21use futures::future::{BoxFuture, FutureExt, pending};
22use itertools::Itertools;
23use mz_adapter::client::RecordFirstRowStream;
24use mz_adapter::session::{
25 EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState, Session,
26 SessionConfig, TransactionStatus,
27};
28use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
29use mz_adapter::{
30 AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics,
31 verify_datum_desc,
32};
33use mz_adapter_types::dyncfgs::OIDC_GROUP_CLAIM;
34use mz_auth::Authenticated;
35use mz_auth::password::Password;
36use mz_authenticator::{Authenticator, GenericOidcAuthenticator};
37use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
38use mz_ore::cast::CastFrom;
39use mz_ore::netio::AsyncReady;
40use mz_ore::now::{EpochMillis, SYSTEM_TIME};
41use mz_ore::str::StrExt;
42use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log, soft_assert_or_log};
43use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
44use mz_pgwire_common::{
45 ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
46 VERSIONS,
47};
48use mz_repr::{
49 CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
50 SqlRelationType, SqlScalarType,
51};
52use mz_server_core::TlsMode;
53use mz_server_core::listeners;
54use mz_server_core::listeners::AllowedRoles;
55use mz_sql::ast::display::AstDisplay;
56use mz_sql::ast::{
57 CopyDirection, CopyStatement, CopyTarget, FetchDirection, Ident, Raw, Statement,
58};
59use mz_sql::parse::StatementParseResult;
60use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
61use mz_sql::session::metadata::SessionMetadata;
62use mz_sql::session::user::INTERNAL_USER_NAMES;
63use mz_sql::session::vars::VarInput;
64use postgres::error::SqlState;
65use tokio::io::{self, AsyncRead, AsyncWrite};
66use tokio::select;
67use tokio::time::{self};
68use tokio_metrics::TaskMetrics;
69use tokio_stream::wrappers::UnboundedReceiverStream;
70use tracing::{Instrument, debug, debug_span, warn};
71use uuid::Uuid;
72
73use crate::codec::{
74 FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
75};
76use crate::message::{
77 self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
78 SASLServerFirstMessage,
79};
80
81pub fn match_handshake(buf: &[u8]) -> bool {
85 if buf.len() < 8 {
95 return false;
96 }
97 let version = NetworkEndian::read_i32(&buf[4..8]);
98 VERSIONS.contains(&version)
99}
100
101pub struct RunParams<'a, A, I>
103where
104 I: Iterator<Item = TaskMetrics> + Send,
105{
106 pub tls_mode: Option<TlsMode>,
108 pub adapter_client: mz_adapter::Client,
110 pub conn: &'a mut FramedConn<A>,
112 pub conn_uuid: Uuid,
114 pub version: i32,
116 pub params: BTreeMap<String, String>,
118 pub frontegg: Option<FronteggAuthenticator>,
120 pub oidc: GenericOidcAuthenticator,
122 pub authenticator_kind: listeners::AuthenticatorKind,
125 pub active_connection_counter: ConnectionCounter,
127 pub helm_chart_version: Option<String>,
129 pub allowed_roles: AllowedRoles,
131 pub tokio_metrics_intervals: I,
133}
134
135#[mz_ore::instrument(level = "debug")]
145pub async fn run<'a, A, I>(
146 RunParams {
147 tls_mode,
148 adapter_client,
149 conn,
150 conn_uuid,
151 version,
152 mut params,
153 frontegg,
154 oidc,
155 authenticator_kind,
156 active_connection_counter,
157 helm_chart_version,
158 allowed_roles,
159 tokio_metrics_intervals,
160 }: RunParams<'a, A, I>,
161) -> Result<(), io::Error>
162where
163 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
164 I: Iterator<Item = TaskMetrics> + Send,
165{
166 if version != VERSION_3 {
167 return conn
168 .send(ErrorResponse::fatal(
169 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
170 "server does not support the client's requested protocol version",
171 ))
172 .await;
173 }
174
175 let user = params.remove("user").unwrap_or_else(String::new);
176 let options = parse_options(params.get("options").unwrap_or(&String::new()));
177 let authenticator =
178 get_authenticator(authenticator_kind, frontegg, oidc, adapter_client.clone());
179 let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
181 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
183 let role_allowed = match allowed_roles {
184 AllowedRoles::Normal => !is_reserved_user,
185 AllowedRoles::Internal => is_internal_user,
186 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
187 };
188 if !role_allowed {
189 let msg = format!("unauthorized login to user '{user}'");
190 return conn
191 .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
192 .await;
193 }
194
195 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
196 return conn.send(err).await;
197 }
198
199 let authenticator_kind = authenticator.kind();
200
201 let (mut session, expired) = match authenticator {
202 Authenticator::Frontegg(frontegg) => {
203 let password = match request_cleartext_password(conn).await {
204 Ok(password) => password,
205 Err(PasswordRequestError::IoError(e)) => return Err(e),
206 Err(PasswordRequestError::InvalidPasswordError(e)) => {
207 return conn.send(e).await;
208 }
209 };
210
211 let group_claim =
212 OIDC_GROUP_CLAIM.get(adapter_client.get_system_vars().await.dyncfgs());
213 let auth_response = frontegg
214 .authenticate(&user, &password, Some(&group_claim))
215 .await;
216 match auth_response {
217 Ok((mut auth_session, authenticated)) => {
224 let groups = auth_session.groups();
225 let session = adapter_client.new_session(
226 SessionConfig {
227 conn_id: conn.conn_id().clone(),
228 uuid: conn_uuid,
229 user: auth_session.user().into(),
230 client_ip: conn.peer_addr().clone(),
231 external_metadata_rx: Some(auth_session.external_metadata_rx()),
232 helm_chart_version,
233 authenticator_kind,
234 groups,
235 },
236 authenticated,
237 );
238 let expired = async move { auth_session.expired().await };
239 (session, expired.left_future())
240 }
241 Err(err) => {
242 warn!(?err, "pgwire connection failed authentication");
243 return conn
244 .send(ErrorResponse::fatal(
245 SqlState::INVALID_PASSWORD,
246 "invalid password",
247 ))
248 .await;
249 }
250 }
251 }
252 Authenticator::Oidc(oidc) => {
253 let password = match request_cleartext_password(conn).await {
256 Ok(password) => password,
257 Err(PasswordRequestError::IoError(e)) => return Err(e),
258 Err(PasswordRequestError::InvalidPasswordError(e)) => {
259 return conn.send(e).await;
260 }
261 };
262 if is_jwt(&password) {
263 let auth_response = oidc.authenticate(&password, Some(&user)).await;
264 match auth_response {
265 Ok((mut claims, authenticated)) => {
266 let groups = claims.groups.take();
267 let session = adapter_client.new_session(
268 SessionConfig {
269 conn_id: conn.conn_id().clone(),
270 uuid: conn_uuid,
271 user: std::mem::take(&mut claims.user),
272 client_ip: conn.peer_addr().clone(),
273 external_metadata_rx: None,
274 helm_chart_version,
275 authenticator_kind,
276 groups,
277 },
278 authenticated,
279 );
280 (session, pending().right_future())
283 }
284 Err(err) => {
285 warn!(?err, "pgwire connection failed authentication");
286 return conn.send(err.into_response()).await;
287 }
288 }
289 } else {
290 let session = match authenticate_with_password(
291 conn,
292 &adapter_client,
293 user,
294 Password(password),
295 conn_uuid,
296 helm_chart_version,
297 )
298 .await
299 {
300 Ok(session) => session,
301 Err(PasswordRequestError::IoError(e)) => return Err(e),
302 Err(PasswordRequestError::InvalidPasswordError(e)) => {
303 return conn.send(e).await;
304 }
305 };
306 (session, pending().right_future())
307 }
308 }
309 Authenticator::Password(adapter_client) => {
310 let password = match request_cleartext_password(conn).await {
311 Ok(password) => password,
312 Err(PasswordRequestError::IoError(e)) => return Err(e),
313 Err(PasswordRequestError::InvalidPasswordError(e)) => {
314 return conn.send(e).await;
315 }
316 };
317 let session = match authenticate_with_password(
318 conn,
319 &adapter_client,
320 user,
321 Password(password),
322 conn_uuid,
323 helm_chart_version,
324 )
325 .await
326 {
327 Ok(session) => session,
328 Err(PasswordRequestError::IoError(e)) => return Err(e),
329 Err(PasswordRequestError::InvalidPasswordError(e)) => {
330 return conn.send(e).await;
331 }
332 };
333 (session, pending().right_future())
335 }
336 Authenticator::Sasl(adapter_client) => {
337 conn.send(BackendMessage::AuthenticationSASL).await?;
339 conn.flush().await?;
340 let (mechanism, initial_response) = match conn.recv().await? {
342 Some(FrontendMessage::RawAuthentication(data)) => {
343 match decode_sasl_initial_response(Cursor::new(&data)).ok() {
344 Some(FrontendMessage::SASLInitialResponse {
345 gs2_header,
346 mechanism,
347 initial_response,
348 }) => {
349 if gs2_header.channel_binding_enabled() {
351 return conn
352 .send(ErrorResponse::fatal(
353 SqlState::PROTOCOL_VIOLATION,
354 "channel binding not supported",
355 ))
356 .await;
357 }
358 (mechanism, initial_response)
359 }
360 _ => {
361 return conn
362 .send(ErrorResponse::fatal(
363 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
364 "expected SASLInitialResponse message",
365 ))
366 .await;
367 }
368 }
369 }
370 _ => {
371 return conn
372 .send(ErrorResponse::fatal(
373 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
374 "expected SASLInitialResponse message",
375 ))
376 .await;
377 }
378 };
379
380 if mechanism != "SCRAM-SHA-256" {
381 return conn
382 .send(ErrorResponse::fatal(
383 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
384 "unsupported SASL mechanism",
385 ))
386 .await;
387 }
388
389 if initial_response.nonce.len() > 256 {
390 return conn
391 .send(ErrorResponse::fatal(
392 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
393 "nonce too long",
394 ))
395 .await;
396 }
397
398 let (server_first_message_raw, mock_hash) = match adapter_client
399 .generate_sasl_challenge(&user, &initial_response.nonce)
400 .await
401 {
402 Ok(response) => {
403 let server_first_message_raw = format!(
404 "r={},s={},i={}",
405 response.nonce, response.salt, response.iteration_count
406 );
407
408 let client_key = [0u8; 32];
409 let server_key = [1u8; 32];
410 let mock_hash = format!(
411 "SCRAM-SHA-256${}:{}${}:{}",
412 response.iteration_count,
413 response.salt,
414 BASE64_STANDARD.encode(client_key),
415 BASE64_STANDARD.encode(server_key)
416 );
417
418 conn.send(BackendMessage::AuthenticationSASLContinue(
419 SASLServerFirstMessage {
420 iteration_count: response.iteration_count,
421 nonce: response.nonce,
422 salt: response.salt,
423 },
424 ))
425 .await?;
426 conn.flush().await?;
427 (server_first_message_raw, mock_hash)
428 }
429 Err(e) => {
430 return conn.send(e.into_response(Severity::Fatal)).await;
431 }
432 };
433
434 let authenticated = match conn.recv().await? {
435 Some(FrontendMessage::RawAuthentication(data)) => {
436 match decode_sasl_response(Cursor::new(&data)).ok() {
437 Some(FrontendMessage::SASLResponse(response)) => {
438 let auth_message = format!(
439 "{},{},{}",
440 initial_response.client_first_message_bare_raw,
441 server_first_message_raw,
442 response.client_final_message_bare_raw
443 );
444 if response.proof.len() > 1024 {
445 return conn
446 .send(ErrorResponse::fatal(
447 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
448 "proof too long",
449 ))
450 .await;
451 }
452 match adapter_client
453 .verify_sasl_proof(
454 &user,
455 &response.proof,
456 &auth_message,
457 &mock_hash,
458 )
459 .await
460 {
461 Ok((proof_response, authenticated)) => {
462 conn.send(BackendMessage::AuthenticationSASLFinal(
463 SASLServerFinalMessage {
464 kind: SASLServerFinalMessageKinds::Verifier(
465 proof_response.verifier,
466 ),
467 extensions: vec![],
468 },
469 ))
470 .await?;
471 conn.flush().await?;
472 authenticated
473 }
474 Err(_) => {
475 return conn
476 .send(ErrorResponse::fatal(
477 SqlState::INVALID_PASSWORD,
478 "invalid password",
479 ))
480 .await;
481 }
482 }
483 }
484 _ => {
485 return conn
486 .send(ErrorResponse::fatal(
487 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
488 "expected SASLResponse message",
489 ))
490 .await;
491 }
492 }
493 }
494 _ => {
495 return conn
496 .send(ErrorResponse::fatal(
497 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
498 "expected SASLResponse message",
499 ))
500 .await;
501 }
502 };
503
504 let session = adapter_client.new_session(
505 SessionConfig {
506 conn_id: conn.conn_id().clone(),
507 uuid: conn_uuid,
508 user,
509 client_ip: conn.peer_addr().clone(),
510 external_metadata_rx: None,
511 helm_chart_version,
512 authenticator_kind,
513 groups: None,
514 },
515 authenticated,
516 );
517 let auth_session = pending().right_future();
519 (session, auth_session)
520 }
521
522 Authenticator::None => {
523 let session = adapter_client.new_session(
524 SessionConfig {
525 conn_id: conn.conn_id().clone(),
526 uuid: conn_uuid,
527 user,
528 client_ip: conn.peer_addr().clone(),
529 external_metadata_rx: None,
530 helm_chart_version,
531 authenticator_kind,
532 groups: None,
533 },
534 Authenticated,
535 );
536 let auth_session = pending().right_future();
538 (session, auth_session)
539 }
540 };
541
542 let system_vars = adapter_client.get_system_vars().await;
543 for (name, value) in params {
544 let settings = match name.as_str() {
545 "options" => match &options {
546 Ok(opts) => opts,
547 Err(()) => {
548 session.add_notice(AdapterNotice::BadStartupSetting {
549 name,
550 reason: "could not parse".into(),
551 });
552 continue;
553 }
554 },
555 _ => &vec![(name, value)],
556 };
557 for (key, val) in settings {
558 const LOCAL: bool = false;
559 if let Err(err) = session
564 .vars_mut()
565 .set(&system_vars, key, VarInput::Flat(val), LOCAL)
566 {
567 session.add_notice(AdapterNotice::BadStartupSetting {
568 name: key.clone(),
569 reason: err.to_string(),
570 });
571 }
572 }
573 }
574 session
575 .vars_mut()
576 .end_transaction(EndTransactionAction::Commit);
577
578 let _guard = match active_connection_counter.allocate_connection(session.user()) {
579 Ok(drop_connection) => drop_connection,
580 Err(e) => {
581 let e: AdapterError = e.into();
582 return conn.send(e.into_response(Severity::Fatal)).await;
583 }
584 };
585
586 let mut adapter_client = match adapter_client.startup(session).await {
588 Ok(adapter_client) => adapter_client,
589 Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
590 };
591
592 let mut buf = vec![BackendMessage::AuthenticationOk];
593 for var in adapter_client.session().vars().notify_set() {
594 buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
595 }
596 buf.push(BackendMessage::BackendKeyData {
597 conn_id: adapter_client.session().conn_id().unhandled(),
598 secret_key: adapter_client.session().secret_key(),
599 });
600 buf.extend(
601 adapter_client
602 .session()
603 .drain_notices()
604 .into_iter()
605 .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
606 );
607 buf.push(BackendMessage::ReadyForQuery(
608 adapter_client.session().transaction().into(),
609 ));
610 conn.send_all(buf).await?;
611 conn.flush().await?;
612
613 let machine = StateMachine {
614 conn,
615 adapter_client,
616 txn_needs_commit: false,
617 tokio_metrics_intervals,
618 };
619
620 select! {
621 r = machine.run() => {
622 if let Err(err) = &r {
627 let _ = conn
628 .send(ErrorResponse::fatal(
629 SqlState::CONNECTION_FAILURE,
630 err.to_string(),
631 ))
632 .await;
633 let _ = conn.flush().await;
634 }
635 r
636 },
637 _ = expired => {
638 conn
639 .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
640 .await?;
641 conn.flush().await
642 }
643 }
644}
645
646fn is_jwt(password: &str) -> bool {
649 jsonwebtoken::decode_header(password).is_ok()
650}
651
652fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
657 let opts = split_options(value);
658 let mut pairs = Vec::with_capacity(opts.len());
659 let mut seen_prefix = false;
660 for opt in opts {
661 if !seen_prefix {
662 if opt == "-c" {
663 seen_prefix = true;
664 } else {
665 let (key, val) = parse_option(&opt)?;
666 pairs.push((key.to_owned(), val.to_owned()));
667 }
668 } else {
669 let (key, val) = opt.split_once('=').ok_or(())?;
670 pairs.push((key.to_owned(), val.to_owned()));
671 seen_prefix = false;
672 }
673 }
674 Ok(pairs)
675}
676
677fn parse_option(option: &str) -> Result<(&str, &str), ()> {
681 let (key, value) = option.split_once('=').ok_or(())?;
682 for prefix in &["-c", "--"] {
683 if let Some(key) = key.strip_prefix(prefix) {
684 return Ok((key, value));
685 }
686 }
687 Err(())
688}
689
690fn split_options(value: &str) -> Vec<String> {
692 let mut strs = Vec::new();
693 let mut current = String::new();
697 let mut was_slash = false;
698 for c in value.chars() {
699 was_slash = match c {
700 ' ' => {
701 if was_slash {
702 current.push(' ');
703 } else if !current.is_empty() {
704 strs.push(std::mem::take(&mut current));
707 }
708 false
709 }
710 '\\' => {
711 if was_slash {
712 current.push('\\');
715 false
716 } else {
717 true
718 }
719 }
720 _ => {
721 current.push(c);
722 false
723 }
724 };
725 }
726 if !current.is_empty() {
728 strs.push(current);
729 }
730 strs
731}
732
733enum PasswordRequestError {
734 InvalidPasswordError(ErrorResponse),
735 IoError(io::Error),
736}
737
738impl From<io::Error> for PasswordRequestError {
739 fn from(e: io::Error) -> Self {
740 PasswordRequestError::IoError(e)
741 }
742}
743
744async fn request_cleartext_password<A>(
748 conn: &mut FramedConn<A>,
749) -> Result<String, PasswordRequestError>
750where
751 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
752{
753 conn.send(BackendMessage::AuthenticationCleartextPassword)
754 .await?;
755 conn.flush().await?;
756
757 if let Some(message) = conn.recv().await? {
758 if let FrontendMessage::RawAuthentication(data) = message {
759 if let Some(FrontendMessage::Password { password }) =
760 decode_password(Cursor::new(&data)).ok()
761 {
762 return Ok(password);
763 }
764 }
765 }
766
767 Err(PasswordRequestError::InvalidPasswordError(
768 ErrorResponse::fatal(
769 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
770 "expected Password message",
771 ),
772 ))
773}
774
775async fn authenticate_with_password<A>(
778 conn: &FramedConn<A>,
779 adapter_client: &mz_adapter::Client,
780 user: String,
781 password: Password,
782 conn_uuid: Uuid,
783 helm_chart_version: Option<String>,
784) -> Result<Session, PasswordRequestError>
785where
786 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
787{
788 let authenticated = match adapter_client.authenticate(&user, &password).await {
789 Ok(authenticated) => authenticated,
790 Err(err) => {
791 warn!(?err, "pgwire connection failed authentication");
792 return Err(PasswordRequestError::InvalidPasswordError(
793 ErrorResponse::fatal(SqlState::INVALID_PASSWORD, "invalid password"),
794 ));
795 }
796 };
797
798 let session = adapter_client.new_session(
799 SessionConfig {
800 conn_id: conn.conn_id().clone(),
801 uuid: conn_uuid,
802 user,
803 client_ip: conn.peer_addr().clone(),
804 external_metadata_rx: None,
805 helm_chart_version,
806 authenticator_kind: mz_auth::AuthenticatorKind::Password,
807 groups: None,
808 },
809 authenticated,
810 );
811
812 Ok(session)
813}
814
815#[derive(Debug)]
816enum State {
817 Ready,
818 Drain,
819 Done,
820}
821
822struct StateMachine<'a, A, I>
823where
824 I: Iterator<Item = TaskMetrics> + Send + 'a,
825{
826 conn: &'a mut FramedConn<A>,
827 adapter_client: mz_adapter::SessionClient,
828 txn_needs_commit: bool,
829 tokio_metrics_intervals: I,
830}
831
832enum SendRowsEndedReason {
833 Success {
834 result_size: u64,
835 rows_returned: u64,
836 },
837 Errored {
838 error: String,
839 },
840 Canceled,
841}
842
843const ABORTED_TXN_MSG: &str =
844 "current transaction is aborted, commands ignored until end of transaction block";
845
846impl<'a, A, I> StateMachine<'a, A, I>
847where
848 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
849 I: Iterator<Item = TaskMetrics> + Send + 'a,
850{
851 #[allow(clippy::manual_async_fn)]
855 #[mz_ore::instrument(level = "debug")]
856 fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
857 async move {
858 let mut state = State::Ready;
859 loop {
860 self.send_pending_notices().await?;
861 state = match state {
862 State::Ready => self.advance_ready().await?,
863 State::Drain => self.advance_drain().await?,
864 State::Done => return Ok(()),
865 };
866 self.adapter_client
867 .add_idle_in_transaction_session_timeout();
868 }
869 }
870 }
871
872 #[instrument(level = "debug")]
873 async fn advance_ready(&mut self) -> Result<State, io::Error> {
874 self.tokio_metrics_intervals
876 .next()
877 .expect("infinite iterator");
878
879 let message = select! {
881 biased;
882
883 Some(timeout) = self.adapter_client.recv_timeout() => {
885 let err: AdapterError = timeout.into();
886 let conn_id = self.adapter_client.session().conn_id();
887 tracing::warn!("session timed out, conn_id {}", conn_id);
888
889 let error_response = err.into_response(Severity::Fatal);
891 let error_state = self.send_error_and_get_state(error_response).await;
892
893 self.adapter_client.terminate().await;
895
896 let _ = self.conn.recv().await?;
900 return error_state;
901 },
902 message = self.conn.recv() => message?,
904 };
905
906 let interval = self
908 .tokio_metrics_intervals
909 .next()
910 .expect("infinite iterator");
911 let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
912
913 let received = SYSTEM_TIME();
917
918 self.adapter_client
919 .remove_idle_in_transaction_session_timeout();
920
921 let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
924
925 let start = message.as_ref().map(|_| Instant::now());
926 let next_state = match message {
927 Some(FrontendMessage::Query { sql }) => {
928 let query_root_span =
929 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
930 query_root_span.follows_from(tracing::Span::current());
931 self.query(sql, received)
932 .instrument(query_root_span)
933 .await?
934 }
935 Some(FrontendMessage::Parse {
936 name,
937 sql,
938 param_types,
939 }) => self.parse(name, sql, param_types).await?,
940 Some(FrontendMessage::Bind {
941 portal_name,
942 statement_name,
943 param_formats,
944 raw_params,
945 result_formats,
946 }) => {
947 self.bind(
948 portal_name,
949 statement_name,
950 param_formats,
951 raw_params,
952 result_formats,
953 )
954 .await?
955 }
956 Some(FrontendMessage::Execute {
957 portal_name,
958 max_rows,
959 }) => {
960 let max_rows = match usize::try_from(max_rows) {
961 Ok(0) | Err(_) => ExecuteCount::All, Ok(n) => ExecuteCount::Count(n),
963 };
964 let execute_root_span =
965 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
966 execute_root_span.follows_from(tracing::Span::current());
967 let state = self
968 .execute(
969 portal_name,
970 max_rows,
971 portal_exec_message,
972 None,
973 ExecuteTimeout::None,
974 None,
975 Some(received),
976 )
977 .instrument(execute_root_span)
978 .await?;
979 if self.adapter_client.session().transaction().is_implicit() {
994 self.txn_needs_commit = true;
995 }
996 state
997 }
998 Some(FrontendMessage::DescribeStatement { name }) => {
999 self.describe_statement(&name).await?
1000 }
1001 Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
1002 Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
1003 Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
1004 Some(FrontendMessage::Flush) => self.flush().await?,
1005 Some(FrontendMessage::Sync) => self.sync().await?,
1006 Some(FrontendMessage::Terminate) => State::Done,
1007
1008 Some(FrontendMessage::CopyData(_))
1009 | Some(FrontendMessage::CopyDone)
1010 | Some(FrontendMessage::CopyFail(_))
1011 | Some(FrontendMessage::Password { .. })
1012 | Some(FrontendMessage::RawAuthentication(_))
1013 | Some(FrontendMessage::SASLInitialResponse { .. })
1014 | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
1015 None => State::Done,
1016 };
1017
1018 if let Some(start) = start {
1019 self.adapter_client
1020 .inner()
1021 .metrics()
1022 .pgwire_message_processing_seconds
1023 .with_label_values(&[message_name])
1024 .observe(start.elapsed().as_secs_f64());
1025 }
1026 self.adapter_client
1027 .inner()
1028 .metrics()
1029 .pgwire_recv_scheduling_delay_ms
1030 .with_label_values(&[message_name])
1031 .observe(recv_scheduling_delay_ms);
1032
1033 Ok(next_state)
1034 }
1035
1036 async fn advance_drain(&mut self) -> Result<State, io::Error> {
1037 let message = self.conn.recv().await?;
1038 if message.is_some() {
1039 self.adapter_client
1040 .remove_idle_in_transaction_session_timeout();
1041 }
1042 match message {
1043 Some(FrontendMessage::Sync) => self.sync().await,
1044 None => Ok(State::Done),
1045 _ => Ok(State::Drain),
1046 }
1047 }
1048
1049 #[instrument(level = "debug")]
1053 async fn one_query(
1054 &mut self,
1055 stmt: Statement<Raw>,
1056 sql: String,
1057 lifecycle_timestamps: LifecycleTimestamps,
1058 ) -> Result<State, io::Error> {
1059 const EMPTY_PORTAL: &str = "";
1062 if let Err(e) = self
1063 .adapter_client
1064 .declare(EMPTY_PORTAL.to_string(), stmt, sql)
1065 .await
1066 {
1067 return self
1068 .send_error_and_get_state(e.into_response(Severity::Error))
1069 .await;
1070 }
1071 let portal = self
1072 .adapter_client
1073 .session()
1074 .get_portal_unverified_mut(EMPTY_PORTAL)
1075 .expect("unnamed portal should be present");
1076
1077 *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
1078
1079 let stmt_desc = portal.desc.clone();
1080 if !stmt_desc.param_types.is_empty() {
1081 return self
1082 .send_error_and_get_state(ErrorResponse::error(
1083 SqlState::UNDEFINED_PARAMETER,
1084 "there is no parameter $1",
1085 ))
1086 .await;
1087 }
1088
1089 if let Some(relation_desc) = &stmt_desc.relation_desc {
1091 if !stmt_desc.is_copy {
1092 let formats = vec![Format::Text; stmt_desc.arity()];
1093 self.send(BackendMessage::RowDescription(
1094 message::encode_row_description(relation_desc, &formats),
1095 ))
1096 .await?;
1097 }
1098 }
1099
1100 let result = match self
1101 .adapter_client
1102 .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
1103 .await
1104 {
1105 Ok((response, execute_started)) => {
1106 self.send_pending_notices().await?;
1107 self.send_execute_response(
1108 response,
1109 stmt_desc.relation_desc,
1110 EMPTY_PORTAL.to_string(),
1111 ExecuteCount::All,
1112 portal_exec_message,
1113 None,
1114 ExecuteTimeout::None,
1115 execute_started,
1116 )
1117 .await
1118 }
1119 Err(e) => {
1120 self.send_pending_notices().await?;
1121 self.send_error_and_get_state(e.into_response(Severity::Error))
1122 .await
1123 }
1124 };
1125
1126 self.adapter_client.session().remove_portal(EMPTY_PORTAL);
1128
1129 result
1130 }
1131
1132 async fn ensure_transaction(
1133 &mut self,
1134 num_stmts: usize,
1135 message_type: &str,
1136 ) -> Result<(), io::Error> {
1137 let start = Instant::now();
1138 if self.txn_needs_commit {
1139 self.commit_transaction().await?;
1140 }
1141 let res = self.adapter_client.start_transaction(Some(num_stmts));
1144 assert_ok!(res);
1145 self.adapter_client
1146 .inner()
1147 .metrics()
1148 .pgwire_ensure_transaction_seconds
1149 .with_label_values(&[message_type])
1150 .observe(start.elapsed().as_secs_f64());
1151 Ok(())
1152 }
1153
1154 fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1155 let parse_start = Instant::now();
1156 let result = match self.adapter_client.parse(sql) {
1157 Ok(result) => result.map_err(|e| {
1158 let pos = sql[..e.error.pos].chars().count() + 1;
1161 ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1162 }),
1163 Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1164 };
1165 self.adapter_client
1166 .inner()
1167 .metrics()
1168 .parse_seconds
1169 .observe(parse_start.elapsed().as_secs_f64());
1170 result
1171 }
1172
1173 #[instrument(level = "debug")]
1178 async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1179 let stmts = match self.parse_sql(&sql) {
1181 Ok(stmts) => stmts,
1182 Err(err) => {
1183 self.send_error_and_get_state(err).await?;
1184 return self.ready().await;
1185 }
1186 };
1187
1188 let num_stmts = stmts.len();
1189
1190 for StatementParseResult { ast: stmt, sql } in stmts {
1192 if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1194 self.aborted_txn_error().await?;
1195 break;
1196 }
1197
1198 self.ensure_transaction(num_stmts, "query").await?;
1206
1207 match self
1208 .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1209 .await?
1210 {
1211 State::Ready => (),
1212 State::Drain => break,
1213 State::Done => return Ok(State::Done),
1214 }
1215 }
1216
1217 {
1219 if self.adapter_client.session().transaction().is_implicit() {
1220 self.commit_transaction().await?;
1221 }
1222 }
1223
1224 if num_stmts == 0 {
1225 self.send(BackendMessage::EmptyQueryResponse).await?;
1226 }
1227
1228 self.ready().await
1229 }
1230
1231 #[instrument(level = "debug")]
1232 async fn parse(
1233 &mut self,
1234 name: String,
1235 sql: String,
1236 param_oids: Vec<u32>,
1237 ) -> Result<State, io::Error> {
1238 self.ensure_transaction(1, "parse").await?;
1240
1241 let mut param_types = vec![];
1242 for oid in param_oids {
1243 match mz_pgrepr::Type::from_oid(oid) {
1244 Ok(ty) => match SqlScalarType::try_from(&ty) {
1245 Ok(ty) => param_types.push(Some(ty)),
1246 Err(err) => {
1247 return self
1248 .send_error_and_get_state(ErrorResponse::error(
1249 SqlState::INVALID_PARAMETER_VALUE,
1250 err.to_string(),
1251 ))
1252 .await;
1253 }
1254 },
1255 Err(_) if oid == 0 => param_types.push(None),
1256 Err(e) => {
1257 return self
1258 .send_error_and_get_state(ErrorResponse::error(
1259 SqlState::PROTOCOL_VIOLATION,
1260 e.to_string(),
1261 ))
1262 .await;
1263 }
1264 }
1265 }
1266
1267 let stmts = match self.parse_sql(&sql) {
1268 Ok(stmts) => stmts,
1269 Err(err) => {
1270 return self.send_error_and_get_state(err).await;
1271 }
1272 };
1273 if stmts.len() > 1 {
1274 return self
1275 .send_error_and_get_state(ErrorResponse::error(
1276 SqlState::INTERNAL_ERROR,
1277 "cannot insert multiple commands into a prepared statement",
1278 ))
1279 .await;
1280 }
1281 let (maybe_stmt, sql) = match stmts.into_iter().next() {
1282 None => (None, ""),
1283 Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1284 };
1285 if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1286 return self.aborted_txn_error().await;
1287 }
1288 match self
1289 .adapter_client
1290 .prepare(name, maybe_stmt, sql.to_string(), param_types)
1291 .await
1292 {
1293 Ok(()) => {
1294 self.send(BackendMessage::ParseComplete).await?;
1295 Ok(State::Ready)
1296 }
1297 Err(e) => {
1298 self.send_error_and_get_state(e.into_response(Severity::Error))
1299 .await
1300 }
1301 }
1302 }
1303
1304 #[instrument(level = "debug")]
1306 async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1307 self.end_transaction(EndTransactionAction::Commit).await
1308 }
1309
1310 #[instrument(level = "debug")]
1312 async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1313 self.end_transaction(EndTransactionAction::Rollback).await
1314 }
1315
1316 #[instrument(level = "debug")]
1318 async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1319 self.txn_needs_commit = false;
1320 let resp = self.adapter_client.end_transaction(action).await;
1321 if let Err(err) = resp {
1322 self.send(BackendMessage::ErrorResponse(
1323 err.into_response(Severity::Error),
1324 ))
1325 .await?;
1326 }
1327 Ok(())
1328 }
1329
1330 #[instrument(level = "debug")]
1331 async fn bind(
1332 &mut self,
1333 portal_name: String,
1334 statement_name: String,
1335 param_formats: Vec<Format>,
1336 raw_params: Vec<Option<Vec<u8>>>,
1337 result_formats: Vec<Format>,
1338 ) -> Result<State, io::Error> {
1339 self.ensure_transaction(1, "bind").await?;
1341
1342 let aborted_txn = self.is_aborted_txn();
1343 let stmt = match self
1344 .adapter_client
1345 .get_prepared_statement(&statement_name)
1346 .await
1347 {
1348 Ok(stmt) => stmt,
1349 Err(err) => {
1350 return self
1351 .send_error_and_get_state(err.into_response(Severity::Error))
1352 .await;
1353 }
1354 };
1355
1356 let param_types = &stmt.desc().param_types;
1357 if param_types.len() != raw_params.len() {
1358 let message = format!(
1359 "bind message supplies {actual} parameters, \
1360 but prepared statement \"{name}\" requires {expected}",
1361 name = statement_name,
1362 actual = raw_params.len(),
1363 expected = param_types.len()
1364 );
1365 return self
1366 .send_error_and_get_state(ErrorResponse::error(
1367 SqlState::PROTOCOL_VIOLATION,
1368 message,
1369 ))
1370 .await;
1371 }
1372 let param_formats = match pad_formats(param_formats, raw_params.len()) {
1373 Ok(param_formats) => param_formats,
1374 Err(msg) => {
1375 return self
1376 .send_error_and_get_state(ErrorResponse::error(
1377 SqlState::PROTOCOL_VIOLATION,
1378 msg,
1379 ))
1380 .await;
1381 }
1382 };
1383 if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1384 return self.aborted_txn_error().await;
1385 }
1386 let buf = RowArena::new();
1387 let mut params = vec![];
1388 for ((raw_param, mz_typ), format) in raw_params
1389 .into_iter()
1390 .zip_eq(param_types)
1391 .zip_eq(param_formats)
1392 {
1393 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1394 let datum = match raw_param {
1395 None => Datum::Null,
1396 Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1397 Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1398 Ok(datum) => datum,
1399 Err(msg) => {
1400 return self
1401 .send_error_and_get_state(ErrorResponse::error(
1402 SqlState::INVALID_PARAMETER_VALUE,
1403 msg,
1404 ))
1405 .await;
1406 }
1407 },
1408 Err(err) => {
1409 let msg = format!("unable to decode parameter: {}", err);
1410 return self
1411 .send_error_and_get_state(ErrorResponse::error(
1412 SqlState::INVALID_PARAMETER_VALUE,
1413 msg,
1414 ))
1415 .await;
1416 }
1417 },
1418 };
1419 params.push((datum, mz_typ.clone()))
1420 }
1421
1422 let result_formats = match pad_formats(
1423 result_formats,
1424 stmt.desc()
1425 .relation_desc
1426 .clone()
1427 .map(|desc| desc.typ().column_types.len())
1428 .unwrap_or(0),
1429 ) {
1430 Ok(result_formats) => result_formats,
1431 Err(msg) => {
1432 return self
1433 .send_error_and_get_state(ErrorResponse::error(
1434 SqlState::PROTOCOL_VIOLATION,
1435 msg,
1436 ))
1437 .await;
1438 }
1439 };
1440
1441 if !stmt.stmt().map_or(false, |stmt| match stmt {
1444 Statement::Copy(CopyStatement {
1445 direction: CopyDirection::To,
1446 ..
1447 }) => true,
1448 Statement::Copy(CopyStatement {
1449 direction: CopyDirection::From,
1450 target: CopyTarget::Expr(_),
1454 ..
1455 }) => true,
1456 _ => false,
1457 }) {
1458 if let Some(desc) = stmt.desc().relation_desc.clone() {
1459 for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1460 if let Format::Binary = format {
1461 if let Err(msg) = mz_pgrepr::Value::binary_encoding_error(&ty.scalar_type) {
1462 return self
1463 .send_error_and_get_state(ErrorResponse::error(
1464 SqlState::PROTOCOL_VIOLATION,
1465 msg,
1466 ))
1467 .await;
1468 }
1469 }
1470 }
1471 }
1472 }
1473
1474 let desc = stmt.desc().clone();
1475 let logging = Arc::clone(stmt.logging());
1476 let stmt_ast = stmt.stmt().cloned();
1477 let state_revision = stmt.state_revision;
1478 if let Err(err) = self.adapter_client.session().set_portal(
1479 portal_name,
1480 desc,
1481 stmt_ast,
1482 logging,
1483 params,
1484 result_formats,
1485 state_revision,
1486 ) {
1487 return self
1488 .send_error_and_get_state(err.into_response(Severity::Error))
1489 .await;
1490 }
1491
1492 self.send(BackendMessage::BindComplete).await?;
1493 Ok(State::Ready)
1494 }
1495
1496 fn execute(
1499 &mut self,
1500 portal_name: String,
1501 max_rows: ExecuteCount,
1502 get_response: GetResponse,
1503 fetch_portal_name: Option<String>,
1504 timeout: ExecuteTimeout,
1505 outer_ctx_extra: Option<ExecuteContextGuard>,
1506 received: Option<EpochMillis>,
1507 ) -> BoxFuture<'_, Result<State, io::Error>> {
1508 async move {
1509 let aborted_txn = self.is_aborted_txn();
1510
1511 let portal = match self
1513 .adapter_client
1514 .session()
1515 .get_portal_unverified_mut(&portal_name)
1516 {
1517 Some(portal) => portal,
1518 None => {
1519 let msg = format!("portal {} does not exist", portal_name.quoted());
1520 if let Some(outer_ctx_extra) = outer_ctx_extra {
1521 self.adapter_client.retire_execute(
1522 outer_ctx_extra,
1523 StatementEndedExecutionReason::Errored { error: msg.clone() },
1524 );
1525 }
1526 return self
1527 .send_error_and_get_state(ErrorResponse::error(
1528 SqlState::INVALID_CURSOR_NAME,
1529 msg,
1530 ))
1531 .await;
1532 }
1533 };
1534
1535 *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1536
1537 let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1539 if aborted_txn && !txn_exit_stmt {
1540 if let Some(outer_ctx_extra) = outer_ctx_extra {
1541 self.adapter_client.retire_execute(
1542 outer_ctx_extra,
1543 StatementEndedExecutionReason::Errored {
1544 error: ABORTED_TXN_MSG.to_string(),
1545 },
1546 );
1547 }
1548 return self.aborted_txn_error().await;
1549 }
1550
1551 let row_desc = portal.desc.relation_desc.clone();
1552 match portal.state {
1553 PortalState::NotStarted => {
1554 self.ensure_transaction(1, "execute").await?;
1556 match self
1557 .adapter_client
1558 .execute(
1559 portal_name.clone(),
1560 self.conn.wait_closed(),
1561 outer_ctx_extra,
1562 )
1563 .await
1564 {
1565 Ok((response, execute_started)) => {
1566 self.send_pending_notices().await?;
1567 self.send_execute_response(
1568 response,
1569 row_desc,
1570 portal_name,
1571 max_rows,
1572 get_response,
1573 fetch_portal_name,
1574 timeout,
1575 execute_started,
1576 )
1577 .await
1578 }
1579 Err(e) => {
1580 self.send_pending_notices().await?;
1581 self.send_error_and_get_state(e.into_response(Severity::Error))
1582 .await
1583 }
1584 }
1585 }
1586 PortalState::InProgress(rows) => {
1587 let rows = rows.take().expect("InProgress rows must be populated");
1588 let (result, statement_ended_execution_reason) = match self
1589 .send_rows(
1590 row_desc.expect("portal missing row desc on resumption"),
1591 portal_name,
1592 rows,
1593 max_rows,
1594 get_response,
1595 fetch_portal_name,
1596 timeout,
1597 )
1598 .await
1599 {
1600 Err(e) => {
1601 (Err(e), StatementEndedExecutionReason::Canceled)
1604 }
1605 Ok((ok, SendRowsEndedReason::Canceled)) => {
1606 (Ok(ok), StatementEndedExecutionReason::Canceled)
1607 }
1608 Ok((
1621 ok,
1622 SendRowsEndedReason::Success {
1623 result_size: _,
1624 rows_returned: _,
1625 },
1626 )) => (
1627 Ok(ok),
1628 StatementEndedExecutionReason::Success {
1629 result_size: None,
1630 rows_returned: None,
1631 execution_strategy: None,
1632 },
1633 ),
1634 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1635 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1636 }
1637 };
1638 if let Some(outer_ctx_extra) = outer_ctx_extra {
1639 self.adapter_client
1640 .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1641 }
1642 result
1643 }
1644 PortalState::Completed(Some(tag)) => {
1651 let tag = tag.to_string();
1652 if let Some(outer_ctx_extra) = outer_ctx_extra {
1653 self.adapter_client.retire_execute(
1654 outer_ctx_extra,
1655 StatementEndedExecutionReason::Success {
1656 result_size: None,
1657 rows_returned: None,
1658 execution_strategy: None,
1659 },
1660 );
1661 }
1662 self.send(BackendMessage::CommandComplete { tag }).await?;
1663 Ok(State::Ready)
1664 }
1665 PortalState::Completed(None) => {
1666 let error = format!(
1667 "portal {} cannot be run",
1668 Ident::new_unchecked(portal_name).to_ast_string_stable()
1669 );
1670 if let Some(outer_ctx_extra) = outer_ctx_extra {
1671 self.adapter_client.retire_execute(
1672 outer_ctx_extra,
1673 StatementEndedExecutionReason::Errored {
1674 error: error.clone(),
1675 },
1676 );
1677 }
1678 self.send_error_and_get_state(ErrorResponse::error(
1679 SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1680 error,
1681 ))
1682 .await
1683 }
1684 }
1685 }
1686 .instrument(debug_span!("execute"))
1687 .boxed()
1688 }
1689
1690 #[instrument(level = "debug")]
1691 async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1692 self.ensure_transaction(1, "describe_statement").await?;
1694
1695 let stmt = match self.adapter_client.get_prepared_statement(name).await {
1696 Ok(stmt) => stmt,
1697 Err(err) => {
1698 return self
1699 .send_error_and_get_state(err.into_response(Severity::Error))
1700 .await;
1701 }
1702 };
1703 let parameter_desc = BackendMessage::ParameterDescription(
1705 stmt.desc()
1706 .param_types
1707 .iter()
1708 .map(mz_pgrepr::Type::from)
1709 .collect(),
1710 );
1711 let formats = vec![Format::Text; stmt.desc().arity()];
1715 let row_desc = describe_rows(stmt.desc(), &formats);
1716 self.send_all([parameter_desc, row_desc]).await?;
1717 Ok(State::Ready)
1718 }
1719
1720 #[instrument(level = "debug")]
1721 async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1722 self.ensure_transaction(1, "describe_portal").await?;
1724
1725 let session = self.adapter_client.session();
1726 let row_desc = session
1727 .get_portal_unverified(name)
1728 .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1729 match row_desc {
1730 Some(row_desc) => {
1731 self.send(row_desc).await?;
1732 Ok(State::Ready)
1733 }
1734 None => {
1735 self.send_error_and_get_state(ErrorResponse::error(
1736 SqlState::INVALID_CURSOR_NAME,
1737 format!("portal {} does not exist", name.quoted()),
1738 ))
1739 .await
1740 }
1741 }
1742 }
1743
1744 #[instrument(level = "debug")]
1745 async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1746 self.adapter_client
1747 .session()
1748 .remove_prepared_statement(&name);
1749 self.send(BackendMessage::CloseComplete).await?;
1750 Ok(State::Ready)
1751 }
1752
1753 #[instrument(level = "debug")]
1754 async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1755 self.adapter_client.session().remove_portal(&name);
1756 self.send(BackendMessage::CloseComplete).await?;
1757 Ok(State::Ready)
1758 }
1759
1760 fn complete_portal(&mut self, name: &str) {
1761 let portal = self
1762 .adapter_client
1763 .session()
1764 .get_portal_unverified_mut(name)
1765 .expect("portal should exist");
1766 *portal.state = PortalState::Completed(None);
1767 }
1768
1769 async fn fetch(
1770 &mut self,
1771 name: String,
1772 count: Option<FetchDirection>,
1773 max_rows: ExecuteCount,
1774 fetch_portal_name: Option<String>,
1775 timeout: ExecuteTimeout,
1776 ctx_extra: ExecuteContextGuard,
1777 ) -> Result<State, io::Error> {
1778 let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1781
1782 let count = match (max_rows, count) {
1795 (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1796 let count = usize::cast_from(count);
1797 if max_rows < count {
1798 let msg = "Execute with max_rows < a FETCH's count is not supported";
1799 self.adapter_client.retire_execute(
1800 ctx_extra,
1801 StatementEndedExecutionReason::Errored {
1802 error: msg.to_string(),
1803 },
1804 );
1805 return self
1806 .send_error_and_get_state(ErrorResponse::error(
1807 SqlState::FEATURE_NOT_SUPPORTED,
1808 msg,
1809 ))
1810 .await;
1811 }
1812 ExecuteCount::Count(count)
1813 }
1814 (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1815 let msg = "Execute with max_rows of a FETCH ALL is not supported";
1816 self.adapter_client.retire_execute(
1817 ctx_extra,
1818 StatementEndedExecutionReason::Errored {
1819 error: msg.to_string(),
1820 },
1821 );
1822 return self
1823 .send_error_and_get_state(ErrorResponse::error(
1824 SqlState::FEATURE_NOT_SUPPORTED,
1825 msg,
1826 ))
1827 .await;
1828 }
1829 (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1830 (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1831 ExecuteCount::Count(usize::cast_from(count))
1832 }
1833 };
1834 let cursor_name = name.to_string();
1835 self.execute(
1836 cursor_name,
1837 count,
1838 fetch_message,
1839 fetch_portal_name,
1840 timeout,
1841 Some(ctx_extra),
1842 None,
1843 )
1844 .await
1845 }
1846
1847 async fn flush(&mut self) -> Result<State, io::Error> {
1848 self.conn.flush().await?;
1849 Ok(State::Ready)
1850 }
1851
1852 #[instrument(level = "debug")]
1857 async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1858 where
1859 M: Into<BackendMessage>,
1860 {
1861 let message: BackendMessage = message.into();
1862 let is_error =
1863 matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1864
1865 self.conn.send(message).await?;
1866
1867 if is_error {
1874 self.conn.flush().await?;
1875 }
1876
1877 Ok(())
1878 }
1879
1880 #[instrument(level = "debug")]
1881 pub async fn send_all(
1882 &mut self,
1883 messages: impl IntoIterator<Item = BackendMessage>,
1884 ) -> Result<(), io::Error> {
1885 for m in messages {
1886 self.send(m).await?;
1887 }
1888 Ok(())
1889 }
1890
1891 #[instrument(level = "debug")]
1892 async fn sync(&mut self) -> Result<State, io::Error> {
1893 if self.adapter_client.session().transaction().is_implicit() {
1895 self.commit_transaction().await?;
1896 }
1897 self.ready().await
1898 }
1899
1900 #[instrument(level = "debug")]
1901 async fn ready(&mut self) -> Result<State, io::Error> {
1902 let txn_state = self.adapter_client.session().transaction().into();
1903 self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1904 self.flush().await
1905 }
1906
1907 #[allow(clippy::too_many_arguments)]
1908 #[instrument(level = "debug")]
1909 async fn send_execute_response(
1910 &mut self,
1911 response: ExecuteResponse,
1912 row_desc: Option<RelationDesc>,
1913 portal_name: String,
1914 max_rows: ExecuteCount,
1915 get_response: GetResponse,
1916 fetch_portal_name: Option<String>,
1917 timeout: ExecuteTimeout,
1918 execute_started: Instant,
1919 ) -> Result<State, io::Error> {
1920 let mut tag = response.tag();
1921
1922 macro_rules! command_complete {
1923 () => {{
1924 self.send(BackendMessage::CommandComplete {
1925 tag: tag
1926 .take()
1927 .expect("command_complete only called on tag-generating results"),
1928 })
1929 .await?;
1930 Ok(State::Ready)
1931 }};
1932 }
1933
1934 let r = match response {
1935 ExecuteResponse::ClosedCursor => {
1936 self.complete_portal(&portal_name);
1937 command_complete!()
1938 }
1939 ExecuteResponse::DeclaredCursor => {
1940 self.complete_portal(&portal_name);
1941 command_complete!()
1942 }
1943 ExecuteResponse::EmptyQuery => {
1944 self.send(BackendMessage::EmptyQueryResponse).await?;
1945 Ok(State::Ready)
1946 }
1947 ExecuteResponse::Fetch {
1948 name,
1949 count,
1950 timeout,
1951 ctx_extra,
1952 } => {
1953 self.fetch(
1954 name,
1955 count,
1956 max_rows,
1957 Some(portal_name.to_string()),
1958 timeout,
1959 ctx_extra,
1960 )
1961 .await
1962 }
1963 ExecuteResponse::SendingRowsStreaming {
1964 rows,
1965 instance_id,
1966 strategy,
1967 } => {
1968 let row_desc = row_desc
1969 .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1970
1971 let span = tracing::debug_span!("sending_rows_streaming");
1972
1973 self.send_rows(
1974 row_desc,
1975 portal_name,
1976 InProgressRows::new(RecordFirstRowStream::new(
1977 Box::new(rows),
1978 execute_started,
1979 &self.adapter_client,
1980 Some(instance_id),
1981 Some(strategy),
1982 )),
1983 max_rows,
1984 get_response,
1985 fetch_portal_name,
1986 timeout,
1987 )
1988 .instrument(span)
1989 .await
1990 .map(|(state, _)| state)
1991 }
1992 ExecuteResponse::SendingRowsImmediate { rows } => {
1993 let row_desc = row_desc
1994 .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1995
1996 let span = tracing::debug_span!("sending_rows_immediate");
1997
1998 let stream =
1999 futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
2000 self.send_rows(
2001 row_desc,
2002 portal_name,
2003 InProgressRows::new(RecordFirstRowStream::new(
2004 Box::new(stream),
2005 execute_started,
2006 &self.adapter_client,
2007 None,
2008 Some(StatementExecutionStrategy::Constant),
2009 )),
2010 max_rows,
2011 get_response,
2012 fetch_portal_name,
2013 timeout,
2014 )
2015 .instrument(span)
2016 .await
2017 .map(|(state, _)| state)
2018 }
2019 ExecuteResponse::SetVariable { name, .. } => {
2020 let qn = name.to_string();
2023 let msg = if let Some(var) = self
2024 .adapter_client
2025 .session()
2026 .vars_mut()
2027 .notify_set()
2028 .find(|v| v.name() == qn)
2029 {
2030 Some(BackendMessage::ParameterStatus(var.name(), var.value()))
2031 } else {
2032 None
2033 };
2034 if let Some(msg) = msg {
2035 self.send(msg).await?;
2036 }
2037 command_complete!()
2038 }
2039 ExecuteResponse::Subscribing {
2040 rx,
2041 ctx_extra,
2042 instance_id,
2043 } => {
2044 if fetch_portal_name.is_none() {
2045 let mut msg = ErrorResponse::notice(
2046 SqlState::WARNING,
2047 "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
2048 );
2049 if self.adapter_client.session().vars().application_name() == "psql" {
2050 msg.hint = Some(
2051 "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
2052 .into(),
2053 )
2054 }
2055 self.send(msg).await?;
2056 self.conn.flush().await?;
2057 }
2058 let row_desc =
2059 row_desc.expect("missing row description for ExecuteResponse::Subscribing");
2060 let (result, statement_ended_execution_reason) = match self
2061 .send_rows(
2062 row_desc,
2063 portal_name,
2064 InProgressRows::new(RecordFirstRowStream::new(
2065 Box::new(UnboundedReceiverStream::new(rx)),
2066 execute_started,
2067 &self.adapter_client,
2068 Some(instance_id),
2069 None,
2070 )),
2071 max_rows,
2072 get_response,
2073 fetch_portal_name,
2074 timeout,
2075 )
2076 .await
2077 {
2078 Err(e) => {
2079 (Err(e), StatementEndedExecutionReason::Canceled)
2082 }
2083 Ok((ok, SendRowsEndedReason::Canceled)) => {
2084 (Ok(ok), StatementEndedExecutionReason::Canceled)
2085 }
2086 Ok((
2087 ok,
2088 SendRowsEndedReason::Success {
2089 result_size,
2090 rows_returned,
2091 },
2092 )) => (
2093 Ok(ok),
2094 StatementEndedExecutionReason::Success {
2095 result_size: Some(result_size),
2096 rows_returned: Some(rows_returned),
2097 execution_strategy: None,
2098 },
2099 ),
2100 Ok((ok, SendRowsEndedReason::Errored { error })) => {
2101 (Ok(ok), StatementEndedExecutionReason::Errored { error })
2102 }
2103 };
2104 self.adapter_client
2105 .retire_execute(ctx_extra, statement_ended_execution_reason);
2106 return result;
2107 }
2108 ExecuteResponse::CopyTo { format, resp } => {
2109 let row_desc =
2110 row_desc.expect("missing row description for ExecuteResponse::CopyTo");
2111 match *resp {
2112 ExecuteResponse::Subscribing {
2113 rx,
2114 ctx_extra,
2115 instance_id,
2116 } => {
2117 let (result, statement_ended_execution_reason) = match self
2118 .copy_rows(
2119 format,
2120 row_desc,
2121 RecordFirstRowStream::new(
2122 Box::new(UnboundedReceiverStream::new(rx)),
2123 execute_started,
2124 &self.adapter_client,
2125 Some(instance_id),
2126 None,
2127 ),
2128 )
2129 .await
2130 {
2131 Err(e) => {
2132 (Err(e), StatementEndedExecutionReason::Canceled)
2135 }
2136 Ok((
2137 state,
2138 SendRowsEndedReason::Success {
2139 result_size,
2140 rows_returned,
2141 },
2142 )) => (
2143 Ok(state),
2144 StatementEndedExecutionReason::Success {
2145 result_size: Some(result_size),
2146 rows_returned: Some(rows_returned),
2147 execution_strategy: None,
2148 },
2149 ),
2150 Ok((state, SendRowsEndedReason::Errored { error })) => {
2151 (Ok(state), StatementEndedExecutionReason::Errored { error })
2152 }
2153 Ok((state, SendRowsEndedReason::Canceled)) => {
2154 (Ok(state), StatementEndedExecutionReason::Canceled)
2155 }
2156 };
2157 self.adapter_client
2158 .retire_execute(ctx_extra, statement_ended_execution_reason);
2159 return result;
2160 }
2161 ExecuteResponse::SendingRowsStreaming {
2162 rows,
2163 instance_id,
2164 strategy,
2165 } => {
2166 return self
2171 .copy_rows(
2172 format,
2173 row_desc,
2174 RecordFirstRowStream::new(
2175 Box::new(rows),
2176 execute_started,
2177 &self.adapter_client,
2178 Some(instance_id),
2179 Some(strategy),
2180 ),
2181 )
2182 .await
2183 .map(|(state, _)| state);
2184 }
2185 ExecuteResponse::SendingRowsImmediate { rows } => {
2186 let span = tracing::debug_span!("sending_rows_immediate");
2187
2188 let rows = futures::stream::once(futures::future::ready(
2189 PeekResponseUnary::Rows(rows),
2190 ));
2191 return self
2196 .copy_rows(
2197 format,
2198 row_desc,
2199 RecordFirstRowStream::new(
2200 Box::new(rows),
2201 execute_started,
2202 &self.adapter_client,
2203 None,
2204 Some(StatementExecutionStrategy::Constant),
2205 ),
2206 )
2207 .instrument(span)
2208 .await
2209 .map(|(state, _)| state);
2210 }
2211 _ => {
2212 return self
2213 .send_error_and_get_state(ErrorResponse::error(
2214 SqlState::INTERNAL_ERROR,
2215 "unsupported COPY response type".to_string(),
2216 ))
2217 .await;
2218 }
2219 };
2220 }
2221 ExecuteResponse::CopyFrom {
2222 target_id,
2223 target_name,
2224 columns,
2225 params,
2226 ctx_extra,
2227 } => {
2228 let row_desc =
2229 row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2230 self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2231 .await
2232 }
2233 ExecuteResponse::TransactionCommitted { params }
2234 | ExecuteResponse::TransactionRolledBack { params } => {
2235 let notify_set: mz_ore::collections::HashSet<String> = self
2236 .adapter_client
2237 .session()
2238 .vars()
2239 .notify_set()
2240 .map(|v| v.name().to_string())
2241 .collect();
2242
2243 for (name, value) in params
2245 .into_iter()
2246 .filter(|(name, _v)| notify_set.contains(*name))
2247 {
2248 let msg = BackendMessage::ParameterStatus(name, value);
2249 self.send(msg).await?;
2250 }
2251 command_complete!()
2252 }
2253
2254 ExecuteResponse::AlteredDefaultPrivileges
2255 | ExecuteResponse::AlteredObject(..)
2256 | ExecuteResponse::AlteredRole
2257 | ExecuteResponse::AlteredSystemConfiguration
2258 | ExecuteResponse::CreatedCluster { .. }
2259 | ExecuteResponse::CreatedClusterReplica { .. }
2260 | ExecuteResponse::CreatedConnection { .. }
2261 | ExecuteResponse::CreatedDatabase { .. }
2262 | ExecuteResponse::CreatedIndex { .. }
2263 | ExecuteResponse::CreatedIntrospectionSubscribe
2264 | ExecuteResponse::CreatedMaterializedView { .. }
2265 | ExecuteResponse::CreatedRole
2266 | ExecuteResponse::CreatedSchema { .. }
2267 | ExecuteResponse::CreatedSecret { .. }
2268 | ExecuteResponse::CreatedSink { .. }
2269 | ExecuteResponse::CreatedSource { .. }
2270 | ExecuteResponse::CreatedTable { .. }
2271 | ExecuteResponse::CreatedType
2272 | ExecuteResponse::CreatedView { .. }
2273 | ExecuteResponse::CreatedViews { .. }
2274 | ExecuteResponse::CreatedNetworkPolicy
2275 | ExecuteResponse::Comment
2276 | ExecuteResponse::Deallocate { .. }
2277 | ExecuteResponse::Deleted(..)
2278 | ExecuteResponse::DiscardedAll
2279 | ExecuteResponse::DiscardedTemp
2280 | ExecuteResponse::DroppedObject(_)
2281 | ExecuteResponse::DroppedOwned
2282 | ExecuteResponse::GrantedPrivilege
2283 | ExecuteResponse::GrantedRole
2284 | ExecuteResponse::Inserted(..)
2285 | ExecuteResponse::Copied(..)
2286 | ExecuteResponse::Prepare
2287 | ExecuteResponse::Raised
2288 | ExecuteResponse::ReassignOwned
2289 | ExecuteResponse::RevokedPrivilege
2290 | ExecuteResponse::RevokedRole
2291 | ExecuteResponse::StartedTransaction { .. }
2292 | ExecuteResponse::Updated(..)
2293 | ExecuteResponse::ValidatedConnection => {
2294 command_complete!()
2295 }
2296 };
2297
2298 assert_none!(tag, "tag created but not consumed: {:?}", tag);
2299 r
2300 }
2301
2302 #[allow(clippy::too_many_arguments)]
2303 #[mz_ore::instrument(level = "debug")]
2305 async fn send_rows(
2306 &mut self,
2307 row_desc: RelationDesc,
2308 portal_name: String,
2309 mut rows: InProgressRows,
2310 max_rows: ExecuteCount,
2311 get_response: GetResponse,
2312 fetch_portal_name: Option<String>,
2313 timeout: ExecuteTimeout,
2314 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2315 let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2318 name
2319 } else {
2320 &portal_name
2321 };
2322 let result_formats = self
2323 .adapter_client
2324 .session()
2325 .get_portal_unverified(result_format_portal_name)
2326 .expect("valid fetch portal name for send rows")
2327 .result_formats
2328 .clone();
2329
2330 let (mut wait_once, mut deadline) = match timeout {
2331 ExecuteTimeout::None => (false, None),
2332 ExecuteTimeout::Seconds(t) => (
2333 false,
2334 Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2335 ),
2336 ExecuteTimeout::WaitOnce => (true, None),
2337 };
2338
2339 {
2341 let portal_name_desc = &self
2342 .adapter_client
2343 .session()
2344 .get_portal_unverified(portal_name.as_str())
2345 .expect("portal should exist")
2346 .desc
2347 .relation_desc;
2348 if let Some(portal_name_desc) = portal_name_desc {
2349 soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2350 }
2351 if let Some(fetch_portal_name) = &fetch_portal_name {
2352 let fetch_portal_desc = &self
2353 .adapter_client
2354 .session()
2355 .get_portal_unverified(fetch_portal_name)
2356 .expect("portal should exist")
2357 .desc
2358 .relation_desc;
2359 if let Some(fetch_portal_desc) = fetch_portal_desc {
2360 soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2361 }
2362 }
2363 }
2364
2365 self.conn.set_encode_state(
2366 row_desc
2367 .typ()
2368 .column_types
2369 .iter()
2370 .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2371 .zip_eq(result_formats)
2372 .collect(),
2373 );
2374
2375 let mut total_sent_rows = 0;
2376 let mut total_sent_bytes = 0;
2377 let mut want_rows = match max_rows {
2379 ExecuteCount::All => usize::MAX,
2380 ExecuteCount::Count(count) => count,
2381 };
2382
2383 loop {
2385 let batch = if rows.current.is_some() {
2388 FetchResult::Rows(rows.current.take())
2389 } else if want_rows == 0 {
2390 FetchResult::Rows(None)
2391 } else {
2392 let notice_fut = self.adapter_client.session().recv_notice();
2393 tokio::select! {
2408 biased;
2409 err = self.conn.wait_closed() => return Err(err),
2410 batch = rows.remaining.recv() => match batch {
2411 None => FetchResult::Rows(None),
2412 Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2413 Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2414 Some(PeekResponseUnary::DependencyDropped(dep)) => {
2415 FetchResult::Error(dep.query_terminated_error())
2416 }
2417 Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2418 },
2419 notice = notice_fut => {
2420 FetchResult::Notice(notice)
2421 }
2422 _ = time::sleep_until(
2423 deadline.unwrap_or_else(tokio::time::Instant::now),
2424 ), if deadline.is_some() => FetchResult::Rows(None),
2425 }
2426 };
2427
2428 match batch {
2429 FetchResult::Rows(None) => break,
2430 FetchResult::Rows(Some(mut batch_rows)) => {
2431 if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2432 let msg = err.to_string();
2433 return self
2434 .send_error_and_get_state(err.into_response(Severity::Error))
2435 .await
2436 .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2437 }
2438
2439 if wait_once && batch_rows.peek().is_some() {
2443 deadline = Some(tokio::time::Instant::now());
2444 wait_once = false;
2445 }
2446
2447 let mut sent_rows = 0;
2449 let mut sent_bytes = 0;
2450 let messages = (&mut batch_rows)
2451 .map(|row| {
2456 let row_len = row.byte_len();
2457 let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2458 (row_len, BackendMessage::DataRow(values))
2459 })
2460 .inspect(|(row_len, _)| {
2461 sent_bytes += row_len;
2462 sent_rows += 1
2463 })
2464 .map(|(_row_len, row)| row)
2465 .take(want_rows);
2466 self.send_all(messages).await?;
2467
2468 total_sent_rows += sent_rows;
2469 total_sent_bytes += sent_bytes;
2470 want_rows -= sent_rows;
2471
2472 if want_rows == 0 {
2475 if batch_rows.peek().is_some() {
2476 rows.current = Some(batch_rows);
2477 }
2478 break;
2479 }
2480
2481 self.conn.flush().await?;
2482 }
2483 FetchResult::Notice(notice) => {
2484 self.send(notice.into_response()).await?;
2485 self.conn.flush().await?;
2486 }
2487 FetchResult::Error(text) => {
2488 return self
2489 .send_error_and_get_state(ErrorResponse::error(
2490 SqlState::INTERNAL_ERROR,
2491 text.clone(),
2492 ))
2493 .await
2494 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2495 }
2496 FetchResult::Canceled => {
2497 return self
2498 .send_error_and_get_state(ErrorResponse::error(
2499 SqlState::QUERY_CANCELED,
2500 "canceling statement due to user request",
2501 ))
2502 .await
2503 .map(|state| (state, SendRowsEndedReason::Canceled));
2504 }
2505 }
2506 }
2507
2508 let portal = self
2509 .adapter_client
2510 .session()
2511 .get_portal_unverified_mut(&portal_name)
2512 .expect("valid portal name for send rows");
2513
2514 let saw_rows = rows.remaining.saw_rows;
2515 let no_more_rows = rows.no_more_rows();
2516 let metric_recorded = rows.remaining.metric_recorded;
2517 let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2518
2519 if no_more_rows && !metric_recorded {
2520 rows.remaining.metric_recorded = true;
2521 }
2522
2523 *portal.state = PortalState::InProgress(Some(rows));
2526
2527 let fetch_portal = fetch_portal_name.map(|name| {
2528 self.adapter_client
2529 .session()
2530 .get_portal_unverified_mut(&name)
2531 .expect("valid fetch portal")
2532 });
2533 let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2534 self.send(response_message).await?;
2535
2536 if no_more_rows && !metric_recorded {
2539 let statement_type = if let Some(stmt) = &self
2540 .adapter_client
2541 .session()
2542 .get_portal_unverified(&portal_name)
2543 .expect("valid portal name for send_rows")
2544 .stmt
2545 {
2546 metrics::statement_type_label_value(stmt.deref())
2547 } else {
2548 "no-statement"
2549 };
2550 let duration = if saw_rows {
2551 recorded_first_row_instant
2552 .expect("recorded_first_row_instant because saw_rows")
2553 .elapsed()
2554 } else {
2555 Duration::ZERO
2559 };
2560 self.adapter_client
2561 .inner()
2562 .metrics()
2563 .result_rows_first_to_last_byte_seconds
2564 .with_label_values(&[statement_type])
2565 .observe(duration.as_secs_f64());
2566 }
2567
2568 Ok((
2569 State::Ready,
2570 SendRowsEndedReason::Success {
2571 result_size: u64::cast_from(total_sent_bytes),
2572 rows_returned: u64::cast_from(total_sent_rows),
2573 },
2574 ))
2575 }
2576
2577 #[mz_ore::instrument(level = "debug")]
2578 async fn copy_rows(
2579 &mut self,
2580 format: CopyFormat,
2581 row_desc: RelationDesc,
2582 mut stream: RecordFirstRowStream,
2583 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2584 let (row_format, encode_format) = match format {
2585 CopyFormat::Text => (
2586 CopyFormatParams::Text(CopyTextFormatParams::default()),
2587 Format::Text,
2588 ),
2589 CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2590 CopyFormat::Csv => (
2591 CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2592 Format::Text,
2593 ),
2594 CopyFormat::Parquet => {
2595 let text = "Parquet format is not supported".to_string();
2596 return self
2597 .send_error_and_get_state(ErrorResponse::error(
2598 SqlState::INTERNAL_ERROR,
2599 text.clone(),
2600 ))
2601 .await
2602 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2603 }
2604 };
2605
2606 if let CopyFormat::Binary = format {
2614 if let Some(msg) = row_desc
2615 .iter_types()
2616 .find_map(|ty| mz_pgrepr::Value::binary_encoding_error(&ty.scalar_type).err())
2617 {
2618 return self
2619 .send_error_and_get_state(ErrorResponse::error(
2620 SqlState::PROTOCOL_VIOLATION,
2621 msg,
2622 ))
2623 .await
2624 .map(|state| {
2625 (
2626 state,
2627 SendRowsEndedReason::Errored {
2628 error: msg.to_string(),
2629 },
2630 )
2631 });
2632 }
2633 }
2634
2635 let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2636 mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2637 };
2638
2639 let typ = row_desc.typ();
2640 let column_formats = iter::repeat(encode_format)
2641 .take(typ.column_types.len())
2642 .collect();
2643 self.send(BackendMessage::CopyOutResponse {
2644 overall_format: encode_format,
2645 column_formats,
2646 })
2647 .await?;
2648
2649 let mut out = Vec::new();
2654
2655 if let CopyFormat::Binary = format {
2656 out.extend(b"PGCOPY\n\xFF\r\n\0");
2658 out.extend([0, 0, 0, 0]);
2660 out.extend([0, 0, 0, 0]);
2662 }
2663
2664 let mut count = 0;
2665 let mut total_sent_bytes = 0;
2666 loop {
2667 tokio::select! {
2668 e = self.conn.wait_closed() => return Err(e),
2669 batch = stream.recv() => match batch {
2670 None => break,
2671 Some(PeekResponseUnary::Error(text)) => {
2672 let err =
2673 ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone());
2674 return self
2675 .send_error_and_get_state(err)
2676 .await
2677 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2678 }
2679 Some(PeekResponseUnary::DependencyDropped(dep)) => {
2680 let err = dep.to_concurrent_dependency_drop();
2681 let text = err.to_string();
2682 let resp = err.into_response(Severity::Error);
2683 return self
2684 .send_error_and_get_state(resp)
2685 .await
2686 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2687 }
2688 Some(PeekResponseUnary::Canceled) => {
2689 return self.send_error_and_get_state(ErrorResponse::error(
2690 SqlState::QUERY_CANCELED,
2691 "canceling statement due to user request",
2692 ))
2693 .await.map(|state| (state, SendRowsEndedReason::Canceled));
2694 }
2695 Some(PeekResponseUnary::Rows(mut rows)) => {
2696 count += rows.count();
2697 while let Some(row) = rows.next() {
2698 total_sent_bytes += row.byte_len();
2699 encode_fn(row, typ, &mut out)?;
2700 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2701 .await?;
2702 }
2703 }
2704 },
2705 notice = self.adapter_client.session().recv_notice() => {
2706 self.send(notice.into_response())
2707 .await?;
2708 self.conn.flush().await?;
2709 }
2710 }
2711
2712 self.conn.flush().await?;
2713 }
2714 if let CopyFormat::Binary = format {
2716 let trailer: i16 = -1;
2717 out.extend(trailer.to_be_bytes());
2718 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2719 .await?;
2720 }
2721
2722 let tag = format!("COPY {}", count);
2723 self.send(BackendMessage::CopyDone).await?;
2724 self.send(BackendMessage::CommandComplete { tag }).await?;
2725 Ok((
2726 State::Ready,
2727 SendRowsEndedReason::Success {
2728 result_size: u64::cast_from(total_sent_bytes),
2729 rows_returned: u64::cast_from(count),
2730 },
2731 ))
2732 }
2733
2734 #[instrument(level = "debug")]
2737 async fn copy_from(
2738 &mut self,
2739 target_id: CatalogItemId,
2740 target_name: String,
2741 columns: Vec<ColumnIndex>,
2742 params: CopyFormatParams<'static>,
2743 row_desc: RelationDesc,
2744 mut ctx_extra: ExecuteContextGuard,
2745 ) -> Result<State, io::Error> {
2746 let res = self
2747 .copy_from_inner(
2748 target_id,
2749 target_name,
2750 columns,
2751 params,
2752 row_desc,
2753 &mut ctx_extra,
2754 )
2755 .await;
2756 match &res {
2757 Ok(State::Ready) => {
2758 self.adapter_client.retire_execute(
2759 ctx_extra,
2760 StatementEndedExecutionReason::Success {
2761 result_size: None,
2762 rows_returned: None,
2763 execution_strategy: None,
2764 },
2765 );
2766 }
2767 Ok(State::Done) => {
2768 self.adapter_client
2772 .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2773 }
2774 Err(e) => {
2775 self.adapter_client.retire_execute(
2776 ctx_extra,
2777 StatementEndedExecutionReason::Errored {
2778 error: format!("{e}"),
2779 },
2780 );
2781 }
2782 Ok(State::Drain) => {}
2783 }
2784 res
2785 }
2786
2787 async fn copy_from_inner(
2788 &mut self,
2789 target_id: CatalogItemId,
2790 target_name: String,
2791 columns: Vec<ColumnIndex>,
2792 params: CopyFormatParams<'static>,
2793 row_desc: RelationDesc,
2794 ctx_extra: &mut ExecuteContextGuard,
2795 ) -> Result<State, io::Error> {
2796 let typ = row_desc.typ();
2797 let column_formats = vec![Format::Text; typ.column_types.len()];
2798 self.send(BackendMessage::CopyInResponse {
2799 overall_format: Format::Text,
2800 column_formats,
2801 })
2802 .await?;
2803 self.conn.flush().await?;
2804
2805 let writer = match self
2807 .adapter_client
2808 .start_copy_from_stdin(
2809 target_id,
2810 target_name.clone(),
2811 columns.clone(),
2812 row_desc.clone(),
2813 params.clone(),
2814 )
2815 .await
2816 {
2817 Ok(writer) => writer,
2818 Err(e) => {
2819 loop {
2825 match self.conn.recv().await? {
2826 Some(FrontendMessage::CopyData(_)) => {}
2827 Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2828 break;
2829 }
2830 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2831 Some(_) => break,
2832 None => return Ok(State::Done),
2833 }
2834 }
2835 self.adapter_client.retire_execute(
2836 std::mem::take(ctx_extra),
2837 StatementEndedExecutionReason::Errored {
2838 error: e.to_string(),
2839 },
2840 );
2841 return self
2842 .send_error_and_get_state(e.into_response(Severity::Error))
2843 .await;
2844 }
2845 };
2846
2847 self.conn.set_copy_mode(true);
2849
2850 const BATCH_SIZE: usize = 32 * 1024 * 1024;
2852 let max_copy_from_row_size = self
2853 .adapter_client
2854 .get_system_vars()
2855 .await
2856 .max_copy_from_row_size()
2857 .try_into()
2858 .unwrap_or(usize::MAX);
2859
2860 let mut data = Vec::new();
2861 let mut row_scanner = CopyRowScanner::new(¶ms);
2862 let num_workers = writer.batch_txs.len();
2863 let mut next_worker: usize = 0;
2864 let mut saw_copy_done = false;
2865 let mut saw_end_marker = false;
2866 let mut copy_from_error: Option<(SqlState, String)> = None;
2867
2868 loop {
2871 let message = self.conn.recv().await?;
2872 match message {
2873 Some(FrontendMessage::CopyData(buf)) => {
2874 if saw_end_marker {
2875 continue;
2878 }
2879 data.extend(buf);
2880 row_scanner.scan_new_bytes(&data);
2881
2882 if let Some(end_pos) = row_scanner.end_marker_end() {
2883 data.truncate(end_pos);
2884 row_scanner.on_truncate(end_pos);
2885 saw_end_marker = true;
2886 }
2887
2888 if row_scanner.current_row_size(data.len()) > max_copy_from_row_size {
2890 copy_from_error = Some((
2891 SqlState::INSUFFICIENT_RESOURCES,
2892 format!(
2893 "COPY FROM STDIN row exceeded max_copy_from_row_size \
2894 ({max_copy_from_row_size} bytes)"
2895 ),
2896 ));
2897 break;
2898 }
2899
2900 let mut send_failed = false;
2903 while data.len() >= BATCH_SIZE {
2904 let split_pos = match row_scanner.last_row_end() {
2905 Some(pos) => pos,
2906 None => break, };
2908 let remainder = data.split_off(split_pos);
2909 let chunk = std::mem::replace(&mut data, remainder);
2910 row_scanner.on_split(split_pos);
2911 if writer.batch_txs[next_worker].send(chunk).await.is_err() {
2912 send_failed = true;
2913 break;
2914 }
2915 next_worker = (next_worker + 1) % num_workers;
2916 }
2917 if send_failed {
2920 break;
2921 }
2922 }
2923 Some(FrontendMessage::CopyDone) => {
2924 if !data.is_empty() {
2926 let chunk = std::mem::take(&mut data);
2927 let _ = writer.batch_txs[next_worker].send(chunk).await;
2929 }
2930 saw_copy_done = true;
2931 break;
2932 }
2933 Some(FrontendMessage::CopyFail(err)) => {
2934 self.adapter_client.retire_execute(
2935 std::mem::take(ctx_extra),
2936 StatementEndedExecutionReason::Canceled,
2937 );
2938 drop(writer);
2940 self.conn.set_copy_mode(false);
2941 return self
2942 .send_error_and_get_state(ErrorResponse::error(
2943 SqlState::QUERY_CANCELED,
2944 format!("COPY from stdin failed: {}", err),
2945 ))
2946 .await;
2947 }
2948 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2949 Some(_) => {
2950 let msg = "unexpected message type during COPY from stdin";
2951 self.adapter_client.retire_execute(
2952 std::mem::take(ctx_extra),
2953 StatementEndedExecutionReason::Errored {
2954 error: msg.to_string(),
2955 },
2956 );
2957 drop(writer);
2958 self.conn.set_copy_mode(false);
2959 return self
2960 .send_error_and_get_state(ErrorResponse::error(
2961 SqlState::PROTOCOL_VIOLATION,
2962 msg,
2963 ))
2964 .await;
2965 }
2966 None => {
2967 drop(writer);
2968 self.conn.set_copy_mode(false);
2969 return Ok(State::Done);
2970 }
2971 }
2972 }
2973
2974 if !saw_copy_done {
2978 loop {
2979 match self.conn.recv().await? {
2980 Some(FrontendMessage::CopyData(_)) => {}
2981 Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2982 break;
2983 }
2984 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2985 Some(_) => {
2986 let msg = "unexpected message type during COPY from stdin";
2987 self.adapter_client.retire_execute(
2988 std::mem::take(ctx_extra),
2989 StatementEndedExecutionReason::Errored {
2990 error: msg.to_string(),
2991 },
2992 );
2993 drop(writer);
2994 self.conn.set_copy_mode(false);
2995 return self
2996 .send_error_and_get_state(ErrorResponse::error(
2997 SqlState::PROTOCOL_VIOLATION,
2998 msg,
2999 ))
3000 .await;
3001 }
3002 None => {
3003 drop(writer);
3004 self.conn.set_copy_mode(false);
3005 return Ok(State::Done);
3006 }
3007 }
3008 }
3009 }
3010
3011 if let Some((code, msg)) = copy_from_error {
3012 self.adapter_client.retire_execute(
3013 std::mem::take(ctx_extra),
3014 StatementEndedExecutionReason::Errored { error: msg.clone() },
3015 );
3016 drop(writer);
3017 self.conn.set_copy_mode(false);
3018 return self
3019 .send_error_and_get_state(ErrorResponse::error(code, msg))
3020 .await;
3021 }
3022
3023 self.conn.set_copy_mode(false);
3024
3025 drop(writer.batch_txs);
3030
3031 let (proto_batches, row_count) = match writer.completion_rx.await {
3033 Ok(Ok(result)) => result,
3034 Ok(Err(e)) => {
3035 self.adapter_client.retire_execute(
3036 std::mem::take(ctx_extra),
3037 StatementEndedExecutionReason::Errored {
3038 error: e.to_string(),
3039 },
3040 );
3041 return self
3042 .send_error_and_get_state(e.into_response(Severity::Error))
3043 .await;
3044 }
3045 Err(_) => {
3046 let msg = "COPY FROM STDIN: background batch builder tasks dropped";
3047 self.adapter_client.retire_execute(
3048 std::mem::take(ctx_extra),
3049 StatementEndedExecutionReason::Errored {
3050 error: msg.to_string(),
3051 },
3052 );
3053 return self
3054 .send_error_and_get_state(ErrorResponse::error(SqlState::INTERNAL_ERROR, msg))
3055 .await;
3056 }
3057 };
3058
3059 if let Err(e) = self
3061 .adapter_client
3062 .stage_copy_from_stdin_batches(target_id, proto_batches)
3063 {
3064 self.adapter_client.retire_execute(
3065 std::mem::take(ctx_extra),
3066 StatementEndedExecutionReason::Errored {
3067 error: e.to_string(),
3068 },
3069 );
3070 return self
3071 .send_error_and_get_state(e.into_response(Severity::Error))
3072 .await;
3073 }
3074
3075 let tag = format!("COPY {}", row_count);
3076 self.send(BackendMessage::CommandComplete { tag }).await?;
3077
3078 Ok(State::Ready)
3079 }
3080
3081 #[instrument(level = "debug")]
3082 async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
3083 let notices = self
3084 .adapter_client
3085 .session()
3086 .drain_notices()
3087 .into_iter()
3088 .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
3089 self.send_all(notices).await?;
3090 Ok(())
3091 }
3092
3093 #[instrument(level = "debug")]
3094 async fn send_error_and_get_state(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
3095 assert!(err.severity.is_error());
3096 debug!(
3097 "cid={} error code={}",
3098 self.adapter_client.session().conn_id(),
3099 err.code.code()
3100 );
3101 let is_fatal = err.severity.is_fatal();
3102 self.send(BackendMessage::ErrorResponse(err)).await?;
3103
3104 let txn = self.adapter_client.session().transaction();
3105 match txn {
3106 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
3109 TransactionStatus::Started(_) => {
3111 self.rollback_transaction().await?;
3112 }
3113 TransactionStatus::InTransactionImplicit(_) => {
3115 self.rollback_transaction().await?;
3116 }
3117 TransactionStatus::InTransaction(_) => {
3119 self.adapter_client.fail_transaction();
3120 }
3121 };
3122 if is_fatal {
3123 Ok(State::Done)
3124 } else {
3125 Ok(State::Drain)
3126 }
3127 }
3128
3129 #[instrument(level = "debug")]
3130 async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
3131 self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
3132 SqlState::IN_FAILED_SQL_TRANSACTION,
3133 ABORTED_TXN_MSG,
3134 )))
3135 .await?;
3136 Ok(State::Drain)
3137 }
3138
3139 fn is_aborted_txn(&mut self) -> bool {
3140 matches!(
3141 self.adapter_client.session().transaction(),
3142 TransactionStatus::Failed(_)
3143 )
3144 }
3145}
3146
3147fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
3148 match (formats.len(), n) {
3149 (0, e) => Ok(vec![Format::Text; e]),
3150 (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
3151 (a, e) if a == e => Ok(formats),
3152 (a, e) => Err(format!(
3153 "expected {} field format specifiers, but got {}",
3154 e, a
3155 )),
3156 }
3157}
3158
3159fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
3160 match &stmt_desc.relation_desc {
3161 Some(desc) if !stmt_desc.is_copy => {
3162 BackendMessage::RowDescription(message::encode_row_description(desc, formats))
3163 }
3164 _ => BackendMessage::NoData,
3165 }
3166}
3167
3168type GetResponse = fn(
3169 max_rows: ExecuteCount,
3170 total_sent_rows: usize,
3171 fetch_portal: Option<PortalRefMut>,
3172) -> BackendMessage;
3173
3174fn portal_exec_message(
3177 max_rows: ExecuteCount,
3178 total_sent_rows: usize,
3179 _fetch_portal: Option<PortalRefMut>,
3180) -> BackendMessage {
3181 match max_rows {
3188 ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
3189 BackendMessage::PortalSuspended
3190 }
3191 _ => BackendMessage::CommandComplete {
3192 tag: format!("SELECT {}", total_sent_rows),
3193 },
3194 }
3195}
3196
3197fn fetch_message(
3199 _max_rows: ExecuteCount,
3200 total_sent_rows: usize,
3201 fetch_portal: Option<PortalRefMut>,
3202) -> BackendMessage {
3203 let tag = format!("FETCH {}", total_sent_rows);
3204 if let Some(portal) = fetch_portal {
3205 *portal.state = PortalState::Completed(Some(tag.clone()));
3206 }
3207 BackendMessage::CommandComplete { tag }
3208}
3209
3210fn get_authenticator(
3211 authenticator_kind: listeners::AuthenticatorKind,
3212 frontegg: Option<FronteggAuthenticator>,
3213 oidc: GenericOidcAuthenticator,
3214 adapter_client: mz_adapter::Client,
3215) -> Authenticator {
3216 match authenticator_kind {
3217 listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
3218 "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
3219 )),
3220 listeners::AuthenticatorKind::Password => Authenticator::Password(adapter_client),
3221 listeners::AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client),
3222 listeners::AuthenticatorKind::Oidc => Authenticator::Oidc(oidc),
3223 listeners::AuthenticatorKind::None => Authenticator::None,
3224 }
3225}
3226
3227#[derive(Debug, Copy, Clone)]
3228enum ExecuteCount {
3229 All,
3230 Count(usize),
3231}
3232
3233fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
3235 match stmt {
3236 Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
3238 None => false,
3239 }
3240}
3241
3242#[derive(Debug)]
3243enum FetchResult {
3244 Rows(Option<Box<dyn RowIterator + Send + Sync>>),
3245 Canceled,
3246 Error(String),
3247 Notice(AdapterNotice),
3248}
3249
3250#[derive(Debug)]
3251struct CopyRowScanner {
3252 scan_pos: usize,
3253 last_row_end: Option<usize>,
3254 end_marker_end: Option<usize>,
3255 record_start: usize,
3260 csv: Option<CsvScanState>,
3261}
3262
3263#[derive(Debug)]
3264struct CsvScanState {
3265 reader: csv_core::Reader,
3266 output: Vec<u8>,
3267 ends: Vec<usize>,
3268 skip_first_record: bool,
3269}
3270
3271impl CopyRowScanner {
3272 fn new(params: &CopyFormatParams<'_>) -> Self {
3273 let csv = match params {
3274 CopyFormatParams::Csv(CopyCsvFormatParams {
3275 delimiter,
3276 quote,
3277 escape,
3278 header,
3279 ..
3280 }) => Some(CsvScanState::new(*delimiter, *quote, *escape, *header)),
3281 _ => None,
3282 };
3283
3284 CopyRowScanner {
3285 scan_pos: 0,
3286 last_row_end: None,
3287 end_marker_end: None,
3288 record_start: 0,
3289 csv,
3290 }
3291 }
3292
3293 fn scan_new_bytes(&mut self, data: &[u8]) {
3294 if self.scan_pos >= data.len() {
3295 return;
3296 }
3297
3298 if let Some(csv) = self.csv.as_mut() {
3299 let mut input = &data[self.scan_pos..];
3300 let mut consumed = 0usize;
3301 while !input.is_empty() {
3302 let (result, n_input, _n_output, _n_ends) =
3303 csv.reader
3304 .read_record(input, &mut csv.output, &mut csv.ends);
3305 consumed += n_input;
3306 input = &input[n_input..];
3307
3308 match result {
3309 ReadRecordResult::InputEmpty => break,
3310 ReadRecordResult::OutputFull => {
3311 if n_input == 0 {
3312 csv.output
3313 .resize(csv.output.len().saturating_mul(2).max(1), 0);
3314 }
3315 }
3316 ReadRecordResult::OutputEndsFull => {
3317 if n_input == 0 {
3318 csv.ends.resize(csv.ends.len().saturating_mul(2).max(1), 0);
3319 }
3320 }
3321 ReadRecordResult::Record | ReadRecordResult::End => {
3322 let row_end = self.scan_pos + consumed;
3323 self.last_row_end = Some(row_end);
3324 if self.end_marker_end.is_none() {
3325 let is_marker = if csv.skip_first_record {
3326 csv.skip_first_record = false;
3327 false
3328 } else {
3329 let raw = &data[self.record_start..row_end];
3335 let start = raw
3346 .iter()
3347 .take_while(|&&b| b == b'\r' || b == b'\n')
3348 .count();
3349 let trailing = raw[start..]
3350 .iter()
3351 .rev()
3352 .take_while(|&&b| b == b'\r' || b == b'\n')
3353 .count();
3354 let trimmed = &raw[start..raw.len() - trailing];
3355 trimmed == b"\\."
3356 };
3357 if is_marker {
3358 self.end_marker_end = Some(row_end);
3359 self.record_start = row_end;
3360 break;
3361 }
3362 }
3363 self.record_start = row_end;
3364 }
3365 }
3366 }
3367 } else {
3368 let mut row_start = self.last_row_end.unwrap_or(0);
3369 for (offset, b) in data[self.scan_pos..].iter().enumerate() {
3370 if *b == b'\n' {
3371 let row_end = self.scan_pos + offset + 1;
3372 self.last_row_end = Some(row_end);
3373 if self.end_marker_end.is_none() {
3374 let row = &data[row_start..row_end];
3375 if row.get(0..2) == Some(b"\\.") {
3376 self.end_marker_end = Some(row_end);
3377 break;
3378 }
3379 }
3380 row_start = row_end;
3381 }
3382 }
3383 }
3384
3385 self.scan_pos = data.len();
3386 }
3387
3388 fn last_row_end(&self) -> Option<usize> {
3389 self.last_row_end
3390 }
3391
3392 fn end_marker_end(&self) -> Option<usize> {
3393 self.end_marker_end
3394 }
3395
3396 fn current_row_size(&self, data_len: usize) -> usize {
3397 data_len.saturating_sub(self.last_row_end.unwrap_or(0))
3398 }
3399
3400 fn on_split(&mut self, split_pos: usize) {
3401 self.scan_pos = self.scan_pos.saturating_sub(split_pos);
3402 self.last_row_end = None;
3403 self.end_marker_end = self
3404 .end_marker_end
3405 .and_then(|end| end.checked_sub(split_pos));
3406 soft_assert_or_log!(
3414 self.csv.is_none() || self.record_start >= split_pos,
3415 "split bisected an in-progress CSV record: record_start={} < split_pos={}",
3416 self.record_start,
3417 split_pos,
3418 );
3419 self.record_start = self.record_start.saturating_sub(split_pos);
3420 }
3421
3422 fn on_truncate(&mut self, new_len: usize) {
3423 self.scan_pos = self.scan_pos.min(new_len);
3424 self.last_row_end = self.last_row_end.filter(|&end| end <= new_len);
3425 self.end_marker_end = self.end_marker_end.filter(|&end| end <= new_len);
3426 self.record_start = self.record_start.min(new_len);
3427 }
3428}
3429
3430impl CsvScanState {
3431 fn new(delimiter: u8, quote: u8, escape: u8, header: bool) -> Self {
3432 let (double_quote, escape) = if quote == escape {
3433 (true, None)
3434 } else {
3435 (false, Some(escape))
3436 };
3437 CsvScanState {
3438 reader: csv_core::ReaderBuilder::new()
3439 .delimiter(delimiter)
3440 .quote(quote)
3441 .double_quote(double_quote)
3442 .escape(escape)
3443 .build(),
3444 output: vec![0; 1],
3445 ends: vec![0; 1],
3446 skip_first_record: header,
3447 }
3448 }
3449}
3450
3451#[cfg(test)]
3452mod test {
3453 use super::*;
3454
3455 #[mz_ore::test]
3456 fn test_copy_row_scanner_end_marker_line_endings() {
3457 let params = CopyFormatParams::Csv(CopyCsvFormatParams::default());
3465
3466 let marker_end = |data: &[u8]| -> Option<usize> {
3467 let mut scanner = CopyRowScanner::new(¶ms);
3468 scanner.scan_new_bytes(data);
3469 scanner.end_marker_end()
3470 };
3471
3472 for eol in [&b"\n"[..], b"\r\n", b"\r"] {
3473 let join = |lines: &[&str]| -> Vec<u8> {
3474 let mut out = Vec::new();
3475 for line in lines {
3476 out.extend_from_slice(line.as_bytes());
3477 out.extend_from_slice(eol);
3478 }
3479 out
3480 };
3481
3482 let data = join(&["first", "\\.", "after"]);
3487 let mut prefix = Vec::new();
3488 prefix.extend_from_slice(b"first");
3489 prefix.extend_from_slice(eol);
3490 prefix.extend_from_slice(b"\\.");
3491 assert_eq!(
3492 marker_end(&data),
3493 Some(prefix.len() + 1),
3494 "bare marker, eol={eol:?}"
3495 );
3496
3497 let data = join(&["before", "\"\\.\"", "after"]);
3499 assert_eq!(marker_end(&data), None, "quoted marker, eol={eol:?}");
3500 }
3501 }
3502
3503 #[mz_ore::test]
3504 fn test_copy_row_scanner_non_csv_split() {
3505 for params in [
3512 CopyFormatParams::Text(CopyTextFormatParams::default()),
3513 CopyFormatParams::Binary,
3514 ] {
3515 let mut scanner = CopyRowScanner::new(¶ms);
3516 let data = b"1\thello world\t2\tsome text value here\n\
3517 3\thello world\t6\tsome text value here\n";
3518 scanner.scan_new_bytes(data);
3519 let split_pos = scanner.last_row_end().expect("a complete row");
3520 assert!(split_pos > 0, "params={params:?}");
3521 scanner.on_split(split_pos);
3523 assert_eq!(scanner.record_start, 0, "params={params:?}");
3524 }
3525 }
3526
3527 #[mz_ore::test]
3528 fn test_parse_options() {
3529 struct TestCase {
3530 input: &'static str,
3531 expect: Result<Vec<(&'static str, &'static str)>, ()>,
3532 }
3533 let tests = vec![
3534 TestCase {
3535 input: "",
3536 expect: Ok(vec![]),
3537 },
3538 TestCase {
3539 input: "--key",
3540 expect: Err(()),
3541 },
3542 TestCase {
3543 input: "--key=val",
3544 expect: Ok(vec![("key", "val")]),
3545 },
3546 TestCase {
3547 input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
3548 expect: Ok(vec![
3549 ("key", "val"),
3550 ("key2", "val2"),
3551 ("key3", "val3"),
3552 ("key4", "val4"),
3553 ("key5", "val5"),
3554 ]),
3555 },
3556 TestCase {
3557 input: r#"-c\ key=val"#,
3558 expect: Ok(vec![(" key", "val")]),
3559 },
3560 TestCase {
3561 input: "--key=val -ckey2 val2",
3562 expect: Err(()),
3563 },
3564 TestCase {
3566 input: "--key=",
3567 expect: Ok(vec![("key", "")]),
3568 },
3569 ];
3570 for test in tests {
3571 let got = parse_options(test.input);
3572 let expect = test.expect.map(|r| {
3573 r.into_iter()
3574 .map(|(k, v)| (k.to_owned(), v.to_owned()))
3575 .collect()
3576 });
3577 assert_eq!(got, expect, "input: {}", test.input);
3578 }
3579 }
3580
3581 #[mz_ore::test]
3582 fn test_parse_option() {
3583 struct TestCase {
3584 input: &'static str,
3585 expect: Result<(&'static str, &'static str), ()>,
3586 }
3587 let tests = vec![
3588 TestCase {
3589 input: "",
3590 expect: Err(()),
3591 },
3592 TestCase {
3593 input: "--",
3594 expect: Err(()),
3595 },
3596 TestCase {
3597 input: "--c",
3598 expect: Err(()),
3599 },
3600 TestCase {
3601 input: "a=b",
3602 expect: Err(()),
3603 },
3604 TestCase {
3605 input: "--a=b",
3606 expect: Ok(("a", "b")),
3607 },
3608 TestCase {
3609 input: "--ca=b",
3610 expect: Ok(("ca", "b")),
3611 },
3612 TestCase {
3613 input: "-ca=b",
3614 expect: Ok(("a", "b")),
3615 },
3616 TestCase {
3618 input: "--=",
3619 expect: Ok(("", "")),
3620 },
3621 ];
3622 for test in tests {
3623 let got = parse_option(test.input);
3624 assert_eq!(got, test.expect, "input: {}", test.input);
3625 }
3626 }
3627
3628 #[mz_ore::test]
3629 fn test_split_options() {
3630 struct TestCase {
3631 input: &'static str,
3632 expect: Vec<&'static str>,
3633 }
3634 let tests = vec![
3635 TestCase {
3636 input: "",
3637 expect: vec![],
3638 },
3639 TestCase {
3640 input: " ",
3641 expect: vec![],
3642 },
3643 TestCase {
3644 input: " a ",
3645 expect: vec!["a"],
3646 },
3647 TestCase {
3648 input: " ab cd ",
3649 expect: vec!["ab", "cd"],
3650 },
3651 TestCase {
3652 input: r#" ab\ cd "#,
3653 expect: vec!["ab ", "cd"],
3654 },
3655 TestCase {
3656 input: r#" ab\\ cd "#,
3657 expect: vec![r#"ab\"#, "cd"],
3658 },
3659 TestCase {
3660 input: r#" ab\\\ cd "#,
3661 expect: vec![r#"ab\ "#, "cd"],
3662 },
3663 TestCase {
3664 input: r#" ab\\\ cd "#,
3665 expect: vec![r#"ab\ cd"#],
3666 },
3667 TestCase {
3668 input: r#" ab\\\cd "#,
3669 expect: vec![r#"ab\cd"#],
3670 },
3671 TestCase {
3672 input: r#"a\"#,
3673 expect: vec!["a"],
3674 },
3675 TestCase {
3676 input: r#"a\ "#,
3677 expect: vec!["a "],
3678 },
3679 TestCase {
3680 input: r#"\"#,
3681 expect: vec![],
3682 },
3683 TestCase {
3684 input: r#"\ "#,
3685 expect: vec![r#" "#],
3686 },
3687 TestCase {
3688 input: r#" \ "#,
3689 expect: vec![r#" "#],
3690 },
3691 TestCase {
3692 input: r#"\ "#,
3693 expect: vec![r#" "#],
3694 },
3695 ];
3696 for test in tests {
3697 let got = split_options(test.input);
3698 assert_eq!(got, test.expect, "input: {}", test.input);
3699 }
3700 }
3701
3702 #[mz_ore::test]
3703 fn test_is_jwt() {
3704 assert!(is_jwt("eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0.signature"));
3706 for s in [
3708 "",
3709 "secure_password",
3710 "p4ss.w0rd",
3711 "aaa.bbb.ccc",
3712 "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0",
3713 "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0.sig.extra",
3714 ] {
3715 assert!(!is_jwt(s), "is_jwt({s:?})");
3716 }
3717 }
3718}