1use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use csv_core::ReadRecordResult;
21use futures::future::{BoxFuture, FutureExt, pending};
22use itertools::Itertools;
23use mz_adapter::client::RecordFirstRowStream;
24use mz_adapter::session::{
25 EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState, Session,
26 SessionConfig, TransactionStatus,
27};
28use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
29use mz_adapter::{
30 AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics,
31 verify_datum_desc,
32};
33use mz_auth::Authenticated;
34use mz_auth::password::Password;
35use mz_authenticator::{Authenticator, GenericOidcAuthenticator};
36use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
37use mz_ore::cast::CastFrom;
38use mz_ore::netio::AsyncReady;
39use mz_ore::now::{EpochMillis, SYSTEM_TIME};
40use mz_ore::str::StrExt;
41use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
42use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
43use mz_pgwire_common::{
44 ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
45 VERSIONS,
46};
47use mz_repr::{
48 CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
49 SqlRelationType, SqlScalarType,
50};
51use mz_server_core::TlsMode;
52use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind};
53use mz_sql::ast::display::AstDisplay;
54use mz_sql::ast::{
55 CopyDirection, CopyStatement, CopyTarget, FetchDirection, Ident, Raw, Statement,
56};
57use mz_sql::parse::StatementParseResult;
58use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
59use mz_sql::session::metadata::SessionMetadata;
60use mz_sql::session::user::INTERNAL_USER_NAMES;
61use mz_sql::session::vars::VarInput;
62use postgres::error::SqlState;
63use tokio::io::{self, AsyncRead, AsyncWrite};
64use tokio::select;
65use tokio::time::{self};
66use tokio_metrics::TaskMetrics;
67use tokio_stream::wrappers::UnboundedReceiverStream;
68use tracing::{Instrument, debug, debug_span, warn};
69use uuid::Uuid;
70
71use crate::codec::{
72 FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
73};
74use crate::message::{
75 self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
76 SASLServerFirstMessage,
77};
78
79pub fn match_handshake(buf: &[u8]) -> bool {
83 if buf.len() < 8 {
93 return false;
94 }
95 let version = NetworkEndian::read_i32(&buf[4..8]);
96 VERSIONS.contains(&version)
97}
98
99pub struct RunParams<'a, A, I>
101where
102 I: Iterator<Item = TaskMetrics> + Send,
103{
104 pub tls_mode: Option<TlsMode>,
106 pub adapter_client: mz_adapter::Client,
108 pub conn: &'a mut FramedConn<A>,
110 pub conn_uuid: Uuid,
112 pub version: i32,
114 pub params: BTreeMap<String, String>,
116 pub frontegg: Option<FronteggAuthenticator>,
118 pub oidc: GenericOidcAuthenticator,
120 pub authenticator_kind: AuthenticatorKind,
123 pub active_connection_counter: ConnectionCounter,
125 pub helm_chart_version: Option<String>,
127 pub allowed_roles: AllowedRoles,
129 pub tokio_metrics_intervals: I,
131}
132
133#[mz_ore::instrument(level = "debug")]
143pub async fn run<'a, A, I>(
144 RunParams {
145 tls_mode,
146 adapter_client,
147 conn,
148 conn_uuid,
149 version,
150 mut params,
151 frontegg,
152 oidc,
153 authenticator_kind,
154 active_connection_counter,
155 helm_chart_version,
156 allowed_roles,
157 tokio_metrics_intervals,
158 }: RunParams<'a, A, I>,
159) -> Result<(), io::Error>
160where
161 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
162 I: Iterator<Item = TaskMetrics> + Send,
163{
164 if version != VERSION_3 {
165 return conn
166 .send(ErrorResponse::fatal(
167 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
168 "server does not support the client's requested protocol version",
169 ))
170 .await;
171 }
172
173 let user = params.remove("user").unwrap_or_else(String::new);
174 let options = parse_options(params.get("options").unwrap_or(&String::new()));
175
176 let (oidc_auth_enabled, options) = extract_oidc_auth_enabled_from_options(options);
179 let authenticator = get_authenticator(
180 authenticator_kind,
181 frontegg,
182 oidc,
183 adapter_client.clone(),
184 oidc_auth_enabled,
185 );
186 let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
188 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
190 let role_allowed = match allowed_roles {
191 AllowedRoles::Normal => !is_reserved_user,
192 AllowedRoles::Internal => is_internal_user,
193 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
194 };
195 if !role_allowed {
196 let msg = format!("unauthorized login to user '{user}'");
197 return conn
198 .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
199 .await;
200 }
201
202 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
203 return conn.send(err).await;
204 }
205
206 let (mut session, expired) = match authenticator {
207 Authenticator::Frontegg(frontegg) => {
208 let password = match request_cleartext_password(conn).await {
209 Ok(password) => password,
210 Err(PasswordRequestError::IoError(e)) => return Err(e),
211 Err(PasswordRequestError::InvalidPasswordError(e)) => {
212 return conn.send(e).await;
213 }
214 };
215
216 let auth_response = frontegg.authenticate(&user, &password).await;
217 match auth_response {
218 Ok((mut auth_session, authenticated)) => {
225 let session = adapter_client.new_session(
226 SessionConfig {
227 conn_id: conn.conn_id().clone(),
228 uuid: conn_uuid,
229 user: auth_session.user().into(),
230 client_ip: conn.peer_addr().clone(),
231 external_metadata_rx: Some(auth_session.external_metadata_rx()),
232 helm_chart_version,
233 },
234 authenticated,
235 );
236 let expired = async move { auth_session.expired().await };
237 (session, expired.left_future())
238 }
239 Err(err) => {
240 warn!(?err, "pgwire connection failed authentication");
241 return conn
242 .send(ErrorResponse::fatal(
243 SqlState::INVALID_PASSWORD,
244 "invalid password",
245 ))
246 .await;
247 }
248 }
249 }
250 Authenticator::Oidc(oidc) => {
251 let jwt = match request_cleartext_password(conn).await {
253 Ok(password) => password,
254 Err(PasswordRequestError::IoError(e)) => return Err(e),
255 Err(PasswordRequestError::InvalidPasswordError(e)) => {
256 return conn.send(e).await;
257 }
258 };
259 let auth_response = oidc.authenticate(&jwt, Some(&user)).await;
260 match auth_response {
261 Ok((claims, authenticated)) => {
262 let session = adapter_client.new_session(
263 SessionConfig {
264 conn_id: conn.conn_id().clone(),
265 uuid: conn_uuid,
266 user: claims.user,
267 client_ip: conn.peer_addr().clone(),
268 external_metadata_rx: None,
269 helm_chart_version,
270 },
271 authenticated,
272 );
273 (session, pending().right_future())
276 }
277 Err(err) => {
278 warn!(?err, "pgwire connection failed authentication");
279 return conn.send(err.into_response()).await;
280 }
281 }
282 }
283 Authenticator::Password(adapter_client) => {
284 let session = match authenticate_with_password(
285 conn,
286 &adapter_client,
287 user,
288 conn_uuid,
289 helm_chart_version,
290 )
291 .await
292 {
293 Ok(session) => session,
294 Err(PasswordRequestError::IoError(e)) => return Err(e),
295 Err(PasswordRequestError::InvalidPasswordError(e)) => {
296 return conn.send(e).await;
297 }
298 };
299 (session, pending().right_future())
301 }
302 Authenticator::Sasl(adapter_client) => {
303 conn.send(BackendMessage::AuthenticationSASL).await?;
305 conn.flush().await?;
306 let (mechanism, initial_response) = match conn.recv().await? {
308 Some(FrontendMessage::RawAuthentication(data)) => {
309 match decode_sasl_initial_response(Cursor::new(&data)).ok() {
310 Some(FrontendMessage::SASLInitialResponse {
311 gs2_header,
312 mechanism,
313 initial_response,
314 }) => {
315 if gs2_header.channel_binding_enabled() {
317 return conn
318 .send(ErrorResponse::fatal(
319 SqlState::PROTOCOL_VIOLATION,
320 "channel binding not supported",
321 ))
322 .await;
323 }
324 (mechanism, initial_response)
325 }
326 _ => {
327 return conn
328 .send(ErrorResponse::fatal(
329 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
330 "expected SASLInitialResponse message",
331 ))
332 .await;
333 }
334 }
335 }
336 _ => {
337 return conn
338 .send(ErrorResponse::fatal(
339 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
340 "expected SASLInitialResponse message",
341 ))
342 .await;
343 }
344 };
345
346 if mechanism != "SCRAM-SHA-256" {
347 return conn
348 .send(ErrorResponse::fatal(
349 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
350 "unsupported SASL mechanism",
351 ))
352 .await;
353 }
354
355 if initial_response.nonce.len() > 256 {
356 return conn
357 .send(ErrorResponse::fatal(
358 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
359 "nonce too long",
360 ))
361 .await;
362 }
363
364 let (server_first_message_raw, mock_hash) = match adapter_client
365 .generate_sasl_challenge(&user, &initial_response.nonce)
366 .await
367 {
368 Ok(response) => {
369 let server_first_message_raw = format!(
370 "r={},s={},i={}",
371 response.nonce, response.salt, response.iteration_count
372 );
373
374 let client_key = [0u8; 32];
375 let server_key = [1u8; 32];
376 let mock_hash = format!(
377 "SCRAM-SHA-256${}:{}${}:{}",
378 response.iteration_count,
379 response.salt,
380 BASE64_STANDARD.encode(client_key),
381 BASE64_STANDARD.encode(server_key)
382 );
383
384 conn.send(BackendMessage::AuthenticationSASLContinue(
385 SASLServerFirstMessage {
386 iteration_count: response.iteration_count,
387 nonce: response.nonce,
388 salt: response.salt,
389 },
390 ))
391 .await?;
392 conn.flush().await?;
393 (server_first_message_raw, mock_hash)
394 }
395 Err(e) => {
396 return conn.send(e.into_response(Severity::Fatal)).await;
397 }
398 };
399
400 let authenticated = match conn.recv().await? {
401 Some(FrontendMessage::RawAuthentication(data)) => {
402 match decode_sasl_response(Cursor::new(&data)).ok() {
403 Some(FrontendMessage::SASLResponse(response)) => {
404 let auth_message = format!(
405 "{},{},{}",
406 initial_response.client_first_message_bare_raw,
407 server_first_message_raw,
408 response.client_final_message_bare_raw
409 );
410 if response.proof.len() > 1024 {
411 return conn
412 .send(ErrorResponse::fatal(
413 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
414 "proof too long",
415 ))
416 .await;
417 }
418 match adapter_client
419 .verify_sasl_proof(
420 &user,
421 &response.proof,
422 &auth_message,
423 &mock_hash,
424 )
425 .await
426 {
427 Ok((proof_response, authenticated)) => {
428 conn.send(BackendMessage::AuthenticationSASLFinal(
429 SASLServerFinalMessage {
430 kind: SASLServerFinalMessageKinds::Verifier(
431 proof_response.verifier,
432 ),
433 extensions: vec![],
434 },
435 ))
436 .await?;
437 conn.flush().await?;
438 authenticated
439 }
440 Err(_) => {
441 return conn
442 .send(ErrorResponse::fatal(
443 SqlState::INVALID_PASSWORD,
444 "invalid password",
445 ))
446 .await;
447 }
448 }
449 }
450 _ => {
451 return conn
452 .send(ErrorResponse::fatal(
453 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
454 "expected SASLResponse message",
455 ))
456 .await;
457 }
458 }
459 }
460 _ => {
461 return conn
462 .send(ErrorResponse::fatal(
463 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
464 "expected SASLResponse message",
465 ))
466 .await;
467 }
468 };
469
470 let session = adapter_client.new_session(
471 SessionConfig {
472 conn_id: conn.conn_id().clone(),
473 uuid: conn_uuid,
474 user,
475 client_ip: conn.peer_addr().clone(),
476 external_metadata_rx: None,
477 helm_chart_version,
478 },
479 authenticated,
480 );
481 let auth_session = pending().right_future();
483 (session, auth_session)
484 }
485
486 Authenticator::None => {
487 let session = adapter_client.new_session(
488 SessionConfig {
489 conn_id: conn.conn_id().clone(),
490 uuid: conn_uuid,
491 user,
492 client_ip: conn.peer_addr().clone(),
493 external_metadata_rx: None,
494 helm_chart_version,
495 },
496 Authenticated,
497 );
498 let auth_session = pending().right_future();
500 (session, auth_session)
501 }
502 };
503
504 let system_vars = adapter_client.get_system_vars().await;
505 for (name, value) in params {
506 let settings = match name.as_str() {
507 "options" => match &options {
508 Ok(opts) => opts,
509 Err(()) => {
510 session.add_notice(AdapterNotice::BadStartupSetting {
511 name,
512 reason: "could not parse".into(),
513 });
514 continue;
515 }
516 },
517 _ => &vec![(name, value)],
518 };
519 for (key, val) in settings {
520 const LOCAL: bool = false;
521 if let Err(err) = session
526 .vars_mut()
527 .set(&system_vars, key, VarInput::Flat(val), LOCAL)
528 {
529 session.add_notice(AdapterNotice::BadStartupSetting {
530 name: key.clone(),
531 reason: err.to_string(),
532 });
533 }
534 }
535 }
536 session
537 .vars_mut()
538 .end_transaction(EndTransactionAction::Commit);
539
540 let _guard = match active_connection_counter.allocate_connection(session.user()) {
541 Ok(drop_connection) => drop_connection,
542 Err(e) => {
543 let e: AdapterError = e.into();
544 return conn.send(e.into_response(Severity::Fatal)).await;
545 }
546 };
547
548 let mut adapter_client = match adapter_client.startup(session).await {
550 Ok(adapter_client) => adapter_client,
551 Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
552 };
553
554 let mut buf = vec![BackendMessage::AuthenticationOk];
555 for var in adapter_client.session().vars().notify_set() {
556 buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
557 }
558 buf.push(BackendMessage::BackendKeyData {
559 conn_id: adapter_client.session().conn_id().unhandled(),
560 secret_key: adapter_client.session().secret_key(),
561 });
562 buf.extend(
563 adapter_client
564 .session()
565 .drain_notices()
566 .into_iter()
567 .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
568 );
569 buf.push(BackendMessage::ReadyForQuery(
570 adapter_client.session().transaction().into(),
571 ));
572 conn.send_all(buf).await?;
573 conn.flush().await?;
574
575 let machine = StateMachine {
576 conn,
577 adapter_client,
578 txn_needs_commit: false,
579 tokio_metrics_intervals,
580 };
581
582 select! {
583 r = machine.run() => {
584 if let Err(err) = &r {
589 let _ = conn
590 .send(ErrorResponse::fatal(
591 SqlState::CONNECTION_FAILURE,
592 err.to_string(),
593 ))
594 .await;
595 let _ = conn.flush().await;
596 }
597 r
598 },
599 _ = expired => {
600 conn
601 .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
602 .await?;
603 conn.flush().await
604 }
605 }
606}
607
608fn extract_oidc_auth_enabled_from_options(
612 options: Result<Vec<(String, String)>, ()>,
613) -> (bool, Result<Vec<(String, String)>, ()>) {
614 let options = match options {
615 Ok(opts) => opts,
616 Err(_) => return (false, options),
617 };
618
619 let mut new_options = Vec::new();
620 let mut oidc_auth_enabled = false;
621
622 for (k, v) in options {
623 if k == "oidc_auth_enabled" {
624 oidc_auth_enabled = v.parse::<bool>().unwrap_or(false);
625 } else {
626 new_options.push((k, v));
627 }
628 }
629
630 (oidc_auth_enabled, Ok(new_options))
631}
632
633fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
638 let opts = split_options(value);
639 let mut pairs = Vec::with_capacity(opts.len());
640 let mut seen_prefix = false;
641 for opt in opts {
642 if !seen_prefix {
643 if opt == "-c" {
644 seen_prefix = true;
645 } else {
646 let (key, val) = parse_option(&opt)?;
647 pairs.push((key.to_owned(), val.to_owned()));
648 }
649 } else {
650 let (key, val) = opt.split_once('=').ok_or(())?;
651 pairs.push((key.to_owned(), val.to_owned()));
652 seen_prefix = false;
653 }
654 }
655 Ok(pairs)
656}
657
658fn parse_option(option: &str) -> Result<(&str, &str), ()> {
662 let (key, value) = option.split_once('=').ok_or(())?;
663 for prefix in &["-c", "--"] {
664 if let Some(key) = key.strip_prefix(prefix) {
665 return Ok((key, value));
666 }
667 }
668 Err(())
669}
670
671fn split_options(value: &str) -> Vec<String> {
673 let mut strs = Vec::new();
674 let mut current = String::new();
678 let mut was_slash = false;
679 for c in value.chars() {
680 was_slash = match c {
681 ' ' => {
682 if was_slash {
683 current.push(' ');
684 } else if !current.is_empty() {
685 strs.push(std::mem::take(&mut current));
688 }
689 false
690 }
691 '\\' => {
692 if was_slash {
693 current.push('\\');
696 false
697 } else {
698 true
699 }
700 }
701 _ => {
702 current.push(c);
703 false
704 }
705 };
706 }
707 if !current.is_empty() {
709 strs.push(current);
710 }
711 strs
712}
713
714enum PasswordRequestError {
715 InvalidPasswordError(ErrorResponse),
716 IoError(io::Error),
717}
718
719impl From<io::Error> for PasswordRequestError {
720 fn from(e: io::Error) -> Self {
721 PasswordRequestError::IoError(e)
722 }
723}
724
725async fn request_cleartext_password<A>(
729 conn: &mut FramedConn<A>,
730) -> Result<String, PasswordRequestError>
731where
732 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
733{
734 conn.send(BackendMessage::AuthenticationCleartextPassword)
735 .await?;
736 conn.flush().await?;
737
738 if let Some(message) = conn.recv().await? {
739 if let FrontendMessage::RawAuthentication(data) = message {
740 if let Some(FrontendMessage::Password { password }) =
741 decode_password(Cursor::new(&data)).ok()
742 {
743 return Ok(password);
744 }
745 }
746 }
747
748 Err(PasswordRequestError::InvalidPasswordError(
749 ErrorResponse::fatal(
750 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
751 "expected Password message",
752 ),
753 ))
754}
755
756async fn authenticate_with_password<A>(
759 conn: &mut FramedConn<A>,
760 adapter_client: &mz_adapter::Client,
761 user: String,
762 conn_uuid: Uuid,
763 helm_chart_version: Option<String>,
764) -> Result<Session, PasswordRequestError>
765where
766 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
767{
768 let password = match request_cleartext_password(conn).await {
769 Ok(password) => Password(password),
770 Err(e) => return Err(e),
771 };
772
773 let authenticated = match adapter_client.authenticate(&user, &password).await {
774 Ok(authenticated) => authenticated,
775 Err(err) => {
776 warn!(?err, "pgwire connection failed authentication");
777 return Err(PasswordRequestError::InvalidPasswordError(
778 ErrorResponse::fatal(SqlState::INVALID_PASSWORD, "invalid password"),
779 ));
780 }
781 };
782
783 let session = adapter_client.new_session(
784 SessionConfig {
785 conn_id: conn.conn_id().clone(),
786 uuid: conn_uuid,
787 user,
788 client_ip: conn.peer_addr().clone(),
789 external_metadata_rx: None,
790 helm_chart_version,
791 },
792 authenticated,
793 );
794
795 Ok(session)
796}
797
798#[derive(Debug)]
799enum State {
800 Ready,
801 Drain,
802 Done,
803}
804
805struct StateMachine<'a, A, I>
806where
807 I: Iterator<Item = TaskMetrics> + Send + 'a,
808{
809 conn: &'a mut FramedConn<A>,
810 adapter_client: mz_adapter::SessionClient,
811 txn_needs_commit: bool,
812 tokio_metrics_intervals: I,
813}
814
815enum SendRowsEndedReason {
816 Success {
817 result_size: u64,
818 rows_returned: u64,
819 },
820 Errored {
821 error: String,
822 },
823 Canceled,
824}
825
826const ABORTED_TXN_MSG: &str =
827 "current transaction is aborted, commands ignored until end of transaction block";
828
829impl<'a, A, I> StateMachine<'a, A, I>
830where
831 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
832 I: Iterator<Item = TaskMetrics> + Send + 'a,
833{
834 #[allow(clippy::manual_async_fn)]
838 #[mz_ore::instrument(level = "debug")]
839 fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
840 async move {
841 let mut state = State::Ready;
842 loop {
843 self.send_pending_notices().await?;
844 state = match state {
845 State::Ready => self.advance_ready().await?,
846 State::Drain => self.advance_drain().await?,
847 State::Done => return Ok(()),
848 };
849 self.adapter_client
850 .add_idle_in_transaction_session_timeout();
851 }
852 }
853 }
854
855 #[instrument(level = "debug")]
856 async fn advance_ready(&mut self) -> Result<State, io::Error> {
857 self.tokio_metrics_intervals
859 .next()
860 .expect("infinite iterator");
861
862 let message = select! {
864 biased;
865
866 Some(timeout) = self.adapter_client.recv_timeout() => {
868 let err: AdapterError = timeout.into();
869 let conn_id = self.adapter_client.session().conn_id();
870 tracing::warn!("session timed out, conn_id {}", conn_id);
871
872 let error_response = err.into_response(Severity::Fatal);
874 let error_state = self.send_error_and_get_state(error_response).await;
875
876 self.adapter_client.terminate().await;
878
879 let _ = self.conn.recv().await?;
883 return error_state;
884 },
885 message = self.conn.recv() => message?,
887 };
888
889 let interval = self
891 .tokio_metrics_intervals
892 .next()
893 .expect("infinite iterator");
894 let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
895
896 let received = SYSTEM_TIME();
900
901 self.adapter_client
902 .remove_idle_in_transaction_session_timeout();
903
904 let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
907
908 let start = message.as_ref().map(|_| Instant::now());
909 let next_state = match message {
910 Some(FrontendMessage::Query { sql }) => {
911 let query_root_span =
912 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
913 query_root_span.follows_from(tracing::Span::current());
914 self.query(sql, received)
915 .instrument(query_root_span)
916 .await?
917 }
918 Some(FrontendMessage::Parse {
919 name,
920 sql,
921 param_types,
922 }) => self.parse(name, sql, param_types).await?,
923 Some(FrontendMessage::Bind {
924 portal_name,
925 statement_name,
926 param_formats,
927 raw_params,
928 result_formats,
929 }) => {
930 self.bind(
931 portal_name,
932 statement_name,
933 param_formats,
934 raw_params,
935 result_formats,
936 )
937 .await?
938 }
939 Some(FrontendMessage::Execute {
940 portal_name,
941 max_rows,
942 }) => {
943 let max_rows = match usize::try_from(max_rows) {
944 Ok(0) | Err(_) => ExecuteCount::All, Ok(n) => ExecuteCount::Count(n),
946 };
947 let execute_root_span =
948 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
949 execute_root_span.follows_from(tracing::Span::current());
950 let state = self
951 .execute(
952 portal_name,
953 max_rows,
954 portal_exec_message,
955 None,
956 ExecuteTimeout::None,
957 None,
958 Some(received),
959 )
960 .instrument(execute_root_span)
961 .await?;
962 if self.adapter_client.session().transaction().is_implicit() {
977 self.txn_needs_commit = true;
978 }
979 state
980 }
981 Some(FrontendMessage::DescribeStatement { name }) => {
982 self.describe_statement(&name).await?
983 }
984 Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
985 Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
986 Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
987 Some(FrontendMessage::Flush) => self.flush().await?,
988 Some(FrontendMessage::Sync) => self.sync().await?,
989 Some(FrontendMessage::Terminate) => State::Done,
990
991 Some(FrontendMessage::CopyData(_))
992 | Some(FrontendMessage::CopyDone)
993 | Some(FrontendMessage::CopyFail(_))
994 | Some(FrontendMessage::Password { .. })
995 | Some(FrontendMessage::RawAuthentication(_))
996 | Some(FrontendMessage::SASLInitialResponse { .. })
997 | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
998 None => State::Done,
999 };
1000
1001 if let Some(start) = start {
1002 self.adapter_client
1003 .inner()
1004 .metrics()
1005 .pgwire_message_processing_seconds
1006 .with_label_values(&[message_name])
1007 .observe(start.elapsed().as_secs_f64());
1008 }
1009 self.adapter_client
1010 .inner()
1011 .metrics()
1012 .pgwire_recv_scheduling_delay_ms
1013 .with_label_values(&[message_name])
1014 .observe(recv_scheduling_delay_ms);
1015
1016 Ok(next_state)
1017 }
1018
1019 async fn advance_drain(&mut self) -> Result<State, io::Error> {
1020 let message = self.conn.recv().await?;
1021 if message.is_some() {
1022 self.adapter_client
1023 .remove_idle_in_transaction_session_timeout();
1024 }
1025 match message {
1026 Some(FrontendMessage::Sync) => self.sync().await,
1027 None => Ok(State::Done),
1028 _ => Ok(State::Drain),
1029 }
1030 }
1031
1032 #[instrument(level = "debug")]
1036 async fn one_query(
1037 &mut self,
1038 stmt: Statement<Raw>,
1039 sql: String,
1040 lifecycle_timestamps: LifecycleTimestamps,
1041 ) -> Result<State, io::Error> {
1042 const EMPTY_PORTAL: &str = "";
1045 if let Err(e) = self
1046 .adapter_client
1047 .declare(EMPTY_PORTAL.to_string(), stmt, sql)
1048 .await
1049 {
1050 return self
1051 .send_error_and_get_state(e.into_response(Severity::Error))
1052 .await;
1053 }
1054 let portal = self
1055 .adapter_client
1056 .session()
1057 .get_portal_unverified_mut(EMPTY_PORTAL)
1058 .expect("unnamed portal should be present");
1059
1060 *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
1061
1062 let stmt_desc = portal.desc.clone();
1063 if !stmt_desc.param_types.is_empty() {
1064 return self
1065 .send_error_and_get_state(ErrorResponse::error(
1066 SqlState::UNDEFINED_PARAMETER,
1067 "there is no parameter $1",
1068 ))
1069 .await;
1070 }
1071
1072 if let Some(relation_desc) = &stmt_desc.relation_desc {
1074 if !stmt_desc.is_copy {
1075 let formats = vec![Format::Text; stmt_desc.arity()];
1076 self.send(BackendMessage::RowDescription(
1077 message::encode_row_description(relation_desc, &formats),
1078 ))
1079 .await?;
1080 }
1081 }
1082
1083 let result = match self
1084 .adapter_client
1085 .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
1086 .await
1087 {
1088 Ok((response, execute_started)) => {
1089 self.send_pending_notices().await?;
1090 self.send_execute_response(
1091 response,
1092 stmt_desc.relation_desc,
1093 EMPTY_PORTAL.to_string(),
1094 ExecuteCount::All,
1095 portal_exec_message,
1096 None,
1097 ExecuteTimeout::None,
1098 execute_started,
1099 )
1100 .await
1101 }
1102 Err(e) => {
1103 self.send_pending_notices().await?;
1104 self.send_error_and_get_state(e.into_response(Severity::Error))
1105 .await
1106 }
1107 };
1108
1109 self.adapter_client.session().remove_portal(EMPTY_PORTAL);
1111
1112 result
1113 }
1114
1115 async fn ensure_transaction(
1116 &mut self,
1117 num_stmts: usize,
1118 message_type: &str,
1119 ) -> Result<(), io::Error> {
1120 let start = Instant::now();
1121 if self.txn_needs_commit {
1122 self.commit_transaction().await?;
1123 }
1124 let res = self.adapter_client.start_transaction(Some(num_stmts));
1127 assert_ok!(res);
1128 self.adapter_client
1129 .inner()
1130 .metrics()
1131 .pgwire_ensure_transaction_seconds
1132 .with_label_values(&[message_type])
1133 .observe(start.elapsed().as_secs_f64());
1134 Ok(())
1135 }
1136
1137 fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1138 let parse_start = Instant::now();
1139 let result = match self.adapter_client.parse(sql) {
1140 Ok(result) => result.map_err(|e| {
1141 let pos = sql[..e.error.pos].chars().count() + 1;
1144 ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1145 }),
1146 Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1147 };
1148 self.adapter_client
1149 .inner()
1150 .metrics()
1151 .parse_seconds
1152 .observe(parse_start.elapsed().as_secs_f64());
1153 result
1154 }
1155
1156 #[instrument(level = "debug")]
1161 async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1162 let stmts = match self.parse_sql(&sql) {
1164 Ok(stmts) => stmts,
1165 Err(err) => {
1166 self.send_error_and_get_state(err).await?;
1167 return self.ready().await;
1168 }
1169 };
1170
1171 let num_stmts = stmts.len();
1172
1173 for StatementParseResult { ast: stmt, sql } in stmts {
1175 if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1177 self.aborted_txn_error().await?;
1178 break;
1179 }
1180
1181 self.ensure_transaction(num_stmts, "query").await?;
1189
1190 match self
1191 .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1192 .await?
1193 {
1194 State::Ready => (),
1195 State::Drain => break,
1196 State::Done => return Ok(State::Done),
1197 }
1198 }
1199
1200 {
1202 if self.adapter_client.session().transaction().is_implicit() {
1203 self.commit_transaction().await?;
1204 }
1205 }
1206
1207 if num_stmts == 0 {
1208 self.send(BackendMessage::EmptyQueryResponse).await?;
1209 }
1210
1211 self.ready().await
1212 }
1213
1214 #[instrument(level = "debug")]
1215 async fn parse(
1216 &mut self,
1217 name: String,
1218 sql: String,
1219 param_oids: Vec<u32>,
1220 ) -> Result<State, io::Error> {
1221 self.ensure_transaction(1, "parse").await?;
1223
1224 let mut param_types = vec![];
1225 for oid in param_oids {
1226 match mz_pgrepr::Type::from_oid(oid) {
1227 Ok(ty) => match SqlScalarType::try_from(&ty) {
1228 Ok(ty) => param_types.push(Some(ty)),
1229 Err(err) => {
1230 return self
1231 .send_error_and_get_state(ErrorResponse::error(
1232 SqlState::INVALID_PARAMETER_VALUE,
1233 err.to_string(),
1234 ))
1235 .await;
1236 }
1237 },
1238 Err(_) if oid == 0 => param_types.push(None),
1239 Err(e) => {
1240 return self
1241 .send_error_and_get_state(ErrorResponse::error(
1242 SqlState::PROTOCOL_VIOLATION,
1243 e.to_string(),
1244 ))
1245 .await;
1246 }
1247 }
1248 }
1249
1250 let stmts = match self.parse_sql(&sql) {
1251 Ok(stmts) => stmts,
1252 Err(err) => {
1253 return self.send_error_and_get_state(err).await;
1254 }
1255 };
1256 if stmts.len() > 1 {
1257 return self
1258 .send_error_and_get_state(ErrorResponse::error(
1259 SqlState::INTERNAL_ERROR,
1260 "cannot insert multiple commands into a prepared statement",
1261 ))
1262 .await;
1263 }
1264 let (maybe_stmt, sql) = match stmts.into_iter().next() {
1265 None => (None, ""),
1266 Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1267 };
1268 if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1269 return self.aborted_txn_error().await;
1270 }
1271 match self
1272 .adapter_client
1273 .prepare(name, maybe_stmt, sql.to_string(), param_types)
1274 .await
1275 {
1276 Ok(()) => {
1277 self.send(BackendMessage::ParseComplete).await?;
1278 Ok(State::Ready)
1279 }
1280 Err(e) => {
1281 self.send_error_and_get_state(e.into_response(Severity::Error))
1282 .await
1283 }
1284 }
1285 }
1286
1287 #[instrument(level = "debug")]
1289 async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1290 self.end_transaction(EndTransactionAction::Commit).await
1291 }
1292
1293 #[instrument(level = "debug")]
1295 async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1296 self.end_transaction(EndTransactionAction::Rollback).await
1297 }
1298
1299 #[instrument(level = "debug")]
1301 async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1302 self.txn_needs_commit = false;
1303 let resp = self.adapter_client.end_transaction(action).await;
1304 if let Err(err) = resp {
1305 self.send(BackendMessage::ErrorResponse(
1306 err.into_response(Severity::Error),
1307 ))
1308 .await?;
1309 }
1310 Ok(())
1311 }
1312
1313 #[instrument(level = "debug")]
1314 async fn bind(
1315 &mut self,
1316 portal_name: String,
1317 statement_name: String,
1318 param_formats: Vec<Format>,
1319 raw_params: Vec<Option<Vec<u8>>>,
1320 result_formats: Vec<Format>,
1321 ) -> Result<State, io::Error> {
1322 self.ensure_transaction(1, "bind").await?;
1324
1325 let aborted_txn = self.is_aborted_txn();
1326 let stmt = match self
1327 .adapter_client
1328 .get_prepared_statement(&statement_name)
1329 .await
1330 {
1331 Ok(stmt) => stmt,
1332 Err(err) => {
1333 return self
1334 .send_error_and_get_state(err.into_response(Severity::Error))
1335 .await;
1336 }
1337 };
1338
1339 let param_types = &stmt.desc().param_types;
1340 if param_types.len() != raw_params.len() {
1341 let message = format!(
1342 "bind message supplies {actual} parameters, \
1343 but prepared statement \"{name}\" requires {expected}",
1344 name = statement_name,
1345 actual = raw_params.len(),
1346 expected = param_types.len()
1347 );
1348 return self
1349 .send_error_and_get_state(ErrorResponse::error(
1350 SqlState::PROTOCOL_VIOLATION,
1351 message,
1352 ))
1353 .await;
1354 }
1355 let param_formats = match pad_formats(param_formats, raw_params.len()) {
1356 Ok(param_formats) => param_formats,
1357 Err(msg) => {
1358 return self
1359 .send_error_and_get_state(ErrorResponse::error(
1360 SqlState::PROTOCOL_VIOLATION,
1361 msg,
1362 ))
1363 .await;
1364 }
1365 };
1366 if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1367 return self.aborted_txn_error().await;
1368 }
1369 let buf = RowArena::new();
1370 let mut params = vec![];
1371 for ((raw_param, mz_typ), format) in raw_params
1372 .into_iter()
1373 .zip_eq(param_types)
1374 .zip_eq(param_formats)
1375 {
1376 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1377 let datum = match raw_param {
1378 None => Datum::Null,
1379 Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1380 Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1381 Ok(datum) => datum,
1382 Err(msg) => {
1383 return self
1384 .send_error_and_get_state(ErrorResponse::error(
1385 SqlState::INVALID_PARAMETER_VALUE,
1386 msg,
1387 ))
1388 .await;
1389 }
1390 },
1391 Err(err) => {
1392 let msg = format!("unable to decode parameter: {}", err);
1393 return self
1394 .send_error_and_get_state(ErrorResponse::error(
1395 SqlState::INVALID_PARAMETER_VALUE,
1396 msg,
1397 ))
1398 .await;
1399 }
1400 },
1401 };
1402 params.push((datum, mz_typ.clone()))
1403 }
1404
1405 let result_formats = match pad_formats(
1406 result_formats,
1407 stmt.desc()
1408 .relation_desc
1409 .clone()
1410 .map(|desc| desc.typ().column_types.len())
1411 .unwrap_or(0),
1412 ) {
1413 Ok(result_formats) => result_formats,
1414 Err(msg) => {
1415 return self
1416 .send_error_and_get_state(ErrorResponse::error(
1417 SqlState::PROTOCOL_VIOLATION,
1418 msg,
1419 ))
1420 .await;
1421 }
1422 };
1423
1424 if !stmt.stmt().map_or(false, |stmt| match stmt {
1427 Statement::Copy(CopyStatement {
1428 direction: CopyDirection::To,
1429 ..
1430 }) => true,
1431 Statement::Copy(CopyStatement {
1432 direction: CopyDirection::From,
1433 target: CopyTarget::Expr(_),
1437 ..
1438 }) => true,
1439 _ => false,
1440 }) {
1441 if let Some(desc) = stmt.desc().relation_desc.clone() {
1442 for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1443 match (format, &ty.scalar_type) {
1444 (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1445 return self
1446 .send_error_and_get_state(ErrorResponse::error(
1447 SqlState::PROTOCOL_VIOLATION,
1448 "binary encoding of list types is not implemented",
1449 ))
1450 .await;
1451 }
1452 (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1453 return self
1454 .send_error_and_get_state(ErrorResponse::error(
1455 SqlState::PROTOCOL_VIOLATION,
1456 "binary encoding of map types is not implemented",
1457 ))
1458 .await;
1459 }
1460 (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1461 return self
1462 .send_error_and_get_state(ErrorResponse::error(
1463 SqlState::PROTOCOL_VIOLATION,
1464 "binary encoding of aclitem types does not exist",
1465 ))
1466 .await;
1467 }
1468 _ => (),
1469 }
1470 }
1471 }
1472 }
1473
1474 let desc = stmt.desc().clone();
1475 let logging = Arc::clone(stmt.logging());
1476 let stmt_ast = stmt.stmt().cloned();
1477 let state_revision = stmt.state_revision;
1478 if let Err(err) = self.adapter_client.session().set_portal(
1479 portal_name,
1480 desc,
1481 stmt_ast,
1482 logging,
1483 params,
1484 result_formats,
1485 state_revision,
1486 ) {
1487 return self
1488 .send_error_and_get_state(err.into_response(Severity::Error))
1489 .await;
1490 }
1491
1492 self.send(BackendMessage::BindComplete).await?;
1493 Ok(State::Ready)
1494 }
1495
1496 fn execute(
1499 &mut self,
1500 portal_name: String,
1501 max_rows: ExecuteCount,
1502 get_response: GetResponse,
1503 fetch_portal_name: Option<String>,
1504 timeout: ExecuteTimeout,
1505 outer_ctx_extra: Option<ExecuteContextGuard>,
1506 received: Option<EpochMillis>,
1507 ) -> BoxFuture<'_, Result<State, io::Error>> {
1508 async move {
1509 let aborted_txn = self.is_aborted_txn();
1510
1511 let portal = match self
1513 .adapter_client
1514 .session()
1515 .get_portal_unverified_mut(&portal_name)
1516 {
1517 Some(portal) => portal,
1518 None => {
1519 let msg = format!("portal {} does not exist", portal_name.quoted());
1520 if let Some(outer_ctx_extra) = outer_ctx_extra {
1521 self.adapter_client.retire_execute(
1522 outer_ctx_extra,
1523 StatementEndedExecutionReason::Errored { error: msg.clone() },
1524 );
1525 }
1526 return self
1527 .send_error_and_get_state(ErrorResponse::error(
1528 SqlState::INVALID_CURSOR_NAME,
1529 msg,
1530 ))
1531 .await;
1532 }
1533 };
1534
1535 *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1536
1537 let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1539 if aborted_txn && !txn_exit_stmt {
1540 if let Some(outer_ctx_extra) = outer_ctx_extra {
1541 self.adapter_client.retire_execute(
1542 outer_ctx_extra,
1543 StatementEndedExecutionReason::Errored {
1544 error: ABORTED_TXN_MSG.to_string(),
1545 },
1546 );
1547 }
1548 return self.aborted_txn_error().await;
1549 }
1550
1551 let row_desc = portal.desc.relation_desc.clone();
1552 match portal.state {
1553 PortalState::NotStarted => {
1554 self.ensure_transaction(1, "execute").await?;
1556 match self
1557 .adapter_client
1558 .execute(
1559 portal_name.clone(),
1560 self.conn.wait_closed(),
1561 outer_ctx_extra,
1562 )
1563 .await
1564 {
1565 Ok((response, execute_started)) => {
1566 self.send_pending_notices().await?;
1567 self.send_execute_response(
1568 response,
1569 row_desc,
1570 portal_name,
1571 max_rows,
1572 get_response,
1573 fetch_portal_name,
1574 timeout,
1575 execute_started,
1576 )
1577 .await
1578 }
1579 Err(e) => {
1580 self.send_pending_notices().await?;
1581 self.send_error_and_get_state(e.into_response(Severity::Error))
1582 .await
1583 }
1584 }
1585 }
1586 PortalState::InProgress(rows) => {
1587 let rows = rows.take().expect("InProgress rows must be populated");
1588 let (result, statement_ended_execution_reason) = match self
1589 .send_rows(
1590 row_desc.expect("portal missing row desc on resumption"),
1591 portal_name,
1592 rows,
1593 max_rows,
1594 get_response,
1595 fetch_portal_name,
1596 timeout,
1597 )
1598 .await
1599 {
1600 Err(e) => {
1601 (Err(e), StatementEndedExecutionReason::Canceled)
1604 }
1605 Ok((ok, SendRowsEndedReason::Canceled)) => {
1606 (Ok(ok), StatementEndedExecutionReason::Canceled)
1607 }
1608 Ok((
1621 ok,
1622 SendRowsEndedReason::Success {
1623 result_size: _,
1624 rows_returned: _,
1625 },
1626 )) => (
1627 Ok(ok),
1628 StatementEndedExecutionReason::Success {
1629 result_size: None,
1630 rows_returned: None,
1631 execution_strategy: None,
1632 },
1633 ),
1634 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1635 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1636 }
1637 };
1638 if let Some(outer_ctx_extra) = outer_ctx_extra {
1639 self.adapter_client
1640 .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1641 }
1642 result
1643 }
1644 PortalState::Completed(Some(tag)) => {
1651 let tag = tag.to_string();
1652 if let Some(outer_ctx_extra) = outer_ctx_extra {
1653 self.adapter_client.retire_execute(
1654 outer_ctx_extra,
1655 StatementEndedExecutionReason::Success {
1656 result_size: None,
1657 rows_returned: None,
1658 execution_strategy: None,
1659 },
1660 );
1661 }
1662 self.send(BackendMessage::CommandComplete { tag }).await?;
1663 Ok(State::Ready)
1664 }
1665 PortalState::Completed(None) => {
1666 let error = format!(
1667 "portal {} cannot be run",
1668 Ident::new_unchecked(portal_name).to_ast_string_stable()
1669 );
1670 if let Some(outer_ctx_extra) = outer_ctx_extra {
1671 self.adapter_client.retire_execute(
1672 outer_ctx_extra,
1673 StatementEndedExecutionReason::Errored {
1674 error: error.clone(),
1675 },
1676 );
1677 }
1678 self.send_error_and_get_state(ErrorResponse::error(
1679 SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1680 error,
1681 ))
1682 .await
1683 }
1684 }
1685 }
1686 .instrument(debug_span!("execute"))
1687 .boxed()
1688 }
1689
1690 #[instrument(level = "debug")]
1691 async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1692 self.ensure_transaction(1, "describe_statement").await?;
1694
1695 let stmt = match self.adapter_client.get_prepared_statement(name).await {
1696 Ok(stmt) => stmt,
1697 Err(err) => {
1698 return self
1699 .send_error_and_get_state(err.into_response(Severity::Error))
1700 .await;
1701 }
1702 };
1703 let parameter_desc = BackendMessage::ParameterDescription(
1705 stmt.desc()
1706 .param_types
1707 .iter()
1708 .map(mz_pgrepr::Type::from)
1709 .collect(),
1710 );
1711 let formats = vec![Format::Text; stmt.desc().arity()];
1715 let row_desc = describe_rows(stmt.desc(), &formats);
1716 self.send_all([parameter_desc, row_desc]).await?;
1717 Ok(State::Ready)
1718 }
1719
1720 #[instrument(level = "debug")]
1721 async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1722 self.ensure_transaction(1, "describe_portal").await?;
1724
1725 let session = self.adapter_client.session();
1726 let row_desc = session
1727 .get_portal_unverified(name)
1728 .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1729 match row_desc {
1730 Some(row_desc) => {
1731 self.send(row_desc).await?;
1732 Ok(State::Ready)
1733 }
1734 None => {
1735 self.send_error_and_get_state(ErrorResponse::error(
1736 SqlState::INVALID_CURSOR_NAME,
1737 format!("portal {} does not exist", name.quoted()),
1738 ))
1739 .await
1740 }
1741 }
1742 }
1743
1744 #[instrument(level = "debug")]
1745 async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1746 self.adapter_client
1747 .session()
1748 .remove_prepared_statement(&name);
1749 self.send(BackendMessage::CloseComplete).await?;
1750 Ok(State::Ready)
1751 }
1752
1753 #[instrument(level = "debug")]
1754 async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1755 self.adapter_client.session().remove_portal(&name);
1756 self.send(BackendMessage::CloseComplete).await?;
1757 Ok(State::Ready)
1758 }
1759
1760 fn complete_portal(&mut self, name: &str) {
1761 let portal = self
1762 .adapter_client
1763 .session()
1764 .get_portal_unverified_mut(name)
1765 .expect("portal should exist");
1766 *portal.state = PortalState::Completed(None);
1767 }
1768
1769 async fn fetch(
1770 &mut self,
1771 name: String,
1772 count: Option<FetchDirection>,
1773 max_rows: ExecuteCount,
1774 fetch_portal_name: Option<String>,
1775 timeout: ExecuteTimeout,
1776 ctx_extra: ExecuteContextGuard,
1777 ) -> Result<State, io::Error> {
1778 let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1781
1782 let count = match (max_rows, count) {
1795 (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1796 let count = usize::cast_from(count);
1797 if max_rows < count {
1798 let msg = "Execute with max_rows < a FETCH's count is not supported";
1799 self.adapter_client.retire_execute(
1800 ctx_extra,
1801 StatementEndedExecutionReason::Errored {
1802 error: msg.to_string(),
1803 },
1804 );
1805 return self
1806 .send_error_and_get_state(ErrorResponse::error(
1807 SqlState::FEATURE_NOT_SUPPORTED,
1808 msg,
1809 ))
1810 .await;
1811 }
1812 ExecuteCount::Count(count)
1813 }
1814 (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1815 let msg = "Execute with max_rows of a FETCH ALL is not supported";
1816 self.adapter_client.retire_execute(
1817 ctx_extra,
1818 StatementEndedExecutionReason::Errored {
1819 error: msg.to_string(),
1820 },
1821 );
1822 return self
1823 .send_error_and_get_state(ErrorResponse::error(
1824 SqlState::FEATURE_NOT_SUPPORTED,
1825 msg,
1826 ))
1827 .await;
1828 }
1829 (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1830 (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1831 ExecuteCount::Count(usize::cast_from(count))
1832 }
1833 };
1834 let cursor_name = name.to_string();
1835 self.execute(
1836 cursor_name,
1837 count,
1838 fetch_message,
1839 fetch_portal_name,
1840 timeout,
1841 Some(ctx_extra),
1842 None,
1843 )
1844 .await
1845 }
1846
1847 async fn flush(&mut self) -> Result<State, io::Error> {
1848 self.conn.flush().await?;
1849 Ok(State::Ready)
1850 }
1851
1852 #[instrument(level = "debug")]
1857 async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1858 where
1859 M: Into<BackendMessage>,
1860 {
1861 let message: BackendMessage = message.into();
1862 let is_error =
1863 matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1864
1865 self.conn.send(message).await?;
1866
1867 if is_error {
1874 self.conn.flush().await?;
1875 }
1876
1877 Ok(())
1878 }
1879
1880 #[instrument(level = "debug")]
1881 pub async fn send_all(
1882 &mut self,
1883 messages: impl IntoIterator<Item = BackendMessage>,
1884 ) -> Result<(), io::Error> {
1885 for m in messages {
1886 self.send(m).await?;
1887 }
1888 Ok(())
1889 }
1890
1891 #[instrument(level = "debug")]
1892 async fn sync(&mut self) -> Result<State, io::Error> {
1893 if self.adapter_client.session().transaction().is_implicit() {
1895 self.commit_transaction().await?;
1896 }
1897 self.ready().await
1898 }
1899
1900 #[instrument(level = "debug")]
1901 async fn ready(&mut self) -> Result<State, io::Error> {
1902 let txn_state = self.adapter_client.session().transaction().into();
1903 self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1904 self.flush().await
1905 }
1906
1907 #[allow(clippy::too_many_arguments)]
1908 #[instrument(level = "debug")]
1909 async fn send_execute_response(
1910 &mut self,
1911 response: ExecuteResponse,
1912 row_desc: Option<RelationDesc>,
1913 portal_name: String,
1914 max_rows: ExecuteCount,
1915 get_response: GetResponse,
1916 fetch_portal_name: Option<String>,
1917 timeout: ExecuteTimeout,
1918 execute_started: Instant,
1919 ) -> Result<State, io::Error> {
1920 let mut tag = response.tag();
1921
1922 macro_rules! command_complete {
1923 () => {{
1924 self.send(BackendMessage::CommandComplete {
1925 tag: tag
1926 .take()
1927 .expect("command_complete only called on tag-generating results"),
1928 })
1929 .await?;
1930 Ok(State::Ready)
1931 }};
1932 }
1933
1934 let r = match response {
1935 ExecuteResponse::ClosedCursor => {
1936 self.complete_portal(&portal_name);
1937 command_complete!()
1938 }
1939 ExecuteResponse::DeclaredCursor => {
1940 self.complete_portal(&portal_name);
1941 command_complete!()
1942 }
1943 ExecuteResponse::EmptyQuery => {
1944 self.send(BackendMessage::EmptyQueryResponse).await?;
1945 Ok(State::Ready)
1946 }
1947 ExecuteResponse::Fetch {
1948 name,
1949 count,
1950 timeout,
1951 ctx_extra,
1952 } => {
1953 self.fetch(
1954 name,
1955 count,
1956 max_rows,
1957 Some(portal_name.to_string()),
1958 timeout,
1959 ctx_extra,
1960 )
1961 .await
1962 }
1963 ExecuteResponse::SendingRowsStreaming {
1964 rows,
1965 instance_id,
1966 strategy,
1967 } => {
1968 let row_desc = row_desc
1969 .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1970
1971 let span = tracing::debug_span!("sending_rows_streaming");
1972
1973 self.send_rows(
1974 row_desc,
1975 portal_name,
1976 InProgressRows::new(RecordFirstRowStream::new(
1977 Box::new(rows),
1978 execute_started,
1979 &self.adapter_client,
1980 Some(instance_id),
1981 Some(strategy),
1982 )),
1983 max_rows,
1984 get_response,
1985 fetch_portal_name,
1986 timeout,
1987 )
1988 .instrument(span)
1989 .await
1990 .map(|(state, _)| state)
1991 }
1992 ExecuteResponse::SendingRowsImmediate { rows } => {
1993 let row_desc = row_desc
1994 .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1995
1996 let span = tracing::debug_span!("sending_rows_immediate");
1997
1998 let stream =
1999 futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
2000 self.send_rows(
2001 row_desc,
2002 portal_name,
2003 InProgressRows::new(RecordFirstRowStream::new(
2004 Box::new(stream),
2005 execute_started,
2006 &self.adapter_client,
2007 None,
2008 Some(StatementExecutionStrategy::Constant),
2009 )),
2010 max_rows,
2011 get_response,
2012 fetch_portal_name,
2013 timeout,
2014 )
2015 .instrument(span)
2016 .await
2017 .map(|(state, _)| state)
2018 }
2019 ExecuteResponse::SetVariable { name, .. } => {
2020 let qn = name.to_string();
2023 let msg = if let Some(var) = self
2024 .adapter_client
2025 .session()
2026 .vars_mut()
2027 .notify_set()
2028 .find(|v| v.name() == qn)
2029 {
2030 Some(BackendMessage::ParameterStatus(var.name(), var.value()))
2031 } else {
2032 None
2033 };
2034 if let Some(msg) = msg {
2035 self.send(msg).await?;
2036 }
2037 command_complete!()
2038 }
2039 ExecuteResponse::Subscribing {
2040 rx,
2041 ctx_extra,
2042 instance_id,
2043 } => {
2044 if fetch_portal_name.is_none() {
2045 let mut msg = ErrorResponse::notice(
2046 SqlState::WARNING,
2047 "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
2048 );
2049 if self.adapter_client.session().vars().application_name() == "psql" {
2050 msg.hint = Some(
2051 "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
2052 .into(),
2053 )
2054 }
2055 self.send(msg).await?;
2056 self.conn.flush().await?;
2057 }
2058 let row_desc =
2059 row_desc.expect("missing row description for ExecuteResponse::Subscribing");
2060 let (result, statement_ended_execution_reason) = match self
2061 .send_rows(
2062 row_desc,
2063 portal_name,
2064 InProgressRows::new(RecordFirstRowStream::new(
2065 Box::new(UnboundedReceiverStream::new(rx)),
2066 execute_started,
2067 &self.adapter_client,
2068 Some(instance_id),
2069 None,
2070 )),
2071 max_rows,
2072 get_response,
2073 fetch_portal_name,
2074 timeout,
2075 )
2076 .await
2077 {
2078 Err(e) => {
2079 (Err(e), StatementEndedExecutionReason::Canceled)
2082 }
2083 Ok((ok, SendRowsEndedReason::Canceled)) => {
2084 (Ok(ok), StatementEndedExecutionReason::Canceled)
2085 }
2086 Ok((
2087 ok,
2088 SendRowsEndedReason::Success {
2089 result_size,
2090 rows_returned,
2091 },
2092 )) => (
2093 Ok(ok),
2094 StatementEndedExecutionReason::Success {
2095 result_size: Some(result_size),
2096 rows_returned: Some(rows_returned),
2097 execution_strategy: None,
2098 },
2099 ),
2100 Ok((ok, SendRowsEndedReason::Errored { error })) => {
2101 (Ok(ok), StatementEndedExecutionReason::Errored { error })
2102 }
2103 };
2104 self.adapter_client
2105 .retire_execute(ctx_extra, statement_ended_execution_reason);
2106 return result;
2107 }
2108 ExecuteResponse::CopyTo { format, resp } => {
2109 let row_desc =
2110 row_desc.expect("missing row description for ExecuteResponse::CopyTo");
2111 match *resp {
2112 ExecuteResponse::Subscribing {
2113 rx,
2114 ctx_extra,
2115 instance_id,
2116 } => {
2117 let (result, statement_ended_execution_reason) = match self
2118 .copy_rows(
2119 format,
2120 row_desc,
2121 RecordFirstRowStream::new(
2122 Box::new(UnboundedReceiverStream::new(rx)),
2123 execute_started,
2124 &self.adapter_client,
2125 Some(instance_id),
2126 None,
2127 ),
2128 )
2129 .await
2130 {
2131 Err(e) => {
2132 (Err(e), StatementEndedExecutionReason::Canceled)
2135 }
2136 Ok((
2137 state,
2138 SendRowsEndedReason::Success {
2139 result_size,
2140 rows_returned,
2141 },
2142 )) => (
2143 Ok(state),
2144 StatementEndedExecutionReason::Success {
2145 result_size: Some(result_size),
2146 rows_returned: Some(rows_returned),
2147 execution_strategy: None,
2148 },
2149 ),
2150 Ok((state, SendRowsEndedReason::Errored { error })) => {
2151 (Ok(state), StatementEndedExecutionReason::Errored { error })
2152 }
2153 Ok((state, SendRowsEndedReason::Canceled)) => {
2154 (Ok(state), StatementEndedExecutionReason::Canceled)
2155 }
2156 };
2157 self.adapter_client
2158 .retire_execute(ctx_extra, statement_ended_execution_reason);
2159 return result;
2160 }
2161 ExecuteResponse::SendingRowsStreaming {
2162 rows,
2163 instance_id,
2164 strategy,
2165 } => {
2166 return self
2171 .copy_rows(
2172 format,
2173 row_desc,
2174 RecordFirstRowStream::new(
2175 Box::new(rows),
2176 execute_started,
2177 &self.adapter_client,
2178 Some(instance_id),
2179 Some(strategy),
2180 ),
2181 )
2182 .await
2183 .map(|(state, _)| state);
2184 }
2185 ExecuteResponse::SendingRowsImmediate { rows } => {
2186 let span = tracing::debug_span!("sending_rows_immediate");
2187
2188 let rows = futures::stream::once(futures::future::ready(
2189 PeekResponseUnary::Rows(rows),
2190 ));
2191 return self
2196 .copy_rows(
2197 format,
2198 row_desc,
2199 RecordFirstRowStream::new(
2200 Box::new(rows),
2201 execute_started,
2202 &self.adapter_client,
2203 None,
2204 Some(StatementExecutionStrategy::Constant),
2205 ),
2206 )
2207 .instrument(span)
2208 .await
2209 .map(|(state, _)| state);
2210 }
2211 _ => {
2212 return self
2213 .send_error_and_get_state(ErrorResponse::error(
2214 SqlState::INTERNAL_ERROR,
2215 "unsupported COPY response type".to_string(),
2216 ))
2217 .await;
2218 }
2219 };
2220 }
2221 ExecuteResponse::CopyFrom {
2222 target_id,
2223 target_name,
2224 columns,
2225 params,
2226 ctx_extra,
2227 } => {
2228 let row_desc =
2229 row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2230 self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2231 .await
2232 }
2233 ExecuteResponse::TransactionCommitted { params }
2234 | ExecuteResponse::TransactionRolledBack { params } => {
2235 let notify_set: mz_ore::collections::HashSet<String> = self
2236 .adapter_client
2237 .session()
2238 .vars()
2239 .notify_set()
2240 .map(|v| v.name().to_string())
2241 .collect();
2242
2243 for (name, value) in params
2245 .into_iter()
2246 .filter(|(name, _v)| notify_set.contains(*name))
2247 {
2248 let msg = BackendMessage::ParameterStatus(name, value);
2249 self.send(msg).await?;
2250 }
2251 command_complete!()
2252 }
2253
2254 ExecuteResponse::AlteredDefaultPrivileges
2255 | ExecuteResponse::AlteredObject(..)
2256 | ExecuteResponse::AlteredRole
2257 | ExecuteResponse::AlteredSystemConfiguration
2258 | ExecuteResponse::CreatedCluster { .. }
2259 | ExecuteResponse::CreatedClusterReplica { .. }
2260 | ExecuteResponse::CreatedConnection { .. }
2261 | ExecuteResponse::CreatedDatabase { .. }
2262 | ExecuteResponse::CreatedIndex { .. }
2263 | ExecuteResponse::CreatedIntrospectionSubscribe
2264 | ExecuteResponse::CreatedMaterializedView { .. }
2265 | ExecuteResponse::CreatedContinualTask { .. }
2266 | ExecuteResponse::CreatedRole
2267 | ExecuteResponse::CreatedSchema { .. }
2268 | ExecuteResponse::CreatedSecret { .. }
2269 | ExecuteResponse::CreatedSink { .. }
2270 | ExecuteResponse::CreatedSource { .. }
2271 | ExecuteResponse::CreatedTable { .. }
2272 | ExecuteResponse::CreatedType
2273 | ExecuteResponse::CreatedView { .. }
2274 | ExecuteResponse::CreatedViews { .. }
2275 | ExecuteResponse::CreatedNetworkPolicy
2276 | ExecuteResponse::Comment
2277 | ExecuteResponse::Deallocate { .. }
2278 | ExecuteResponse::Deleted(..)
2279 | ExecuteResponse::DiscardedAll
2280 | ExecuteResponse::DiscardedTemp
2281 | ExecuteResponse::DroppedObject(_)
2282 | ExecuteResponse::DroppedOwned
2283 | ExecuteResponse::GrantedPrivilege
2284 | ExecuteResponse::GrantedRole
2285 | ExecuteResponse::Inserted(..)
2286 | ExecuteResponse::Copied(..)
2287 | ExecuteResponse::Prepare
2288 | ExecuteResponse::Raised
2289 | ExecuteResponse::ReassignOwned
2290 | ExecuteResponse::RevokedPrivilege
2291 | ExecuteResponse::RevokedRole
2292 | ExecuteResponse::StartedTransaction { .. }
2293 | ExecuteResponse::Updated(..)
2294 | ExecuteResponse::ValidatedConnection => {
2295 command_complete!()
2296 }
2297 };
2298
2299 assert_none!(tag, "tag created but not consumed: {:?}", tag);
2300 r
2301 }
2302
2303 #[allow(clippy::too_many_arguments)]
2304 #[mz_ore::instrument(level = "debug")]
2306 async fn send_rows(
2307 &mut self,
2308 row_desc: RelationDesc,
2309 portal_name: String,
2310 mut rows: InProgressRows,
2311 max_rows: ExecuteCount,
2312 get_response: GetResponse,
2313 fetch_portal_name: Option<String>,
2314 timeout: ExecuteTimeout,
2315 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2316 let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2319 name
2320 } else {
2321 &portal_name
2322 };
2323 let result_formats = self
2324 .adapter_client
2325 .session()
2326 .get_portal_unverified(result_format_portal_name)
2327 .expect("valid fetch portal name for send rows")
2328 .result_formats
2329 .clone();
2330
2331 let (mut wait_once, mut deadline) = match timeout {
2332 ExecuteTimeout::None => (false, None),
2333 ExecuteTimeout::Seconds(t) => (
2334 false,
2335 Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2336 ),
2337 ExecuteTimeout::WaitOnce => (true, None),
2338 };
2339
2340 {
2342 let portal_name_desc = &self
2343 .adapter_client
2344 .session()
2345 .get_portal_unverified(portal_name.as_str())
2346 .expect("portal should exist")
2347 .desc
2348 .relation_desc;
2349 if let Some(portal_name_desc) = portal_name_desc {
2350 soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2351 }
2352 if let Some(fetch_portal_name) = &fetch_portal_name {
2353 let fetch_portal_desc = &self
2354 .adapter_client
2355 .session()
2356 .get_portal_unverified(fetch_portal_name)
2357 .expect("portal should exist")
2358 .desc
2359 .relation_desc;
2360 if let Some(fetch_portal_desc) = fetch_portal_desc {
2361 soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2362 }
2363 }
2364 }
2365
2366 self.conn.set_encode_state(
2367 row_desc
2368 .typ()
2369 .column_types
2370 .iter()
2371 .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2372 .zip_eq(result_formats)
2373 .collect(),
2374 );
2375
2376 let mut total_sent_rows = 0;
2377 let mut total_sent_bytes = 0;
2378 let mut want_rows = match max_rows {
2380 ExecuteCount::All => usize::MAX,
2381 ExecuteCount::Count(count) => count,
2382 };
2383
2384 loop {
2386 let batch = if rows.current.is_some() {
2389 FetchResult::Rows(rows.current.take())
2390 } else if want_rows == 0 {
2391 FetchResult::Rows(None)
2392 } else {
2393 let notice_fut = self.adapter_client.session().recv_notice();
2394 tokio::select! {
2395 err = self.conn.wait_closed() => return Err(err),
2396 _ = time::sleep_until(
2397 deadline.unwrap_or_else(tokio::time::Instant::now),
2398 ), if deadline.is_some() => FetchResult::Rows(None),
2399 notice = notice_fut => {
2400 FetchResult::Notice(notice)
2401 }
2402 batch = rows.remaining.recv() => match batch {
2403 None => FetchResult::Rows(None),
2404 Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2405 Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2406 Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2407 },
2408 }
2409 };
2410
2411 match batch {
2412 FetchResult::Rows(None) => break,
2413 FetchResult::Rows(Some(mut batch_rows)) => {
2414 if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2415 let msg = err.to_string();
2416 return self
2417 .send_error_and_get_state(err.into_response(Severity::Error))
2418 .await
2419 .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2420 }
2421
2422 if wait_once && batch_rows.peek().is_some() {
2426 deadline = Some(tokio::time::Instant::now());
2427 wait_once = false;
2428 }
2429
2430 let mut sent_rows = 0;
2432 let mut sent_bytes = 0;
2433 let messages = (&mut batch_rows)
2434 .map(|row| {
2439 let row_len = row.byte_len();
2440 let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2441 (row_len, BackendMessage::DataRow(values))
2442 })
2443 .inspect(|(row_len, _)| {
2444 sent_bytes += row_len;
2445 sent_rows += 1
2446 })
2447 .map(|(_row_len, row)| row)
2448 .take(want_rows);
2449 self.send_all(messages).await?;
2450
2451 total_sent_rows += sent_rows;
2452 total_sent_bytes += sent_bytes;
2453 want_rows -= sent_rows;
2454
2455 if want_rows == 0 {
2458 if batch_rows.peek().is_some() {
2459 rows.current = Some(batch_rows);
2460 }
2461 break;
2462 }
2463
2464 self.conn.flush().await?;
2465 }
2466 FetchResult::Notice(notice) => {
2467 self.send(notice.into_response()).await?;
2468 self.conn.flush().await?;
2469 }
2470 FetchResult::Error(text) => {
2471 return self
2472 .send_error_and_get_state(ErrorResponse::error(
2473 SqlState::INTERNAL_ERROR,
2474 text.clone(),
2475 ))
2476 .await
2477 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2478 }
2479 FetchResult::Canceled => {
2480 return self
2481 .send_error_and_get_state(ErrorResponse::error(
2482 SqlState::QUERY_CANCELED,
2483 "canceling statement due to user request",
2484 ))
2485 .await
2486 .map(|state| (state, SendRowsEndedReason::Canceled));
2487 }
2488 }
2489 }
2490
2491 let portal = self
2492 .adapter_client
2493 .session()
2494 .get_portal_unverified_mut(&portal_name)
2495 .expect("valid portal name for send rows");
2496
2497 let saw_rows = rows.remaining.saw_rows;
2498 let no_more_rows = rows.no_more_rows();
2499 let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2500
2501 *portal.state = PortalState::InProgress(Some(rows));
2504
2505 let fetch_portal = fetch_portal_name.map(|name| {
2506 self.adapter_client
2507 .session()
2508 .get_portal_unverified_mut(&name)
2509 .expect("valid fetch portal")
2510 });
2511 let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2512 self.send(response_message).await?;
2513
2514 if no_more_rows {
2516 let statement_type = if let Some(stmt) = &self
2517 .adapter_client
2518 .session()
2519 .get_portal_unverified(&portal_name)
2520 .expect("valid portal name for send_rows")
2521 .stmt
2522 {
2523 metrics::statement_type_label_value(stmt.deref())
2524 } else {
2525 "no-statement"
2526 };
2527 let duration = if saw_rows {
2528 recorded_first_row_instant
2529 .expect("recorded_first_row_instant because saw_rows")
2530 .elapsed()
2531 } else {
2532 Duration::ZERO
2536 };
2537 self.adapter_client
2538 .inner()
2539 .metrics()
2540 .result_rows_first_to_last_byte_seconds
2541 .with_label_values(&[statement_type])
2542 .observe(duration.as_secs_f64());
2543 }
2544
2545 Ok((
2546 State::Ready,
2547 SendRowsEndedReason::Success {
2548 result_size: u64::cast_from(total_sent_bytes),
2549 rows_returned: u64::cast_from(total_sent_rows),
2550 },
2551 ))
2552 }
2553
2554 #[mz_ore::instrument(level = "debug")]
2555 async fn copy_rows(
2556 &mut self,
2557 format: CopyFormat,
2558 row_desc: RelationDesc,
2559 mut stream: RecordFirstRowStream,
2560 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2561 let (row_format, encode_format) = match format {
2562 CopyFormat::Text => (
2563 CopyFormatParams::Text(CopyTextFormatParams::default()),
2564 Format::Text,
2565 ),
2566 CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2567 CopyFormat::Csv => (
2568 CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2569 Format::Text,
2570 ),
2571 CopyFormat::Parquet => {
2572 let text = "Parquet format is not supported".to_string();
2573 return self
2574 .send_error_and_get_state(ErrorResponse::error(
2575 SqlState::INTERNAL_ERROR,
2576 text.clone(),
2577 ))
2578 .await
2579 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2580 }
2581 };
2582
2583 let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2584 mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2585 };
2586
2587 let typ = row_desc.typ();
2588 let column_formats = iter::repeat(encode_format)
2589 .take(typ.column_types.len())
2590 .collect();
2591 self.send(BackendMessage::CopyOutResponse {
2592 overall_format: encode_format,
2593 column_formats,
2594 })
2595 .await?;
2596
2597 let mut out = Vec::new();
2602
2603 if let CopyFormat::Binary = format {
2604 out.extend(b"PGCOPY\n\xFF\r\n\0");
2606 out.extend([0, 0, 0, 0]);
2608 out.extend([0, 0, 0, 0]);
2610 }
2611
2612 let mut count = 0;
2613 let mut total_sent_bytes = 0;
2614 loop {
2615 tokio::select! {
2616 e = self.conn.wait_closed() => return Err(e),
2617 batch = stream.recv() => match batch {
2618 None => break,
2619 Some(PeekResponseUnary::Error(text)) => {
2620 let err =
2621 ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone());
2622 return self
2623 .send_error_and_get_state(err)
2624 .await
2625 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2626 }
2627 Some(PeekResponseUnary::Canceled) => {
2628 return self.send_error_and_get_state(ErrorResponse::error(
2629 SqlState::QUERY_CANCELED,
2630 "canceling statement due to user request",
2631 ))
2632 .await.map(|state| (state, SendRowsEndedReason::Canceled));
2633 }
2634 Some(PeekResponseUnary::Rows(mut rows)) => {
2635 count += rows.count();
2636 while let Some(row) = rows.next() {
2637 total_sent_bytes += row.byte_len();
2638 encode_fn(row, typ, &mut out)?;
2639 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2640 .await?;
2641 }
2642 }
2643 },
2644 notice = self.adapter_client.session().recv_notice() => {
2645 self.send(notice.into_response())
2646 .await?;
2647 self.conn.flush().await?;
2648 }
2649 }
2650
2651 self.conn.flush().await?;
2652 }
2653 if let CopyFormat::Binary = format {
2655 let trailer: i16 = -1;
2656 out.extend(trailer.to_be_bytes());
2657 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2658 .await?;
2659 }
2660
2661 let tag = format!("COPY {}", count);
2662 self.send(BackendMessage::CopyDone).await?;
2663 self.send(BackendMessage::CommandComplete { tag }).await?;
2664 Ok((
2665 State::Ready,
2666 SendRowsEndedReason::Success {
2667 result_size: u64::cast_from(total_sent_bytes),
2668 rows_returned: u64::cast_from(count),
2669 },
2670 ))
2671 }
2672
2673 #[instrument(level = "debug")]
2676 async fn copy_from(
2677 &mut self,
2678 target_id: CatalogItemId,
2679 target_name: String,
2680 columns: Vec<ColumnIndex>,
2681 params: CopyFormatParams<'static>,
2682 row_desc: RelationDesc,
2683 mut ctx_extra: ExecuteContextGuard,
2684 ) -> Result<State, io::Error> {
2685 let res = self
2686 .copy_from_inner(
2687 target_id,
2688 target_name,
2689 columns,
2690 params,
2691 row_desc,
2692 &mut ctx_extra,
2693 )
2694 .await;
2695 match &res {
2696 Ok(State::Ready) => {
2697 self.adapter_client.retire_execute(
2698 ctx_extra,
2699 StatementEndedExecutionReason::Success {
2700 result_size: None,
2701 rows_returned: None,
2702 execution_strategy: None,
2703 },
2704 );
2705 }
2706 Ok(State::Done) => {
2707 self.adapter_client
2711 .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2712 }
2713 Err(e) => {
2714 self.adapter_client.retire_execute(
2715 ctx_extra,
2716 StatementEndedExecutionReason::Errored {
2717 error: format!("{e}"),
2718 },
2719 );
2720 }
2721 Ok(state) if matches!(state, State::Drain) => {}
2722 other => {
2723 mz_ore::soft_panic_or_log!(
2724 "unexpected COPY FROM state in copy_from: {other:?}; \
2725 relying on ExecuteContextGuard::drop to retire as Aborted"
2726 );
2727 }
2728 }
2729 res
2730 }
2731
2732 async fn copy_from_inner(
2733 &mut self,
2734 target_id: CatalogItemId,
2735 target_name: String,
2736 columns: Vec<ColumnIndex>,
2737 params: CopyFormatParams<'static>,
2738 row_desc: RelationDesc,
2739 ctx_extra: &mut ExecuteContextGuard,
2740 ) -> Result<State, io::Error> {
2741 let typ = row_desc.typ();
2742 let column_formats = vec![Format::Text; typ.column_types.len()];
2743 self.send(BackendMessage::CopyInResponse {
2744 overall_format: Format::Text,
2745 column_formats,
2746 })
2747 .await?;
2748 self.conn.flush().await?;
2749
2750 let writer = match self
2752 .adapter_client
2753 .start_copy_from_stdin(
2754 target_id,
2755 target_name.clone(),
2756 columns.clone(),
2757 row_desc.clone(),
2758 params.clone(),
2759 )
2760 .await
2761 {
2762 Ok(writer) => writer,
2763 Err(e) => {
2764 loop {
2770 match self.conn.recv().await? {
2771 Some(FrontendMessage::CopyData(_)) => {}
2772 Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2773 break;
2774 }
2775 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2776 Some(_) => break,
2777 None => return Ok(State::Done),
2778 }
2779 }
2780 self.adapter_client.retire_execute(
2781 std::mem::take(ctx_extra),
2782 StatementEndedExecutionReason::Errored {
2783 error: e.to_string(),
2784 },
2785 );
2786 return self
2787 .send_error_and_get_state(e.into_response(Severity::Error))
2788 .await;
2789 }
2790 };
2791
2792 self.conn.set_copy_mode(true);
2794
2795 const BATCH_SIZE: usize = 32 * 1024 * 1024;
2797 let max_copy_from_row_size = self
2798 .adapter_client
2799 .get_system_vars()
2800 .await
2801 .max_copy_from_row_size()
2802 .try_into()
2803 .unwrap_or(usize::MAX);
2804
2805 let mut data = Vec::new();
2806 let mut row_scanner = CopyRowScanner::new(¶ms);
2807 let num_workers = writer.batch_txs.len();
2808 let mut next_worker: usize = 0;
2809 let mut saw_copy_done = false;
2810 let mut saw_end_marker = false;
2811 let mut copy_from_error: Option<(SqlState, String)> = None;
2812
2813 loop {
2816 let message = self.conn.recv().await?;
2817 match message {
2818 Some(FrontendMessage::CopyData(buf)) => {
2819 if saw_end_marker {
2820 continue;
2823 }
2824 data.extend(buf);
2825 row_scanner.scan_new_bytes(&data);
2826
2827 if let Some(end_pos) = row_scanner.end_marker_end() {
2828 data.truncate(end_pos);
2829 row_scanner.on_truncate(end_pos);
2830 saw_end_marker = true;
2831 }
2832
2833 if row_scanner.current_row_size(data.len()) > max_copy_from_row_size {
2835 copy_from_error = Some((
2836 SqlState::INSUFFICIENT_RESOURCES,
2837 format!(
2838 "COPY FROM STDIN row exceeded max_copy_from_row_size \
2839 ({max_copy_from_row_size} bytes)"
2840 ),
2841 ));
2842 break;
2843 }
2844
2845 let mut send_failed = false;
2848 while data.len() >= BATCH_SIZE {
2849 let split_pos = match row_scanner.last_row_end() {
2850 Some(pos) => pos,
2851 None => break, };
2853 let remainder = data.split_off(split_pos);
2854 let chunk = std::mem::replace(&mut data, remainder);
2855 row_scanner.on_split(split_pos);
2856 if writer.batch_txs[next_worker].send(chunk).await.is_err() {
2857 send_failed = true;
2858 break;
2859 }
2860 next_worker = (next_worker + 1) % num_workers;
2861 }
2862 if send_failed {
2865 break;
2866 }
2867 }
2868 Some(FrontendMessage::CopyDone) => {
2869 if !data.is_empty() {
2871 let chunk = std::mem::take(&mut data);
2872 let _ = writer.batch_txs[next_worker].send(chunk).await;
2874 }
2875 saw_copy_done = true;
2876 break;
2877 }
2878 Some(FrontendMessage::CopyFail(err)) => {
2879 self.adapter_client.retire_execute(
2880 std::mem::take(ctx_extra),
2881 StatementEndedExecutionReason::Canceled,
2882 );
2883 drop(writer);
2885 self.conn.set_copy_mode(false);
2886 return self
2887 .send_error_and_get_state(ErrorResponse::error(
2888 SqlState::QUERY_CANCELED,
2889 format!("COPY from stdin failed: {}", err),
2890 ))
2891 .await;
2892 }
2893 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2894 Some(_) => {
2895 let msg = "unexpected message type during COPY from stdin";
2896 self.adapter_client.retire_execute(
2897 std::mem::take(ctx_extra),
2898 StatementEndedExecutionReason::Errored {
2899 error: msg.to_string(),
2900 },
2901 );
2902 drop(writer);
2903 self.conn.set_copy_mode(false);
2904 return self
2905 .send_error_and_get_state(ErrorResponse::error(
2906 SqlState::PROTOCOL_VIOLATION,
2907 msg,
2908 ))
2909 .await;
2910 }
2911 None => {
2912 drop(writer);
2913 self.conn.set_copy_mode(false);
2914 return Ok(State::Done);
2915 }
2916 }
2917 }
2918
2919 if !saw_copy_done {
2923 loop {
2924 match self.conn.recv().await? {
2925 Some(FrontendMessage::CopyData(_)) => {}
2926 Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2927 break;
2928 }
2929 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2930 Some(_) => {
2931 let msg = "unexpected message type during COPY from stdin";
2932 self.adapter_client.retire_execute(
2933 std::mem::take(ctx_extra),
2934 StatementEndedExecutionReason::Errored {
2935 error: msg.to_string(),
2936 },
2937 );
2938 drop(writer);
2939 self.conn.set_copy_mode(false);
2940 return self
2941 .send_error_and_get_state(ErrorResponse::error(
2942 SqlState::PROTOCOL_VIOLATION,
2943 msg,
2944 ))
2945 .await;
2946 }
2947 None => {
2948 drop(writer);
2949 self.conn.set_copy_mode(false);
2950 return Ok(State::Done);
2951 }
2952 }
2953 }
2954 }
2955
2956 if let Some((code, msg)) = copy_from_error {
2957 self.adapter_client.retire_execute(
2958 std::mem::take(ctx_extra),
2959 StatementEndedExecutionReason::Errored { error: msg.clone() },
2960 );
2961 drop(writer);
2962 self.conn.set_copy_mode(false);
2963 return self
2964 .send_error_and_get_state(ErrorResponse::error(code, msg))
2965 .await;
2966 }
2967
2968 self.conn.set_copy_mode(false);
2969
2970 drop(writer.batch_txs);
2975
2976 let (proto_batches, row_count) = match writer.completion_rx.await {
2978 Ok(Ok(result)) => result,
2979 Ok(Err(e)) => {
2980 self.adapter_client.retire_execute(
2981 std::mem::take(ctx_extra),
2982 StatementEndedExecutionReason::Errored {
2983 error: e.to_string(),
2984 },
2985 );
2986 return self
2987 .send_error_and_get_state(e.into_response(Severity::Error))
2988 .await;
2989 }
2990 Err(_) => {
2991 let msg = "COPY FROM STDIN: background batch builder tasks dropped";
2992 self.adapter_client.retire_execute(
2993 std::mem::take(ctx_extra),
2994 StatementEndedExecutionReason::Errored {
2995 error: msg.to_string(),
2996 },
2997 );
2998 return self
2999 .send_error_and_get_state(ErrorResponse::error(SqlState::INTERNAL_ERROR, msg))
3000 .await;
3001 }
3002 };
3003
3004 if let Err(e) = self
3006 .adapter_client
3007 .stage_copy_from_stdin_batches(target_id, proto_batches)
3008 {
3009 self.adapter_client.retire_execute(
3010 std::mem::take(ctx_extra),
3011 StatementEndedExecutionReason::Errored {
3012 error: e.to_string(),
3013 },
3014 );
3015 return self
3016 .send_error_and_get_state(e.into_response(Severity::Error))
3017 .await;
3018 }
3019
3020 let tag = format!("COPY {}", row_count);
3021 self.send(BackendMessage::CommandComplete { tag }).await?;
3022
3023 Ok(State::Ready)
3024 }
3025
3026 #[instrument(level = "debug")]
3027 async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
3028 let notices = self
3029 .adapter_client
3030 .session()
3031 .drain_notices()
3032 .into_iter()
3033 .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
3034 self.send_all(notices).await?;
3035 Ok(())
3036 }
3037
3038 #[instrument(level = "debug")]
3039 async fn send_error_and_get_state(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
3040 assert!(err.severity.is_error());
3041 debug!(
3042 "cid={} error code={}",
3043 self.adapter_client.session().conn_id(),
3044 err.code.code()
3045 );
3046 let is_fatal = err.severity.is_fatal();
3047 self.send(BackendMessage::ErrorResponse(err)).await?;
3048
3049 let txn = self.adapter_client.session().transaction();
3050 match txn {
3051 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
3054 TransactionStatus::Started(_) => {
3056 self.rollback_transaction().await?;
3057 }
3058 TransactionStatus::InTransactionImplicit(_) => {
3060 self.rollback_transaction().await?;
3061 }
3062 TransactionStatus::InTransaction(_) => {
3064 self.adapter_client.fail_transaction();
3065 }
3066 };
3067 if is_fatal {
3068 Ok(State::Done)
3069 } else {
3070 Ok(State::Drain)
3071 }
3072 }
3073
3074 #[instrument(level = "debug")]
3075 async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
3076 self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
3077 SqlState::IN_FAILED_SQL_TRANSACTION,
3078 ABORTED_TXN_MSG,
3079 )))
3080 .await?;
3081 Ok(State::Drain)
3082 }
3083
3084 fn is_aborted_txn(&mut self) -> bool {
3085 matches!(
3086 self.adapter_client.session().transaction(),
3087 TransactionStatus::Failed(_)
3088 )
3089 }
3090}
3091
3092fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
3093 match (formats.len(), n) {
3094 (0, e) => Ok(vec![Format::Text; e]),
3095 (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
3096 (a, e) if a == e => Ok(formats),
3097 (a, e) => Err(format!(
3098 "expected {} field format specifiers, but got {}",
3099 e, a
3100 )),
3101 }
3102}
3103
3104fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
3105 match &stmt_desc.relation_desc {
3106 Some(desc) if !stmt_desc.is_copy => {
3107 BackendMessage::RowDescription(message::encode_row_description(desc, formats))
3108 }
3109 _ => BackendMessage::NoData,
3110 }
3111}
3112
3113type GetResponse = fn(
3114 max_rows: ExecuteCount,
3115 total_sent_rows: usize,
3116 fetch_portal: Option<PortalRefMut>,
3117) -> BackendMessage;
3118
3119fn portal_exec_message(
3122 max_rows: ExecuteCount,
3123 total_sent_rows: usize,
3124 _fetch_portal: Option<PortalRefMut>,
3125) -> BackendMessage {
3126 match max_rows {
3133 ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
3134 BackendMessage::PortalSuspended
3135 }
3136 _ => BackendMessage::CommandComplete {
3137 tag: format!("SELECT {}", total_sent_rows),
3138 },
3139 }
3140}
3141
3142fn fetch_message(
3144 _max_rows: ExecuteCount,
3145 total_sent_rows: usize,
3146 fetch_portal: Option<PortalRefMut>,
3147) -> BackendMessage {
3148 let tag = format!("FETCH {}", total_sent_rows);
3149 if let Some(portal) = fetch_portal {
3150 *portal.state = PortalState::Completed(Some(tag.clone()));
3151 }
3152 BackendMessage::CommandComplete { tag }
3153}
3154
3155fn get_authenticator(
3156 authenticator_kind: AuthenticatorKind,
3157 frontegg: Option<FronteggAuthenticator>,
3158 oidc: GenericOidcAuthenticator,
3159 adapter_client: mz_adapter::Client,
3160 oidc_auth_option_enabled: bool,
3163) -> Authenticator {
3164 match authenticator_kind {
3165 AuthenticatorKind::Frontegg => Authenticator::Frontegg(
3166 frontegg.expect("Frontegg authenticator should exist with AuthenticatorKind::Frontegg"),
3167 ),
3168 AuthenticatorKind::Password => Authenticator::Password(adapter_client),
3169 AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client),
3170 AuthenticatorKind::Oidc => {
3171 if oidc_auth_option_enabled {
3172 Authenticator::Oidc(oidc)
3173 } else {
3174 Authenticator::Password(adapter_client)
3177 }
3178 }
3179 AuthenticatorKind::None => Authenticator::None,
3180 }
3181}
3182
3183#[derive(Debug, Copy, Clone)]
3184enum ExecuteCount {
3185 All,
3186 Count(usize),
3187}
3188
3189fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
3191 match stmt {
3192 Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
3194 None => false,
3195 }
3196}
3197
3198#[derive(Debug)]
3199enum FetchResult {
3200 Rows(Option<Box<dyn RowIterator + Send + Sync>>),
3201 Canceled,
3202 Error(String),
3203 Notice(AdapterNotice),
3204}
3205
3206#[derive(Debug)]
3207struct CopyRowScanner {
3208 scan_pos: usize,
3209 last_row_end: Option<usize>,
3210 end_marker_end: Option<usize>,
3211 csv: Option<CsvScanState>,
3212}
3213
3214#[derive(Debug)]
3215struct CsvScanState {
3216 reader: csv_core::Reader,
3217 output: Vec<u8>,
3218 ends: Vec<usize>,
3219 record: Vec<u8>,
3220 record_ends: Vec<usize>,
3221 skip_first_record: bool,
3222}
3223
3224impl CopyRowScanner {
3225 fn new(params: &CopyFormatParams<'_>) -> Self {
3226 let csv = match params {
3227 CopyFormatParams::Csv(CopyCsvFormatParams {
3228 delimiter,
3229 quote,
3230 escape,
3231 header,
3232 ..
3233 }) => Some(CsvScanState::new(*delimiter, *quote, *escape, *header)),
3234 _ => None,
3235 };
3236
3237 CopyRowScanner {
3238 scan_pos: 0,
3239 last_row_end: None,
3240 end_marker_end: None,
3241 csv,
3242 }
3243 }
3244
3245 fn scan_new_bytes(&mut self, data: &[u8]) {
3246 if self.scan_pos >= data.len() {
3247 return;
3248 }
3249
3250 if let Some(csv) = self.csv.as_mut() {
3251 let mut input = &data[self.scan_pos..];
3252 let mut consumed = 0usize;
3253 while !input.is_empty() {
3254 let (result, n_input, n_output, n_ends) =
3255 csv.reader
3256 .read_record(input, &mut csv.output, &mut csv.ends);
3257 consumed += n_input;
3258 input = &input[n_input..];
3259 if !csv.output.is_empty() {
3260 csv.record.extend_from_slice(&csv.output[..n_output]);
3261 }
3262 if !csv.ends.is_empty() {
3263 csv.record_ends.extend_from_slice(&csv.ends[..n_ends]);
3264 }
3265
3266 match result {
3267 ReadRecordResult::InputEmpty => break,
3268 ReadRecordResult::OutputFull => {
3269 if n_input == 0 {
3270 csv.output
3271 .resize(csv.output.len().saturating_mul(2).max(1), 0);
3272 }
3273 }
3274 ReadRecordResult::OutputEndsFull => {
3275 if n_input == 0 {
3276 csv.ends.resize(csv.ends.len().saturating_mul(2).max(1), 0);
3277 }
3278 }
3279 ReadRecordResult::Record | ReadRecordResult::End => {
3280 let row_end = self.scan_pos + consumed;
3281 self.last_row_end = Some(row_end);
3282 if self.end_marker_end.is_none() {
3283 let is_marker = if csv.skip_first_record {
3284 csv.skip_first_record = false;
3285 false
3286 } else if csv.record_ends.len() == 1 {
3287 let end = csv.record_ends[0];
3288 end == 2 && csv.record.get(0..end) == Some(b"\\.")
3289 } else {
3290 false
3291 };
3292 if is_marker {
3293 self.end_marker_end = Some(row_end);
3294 break;
3295 }
3296 }
3297 csv.record.clear();
3298 csv.record_ends.clear();
3299 }
3300 }
3301 }
3302 } else {
3303 let mut row_start = self.last_row_end.unwrap_or(0);
3304 for (offset, b) in data[self.scan_pos..].iter().enumerate() {
3305 if *b == b'\n' {
3306 let row_end = self.scan_pos + offset + 1;
3307 self.last_row_end = Some(row_end);
3308 if self.end_marker_end.is_none() {
3309 let row = &data[row_start..row_end];
3310 if row.get(0..2) == Some(b"\\.") {
3311 self.end_marker_end = Some(row_end);
3312 break;
3313 }
3314 }
3315 row_start = row_end;
3316 }
3317 }
3318 }
3319
3320 self.scan_pos = data.len();
3321 }
3322
3323 fn last_row_end(&self) -> Option<usize> {
3324 self.last_row_end
3325 }
3326
3327 fn end_marker_end(&self) -> Option<usize> {
3328 self.end_marker_end
3329 }
3330
3331 fn current_row_size(&self, data_len: usize) -> usize {
3332 data_len.saturating_sub(self.last_row_end.unwrap_or(0))
3333 }
3334
3335 fn on_split(&mut self, split_pos: usize) {
3336 self.scan_pos = self.scan_pos.saturating_sub(split_pos);
3337 self.last_row_end = None;
3338 self.end_marker_end = self
3339 .end_marker_end
3340 .and_then(|end| end.checked_sub(split_pos));
3341 }
3342
3343 fn on_truncate(&mut self, new_len: usize) {
3344 self.scan_pos = self.scan_pos.min(new_len);
3345 self.last_row_end = self.last_row_end.filter(|&end| end <= new_len);
3346 self.end_marker_end = self.end_marker_end.filter(|&end| end <= new_len);
3347 }
3348}
3349
3350impl CsvScanState {
3351 fn new(delimiter: u8, quote: u8, escape: u8, header: bool) -> Self {
3352 let (double_quote, escape) = if quote == escape {
3353 (true, None)
3354 } else {
3355 (false, Some(escape))
3356 };
3357 CsvScanState {
3358 reader: csv_core::ReaderBuilder::new()
3359 .delimiter(delimiter)
3360 .quote(quote)
3361 .double_quote(double_quote)
3362 .escape(escape)
3363 .build(),
3364 output: vec![0; 1],
3365 ends: vec![0; 1],
3366 record: Vec::new(),
3367 record_ends: Vec::new(),
3368 skip_first_record: header,
3369 }
3370 }
3371}
3372
3373#[cfg(test)]
3374mod test {
3375 use super::*;
3376
3377 #[mz_ore::test]
3378 fn test_parse_options() {
3379 struct TestCase {
3380 input: &'static str,
3381 expect: Result<Vec<(&'static str, &'static str)>, ()>,
3382 }
3383 let tests = vec![
3384 TestCase {
3385 input: "",
3386 expect: Ok(vec![]),
3387 },
3388 TestCase {
3389 input: "--key",
3390 expect: Err(()),
3391 },
3392 TestCase {
3393 input: "--key=val",
3394 expect: Ok(vec![("key", "val")]),
3395 },
3396 TestCase {
3397 input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
3398 expect: Ok(vec![
3399 ("key", "val"),
3400 ("key2", "val2"),
3401 ("key3", "val3"),
3402 ("key4", "val4"),
3403 ("key5", "val5"),
3404 ]),
3405 },
3406 TestCase {
3407 input: r#"-c\ key=val"#,
3408 expect: Ok(vec![(" key", "val")]),
3409 },
3410 TestCase {
3411 input: "--key=val -ckey2 val2",
3412 expect: Err(()),
3413 },
3414 TestCase {
3416 input: "--key=",
3417 expect: Ok(vec![("key", "")]),
3418 },
3419 ];
3420 for test in tests {
3421 let got = parse_options(test.input);
3422 let expect = test.expect.map(|r| {
3423 r.into_iter()
3424 .map(|(k, v)| (k.to_owned(), v.to_owned()))
3425 .collect()
3426 });
3427 assert_eq!(got, expect, "input: {}", test.input);
3428 }
3429 }
3430
3431 #[mz_ore::test]
3432 fn test_parse_option() {
3433 struct TestCase {
3434 input: &'static str,
3435 expect: Result<(&'static str, &'static str), ()>,
3436 }
3437 let tests = vec![
3438 TestCase {
3439 input: "",
3440 expect: Err(()),
3441 },
3442 TestCase {
3443 input: "--",
3444 expect: Err(()),
3445 },
3446 TestCase {
3447 input: "--c",
3448 expect: Err(()),
3449 },
3450 TestCase {
3451 input: "a=b",
3452 expect: Err(()),
3453 },
3454 TestCase {
3455 input: "--a=b",
3456 expect: Ok(("a", "b")),
3457 },
3458 TestCase {
3459 input: "--ca=b",
3460 expect: Ok(("ca", "b")),
3461 },
3462 TestCase {
3463 input: "-ca=b",
3464 expect: Ok(("a", "b")),
3465 },
3466 TestCase {
3468 input: "--=",
3469 expect: Ok(("", "")),
3470 },
3471 ];
3472 for test in tests {
3473 let got = parse_option(test.input);
3474 assert_eq!(got, test.expect, "input: {}", test.input);
3475 }
3476 }
3477
3478 #[mz_ore::test]
3479 fn test_split_options() {
3480 struct TestCase {
3481 input: &'static str,
3482 expect: Vec<&'static str>,
3483 }
3484 let tests = vec![
3485 TestCase {
3486 input: "",
3487 expect: vec![],
3488 },
3489 TestCase {
3490 input: " ",
3491 expect: vec![],
3492 },
3493 TestCase {
3494 input: " a ",
3495 expect: vec!["a"],
3496 },
3497 TestCase {
3498 input: " ab cd ",
3499 expect: vec!["ab", "cd"],
3500 },
3501 TestCase {
3502 input: r#" ab\ cd "#,
3503 expect: vec!["ab ", "cd"],
3504 },
3505 TestCase {
3506 input: r#" ab\\ cd "#,
3507 expect: vec![r#"ab\"#, "cd"],
3508 },
3509 TestCase {
3510 input: r#" ab\\\ cd "#,
3511 expect: vec![r#"ab\ "#, "cd"],
3512 },
3513 TestCase {
3514 input: r#" ab\\\ cd "#,
3515 expect: vec![r#"ab\ cd"#],
3516 },
3517 TestCase {
3518 input: r#" ab\\\cd "#,
3519 expect: vec![r#"ab\cd"#],
3520 },
3521 TestCase {
3522 input: r#"a\"#,
3523 expect: vec!["a"],
3524 },
3525 TestCase {
3526 input: r#"a\ "#,
3527 expect: vec!["a "],
3528 },
3529 TestCase {
3530 input: r#"\"#,
3531 expect: vec![],
3532 },
3533 TestCase {
3534 input: r#"\ "#,
3535 expect: vec![r#" "#],
3536 },
3537 TestCase {
3538 input: r#" \ "#,
3539 expect: vec![r#" "#],
3540 },
3541 TestCase {
3542 input: r#"\ "#,
3543 expect: vec![r#" "#],
3544 },
3545 ];
3546 for test in tests {
3547 let got = split_options(test.input);
3548 assert_eq!(got, test.expect, "input: {}", test.input);
3549 }
3550 }
3551}