1use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use csv_core::ReadRecordResult;
21use futures::future::{BoxFuture, FutureExt, pending};
22use itertools::Itertools;
23use mz_adapter::client::RecordFirstRowStream;
24use mz_adapter::session::{
25 EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState, Session,
26 SessionConfig, TransactionStatus,
27};
28use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
29use mz_adapter::{
30 AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics,
31 verify_datum_desc,
32};
33use mz_auth::Authenticated;
34use mz_auth::password::Password;
35use mz_authenticator::{Authenticator, GenericOidcAuthenticator};
36use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
37use mz_ore::cast::CastFrom;
38use mz_ore::netio::AsyncReady;
39use mz_ore::now::{EpochMillis, SYSTEM_TIME};
40use mz_ore::str::StrExt;
41use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log, soft_assert_or_log};
42use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
43use mz_pgwire_common::{
44 ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
45 VERSIONS,
46};
47use mz_repr::{
48 CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
49 SqlRelationType, SqlScalarType,
50};
51use mz_server_core::TlsMode;
52use mz_server_core::listeners;
53use mz_server_core::listeners::AllowedRoles;
54use mz_sql::ast::display::AstDisplay;
55use mz_sql::ast::{
56 CopyDirection, CopyStatement, CopyTarget, FetchDirection, Ident, Raw, Statement,
57};
58use mz_sql::parse::StatementParseResult;
59use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
60use mz_sql::session::metadata::SessionMetadata;
61use mz_sql::session::user::INTERNAL_USER_NAMES;
62use mz_sql::session::vars::VarInput;
63use postgres::error::SqlState;
64use tokio::io::{self, AsyncRead, AsyncWrite};
65use tokio::select;
66use tokio::time::{self};
67use tokio_metrics::TaskMetrics;
68use tokio_stream::wrappers::UnboundedReceiverStream;
69use tracing::{Instrument, debug, debug_span, warn};
70use uuid::Uuid;
71
72use crate::codec::{
73 FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
74};
75use crate::message::{
76 self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
77 SASLServerFirstMessage,
78};
79
80pub fn match_handshake(buf: &[u8]) -> bool {
84 if buf.len() < 8 {
94 return false;
95 }
96 let version = NetworkEndian::read_i32(&buf[4..8]);
97 VERSIONS.contains(&version)
98}
99
100pub struct RunParams<'a, A, I>
102where
103 I: Iterator<Item = TaskMetrics> + Send,
104{
105 pub tls_mode: Option<TlsMode>,
107 pub adapter_client: mz_adapter::Client,
109 pub conn: &'a mut FramedConn<A>,
111 pub conn_uuid: Uuid,
113 pub version: i32,
115 pub params: BTreeMap<String, String>,
117 pub frontegg: Option<FronteggAuthenticator>,
119 pub oidc: GenericOidcAuthenticator,
121 pub authenticator_kind: listeners::AuthenticatorKind,
124 pub active_connection_counter: ConnectionCounter,
126 pub helm_chart_version: Option<String>,
128 pub allowed_roles: AllowedRoles,
130 pub tokio_metrics_intervals: I,
132}
133
134#[mz_ore::instrument(level = "debug")]
144pub async fn run<'a, A, I>(
145 RunParams {
146 tls_mode,
147 adapter_client,
148 conn,
149 conn_uuid,
150 version,
151 mut params,
152 frontegg,
153 oidc,
154 authenticator_kind,
155 active_connection_counter,
156 helm_chart_version,
157 allowed_roles,
158 tokio_metrics_intervals,
159 }: RunParams<'a, A, I>,
160) -> Result<(), io::Error>
161where
162 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
163 I: Iterator<Item = TaskMetrics> + Send,
164{
165 if version != VERSION_3 {
166 return conn
167 .send(ErrorResponse::fatal(
168 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
169 "server does not support the client's requested protocol version",
170 ))
171 .await;
172 }
173
174 let user = params.remove("user").unwrap_or_else(String::new);
175 let options = parse_options(params.get("options").unwrap_or(&String::new()));
176 let authenticator =
177 get_authenticator(authenticator_kind, frontegg, oidc, adapter_client.clone());
178 let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
180 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
182 let role_allowed = match allowed_roles {
183 AllowedRoles::Normal => !is_reserved_user,
184 AllowedRoles::Internal => is_internal_user,
185 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
186 };
187 if !role_allowed {
188 let msg = format!("unauthorized login to user '{user}'");
189 return conn
190 .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
191 .await;
192 }
193
194 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
195 return conn.send(err).await;
196 }
197
198 let authenticator_kind = authenticator.kind();
199
200 let (mut session, expired) = match authenticator {
201 Authenticator::Frontegg(frontegg) => {
202 let password = match request_cleartext_password(conn).await {
203 Ok(password) => password,
204 Err(PasswordRequestError::IoError(e)) => return Err(e),
205 Err(PasswordRequestError::InvalidPasswordError(e)) => {
206 return conn.send(e).await;
207 }
208 };
209
210 let auth_response = frontegg.authenticate(&user, &password).await;
211 match auth_response {
212 Ok((mut auth_session, authenticated)) => {
219 let session = adapter_client.new_session(
220 SessionConfig {
221 conn_id: conn.conn_id().clone(),
222 uuid: conn_uuid,
223 user: auth_session.user().into(),
224 client_ip: conn.peer_addr().clone(),
225 external_metadata_rx: Some(auth_session.external_metadata_rx()),
226 helm_chart_version,
227 authenticator_kind,
228 groups: None,
229 },
230 authenticated,
231 );
232 let expired = async move { auth_session.expired().await };
233 (session, expired.left_future())
234 }
235 Err(err) => {
236 warn!(?err, "pgwire connection failed authentication");
237 return conn
238 .send(ErrorResponse::fatal(
239 SqlState::INVALID_PASSWORD,
240 "invalid password",
241 ))
242 .await;
243 }
244 }
245 }
246 Authenticator::Oidc(oidc) => {
247 let password = match request_cleartext_password(conn).await {
250 Ok(password) => password,
251 Err(PasswordRequestError::IoError(e)) => return Err(e),
252 Err(PasswordRequestError::InvalidPasswordError(e)) => {
253 return conn.send(e).await;
254 }
255 };
256 if is_jwt(&password) {
257 let auth_response = oidc.authenticate(&password, Some(&user)).await;
258 match auth_response {
259 Ok((mut claims, authenticated)) => {
260 let groups = claims.groups.take();
261 let session = adapter_client.new_session(
262 SessionConfig {
263 conn_id: conn.conn_id().clone(),
264 uuid: conn_uuid,
265 user: std::mem::take(&mut claims.user),
266 client_ip: conn.peer_addr().clone(),
267 external_metadata_rx: None,
268 helm_chart_version,
269 authenticator_kind,
270 groups,
271 },
272 authenticated,
273 );
274 (session, pending().right_future())
277 }
278 Err(err) => {
279 warn!(?err, "pgwire connection failed authentication");
280 return conn.send(err.into_response()).await;
281 }
282 }
283 } else {
284 let session = match authenticate_with_password(
285 conn,
286 &adapter_client,
287 user,
288 Password(password),
289 conn_uuid,
290 helm_chart_version,
291 )
292 .await
293 {
294 Ok(session) => session,
295 Err(PasswordRequestError::IoError(e)) => return Err(e),
296 Err(PasswordRequestError::InvalidPasswordError(e)) => {
297 return conn.send(e).await;
298 }
299 };
300 (session, pending().right_future())
301 }
302 }
303 Authenticator::Password(adapter_client) => {
304 let password = match request_cleartext_password(conn).await {
305 Ok(password) => password,
306 Err(PasswordRequestError::IoError(e)) => return Err(e),
307 Err(PasswordRequestError::InvalidPasswordError(e)) => {
308 return conn.send(e).await;
309 }
310 };
311 let session = match authenticate_with_password(
312 conn,
313 &adapter_client,
314 user,
315 Password(password),
316 conn_uuid,
317 helm_chart_version,
318 )
319 .await
320 {
321 Ok(session) => session,
322 Err(PasswordRequestError::IoError(e)) => return Err(e),
323 Err(PasswordRequestError::InvalidPasswordError(e)) => {
324 return conn.send(e).await;
325 }
326 };
327 (session, pending().right_future())
329 }
330 Authenticator::Sasl(adapter_client) => {
331 conn.send(BackendMessage::AuthenticationSASL).await?;
333 conn.flush().await?;
334 let (mechanism, initial_response) = match conn.recv().await? {
336 Some(FrontendMessage::RawAuthentication(data)) => {
337 match decode_sasl_initial_response(Cursor::new(&data)).ok() {
338 Some(FrontendMessage::SASLInitialResponse {
339 gs2_header,
340 mechanism,
341 initial_response,
342 }) => {
343 if gs2_header.channel_binding_enabled() {
345 return conn
346 .send(ErrorResponse::fatal(
347 SqlState::PROTOCOL_VIOLATION,
348 "channel binding not supported",
349 ))
350 .await;
351 }
352 (mechanism, initial_response)
353 }
354 _ => {
355 return conn
356 .send(ErrorResponse::fatal(
357 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
358 "expected SASLInitialResponse message",
359 ))
360 .await;
361 }
362 }
363 }
364 _ => {
365 return conn
366 .send(ErrorResponse::fatal(
367 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
368 "expected SASLInitialResponse message",
369 ))
370 .await;
371 }
372 };
373
374 if mechanism != "SCRAM-SHA-256" {
375 return conn
376 .send(ErrorResponse::fatal(
377 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
378 "unsupported SASL mechanism",
379 ))
380 .await;
381 }
382
383 if initial_response.nonce.len() > 256 {
384 return conn
385 .send(ErrorResponse::fatal(
386 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
387 "nonce too long",
388 ))
389 .await;
390 }
391
392 let (server_first_message_raw, mock_hash) = match adapter_client
393 .generate_sasl_challenge(&user, &initial_response.nonce)
394 .await
395 {
396 Ok(response) => {
397 let server_first_message_raw = format!(
398 "r={},s={},i={}",
399 response.nonce, response.salt, response.iteration_count
400 );
401
402 let client_key = [0u8; 32];
403 let server_key = [1u8; 32];
404 let mock_hash = format!(
405 "SCRAM-SHA-256${}:{}${}:{}",
406 response.iteration_count,
407 response.salt,
408 BASE64_STANDARD.encode(client_key),
409 BASE64_STANDARD.encode(server_key)
410 );
411
412 conn.send(BackendMessage::AuthenticationSASLContinue(
413 SASLServerFirstMessage {
414 iteration_count: response.iteration_count,
415 nonce: response.nonce,
416 salt: response.salt,
417 },
418 ))
419 .await?;
420 conn.flush().await?;
421 (server_first_message_raw, mock_hash)
422 }
423 Err(e) => {
424 return conn.send(e.into_response(Severity::Fatal)).await;
425 }
426 };
427
428 let authenticated = match conn.recv().await? {
429 Some(FrontendMessage::RawAuthentication(data)) => {
430 match decode_sasl_response(Cursor::new(&data)).ok() {
431 Some(FrontendMessage::SASLResponse(response)) => {
432 let auth_message = format!(
433 "{},{},{}",
434 initial_response.client_first_message_bare_raw,
435 server_first_message_raw,
436 response.client_final_message_bare_raw
437 );
438 if response.proof.len() > 1024 {
439 return conn
440 .send(ErrorResponse::fatal(
441 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
442 "proof too long",
443 ))
444 .await;
445 }
446 match adapter_client
447 .verify_sasl_proof(
448 &user,
449 &response.proof,
450 &auth_message,
451 &mock_hash,
452 )
453 .await
454 {
455 Ok((proof_response, authenticated)) => {
456 conn.send(BackendMessage::AuthenticationSASLFinal(
457 SASLServerFinalMessage {
458 kind: SASLServerFinalMessageKinds::Verifier(
459 proof_response.verifier,
460 ),
461 extensions: vec![],
462 },
463 ))
464 .await?;
465 conn.flush().await?;
466 authenticated
467 }
468 Err(_) => {
469 return conn
470 .send(ErrorResponse::fatal(
471 SqlState::INVALID_PASSWORD,
472 "invalid password",
473 ))
474 .await;
475 }
476 }
477 }
478 _ => {
479 return conn
480 .send(ErrorResponse::fatal(
481 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
482 "expected SASLResponse message",
483 ))
484 .await;
485 }
486 }
487 }
488 _ => {
489 return conn
490 .send(ErrorResponse::fatal(
491 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
492 "expected SASLResponse message",
493 ))
494 .await;
495 }
496 };
497
498 let session = adapter_client.new_session(
499 SessionConfig {
500 conn_id: conn.conn_id().clone(),
501 uuid: conn_uuid,
502 user,
503 client_ip: conn.peer_addr().clone(),
504 external_metadata_rx: None,
505 helm_chart_version,
506 authenticator_kind,
507 groups: None,
508 },
509 authenticated,
510 );
511 let auth_session = pending().right_future();
513 (session, auth_session)
514 }
515
516 Authenticator::None => {
517 let session = adapter_client.new_session(
518 SessionConfig {
519 conn_id: conn.conn_id().clone(),
520 uuid: conn_uuid,
521 user,
522 client_ip: conn.peer_addr().clone(),
523 external_metadata_rx: None,
524 helm_chart_version,
525 authenticator_kind,
526 groups: None,
527 },
528 Authenticated,
529 );
530 let auth_session = pending().right_future();
532 (session, auth_session)
533 }
534 };
535
536 let system_vars = adapter_client.get_system_vars().await;
537 for (name, value) in params {
538 let settings = match name.as_str() {
539 "options" => match &options {
540 Ok(opts) => opts,
541 Err(()) => {
542 session.add_notice(AdapterNotice::BadStartupSetting {
543 name,
544 reason: "could not parse".into(),
545 });
546 continue;
547 }
548 },
549 _ => &vec![(name, value)],
550 };
551 for (key, val) in settings {
552 const LOCAL: bool = false;
553 if let Err(err) = session
558 .vars_mut()
559 .set(&system_vars, key, VarInput::Flat(val), LOCAL)
560 {
561 session.add_notice(AdapterNotice::BadStartupSetting {
562 name: key.clone(),
563 reason: err.to_string(),
564 });
565 }
566 }
567 }
568 session
569 .vars_mut()
570 .end_transaction(EndTransactionAction::Commit);
571
572 let _guard = match active_connection_counter.allocate_connection(session.user()) {
573 Ok(drop_connection) => drop_connection,
574 Err(e) => {
575 let e: AdapterError = e.into();
576 return conn.send(e.into_response(Severity::Fatal)).await;
577 }
578 };
579
580 let mut adapter_client = match adapter_client.startup(session).await {
582 Ok(adapter_client) => adapter_client,
583 Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
584 };
585
586 let mut buf = vec![BackendMessage::AuthenticationOk];
587 for var in adapter_client.session().vars().notify_set() {
588 buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
589 }
590 buf.push(BackendMessage::BackendKeyData {
591 conn_id: adapter_client.session().conn_id().unhandled(),
592 secret_key: adapter_client.session().secret_key(),
593 });
594 buf.extend(
595 adapter_client
596 .session()
597 .drain_notices()
598 .into_iter()
599 .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
600 );
601 buf.push(BackendMessage::ReadyForQuery(
602 adapter_client.session().transaction().into(),
603 ));
604 conn.send_all(buf).await?;
605 conn.flush().await?;
606
607 let machine = StateMachine {
608 conn,
609 adapter_client,
610 txn_needs_commit: false,
611 tokio_metrics_intervals,
612 };
613
614 select! {
615 r = machine.run() => {
616 if let Err(err) = &r {
621 let _ = conn
622 .send(ErrorResponse::fatal(
623 SqlState::CONNECTION_FAILURE,
624 err.to_string(),
625 ))
626 .await;
627 let _ = conn.flush().await;
628 }
629 r
630 },
631 _ = expired => {
632 conn
633 .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
634 .await?;
635 conn.flush().await
636 }
637 }
638}
639
640fn is_jwt(password: &str) -> bool {
643 jsonwebtoken::decode_header(password).is_ok()
644}
645
646fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
651 let opts = split_options(value);
652 let mut pairs = Vec::with_capacity(opts.len());
653 let mut seen_prefix = false;
654 for opt in opts {
655 if !seen_prefix {
656 if opt == "-c" {
657 seen_prefix = true;
658 } else {
659 let (key, val) = parse_option(&opt)?;
660 pairs.push((key.to_owned(), val.to_owned()));
661 }
662 } else {
663 let (key, val) = opt.split_once('=').ok_or(())?;
664 pairs.push((key.to_owned(), val.to_owned()));
665 seen_prefix = false;
666 }
667 }
668 Ok(pairs)
669}
670
671fn parse_option(option: &str) -> Result<(&str, &str), ()> {
675 let (key, value) = option.split_once('=').ok_or(())?;
676 for prefix in &["-c", "--"] {
677 if let Some(key) = key.strip_prefix(prefix) {
678 return Ok((key, value));
679 }
680 }
681 Err(())
682}
683
684fn split_options(value: &str) -> Vec<String> {
686 let mut strs = Vec::new();
687 let mut current = String::new();
691 let mut was_slash = false;
692 for c in value.chars() {
693 was_slash = match c {
694 ' ' => {
695 if was_slash {
696 current.push(' ');
697 } else if !current.is_empty() {
698 strs.push(std::mem::take(&mut current));
701 }
702 false
703 }
704 '\\' => {
705 if was_slash {
706 current.push('\\');
709 false
710 } else {
711 true
712 }
713 }
714 _ => {
715 current.push(c);
716 false
717 }
718 };
719 }
720 if !current.is_empty() {
722 strs.push(current);
723 }
724 strs
725}
726
727enum PasswordRequestError {
728 InvalidPasswordError(ErrorResponse),
729 IoError(io::Error),
730}
731
732impl From<io::Error> for PasswordRequestError {
733 fn from(e: io::Error) -> Self {
734 PasswordRequestError::IoError(e)
735 }
736}
737
738async fn request_cleartext_password<A>(
742 conn: &mut FramedConn<A>,
743) -> Result<String, PasswordRequestError>
744where
745 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
746{
747 conn.send(BackendMessage::AuthenticationCleartextPassword)
748 .await?;
749 conn.flush().await?;
750
751 if let Some(message) = conn.recv().await? {
752 if let FrontendMessage::RawAuthentication(data) = message {
753 if let Some(FrontendMessage::Password { password }) =
754 decode_password(Cursor::new(&data)).ok()
755 {
756 return Ok(password);
757 }
758 }
759 }
760
761 Err(PasswordRequestError::InvalidPasswordError(
762 ErrorResponse::fatal(
763 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
764 "expected Password message",
765 ),
766 ))
767}
768
769async fn authenticate_with_password<A>(
772 conn: &FramedConn<A>,
773 adapter_client: &mz_adapter::Client,
774 user: String,
775 password: Password,
776 conn_uuid: Uuid,
777 helm_chart_version: Option<String>,
778) -> Result<Session, PasswordRequestError>
779where
780 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
781{
782 let authenticated = match adapter_client.authenticate(&user, &password).await {
783 Ok(authenticated) => authenticated,
784 Err(err) => {
785 warn!(?err, "pgwire connection failed authentication");
786 return Err(PasswordRequestError::InvalidPasswordError(
787 ErrorResponse::fatal(SqlState::INVALID_PASSWORD, "invalid password"),
788 ));
789 }
790 };
791
792 let session = adapter_client.new_session(
793 SessionConfig {
794 conn_id: conn.conn_id().clone(),
795 uuid: conn_uuid,
796 user,
797 client_ip: conn.peer_addr().clone(),
798 external_metadata_rx: None,
799 helm_chart_version,
800 authenticator_kind: mz_auth::AuthenticatorKind::Password,
801 groups: None,
802 },
803 authenticated,
804 );
805
806 Ok(session)
807}
808
809#[derive(Debug)]
810enum State {
811 Ready,
812 Drain,
813 Done,
814}
815
816struct StateMachine<'a, A, I>
817where
818 I: Iterator<Item = TaskMetrics> + Send + 'a,
819{
820 conn: &'a mut FramedConn<A>,
821 adapter_client: mz_adapter::SessionClient,
822 txn_needs_commit: bool,
823 tokio_metrics_intervals: I,
824}
825
826enum SendRowsEndedReason {
827 Success {
828 result_size: u64,
829 rows_returned: u64,
830 },
831 Errored {
832 error: String,
833 },
834 Canceled,
835}
836
837const ABORTED_TXN_MSG: &str =
838 "current transaction is aborted, commands ignored until end of transaction block";
839
840impl<'a, A, I> StateMachine<'a, A, I>
841where
842 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
843 I: Iterator<Item = TaskMetrics> + Send + 'a,
844{
845 #[allow(clippy::manual_async_fn)]
849 #[mz_ore::instrument(level = "debug")]
850 fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
851 async move {
852 let mut state = State::Ready;
853 loop {
854 self.send_pending_notices().await?;
855 state = match state {
856 State::Ready => self.advance_ready().await?,
857 State::Drain => self.advance_drain().await?,
858 State::Done => return Ok(()),
859 };
860 self.adapter_client
861 .add_idle_in_transaction_session_timeout();
862 }
863 }
864 }
865
866 #[instrument(level = "debug")]
867 async fn advance_ready(&mut self) -> Result<State, io::Error> {
868 self.tokio_metrics_intervals
870 .next()
871 .expect("infinite iterator");
872
873 let message = select! {
875 biased;
876
877 Some(timeout) = self.adapter_client.recv_timeout() => {
879 let err: AdapterError = timeout.into();
880 let conn_id = self.adapter_client.session().conn_id();
881 tracing::warn!("session timed out, conn_id {}", conn_id);
882
883 let error_response = err.into_response(Severity::Fatal);
885 let error_state = self.send_error_and_get_state(error_response).await;
886
887 self.adapter_client.terminate().await;
889
890 let _ = self.conn.recv().await?;
894 return error_state;
895 },
896 message = self.conn.recv() => message?,
898 };
899
900 let interval = self
902 .tokio_metrics_intervals
903 .next()
904 .expect("infinite iterator");
905 let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
906
907 let received = SYSTEM_TIME();
911
912 self.adapter_client
913 .remove_idle_in_transaction_session_timeout();
914
915 let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
918
919 let start = message.as_ref().map(|_| Instant::now());
920 let next_state = match message {
921 Some(FrontendMessage::Query { sql }) => {
922 let query_root_span =
923 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
924 query_root_span.follows_from(tracing::Span::current());
925 self.query(sql, received)
926 .instrument(query_root_span)
927 .await?
928 }
929 Some(FrontendMessage::Parse {
930 name,
931 sql,
932 param_types,
933 }) => self.parse(name, sql, param_types).await?,
934 Some(FrontendMessage::Bind {
935 portal_name,
936 statement_name,
937 param_formats,
938 raw_params,
939 result_formats,
940 }) => {
941 self.bind(
942 portal_name,
943 statement_name,
944 param_formats,
945 raw_params,
946 result_formats,
947 )
948 .await?
949 }
950 Some(FrontendMessage::Execute {
951 portal_name,
952 max_rows,
953 }) => {
954 let max_rows = match usize::try_from(max_rows) {
955 Ok(0) | Err(_) => ExecuteCount::All, Ok(n) => ExecuteCount::Count(n),
957 };
958 let execute_root_span =
959 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
960 execute_root_span.follows_from(tracing::Span::current());
961 let state = self
962 .execute(
963 portal_name,
964 max_rows,
965 portal_exec_message,
966 None,
967 ExecuteTimeout::None,
968 None,
969 Some(received),
970 )
971 .instrument(execute_root_span)
972 .await?;
973 if self.adapter_client.session().transaction().is_implicit() {
988 self.txn_needs_commit = true;
989 }
990 state
991 }
992 Some(FrontendMessage::DescribeStatement { name }) => {
993 self.describe_statement(&name).await?
994 }
995 Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
996 Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
997 Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
998 Some(FrontendMessage::Flush) => self.flush().await?,
999 Some(FrontendMessage::Sync) => self.sync().await?,
1000 Some(FrontendMessage::Terminate) => State::Done,
1001
1002 Some(FrontendMessage::CopyData(_))
1003 | Some(FrontendMessage::CopyDone)
1004 | Some(FrontendMessage::CopyFail(_))
1005 | Some(FrontendMessage::Password { .. })
1006 | Some(FrontendMessage::RawAuthentication(_))
1007 | Some(FrontendMessage::SASLInitialResponse { .. })
1008 | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
1009 None => State::Done,
1010 };
1011
1012 if let Some(start) = start {
1013 self.adapter_client
1014 .inner()
1015 .metrics()
1016 .pgwire_message_processing_seconds
1017 .with_label_values(&[message_name])
1018 .observe(start.elapsed().as_secs_f64());
1019 }
1020 self.adapter_client
1021 .inner()
1022 .metrics()
1023 .pgwire_recv_scheduling_delay_ms
1024 .with_label_values(&[message_name])
1025 .observe(recv_scheduling_delay_ms);
1026
1027 Ok(next_state)
1028 }
1029
1030 async fn advance_drain(&mut self) -> Result<State, io::Error> {
1031 let message = self.conn.recv().await?;
1032 if message.is_some() {
1033 self.adapter_client
1034 .remove_idle_in_transaction_session_timeout();
1035 }
1036 match message {
1037 Some(FrontendMessage::Sync) => self.sync().await,
1038 None => Ok(State::Done),
1039 _ => Ok(State::Drain),
1040 }
1041 }
1042
1043 #[instrument(level = "debug")]
1047 async fn one_query(
1048 &mut self,
1049 stmt: Statement<Raw>,
1050 sql: String,
1051 lifecycle_timestamps: LifecycleTimestamps,
1052 ) -> Result<State, io::Error> {
1053 const EMPTY_PORTAL: &str = "";
1056 if let Err(e) = self
1057 .adapter_client
1058 .declare(EMPTY_PORTAL.to_string(), stmt, sql)
1059 .await
1060 {
1061 return self
1062 .send_error_and_get_state(e.into_response(Severity::Error))
1063 .await;
1064 }
1065 let portal = self
1066 .adapter_client
1067 .session()
1068 .get_portal_unverified_mut(EMPTY_PORTAL)
1069 .expect("unnamed portal should be present");
1070
1071 *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
1072
1073 let stmt_desc = portal.desc.clone();
1074 if !stmt_desc.param_types.is_empty() {
1075 return self
1076 .send_error_and_get_state(ErrorResponse::error(
1077 SqlState::UNDEFINED_PARAMETER,
1078 "there is no parameter $1",
1079 ))
1080 .await;
1081 }
1082
1083 if let Some(relation_desc) = &stmt_desc.relation_desc {
1085 if !stmt_desc.is_copy {
1086 let formats = vec![Format::Text; stmt_desc.arity()];
1087 self.send(BackendMessage::RowDescription(
1088 message::encode_row_description(relation_desc, &formats),
1089 ))
1090 .await?;
1091 }
1092 }
1093
1094 let result = match self
1095 .adapter_client
1096 .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
1097 .await
1098 {
1099 Ok((response, execute_started)) => {
1100 self.send_pending_notices().await?;
1101 self.send_execute_response(
1102 response,
1103 stmt_desc.relation_desc,
1104 EMPTY_PORTAL.to_string(),
1105 ExecuteCount::All,
1106 portal_exec_message,
1107 None,
1108 ExecuteTimeout::None,
1109 execute_started,
1110 )
1111 .await
1112 }
1113 Err(e) => {
1114 self.send_pending_notices().await?;
1115 self.send_error_and_get_state(e.into_response(Severity::Error))
1116 .await
1117 }
1118 };
1119
1120 self.adapter_client.session().remove_portal(EMPTY_PORTAL);
1122
1123 result
1124 }
1125
1126 async fn ensure_transaction(
1127 &mut self,
1128 num_stmts: usize,
1129 message_type: &str,
1130 ) -> Result<(), io::Error> {
1131 let start = Instant::now();
1132 if self.txn_needs_commit {
1133 self.commit_transaction().await?;
1134 }
1135 let res = self.adapter_client.start_transaction(Some(num_stmts));
1138 assert_ok!(res);
1139 self.adapter_client
1140 .inner()
1141 .metrics()
1142 .pgwire_ensure_transaction_seconds
1143 .with_label_values(&[message_type])
1144 .observe(start.elapsed().as_secs_f64());
1145 Ok(())
1146 }
1147
1148 fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1149 let parse_start = Instant::now();
1150 let result = match self.adapter_client.parse(sql) {
1151 Ok(result) => result.map_err(|e| {
1152 let pos = sql[..e.error.pos].chars().count() + 1;
1155 ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1156 }),
1157 Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1158 };
1159 self.adapter_client
1160 .inner()
1161 .metrics()
1162 .parse_seconds
1163 .observe(parse_start.elapsed().as_secs_f64());
1164 result
1165 }
1166
1167 #[instrument(level = "debug")]
1172 async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1173 let stmts = match self.parse_sql(&sql) {
1175 Ok(stmts) => stmts,
1176 Err(err) => {
1177 self.send_error_and_get_state(err).await?;
1178 return self.ready().await;
1179 }
1180 };
1181
1182 let num_stmts = stmts.len();
1183
1184 for StatementParseResult { ast: stmt, sql } in stmts {
1186 if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1188 self.aborted_txn_error().await?;
1189 break;
1190 }
1191
1192 self.ensure_transaction(num_stmts, "query").await?;
1200
1201 match self
1202 .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1203 .await?
1204 {
1205 State::Ready => (),
1206 State::Drain => break,
1207 State::Done => return Ok(State::Done),
1208 }
1209 }
1210
1211 {
1213 if self.adapter_client.session().transaction().is_implicit() {
1214 self.commit_transaction().await?;
1215 }
1216 }
1217
1218 if num_stmts == 0 {
1219 self.send(BackendMessage::EmptyQueryResponse).await?;
1220 }
1221
1222 self.ready().await
1223 }
1224
1225 #[instrument(level = "debug")]
1226 async fn parse(
1227 &mut self,
1228 name: String,
1229 sql: String,
1230 param_oids: Vec<u32>,
1231 ) -> Result<State, io::Error> {
1232 self.ensure_transaction(1, "parse").await?;
1234
1235 let mut param_types = vec![];
1236 for oid in param_oids {
1237 match mz_pgrepr::Type::from_oid(oid) {
1238 Ok(ty) => match SqlScalarType::try_from(&ty) {
1239 Ok(ty) => param_types.push(Some(ty)),
1240 Err(err) => {
1241 return self
1242 .send_error_and_get_state(ErrorResponse::error(
1243 SqlState::INVALID_PARAMETER_VALUE,
1244 err.to_string(),
1245 ))
1246 .await;
1247 }
1248 },
1249 Err(_) if oid == 0 => param_types.push(None),
1250 Err(e) => {
1251 return self
1252 .send_error_and_get_state(ErrorResponse::error(
1253 SqlState::PROTOCOL_VIOLATION,
1254 e.to_string(),
1255 ))
1256 .await;
1257 }
1258 }
1259 }
1260
1261 let stmts = match self.parse_sql(&sql) {
1262 Ok(stmts) => stmts,
1263 Err(err) => {
1264 return self.send_error_and_get_state(err).await;
1265 }
1266 };
1267 if stmts.len() > 1 {
1268 return self
1269 .send_error_and_get_state(ErrorResponse::error(
1270 SqlState::INTERNAL_ERROR,
1271 "cannot insert multiple commands into a prepared statement",
1272 ))
1273 .await;
1274 }
1275 let (maybe_stmt, sql) = match stmts.into_iter().next() {
1276 None => (None, ""),
1277 Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1278 };
1279 if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1280 return self.aborted_txn_error().await;
1281 }
1282 match self
1283 .adapter_client
1284 .prepare(name, maybe_stmt, sql.to_string(), param_types)
1285 .await
1286 {
1287 Ok(()) => {
1288 self.send(BackendMessage::ParseComplete).await?;
1289 Ok(State::Ready)
1290 }
1291 Err(e) => {
1292 self.send_error_and_get_state(e.into_response(Severity::Error))
1293 .await
1294 }
1295 }
1296 }
1297
1298 #[instrument(level = "debug")]
1300 async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1301 self.end_transaction(EndTransactionAction::Commit).await
1302 }
1303
1304 #[instrument(level = "debug")]
1306 async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1307 self.end_transaction(EndTransactionAction::Rollback).await
1308 }
1309
1310 #[instrument(level = "debug")]
1312 async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1313 self.txn_needs_commit = false;
1314 let resp = self.adapter_client.end_transaction(action).await;
1315 if let Err(err) = resp {
1316 self.send(BackendMessage::ErrorResponse(
1317 err.into_response(Severity::Error),
1318 ))
1319 .await?;
1320 }
1321 Ok(())
1322 }
1323
1324 #[instrument(level = "debug")]
1325 async fn bind(
1326 &mut self,
1327 portal_name: String,
1328 statement_name: String,
1329 param_formats: Vec<Format>,
1330 raw_params: Vec<Option<Vec<u8>>>,
1331 result_formats: Vec<Format>,
1332 ) -> Result<State, io::Error> {
1333 self.ensure_transaction(1, "bind").await?;
1335
1336 let aborted_txn = self.is_aborted_txn();
1337 let stmt = match self
1338 .adapter_client
1339 .get_prepared_statement(&statement_name)
1340 .await
1341 {
1342 Ok(stmt) => stmt,
1343 Err(err) => {
1344 return self
1345 .send_error_and_get_state(err.into_response(Severity::Error))
1346 .await;
1347 }
1348 };
1349
1350 let param_types = &stmt.desc().param_types;
1351 if param_types.len() != raw_params.len() {
1352 let message = format!(
1353 "bind message supplies {actual} parameters, \
1354 but prepared statement \"{name}\" requires {expected}",
1355 name = statement_name,
1356 actual = raw_params.len(),
1357 expected = param_types.len()
1358 );
1359 return self
1360 .send_error_and_get_state(ErrorResponse::error(
1361 SqlState::PROTOCOL_VIOLATION,
1362 message,
1363 ))
1364 .await;
1365 }
1366 let param_formats = match pad_formats(param_formats, raw_params.len()) {
1367 Ok(param_formats) => param_formats,
1368 Err(msg) => {
1369 return self
1370 .send_error_and_get_state(ErrorResponse::error(
1371 SqlState::PROTOCOL_VIOLATION,
1372 msg,
1373 ))
1374 .await;
1375 }
1376 };
1377 if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1378 return self.aborted_txn_error().await;
1379 }
1380 let buf = RowArena::new();
1381 let mut params = vec![];
1382 for ((raw_param, mz_typ), format) in raw_params
1383 .into_iter()
1384 .zip_eq(param_types)
1385 .zip_eq(param_formats)
1386 {
1387 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1388 let datum = match raw_param {
1389 None => Datum::Null,
1390 Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1391 Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1392 Ok(datum) => datum,
1393 Err(msg) => {
1394 return self
1395 .send_error_and_get_state(ErrorResponse::error(
1396 SqlState::INVALID_PARAMETER_VALUE,
1397 msg,
1398 ))
1399 .await;
1400 }
1401 },
1402 Err(err) => {
1403 let msg = format!("unable to decode parameter: {}", err);
1404 return self
1405 .send_error_and_get_state(ErrorResponse::error(
1406 SqlState::INVALID_PARAMETER_VALUE,
1407 msg,
1408 ))
1409 .await;
1410 }
1411 },
1412 };
1413 params.push((datum, mz_typ.clone()))
1414 }
1415
1416 let result_formats = match pad_formats(
1417 result_formats,
1418 stmt.desc()
1419 .relation_desc
1420 .clone()
1421 .map(|desc| desc.typ().column_types.len())
1422 .unwrap_or(0),
1423 ) {
1424 Ok(result_formats) => result_formats,
1425 Err(msg) => {
1426 return self
1427 .send_error_and_get_state(ErrorResponse::error(
1428 SqlState::PROTOCOL_VIOLATION,
1429 msg,
1430 ))
1431 .await;
1432 }
1433 };
1434
1435 if !stmt.stmt().map_or(false, |stmt| match stmt {
1438 Statement::Copy(CopyStatement {
1439 direction: CopyDirection::To,
1440 ..
1441 }) => true,
1442 Statement::Copy(CopyStatement {
1443 direction: CopyDirection::From,
1444 target: CopyTarget::Expr(_),
1448 ..
1449 }) => true,
1450 _ => false,
1451 }) {
1452 if let Some(desc) = stmt.desc().relation_desc.clone() {
1453 for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1454 match (format, &ty.scalar_type) {
1455 (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1456 return self
1457 .send_error_and_get_state(ErrorResponse::error(
1458 SqlState::PROTOCOL_VIOLATION,
1459 "binary encoding of list types is not implemented",
1460 ))
1461 .await;
1462 }
1463 (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1464 return self
1465 .send_error_and_get_state(ErrorResponse::error(
1466 SqlState::PROTOCOL_VIOLATION,
1467 "binary encoding of map types is not implemented",
1468 ))
1469 .await;
1470 }
1471 (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1472 return self
1473 .send_error_and_get_state(ErrorResponse::error(
1474 SqlState::PROTOCOL_VIOLATION,
1475 "binary encoding of aclitem types does not exist",
1476 ))
1477 .await;
1478 }
1479 _ => (),
1480 }
1481 }
1482 }
1483 }
1484
1485 let desc = stmt.desc().clone();
1486 let logging = Arc::clone(stmt.logging());
1487 let stmt_ast = stmt.stmt().cloned();
1488 let state_revision = stmt.state_revision;
1489 if let Err(err) = self.adapter_client.session().set_portal(
1490 portal_name,
1491 desc,
1492 stmt_ast,
1493 logging,
1494 params,
1495 result_formats,
1496 state_revision,
1497 ) {
1498 return self
1499 .send_error_and_get_state(err.into_response(Severity::Error))
1500 .await;
1501 }
1502
1503 self.send(BackendMessage::BindComplete).await?;
1504 Ok(State::Ready)
1505 }
1506
1507 fn execute(
1510 &mut self,
1511 portal_name: String,
1512 max_rows: ExecuteCount,
1513 get_response: GetResponse,
1514 fetch_portal_name: Option<String>,
1515 timeout: ExecuteTimeout,
1516 outer_ctx_extra: Option<ExecuteContextGuard>,
1517 received: Option<EpochMillis>,
1518 ) -> BoxFuture<'_, Result<State, io::Error>> {
1519 async move {
1520 let aborted_txn = self.is_aborted_txn();
1521
1522 let portal = match self
1524 .adapter_client
1525 .session()
1526 .get_portal_unverified_mut(&portal_name)
1527 {
1528 Some(portal) => portal,
1529 None => {
1530 let msg = format!("portal {} does not exist", portal_name.quoted());
1531 if let Some(outer_ctx_extra) = outer_ctx_extra {
1532 self.adapter_client.retire_execute(
1533 outer_ctx_extra,
1534 StatementEndedExecutionReason::Errored { error: msg.clone() },
1535 );
1536 }
1537 return self
1538 .send_error_and_get_state(ErrorResponse::error(
1539 SqlState::INVALID_CURSOR_NAME,
1540 msg,
1541 ))
1542 .await;
1543 }
1544 };
1545
1546 *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1547
1548 let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1550 if aborted_txn && !txn_exit_stmt {
1551 if let Some(outer_ctx_extra) = outer_ctx_extra {
1552 self.adapter_client.retire_execute(
1553 outer_ctx_extra,
1554 StatementEndedExecutionReason::Errored {
1555 error: ABORTED_TXN_MSG.to_string(),
1556 },
1557 );
1558 }
1559 return self.aborted_txn_error().await;
1560 }
1561
1562 let row_desc = portal.desc.relation_desc.clone();
1563 match portal.state {
1564 PortalState::NotStarted => {
1565 self.ensure_transaction(1, "execute").await?;
1567 match self
1568 .adapter_client
1569 .execute(
1570 portal_name.clone(),
1571 self.conn.wait_closed(),
1572 outer_ctx_extra,
1573 )
1574 .await
1575 {
1576 Ok((response, execute_started)) => {
1577 self.send_pending_notices().await?;
1578 self.send_execute_response(
1579 response,
1580 row_desc,
1581 portal_name,
1582 max_rows,
1583 get_response,
1584 fetch_portal_name,
1585 timeout,
1586 execute_started,
1587 )
1588 .await
1589 }
1590 Err(e) => {
1591 self.send_pending_notices().await?;
1592 self.send_error_and_get_state(e.into_response(Severity::Error))
1593 .await
1594 }
1595 }
1596 }
1597 PortalState::InProgress(rows) => {
1598 let rows = rows.take().expect("InProgress rows must be populated");
1599 let (result, statement_ended_execution_reason) = match self
1600 .send_rows(
1601 row_desc.expect("portal missing row desc on resumption"),
1602 portal_name,
1603 rows,
1604 max_rows,
1605 get_response,
1606 fetch_portal_name,
1607 timeout,
1608 )
1609 .await
1610 {
1611 Err(e) => {
1612 (Err(e), StatementEndedExecutionReason::Canceled)
1615 }
1616 Ok((ok, SendRowsEndedReason::Canceled)) => {
1617 (Ok(ok), StatementEndedExecutionReason::Canceled)
1618 }
1619 Ok((
1632 ok,
1633 SendRowsEndedReason::Success {
1634 result_size: _,
1635 rows_returned: _,
1636 },
1637 )) => (
1638 Ok(ok),
1639 StatementEndedExecutionReason::Success {
1640 result_size: None,
1641 rows_returned: None,
1642 execution_strategy: None,
1643 },
1644 ),
1645 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1646 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1647 }
1648 };
1649 if let Some(outer_ctx_extra) = outer_ctx_extra {
1650 self.adapter_client
1651 .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1652 }
1653 result
1654 }
1655 PortalState::Completed(Some(tag)) => {
1662 let tag = tag.to_string();
1663 if let Some(outer_ctx_extra) = outer_ctx_extra {
1664 self.adapter_client.retire_execute(
1665 outer_ctx_extra,
1666 StatementEndedExecutionReason::Success {
1667 result_size: None,
1668 rows_returned: None,
1669 execution_strategy: None,
1670 },
1671 );
1672 }
1673 self.send(BackendMessage::CommandComplete { tag }).await?;
1674 Ok(State::Ready)
1675 }
1676 PortalState::Completed(None) => {
1677 let error = format!(
1678 "portal {} cannot be run",
1679 Ident::new_unchecked(portal_name).to_ast_string_stable()
1680 );
1681 if let Some(outer_ctx_extra) = outer_ctx_extra {
1682 self.adapter_client.retire_execute(
1683 outer_ctx_extra,
1684 StatementEndedExecutionReason::Errored {
1685 error: error.clone(),
1686 },
1687 );
1688 }
1689 self.send_error_and_get_state(ErrorResponse::error(
1690 SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1691 error,
1692 ))
1693 .await
1694 }
1695 }
1696 }
1697 .instrument(debug_span!("execute"))
1698 .boxed()
1699 }
1700
1701 #[instrument(level = "debug")]
1702 async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1703 self.ensure_transaction(1, "describe_statement").await?;
1705
1706 let stmt = match self.adapter_client.get_prepared_statement(name).await {
1707 Ok(stmt) => stmt,
1708 Err(err) => {
1709 return self
1710 .send_error_and_get_state(err.into_response(Severity::Error))
1711 .await;
1712 }
1713 };
1714 let parameter_desc = BackendMessage::ParameterDescription(
1716 stmt.desc()
1717 .param_types
1718 .iter()
1719 .map(mz_pgrepr::Type::from)
1720 .collect(),
1721 );
1722 let formats = vec![Format::Text; stmt.desc().arity()];
1726 let row_desc = describe_rows(stmt.desc(), &formats);
1727 self.send_all([parameter_desc, row_desc]).await?;
1728 Ok(State::Ready)
1729 }
1730
1731 #[instrument(level = "debug")]
1732 async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1733 self.ensure_transaction(1, "describe_portal").await?;
1735
1736 let session = self.adapter_client.session();
1737 let row_desc = session
1738 .get_portal_unverified(name)
1739 .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1740 match row_desc {
1741 Some(row_desc) => {
1742 self.send(row_desc).await?;
1743 Ok(State::Ready)
1744 }
1745 None => {
1746 self.send_error_and_get_state(ErrorResponse::error(
1747 SqlState::INVALID_CURSOR_NAME,
1748 format!("portal {} does not exist", name.quoted()),
1749 ))
1750 .await
1751 }
1752 }
1753 }
1754
1755 #[instrument(level = "debug")]
1756 async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1757 self.adapter_client
1758 .session()
1759 .remove_prepared_statement(&name);
1760 self.send(BackendMessage::CloseComplete).await?;
1761 Ok(State::Ready)
1762 }
1763
1764 #[instrument(level = "debug")]
1765 async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1766 self.adapter_client.session().remove_portal(&name);
1767 self.send(BackendMessage::CloseComplete).await?;
1768 Ok(State::Ready)
1769 }
1770
1771 fn complete_portal(&mut self, name: &str) {
1772 let portal = self
1773 .adapter_client
1774 .session()
1775 .get_portal_unverified_mut(name)
1776 .expect("portal should exist");
1777 *portal.state = PortalState::Completed(None);
1778 }
1779
1780 async fn fetch(
1781 &mut self,
1782 name: String,
1783 count: Option<FetchDirection>,
1784 max_rows: ExecuteCount,
1785 fetch_portal_name: Option<String>,
1786 timeout: ExecuteTimeout,
1787 ctx_extra: ExecuteContextGuard,
1788 ) -> Result<State, io::Error> {
1789 let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1792
1793 let count = match (max_rows, count) {
1806 (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1807 let count = usize::cast_from(count);
1808 if max_rows < count {
1809 let msg = "Execute with max_rows < a FETCH's count is not supported";
1810 self.adapter_client.retire_execute(
1811 ctx_extra,
1812 StatementEndedExecutionReason::Errored {
1813 error: msg.to_string(),
1814 },
1815 );
1816 return self
1817 .send_error_and_get_state(ErrorResponse::error(
1818 SqlState::FEATURE_NOT_SUPPORTED,
1819 msg,
1820 ))
1821 .await;
1822 }
1823 ExecuteCount::Count(count)
1824 }
1825 (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1826 let msg = "Execute with max_rows of a FETCH ALL is not supported";
1827 self.adapter_client.retire_execute(
1828 ctx_extra,
1829 StatementEndedExecutionReason::Errored {
1830 error: msg.to_string(),
1831 },
1832 );
1833 return self
1834 .send_error_and_get_state(ErrorResponse::error(
1835 SqlState::FEATURE_NOT_SUPPORTED,
1836 msg,
1837 ))
1838 .await;
1839 }
1840 (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1841 (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1842 ExecuteCount::Count(usize::cast_from(count))
1843 }
1844 };
1845 let cursor_name = name.to_string();
1846 self.execute(
1847 cursor_name,
1848 count,
1849 fetch_message,
1850 fetch_portal_name,
1851 timeout,
1852 Some(ctx_extra),
1853 None,
1854 )
1855 .await
1856 }
1857
1858 async fn flush(&mut self) -> Result<State, io::Error> {
1859 self.conn.flush().await?;
1860 Ok(State::Ready)
1861 }
1862
1863 #[instrument(level = "debug")]
1868 async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1869 where
1870 M: Into<BackendMessage>,
1871 {
1872 let message: BackendMessage = message.into();
1873 let is_error =
1874 matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1875
1876 self.conn.send(message).await?;
1877
1878 if is_error {
1885 self.conn.flush().await?;
1886 }
1887
1888 Ok(())
1889 }
1890
1891 #[instrument(level = "debug")]
1892 pub async fn send_all(
1893 &mut self,
1894 messages: impl IntoIterator<Item = BackendMessage>,
1895 ) -> Result<(), io::Error> {
1896 for m in messages {
1897 self.send(m).await?;
1898 }
1899 Ok(())
1900 }
1901
1902 #[instrument(level = "debug")]
1903 async fn sync(&mut self) -> Result<State, io::Error> {
1904 if self.adapter_client.session().transaction().is_implicit() {
1906 self.commit_transaction().await?;
1907 }
1908 self.ready().await
1909 }
1910
1911 #[instrument(level = "debug")]
1912 async fn ready(&mut self) -> Result<State, io::Error> {
1913 let txn_state = self.adapter_client.session().transaction().into();
1914 self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1915 self.flush().await
1916 }
1917
1918 #[allow(clippy::too_many_arguments)]
1919 #[instrument(level = "debug")]
1920 async fn send_execute_response(
1921 &mut self,
1922 response: ExecuteResponse,
1923 row_desc: Option<RelationDesc>,
1924 portal_name: String,
1925 max_rows: ExecuteCount,
1926 get_response: GetResponse,
1927 fetch_portal_name: Option<String>,
1928 timeout: ExecuteTimeout,
1929 execute_started: Instant,
1930 ) -> Result<State, io::Error> {
1931 let mut tag = response.tag();
1932
1933 macro_rules! command_complete {
1934 () => {{
1935 self.send(BackendMessage::CommandComplete {
1936 tag: tag
1937 .take()
1938 .expect("command_complete only called on tag-generating results"),
1939 })
1940 .await?;
1941 Ok(State::Ready)
1942 }};
1943 }
1944
1945 let r = match response {
1946 ExecuteResponse::ClosedCursor => {
1947 self.complete_portal(&portal_name);
1948 command_complete!()
1949 }
1950 ExecuteResponse::DeclaredCursor => {
1951 self.complete_portal(&portal_name);
1952 command_complete!()
1953 }
1954 ExecuteResponse::EmptyQuery => {
1955 self.send(BackendMessage::EmptyQueryResponse).await?;
1956 Ok(State::Ready)
1957 }
1958 ExecuteResponse::Fetch {
1959 name,
1960 count,
1961 timeout,
1962 ctx_extra,
1963 } => {
1964 self.fetch(
1965 name,
1966 count,
1967 max_rows,
1968 Some(portal_name.to_string()),
1969 timeout,
1970 ctx_extra,
1971 )
1972 .await
1973 }
1974 ExecuteResponse::SendingRowsStreaming {
1975 rows,
1976 instance_id,
1977 strategy,
1978 } => {
1979 let row_desc = row_desc
1980 .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1981
1982 let span = tracing::debug_span!("sending_rows_streaming");
1983
1984 self.send_rows(
1985 row_desc,
1986 portal_name,
1987 InProgressRows::new(RecordFirstRowStream::new(
1988 Box::new(rows),
1989 execute_started,
1990 &self.adapter_client,
1991 Some(instance_id),
1992 Some(strategy),
1993 )),
1994 max_rows,
1995 get_response,
1996 fetch_portal_name,
1997 timeout,
1998 )
1999 .instrument(span)
2000 .await
2001 .map(|(state, _)| state)
2002 }
2003 ExecuteResponse::SendingRowsImmediate { rows } => {
2004 let row_desc = row_desc
2005 .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
2006
2007 let span = tracing::debug_span!("sending_rows_immediate");
2008
2009 let stream =
2010 futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
2011 self.send_rows(
2012 row_desc,
2013 portal_name,
2014 InProgressRows::new(RecordFirstRowStream::new(
2015 Box::new(stream),
2016 execute_started,
2017 &self.adapter_client,
2018 None,
2019 Some(StatementExecutionStrategy::Constant),
2020 )),
2021 max_rows,
2022 get_response,
2023 fetch_portal_name,
2024 timeout,
2025 )
2026 .instrument(span)
2027 .await
2028 .map(|(state, _)| state)
2029 }
2030 ExecuteResponse::SetVariable { name, .. } => {
2031 let qn = name.to_string();
2034 let msg = if let Some(var) = self
2035 .adapter_client
2036 .session()
2037 .vars_mut()
2038 .notify_set()
2039 .find(|v| v.name() == qn)
2040 {
2041 Some(BackendMessage::ParameterStatus(var.name(), var.value()))
2042 } else {
2043 None
2044 };
2045 if let Some(msg) = msg {
2046 self.send(msg).await?;
2047 }
2048 command_complete!()
2049 }
2050 ExecuteResponse::Subscribing {
2051 rx,
2052 ctx_extra,
2053 instance_id,
2054 } => {
2055 if fetch_portal_name.is_none() {
2056 let mut msg = ErrorResponse::notice(
2057 SqlState::WARNING,
2058 "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
2059 );
2060 if self.adapter_client.session().vars().application_name() == "psql" {
2061 msg.hint = Some(
2062 "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
2063 .into(),
2064 )
2065 }
2066 self.send(msg).await?;
2067 self.conn.flush().await?;
2068 }
2069 let row_desc =
2070 row_desc.expect("missing row description for ExecuteResponse::Subscribing");
2071 let (result, statement_ended_execution_reason) = match self
2072 .send_rows(
2073 row_desc,
2074 portal_name,
2075 InProgressRows::new(RecordFirstRowStream::new(
2076 Box::new(UnboundedReceiverStream::new(rx)),
2077 execute_started,
2078 &self.adapter_client,
2079 Some(instance_id),
2080 None,
2081 )),
2082 max_rows,
2083 get_response,
2084 fetch_portal_name,
2085 timeout,
2086 )
2087 .await
2088 {
2089 Err(e) => {
2090 (Err(e), StatementEndedExecutionReason::Canceled)
2093 }
2094 Ok((ok, SendRowsEndedReason::Canceled)) => {
2095 (Ok(ok), StatementEndedExecutionReason::Canceled)
2096 }
2097 Ok((
2098 ok,
2099 SendRowsEndedReason::Success {
2100 result_size,
2101 rows_returned,
2102 },
2103 )) => (
2104 Ok(ok),
2105 StatementEndedExecutionReason::Success {
2106 result_size: Some(result_size),
2107 rows_returned: Some(rows_returned),
2108 execution_strategy: None,
2109 },
2110 ),
2111 Ok((ok, SendRowsEndedReason::Errored { error })) => {
2112 (Ok(ok), StatementEndedExecutionReason::Errored { error })
2113 }
2114 };
2115 self.adapter_client
2116 .retire_execute(ctx_extra, statement_ended_execution_reason);
2117 return result;
2118 }
2119 ExecuteResponse::CopyTo { format, resp } => {
2120 let row_desc =
2121 row_desc.expect("missing row description for ExecuteResponse::CopyTo");
2122 match *resp {
2123 ExecuteResponse::Subscribing {
2124 rx,
2125 ctx_extra,
2126 instance_id,
2127 } => {
2128 let (result, statement_ended_execution_reason) = match self
2129 .copy_rows(
2130 format,
2131 row_desc,
2132 RecordFirstRowStream::new(
2133 Box::new(UnboundedReceiverStream::new(rx)),
2134 execute_started,
2135 &self.adapter_client,
2136 Some(instance_id),
2137 None,
2138 ),
2139 )
2140 .await
2141 {
2142 Err(e) => {
2143 (Err(e), StatementEndedExecutionReason::Canceled)
2146 }
2147 Ok((
2148 state,
2149 SendRowsEndedReason::Success {
2150 result_size,
2151 rows_returned,
2152 },
2153 )) => (
2154 Ok(state),
2155 StatementEndedExecutionReason::Success {
2156 result_size: Some(result_size),
2157 rows_returned: Some(rows_returned),
2158 execution_strategy: None,
2159 },
2160 ),
2161 Ok((state, SendRowsEndedReason::Errored { error })) => {
2162 (Ok(state), StatementEndedExecutionReason::Errored { error })
2163 }
2164 Ok((state, SendRowsEndedReason::Canceled)) => {
2165 (Ok(state), StatementEndedExecutionReason::Canceled)
2166 }
2167 };
2168 self.adapter_client
2169 .retire_execute(ctx_extra, statement_ended_execution_reason);
2170 return result;
2171 }
2172 ExecuteResponse::SendingRowsStreaming {
2173 rows,
2174 instance_id,
2175 strategy,
2176 } => {
2177 return self
2182 .copy_rows(
2183 format,
2184 row_desc,
2185 RecordFirstRowStream::new(
2186 Box::new(rows),
2187 execute_started,
2188 &self.adapter_client,
2189 Some(instance_id),
2190 Some(strategy),
2191 ),
2192 )
2193 .await
2194 .map(|(state, _)| state);
2195 }
2196 ExecuteResponse::SendingRowsImmediate { rows } => {
2197 let span = tracing::debug_span!("sending_rows_immediate");
2198
2199 let rows = futures::stream::once(futures::future::ready(
2200 PeekResponseUnary::Rows(rows),
2201 ));
2202 return self
2207 .copy_rows(
2208 format,
2209 row_desc,
2210 RecordFirstRowStream::new(
2211 Box::new(rows),
2212 execute_started,
2213 &self.adapter_client,
2214 None,
2215 Some(StatementExecutionStrategy::Constant),
2216 ),
2217 )
2218 .instrument(span)
2219 .await
2220 .map(|(state, _)| state);
2221 }
2222 _ => {
2223 return self
2224 .send_error_and_get_state(ErrorResponse::error(
2225 SqlState::INTERNAL_ERROR,
2226 "unsupported COPY response type".to_string(),
2227 ))
2228 .await;
2229 }
2230 };
2231 }
2232 ExecuteResponse::CopyFrom {
2233 target_id,
2234 target_name,
2235 columns,
2236 params,
2237 ctx_extra,
2238 } => {
2239 let row_desc =
2240 row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2241 self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2242 .await
2243 }
2244 ExecuteResponse::TransactionCommitted { params }
2245 | ExecuteResponse::TransactionRolledBack { params } => {
2246 let notify_set: mz_ore::collections::HashSet<String> = self
2247 .adapter_client
2248 .session()
2249 .vars()
2250 .notify_set()
2251 .map(|v| v.name().to_string())
2252 .collect();
2253
2254 for (name, value) in params
2256 .into_iter()
2257 .filter(|(name, _v)| notify_set.contains(*name))
2258 {
2259 let msg = BackendMessage::ParameterStatus(name, value);
2260 self.send(msg).await?;
2261 }
2262 command_complete!()
2263 }
2264
2265 ExecuteResponse::AlteredDefaultPrivileges
2266 | ExecuteResponse::AlteredObject(..)
2267 | ExecuteResponse::AlteredRole
2268 | ExecuteResponse::AlteredSystemConfiguration
2269 | ExecuteResponse::CreatedCluster { .. }
2270 | ExecuteResponse::CreatedClusterReplica { .. }
2271 | ExecuteResponse::CreatedConnection { .. }
2272 | ExecuteResponse::CreatedDatabase { .. }
2273 | ExecuteResponse::CreatedIndex { .. }
2274 | ExecuteResponse::CreatedIntrospectionSubscribe
2275 | ExecuteResponse::CreatedMaterializedView { .. }
2276 | ExecuteResponse::CreatedRole
2277 | ExecuteResponse::CreatedSchema { .. }
2278 | ExecuteResponse::CreatedSecret { .. }
2279 | ExecuteResponse::CreatedSink { .. }
2280 | ExecuteResponse::CreatedSource { .. }
2281 | ExecuteResponse::CreatedTable { .. }
2282 | ExecuteResponse::CreatedType
2283 | ExecuteResponse::CreatedView { .. }
2284 | ExecuteResponse::CreatedViews { .. }
2285 | ExecuteResponse::CreatedNetworkPolicy
2286 | ExecuteResponse::Comment
2287 | ExecuteResponse::Deallocate { .. }
2288 | ExecuteResponse::Deleted(..)
2289 | ExecuteResponse::DiscardedAll
2290 | ExecuteResponse::DiscardedTemp
2291 | ExecuteResponse::DroppedObject(_)
2292 | ExecuteResponse::DroppedOwned
2293 | ExecuteResponse::GrantedPrivilege
2294 | ExecuteResponse::GrantedRole
2295 | ExecuteResponse::Inserted(..)
2296 | ExecuteResponse::Copied(..)
2297 | ExecuteResponse::Prepare
2298 | ExecuteResponse::Raised
2299 | ExecuteResponse::ReassignOwned
2300 | ExecuteResponse::RevokedPrivilege
2301 | ExecuteResponse::RevokedRole
2302 | ExecuteResponse::StartedTransaction { .. }
2303 | ExecuteResponse::Updated(..)
2304 | ExecuteResponse::ValidatedConnection => {
2305 command_complete!()
2306 }
2307 };
2308
2309 assert_none!(tag, "tag created but not consumed: {:?}", tag);
2310 r
2311 }
2312
2313 #[allow(clippy::too_many_arguments)]
2314 #[mz_ore::instrument(level = "debug")]
2316 async fn send_rows(
2317 &mut self,
2318 row_desc: RelationDesc,
2319 portal_name: String,
2320 mut rows: InProgressRows,
2321 max_rows: ExecuteCount,
2322 get_response: GetResponse,
2323 fetch_portal_name: Option<String>,
2324 timeout: ExecuteTimeout,
2325 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2326 let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2329 name
2330 } else {
2331 &portal_name
2332 };
2333 let result_formats = self
2334 .adapter_client
2335 .session()
2336 .get_portal_unverified(result_format_portal_name)
2337 .expect("valid fetch portal name for send rows")
2338 .result_formats
2339 .clone();
2340
2341 let (mut wait_once, mut deadline) = match timeout {
2342 ExecuteTimeout::None => (false, None),
2343 ExecuteTimeout::Seconds(t) => (
2344 false,
2345 Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2346 ),
2347 ExecuteTimeout::WaitOnce => (true, None),
2348 };
2349
2350 {
2352 let portal_name_desc = &self
2353 .adapter_client
2354 .session()
2355 .get_portal_unverified(portal_name.as_str())
2356 .expect("portal should exist")
2357 .desc
2358 .relation_desc;
2359 if let Some(portal_name_desc) = portal_name_desc {
2360 soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2361 }
2362 if let Some(fetch_portal_name) = &fetch_portal_name {
2363 let fetch_portal_desc = &self
2364 .adapter_client
2365 .session()
2366 .get_portal_unverified(fetch_portal_name)
2367 .expect("portal should exist")
2368 .desc
2369 .relation_desc;
2370 if let Some(fetch_portal_desc) = fetch_portal_desc {
2371 soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2372 }
2373 }
2374 }
2375
2376 self.conn.set_encode_state(
2377 row_desc
2378 .typ()
2379 .column_types
2380 .iter()
2381 .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2382 .zip_eq(result_formats)
2383 .collect(),
2384 );
2385
2386 let mut total_sent_rows = 0;
2387 let mut total_sent_bytes = 0;
2388 let mut want_rows = match max_rows {
2390 ExecuteCount::All => usize::MAX,
2391 ExecuteCount::Count(count) => count,
2392 };
2393
2394 loop {
2396 let batch = if rows.current.is_some() {
2399 FetchResult::Rows(rows.current.take())
2400 } else if want_rows == 0 {
2401 FetchResult::Rows(None)
2402 } else {
2403 let notice_fut = self.adapter_client.session().recv_notice();
2404 tokio::select! {
2419 biased;
2420 err = self.conn.wait_closed() => return Err(err),
2421 batch = rows.remaining.recv() => match batch {
2422 None => FetchResult::Rows(None),
2423 Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2424 Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2425 Some(PeekResponseUnary::DependencyDropped(dep)) => {
2426 FetchResult::Error(dep.query_terminated_error())
2427 }
2428 Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2429 },
2430 notice = notice_fut => {
2431 FetchResult::Notice(notice)
2432 }
2433 _ = time::sleep_until(
2434 deadline.unwrap_or_else(tokio::time::Instant::now),
2435 ), if deadline.is_some() => FetchResult::Rows(None),
2436 }
2437 };
2438
2439 match batch {
2440 FetchResult::Rows(None) => break,
2441 FetchResult::Rows(Some(mut batch_rows)) => {
2442 if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2443 let msg = err.to_string();
2444 return self
2445 .send_error_and_get_state(err.into_response(Severity::Error))
2446 .await
2447 .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2448 }
2449
2450 if wait_once && batch_rows.peek().is_some() {
2454 deadline = Some(tokio::time::Instant::now());
2455 wait_once = false;
2456 }
2457
2458 let mut sent_rows = 0;
2460 let mut sent_bytes = 0;
2461 let messages = (&mut batch_rows)
2462 .map(|row| {
2467 let row_len = row.byte_len();
2468 let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2469 (row_len, BackendMessage::DataRow(values))
2470 })
2471 .inspect(|(row_len, _)| {
2472 sent_bytes += row_len;
2473 sent_rows += 1
2474 })
2475 .map(|(_row_len, row)| row)
2476 .take(want_rows);
2477 self.send_all(messages).await?;
2478
2479 total_sent_rows += sent_rows;
2480 total_sent_bytes += sent_bytes;
2481 want_rows -= sent_rows;
2482
2483 if want_rows == 0 {
2486 if batch_rows.peek().is_some() {
2487 rows.current = Some(batch_rows);
2488 }
2489 break;
2490 }
2491
2492 self.conn.flush().await?;
2493 }
2494 FetchResult::Notice(notice) => {
2495 self.send(notice.into_response()).await?;
2496 self.conn.flush().await?;
2497 }
2498 FetchResult::Error(text) => {
2499 return self
2500 .send_error_and_get_state(ErrorResponse::error(
2501 SqlState::INTERNAL_ERROR,
2502 text.clone(),
2503 ))
2504 .await
2505 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2506 }
2507 FetchResult::Canceled => {
2508 return self
2509 .send_error_and_get_state(ErrorResponse::error(
2510 SqlState::QUERY_CANCELED,
2511 "canceling statement due to user request",
2512 ))
2513 .await
2514 .map(|state| (state, SendRowsEndedReason::Canceled));
2515 }
2516 }
2517 }
2518
2519 let portal = self
2520 .adapter_client
2521 .session()
2522 .get_portal_unverified_mut(&portal_name)
2523 .expect("valid portal name for send rows");
2524
2525 let saw_rows = rows.remaining.saw_rows;
2526 let no_more_rows = rows.no_more_rows();
2527 let metric_recorded = rows.remaining.metric_recorded;
2528 let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2529
2530 if no_more_rows && !metric_recorded {
2531 rows.remaining.metric_recorded = true;
2532 }
2533
2534 *portal.state = PortalState::InProgress(Some(rows));
2537
2538 let fetch_portal = fetch_portal_name.map(|name| {
2539 self.adapter_client
2540 .session()
2541 .get_portal_unverified_mut(&name)
2542 .expect("valid fetch portal")
2543 });
2544 let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2545 self.send(response_message).await?;
2546
2547 if no_more_rows && !metric_recorded {
2550 let statement_type = if let Some(stmt) = &self
2551 .adapter_client
2552 .session()
2553 .get_portal_unverified(&portal_name)
2554 .expect("valid portal name for send_rows")
2555 .stmt
2556 {
2557 metrics::statement_type_label_value(stmt.deref())
2558 } else {
2559 "no-statement"
2560 };
2561 let duration = if saw_rows {
2562 recorded_first_row_instant
2563 .expect("recorded_first_row_instant because saw_rows")
2564 .elapsed()
2565 } else {
2566 Duration::ZERO
2570 };
2571 self.adapter_client
2572 .inner()
2573 .metrics()
2574 .result_rows_first_to_last_byte_seconds
2575 .with_label_values(&[statement_type])
2576 .observe(duration.as_secs_f64());
2577 }
2578
2579 Ok((
2580 State::Ready,
2581 SendRowsEndedReason::Success {
2582 result_size: u64::cast_from(total_sent_bytes),
2583 rows_returned: u64::cast_from(total_sent_rows),
2584 },
2585 ))
2586 }
2587
2588 #[mz_ore::instrument(level = "debug")]
2589 async fn copy_rows(
2590 &mut self,
2591 format: CopyFormat,
2592 row_desc: RelationDesc,
2593 mut stream: RecordFirstRowStream,
2594 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2595 let (row_format, encode_format) = match format {
2596 CopyFormat::Text => (
2597 CopyFormatParams::Text(CopyTextFormatParams::default()),
2598 Format::Text,
2599 ),
2600 CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2601 CopyFormat::Csv => (
2602 CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2603 Format::Text,
2604 ),
2605 CopyFormat::Parquet => {
2606 let text = "Parquet format is not supported".to_string();
2607 return self
2608 .send_error_and_get_state(ErrorResponse::error(
2609 SqlState::INTERNAL_ERROR,
2610 text.clone(),
2611 ))
2612 .await
2613 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2614 }
2615 };
2616
2617 let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2618 mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2619 };
2620
2621 let typ = row_desc.typ();
2622 let column_formats = iter::repeat(encode_format)
2623 .take(typ.column_types.len())
2624 .collect();
2625 self.send(BackendMessage::CopyOutResponse {
2626 overall_format: encode_format,
2627 column_formats,
2628 })
2629 .await?;
2630
2631 let mut out = Vec::new();
2636
2637 if let CopyFormat::Binary = format {
2638 out.extend(b"PGCOPY\n\xFF\r\n\0");
2640 out.extend([0, 0, 0, 0]);
2642 out.extend([0, 0, 0, 0]);
2644 }
2645
2646 let mut count = 0;
2647 let mut total_sent_bytes = 0;
2648 loop {
2649 tokio::select! {
2650 e = self.conn.wait_closed() => return Err(e),
2651 batch = stream.recv() => match batch {
2652 None => break,
2653 Some(PeekResponseUnary::Error(text)) => {
2654 let err =
2655 ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone());
2656 return self
2657 .send_error_and_get_state(err)
2658 .await
2659 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2660 }
2661 Some(PeekResponseUnary::DependencyDropped(dep)) => {
2662 let err = dep.to_concurrent_dependency_drop();
2663 let text = err.to_string();
2664 let resp = err.into_response(Severity::Error);
2665 return self
2666 .send_error_and_get_state(resp)
2667 .await
2668 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2669 }
2670 Some(PeekResponseUnary::Canceled) => {
2671 return self.send_error_and_get_state(ErrorResponse::error(
2672 SqlState::QUERY_CANCELED,
2673 "canceling statement due to user request",
2674 ))
2675 .await.map(|state| (state, SendRowsEndedReason::Canceled));
2676 }
2677 Some(PeekResponseUnary::Rows(mut rows)) => {
2678 count += rows.count();
2679 while let Some(row) = rows.next() {
2680 total_sent_bytes += row.byte_len();
2681 encode_fn(row, typ, &mut out)?;
2682 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2683 .await?;
2684 }
2685 }
2686 },
2687 notice = self.adapter_client.session().recv_notice() => {
2688 self.send(notice.into_response())
2689 .await?;
2690 self.conn.flush().await?;
2691 }
2692 }
2693
2694 self.conn.flush().await?;
2695 }
2696 if let CopyFormat::Binary = format {
2698 let trailer: i16 = -1;
2699 out.extend(trailer.to_be_bytes());
2700 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2701 .await?;
2702 }
2703
2704 let tag = format!("COPY {}", count);
2705 self.send(BackendMessage::CopyDone).await?;
2706 self.send(BackendMessage::CommandComplete { tag }).await?;
2707 Ok((
2708 State::Ready,
2709 SendRowsEndedReason::Success {
2710 result_size: u64::cast_from(total_sent_bytes),
2711 rows_returned: u64::cast_from(count),
2712 },
2713 ))
2714 }
2715
2716 #[instrument(level = "debug")]
2719 async fn copy_from(
2720 &mut self,
2721 target_id: CatalogItemId,
2722 target_name: String,
2723 columns: Vec<ColumnIndex>,
2724 params: CopyFormatParams<'static>,
2725 row_desc: RelationDesc,
2726 mut ctx_extra: ExecuteContextGuard,
2727 ) -> Result<State, io::Error> {
2728 let res = self
2729 .copy_from_inner(
2730 target_id,
2731 target_name,
2732 columns,
2733 params,
2734 row_desc,
2735 &mut ctx_extra,
2736 )
2737 .await;
2738 match &res {
2739 Ok(State::Ready) => {
2740 self.adapter_client.retire_execute(
2741 ctx_extra,
2742 StatementEndedExecutionReason::Success {
2743 result_size: None,
2744 rows_returned: None,
2745 execution_strategy: None,
2746 },
2747 );
2748 }
2749 Ok(State::Done) => {
2750 self.adapter_client
2754 .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2755 }
2756 Err(e) => {
2757 self.adapter_client.retire_execute(
2758 ctx_extra,
2759 StatementEndedExecutionReason::Errored {
2760 error: format!("{e}"),
2761 },
2762 );
2763 }
2764 Ok(State::Drain) => {}
2765 }
2766 res
2767 }
2768
2769 async fn copy_from_inner(
2770 &mut self,
2771 target_id: CatalogItemId,
2772 target_name: String,
2773 columns: Vec<ColumnIndex>,
2774 params: CopyFormatParams<'static>,
2775 row_desc: RelationDesc,
2776 ctx_extra: &mut ExecuteContextGuard,
2777 ) -> Result<State, io::Error> {
2778 let typ = row_desc.typ();
2779 let column_formats = vec![Format::Text; typ.column_types.len()];
2780 self.send(BackendMessage::CopyInResponse {
2781 overall_format: Format::Text,
2782 column_formats,
2783 })
2784 .await?;
2785 self.conn.flush().await?;
2786
2787 let writer = match self
2789 .adapter_client
2790 .start_copy_from_stdin(
2791 target_id,
2792 target_name.clone(),
2793 columns.clone(),
2794 row_desc.clone(),
2795 params.clone(),
2796 )
2797 .await
2798 {
2799 Ok(writer) => writer,
2800 Err(e) => {
2801 loop {
2807 match self.conn.recv().await? {
2808 Some(FrontendMessage::CopyData(_)) => {}
2809 Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2810 break;
2811 }
2812 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2813 Some(_) => break,
2814 None => return Ok(State::Done),
2815 }
2816 }
2817 self.adapter_client.retire_execute(
2818 std::mem::take(ctx_extra),
2819 StatementEndedExecutionReason::Errored {
2820 error: e.to_string(),
2821 },
2822 );
2823 return self
2824 .send_error_and_get_state(e.into_response(Severity::Error))
2825 .await;
2826 }
2827 };
2828
2829 self.conn.set_copy_mode(true);
2831
2832 const BATCH_SIZE: usize = 32 * 1024 * 1024;
2834 let max_copy_from_row_size = self
2835 .adapter_client
2836 .get_system_vars()
2837 .await
2838 .max_copy_from_row_size()
2839 .try_into()
2840 .unwrap_or(usize::MAX);
2841
2842 let mut data = Vec::new();
2843 let mut row_scanner = CopyRowScanner::new(¶ms);
2844 let num_workers = writer.batch_txs.len();
2845 let mut next_worker: usize = 0;
2846 let mut saw_copy_done = false;
2847 let mut saw_end_marker = false;
2848 let mut copy_from_error: Option<(SqlState, String)> = None;
2849
2850 loop {
2853 let message = self.conn.recv().await?;
2854 match message {
2855 Some(FrontendMessage::CopyData(buf)) => {
2856 if saw_end_marker {
2857 continue;
2860 }
2861 data.extend(buf);
2862 row_scanner.scan_new_bytes(&data);
2863
2864 if let Some(end_pos) = row_scanner.end_marker_end() {
2865 data.truncate(end_pos);
2866 row_scanner.on_truncate(end_pos);
2867 saw_end_marker = true;
2868 }
2869
2870 if row_scanner.current_row_size(data.len()) > max_copy_from_row_size {
2872 copy_from_error = Some((
2873 SqlState::INSUFFICIENT_RESOURCES,
2874 format!(
2875 "COPY FROM STDIN row exceeded max_copy_from_row_size \
2876 ({max_copy_from_row_size} bytes)"
2877 ),
2878 ));
2879 break;
2880 }
2881
2882 let mut send_failed = false;
2885 while data.len() >= BATCH_SIZE {
2886 let split_pos = match row_scanner.last_row_end() {
2887 Some(pos) => pos,
2888 None => break, };
2890 let remainder = data.split_off(split_pos);
2891 let chunk = std::mem::replace(&mut data, remainder);
2892 row_scanner.on_split(split_pos);
2893 if writer.batch_txs[next_worker].send(chunk).await.is_err() {
2894 send_failed = true;
2895 break;
2896 }
2897 next_worker = (next_worker + 1) % num_workers;
2898 }
2899 if send_failed {
2902 break;
2903 }
2904 }
2905 Some(FrontendMessage::CopyDone) => {
2906 if !data.is_empty() {
2908 let chunk = std::mem::take(&mut data);
2909 let _ = writer.batch_txs[next_worker].send(chunk).await;
2911 }
2912 saw_copy_done = true;
2913 break;
2914 }
2915 Some(FrontendMessage::CopyFail(err)) => {
2916 self.adapter_client.retire_execute(
2917 std::mem::take(ctx_extra),
2918 StatementEndedExecutionReason::Canceled,
2919 );
2920 drop(writer);
2922 self.conn.set_copy_mode(false);
2923 return self
2924 .send_error_and_get_state(ErrorResponse::error(
2925 SqlState::QUERY_CANCELED,
2926 format!("COPY from stdin failed: {}", err),
2927 ))
2928 .await;
2929 }
2930 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2931 Some(_) => {
2932 let msg = "unexpected message type during COPY from stdin";
2933 self.adapter_client.retire_execute(
2934 std::mem::take(ctx_extra),
2935 StatementEndedExecutionReason::Errored {
2936 error: msg.to_string(),
2937 },
2938 );
2939 drop(writer);
2940 self.conn.set_copy_mode(false);
2941 return self
2942 .send_error_and_get_state(ErrorResponse::error(
2943 SqlState::PROTOCOL_VIOLATION,
2944 msg,
2945 ))
2946 .await;
2947 }
2948 None => {
2949 drop(writer);
2950 self.conn.set_copy_mode(false);
2951 return Ok(State::Done);
2952 }
2953 }
2954 }
2955
2956 if !saw_copy_done {
2960 loop {
2961 match self.conn.recv().await? {
2962 Some(FrontendMessage::CopyData(_)) => {}
2963 Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2964 break;
2965 }
2966 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2967 Some(_) => {
2968 let msg = "unexpected message type during COPY from stdin";
2969 self.adapter_client.retire_execute(
2970 std::mem::take(ctx_extra),
2971 StatementEndedExecutionReason::Errored {
2972 error: msg.to_string(),
2973 },
2974 );
2975 drop(writer);
2976 self.conn.set_copy_mode(false);
2977 return self
2978 .send_error_and_get_state(ErrorResponse::error(
2979 SqlState::PROTOCOL_VIOLATION,
2980 msg,
2981 ))
2982 .await;
2983 }
2984 None => {
2985 drop(writer);
2986 self.conn.set_copy_mode(false);
2987 return Ok(State::Done);
2988 }
2989 }
2990 }
2991 }
2992
2993 if let Some((code, msg)) = copy_from_error {
2994 self.adapter_client.retire_execute(
2995 std::mem::take(ctx_extra),
2996 StatementEndedExecutionReason::Errored { error: msg.clone() },
2997 );
2998 drop(writer);
2999 self.conn.set_copy_mode(false);
3000 return self
3001 .send_error_and_get_state(ErrorResponse::error(code, msg))
3002 .await;
3003 }
3004
3005 self.conn.set_copy_mode(false);
3006
3007 drop(writer.batch_txs);
3012
3013 let (proto_batches, row_count) = match writer.completion_rx.await {
3015 Ok(Ok(result)) => result,
3016 Ok(Err(e)) => {
3017 self.adapter_client.retire_execute(
3018 std::mem::take(ctx_extra),
3019 StatementEndedExecutionReason::Errored {
3020 error: e.to_string(),
3021 },
3022 );
3023 return self
3024 .send_error_and_get_state(e.into_response(Severity::Error))
3025 .await;
3026 }
3027 Err(_) => {
3028 let msg = "COPY FROM STDIN: background batch builder tasks dropped";
3029 self.adapter_client.retire_execute(
3030 std::mem::take(ctx_extra),
3031 StatementEndedExecutionReason::Errored {
3032 error: msg.to_string(),
3033 },
3034 );
3035 return self
3036 .send_error_and_get_state(ErrorResponse::error(SqlState::INTERNAL_ERROR, msg))
3037 .await;
3038 }
3039 };
3040
3041 if let Err(e) = self
3043 .adapter_client
3044 .stage_copy_from_stdin_batches(target_id, proto_batches)
3045 {
3046 self.adapter_client.retire_execute(
3047 std::mem::take(ctx_extra),
3048 StatementEndedExecutionReason::Errored {
3049 error: e.to_string(),
3050 },
3051 );
3052 return self
3053 .send_error_and_get_state(e.into_response(Severity::Error))
3054 .await;
3055 }
3056
3057 let tag = format!("COPY {}", row_count);
3058 self.send(BackendMessage::CommandComplete { tag }).await?;
3059
3060 Ok(State::Ready)
3061 }
3062
3063 #[instrument(level = "debug")]
3064 async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
3065 let notices = self
3066 .adapter_client
3067 .session()
3068 .drain_notices()
3069 .into_iter()
3070 .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
3071 self.send_all(notices).await?;
3072 Ok(())
3073 }
3074
3075 #[instrument(level = "debug")]
3076 async fn send_error_and_get_state(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
3077 assert!(err.severity.is_error());
3078 debug!(
3079 "cid={} error code={}",
3080 self.adapter_client.session().conn_id(),
3081 err.code.code()
3082 );
3083 let is_fatal = err.severity.is_fatal();
3084 self.send(BackendMessage::ErrorResponse(err)).await?;
3085
3086 let txn = self.adapter_client.session().transaction();
3087 match txn {
3088 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
3091 TransactionStatus::Started(_) => {
3093 self.rollback_transaction().await?;
3094 }
3095 TransactionStatus::InTransactionImplicit(_) => {
3097 self.rollback_transaction().await?;
3098 }
3099 TransactionStatus::InTransaction(_) => {
3101 self.adapter_client.fail_transaction();
3102 }
3103 };
3104 if is_fatal {
3105 Ok(State::Done)
3106 } else {
3107 Ok(State::Drain)
3108 }
3109 }
3110
3111 #[instrument(level = "debug")]
3112 async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
3113 self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
3114 SqlState::IN_FAILED_SQL_TRANSACTION,
3115 ABORTED_TXN_MSG,
3116 )))
3117 .await?;
3118 Ok(State::Drain)
3119 }
3120
3121 fn is_aborted_txn(&mut self) -> bool {
3122 matches!(
3123 self.adapter_client.session().transaction(),
3124 TransactionStatus::Failed(_)
3125 )
3126 }
3127}
3128
3129fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
3130 match (formats.len(), n) {
3131 (0, e) => Ok(vec![Format::Text; e]),
3132 (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
3133 (a, e) if a == e => Ok(formats),
3134 (a, e) => Err(format!(
3135 "expected {} field format specifiers, but got {}",
3136 e, a
3137 )),
3138 }
3139}
3140
3141fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
3142 match &stmt_desc.relation_desc {
3143 Some(desc) if !stmt_desc.is_copy => {
3144 BackendMessage::RowDescription(message::encode_row_description(desc, formats))
3145 }
3146 _ => BackendMessage::NoData,
3147 }
3148}
3149
3150type GetResponse = fn(
3151 max_rows: ExecuteCount,
3152 total_sent_rows: usize,
3153 fetch_portal: Option<PortalRefMut>,
3154) -> BackendMessage;
3155
3156fn portal_exec_message(
3159 max_rows: ExecuteCount,
3160 total_sent_rows: usize,
3161 _fetch_portal: Option<PortalRefMut>,
3162) -> BackendMessage {
3163 match max_rows {
3170 ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
3171 BackendMessage::PortalSuspended
3172 }
3173 _ => BackendMessage::CommandComplete {
3174 tag: format!("SELECT {}", total_sent_rows),
3175 },
3176 }
3177}
3178
3179fn fetch_message(
3181 _max_rows: ExecuteCount,
3182 total_sent_rows: usize,
3183 fetch_portal: Option<PortalRefMut>,
3184) -> BackendMessage {
3185 let tag = format!("FETCH {}", total_sent_rows);
3186 if let Some(portal) = fetch_portal {
3187 *portal.state = PortalState::Completed(Some(tag.clone()));
3188 }
3189 BackendMessage::CommandComplete { tag }
3190}
3191
3192fn get_authenticator(
3193 authenticator_kind: listeners::AuthenticatorKind,
3194 frontegg: Option<FronteggAuthenticator>,
3195 oidc: GenericOidcAuthenticator,
3196 adapter_client: mz_adapter::Client,
3197) -> Authenticator {
3198 match authenticator_kind {
3199 listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
3200 "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
3201 )),
3202 listeners::AuthenticatorKind::Password => Authenticator::Password(adapter_client),
3203 listeners::AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client),
3204 listeners::AuthenticatorKind::Oidc => Authenticator::Oidc(oidc),
3205 listeners::AuthenticatorKind::None => Authenticator::None,
3206 }
3207}
3208
3209#[derive(Debug, Copy, Clone)]
3210enum ExecuteCount {
3211 All,
3212 Count(usize),
3213}
3214
3215fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
3217 match stmt {
3218 Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
3220 None => false,
3221 }
3222}
3223
3224#[derive(Debug)]
3225enum FetchResult {
3226 Rows(Option<Box<dyn RowIterator + Send + Sync>>),
3227 Canceled,
3228 Error(String),
3229 Notice(AdapterNotice),
3230}
3231
3232#[derive(Debug)]
3233struct CopyRowScanner {
3234 scan_pos: usize,
3235 last_row_end: Option<usize>,
3236 end_marker_end: Option<usize>,
3237 record_start: usize,
3242 csv: Option<CsvScanState>,
3243}
3244
3245#[derive(Debug)]
3246struct CsvScanState {
3247 reader: csv_core::Reader,
3248 output: Vec<u8>,
3249 ends: Vec<usize>,
3250 skip_first_record: bool,
3251}
3252
3253impl CopyRowScanner {
3254 fn new(params: &CopyFormatParams<'_>) -> Self {
3255 let csv = match params {
3256 CopyFormatParams::Csv(CopyCsvFormatParams {
3257 delimiter,
3258 quote,
3259 escape,
3260 header,
3261 ..
3262 }) => Some(CsvScanState::new(*delimiter, *quote, *escape, *header)),
3263 _ => None,
3264 };
3265
3266 CopyRowScanner {
3267 scan_pos: 0,
3268 last_row_end: None,
3269 end_marker_end: None,
3270 record_start: 0,
3271 csv,
3272 }
3273 }
3274
3275 fn scan_new_bytes(&mut self, data: &[u8]) {
3276 if self.scan_pos >= data.len() {
3277 return;
3278 }
3279
3280 if let Some(csv) = self.csv.as_mut() {
3281 let mut input = &data[self.scan_pos..];
3282 let mut consumed = 0usize;
3283 while !input.is_empty() {
3284 let (result, n_input, _n_output, _n_ends) =
3285 csv.reader
3286 .read_record(input, &mut csv.output, &mut csv.ends);
3287 consumed += n_input;
3288 input = &input[n_input..];
3289
3290 match result {
3291 ReadRecordResult::InputEmpty => break,
3292 ReadRecordResult::OutputFull => {
3293 if n_input == 0 {
3294 csv.output
3295 .resize(csv.output.len().saturating_mul(2).max(1), 0);
3296 }
3297 }
3298 ReadRecordResult::OutputEndsFull => {
3299 if n_input == 0 {
3300 csv.ends.resize(csv.ends.len().saturating_mul(2).max(1), 0);
3301 }
3302 }
3303 ReadRecordResult::Record | ReadRecordResult::End => {
3304 let row_end = self.scan_pos + consumed;
3305 self.last_row_end = Some(row_end);
3306 if self.end_marker_end.is_none() {
3307 let is_marker = if csv.skip_first_record {
3308 csv.skip_first_record = false;
3309 false
3310 } else {
3311 let raw = &data[self.record_start..row_end];
3317 let start = raw
3328 .iter()
3329 .take_while(|&&b| b == b'\r' || b == b'\n')
3330 .count();
3331 let trailing = raw[start..]
3332 .iter()
3333 .rev()
3334 .take_while(|&&b| b == b'\r' || b == b'\n')
3335 .count();
3336 let trimmed = &raw[start..raw.len() - trailing];
3337 trimmed == b"\\."
3338 };
3339 if is_marker {
3340 self.end_marker_end = Some(row_end);
3341 self.record_start = row_end;
3342 break;
3343 }
3344 }
3345 self.record_start = row_end;
3346 }
3347 }
3348 }
3349 } else {
3350 let mut row_start = self.last_row_end.unwrap_or(0);
3351 for (offset, b) in data[self.scan_pos..].iter().enumerate() {
3352 if *b == b'\n' {
3353 let row_end = self.scan_pos + offset + 1;
3354 self.last_row_end = Some(row_end);
3355 if self.end_marker_end.is_none() {
3356 let row = &data[row_start..row_end];
3357 if row.get(0..2) == Some(b"\\.") {
3358 self.end_marker_end = Some(row_end);
3359 break;
3360 }
3361 }
3362 row_start = row_end;
3363 }
3364 }
3365 }
3366
3367 self.scan_pos = data.len();
3368 }
3369
3370 fn last_row_end(&self) -> Option<usize> {
3371 self.last_row_end
3372 }
3373
3374 fn end_marker_end(&self) -> Option<usize> {
3375 self.end_marker_end
3376 }
3377
3378 fn current_row_size(&self, data_len: usize) -> usize {
3379 data_len.saturating_sub(self.last_row_end.unwrap_or(0))
3380 }
3381
3382 fn on_split(&mut self, split_pos: usize) {
3383 self.scan_pos = self.scan_pos.saturating_sub(split_pos);
3384 self.last_row_end = None;
3385 self.end_marker_end = self
3386 .end_marker_end
3387 .and_then(|end| end.checked_sub(split_pos));
3388 soft_assert_or_log!(
3396 self.csv.is_none() || self.record_start >= split_pos,
3397 "split bisected an in-progress CSV record: record_start={} < split_pos={}",
3398 self.record_start,
3399 split_pos,
3400 );
3401 self.record_start = self.record_start.saturating_sub(split_pos);
3402 }
3403
3404 fn on_truncate(&mut self, new_len: usize) {
3405 self.scan_pos = self.scan_pos.min(new_len);
3406 self.last_row_end = self.last_row_end.filter(|&end| end <= new_len);
3407 self.end_marker_end = self.end_marker_end.filter(|&end| end <= new_len);
3408 self.record_start = self.record_start.min(new_len);
3409 }
3410}
3411
3412impl CsvScanState {
3413 fn new(delimiter: u8, quote: u8, escape: u8, header: bool) -> Self {
3414 let (double_quote, escape) = if quote == escape {
3415 (true, None)
3416 } else {
3417 (false, Some(escape))
3418 };
3419 CsvScanState {
3420 reader: csv_core::ReaderBuilder::new()
3421 .delimiter(delimiter)
3422 .quote(quote)
3423 .double_quote(double_quote)
3424 .escape(escape)
3425 .build(),
3426 output: vec![0; 1],
3427 ends: vec![0; 1],
3428 skip_first_record: header,
3429 }
3430 }
3431}
3432
3433#[cfg(test)]
3434mod test {
3435 use super::*;
3436
3437 #[mz_ore::test]
3438 fn test_copy_row_scanner_end_marker_line_endings() {
3439 let params = CopyFormatParams::Csv(CopyCsvFormatParams::default());
3447
3448 let marker_end = |data: &[u8]| -> Option<usize> {
3449 let mut scanner = CopyRowScanner::new(¶ms);
3450 scanner.scan_new_bytes(data);
3451 scanner.end_marker_end()
3452 };
3453
3454 for eol in [&b"\n"[..], b"\r\n", b"\r"] {
3455 let join = |lines: &[&str]| -> Vec<u8> {
3456 let mut out = Vec::new();
3457 for line in lines {
3458 out.extend_from_slice(line.as_bytes());
3459 out.extend_from_slice(eol);
3460 }
3461 out
3462 };
3463
3464 let data = join(&["first", "\\.", "after"]);
3469 let mut prefix = Vec::new();
3470 prefix.extend_from_slice(b"first");
3471 prefix.extend_from_slice(eol);
3472 prefix.extend_from_slice(b"\\.");
3473 assert_eq!(
3474 marker_end(&data),
3475 Some(prefix.len() + 1),
3476 "bare marker, eol={eol:?}"
3477 );
3478
3479 let data = join(&["before", "\"\\.\"", "after"]);
3481 assert_eq!(marker_end(&data), None, "quoted marker, eol={eol:?}");
3482 }
3483 }
3484
3485 #[mz_ore::test]
3486 fn test_copy_row_scanner_non_csv_split() {
3487 for params in [
3494 CopyFormatParams::Text(CopyTextFormatParams::default()),
3495 CopyFormatParams::Binary,
3496 ] {
3497 let mut scanner = CopyRowScanner::new(¶ms);
3498 let data = b"1\thello world\t2\tsome text value here\n\
3499 3\thello world\t6\tsome text value here\n";
3500 scanner.scan_new_bytes(data);
3501 let split_pos = scanner.last_row_end().expect("a complete row");
3502 assert!(split_pos > 0, "params={params:?}");
3503 scanner.on_split(split_pos);
3505 assert_eq!(scanner.record_start, 0, "params={params:?}");
3506 }
3507 }
3508
3509 #[mz_ore::test]
3510 fn test_parse_options() {
3511 struct TestCase {
3512 input: &'static str,
3513 expect: Result<Vec<(&'static str, &'static str)>, ()>,
3514 }
3515 let tests = vec![
3516 TestCase {
3517 input: "",
3518 expect: Ok(vec![]),
3519 },
3520 TestCase {
3521 input: "--key",
3522 expect: Err(()),
3523 },
3524 TestCase {
3525 input: "--key=val",
3526 expect: Ok(vec![("key", "val")]),
3527 },
3528 TestCase {
3529 input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
3530 expect: Ok(vec![
3531 ("key", "val"),
3532 ("key2", "val2"),
3533 ("key3", "val3"),
3534 ("key4", "val4"),
3535 ("key5", "val5"),
3536 ]),
3537 },
3538 TestCase {
3539 input: r#"-c\ key=val"#,
3540 expect: Ok(vec![(" key", "val")]),
3541 },
3542 TestCase {
3543 input: "--key=val -ckey2 val2",
3544 expect: Err(()),
3545 },
3546 TestCase {
3548 input: "--key=",
3549 expect: Ok(vec![("key", "")]),
3550 },
3551 ];
3552 for test in tests {
3553 let got = parse_options(test.input);
3554 let expect = test.expect.map(|r| {
3555 r.into_iter()
3556 .map(|(k, v)| (k.to_owned(), v.to_owned()))
3557 .collect()
3558 });
3559 assert_eq!(got, expect, "input: {}", test.input);
3560 }
3561 }
3562
3563 #[mz_ore::test]
3564 fn test_parse_option() {
3565 struct TestCase {
3566 input: &'static str,
3567 expect: Result<(&'static str, &'static str), ()>,
3568 }
3569 let tests = vec![
3570 TestCase {
3571 input: "",
3572 expect: Err(()),
3573 },
3574 TestCase {
3575 input: "--",
3576 expect: Err(()),
3577 },
3578 TestCase {
3579 input: "--c",
3580 expect: Err(()),
3581 },
3582 TestCase {
3583 input: "a=b",
3584 expect: Err(()),
3585 },
3586 TestCase {
3587 input: "--a=b",
3588 expect: Ok(("a", "b")),
3589 },
3590 TestCase {
3591 input: "--ca=b",
3592 expect: Ok(("ca", "b")),
3593 },
3594 TestCase {
3595 input: "-ca=b",
3596 expect: Ok(("a", "b")),
3597 },
3598 TestCase {
3600 input: "--=",
3601 expect: Ok(("", "")),
3602 },
3603 ];
3604 for test in tests {
3605 let got = parse_option(test.input);
3606 assert_eq!(got, test.expect, "input: {}", test.input);
3607 }
3608 }
3609
3610 #[mz_ore::test]
3611 fn test_split_options() {
3612 struct TestCase {
3613 input: &'static str,
3614 expect: Vec<&'static str>,
3615 }
3616 let tests = vec![
3617 TestCase {
3618 input: "",
3619 expect: vec![],
3620 },
3621 TestCase {
3622 input: " ",
3623 expect: vec![],
3624 },
3625 TestCase {
3626 input: " a ",
3627 expect: vec!["a"],
3628 },
3629 TestCase {
3630 input: " ab cd ",
3631 expect: vec!["ab", "cd"],
3632 },
3633 TestCase {
3634 input: r#" ab\ cd "#,
3635 expect: vec!["ab ", "cd"],
3636 },
3637 TestCase {
3638 input: r#" ab\\ cd "#,
3639 expect: vec![r#"ab\"#, "cd"],
3640 },
3641 TestCase {
3642 input: r#" ab\\\ cd "#,
3643 expect: vec![r#"ab\ "#, "cd"],
3644 },
3645 TestCase {
3646 input: r#" ab\\\ cd "#,
3647 expect: vec![r#"ab\ cd"#],
3648 },
3649 TestCase {
3650 input: r#" ab\\\cd "#,
3651 expect: vec![r#"ab\cd"#],
3652 },
3653 TestCase {
3654 input: r#"a\"#,
3655 expect: vec!["a"],
3656 },
3657 TestCase {
3658 input: r#"a\ "#,
3659 expect: vec!["a "],
3660 },
3661 TestCase {
3662 input: r#"\"#,
3663 expect: vec![],
3664 },
3665 TestCase {
3666 input: r#"\ "#,
3667 expect: vec![r#" "#],
3668 },
3669 TestCase {
3670 input: r#" \ "#,
3671 expect: vec![r#" "#],
3672 },
3673 TestCase {
3674 input: r#"\ "#,
3675 expect: vec![r#" "#],
3676 },
3677 ];
3678 for test in tests {
3679 let got = split_options(test.input);
3680 assert_eq!(got, test.expect, "input: {}", test.input);
3681 }
3682 }
3683
3684 #[mz_ore::test]
3685 fn test_is_jwt() {
3686 assert!(is_jwt("eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0.signature"));
3688 for s in [
3690 "",
3691 "secure_password",
3692 "p4ss.w0rd",
3693 "aaa.bbb.ccc",
3694 "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0",
3695 "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0.sig.extra",
3696 ] {
3697 assert!(!is_jwt(s), "is_jwt({s:?})");
3698 }
3699 }
3700}