1use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use csv_core::ReadRecordResult;
21use futures::future::{BoxFuture, FutureExt, pending};
22use itertools::Itertools;
23use mz_adapter::client::RecordFirstRowStream;
24use mz_adapter::session::{
25 EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState, Session,
26 SessionConfig, TransactionStatus,
27};
28use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
29use mz_adapter::{
30 AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics,
31 verify_datum_desc,
32};
33use mz_adapter_types::dyncfgs::OIDC_GROUP_CLAIM;
34use mz_auth::Authenticated;
35use mz_auth::password::Password;
36use mz_authenticator::{Authenticator, GenericOidcAuthenticator};
37use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
38use mz_ore::cast::CastFrom;
39use mz_ore::netio::AsyncReady;
40use mz_ore::now::{EpochMillis, SYSTEM_TIME};
41use mz_ore::str::StrExt;
42use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log, soft_assert_or_log};
43use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
44use mz_pgwire_common::{
45 ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
46 VERSIONS,
47};
48use mz_repr::{
49 CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
50 SqlRelationType, SqlScalarType,
51};
52use mz_server_core::TlsMode;
53use mz_server_core::listeners;
54use mz_server_core::listeners::AllowedRoles;
55use mz_sql::ast::display::AstDisplay;
56use mz_sql::ast::{
57 CopyDirection, CopyStatement, CopyTarget, FetchDirection, Ident, Raw, Statement,
58};
59use mz_sql::parse::StatementParseResult;
60use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
61use mz_sql::session::metadata::SessionMetadata;
62use mz_sql::session::user::INTERNAL_USER_NAMES;
63use mz_sql::session::vars::VarInput;
64use postgres::error::SqlState;
65use tokio::io::{self, AsyncRead, AsyncWrite};
66use tokio::select;
67use tokio::time::{self};
68use tokio_metrics::TaskMetrics;
69use tokio_stream::wrappers::UnboundedReceiverStream;
70use tracing::{Instrument, debug, debug_span, warn};
71use uuid::Uuid;
72
73use crate::codec::{
74 FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
75};
76use crate::message::{
77 self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
78 SASLServerFirstMessage,
79};
80
81pub fn match_handshake(buf: &[u8]) -> bool {
85 if buf.len() < 8 {
95 return false;
96 }
97 let version = NetworkEndian::read_i32(&buf[4..8]);
98 VERSIONS.contains(&version)
99}
100
101pub struct RunParams<'a, A, I>
103where
104 I: Iterator<Item = TaskMetrics> + Send,
105{
106 pub tls_mode: Option<TlsMode>,
108 pub adapter_client: mz_adapter::Client,
110 pub conn: &'a mut FramedConn<A>,
112 pub conn_uuid: Uuid,
114 pub version: i32,
116 pub params: BTreeMap<String, String>,
118 pub frontegg: Option<FronteggAuthenticator>,
120 pub oidc: GenericOidcAuthenticator,
122 pub authenticator_kind: listeners::AuthenticatorKind,
125 pub active_connection_counter: ConnectionCounter,
127 pub helm_chart_version: Option<String>,
129 pub allowed_roles: AllowedRoles,
131 pub tokio_metrics_intervals: I,
133}
134
135#[mz_ore::instrument(level = "debug")]
145pub async fn run<'a, A, I>(
146 RunParams {
147 tls_mode,
148 adapter_client,
149 conn,
150 conn_uuid,
151 version,
152 mut params,
153 frontegg,
154 oidc,
155 authenticator_kind,
156 active_connection_counter,
157 helm_chart_version,
158 allowed_roles,
159 tokio_metrics_intervals,
160 }: RunParams<'a, A, I>,
161) -> Result<(), io::Error>
162where
163 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
164 I: Iterator<Item = TaskMetrics> + Send,
165{
166 if version != VERSION_3 {
167 return conn
168 .send(ErrorResponse::fatal(
169 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
170 "server does not support the client's requested protocol version",
171 ))
172 .await;
173 }
174
175 let user = params.remove("user").unwrap_or_else(String::new);
176 let options = parse_options(params.get("options").unwrap_or(&String::new()));
177 let authenticator =
178 get_authenticator(authenticator_kind, frontegg, oidc, adapter_client.clone());
179 let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
181 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
183 let role_allowed = match allowed_roles {
184 AllowedRoles::Normal => !is_reserved_user,
185 AllowedRoles::Internal => is_internal_user,
186 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
187 };
188 if !role_allowed {
189 let msg = format!("unauthorized login to user '{user}'");
190 return conn
191 .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
192 .await;
193 }
194
195 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
196 return conn.send(err).await;
197 }
198
199 let authenticator_kind = authenticator.kind();
200
201 let (mut session, expired) = match authenticator {
202 Authenticator::Frontegg(frontegg) => {
203 let password = match request_cleartext_password(conn).await {
204 Ok(password) => password,
205 Err(PasswordRequestError::IoError(e)) => return Err(e),
206 Err(PasswordRequestError::InvalidPasswordError(e)) => {
207 return conn.send(e).await;
208 }
209 };
210
211 let group_claim =
212 OIDC_GROUP_CLAIM.get(adapter_client.get_system_vars().await.dyncfgs());
213 let auth_response = frontegg
214 .authenticate(&user, &password, Some(&group_claim))
215 .await;
216 match auth_response {
217 Ok((mut auth_session, authenticated)) => {
224 let groups = auth_session.groups();
225 let session = adapter_client.new_session(
226 SessionConfig {
227 conn_id: conn.conn_id().clone(),
228 uuid: conn_uuid,
229 user: auth_session.user().into(),
230 client_ip: conn.peer_addr().clone(),
231 external_metadata_rx: Some(auth_session.external_metadata_rx()),
232 helm_chart_version,
233 authenticator_kind,
234 groups,
235 },
236 authenticated,
237 );
238 let expired = async move { auth_session.expired().await };
239 (session, expired.left_future())
240 }
241 Err(err) => {
242 warn!(?err, "pgwire connection failed authentication");
243 return conn
244 .send(ErrorResponse::fatal(
245 SqlState::INVALID_PASSWORD,
246 "invalid password",
247 ))
248 .await;
249 }
250 }
251 }
252 Authenticator::Oidc(oidc) => {
253 let password = match request_cleartext_password(conn).await {
256 Ok(password) => password,
257 Err(PasswordRequestError::IoError(e)) => return Err(e),
258 Err(PasswordRequestError::InvalidPasswordError(e)) => {
259 return conn.send(e).await;
260 }
261 };
262 if is_jwt(&password) {
263 let auth_response = oidc.authenticate(&password, Some(&user)).await;
264 match auth_response {
265 Ok((mut claims, authenticated)) => {
266 let groups = claims.groups.take();
267 let session = adapter_client.new_session(
268 SessionConfig {
269 conn_id: conn.conn_id().clone(),
270 uuid: conn_uuid,
271 user: std::mem::take(&mut claims.user),
272 client_ip: conn.peer_addr().clone(),
273 external_metadata_rx: None,
274 helm_chart_version,
275 authenticator_kind,
276 groups,
277 },
278 authenticated,
279 );
280 (session, pending().right_future())
283 }
284 Err(err) => {
285 warn!(?err, "pgwire connection failed authentication");
286 return conn.send(err.into_response()).await;
287 }
288 }
289 } else {
290 let session = match authenticate_with_password(
291 conn,
292 &adapter_client,
293 user,
294 Password(password),
295 conn_uuid,
296 helm_chart_version,
297 )
298 .await
299 {
300 Ok(session) => session,
301 Err(PasswordRequestError::IoError(e)) => return Err(e),
302 Err(PasswordRequestError::InvalidPasswordError(e)) => {
303 return conn.send(e).await;
304 }
305 };
306 (session, pending().right_future())
307 }
308 }
309 Authenticator::Password(adapter_client) => {
310 let password = match request_cleartext_password(conn).await {
311 Ok(password) => password,
312 Err(PasswordRequestError::IoError(e)) => return Err(e),
313 Err(PasswordRequestError::InvalidPasswordError(e)) => {
314 return conn.send(e).await;
315 }
316 };
317 let session = match authenticate_with_password(
318 conn,
319 &adapter_client,
320 user,
321 Password(password),
322 conn_uuid,
323 helm_chart_version,
324 )
325 .await
326 {
327 Ok(session) => session,
328 Err(PasswordRequestError::IoError(e)) => return Err(e),
329 Err(PasswordRequestError::InvalidPasswordError(e)) => {
330 return conn.send(e).await;
331 }
332 };
333 (session, pending().right_future())
335 }
336 Authenticator::Sasl(adapter_client) => {
337 conn.send(BackendMessage::AuthenticationSASL).await?;
339 conn.flush().await?;
340 let (mechanism, initial_response) = match conn.recv().await? {
342 Some(FrontendMessage::RawAuthentication(data)) => {
343 match decode_sasl_initial_response(Cursor::new(&data)).ok() {
344 Some(FrontendMessage::SASLInitialResponse {
345 gs2_header,
346 mechanism,
347 initial_response,
348 }) => {
349 if gs2_header.channel_binding_enabled() {
351 return conn
352 .send(ErrorResponse::fatal(
353 SqlState::PROTOCOL_VIOLATION,
354 "channel binding not supported",
355 ))
356 .await;
357 }
358 (mechanism, initial_response)
359 }
360 _ => {
361 return conn
362 .send(ErrorResponse::fatal(
363 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
364 "expected SASLInitialResponse message",
365 ))
366 .await;
367 }
368 }
369 }
370 _ => {
371 return conn
372 .send(ErrorResponse::fatal(
373 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
374 "expected SASLInitialResponse message",
375 ))
376 .await;
377 }
378 };
379
380 if mechanism != "SCRAM-SHA-256" {
381 return conn
382 .send(ErrorResponse::fatal(
383 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
384 "unsupported SASL mechanism",
385 ))
386 .await;
387 }
388
389 if initial_response.nonce.len() > 256 {
390 return conn
391 .send(ErrorResponse::fatal(
392 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
393 "nonce too long",
394 ))
395 .await;
396 }
397
398 let (server_first_message_raw, mock_hash) = match adapter_client
399 .generate_sasl_challenge(&user, &initial_response.nonce)
400 .await
401 {
402 Ok(response) => {
403 let server_first_message_raw = format!(
404 "r={},s={},i={}",
405 response.nonce, response.salt, response.iteration_count
406 );
407
408 let client_key = [0u8; 32];
409 let server_key = [1u8; 32];
410 let mock_hash = format!(
411 "SCRAM-SHA-256${}:{}${}:{}",
412 response.iteration_count,
413 response.salt,
414 BASE64_STANDARD.encode(client_key),
415 BASE64_STANDARD.encode(server_key)
416 );
417
418 conn.send(BackendMessage::AuthenticationSASLContinue(
419 SASLServerFirstMessage {
420 iteration_count: response.iteration_count,
421 nonce: response.nonce,
422 salt: response.salt,
423 },
424 ))
425 .await?;
426 conn.flush().await?;
427 (server_first_message_raw, mock_hash)
428 }
429 Err(e) => {
430 return conn.send(e.into_response(Severity::Fatal)).await;
431 }
432 };
433
434 let authenticated = match conn.recv().await? {
435 Some(FrontendMessage::RawAuthentication(data)) => {
436 match decode_sasl_response(Cursor::new(&data)).ok() {
437 Some(FrontendMessage::SASLResponse(response)) => {
438 let auth_message = format!(
439 "{},{},{}",
440 initial_response.client_first_message_bare_raw,
441 server_first_message_raw,
442 response.client_final_message_bare_raw
443 );
444 if response.proof.len() > 1024 {
445 return conn
446 .send(ErrorResponse::fatal(
447 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
448 "proof too long",
449 ))
450 .await;
451 }
452 match adapter_client
453 .verify_sasl_proof(
454 &user,
455 &response.proof,
456 &auth_message,
457 &mock_hash,
458 )
459 .await
460 {
461 Ok((proof_response, authenticated)) => {
462 conn.send(BackendMessage::AuthenticationSASLFinal(
463 SASLServerFinalMessage {
464 kind: SASLServerFinalMessageKinds::Verifier(
465 proof_response.verifier,
466 ),
467 extensions: vec![],
468 },
469 ))
470 .await?;
471 conn.flush().await?;
472 authenticated
473 }
474 Err(_) => {
475 return conn
476 .send(ErrorResponse::fatal(
477 SqlState::INVALID_PASSWORD,
478 "invalid password",
479 ))
480 .await;
481 }
482 }
483 }
484 _ => {
485 return conn
486 .send(ErrorResponse::fatal(
487 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
488 "expected SASLResponse message",
489 ))
490 .await;
491 }
492 }
493 }
494 _ => {
495 return conn
496 .send(ErrorResponse::fatal(
497 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
498 "expected SASLResponse message",
499 ))
500 .await;
501 }
502 };
503
504 let session = adapter_client.new_session(
505 SessionConfig {
506 conn_id: conn.conn_id().clone(),
507 uuid: conn_uuid,
508 user,
509 client_ip: conn.peer_addr().clone(),
510 external_metadata_rx: None,
511 helm_chart_version,
512 authenticator_kind,
513 groups: None,
514 },
515 authenticated,
516 );
517 let auth_session = pending().right_future();
519 (session, auth_session)
520 }
521
522 Authenticator::None => {
523 let session = adapter_client.new_session(
524 SessionConfig {
525 conn_id: conn.conn_id().clone(),
526 uuid: conn_uuid,
527 user,
528 client_ip: conn.peer_addr().clone(),
529 external_metadata_rx: None,
530 helm_chart_version,
531 authenticator_kind,
532 groups: None,
533 },
534 Authenticated,
535 );
536 let auth_session = pending().right_future();
538 (session, auth_session)
539 }
540 };
541
542 let system_vars = adapter_client.get_system_vars().await;
543 for (name, value) in params {
544 let settings = match name.as_str() {
545 "options" => match &options {
546 Ok(opts) => opts,
547 Err(()) => {
548 session.add_notice(AdapterNotice::BadStartupSetting {
549 name,
550 reason: "could not parse".into(),
551 });
552 continue;
553 }
554 },
555 _ => &vec![(name, value)],
556 };
557 for (key, val) in settings {
558 const LOCAL: bool = false;
559 if let Err(err) = session
564 .vars_mut()
565 .set(&system_vars, key, VarInput::Flat(val), LOCAL)
566 {
567 session.add_notice(AdapterNotice::BadStartupSetting {
568 name: key.clone(),
569 reason: err.to_string(),
570 });
571 }
572 }
573 }
574 session
575 .vars_mut()
576 .end_transaction(EndTransactionAction::Commit);
577
578 let _guard = match active_connection_counter.allocate_connection(session.user()) {
579 Ok(drop_connection) => drop_connection,
580 Err(e) => {
581 let e: AdapterError = e.into();
582 return conn.send(e.into_response(Severity::Fatal)).await;
583 }
584 };
585
586 let mut adapter_client = match adapter_client.startup(session).await {
588 Ok(adapter_client) => adapter_client,
589 Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
590 };
591
592 let mut buf = vec![BackendMessage::AuthenticationOk];
593 for var in adapter_client.session().vars().notify_set() {
594 buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
595 }
596 buf.push(BackendMessage::BackendKeyData {
597 conn_id: adapter_client.session().conn_id().unhandled(),
598 secret_key: adapter_client.session().secret_key(),
599 });
600 buf.extend(
601 adapter_client
602 .session()
603 .drain_notices()
604 .into_iter()
605 .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
606 );
607 buf.push(BackendMessage::ReadyForQuery(
608 adapter_client.session().transaction().into(),
609 ));
610 conn.send_all(buf).await?;
611 conn.flush().await?;
612
613 let machine = StateMachine {
614 conn,
615 adapter_client,
616 txn_needs_commit: false,
617 tokio_metrics_intervals,
618 };
619
620 select! {
621 r = machine.run() => {
622 if let Err(err) = &r {
627 let _ = conn
628 .send(ErrorResponse::fatal(
629 SqlState::CONNECTION_FAILURE,
630 err.to_string(),
631 ))
632 .await;
633 let _ = conn.flush().await;
634 }
635 r
636 },
637 _ = expired => {
638 conn
639 .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
640 .await?;
641 conn.flush().await
642 }
643 }
644}
645
646fn is_jwt(password: &str) -> bool {
649 jsonwebtoken::decode_header(password).is_ok()
650}
651
652fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
657 let opts = split_options(value);
658 let mut pairs = Vec::with_capacity(opts.len());
659 let mut seen_prefix = false;
660 for opt in opts {
661 if !seen_prefix {
662 if opt == "-c" {
663 seen_prefix = true;
664 } else {
665 let (key, val) = parse_option(&opt)?;
666 pairs.push((key.to_owned(), val.to_owned()));
667 }
668 } else {
669 let (key, val) = opt.split_once('=').ok_or(())?;
670 pairs.push((key.to_owned(), val.to_owned()));
671 seen_prefix = false;
672 }
673 }
674 Ok(pairs)
675}
676
677fn parse_option(option: &str) -> Result<(&str, &str), ()> {
681 let (key, value) = option.split_once('=').ok_or(())?;
682 for prefix in &["-c", "--"] {
683 if let Some(key) = key.strip_prefix(prefix) {
684 return Ok((key, value));
685 }
686 }
687 Err(())
688}
689
690fn split_options(value: &str) -> Vec<String> {
692 let mut strs = Vec::new();
693 let mut current = String::new();
697 let mut was_slash = false;
698 for c in value.chars() {
699 was_slash = match c {
700 ' ' => {
701 if was_slash {
702 current.push(' ');
703 } else if !current.is_empty() {
704 strs.push(std::mem::take(&mut current));
707 }
708 false
709 }
710 '\\' => {
711 if was_slash {
712 current.push('\\');
715 false
716 } else {
717 true
718 }
719 }
720 _ => {
721 current.push(c);
722 false
723 }
724 };
725 }
726 if !current.is_empty() {
728 strs.push(current);
729 }
730 strs
731}
732
733enum PasswordRequestError {
734 InvalidPasswordError(ErrorResponse),
735 IoError(io::Error),
736}
737
738impl From<io::Error> for PasswordRequestError {
739 fn from(e: io::Error) -> Self {
740 PasswordRequestError::IoError(e)
741 }
742}
743
744async fn request_cleartext_password<A>(
748 conn: &mut FramedConn<A>,
749) -> Result<String, PasswordRequestError>
750where
751 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
752{
753 conn.send(BackendMessage::AuthenticationCleartextPassword)
754 .await?;
755 conn.flush().await?;
756
757 if let Some(message) = conn.recv().await? {
758 if let FrontendMessage::RawAuthentication(data) = message {
759 if let Some(FrontendMessage::Password { password }) =
760 decode_password(Cursor::new(&data)).ok()
761 {
762 return Ok(password);
763 }
764 }
765 }
766
767 Err(PasswordRequestError::InvalidPasswordError(
768 ErrorResponse::fatal(
769 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
770 "expected Password message",
771 ),
772 ))
773}
774
775async fn authenticate_with_password<A>(
778 conn: &FramedConn<A>,
779 adapter_client: &mz_adapter::Client,
780 user: String,
781 password: Password,
782 conn_uuid: Uuid,
783 helm_chart_version: Option<String>,
784) -> Result<Session, PasswordRequestError>
785where
786 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
787{
788 let authenticated = match adapter_client.authenticate(&user, &password).await {
789 Ok(authenticated) => authenticated,
790 Err(err) => {
791 warn!(?err, "pgwire connection failed authentication");
792 return Err(PasswordRequestError::InvalidPasswordError(
793 ErrorResponse::fatal(SqlState::INVALID_PASSWORD, "invalid password"),
794 ));
795 }
796 };
797
798 let session = adapter_client.new_session(
799 SessionConfig {
800 conn_id: conn.conn_id().clone(),
801 uuid: conn_uuid,
802 user,
803 client_ip: conn.peer_addr().clone(),
804 external_metadata_rx: None,
805 helm_chart_version,
806 authenticator_kind: mz_auth::AuthenticatorKind::Password,
807 groups: None,
808 },
809 authenticated,
810 );
811
812 Ok(session)
813}
814
815#[derive(Debug)]
816enum State {
817 Ready,
818 Drain,
819 Done,
820}
821
822struct StateMachine<'a, A, I>
823where
824 I: Iterator<Item = TaskMetrics> + Send + 'a,
825{
826 conn: &'a mut FramedConn<A>,
827 adapter_client: mz_adapter::SessionClient,
828 txn_needs_commit: bool,
829 tokio_metrics_intervals: I,
830}
831
832enum SendRowsEndedReason {
833 Success {
834 result_size: u64,
835 rows_returned: u64,
836 },
837 Errored {
838 error: String,
839 },
840 Canceled,
841}
842
843const ABORTED_TXN_MSG: &str =
844 "current transaction is aborted, commands ignored until end of transaction block";
845
846impl<'a, A, I> StateMachine<'a, A, I>
847where
848 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
849 I: Iterator<Item = TaskMetrics> + Send + 'a,
850{
851 #[allow(clippy::manual_async_fn)]
855 #[mz_ore::instrument(level = "debug")]
856 fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
857 async move {
858 let mut state = State::Ready;
859 loop {
860 self.send_pending_notices().await?;
861 state = match state {
862 State::Ready => self.advance_ready().await?,
863 State::Drain => self.advance_drain().await?,
864 State::Done => return Ok(()),
865 };
866 self.adapter_client
867 .add_idle_in_transaction_session_timeout();
868 }
869 }
870 }
871
872 #[instrument(level = "debug")]
873 async fn advance_ready(&mut self) -> Result<State, io::Error> {
874 self.tokio_metrics_intervals
876 .next()
877 .expect("infinite iterator");
878
879 let message = select! {
881 biased;
882
883 Some(timeout) = self.adapter_client.recv_timeout() => {
885 let err: AdapterError = timeout.into();
886 let conn_id = self.adapter_client.session().conn_id();
887 tracing::warn!("session timed out, conn_id {}", conn_id);
888
889 let error_response = err.into_response(Severity::Fatal);
891 let error_state = self.send_error_and_get_state(error_response).await;
892
893 self.adapter_client.terminate().await;
895
896 let _ = self.conn.recv().await?;
900 return error_state;
901 },
902 message = self.conn.recv() => message?,
904 };
905
906 let interval = self
908 .tokio_metrics_intervals
909 .next()
910 .expect("infinite iterator");
911 let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
912
913 let received = SYSTEM_TIME();
917
918 self.adapter_client
919 .remove_idle_in_transaction_session_timeout();
920
921 let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
924
925 let start = message.as_ref().map(|_| Instant::now());
926 let next_state = match message {
927 Some(FrontendMessage::Query { sql }) => {
928 let query_root_span =
929 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
930 query_root_span.follows_from(tracing::Span::current());
931 self.query(sql, received)
932 .instrument(query_root_span)
933 .await?
934 }
935 Some(FrontendMessage::Parse {
936 name,
937 sql,
938 param_types,
939 }) => self.parse(name, sql, param_types).await?,
940 Some(FrontendMessage::Bind {
941 portal_name,
942 statement_name,
943 param_formats,
944 raw_params,
945 result_formats,
946 }) => {
947 self.bind(
948 portal_name,
949 statement_name,
950 param_formats,
951 raw_params,
952 result_formats,
953 )
954 .await?
955 }
956 Some(FrontendMessage::Execute {
957 portal_name,
958 max_rows,
959 }) => {
960 let max_rows = match usize::try_from(max_rows) {
961 Ok(0) | Err(_) => ExecuteCount::All, Ok(n) => ExecuteCount::Count(n),
963 };
964 let execute_root_span =
965 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
966 execute_root_span.follows_from(tracing::Span::current());
967 let state = self
968 .execute(
969 portal_name,
970 max_rows,
971 portal_exec_message,
972 None,
973 ExecuteTimeout::None,
974 None,
975 Some(received),
976 )
977 .instrument(execute_root_span)
978 .await?;
979 if self.adapter_client.session().transaction().is_implicit() {
994 self.txn_needs_commit = true;
995 }
996 state
997 }
998 Some(FrontendMessage::DescribeStatement { name }) => {
999 self.describe_statement(&name).await?
1000 }
1001 Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
1002 Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
1003 Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
1004 Some(FrontendMessage::Flush) => self.flush().await?,
1005 Some(FrontendMessage::Sync) => self.sync().await?,
1006 Some(FrontendMessage::Terminate) => State::Done,
1007
1008 Some(FrontendMessage::CopyData(_))
1009 | Some(FrontendMessage::CopyDone)
1010 | Some(FrontendMessage::CopyFail(_))
1011 | Some(FrontendMessage::Password { .. })
1012 | Some(FrontendMessage::RawAuthentication(_))
1013 | Some(FrontendMessage::SASLInitialResponse { .. })
1014 | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
1015 None => State::Done,
1016 };
1017
1018 if let Some(start) = start {
1019 self.adapter_client
1020 .inner()
1021 .metrics()
1022 .pgwire_message_processing_seconds
1023 .with_label_values(&[message_name])
1024 .observe(start.elapsed().as_secs_f64());
1025 }
1026 self.adapter_client
1027 .inner()
1028 .metrics()
1029 .pgwire_recv_scheduling_delay_ms
1030 .with_label_values(&[message_name])
1031 .observe(recv_scheduling_delay_ms);
1032
1033 Ok(next_state)
1034 }
1035
1036 async fn advance_drain(&mut self) -> Result<State, io::Error> {
1037 let message = self.conn.recv().await?;
1038 if message.is_some() {
1039 self.adapter_client
1040 .remove_idle_in_transaction_session_timeout();
1041 }
1042 match message {
1043 Some(FrontendMessage::Sync) => self.sync().await,
1044 None => Ok(State::Done),
1045 _ => Ok(State::Drain),
1046 }
1047 }
1048
1049 #[instrument(level = "debug")]
1053 async fn one_query(
1054 &mut self,
1055 stmt: Statement<Raw>,
1056 sql: String,
1057 lifecycle_timestamps: LifecycleTimestamps,
1058 ) -> Result<State, io::Error> {
1059 const EMPTY_PORTAL: &str = "";
1062 if let Err(e) = self
1063 .adapter_client
1064 .declare(EMPTY_PORTAL.to_string(), stmt, sql)
1065 .await
1066 {
1067 return self
1068 .send_error_and_get_state(e.into_response(Severity::Error))
1069 .await;
1070 }
1071 let portal = self
1072 .adapter_client
1073 .session()
1074 .get_portal_unverified_mut(EMPTY_PORTAL)
1075 .expect("unnamed portal should be present");
1076
1077 *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
1078
1079 let stmt_desc = portal.desc.clone();
1080 if !stmt_desc.param_types.is_empty() {
1081 return self
1082 .send_error_and_get_state(ErrorResponse::error(
1083 SqlState::UNDEFINED_PARAMETER,
1084 "there is no parameter $1",
1085 ))
1086 .await;
1087 }
1088
1089 if let Some(relation_desc) = &stmt_desc.relation_desc {
1091 if !stmt_desc.is_copy {
1092 let formats = vec![Format::Text; stmt_desc.arity()];
1093 self.send(BackendMessage::RowDescription(
1094 message::encode_row_description(relation_desc, &formats),
1095 ))
1096 .await?;
1097 }
1098 }
1099
1100 let result = match self
1101 .adapter_client
1102 .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
1103 .await
1104 {
1105 Ok((response, execute_started)) => {
1106 self.send_pending_notices().await?;
1107 self.send_execute_response(
1108 response,
1109 stmt_desc.relation_desc,
1110 EMPTY_PORTAL.to_string(),
1111 ExecuteCount::All,
1112 portal_exec_message,
1113 None,
1114 ExecuteTimeout::None,
1115 execute_started,
1116 )
1117 .await
1118 }
1119 Err(e) => {
1120 self.send_pending_notices().await?;
1121 self.send_error_and_get_state(e.into_response(Severity::Error))
1122 .await
1123 }
1124 };
1125
1126 self.adapter_client.session().remove_portal(EMPTY_PORTAL);
1128
1129 result
1130 }
1131
1132 async fn ensure_transaction(
1133 &mut self,
1134 num_stmts: usize,
1135 message_type: &str,
1136 ) -> Result<(), io::Error> {
1137 let start = Instant::now();
1138 if self.txn_needs_commit {
1139 self.commit_transaction().await?;
1140 }
1141 let res = self.adapter_client.start_transaction(Some(num_stmts));
1144 assert_ok!(res);
1145 self.adapter_client
1146 .inner()
1147 .metrics()
1148 .pgwire_ensure_transaction_seconds
1149 .with_label_values(&[message_type])
1150 .observe(start.elapsed().as_secs_f64());
1151 Ok(())
1152 }
1153
1154 fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1155 let parse_start = Instant::now();
1156 let result = match self.adapter_client.parse(sql) {
1157 Ok(result) => result.map_err(|e| {
1158 let pos = sql[..e.error.pos].chars().count() + 1;
1161 ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1162 }),
1163 Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1164 };
1165 self.adapter_client
1166 .inner()
1167 .metrics()
1168 .parse_seconds
1169 .observe(parse_start.elapsed().as_secs_f64());
1170 result
1171 }
1172
1173 #[instrument(level = "debug")]
1178 async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1179 let stmts = match self.parse_sql(&sql) {
1181 Ok(stmts) => stmts,
1182 Err(err) => {
1183 self.send_error_and_get_state(err).await?;
1184 return self.ready().await;
1185 }
1186 };
1187
1188 let num_stmts = stmts.len();
1189
1190 for StatementParseResult { ast: stmt, sql } in stmts {
1192 if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1194 self.aborted_txn_error().await?;
1195 break;
1196 }
1197
1198 self.ensure_transaction(num_stmts, "query").await?;
1206
1207 match self
1208 .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1209 .await?
1210 {
1211 State::Ready => (),
1212 State::Drain => break,
1213 State::Done => return Ok(State::Done),
1214 }
1215 }
1216
1217 {
1219 if self.adapter_client.session().transaction().is_implicit() {
1220 self.commit_transaction().await?;
1221 }
1222 }
1223
1224 if num_stmts == 0 {
1225 self.send(BackendMessage::EmptyQueryResponse).await?;
1226 }
1227
1228 self.ready().await
1229 }
1230
1231 #[instrument(level = "debug")]
1232 async fn parse(
1233 &mut self,
1234 name: String,
1235 sql: String,
1236 param_oids: Vec<u32>,
1237 ) -> Result<State, io::Error> {
1238 self.ensure_transaction(1, "parse").await?;
1240
1241 let mut param_types = vec![];
1242 for oid in param_oids {
1243 match mz_pgrepr::Type::from_oid(oid) {
1244 Ok(ty) => match SqlScalarType::try_from(&ty) {
1245 Ok(ty) => param_types.push(Some(ty)),
1246 Err(err) => {
1247 return self
1248 .send_error_and_get_state(ErrorResponse::error(
1249 SqlState::INVALID_PARAMETER_VALUE,
1250 err.to_string(),
1251 ))
1252 .await;
1253 }
1254 },
1255 Err(_) if oid == 0 => param_types.push(None),
1256 Err(e) => {
1257 return self
1258 .send_error_and_get_state(ErrorResponse::error(
1259 SqlState::PROTOCOL_VIOLATION,
1260 e.to_string(),
1261 ))
1262 .await;
1263 }
1264 }
1265 }
1266
1267 let stmts = match self.parse_sql(&sql) {
1268 Ok(stmts) => stmts,
1269 Err(err) => {
1270 return self.send_error_and_get_state(err).await;
1271 }
1272 };
1273 if stmts.len() > 1 {
1274 return self
1275 .send_error_and_get_state(ErrorResponse::error(
1276 SqlState::INTERNAL_ERROR,
1277 "cannot insert multiple commands into a prepared statement",
1278 ))
1279 .await;
1280 }
1281 let (maybe_stmt, sql) = match stmts.into_iter().next() {
1282 None => (None, ""),
1283 Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1284 };
1285 if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1286 return self.aborted_txn_error().await;
1287 }
1288 match self
1289 .adapter_client
1290 .prepare(name, maybe_stmt, sql.to_string(), param_types)
1291 .await
1292 {
1293 Ok(()) => {
1294 self.send(BackendMessage::ParseComplete).await?;
1295 Ok(State::Ready)
1296 }
1297 Err(e) => {
1298 self.send_error_and_get_state(e.into_response(Severity::Error))
1299 .await
1300 }
1301 }
1302 }
1303
1304 #[instrument(level = "debug")]
1306 async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1307 self.end_transaction(EndTransactionAction::Commit).await
1308 }
1309
1310 #[instrument(level = "debug")]
1312 async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1313 self.end_transaction(EndTransactionAction::Rollback).await
1314 }
1315
1316 #[instrument(level = "debug")]
1318 async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1319 self.txn_needs_commit = false;
1320 let resp = self.adapter_client.end_transaction(action).await;
1321 if let Err(err) = resp {
1322 self.send(BackendMessage::ErrorResponse(
1323 err.into_response(Severity::Error),
1324 ))
1325 .await?;
1326 }
1327 Ok(())
1328 }
1329
1330 #[instrument(level = "debug")]
1331 async fn bind(
1332 &mut self,
1333 portal_name: String,
1334 statement_name: String,
1335 param_formats: Vec<Format>,
1336 raw_params: Vec<Option<Vec<u8>>>,
1337 result_formats: Vec<Format>,
1338 ) -> Result<State, io::Error> {
1339 self.ensure_transaction(1, "bind").await?;
1341
1342 let aborted_txn = self.is_aborted_txn();
1343 let stmt = match self
1344 .adapter_client
1345 .get_prepared_statement(&statement_name)
1346 .await
1347 {
1348 Ok(stmt) => stmt,
1349 Err(err) => {
1350 return self
1351 .send_error_and_get_state(err.into_response(Severity::Error))
1352 .await;
1353 }
1354 };
1355
1356 let param_types = &stmt.desc().param_types;
1357 if param_types.len() != raw_params.len() {
1358 let message = format!(
1359 "bind message supplies {actual} parameters, \
1360 but prepared statement \"{name}\" requires {expected}",
1361 name = statement_name,
1362 actual = raw_params.len(),
1363 expected = param_types.len()
1364 );
1365 return self
1366 .send_error_and_get_state(ErrorResponse::error(
1367 SqlState::PROTOCOL_VIOLATION,
1368 message,
1369 ))
1370 .await;
1371 }
1372 let param_formats = match pad_formats(param_formats, raw_params.len()) {
1373 Ok(param_formats) => param_formats,
1374 Err(msg) => {
1375 return self
1376 .send_error_and_get_state(ErrorResponse::error(
1377 SqlState::PROTOCOL_VIOLATION,
1378 msg,
1379 ))
1380 .await;
1381 }
1382 };
1383 if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1384 return self.aborted_txn_error().await;
1385 }
1386 let buf = RowArena::new();
1387 let mut params = vec![];
1388 for ((raw_param, mz_typ), format) in raw_params
1389 .into_iter()
1390 .zip_eq(param_types)
1391 .zip_eq(param_formats)
1392 {
1393 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1394 let datum = match raw_param {
1395 None => Datum::Null,
1396 Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1397 Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1398 Ok(datum) => datum,
1399 Err(msg) => {
1400 return self
1401 .send_error_and_get_state(ErrorResponse::error(
1402 SqlState::INVALID_PARAMETER_VALUE,
1403 msg,
1404 ))
1405 .await;
1406 }
1407 },
1408 Err(err) => {
1409 let msg = format!("unable to decode parameter: {}", err);
1410 return self
1411 .send_error_and_get_state(ErrorResponse::error(
1412 SqlState::INVALID_PARAMETER_VALUE,
1413 msg,
1414 ))
1415 .await;
1416 }
1417 },
1418 };
1419 params.push((datum, mz_typ.clone()))
1420 }
1421
1422 let result_formats = match pad_formats(
1423 result_formats,
1424 stmt.desc()
1425 .relation_desc
1426 .clone()
1427 .map(|desc| desc.typ().column_types.len())
1428 .unwrap_or(0),
1429 ) {
1430 Ok(result_formats) => result_formats,
1431 Err(msg) => {
1432 return self
1433 .send_error_and_get_state(ErrorResponse::error(
1434 SqlState::PROTOCOL_VIOLATION,
1435 msg,
1436 ))
1437 .await;
1438 }
1439 };
1440
1441 if !stmt.stmt().map_or(false, |stmt| match stmt {
1444 Statement::Copy(CopyStatement {
1445 direction: CopyDirection::To,
1446 ..
1447 }) => true,
1448 Statement::Copy(CopyStatement {
1449 direction: CopyDirection::From,
1450 target: CopyTarget::Expr(_),
1454 ..
1455 }) => true,
1456 _ => false,
1457 }) {
1458 if let Some(desc) = stmt.desc().relation_desc.clone() {
1459 for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1460 match (format, &ty.scalar_type) {
1461 (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1462 return self
1463 .send_error_and_get_state(ErrorResponse::error(
1464 SqlState::PROTOCOL_VIOLATION,
1465 "binary encoding of list types is not implemented",
1466 ))
1467 .await;
1468 }
1469 (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1470 return self
1471 .send_error_and_get_state(ErrorResponse::error(
1472 SqlState::PROTOCOL_VIOLATION,
1473 "binary encoding of map types is not implemented",
1474 ))
1475 .await;
1476 }
1477 (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1478 return self
1479 .send_error_and_get_state(ErrorResponse::error(
1480 SqlState::PROTOCOL_VIOLATION,
1481 "binary encoding of aclitem types does not exist",
1482 ))
1483 .await;
1484 }
1485 _ => (),
1486 }
1487 }
1488 }
1489 }
1490
1491 let desc = stmt.desc().clone();
1492 let logging = Arc::clone(stmt.logging());
1493 let stmt_ast = stmt.stmt().cloned();
1494 let state_revision = stmt.state_revision;
1495 if let Err(err) = self.adapter_client.session().set_portal(
1496 portal_name,
1497 desc,
1498 stmt_ast,
1499 logging,
1500 params,
1501 result_formats,
1502 state_revision,
1503 ) {
1504 return self
1505 .send_error_and_get_state(err.into_response(Severity::Error))
1506 .await;
1507 }
1508
1509 self.send(BackendMessage::BindComplete).await?;
1510 Ok(State::Ready)
1511 }
1512
1513 fn execute(
1516 &mut self,
1517 portal_name: String,
1518 max_rows: ExecuteCount,
1519 get_response: GetResponse,
1520 fetch_portal_name: Option<String>,
1521 timeout: ExecuteTimeout,
1522 outer_ctx_extra: Option<ExecuteContextGuard>,
1523 received: Option<EpochMillis>,
1524 ) -> BoxFuture<'_, Result<State, io::Error>> {
1525 async move {
1526 let aborted_txn = self.is_aborted_txn();
1527
1528 let portal = match self
1530 .adapter_client
1531 .session()
1532 .get_portal_unverified_mut(&portal_name)
1533 {
1534 Some(portal) => portal,
1535 None => {
1536 let msg = format!("portal {} does not exist", portal_name.quoted());
1537 if let Some(outer_ctx_extra) = outer_ctx_extra {
1538 self.adapter_client.retire_execute(
1539 outer_ctx_extra,
1540 StatementEndedExecutionReason::Errored { error: msg.clone() },
1541 );
1542 }
1543 return self
1544 .send_error_and_get_state(ErrorResponse::error(
1545 SqlState::INVALID_CURSOR_NAME,
1546 msg,
1547 ))
1548 .await;
1549 }
1550 };
1551
1552 *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1553
1554 let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1556 if aborted_txn && !txn_exit_stmt {
1557 if let Some(outer_ctx_extra) = outer_ctx_extra {
1558 self.adapter_client.retire_execute(
1559 outer_ctx_extra,
1560 StatementEndedExecutionReason::Errored {
1561 error: ABORTED_TXN_MSG.to_string(),
1562 },
1563 );
1564 }
1565 return self.aborted_txn_error().await;
1566 }
1567
1568 let row_desc = portal.desc.relation_desc.clone();
1569 match portal.state {
1570 PortalState::NotStarted => {
1571 self.ensure_transaction(1, "execute").await?;
1573 match self
1574 .adapter_client
1575 .execute(
1576 portal_name.clone(),
1577 self.conn.wait_closed(),
1578 outer_ctx_extra,
1579 )
1580 .await
1581 {
1582 Ok((response, execute_started)) => {
1583 self.send_pending_notices().await?;
1584 self.send_execute_response(
1585 response,
1586 row_desc,
1587 portal_name,
1588 max_rows,
1589 get_response,
1590 fetch_portal_name,
1591 timeout,
1592 execute_started,
1593 )
1594 .await
1595 }
1596 Err(e) => {
1597 self.send_pending_notices().await?;
1598 self.send_error_and_get_state(e.into_response(Severity::Error))
1599 .await
1600 }
1601 }
1602 }
1603 PortalState::InProgress(rows) => {
1604 let rows = rows.take().expect("InProgress rows must be populated");
1605 let (result, statement_ended_execution_reason) = match self
1606 .send_rows(
1607 row_desc.expect("portal missing row desc on resumption"),
1608 portal_name,
1609 rows,
1610 max_rows,
1611 get_response,
1612 fetch_portal_name,
1613 timeout,
1614 )
1615 .await
1616 {
1617 Err(e) => {
1618 (Err(e), StatementEndedExecutionReason::Canceled)
1621 }
1622 Ok((ok, SendRowsEndedReason::Canceled)) => {
1623 (Ok(ok), StatementEndedExecutionReason::Canceled)
1624 }
1625 Ok((
1638 ok,
1639 SendRowsEndedReason::Success {
1640 result_size: _,
1641 rows_returned: _,
1642 },
1643 )) => (
1644 Ok(ok),
1645 StatementEndedExecutionReason::Success {
1646 result_size: None,
1647 rows_returned: None,
1648 execution_strategy: None,
1649 },
1650 ),
1651 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1652 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1653 }
1654 };
1655 if let Some(outer_ctx_extra) = outer_ctx_extra {
1656 self.adapter_client
1657 .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1658 }
1659 result
1660 }
1661 PortalState::Completed(Some(tag)) => {
1668 let tag = tag.to_string();
1669 if let Some(outer_ctx_extra) = outer_ctx_extra {
1670 self.adapter_client.retire_execute(
1671 outer_ctx_extra,
1672 StatementEndedExecutionReason::Success {
1673 result_size: None,
1674 rows_returned: None,
1675 execution_strategy: None,
1676 },
1677 );
1678 }
1679 self.send(BackendMessage::CommandComplete { tag }).await?;
1680 Ok(State::Ready)
1681 }
1682 PortalState::Completed(None) => {
1683 let error = format!(
1684 "portal {} cannot be run",
1685 Ident::new_unchecked(portal_name).to_ast_string_stable()
1686 );
1687 if let Some(outer_ctx_extra) = outer_ctx_extra {
1688 self.adapter_client.retire_execute(
1689 outer_ctx_extra,
1690 StatementEndedExecutionReason::Errored {
1691 error: error.clone(),
1692 },
1693 );
1694 }
1695 self.send_error_and_get_state(ErrorResponse::error(
1696 SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1697 error,
1698 ))
1699 .await
1700 }
1701 }
1702 }
1703 .instrument(debug_span!("execute"))
1704 .boxed()
1705 }
1706
1707 #[instrument(level = "debug")]
1708 async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1709 self.ensure_transaction(1, "describe_statement").await?;
1711
1712 let stmt = match self.adapter_client.get_prepared_statement(name).await {
1713 Ok(stmt) => stmt,
1714 Err(err) => {
1715 return self
1716 .send_error_and_get_state(err.into_response(Severity::Error))
1717 .await;
1718 }
1719 };
1720 let parameter_desc = BackendMessage::ParameterDescription(
1722 stmt.desc()
1723 .param_types
1724 .iter()
1725 .map(mz_pgrepr::Type::from)
1726 .collect(),
1727 );
1728 let formats = vec![Format::Text; stmt.desc().arity()];
1732 let row_desc = describe_rows(stmt.desc(), &formats);
1733 self.send_all([parameter_desc, row_desc]).await?;
1734 Ok(State::Ready)
1735 }
1736
1737 #[instrument(level = "debug")]
1738 async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1739 self.ensure_transaction(1, "describe_portal").await?;
1741
1742 let session = self.adapter_client.session();
1743 let row_desc = session
1744 .get_portal_unverified(name)
1745 .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1746 match row_desc {
1747 Some(row_desc) => {
1748 self.send(row_desc).await?;
1749 Ok(State::Ready)
1750 }
1751 None => {
1752 self.send_error_and_get_state(ErrorResponse::error(
1753 SqlState::INVALID_CURSOR_NAME,
1754 format!("portal {} does not exist", name.quoted()),
1755 ))
1756 .await
1757 }
1758 }
1759 }
1760
1761 #[instrument(level = "debug")]
1762 async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1763 self.adapter_client
1764 .session()
1765 .remove_prepared_statement(&name);
1766 self.send(BackendMessage::CloseComplete).await?;
1767 Ok(State::Ready)
1768 }
1769
1770 #[instrument(level = "debug")]
1771 async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1772 self.adapter_client.session().remove_portal(&name);
1773 self.send(BackendMessage::CloseComplete).await?;
1774 Ok(State::Ready)
1775 }
1776
1777 fn complete_portal(&mut self, name: &str) {
1778 let portal = self
1779 .adapter_client
1780 .session()
1781 .get_portal_unverified_mut(name)
1782 .expect("portal should exist");
1783 *portal.state = PortalState::Completed(None);
1784 }
1785
1786 async fn fetch(
1787 &mut self,
1788 name: String,
1789 count: Option<FetchDirection>,
1790 max_rows: ExecuteCount,
1791 fetch_portal_name: Option<String>,
1792 timeout: ExecuteTimeout,
1793 ctx_extra: ExecuteContextGuard,
1794 ) -> Result<State, io::Error> {
1795 let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1798
1799 let count = match (max_rows, count) {
1812 (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1813 let count = usize::cast_from(count);
1814 if max_rows < count {
1815 let msg = "Execute with max_rows < a FETCH's count is not supported";
1816 self.adapter_client.retire_execute(
1817 ctx_extra,
1818 StatementEndedExecutionReason::Errored {
1819 error: msg.to_string(),
1820 },
1821 );
1822 return self
1823 .send_error_and_get_state(ErrorResponse::error(
1824 SqlState::FEATURE_NOT_SUPPORTED,
1825 msg,
1826 ))
1827 .await;
1828 }
1829 ExecuteCount::Count(count)
1830 }
1831 (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1832 let msg = "Execute with max_rows of a FETCH ALL is not supported";
1833 self.adapter_client.retire_execute(
1834 ctx_extra,
1835 StatementEndedExecutionReason::Errored {
1836 error: msg.to_string(),
1837 },
1838 );
1839 return self
1840 .send_error_and_get_state(ErrorResponse::error(
1841 SqlState::FEATURE_NOT_SUPPORTED,
1842 msg,
1843 ))
1844 .await;
1845 }
1846 (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1847 (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1848 ExecuteCount::Count(usize::cast_from(count))
1849 }
1850 };
1851 let cursor_name = name.to_string();
1852 self.execute(
1853 cursor_name,
1854 count,
1855 fetch_message,
1856 fetch_portal_name,
1857 timeout,
1858 Some(ctx_extra),
1859 None,
1860 )
1861 .await
1862 }
1863
1864 async fn flush(&mut self) -> Result<State, io::Error> {
1865 self.conn.flush().await?;
1866 Ok(State::Ready)
1867 }
1868
1869 #[instrument(level = "debug")]
1874 async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1875 where
1876 M: Into<BackendMessage>,
1877 {
1878 let message: BackendMessage = message.into();
1879 let is_error =
1880 matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1881
1882 self.conn.send(message).await?;
1883
1884 if is_error {
1891 self.conn.flush().await?;
1892 }
1893
1894 Ok(())
1895 }
1896
1897 #[instrument(level = "debug")]
1898 pub async fn send_all(
1899 &mut self,
1900 messages: impl IntoIterator<Item = BackendMessage>,
1901 ) -> Result<(), io::Error> {
1902 for m in messages {
1903 self.send(m).await?;
1904 }
1905 Ok(())
1906 }
1907
1908 #[instrument(level = "debug")]
1909 async fn sync(&mut self) -> Result<State, io::Error> {
1910 if self.adapter_client.session().transaction().is_implicit() {
1912 self.commit_transaction().await?;
1913 }
1914 self.ready().await
1915 }
1916
1917 #[instrument(level = "debug")]
1918 async fn ready(&mut self) -> Result<State, io::Error> {
1919 let txn_state = self.adapter_client.session().transaction().into();
1920 self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1921 self.flush().await
1922 }
1923
1924 #[allow(clippy::too_many_arguments)]
1925 #[instrument(level = "debug")]
1926 async fn send_execute_response(
1927 &mut self,
1928 response: ExecuteResponse,
1929 row_desc: Option<RelationDesc>,
1930 portal_name: String,
1931 max_rows: ExecuteCount,
1932 get_response: GetResponse,
1933 fetch_portal_name: Option<String>,
1934 timeout: ExecuteTimeout,
1935 execute_started: Instant,
1936 ) -> Result<State, io::Error> {
1937 let mut tag = response.tag();
1938
1939 macro_rules! command_complete {
1940 () => {{
1941 self.send(BackendMessage::CommandComplete {
1942 tag: tag
1943 .take()
1944 .expect("command_complete only called on tag-generating results"),
1945 })
1946 .await?;
1947 Ok(State::Ready)
1948 }};
1949 }
1950
1951 let r = match response {
1952 ExecuteResponse::ClosedCursor => {
1953 self.complete_portal(&portal_name);
1954 command_complete!()
1955 }
1956 ExecuteResponse::DeclaredCursor => {
1957 self.complete_portal(&portal_name);
1958 command_complete!()
1959 }
1960 ExecuteResponse::EmptyQuery => {
1961 self.send(BackendMessage::EmptyQueryResponse).await?;
1962 Ok(State::Ready)
1963 }
1964 ExecuteResponse::Fetch {
1965 name,
1966 count,
1967 timeout,
1968 ctx_extra,
1969 } => {
1970 self.fetch(
1971 name,
1972 count,
1973 max_rows,
1974 Some(portal_name.to_string()),
1975 timeout,
1976 ctx_extra,
1977 )
1978 .await
1979 }
1980 ExecuteResponse::SendingRowsStreaming {
1981 rows,
1982 instance_id,
1983 strategy,
1984 } => {
1985 let row_desc = row_desc
1986 .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1987
1988 let span = tracing::debug_span!("sending_rows_streaming");
1989
1990 self.send_rows(
1991 row_desc,
1992 portal_name,
1993 InProgressRows::new(RecordFirstRowStream::new(
1994 Box::new(rows),
1995 execute_started,
1996 &self.adapter_client,
1997 Some(instance_id),
1998 Some(strategy),
1999 )),
2000 max_rows,
2001 get_response,
2002 fetch_portal_name,
2003 timeout,
2004 )
2005 .instrument(span)
2006 .await
2007 .map(|(state, _)| state)
2008 }
2009 ExecuteResponse::SendingRowsImmediate { rows } => {
2010 let row_desc = row_desc
2011 .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
2012
2013 let span = tracing::debug_span!("sending_rows_immediate");
2014
2015 let stream =
2016 futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
2017 self.send_rows(
2018 row_desc,
2019 portal_name,
2020 InProgressRows::new(RecordFirstRowStream::new(
2021 Box::new(stream),
2022 execute_started,
2023 &self.adapter_client,
2024 None,
2025 Some(StatementExecutionStrategy::Constant),
2026 )),
2027 max_rows,
2028 get_response,
2029 fetch_portal_name,
2030 timeout,
2031 )
2032 .instrument(span)
2033 .await
2034 .map(|(state, _)| state)
2035 }
2036 ExecuteResponse::SetVariable { name, .. } => {
2037 let qn = name.to_string();
2040 let msg = if let Some(var) = self
2041 .adapter_client
2042 .session()
2043 .vars_mut()
2044 .notify_set()
2045 .find(|v| v.name() == qn)
2046 {
2047 Some(BackendMessage::ParameterStatus(var.name(), var.value()))
2048 } else {
2049 None
2050 };
2051 if let Some(msg) = msg {
2052 self.send(msg).await?;
2053 }
2054 command_complete!()
2055 }
2056 ExecuteResponse::Subscribing {
2057 rx,
2058 ctx_extra,
2059 instance_id,
2060 } => {
2061 if fetch_portal_name.is_none() {
2062 let mut msg = ErrorResponse::notice(
2063 SqlState::WARNING,
2064 "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
2065 );
2066 if self.adapter_client.session().vars().application_name() == "psql" {
2067 msg.hint = Some(
2068 "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
2069 .into(),
2070 )
2071 }
2072 self.send(msg).await?;
2073 self.conn.flush().await?;
2074 }
2075 let row_desc =
2076 row_desc.expect("missing row description for ExecuteResponse::Subscribing");
2077 let (result, statement_ended_execution_reason) = match self
2078 .send_rows(
2079 row_desc,
2080 portal_name,
2081 InProgressRows::new(RecordFirstRowStream::new(
2082 Box::new(UnboundedReceiverStream::new(rx)),
2083 execute_started,
2084 &self.adapter_client,
2085 Some(instance_id),
2086 None,
2087 )),
2088 max_rows,
2089 get_response,
2090 fetch_portal_name,
2091 timeout,
2092 )
2093 .await
2094 {
2095 Err(e) => {
2096 (Err(e), StatementEndedExecutionReason::Canceled)
2099 }
2100 Ok((ok, SendRowsEndedReason::Canceled)) => {
2101 (Ok(ok), StatementEndedExecutionReason::Canceled)
2102 }
2103 Ok((
2104 ok,
2105 SendRowsEndedReason::Success {
2106 result_size,
2107 rows_returned,
2108 },
2109 )) => (
2110 Ok(ok),
2111 StatementEndedExecutionReason::Success {
2112 result_size: Some(result_size),
2113 rows_returned: Some(rows_returned),
2114 execution_strategy: None,
2115 },
2116 ),
2117 Ok((ok, SendRowsEndedReason::Errored { error })) => {
2118 (Ok(ok), StatementEndedExecutionReason::Errored { error })
2119 }
2120 };
2121 self.adapter_client
2122 .retire_execute(ctx_extra, statement_ended_execution_reason);
2123 return result;
2124 }
2125 ExecuteResponse::CopyTo { format, resp } => {
2126 let row_desc =
2127 row_desc.expect("missing row description for ExecuteResponse::CopyTo");
2128 match *resp {
2129 ExecuteResponse::Subscribing {
2130 rx,
2131 ctx_extra,
2132 instance_id,
2133 } => {
2134 let (result, statement_ended_execution_reason) = match self
2135 .copy_rows(
2136 format,
2137 row_desc,
2138 RecordFirstRowStream::new(
2139 Box::new(UnboundedReceiverStream::new(rx)),
2140 execute_started,
2141 &self.adapter_client,
2142 Some(instance_id),
2143 None,
2144 ),
2145 )
2146 .await
2147 {
2148 Err(e) => {
2149 (Err(e), StatementEndedExecutionReason::Canceled)
2152 }
2153 Ok((
2154 state,
2155 SendRowsEndedReason::Success {
2156 result_size,
2157 rows_returned,
2158 },
2159 )) => (
2160 Ok(state),
2161 StatementEndedExecutionReason::Success {
2162 result_size: Some(result_size),
2163 rows_returned: Some(rows_returned),
2164 execution_strategy: None,
2165 },
2166 ),
2167 Ok((state, SendRowsEndedReason::Errored { error })) => {
2168 (Ok(state), StatementEndedExecutionReason::Errored { error })
2169 }
2170 Ok((state, SendRowsEndedReason::Canceled)) => {
2171 (Ok(state), StatementEndedExecutionReason::Canceled)
2172 }
2173 };
2174 self.adapter_client
2175 .retire_execute(ctx_extra, statement_ended_execution_reason);
2176 return result;
2177 }
2178 ExecuteResponse::SendingRowsStreaming {
2179 rows,
2180 instance_id,
2181 strategy,
2182 } => {
2183 return self
2188 .copy_rows(
2189 format,
2190 row_desc,
2191 RecordFirstRowStream::new(
2192 Box::new(rows),
2193 execute_started,
2194 &self.adapter_client,
2195 Some(instance_id),
2196 Some(strategy),
2197 ),
2198 )
2199 .await
2200 .map(|(state, _)| state);
2201 }
2202 ExecuteResponse::SendingRowsImmediate { rows } => {
2203 let span = tracing::debug_span!("sending_rows_immediate");
2204
2205 let rows = futures::stream::once(futures::future::ready(
2206 PeekResponseUnary::Rows(rows),
2207 ));
2208 return self
2213 .copy_rows(
2214 format,
2215 row_desc,
2216 RecordFirstRowStream::new(
2217 Box::new(rows),
2218 execute_started,
2219 &self.adapter_client,
2220 None,
2221 Some(StatementExecutionStrategy::Constant),
2222 ),
2223 )
2224 .instrument(span)
2225 .await
2226 .map(|(state, _)| state);
2227 }
2228 _ => {
2229 return self
2230 .send_error_and_get_state(ErrorResponse::error(
2231 SqlState::INTERNAL_ERROR,
2232 "unsupported COPY response type".to_string(),
2233 ))
2234 .await;
2235 }
2236 };
2237 }
2238 ExecuteResponse::CopyFrom {
2239 target_id,
2240 target_name,
2241 columns,
2242 params,
2243 ctx_extra,
2244 } => {
2245 let row_desc =
2246 row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2247 self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2248 .await
2249 }
2250 ExecuteResponse::TransactionCommitted { params }
2251 | ExecuteResponse::TransactionRolledBack { params } => {
2252 let notify_set: mz_ore::collections::HashSet<String> = self
2253 .adapter_client
2254 .session()
2255 .vars()
2256 .notify_set()
2257 .map(|v| v.name().to_string())
2258 .collect();
2259
2260 for (name, value) in params
2262 .into_iter()
2263 .filter(|(name, _v)| notify_set.contains(*name))
2264 {
2265 let msg = BackendMessage::ParameterStatus(name, value);
2266 self.send(msg).await?;
2267 }
2268 command_complete!()
2269 }
2270
2271 ExecuteResponse::AlteredDefaultPrivileges
2272 | ExecuteResponse::AlteredObject(..)
2273 | ExecuteResponse::AlteredRole
2274 | ExecuteResponse::AlteredSystemConfiguration
2275 | ExecuteResponse::CreatedCluster { .. }
2276 | ExecuteResponse::CreatedClusterReplica { .. }
2277 | ExecuteResponse::CreatedConnection { .. }
2278 | ExecuteResponse::CreatedDatabase { .. }
2279 | ExecuteResponse::CreatedIndex { .. }
2280 | ExecuteResponse::CreatedIntrospectionSubscribe
2281 | ExecuteResponse::CreatedMaterializedView { .. }
2282 | ExecuteResponse::CreatedRole
2283 | ExecuteResponse::CreatedSchema { .. }
2284 | ExecuteResponse::CreatedSecret { .. }
2285 | ExecuteResponse::CreatedSink { .. }
2286 | ExecuteResponse::CreatedSource { .. }
2287 | ExecuteResponse::CreatedTable { .. }
2288 | ExecuteResponse::CreatedType
2289 | ExecuteResponse::CreatedView { .. }
2290 | ExecuteResponse::CreatedViews { .. }
2291 | ExecuteResponse::CreatedNetworkPolicy
2292 | ExecuteResponse::Comment
2293 | ExecuteResponse::Deallocate { .. }
2294 | ExecuteResponse::Deleted(..)
2295 | ExecuteResponse::DiscardedAll
2296 | ExecuteResponse::DiscardedTemp
2297 | ExecuteResponse::DroppedObject(_)
2298 | ExecuteResponse::DroppedOwned
2299 | ExecuteResponse::GrantedPrivilege
2300 | ExecuteResponse::GrantedRole
2301 | ExecuteResponse::Inserted(..)
2302 | ExecuteResponse::Copied(..)
2303 | ExecuteResponse::Prepare
2304 | ExecuteResponse::Raised
2305 | ExecuteResponse::ReassignOwned
2306 | ExecuteResponse::RevokedPrivilege
2307 | ExecuteResponse::RevokedRole
2308 | ExecuteResponse::StartedTransaction { .. }
2309 | ExecuteResponse::Updated(..)
2310 | ExecuteResponse::ValidatedConnection => {
2311 command_complete!()
2312 }
2313 };
2314
2315 assert_none!(tag, "tag created but not consumed: {:?}", tag);
2316 r
2317 }
2318
2319 #[allow(clippy::too_many_arguments)]
2320 #[mz_ore::instrument(level = "debug")]
2322 async fn send_rows(
2323 &mut self,
2324 row_desc: RelationDesc,
2325 portal_name: String,
2326 mut rows: InProgressRows,
2327 max_rows: ExecuteCount,
2328 get_response: GetResponse,
2329 fetch_portal_name: Option<String>,
2330 timeout: ExecuteTimeout,
2331 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2332 let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2335 name
2336 } else {
2337 &portal_name
2338 };
2339 let result_formats = self
2340 .adapter_client
2341 .session()
2342 .get_portal_unverified(result_format_portal_name)
2343 .expect("valid fetch portal name for send rows")
2344 .result_formats
2345 .clone();
2346
2347 let (mut wait_once, mut deadline) = match timeout {
2348 ExecuteTimeout::None => (false, None),
2349 ExecuteTimeout::Seconds(t) => (
2350 false,
2351 Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2352 ),
2353 ExecuteTimeout::WaitOnce => (true, None),
2354 };
2355
2356 {
2358 let portal_name_desc = &self
2359 .adapter_client
2360 .session()
2361 .get_portal_unverified(portal_name.as_str())
2362 .expect("portal should exist")
2363 .desc
2364 .relation_desc;
2365 if let Some(portal_name_desc) = portal_name_desc {
2366 soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2367 }
2368 if let Some(fetch_portal_name) = &fetch_portal_name {
2369 let fetch_portal_desc = &self
2370 .adapter_client
2371 .session()
2372 .get_portal_unverified(fetch_portal_name)
2373 .expect("portal should exist")
2374 .desc
2375 .relation_desc;
2376 if let Some(fetch_portal_desc) = fetch_portal_desc {
2377 soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2378 }
2379 }
2380 }
2381
2382 self.conn.set_encode_state(
2383 row_desc
2384 .typ()
2385 .column_types
2386 .iter()
2387 .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2388 .zip_eq(result_formats)
2389 .collect(),
2390 );
2391
2392 let mut total_sent_rows = 0;
2393 let mut total_sent_bytes = 0;
2394 let mut want_rows = match max_rows {
2396 ExecuteCount::All => usize::MAX,
2397 ExecuteCount::Count(count) => count,
2398 };
2399
2400 loop {
2402 let batch = if rows.current.is_some() {
2405 FetchResult::Rows(rows.current.take())
2406 } else if want_rows == 0 {
2407 FetchResult::Rows(None)
2408 } else {
2409 let notice_fut = self.adapter_client.session().recv_notice();
2410 tokio::select! {
2425 biased;
2426 err = self.conn.wait_closed() => return Err(err),
2427 batch = rows.remaining.recv() => match batch {
2428 None => FetchResult::Rows(None),
2429 Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2430 Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2431 Some(PeekResponseUnary::DependencyDropped(dep)) => {
2432 FetchResult::Error(dep.query_terminated_error())
2433 }
2434 Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2435 },
2436 notice = notice_fut => {
2437 FetchResult::Notice(notice)
2438 }
2439 _ = time::sleep_until(
2440 deadline.unwrap_or_else(tokio::time::Instant::now),
2441 ), if deadline.is_some() => FetchResult::Rows(None),
2442 }
2443 };
2444
2445 match batch {
2446 FetchResult::Rows(None) => break,
2447 FetchResult::Rows(Some(mut batch_rows)) => {
2448 if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2449 let msg = err.to_string();
2450 return self
2451 .send_error_and_get_state(err.into_response(Severity::Error))
2452 .await
2453 .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2454 }
2455
2456 if wait_once && batch_rows.peek().is_some() {
2460 deadline = Some(tokio::time::Instant::now());
2461 wait_once = false;
2462 }
2463
2464 let mut sent_rows = 0;
2466 let mut sent_bytes = 0;
2467 let messages = (&mut batch_rows)
2468 .map(|row| {
2473 let row_len = row.byte_len();
2474 let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2475 (row_len, BackendMessage::DataRow(values))
2476 })
2477 .inspect(|(row_len, _)| {
2478 sent_bytes += row_len;
2479 sent_rows += 1
2480 })
2481 .map(|(_row_len, row)| row)
2482 .take(want_rows);
2483 self.send_all(messages).await?;
2484
2485 total_sent_rows += sent_rows;
2486 total_sent_bytes += sent_bytes;
2487 want_rows -= sent_rows;
2488
2489 if want_rows == 0 {
2492 if batch_rows.peek().is_some() {
2493 rows.current = Some(batch_rows);
2494 }
2495 break;
2496 }
2497
2498 self.conn.flush().await?;
2499 }
2500 FetchResult::Notice(notice) => {
2501 self.send(notice.into_response()).await?;
2502 self.conn.flush().await?;
2503 }
2504 FetchResult::Error(text) => {
2505 return self
2506 .send_error_and_get_state(ErrorResponse::error(
2507 SqlState::INTERNAL_ERROR,
2508 text.clone(),
2509 ))
2510 .await
2511 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2512 }
2513 FetchResult::Canceled => {
2514 return self
2515 .send_error_and_get_state(ErrorResponse::error(
2516 SqlState::QUERY_CANCELED,
2517 "canceling statement due to user request",
2518 ))
2519 .await
2520 .map(|state| (state, SendRowsEndedReason::Canceled));
2521 }
2522 }
2523 }
2524
2525 let portal = self
2526 .adapter_client
2527 .session()
2528 .get_portal_unverified_mut(&portal_name)
2529 .expect("valid portal name for send rows");
2530
2531 let saw_rows = rows.remaining.saw_rows;
2532 let no_more_rows = rows.no_more_rows();
2533 let metric_recorded = rows.remaining.metric_recorded;
2534 let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2535
2536 if no_more_rows && !metric_recorded {
2537 rows.remaining.metric_recorded = true;
2538 }
2539
2540 *portal.state = PortalState::InProgress(Some(rows));
2543
2544 let fetch_portal = fetch_portal_name.map(|name| {
2545 self.adapter_client
2546 .session()
2547 .get_portal_unverified_mut(&name)
2548 .expect("valid fetch portal")
2549 });
2550 let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2551 self.send(response_message).await?;
2552
2553 if no_more_rows && !metric_recorded {
2556 let statement_type = if let Some(stmt) = &self
2557 .adapter_client
2558 .session()
2559 .get_portal_unverified(&portal_name)
2560 .expect("valid portal name for send_rows")
2561 .stmt
2562 {
2563 metrics::statement_type_label_value(stmt.deref())
2564 } else {
2565 "no-statement"
2566 };
2567 let duration = if saw_rows {
2568 recorded_first_row_instant
2569 .expect("recorded_first_row_instant because saw_rows")
2570 .elapsed()
2571 } else {
2572 Duration::ZERO
2576 };
2577 self.adapter_client
2578 .inner()
2579 .metrics()
2580 .result_rows_first_to_last_byte_seconds
2581 .with_label_values(&[statement_type])
2582 .observe(duration.as_secs_f64());
2583 }
2584
2585 Ok((
2586 State::Ready,
2587 SendRowsEndedReason::Success {
2588 result_size: u64::cast_from(total_sent_bytes),
2589 rows_returned: u64::cast_from(total_sent_rows),
2590 },
2591 ))
2592 }
2593
2594 #[mz_ore::instrument(level = "debug")]
2595 async fn copy_rows(
2596 &mut self,
2597 format: CopyFormat,
2598 row_desc: RelationDesc,
2599 mut stream: RecordFirstRowStream,
2600 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2601 let (row_format, encode_format) = match format {
2602 CopyFormat::Text => (
2603 CopyFormatParams::Text(CopyTextFormatParams::default()),
2604 Format::Text,
2605 ),
2606 CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2607 CopyFormat::Csv => (
2608 CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2609 Format::Text,
2610 ),
2611 CopyFormat::Parquet => {
2612 let text = "Parquet format is not supported".to_string();
2613 return self
2614 .send_error_and_get_state(ErrorResponse::error(
2615 SqlState::INTERNAL_ERROR,
2616 text.clone(),
2617 ))
2618 .await
2619 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2620 }
2621 };
2622
2623 let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2624 mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2625 };
2626
2627 let typ = row_desc.typ();
2628 let column_formats = iter::repeat(encode_format)
2629 .take(typ.column_types.len())
2630 .collect();
2631 self.send(BackendMessage::CopyOutResponse {
2632 overall_format: encode_format,
2633 column_formats,
2634 })
2635 .await?;
2636
2637 let mut out = Vec::new();
2642
2643 if let CopyFormat::Binary = format {
2644 out.extend(b"PGCOPY\n\xFF\r\n\0");
2646 out.extend([0, 0, 0, 0]);
2648 out.extend([0, 0, 0, 0]);
2650 }
2651
2652 let mut count = 0;
2653 let mut total_sent_bytes = 0;
2654 loop {
2655 tokio::select! {
2656 e = self.conn.wait_closed() => return Err(e),
2657 batch = stream.recv() => match batch {
2658 None => break,
2659 Some(PeekResponseUnary::Error(text)) => {
2660 let err =
2661 ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone());
2662 return self
2663 .send_error_and_get_state(err)
2664 .await
2665 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2666 }
2667 Some(PeekResponseUnary::DependencyDropped(dep)) => {
2668 let err = dep.to_concurrent_dependency_drop();
2669 let text = err.to_string();
2670 let resp = err.into_response(Severity::Error);
2671 return self
2672 .send_error_and_get_state(resp)
2673 .await
2674 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2675 }
2676 Some(PeekResponseUnary::Canceled) => {
2677 return self.send_error_and_get_state(ErrorResponse::error(
2678 SqlState::QUERY_CANCELED,
2679 "canceling statement due to user request",
2680 ))
2681 .await.map(|state| (state, SendRowsEndedReason::Canceled));
2682 }
2683 Some(PeekResponseUnary::Rows(mut rows)) => {
2684 count += rows.count();
2685 while let Some(row) = rows.next() {
2686 total_sent_bytes += row.byte_len();
2687 encode_fn(row, typ, &mut out)?;
2688 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2689 .await?;
2690 }
2691 }
2692 },
2693 notice = self.adapter_client.session().recv_notice() => {
2694 self.send(notice.into_response())
2695 .await?;
2696 self.conn.flush().await?;
2697 }
2698 }
2699
2700 self.conn.flush().await?;
2701 }
2702 if let CopyFormat::Binary = format {
2704 let trailer: i16 = -1;
2705 out.extend(trailer.to_be_bytes());
2706 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2707 .await?;
2708 }
2709
2710 let tag = format!("COPY {}", count);
2711 self.send(BackendMessage::CopyDone).await?;
2712 self.send(BackendMessage::CommandComplete { tag }).await?;
2713 Ok((
2714 State::Ready,
2715 SendRowsEndedReason::Success {
2716 result_size: u64::cast_from(total_sent_bytes),
2717 rows_returned: u64::cast_from(count),
2718 },
2719 ))
2720 }
2721
2722 #[instrument(level = "debug")]
2725 async fn copy_from(
2726 &mut self,
2727 target_id: CatalogItemId,
2728 target_name: String,
2729 columns: Vec<ColumnIndex>,
2730 params: CopyFormatParams<'static>,
2731 row_desc: RelationDesc,
2732 mut ctx_extra: ExecuteContextGuard,
2733 ) -> Result<State, io::Error> {
2734 let res = self
2735 .copy_from_inner(
2736 target_id,
2737 target_name,
2738 columns,
2739 params,
2740 row_desc,
2741 &mut ctx_extra,
2742 )
2743 .await;
2744 match &res {
2745 Ok(State::Ready) => {
2746 self.adapter_client.retire_execute(
2747 ctx_extra,
2748 StatementEndedExecutionReason::Success {
2749 result_size: None,
2750 rows_returned: None,
2751 execution_strategy: None,
2752 },
2753 );
2754 }
2755 Ok(State::Done) => {
2756 self.adapter_client
2760 .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2761 }
2762 Err(e) => {
2763 self.adapter_client.retire_execute(
2764 ctx_extra,
2765 StatementEndedExecutionReason::Errored {
2766 error: format!("{e}"),
2767 },
2768 );
2769 }
2770 Ok(State::Drain) => {}
2771 }
2772 res
2773 }
2774
2775 async fn copy_from_inner(
2776 &mut self,
2777 target_id: CatalogItemId,
2778 target_name: String,
2779 columns: Vec<ColumnIndex>,
2780 params: CopyFormatParams<'static>,
2781 row_desc: RelationDesc,
2782 ctx_extra: &mut ExecuteContextGuard,
2783 ) -> Result<State, io::Error> {
2784 let typ = row_desc.typ();
2785 let column_formats = vec![Format::Text; typ.column_types.len()];
2786 self.send(BackendMessage::CopyInResponse {
2787 overall_format: Format::Text,
2788 column_formats,
2789 })
2790 .await?;
2791 self.conn.flush().await?;
2792
2793 let writer = match self
2795 .adapter_client
2796 .start_copy_from_stdin(
2797 target_id,
2798 target_name.clone(),
2799 columns.clone(),
2800 row_desc.clone(),
2801 params.clone(),
2802 )
2803 .await
2804 {
2805 Ok(writer) => writer,
2806 Err(e) => {
2807 loop {
2813 match self.conn.recv().await? {
2814 Some(FrontendMessage::CopyData(_)) => {}
2815 Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2816 break;
2817 }
2818 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2819 Some(_) => break,
2820 None => return Ok(State::Done),
2821 }
2822 }
2823 self.adapter_client.retire_execute(
2824 std::mem::take(ctx_extra),
2825 StatementEndedExecutionReason::Errored {
2826 error: e.to_string(),
2827 },
2828 );
2829 return self
2830 .send_error_and_get_state(e.into_response(Severity::Error))
2831 .await;
2832 }
2833 };
2834
2835 self.conn.set_copy_mode(true);
2837
2838 const BATCH_SIZE: usize = 32 * 1024 * 1024;
2840 let max_copy_from_row_size = self
2841 .adapter_client
2842 .get_system_vars()
2843 .await
2844 .max_copy_from_row_size()
2845 .try_into()
2846 .unwrap_or(usize::MAX);
2847
2848 let mut data = Vec::new();
2849 let mut row_scanner = CopyRowScanner::new(¶ms);
2850 let num_workers = writer.batch_txs.len();
2851 let mut next_worker: usize = 0;
2852 let mut saw_copy_done = false;
2853 let mut saw_end_marker = false;
2854 let mut copy_from_error: Option<(SqlState, String)> = None;
2855
2856 loop {
2859 let message = self.conn.recv().await?;
2860 match message {
2861 Some(FrontendMessage::CopyData(buf)) => {
2862 if saw_end_marker {
2863 continue;
2866 }
2867 data.extend(buf);
2868 row_scanner.scan_new_bytes(&data);
2869
2870 if let Some(end_pos) = row_scanner.end_marker_end() {
2871 data.truncate(end_pos);
2872 row_scanner.on_truncate(end_pos);
2873 saw_end_marker = true;
2874 }
2875
2876 if row_scanner.current_row_size(data.len()) > max_copy_from_row_size {
2878 copy_from_error = Some((
2879 SqlState::INSUFFICIENT_RESOURCES,
2880 format!(
2881 "COPY FROM STDIN row exceeded max_copy_from_row_size \
2882 ({max_copy_from_row_size} bytes)"
2883 ),
2884 ));
2885 break;
2886 }
2887
2888 let mut send_failed = false;
2891 while data.len() >= BATCH_SIZE {
2892 let split_pos = match row_scanner.last_row_end() {
2893 Some(pos) => pos,
2894 None => break, };
2896 let remainder = data.split_off(split_pos);
2897 let chunk = std::mem::replace(&mut data, remainder);
2898 row_scanner.on_split(split_pos);
2899 if writer.batch_txs[next_worker].send(chunk).await.is_err() {
2900 send_failed = true;
2901 break;
2902 }
2903 next_worker = (next_worker + 1) % num_workers;
2904 }
2905 if send_failed {
2908 break;
2909 }
2910 }
2911 Some(FrontendMessage::CopyDone) => {
2912 if !data.is_empty() {
2914 let chunk = std::mem::take(&mut data);
2915 let _ = writer.batch_txs[next_worker].send(chunk).await;
2917 }
2918 saw_copy_done = true;
2919 break;
2920 }
2921 Some(FrontendMessage::CopyFail(err)) => {
2922 self.adapter_client.retire_execute(
2923 std::mem::take(ctx_extra),
2924 StatementEndedExecutionReason::Canceled,
2925 );
2926 drop(writer);
2928 self.conn.set_copy_mode(false);
2929 return self
2930 .send_error_and_get_state(ErrorResponse::error(
2931 SqlState::QUERY_CANCELED,
2932 format!("COPY from stdin failed: {}", err),
2933 ))
2934 .await;
2935 }
2936 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2937 Some(_) => {
2938 let msg = "unexpected message type during COPY from stdin";
2939 self.adapter_client.retire_execute(
2940 std::mem::take(ctx_extra),
2941 StatementEndedExecutionReason::Errored {
2942 error: msg.to_string(),
2943 },
2944 );
2945 drop(writer);
2946 self.conn.set_copy_mode(false);
2947 return self
2948 .send_error_and_get_state(ErrorResponse::error(
2949 SqlState::PROTOCOL_VIOLATION,
2950 msg,
2951 ))
2952 .await;
2953 }
2954 None => {
2955 drop(writer);
2956 self.conn.set_copy_mode(false);
2957 return Ok(State::Done);
2958 }
2959 }
2960 }
2961
2962 if !saw_copy_done {
2966 loop {
2967 match self.conn.recv().await? {
2968 Some(FrontendMessage::CopyData(_)) => {}
2969 Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2970 break;
2971 }
2972 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2973 Some(_) => {
2974 let msg = "unexpected message type during COPY from stdin";
2975 self.adapter_client.retire_execute(
2976 std::mem::take(ctx_extra),
2977 StatementEndedExecutionReason::Errored {
2978 error: msg.to_string(),
2979 },
2980 );
2981 drop(writer);
2982 self.conn.set_copy_mode(false);
2983 return self
2984 .send_error_and_get_state(ErrorResponse::error(
2985 SqlState::PROTOCOL_VIOLATION,
2986 msg,
2987 ))
2988 .await;
2989 }
2990 None => {
2991 drop(writer);
2992 self.conn.set_copy_mode(false);
2993 return Ok(State::Done);
2994 }
2995 }
2996 }
2997 }
2998
2999 if let Some((code, msg)) = copy_from_error {
3000 self.adapter_client.retire_execute(
3001 std::mem::take(ctx_extra),
3002 StatementEndedExecutionReason::Errored { error: msg.clone() },
3003 );
3004 drop(writer);
3005 self.conn.set_copy_mode(false);
3006 return self
3007 .send_error_and_get_state(ErrorResponse::error(code, msg))
3008 .await;
3009 }
3010
3011 self.conn.set_copy_mode(false);
3012
3013 drop(writer.batch_txs);
3018
3019 let (proto_batches, row_count) = match writer.completion_rx.await {
3021 Ok(Ok(result)) => result,
3022 Ok(Err(e)) => {
3023 self.adapter_client.retire_execute(
3024 std::mem::take(ctx_extra),
3025 StatementEndedExecutionReason::Errored {
3026 error: e.to_string(),
3027 },
3028 );
3029 return self
3030 .send_error_and_get_state(e.into_response(Severity::Error))
3031 .await;
3032 }
3033 Err(_) => {
3034 let msg = "COPY FROM STDIN: background batch builder tasks dropped";
3035 self.adapter_client.retire_execute(
3036 std::mem::take(ctx_extra),
3037 StatementEndedExecutionReason::Errored {
3038 error: msg.to_string(),
3039 },
3040 );
3041 return self
3042 .send_error_and_get_state(ErrorResponse::error(SqlState::INTERNAL_ERROR, msg))
3043 .await;
3044 }
3045 };
3046
3047 if let Err(e) = self
3049 .adapter_client
3050 .stage_copy_from_stdin_batches(target_id, proto_batches)
3051 {
3052 self.adapter_client.retire_execute(
3053 std::mem::take(ctx_extra),
3054 StatementEndedExecutionReason::Errored {
3055 error: e.to_string(),
3056 },
3057 );
3058 return self
3059 .send_error_and_get_state(e.into_response(Severity::Error))
3060 .await;
3061 }
3062
3063 let tag = format!("COPY {}", row_count);
3064 self.send(BackendMessage::CommandComplete { tag }).await?;
3065
3066 Ok(State::Ready)
3067 }
3068
3069 #[instrument(level = "debug")]
3070 async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
3071 let notices = self
3072 .adapter_client
3073 .session()
3074 .drain_notices()
3075 .into_iter()
3076 .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
3077 self.send_all(notices).await?;
3078 Ok(())
3079 }
3080
3081 #[instrument(level = "debug")]
3082 async fn send_error_and_get_state(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
3083 assert!(err.severity.is_error());
3084 debug!(
3085 "cid={} error code={}",
3086 self.adapter_client.session().conn_id(),
3087 err.code.code()
3088 );
3089 let is_fatal = err.severity.is_fatal();
3090 self.send(BackendMessage::ErrorResponse(err)).await?;
3091
3092 let txn = self.adapter_client.session().transaction();
3093 match txn {
3094 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
3097 TransactionStatus::Started(_) => {
3099 self.rollback_transaction().await?;
3100 }
3101 TransactionStatus::InTransactionImplicit(_) => {
3103 self.rollback_transaction().await?;
3104 }
3105 TransactionStatus::InTransaction(_) => {
3107 self.adapter_client.fail_transaction();
3108 }
3109 };
3110 if is_fatal {
3111 Ok(State::Done)
3112 } else {
3113 Ok(State::Drain)
3114 }
3115 }
3116
3117 #[instrument(level = "debug")]
3118 async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
3119 self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
3120 SqlState::IN_FAILED_SQL_TRANSACTION,
3121 ABORTED_TXN_MSG,
3122 )))
3123 .await?;
3124 Ok(State::Drain)
3125 }
3126
3127 fn is_aborted_txn(&mut self) -> bool {
3128 matches!(
3129 self.adapter_client.session().transaction(),
3130 TransactionStatus::Failed(_)
3131 )
3132 }
3133}
3134
3135fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
3136 match (formats.len(), n) {
3137 (0, e) => Ok(vec![Format::Text; e]),
3138 (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
3139 (a, e) if a == e => Ok(formats),
3140 (a, e) => Err(format!(
3141 "expected {} field format specifiers, but got {}",
3142 e, a
3143 )),
3144 }
3145}
3146
3147fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
3148 match &stmt_desc.relation_desc {
3149 Some(desc) if !stmt_desc.is_copy => {
3150 BackendMessage::RowDescription(message::encode_row_description(desc, formats))
3151 }
3152 _ => BackendMessage::NoData,
3153 }
3154}
3155
3156type GetResponse = fn(
3157 max_rows: ExecuteCount,
3158 total_sent_rows: usize,
3159 fetch_portal: Option<PortalRefMut>,
3160) -> BackendMessage;
3161
3162fn portal_exec_message(
3165 max_rows: ExecuteCount,
3166 total_sent_rows: usize,
3167 _fetch_portal: Option<PortalRefMut>,
3168) -> BackendMessage {
3169 match max_rows {
3176 ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
3177 BackendMessage::PortalSuspended
3178 }
3179 _ => BackendMessage::CommandComplete {
3180 tag: format!("SELECT {}", total_sent_rows),
3181 },
3182 }
3183}
3184
3185fn fetch_message(
3187 _max_rows: ExecuteCount,
3188 total_sent_rows: usize,
3189 fetch_portal: Option<PortalRefMut>,
3190) -> BackendMessage {
3191 let tag = format!("FETCH {}", total_sent_rows);
3192 if let Some(portal) = fetch_portal {
3193 *portal.state = PortalState::Completed(Some(tag.clone()));
3194 }
3195 BackendMessage::CommandComplete { tag }
3196}
3197
3198fn get_authenticator(
3199 authenticator_kind: listeners::AuthenticatorKind,
3200 frontegg: Option<FronteggAuthenticator>,
3201 oidc: GenericOidcAuthenticator,
3202 adapter_client: mz_adapter::Client,
3203) -> Authenticator {
3204 match authenticator_kind {
3205 listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
3206 "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
3207 )),
3208 listeners::AuthenticatorKind::Password => Authenticator::Password(adapter_client),
3209 listeners::AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client),
3210 listeners::AuthenticatorKind::Oidc => Authenticator::Oidc(oidc),
3211 listeners::AuthenticatorKind::None => Authenticator::None,
3212 }
3213}
3214
3215#[derive(Debug, Copy, Clone)]
3216enum ExecuteCount {
3217 All,
3218 Count(usize),
3219}
3220
3221fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
3223 match stmt {
3224 Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
3226 None => false,
3227 }
3228}
3229
3230#[derive(Debug)]
3231enum FetchResult {
3232 Rows(Option<Box<dyn RowIterator + Send + Sync>>),
3233 Canceled,
3234 Error(String),
3235 Notice(AdapterNotice),
3236}
3237
3238#[derive(Debug)]
3239struct CopyRowScanner {
3240 scan_pos: usize,
3241 last_row_end: Option<usize>,
3242 end_marker_end: Option<usize>,
3243 record_start: usize,
3248 csv: Option<CsvScanState>,
3249}
3250
3251#[derive(Debug)]
3252struct CsvScanState {
3253 reader: csv_core::Reader,
3254 output: Vec<u8>,
3255 ends: Vec<usize>,
3256 skip_first_record: bool,
3257}
3258
3259impl CopyRowScanner {
3260 fn new(params: &CopyFormatParams<'_>) -> Self {
3261 let csv = match params {
3262 CopyFormatParams::Csv(CopyCsvFormatParams {
3263 delimiter,
3264 quote,
3265 escape,
3266 header,
3267 ..
3268 }) => Some(CsvScanState::new(*delimiter, *quote, *escape, *header)),
3269 _ => None,
3270 };
3271
3272 CopyRowScanner {
3273 scan_pos: 0,
3274 last_row_end: None,
3275 end_marker_end: None,
3276 record_start: 0,
3277 csv,
3278 }
3279 }
3280
3281 fn scan_new_bytes(&mut self, data: &[u8]) {
3282 if self.scan_pos >= data.len() {
3283 return;
3284 }
3285
3286 if let Some(csv) = self.csv.as_mut() {
3287 let mut input = &data[self.scan_pos..];
3288 let mut consumed = 0usize;
3289 while !input.is_empty() {
3290 let (result, n_input, _n_output, _n_ends) =
3291 csv.reader
3292 .read_record(input, &mut csv.output, &mut csv.ends);
3293 consumed += n_input;
3294 input = &input[n_input..];
3295
3296 match result {
3297 ReadRecordResult::InputEmpty => break,
3298 ReadRecordResult::OutputFull => {
3299 if n_input == 0 {
3300 csv.output
3301 .resize(csv.output.len().saturating_mul(2).max(1), 0);
3302 }
3303 }
3304 ReadRecordResult::OutputEndsFull => {
3305 if n_input == 0 {
3306 csv.ends.resize(csv.ends.len().saturating_mul(2).max(1), 0);
3307 }
3308 }
3309 ReadRecordResult::Record | ReadRecordResult::End => {
3310 let row_end = self.scan_pos + consumed;
3311 self.last_row_end = Some(row_end);
3312 if self.end_marker_end.is_none() {
3313 let is_marker = if csv.skip_first_record {
3314 csv.skip_first_record = false;
3315 false
3316 } else {
3317 let raw = &data[self.record_start..row_end];
3323 let start = raw
3334 .iter()
3335 .take_while(|&&b| b == b'\r' || b == b'\n')
3336 .count();
3337 let trailing = raw[start..]
3338 .iter()
3339 .rev()
3340 .take_while(|&&b| b == b'\r' || b == b'\n')
3341 .count();
3342 let trimmed = &raw[start..raw.len() - trailing];
3343 trimmed == b"\\."
3344 };
3345 if is_marker {
3346 self.end_marker_end = Some(row_end);
3347 self.record_start = row_end;
3348 break;
3349 }
3350 }
3351 self.record_start = row_end;
3352 }
3353 }
3354 }
3355 } else {
3356 let mut row_start = self.last_row_end.unwrap_or(0);
3357 for (offset, b) in data[self.scan_pos..].iter().enumerate() {
3358 if *b == b'\n' {
3359 let row_end = self.scan_pos + offset + 1;
3360 self.last_row_end = Some(row_end);
3361 if self.end_marker_end.is_none() {
3362 let row = &data[row_start..row_end];
3363 if row.get(0..2) == Some(b"\\.") {
3364 self.end_marker_end = Some(row_end);
3365 break;
3366 }
3367 }
3368 row_start = row_end;
3369 }
3370 }
3371 }
3372
3373 self.scan_pos = data.len();
3374 }
3375
3376 fn last_row_end(&self) -> Option<usize> {
3377 self.last_row_end
3378 }
3379
3380 fn end_marker_end(&self) -> Option<usize> {
3381 self.end_marker_end
3382 }
3383
3384 fn current_row_size(&self, data_len: usize) -> usize {
3385 data_len.saturating_sub(self.last_row_end.unwrap_or(0))
3386 }
3387
3388 fn on_split(&mut self, split_pos: usize) {
3389 self.scan_pos = self.scan_pos.saturating_sub(split_pos);
3390 self.last_row_end = None;
3391 self.end_marker_end = self
3392 .end_marker_end
3393 .and_then(|end| end.checked_sub(split_pos));
3394 soft_assert_or_log!(
3402 self.csv.is_none() || self.record_start >= split_pos,
3403 "split bisected an in-progress CSV record: record_start={} < split_pos={}",
3404 self.record_start,
3405 split_pos,
3406 );
3407 self.record_start = self.record_start.saturating_sub(split_pos);
3408 }
3409
3410 fn on_truncate(&mut self, new_len: usize) {
3411 self.scan_pos = self.scan_pos.min(new_len);
3412 self.last_row_end = self.last_row_end.filter(|&end| end <= new_len);
3413 self.end_marker_end = self.end_marker_end.filter(|&end| end <= new_len);
3414 self.record_start = self.record_start.min(new_len);
3415 }
3416}
3417
3418impl CsvScanState {
3419 fn new(delimiter: u8, quote: u8, escape: u8, header: bool) -> Self {
3420 let (double_quote, escape) = if quote == escape {
3421 (true, None)
3422 } else {
3423 (false, Some(escape))
3424 };
3425 CsvScanState {
3426 reader: csv_core::ReaderBuilder::new()
3427 .delimiter(delimiter)
3428 .quote(quote)
3429 .double_quote(double_quote)
3430 .escape(escape)
3431 .build(),
3432 output: vec![0; 1],
3433 ends: vec![0; 1],
3434 skip_first_record: header,
3435 }
3436 }
3437}
3438
3439#[cfg(test)]
3440mod test {
3441 use super::*;
3442
3443 #[mz_ore::test]
3444 fn test_copy_row_scanner_end_marker_line_endings() {
3445 let params = CopyFormatParams::Csv(CopyCsvFormatParams::default());
3453
3454 let marker_end = |data: &[u8]| -> Option<usize> {
3455 let mut scanner = CopyRowScanner::new(¶ms);
3456 scanner.scan_new_bytes(data);
3457 scanner.end_marker_end()
3458 };
3459
3460 for eol in [&b"\n"[..], b"\r\n", b"\r"] {
3461 let join = |lines: &[&str]| -> Vec<u8> {
3462 let mut out = Vec::new();
3463 for line in lines {
3464 out.extend_from_slice(line.as_bytes());
3465 out.extend_from_slice(eol);
3466 }
3467 out
3468 };
3469
3470 let data = join(&["first", "\\.", "after"]);
3475 let mut prefix = Vec::new();
3476 prefix.extend_from_slice(b"first");
3477 prefix.extend_from_slice(eol);
3478 prefix.extend_from_slice(b"\\.");
3479 assert_eq!(
3480 marker_end(&data),
3481 Some(prefix.len() + 1),
3482 "bare marker, eol={eol:?}"
3483 );
3484
3485 let data = join(&["before", "\"\\.\"", "after"]);
3487 assert_eq!(marker_end(&data), None, "quoted marker, eol={eol:?}");
3488 }
3489 }
3490
3491 #[mz_ore::test]
3492 fn test_copy_row_scanner_non_csv_split() {
3493 for params in [
3500 CopyFormatParams::Text(CopyTextFormatParams::default()),
3501 CopyFormatParams::Binary,
3502 ] {
3503 let mut scanner = CopyRowScanner::new(¶ms);
3504 let data = b"1\thello world\t2\tsome text value here\n\
3505 3\thello world\t6\tsome text value here\n";
3506 scanner.scan_new_bytes(data);
3507 let split_pos = scanner.last_row_end().expect("a complete row");
3508 assert!(split_pos > 0, "params={params:?}");
3509 scanner.on_split(split_pos);
3511 assert_eq!(scanner.record_start, 0, "params={params:?}");
3512 }
3513 }
3514
3515 #[mz_ore::test]
3516 fn test_parse_options() {
3517 struct TestCase {
3518 input: &'static str,
3519 expect: Result<Vec<(&'static str, &'static str)>, ()>,
3520 }
3521 let tests = vec![
3522 TestCase {
3523 input: "",
3524 expect: Ok(vec![]),
3525 },
3526 TestCase {
3527 input: "--key",
3528 expect: Err(()),
3529 },
3530 TestCase {
3531 input: "--key=val",
3532 expect: Ok(vec![("key", "val")]),
3533 },
3534 TestCase {
3535 input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
3536 expect: Ok(vec![
3537 ("key", "val"),
3538 ("key2", "val2"),
3539 ("key3", "val3"),
3540 ("key4", "val4"),
3541 ("key5", "val5"),
3542 ]),
3543 },
3544 TestCase {
3545 input: r#"-c\ key=val"#,
3546 expect: Ok(vec![(" key", "val")]),
3547 },
3548 TestCase {
3549 input: "--key=val -ckey2 val2",
3550 expect: Err(()),
3551 },
3552 TestCase {
3554 input: "--key=",
3555 expect: Ok(vec![("key", "")]),
3556 },
3557 ];
3558 for test in tests {
3559 let got = parse_options(test.input);
3560 let expect = test.expect.map(|r| {
3561 r.into_iter()
3562 .map(|(k, v)| (k.to_owned(), v.to_owned()))
3563 .collect()
3564 });
3565 assert_eq!(got, expect, "input: {}", test.input);
3566 }
3567 }
3568
3569 #[mz_ore::test]
3570 fn test_parse_option() {
3571 struct TestCase {
3572 input: &'static str,
3573 expect: Result<(&'static str, &'static str), ()>,
3574 }
3575 let tests = vec![
3576 TestCase {
3577 input: "",
3578 expect: Err(()),
3579 },
3580 TestCase {
3581 input: "--",
3582 expect: Err(()),
3583 },
3584 TestCase {
3585 input: "--c",
3586 expect: Err(()),
3587 },
3588 TestCase {
3589 input: "a=b",
3590 expect: Err(()),
3591 },
3592 TestCase {
3593 input: "--a=b",
3594 expect: Ok(("a", "b")),
3595 },
3596 TestCase {
3597 input: "--ca=b",
3598 expect: Ok(("ca", "b")),
3599 },
3600 TestCase {
3601 input: "-ca=b",
3602 expect: Ok(("a", "b")),
3603 },
3604 TestCase {
3606 input: "--=",
3607 expect: Ok(("", "")),
3608 },
3609 ];
3610 for test in tests {
3611 let got = parse_option(test.input);
3612 assert_eq!(got, test.expect, "input: {}", test.input);
3613 }
3614 }
3615
3616 #[mz_ore::test]
3617 fn test_split_options() {
3618 struct TestCase {
3619 input: &'static str,
3620 expect: Vec<&'static str>,
3621 }
3622 let tests = vec![
3623 TestCase {
3624 input: "",
3625 expect: vec![],
3626 },
3627 TestCase {
3628 input: " ",
3629 expect: vec![],
3630 },
3631 TestCase {
3632 input: " a ",
3633 expect: vec!["a"],
3634 },
3635 TestCase {
3636 input: " ab cd ",
3637 expect: vec!["ab", "cd"],
3638 },
3639 TestCase {
3640 input: r#" ab\ cd "#,
3641 expect: vec!["ab ", "cd"],
3642 },
3643 TestCase {
3644 input: r#" ab\\ cd "#,
3645 expect: vec![r#"ab\"#, "cd"],
3646 },
3647 TestCase {
3648 input: r#" ab\\\ cd "#,
3649 expect: vec![r#"ab\ "#, "cd"],
3650 },
3651 TestCase {
3652 input: r#" ab\\\ cd "#,
3653 expect: vec![r#"ab\ cd"#],
3654 },
3655 TestCase {
3656 input: r#" ab\\\cd "#,
3657 expect: vec![r#"ab\cd"#],
3658 },
3659 TestCase {
3660 input: r#"a\"#,
3661 expect: vec!["a"],
3662 },
3663 TestCase {
3664 input: r#"a\ "#,
3665 expect: vec!["a "],
3666 },
3667 TestCase {
3668 input: r#"\"#,
3669 expect: vec![],
3670 },
3671 TestCase {
3672 input: r#"\ "#,
3673 expect: vec![r#" "#],
3674 },
3675 TestCase {
3676 input: r#" \ "#,
3677 expect: vec![r#" "#],
3678 },
3679 TestCase {
3680 input: r#"\ "#,
3681 expect: vec![r#" "#],
3682 },
3683 ];
3684 for test in tests {
3685 let got = split_options(test.input);
3686 assert_eq!(got, test.expect, "input: {}", test.input);
3687 }
3688 }
3689
3690 #[mz_ore::test]
3691 fn test_is_jwt() {
3692 assert!(is_jwt("eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0.signature"));
3694 for s in [
3696 "",
3697 "secure_password",
3698 "p4ss.w0rd",
3699 "aaa.bbb.ccc",
3700 "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0",
3701 "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0.sig.extra",
3702 ] {
3703 assert!(!is_jwt(s), "is_jwt({s:?})");
3704 }
3705 }
3706}