1use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use csv_core::ReadRecordResult;
21use futures::future::{BoxFuture, FutureExt, pending};
22use itertools::Itertools;
23use mz_adapter::client::RecordFirstRowStream;
24use mz_adapter::session::{
25 EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState, Session,
26 SessionConfig, TransactionStatus,
27};
28use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
29use mz_adapter::{
30 AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics,
31 verify_datum_desc,
32};
33use mz_auth::Authenticated;
34use mz_auth::password::Password;
35use mz_authenticator::{Authenticator, GenericOidcAuthenticator};
36use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
37use mz_ore::cast::CastFrom;
38use mz_ore::netio::AsyncReady;
39use mz_ore::now::{EpochMillis, SYSTEM_TIME};
40use mz_ore::str::StrExt;
41use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
42use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
43use mz_pgwire_common::{
44 ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
45 VERSIONS,
46};
47use mz_repr::{
48 CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
49 SqlRelationType, SqlScalarType,
50};
51use mz_server_core::TlsMode;
52use mz_server_core::listeners;
53use mz_server_core::listeners::AllowedRoles;
54use mz_sql::ast::display::AstDisplay;
55use mz_sql::ast::{
56 CopyDirection, CopyStatement, CopyTarget, FetchDirection, Ident, Raw, Statement,
57};
58use mz_sql::parse::StatementParseResult;
59use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
60use mz_sql::session::metadata::SessionMetadata;
61use mz_sql::session::user::INTERNAL_USER_NAMES;
62use mz_sql::session::vars::VarInput;
63use postgres::error::SqlState;
64use tokio::io::{self, AsyncRead, AsyncWrite};
65use tokio::select;
66use tokio::time::{self};
67use tokio_metrics::TaskMetrics;
68use tokio_stream::wrappers::UnboundedReceiverStream;
69use tracing::{Instrument, debug, debug_span, warn};
70use uuid::Uuid;
71
72use crate::codec::{
73 FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
74};
75use crate::message::{
76 self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
77 SASLServerFirstMessage,
78};
79
80pub fn match_handshake(buf: &[u8]) -> bool {
84 if buf.len() < 8 {
94 return false;
95 }
96 let version = NetworkEndian::read_i32(&buf[4..8]);
97 VERSIONS.contains(&version)
98}
99
100pub struct RunParams<'a, A, I>
102where
103 I: Iterator<Item = TaskMetrics> + Send,
104{
105 pub tls_mode: Option<TlsMode>,
107 pub adapter_client: mz_adapter::Client,
109 pub conn: &'a mut FramedConn<A>,
111 pub conn_uuid: Uuid,
113 pub version: i32,
115 pub params: BTreeMap<String, String>,
117 pub frontegg: Option<FronteggAuthenticator>,
119 pub oidc: GenericOidcAuthenticator,
121 pub authenticator_kind: listeners::AuthenticatorKind,
124 pub active_connection_counter: ConnectionCounter,
126 pub helm_chart_version: Option<String>,
128 pub allowed_roles: AllowedRoles,
130 pub tokio_metrics_intervals: I,
132}
133
134#[mz_ore::instrument(level = "debug")]
144pub async fn run<'a, A, I>(
145 RunParams {
146 tls_mode,
147 adapter_client,
148 conn,
149 conn_uuid,
150 version,
151 mut params,
152 frontegg,
153 oidc,
154 authenticator_kind,
155 active_connection_counter,
156 helm_chart_version,
157 allowed_roles,
158 tokio_metrics_intervals,
159 }: RunParams<'a, A, I>,
160) -> Result<(), io::Error>
161where
162 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
163 I: Iterator<Item = TaskMetrics> + Send,
164{
165 if version != VERSION_3 {
166 return conn
167 .send(ErrorResponse::fatal(
168 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
169 "server does not support the client's requested protocol version",
170 ))
171 .await;
172 }
173
174 let user = params.remove("user").unwrap_or_else(String::new);
175 let options = parse_options(params.get("options").unwrap_or(&String::new()));
176
177 let (oidc_auth_enabled, options) = extract_oidc_auth_enabled_from_options(options);
180 let authenticator = get_authenticator(
181 authenticator_kind,
182 frontegg,
183 oidc,
184 adapter_client.clone(),
185 oidc_auth_enabled,
186 );
187 let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
189 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
191 let role_allowed = match allowed_roles {
192 AllowedRoles::Normal => !is_reserved_user,
193 AllowedRoles::Internal => is_internal_user,
194 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
195 };
196 if !role_allowed {
197 let msg = format!("unauthorized login to user '{user}'");
198 return conn
199 .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
200 .await;
201 }
202
203 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
204 return conn.send(err).await;
205 }
206
207 let authenticator_kind = authenticator.kind();
208
209 let (mut session, expired) = match authenticator {
210 Authenticator::Frontegg(frontegg) => {
211 let password = match request_cleartext_password(conn).await {
212 Ok(password) => password,
213 Err(PasswordRequestError::IoError(e)) => return Err(e),
214 Err(PasswordRequestError::InvalidPasswordError(e)) => {
215 return conn.send(e).await;
216 }
217 };
218
219 let auth_response = frontegg.authenticate(&user, &password).await;
220 match auth_response {
221 Ok((mut auth_session, authenticated)) => {
228 let session = adapter_client.new_session(
229 SessionConfig {
230 conn_id: conn.conn_id().clone(),
231 uuid: conn_uuid,
232 user: auth_session.user().into(),
233 client_ip: conn.peer_addr().clone(),
234 external_metadata_rx: Some(auth_session.external_metadata_rx()),
235 helm_chart_version,
236 authenticator_kind,
237 groups: None,
238 },
239 authenticated,
240 );
241 let expired = async move { auth_session.expired().await };
242 (session, expired.left_future())
243 }
244 Err(err) => {
245 warn!(?err, "pgwire connection failed authentication");
246 return conn
247 .send(ErrorResponse::fatal(
248 SqlState::INVALID_PASSWORD,
249 "invalid password",
250 ))
251 .await;
252 }
253 }
254 }
255 Authenticator::Oidc(oidc) => {
256 let jwt = match request_cleartext_password(conn).await {
258 Ok(password) => password,
259 Err(PasswordRequestError::IoError(e)) => return Err(e),
260 Err(PasswordRequestError::InvalidPasswordError(e)) => {
261 return conn.send(e).await;
262 }
263 };
264 let auth_response = oidc.authenticate(&jwt, Some(&user)).await;
265 match auth_response {
266 Ok((mut claims, authenticated)) => {
267 let groups = claims.groups.take();
268 let session = adapter_client.new_session(
269 SessionConfig {
270 conn_id: conn.conn_id().clone(),
271 uuid: conn_uuid,
272 user: std::mem::take(&mut claims.user),
273 client_ip: conn.peer_addr().clone(),
274 external_metadata_rx: None,
275 helm_chart_version,
276 authenticator_kind,
277 groups,
278 },
279 authenticated,
280 );
281 (session, pending().right_future())
284 }
285 Err(err) => {
286 warn!(?err, "pgwire connection failed authentication");
287 return conn.send(err.into_response()).await;
288 }
289 }
290 }
291 Authenticator::Password(adapter_client) => {
292 let session = match authenticate_with_password(
293 conn,
294 &adapter_client,
295 user,
296 conn_uuid,
297 helm_chart_version,
298 authenticator_kind,
299 )
300 .await
301 {
302 Ok(session) => session,
303 Err(PasswordRequestError::IoError(e)) => return Err(e),
304 Err(PasswordRequestError::InvalidPasswordError(e)) => {
305 return conn.send(e).await;
306 }
307 };
308 (session, pending().right_future())
310 }
311 Authenticator::Sasl(adapter_client) => {
312 conn.send(BackendMessage::AuthenticationSASL).await?;
314 conn.flush().await?;
315 let (mechanism, initial_response) = match conn.recv().await? {
317 Some(FrontendMessage::RawAuthentication(data)) => {
318 match decode_sasl_initial_response(Cursor::new(&data)).ok() {
319 Some(FrontendMessage::SASLInitialResponse {
320 gs2_header,
321 mechanism,
322 initial_response,
323 }) => {
324 if gs2_header.channel_binding_enabled() {
326 return conn
327 .send(ErrorResponse::fatal(
328 SqlState::PROTOCOL_VIOLATION,
329 "channel binding not supported",
330 ))
331 .await;
332 }
333 (mechanism, initial_response)
334 }
335 _ => {
336 return conn
337 .send(ErrorResponse::fatal(
338 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
339 "expected SASLInitialResponse message",
340 ))
341 .await;
342 }
343 }
344 }
345 _ => {
346 return conn
347 .send(ErrorResponse::fatal(
348 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
349 "expected SASLInitialResponse message",
350 ))
351 .await;
352 }
353 };
354
355 if mechanism != "SCRAM-SHA-256" {
356 return conn
357 .send(ErrorResponse::fatal(
358 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
359 "unsupported SASL mechanism",
360 ))
361 .await;
362 }
363
364 if initial_response.nonce.len() > 256 {
365 return conn
366 .send(ErrorResponse::fatal(
367 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
368 "nonce too long",
369 ))
370 .await;
371 }
372
373 let (server_first_message_raw, mock_hash) = match adapter_client
374 .generate_sasl_challenge(&user, &initial_response.nonce)
375 .await
376 {
377 Ok(response) => {
378 let server_first_message_raw = format!(
379 "r={},s={},i={}",
380 response.nonce, response.salt, response.iteration_count
381 );
382
383 let client_key = [0u8; 32];
384 let server_key = [1u8; 32];
385 let mock_hash = format!(
386 "SCRAM-SHA-256${}:{}${}:{}",
387 response.iteration_count,
388 response.salt,
389 BASE64_STANDARD.encode(client_key),
390 BASE64_STANDARD.encode(server_key)
391 );
392
393 conn.send(BackendMessage::AuthenticationSASLContinue(
394 SASLServerFirstMessage {
395 iteration_count: response.iteration_count,
396 nonce: response.nonce,
397 salt: response.salt,
398 },
399 ))
400 .await?;
401 conn.flush().await?;
402 (server_first_message_raw, mock_hash)
403 }
404 Err(e) => {
405 return conn.send(e.into_response(Severity::Fatal)).await;
406 }
407 };
408
409 let authenticated = match conn.recv().await? {
410 Some(FrontendMessage::RawAuthentication(data)) => {
411 match decode_sasl_response(Cursor::new(&data)).ok() {
412 Some(FrontendMessage::SASLResponse(response)) => {
413 let auth_message = format!(
414 "{},{},{}",
415 initial_response.client_first_message_bare_raw,
416 server_first_message_raw,
417 response.client_final_message_bare_raw
418 );
419 if response.proof.len() > 1024 {
420 return conn
421 .send(ErrorResponse::fatal(
422 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
423 "proof too long",
424 ))
425 .await;
426 }
427 match adapter_client
428 .verify_sasl_proof(
429 &user,
430 &response.proof,
431 &auth_message,
432 &mock_hash,
433 )
434 .await
435 {
436 Ok((proof_response, authenticated)) => {
437 conn.send(BackendMessage::AuthenticationSASLFinal(
438 SASLServerFinalMessage {
439 kind: SASLServerFinalMessageKinds::Verifier(
440 proof_response.verifier,
441 ),
442 extensions: vec![],
443 },
444 ))
445 .await?;
446 conn.flush().await?;
447 authenticated
448 }
449 Err(_) => {
450 return conn
451 .send(ErrorResponse::fatal(
452 SqlState::INVALID_PASSWORD,
453 "invalid password",
454 ))
455 .await;
456 }
457 }
458 }
459 _ => {
460 return conn
461 .send(ErrorResponse::fatal(
462 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
463 "expected SASLResponse message",
464 ))
465 .await;
466 }
467 }
468 }
469 _ => {
470 return conn
471 .send(ErrorResponse::fatal(
472 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
473 "expected SASLResponse message",
474 ))
475 .await;
476 }
477 };
478
479 let session = adapter_client.new_session(
480 SessionConfig {
481 conn_id: conn.conn_id().clone(),
482 uuid: conn_uuid,
483 user,
484 client_ip: conn.peer_addr().clone(),
485 external_metadata_rx: None,
486 helm_chart_version,
487 authenticator_kind,
488 groups: None,
489 },
490 authenticated,
491 );
492 let auth_session = pending().right_future();
494 (session, auth_session)
495 }
496
497 Authenticator::None => {
498 let session = adapter_client.new_session(
499 SessionConfig {
500 conn_id: conn.conn_id().clone(),
501 uuid: conn_uuid,
502 user,
503 client_ip: conn.peer_addr().clone(),
504 external_metadata_rx: None,
505 helm_chart_version,
506 authenticator_kind,
507 groups: None,
508 },
509 Authenticated,
510 );
511 let auth_session = pending().right_future();
513 (session, auth_session)
514 }
515 };
516
517 let system_vars = adapter_client.get_system_vars().await;
518 for (name, value) in params {
519 let settings = match name.as_str() {
520 "options" => match &options {
521 Ok(opts) => opts,
522 Err(()) => {
523 session.add_notice(AdapterNotice::BadStartupSetting {
524 name,
525 reason: "could not parse".into(),
526 });
527 continue;
528 }
529 },
530 _ => &vec![(name, value)],
531 };
532 for (key, val) in settings {
533 const LOCAL: bool = false;
534 if let Err(err) = session
539 .vars_mut()
540 .set(&system_vars, key, VarInput::Flat(val), LOCAL)
541 {
542 session.add_notice(AdapterNotice::BadStartupSetting {
543 name: key.clone(),
544 reason: err.to_string(),
545 });
546 }
547 }
548 }
549 session
550 .vars_mut()
551 .end_transaction(EndTransactionAction::Commit);
552
553 let _guard = match active_connection_counter.allocate_connection(session.user()) {
554 Ok(drop_connection) => drop_connection,
555 Err(e) => {
556 let e: AdapterError = e.into();
557 return conn.send(e.into_response(Severity::Fatal)).await;
558 }
559 };
560
561 let mut adapter_client = match adapter_client.startup(session).await {
563 Ok(adapter_client) => adapter_client,
564 Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
565 };
566
567 let mut buf = vec![BackendMessage::AuthenticationOk];
568 for var in adapter_client.session().vars().notify_set() {
569 buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
570 }
571 buf.push(BackendMessage::BackendKeyData {
572 conn_id: adapter_client.session().conn_id().unhandled(),
573 secret_key: adapter_client.session().secret_key(),
574 });
575 buf.extend(
576 adapter_client
577 .session()
578 .drain_notices()
579 .into_iter()
580 .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
581 );
582 buf.push(BackendMessage::ReadyForQuery(
583 adapter_client.session().transaction().into(),
584 ));
585 conn.send_all(buf).await?;
586 conn.flush().await?;
587
588 let machine = StateMachine {
589 conn,
590 adapter_client,
591 txn_needs_commit: false,
592 tokio_metrics_intervals,
593 };
594
595 select! {
596 r = machine.run() => {
597 if let Err(err) = &r {
602 let _ = conn
603 .send(ErrorResponse::fatal(
604 SqlState::CONNECTION_FAILURE,
605 err.to_string(),
606 ))
607 .await;
608 let _ = conn.flush().await;
609 }
610 r
611 },
612 _ = expired => {
613 conn
614 .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
615 .await?;
616 conn.flush().await
617 }
618 }
619}
620
621fn extract_oidc_auth_enabled_from_options(
625 options: Result<Vec<(String, String)>, ()>,
626) -> (bool, Result<Vec<(String, String)>, ()>) {
627 let options = match options {
628 Ok(opts) => opts,
629 Err(_) => return (false, options),
630 };
631
632 let mut new_options = Vec::new();
633 let mut oidc_auth_enabled = false;
634
635 for (k, v) in options {
636 if k == "oidc_auth_enabled" {
637 oidc_auth_enabled = v.parse::<bool>().unwrap_or(false);
638 } else {
639 new_options.push((k, v));
640 }
641 }
642
643 (oidc_auth_enabled, Ok(new_options))
644}
645
646fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
651 let opts = split_options(value);
652 let mut pairs = Vec::with_capacity(opts.len());
653 let mut seen_prefix = false;
654 for opt in opts {
655 if !seen_prefix {
656 if opt == "-c" {
657 seen_prefix = true;
658 } else {
659 let (key, val) = parse_option(&opt)?;
660 pairs.push((key.to_owned(), val.to_owned()));
661 }
662 } else {
663 let (key, val) = opt.split_once('=').ok_or(())?;
664 pairs.push((key.to_owned(), val.to_owned()));
665 seen_prefix = false;
666 }
667 }
668 Ok(pairs)
669}
670
671fn parse_option(option: &str) -> Result<(&str, &str), ()> {
675 let (key, value) = option.split_once('=').ok_or(())?;
676 for prefix in &["-c", "--"] {
677 if let Some(key) = key.strip_prefix(prefix) {
678 return Ok((key, value));
679 }
680 }
681 Err(())
682}
683
684fn split_options(value: &str) -> Vec<String> {
686 let mut strs = Vec::new();
687 let mut current = String::new();
691 let mut was_slash = false;
692 for c in value.chars() {
693 was_slash = match c {
694 ' ' => {
695 if was_slash {
696 current.push(' ');
697 } else if !current.is_empty() {
698 strs.push(std::mem::take(&mut current));
701 }
702 false
703 }
704 '\\' => {
705 if was_slash {
706 current.push('\\');
709 false
710 } else {
711 true
712 }
713 }
714 _ => {
715 current.push(c);
716 false
717 }
718 };
719 }
720 if !current.is_empty() {
722 strs.push(current);
723 }
724 strs
725}
726
727enum PasswordRequestError {
728 InvalidPasswordError(ErrorResponse),
729 IoError(io::Error),
730}
731
732impl From<io::Error> for PasswordRequestError {
733 fn from(e: io::Error) -> Self {
734 PasswordRequestError::IoError(e)
735 }
736}
737
738async fn request_cleartext_password<A>(
742 conn: &mut FramedConn<A>,
743) -> Result<String, PasswordRequestError>
744where
745 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
746{
747 conn.send(BackendMessage::AuthenticationCleartextPassword)
748 .await?;
749 conn.flush().await?;
750
751 if let Some(message) = conn.recv().await? {
752 if let FrontendMessage::RawAuthentication(data) = message {
753 if let Some(FrontendMessage::Password { password }) =
754 decode_password(Cursor::new(&data)).ok()
755 {
756 return Ok(password);
757 }
758 }
759 }
760
761 Err(PasswordRequestError::InvalidPasswordError(
762 ErrorResponse::fatal(
763 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
764 "expected Password message",
765 ),
766 ))
767}
768
769async fn authenticate_with_password<A>(
772 conn: &mut FramedConn<A>,
773 adapter_client: &mz_adapter::Client,
774 user: String,
775 conn_uuid: Uuid,
776 helm_chart_version: Option<String>,
777 authenticator_kind: mz_auth::AuthenticatorKind,
778) -> Result<Session, PasswordRequestError>
779where
780 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
781{
782 let password = match request_cleartext_password(conn).await {
783 Ok(password) => Password(password),
784 Err(e) => return Err(e),
785 };
786
787 let authenticated = match adapter_client.authenticate(&user, &password).await {
788 Ok(authenticated) => authenticated,
789 Err(err) => {
790 warn!(?err, "pgwire connection failed authentication");
791 return Err(PasswordRequestError::InvalidPasswordError(
792 ErrorResponse::fatal(SqlState::INVALID_PASSWORD, "invalid password"),
793 ));
794 }
795 };
796
797 let session = adapter_client.new_session(
798 SessionConfig {
799 conn_id: conn.conn_id().clone(),
800 uuid: conn_uuid,
801 user,
802 client_ip: conn.peer_addr().clone(),
803 external_metadata_rx: None,
804 helm_chart_version,
805 authenticator_kind,
806 groups: None,
807 },
808 authenticated,
809 );
810
811 Ok(session)
812}
813
814#[derive(Debug)]
815enum State {
816 Ready,
817 Drain,
818 Done,
819}
820
821struct StateMachine<'a, A, I>
822where
823 I: Iterator<Item = TaskMetrics> + Send + 'a,
824{
825 conn: &'a mut FramedConn<A>,
826 adapter_client: mz_adapter::SessionClient,
827 txn_needs_commit: bool,
828 tokio_metrics_intervals: I,
829}
830
831enum SendRowsEndedReason {
832 Success {
833 result_size: u64,
834 rows_returned: u64,
835 },
836 Errored {
837 error: String,
838 },
839 Canceled,
840}
841
842const ABORTED_TXN_MSG: &str =
843 "current transaction is aborted, commands ignored until end of transaction block";
844
845impl<'a, A, I> StateMachine<'a, A, I>
846where
847 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
848 I: Iterator<Item = TaskMetrics> + Send + 'a,
849{
850 #[allow(clippy::manual_async_fn)]
854 #[mz_ore::instrument(level = "debug")]
855 fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
856 async move {
857 let mut state = State::Ready;
858 loop {
859 self.send_pending_notices().await?;
860 state = match state {
861 State::Ready => self.advance_ready().await?,
862 State::Drain => self.advance_drain().await?,
863 State::Done => return Ok(()),
864 };
865 self.adapter_client
866 .add_idle_in_transaction_session_timeout();
867 }
868 }
869 }
870
871 #[instrument(level = "debug")]
872 async fn advance_ready(&mut self) -> Result<State, io::Error> {
873 self.tokio_metrics_intervals
875 .next()
876 .expect("infinite iterator");
877
878 let message = select! {
880 biased;
881
882 Some(timeout) = self.adapter_client.recv_timeout() => {
884 let err: AdapterError = timeout.into();
885 let conn_id = self.adapter_client.session().conn_id();
886 tracing::warn!("session timed out, conn_id {}", conn_id);
887
888 let error_response = err.into_response(Severity::Fatal);
890 let error_state = self.send_error_and_get_state(error_response).await;
891
892 self.adapter_client.terminate().await;
894
895 let _ = self.conn.recv().await?;
899 return error_state;
900 },
901 message = self.conn.recv() => message?,
903 };
904
905 let interval = self
907 .tokio_metrics_intervals
908 .next()
909 .expect("infinite iterator");
910 let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
911
912 let received = SYSTEM_TIME();
916
917 self.adapter_client
918 .remove_idle_in_transaction_session_timeout();
919
920 let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
923
924 let start = message.as_ref().map(|_| Instant::now());
925 let next_state = match message {
926 Some(FrontendMessage::Query { sql }) => {
927 let query_root_span =
928 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
929 query_root_span.follows_from(tracing::Span::current());
930 self.query(sql, received)
931 .instrument(query_root_span)
932 .await?
933 }
934 Some(FrontendMessage::Parse {
935 name,
936 sql,
937 param_types,
938 }) => self.parse(name, sql, param_types).await?,
939 Some(FrontendMessage::Bind {
940 portal_name,
941 statement_name,
942 param_formats,
943 raw_params,
944 result_formats,
945 }) => {
946 self.bind(
947 portal_name,
948 statement_name,
949 param_formats,
950 raw_params,
951 result_formats,
952 )
953 .await?
954 }
955 Some(FrontendMessage::Execute {
956 portal_name,
957 max_rows,
958 }) => {
959 let max_rows = match usize::try_from(max_rows) {
960 Ok(0) | Err(_) => ExecuteCount::All, Ok(n) => ExecuteCount::Count(n),
962 };
963 let execute_root_span =
964 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
965 execute_root_span.follows_from(tracing::Span::current());
966 let state = self
967 .execute(
968 portal_name,
969 max_rows,
970 portal_exec_message,
971 None,
972 ExecuteTimeout::None,
973 None,
974 Some(received),
975 )
976 .instrument(execute_root_span)
977 .await?;
978 if self.adapter_client.session().transaction().is_implicit() {
993 self.txn_needs_commit = true;
994 }
995 state
996 }
997 Some(FrontendMessage::DescribeStatement { name }) => {
998 self.describe_statement(&name).await?
999 }
1000 Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
1001 Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
1002 Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
1003 Some(FrontendMessage::Flush) => self.flush().await?,
1004 Some(FrontendMessage::Sync) => self.sync().await?,
1005 Some(FrontendMessage::Terminate) => State::Done,
1006
1007 Some(FrontendMessage::CopyData(_))
1008 | Some(FrontendMessage::CopyDone)
1009 | Some(FrontendMessage::CopyFail(_))
1010 | Some(FrontendMessage::Password { .. })
1011 | Some(FrontendMessage::RawAuthentication(_))
1012 | Some(FrontendMessage::SASLInitialResponse { .. })
1013 | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
1014 None => State::Done,
1015 };
1016
1017 if let Some(start) = start {
1018 self.adapter_client
1019 .inner()
1020 .metrics()
1021 .pgwire_message_processing_seconds
1022 .with_label_values(&[message_name])
1023 .observe(start.elapsed().as_secs_f64());
1024 }
1025 self.adapter_client
1026 .inner()
1027 .metrics()
1028 .pgwire_recv_scheduling_delay_ms
1029 .with_label_values(&[message_name])
1030 .observe(recv_scheduling_delay_ms);
1031
1032 Ok(next_state)
1033 }
1034
1035 async fn advance_drain(&mut self) -> Result<State, io::Error> {
1036 let message = self.conn.recv().await?;
1037 if message.is_some() {
1038 self.adapter_client
1039 .remove_idle_in_transaction_session_timeout();
1040 }
1041 match message {
1042 Some(FrontendMessage::Sync) => self.sync().await,
1043 None => Ok(State::Done),
1044 _ => Ok(State::Drain),
1045 }
1046 }
1047
1048 #[instrument(level = "debug")]
1052 async fn one_query(
1053 &mut self,
1054 stmt: Statement<Raw>,
1055 sql: String,
1056 lifecycle_timestamps: LifecycleTimestamps,
1057 ) -> Result<State, io::Error> {
1058 const EMPTY_PORTAL: &str = "";
1061 if let Err(e) = self
1062 .adapter_client
1063 .declare(EMPTY_PORTAL.to_string(), stmt, sql)
1064 .await
1065 {
1066 return self
1067 .send_error_and_get_state(e.into_response(Severity::Error))
1068 .await;
1069 }
1070 let portal = self
1071 .adapter_client
1072 .session()
1073 .get_portal_unverified_mut(EMPTY_PORTAL)
1074 .expect("unnamed portal should be present");
1075
1076 *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
1077
1078 let stmt_desc = portal.desc.clone();
1079 if !stmt_desc.param_types.is_empty() {
1080 return self
1081 .send_error_and_get_state(ErrorResponse::error(
1082 SqlState::UNDEFINED_PARAMETER,
1083 "there is no parameter $1",
1084 ))
1085 .await;
1086 }
1087
1088 if let Some(relation_desc) = &stmt_desc.relation_desc {
1090 if !stmt_desc.is_copy {
1091 let formats = vec![Format::Text; stmt_desc.arity()];
1092 self.send(BackendMessage::RowDescription(
1093 message::encode_row_description(relation_desc, &formats),
1094 ))
1095 .await?;
1096 }
1097 }
1098
1099 let result = match self
1100 .adapter_client
1101 .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
1102 .await
1103 {
1104 Ok((response, execute_started)) => {
1105 self.send_pending_notices().await?;
1106 self.send_execute_response(
1107 response,
1108 stmt_desc.relation_desc,
1109 EMPTY_PORTAL.to_string(),
1110 ExecuteCount::All,
1111 portal_exec_message,
1112 None,
1113 ExecuteTimeout::None,
1114 execute_started,
1115 )
1116 .await
1117 }
1118 Err(e) => {
1119 self.send_pending_notices().await?;
1120 self.send_error_and_get_state(e.into_response(Severity::Error))
1121 .await
1122 }
1123 };
1124
1125 self.adapter_client.session().remove_portal(EMPTY_PORTAL);
1127
1128 result
1129 }
1130
1131 async fn ensure_transaction(
1132 &mut self,
1133 num_stmts: usize,
1134 message_type: &str,
1135 ) -> Result<(), io::Error> {
1136 let start = Instant::now();
1137 if self.txn_needs_commit {
1138 self.commit_transaction().await?;
1139 }
1140 let res = self.adapter_client.start_transaction(Some(num_stmts));
1143 assert_ok!(res);
1144 self.adapter_client
1145 .inner()
1146 .metrics()
1147 .pgwire_ensure_transaction_seconds
1148 .with_label_values(&[message_type])
1149 .observe(start.elapsed().as_secs_f64());
1150 Ok(())
1151 }
1152
1153 fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1154 let parse_start = Instant::now();
1155 let result = match self.adapter_client.parse(sql) {
1156 Ok(result) => result.map_err(|e| {
1157 let pos = sql[..e.error.pos].chars().count() + 1;
1160 ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1161 }),
1162 Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1163 };
1164 self.adapter_client
1165 .inner()
1166 .metrics()
1167 .parse_seconds
1168 .observe(parse_start.elapsed().as_secs_f64());
1169 result
1170 }
1171
1172 #[instrument(level = "debug")]
1177 async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1178 let stmts = match self.parse_sql(&sql) {
1180 Ok(stmts) => stmts,
1181 Err(err) => {
1182 self.send_error_and_get_state(err).await?;
1183 return self.ready().await;
1184 }
1185 };
1186
1187 let num_stmts = stmts.len();
1188
1189 for StatementParseResult { ast: stmt, sql } in stmts {
1191 if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1193 self.aborted_txn_error().await?;
1194 break;
1195 }
1196
1197 self.ensure_transaction(num_stmts, "query").await?;
1205
1206 match self
1207 .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1208 .await?
1209 {
1210 State::Ready => (),
1211 State::Drain => break,
1212 State::Done => return Ok(State::Done),
1213 }
1214 }
1215
1216 {
1218 if self.adapter_client.session().transaction().is_implicit() {
1219 self.commit_transaction().await?;
1220 }
1221 }
1222
1223 if num_stmts == 0 {
1224 self.send(BackendMessage::EmptyQueryResponse).await?;
1225 }
1226
1227 self.ready().await
1228 }
1229
1230 #[instrument(level = "debug")]
1231 async fn parse(
1232 &mut self,
1233 name: String,
1234 sql: String,
1235 param_oids: Vec<u32>,
1236 ) -> Result<State, io::Error> {
1237 self.ensure_transaction(1, "parse").await?;
1239
1240 let mut param_types = vec![];
1241 for oid in param_oids {
1242 match mz_pgrepr::Type::from_oid(oid) {
1243 Ok(ty) => match SqlScalarType::try_from(&ty) {
1244 Ok(ty) => param_types.push(Some(ty)),
1245 Err(err) => {
1246 return self
1247 .send_error_and_get_state(ErrorResponse::error(
1248 SqlState::INVALID_PARAMETER_VALUE,
1249 err.to_string(),
1250 ))
1251 .await;
1252 }
1253 },
1254 Err(_) if oid == 0 => param_types.push(None),
1255 Err(e) => {
1256 return self
1257 .send_error_and_get_state(ErrorResponse::error(
1258 SqlState::PROTOCOL_VIOLATION,
1259 e.to_string(),
1260 ))
1261 .await;
1262 }
1263 }
1264 }
1265
1266 let stmts = match self.parse_sql(&sql) {
1267 Ok(stmts) => stmts,
1268 Err(err) => {
1269 return self.send_error_and_get_state(err).await;
1270 }
1271 };
1272 if stmts.len() > 1 {
1273 return self
1274 .send_error_and_get_state(ErrorResponse::error(
1275 SqlState::INTERNAL_ERROR,
1276 "cannot insert multiple commands into a prepared statement",
1277 ))
1278 .await;
1279 }
1280 let (maybe_stmt, sql) = match stmts.into_iter().next() {
1281 None => (None, ""),
1282 Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1283 };
1284 if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1285 return self.aborted_txn_error().await;
1286 }
1287 match self
1288 .adapter_client
1289 .prepare(name, maybe_stmt, sql.to_string(), param_types)
1290 .await
1291 {
1292 Ok(()) => {
1293 self.send(BackendMessage::ParseComplete).await?;
1294 Ok(State::Ready)
1295 }
1296 Err(e) => {
1297 self.send_error_and_get_state(e.into_response(Severity::Error))
1298 .await
1299 }
1300 }
1301 }
1302
1303 #[instrument(level = "debug")]
1305 async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1306 self.end_transaction(EndTransactionAction::Commit).await
1307 }
1308
1309 #[instrument(level = "debug")]
1311 async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1312 self.end_transaction(EndTransactionAction::Rollback).await
1313 }
1314
1315 #[instrument(level = "debug")]
1317 async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1318 self.txn_needs_commit = false;
1319 let resp = self.adapter_client.end_transaction(action).await;
1320 if let Err(err) = resp {
1321 self.send(BackendMessage::ErrorResponse(
1322 err.into_response(Severity::Error),
1323 ))
1324 .await?;
1325 }
1326 Ok(())
1327 }
1328
1329 #[instrument(level = "debug")]
1330 async fn bind(
1331 &mut self,
1332 portal_name: String,
1333 statement_name: String,
1334 param_formats: Vec<Format>,
1335 raw_params: Vec<Option<Vec<u8>>>,
1336 result_formats: Vec<Format>,
1337 ) -> Result<State, io::Error> {
1338 self.ensure_transaction(1, "bind").await?;
1340
1341 let aborted_txn = self.is_aborted_txn();
1342 let stmt = match self
1343 .adapter_client
1344 .get_prepared_statement(&statement_name)
1345 .await
1346 {
1347 Ok(stmt) => stmt,
1348 Err(err) => {
1349 return self
1350 .send_error_and_get_state(err.into_response(Severity::Error))
1351 .await;
1352 }
1353 };
1354
1355 let param_types = &stmt.desc().param_types;
1356 if param_types.len() != raw_params.len() {
1357 let message = format!(
1358 "bind message supplies {actual} parameters, \
1359 but prepared statement \"{name}\" requires {expected}",
1360 name = statement_name,
1361 actual = raw_params.len(),
1362 expected = param_types.len()
1363 );
1364 return self
1365 .send_error_and_get_state(ErrorResponse::error(
1366 SqlState::PROTOCOL_VIOLATION,
1367 message,
1368 ))
1369 .await;
1370 }
1371 let param_formats = match pad_formats(param_formats, raw_params.len()) {
1372 Ok(param_formats) => param_formats,
1373 Err(msg) => {
1374 return self
1375 .send_error_and_get_state(ErrorResponse::error(
1376 SqlState::PROTOCOL_VIOLATION,
1377 msg,
1378 ))
1379 .await;
1380 }
1381 };
1382 if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1383 return self.aborted_txn_error().await;
1384 }
1385 let buf = RowArena::new();
1386 let mut params = vec![];
1387 for ((raw_param, mz_typ), format) in raw_params
1388 .into_iter()
1389 .zip_eq(param_types)
1390 .zip_eq(param_formats)
1391 {
1392 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1393 let datum = match raw_param {
1394 None => Datum::Null,
1395 Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1396 Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1397 Ok(datum) => datum,
1398 Err(msg) => {
1399 return self
1400 .send_error_and_get_state(ErrorResponse::error(
1401 SqlState::INVALID_PARAMETER_VALUE,
1402 msg,
1403 ))
1404 .await;
1405 }
1406 },
1407 Err(err) => {
1408 let msg = format!("unable to decode parameter: {}", err);
1409 return self
1410 .send_error_and_get_state(ErrorResponse::error(
1411 SqlState::INVALID_PARAMETER_VALUE,
1412 msg,
1413 ))
1414 .await;
1415 }
1416 },
1417 };
1418 params.push((datum, mz_typ.clone()))
1419 }
1420
1421 let result_formats = match pad_formats(
1422 result_formats,
1423 stmt.desc()
1424 .relation_desc
1425 .clone()
1426 .map(|desc| desc.typ().column_types.len())
1427 .unwrap_or(0),
1428 ) {
1429 Ok(result_formats) => result_formats,
1430 Err(msg) => {
1431 return self
1432 .send_error_and_get_state(ErrorResponse::error(
1433 SqlState::PROTOCOL_VIOLATION,
1434 msg,
1435 ))
1436 .await;
1437 }
1438 };
1439
1440 if !stmt.stmt().map_or(false, |stmt| match stmt {
1443 Statement::Copy(CopyStatement {
1444 direction: CopyDirection::To,
1445 ..
1446 }) => true,
1447 Statement::Copy(CopyStatement {
1448 direction: CopyDirection::From,
1449 target: CopyTarget::Expr(_),
1453 ..
1454 }) => true,
1455 _ => false,
1456 }) {
1457 if let Some(desc) = stmt.desc().relation_desc.clone() {
1458 for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1459 match (format, &ty.scalar_type) {
1460 (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1461 return self
1462 .send_error_and_get_state(ErrorResponse::error(
1463 SqlState::PROTOCOL_VIOLATION,
1464 "binary encoding of list types is not implemented",
1465 ))
1466 .await;
1467 }
1468 (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1469 return self
1470 .send_error_and_get_state(ErrorResponse::error(
1471 SqlState::PROTOCOL_VIOLATION,
1472 "binary encoding of map types is not implemented",
1473 ))
1474 .await;
1475 }
1476 (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1477 return self
1478 .send_error_and_get_state(ErrorResponse::error(
1479 SqlState::PROTOCOL_VIOLATION,
1480 "binary encoding of aclitem types does not exist",
1481 ))
1482 .await;
1483 }
1484 _ => (),
1485 }
1486 }
1487 }
1488 }
1489
1490 let desc = stmt.desc().clone();
1491 let logging = Arc::clone(stmt.logging());
1492 let stmt_ast = stmt.stmt().cloned();
1493 let state_revision = stmt.state_revision;
1494 if let Err(err) = self.adapter_client.session().set_portal(
1495 portal_name,
1496 desc,
1497 stmt_ast,
1498 logging,
1499 params,
1500 result_formats,
1501 state_revision,
1502 ) {
1503 return self
1504 .send_error_and_get_state(err.into_response(Severity::Error))
1505 .await;
1506 }
1507
1508 self.send(BackendMessage::BindComplete).await?;
1509 Ok(State::Ready)
1510 }
1511
1512 fn execute(
1515 &mut self,
1516 portal_name: String,
1517 max_rows: ExecuteCount,
1518 get_response: GetResponse,
1519 fetch_portal_name: Option<String>,
1520 timeout: ExecuteTimeout,
1521 outer_ctx_extra: Option<ExecuteContextGuard>,
1522 received: Option<EpochMillis>,
1523 ) -> BoxFuture<'_, Result<State, io::Error>> {
1524 async move {
1525 let aborted_txn = self.is_aborted_txn();
1526
1527 let portal = match self
1529 .adapter_client
1530 .session()
1531 .get_portal_unverified_mut(&portal_name)
1532 {
1533 Some(portal) => portal,
1534 None => {
1535 let msg = format!("portal {} does not exist", portal_name.quoted());
1536 if let Some(outer_ctx_extra) = outer_ctx_extra {
1537 self.adapter_client.retire_execute(
1538 outer_ctx_extra,
1539 StatementEndedExecutionReason::Errored { error: msg.clone() },
1540 );
1541 }
1542 return self
1543 .send_error_and_get_state(ErrorResponse::error(
1544 SqlState::INVALID_CURSOR_NAME,
1545 msg,
1546 ))
1547 .await;
1548 }
1549 };
1550
1551 *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1552
1553 let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1555 if aborted_txn && !txn_exit_stmt {
1556 if let Some(outer_ctx_extra) = outer_ctx_extra {
1557 self.adapter_client.retire_execute(
1558 outer_ctx_extra,
1559 StatementEndedExecutionReason::Errored {
1560 error: ABORTED_TXN_MSG.to_string(),
1561 },
1562 );
1563 }
1564 return self.aborted_txn_error().await;
1565 }
1566
1567 let row_desc = portal.desc.relation_desc.clone();
1568 match portal.state {
1569 PortalState::NotStarted => {
1570 self.ensure_transaction(1, "execute").await?;
1572 match self
1573 .adapter_client
1574 .execute(
1575 portal_name.clone(),
1576 self.conn.wait_closed(),
1577 outer_ctx_extra,
1578 )
1579 .await
1580 {
1581 Ok((response, execute_started)) => {
1582 self.send_pending_notices().await?;
1583 self.send_execute_response(
1584 response,
1585 row_desc,
1586 portal_name,
1587 max_rows,
1588 get_response,
1589 fetch_portal_name,
1590 timeout,
1591 execute_started,
1592 )
1593 .await
1594 }
1595 Err(e) => {
1596 self.send_pending_notices().await?;
1597 self.send_error_and_get_state(e.into_response(Severity::Error))
1598 .await
1599 }
1600 }
1601 }
1602 PortalState::InProgress(rows) => {
1603 let rows = rows.take().expect("InProgress rows must be populated");
1604 let (result, statement_ended_execution_reason) = match self
1605 .send_rows(
1606 row_desc.expect("portal missing row desc on resumption"),
1607 portal_name,
1608 rows,
1609 max_rows,
1610 get_response,
1611 fetch_portal_name,
1612 timeout,
1613 )
1614 .await
1615 {
1616 Err(e) => {
1617 (Err(e), StatementEndedExecutionReason::Canceled)
1620 }
1621 Ok((ok, SendRowsEndedReason::Canceled)) => {
1622 (Ok(ok), StatementEndedExecutionReason::Canceled)
1623 }
1624 Ok((
1637 ok,
1638 SendRowsEndedReason::Success {
1639 result_size: _,
1640 rows_returned: _,
1641 },
1642 )) => (
1643 Ok(ok),
1644 StatementEndedExecutionReason::Success {
1645 result_size: None,
1646 rows_returned: None,
1647 execution_strategy: None,
1648 },
1649 ),
1650 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1651 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1652 }
1653 };
1654 if let Some(outer_ctx_extra) = outer_ctx_extra {
1655 self.adapter_client
1656 .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1657 }
1658 result
1659 }
1660 PortalState::Completed(Some(tag)) => {
1667 let tag = tag.to_string();
1668 if let Some(outer_ctx_extra) = outer_ctx_extra {
1669 self.adapter_client.retire_execute(
1670 outer_ctx_extra,
1671 StatementEndedExecutionReason::Success {
1672 result_size: None,
1673 rows_returned: None,
1674 execution_strategy: None,
1675 },
1676 );
1677 }
1678 self.send(BackendMessage::CommandComplete { tag }).await?;
1679 Ok(State::Ready)
1680 }
1681 PortalState::Completed(None) => {
1682 let error = format!(
1683 "portal {} cannot be run",
1684 Ident::new_unchecked(portal_name).to_ast_string_stable()
1685 );
1686 if let Some(outer_ctx_extra) = outer_ctx_extra {
1687 self.adapter_client.retire_execute(
1688 outer_ctx_extra,
1689 StatementEndedExecutionReason::Errored {
1690 error: error.clone(),
1691 },
1692 );
1693 }
1694 self.send_error_and_get_state(ErrorResponse::error(
1695 SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1696 error,
1697 ))
1698 .await
1699 }
1700 }
1701 }
1702 .instrument(debug_span!("execute"))
1703 .boxed()
1704 }
1705
1706 #[instrument(level = "debug")]
1707 async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1708 self.ensure_transaction(1, "describe_statement").await?;
1710
1711 let stmt = match self.adapter_client.get_prepared_statement(name).await {
1712 Ok(stmt) => stmt,
1713 Err(err) => {
1714 return self
1715 .send_error_and_get_state(err.into_response(Severity::Error))
1716 .await;
1717 }
1718 };
1719 let parameter_desc = BackendMessage::ParameterDescription(
1721 stmt.desc()
1722 .param_types
1723 .iter()
1724 .map(mz_pgrepr::Type::from)
1725 .collect(),
1726 );
1727 let formats = vec![Format::Text; stmt.desc().arity()];
1731 let row_desc = describe_rows(stmt.desc(), &formats);
1732 self.send_all([parameter_desc, row_desc]).await?;
1733 Ok(State::Ready)
1734 }
1735
1736 #[instrument(level = "debug")]
1737 async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1738 self.ensure_transaction(1, "describe_portal").await?;
1740
1741 let session = self.adapter_client.session();
1742 let row_desc = session
1743 .get_portal_unverified(name)
1744 .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1745 match row_desc {
1746 Some(row_desc) => {
1747 self.send(row_desc).await?;
1748 Ok(State::Ready)
1749 }
1750 None => {
1751 self.send_error_and_get_state(ErrorResponse::error(
1752 SqlState::INVALID_CURSOR_NAME,
1753 format!("portal {} does not exist", name.quoted()),
1754 ))
1755 .await
1756 }
1757 }
1758 }
1759
1760 #[instrument(level = "debug")]
1761 async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1762 self.adapter_client
1763 .session()
1764 .remove_prepared_statement(&name);
1765 self.send(BackendMessage::CloseComplete).await?;
1766 Ok(State::Ready)
1767 }
1768
1769 #[instrument(level = "debug")]
1770 async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1771 self.adapter_client.session().remove_portal(&name);
1772 self.send(BackendMessage::CloseComplete).await?;
1773 Ok(State::Ready)
1774 }
1775
1776 fn complete_portal(&mut self, name: &str) {
1777 let portal = self
1778 .adapter_client
1779 .session()
1780 .get_portal_unverified_mut(name)
1781 .expect("portal should exist");
1782 *portal.state = PortalState::Completed(None);
1783 }
1784
1785 async fn fetch(
1786 &mut self,
1787 name: String,
1788 count: Option<FetchDirection>,
1789 max_rows: ExecuteCount,
1790 fetch_portal_name: Option<String>,
1791 timeout: ExecuteTimeout,
1792 ctx_extra: ExecuteContextGuard,
1793 ) -> Result<State, io::Error> {
1794 let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1797
1798 let count = match (max_rows, count) {
1811 (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1812 let count = usize::cast_from(count);
1813 if max_rows < count {
1814 let msg = "Execute with max_rows < a FETCH's count is not supported";
1815 self.adapter_client.retire_execute(
1816 ctx_extra,
1817 StatementEndedExecutionReason::Errored {
1818 error: msg.to_string(),
1819 },
1820 );
1821 return self
1822 .send_error_and_get_state(ErrorResponse::error(
1823 SqlState::FEATURE_NOT_SUPPORTED,
1824 msg,
1825 ))
1826 .await;
1827 }
1828 ExecuteCount::Count(count)
1829 }
1830 (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1831 let msg = "Execute with max_rows of a FETCH ALL is not supported";
1832 self.adapter_client.retire_execute(
1833 ctx_extra,
1834 StatementEndedExecutionReason::Errored {
1835 error: msg.to_string(),
1836 },
1837 );
1838 return self
1839 .send_error_and_get_state(ErrorResponse::error(
1840 SqlState::FEATURE_NOT_SUPPORTED,
1841 msg,
1842 ))
1843 .await;
1844 }
1845 (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1846 (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1847 ExecuteCount::Count(usize::cast_from(count))
1848 }
1849 };
1850 let cursor_name = name.to_string();
1851 self.execute(
1852 cursor_name,
1853 count,
1854 fetch_message,
1855 fetch_portal_name,
1856 timeout,
1857 Some(ctx_extra),
1858 None,
1859 )
1860 .await
1861 }
1862
1863 async fn flush(&mut self) -> Result<State, io::Error> {
1864 self.conn.flush().await?;
1865 Ok(State::Ready)
1866 }
1867
1868 #[instrument(level = "debug")]
1873 async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1874 where
1875 M: Into<BackendMessage>,
1876 {
1877 let message: BackendMessage = message.into();
1878 let is_error =
1879 matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1880
1881 self.conn.send(message).await?;
1882
1883 if is_error {
1890 self.conn.flush().await?;
1891 }
1892
1893 Ok(())
1894 }
1895
1896 #[instrument(level = "debug")]
1897 pub async fn send_all(
1898 &mut self,
1899 messages: impl IntoIterator<Item = BackendMessage>,
1900 ) -> Result<(), io::Error> {
1901 for m in messages {
1902 self.send(m).await?;
1903 }
1904 Ok(())
1905 }
1906
1907 #[instrument(level = "debug")]
1908 async fn sync(&mut self) -> Result<State, io::Error> {
1909 if self.adapter_client.session().transaction().is_implicit() {
1911 self.commit_transaction().await?;
1912 }
1913 self.ready().await
1914 }
1915
1916 #[instrument(level = "debug")]
1917 async fn ready(&mut self) -> Result<State, io::Error> {
1918 let txn_state = self.adapter_client.session().transaction().into();
1919 self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1920 self.flush().await
1921 }
1922
1923 #[allow(clippy::too_many_arguments)]
1924 #[instrument(level = "debug")]
1925 async fn send_execute_response(
1926 &mut self,
1927 response: ExecuteResponse,
1928 row_desc: Option<RelationDesc>,
1929 portal_name: String,
1930 max_rows: ExecuteCount,
1931 get_response: GetResponse,
1932 fetch_portal_name: Option<String>,
1933 timeout: ExecuteTimeout,
1934 execute_started: Instant,
1935 ) -> Result<State, io::Error> {
1936 let mut tag = response.tag();
1937
1938 macro_rules! command_complete {
1939 () => {{
1940 self.send(BackendMessage::CommandComplete {
1941 tag: tag
1942 .take()
1943 .expect("command_complete only called on tag-generating results"),
1944 })
1945 .await?;
1946 Ok(State::Ready)
1947 }};
1948 }
1949
1950 let r = match response {
1951 ExecuteResponse::ClosedCursor => {
1952 self.complete_portal(&portal_name);
1953 command_complete!()
1954 }
1955 ExecuteResponse::DeclaredCursor => {
1956 self.complete_portal(&portal_name);
1957 command_complete!()
1958 }
1959 ExecuteResponse::EmptyQuery => {
1960 self.send(BackendMessage::EmptyQueryResponse).await?;
1961 Ok(State::Ready)
1962 }
1963 ExecuteResponse::Fetch {
1964 name,
1965 count,
1966 timeout,
1967 ctx_extra,
1968 } => {
1969 self.fetch(
1970 name,
1971 count,
1972 max_rows,
1973 Some(portal_name.to_string()),
1974 timeout,
1975 ctx_extra,
1976 )
1977 .await
1978 }
1979 ExecuteResponse::SendingRowsStreaming {
1980 rows,
1981 instance_id,
1982 strategy,
1983 } => {
1984 let row_desc = row_desc
1985 .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1986
1987 let span = tracing::debug_span!("sending_rows_streaming");
1988
1989 self.send_rows(
1990 row_desc,
1991 portal_name,
1992 InProgressRows::new(RecordFirstRowStream::new(
1993 Box::new(rows),
1994 execute_started,
1995 &self.adapter_client,
1996 Some(instance_id),
1997 Some(strategy),
1998 )),
1999 max_rows,
2000 get_response,
2001 fetch_portal_name,
2002 timeout,
2003 )
2004 .instrument(span)
2005 .await
2006 .map(|(state, _)| state)
2007 }
2008 ExecuteResponse::SendingRowsImmediate { rows } => {
2009 let row_desc = row_desc
2010 .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
2011
2012 let span = tracing::debug_span!("sending_rows_immediate");
2013
2014 let stream =
2015 futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
2016 self.send_rows(
2017 row_desc,
2018 portal_name,
2019 InProgressRows::new(RecordFirstRowStream::new(
2020 Box::new(stream),
2021 execute_started,
2022 &self.adapter_client,
2023 None,
2024 Some(StatementExecutionStrategy::Constant),
2025 )),
2026 max_rows,
2027 get_response,
2028 fetch_portal_name,
2029 timeout,
2030 )
2031 .instrument(span)
2032 .await
2033 .map(|(state, _)| state)
2034 }
2035 ExecuteResponse::SetVariable { name, .. } => {
2036 let qn = name.to_string();
2039 let msg = if let Some(var) = self
2040 .adapter_client
2041 .session()
2042 .vars_mut()
2043 .notify_set()
2044 .find(|v| v.name() == qn)
2045 {
2046 Some(BackendMessage::ParameterStatus(var.name(), var.value()))
2047 } else {
2048 None
2049 };
2050 if let Some(msg) = msg {
2051 self.send(msg).await?;
2052 }
2053 command_complete!()
2054 }
2055 ExecuteResponse::Subscribing {
2056 rx,
2057 ctx_extra,
2058 instance_id,
2059 } => {
2060 if fetch_portal_name.is_none() {
2061 let mut msg = ErrorResponse::notice(
2062 SqlState::WARNING,
2063 "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
2064 );
2065 if self.adapter_client.session().vars().application_name() == "psql" {
2066 msg.hint = Some(
2067 "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
2068 .into(),
2069 )
2070 }
2071 self.send(msg).await?;
2072 self.conn.flush().await?;
2073 }
2074 let row_desc =
2075 row_desc.expect("missing row description for ExecuteResponse::Subscribing");
2076 let (result, statement_ended_execution_reason) = match self
2077 .send_rows(
2078 row_desc,
2079 portal_name,
2080 InProgressRows::new(RecordFirstRowStream::new(
2081 Box::new(UnboundedReceiverStream::new(rx)),
2082 execute_started,
2083 &self.adapter_client,
2084 Some(instance_id),
2085 None,
2086 )),
2087 max_rows,
2088 get_response,
2089 fetch_portal_name,
2090 timeout,
2091 )
2092 .await
2093 {
2094 Err(e) => {
2095 (Err(e), StatementEndedExecutionReason::Canceled)
2098 }
2099 Ok((ok, SendRowsEndedReason::Canceled)) => {
2100 (Ok(ok), StatementEndedExecutionReason::Canceled)
2101 }
2102 Ok((
2103 ok,
2104 SendRowsEndedReason::Success {
2105 result_size,
2106 rows_returned,
2107 },
2108 )) => (
2109 Ok(ok),
2110 StatementEndedExecutionReason::Success {
2111 result_size: Some(result_size),
2112 rows_returned: Some(rows_returned),
2113 execution_strategy: None,
2114 },
2115 ),
2116 Ok((ok, SendRowsEndedReason::Errored { error })) => {
2117 (Ok(ok), StatementEndedExecutionReason::Errored { error })
2118 }
2119 };
2120 self.adapter_client
2121 .retire_execute(ctx_extra, statement_ended_execution_reason);
2122 return result;
2123 }
2124 ExecuteResponse::CopyTo { format, resp } => {
2125 let row_desc =
2126 row_desc.expect("missing row description for ExecuteResponse::CopyTo");
2127 match *resp {
2128 ExecuteResponse::Subscribing {
2129 rx,
2130 ctx_extra,
2131 instance_id,
2132 } => {
2133 let (result, statement_ended_execution_reason) = match self
2134 .copy_rows(
2135 format,
2136 row_desc,
2137 RecordFirstRowStream::new(
2138 Box::new(UnboundedReceiverStream::new(rx)),
2139 execute_started,
2140 &self.adapter_client,
2141 Some(instance_id),
2142 None,
2143 ),
2144 )
2145 .await
2146 {
2147 Err(e) => {
2148 (Err(e), StatementEndedExecutionReason::Canceled)
2151 }
2152 Ok((
2153 state,
2154 SendRowsEndedReason::Success {
2155 result_size,
2156 rows_returned,
2157 },
2158 )) => (
2159 Ok(state),
2160 StatementEndedExecutionReason::Success {
2161 result_size: Some(result_size),
2162 rows_returned: Some(rows_returned),
2163 execution_strategy: None,
2164 },
2165 ),
2166 Ok((state, SendRowsEndedReason::Errored { error })) => {
2167 (Ok(state), StatementEndedExecutionReason::Errored { error })
2168 }
2169 Ok((state, SendRowsEndedReason::Canceled)) => {
2170 (Ok(state), StatementEndedExecutionReason::Canceled)
2171 }
2172 };
2173 self.adapter_client
2174 .retire_execute(ctx_extra, statement_ended_execution_reason);
2175 return result;
2176 }
2177 ExecuteResponse::SendingRowsStreaming {
2178 rows,
2179 instance_id,
2180 strategy,
2181 } => {
2182 return self
2187 .copy_rows(
2188 format,
2189 row_desc,
2190 RecordFirstRowStream::new(
2191 Box::new(rows),
2192 execute_started,
2193 &self.adapter_client,
2194 Some(instance_id),
2195 Some(strategy),
2196 ),
2197 )
2198 .await
2199 .map(|(state, _)| state);
2200 }
2201 ExecuteResponse::SendingRowsImmediate { rows } => {
2202 let span = tracing::debug_span!("sending_rows_immediate");
2203
2204 let rows = futures::stream::once(futures::future::ready(
2205 PeekResponseUnary::Rows(rows),
2206 ));
2207 return self
2212 .copy_rows(
2213 format,
2214 row_desc,
2215 RecordFirstRowStream::new(
2216 Box::new(rows),
2217 execute_started,
2218 &self.adapter_client,
2219 None,
2220 Some(StatementExecutionStrategy::Constant),
2221 ),
2222 )
2223 .instrument(span)
2224 .await
2225 .map(|(state, _)| state);
2226 }
2227 _ => {
2228 return self
2229 .send_error_and_get_state(ErrorResponse::error(
2230 SqlState::INTERNAL_ERROR,
2231 "unsupported COPY response type".to_string(),
2232 ))
2233 .await;
2234 }
2235 };
2236 }
2237 ExecuteResponse::CopyFrom {
2238 target_id,
2239 target_name,
2240 columns,
2241 params,
2242 ctx_extra,
2243 } => {
2244 let row_desc =
2245 row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2246 self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2247 .await
2248 }
2249 ExecuteResponse::TransactionCommitted { params }
2250 | ExecuteResponse::TransactionRolledBack { params } => {
2251 let notify_set: mz_ore::collections::HashSet<String> = self
2252 .adapter_client
2253 .session()
2254 .vars()
2255 .notify_set()
2256 .map(|v| v.name().to_string())
2257 .collect();
2258
2259 for (name, value) in params
2261 .into_iter()
2262 .filter(|(name, _v)| notify_set.contains(*name))
2263 {
2264 let msg = BackendMessage::ParameterStatus(name, value);
2265 self.send(msg).await?;
2266 }
2267 command_complete!()
2268 }
2269
2270 ExecuteResponse::AlteredDefaultPrivileges
2271 | ExecuteResponse::AlteredObject(..)
2272 | ExecuteResponse::AlteredRole
2273 | ExecuteResponse::AlteredSystemConfiguration
2274 | ExecuteResponse::CreatedCluster { .. }
2275 | ExecuteResponse::CreatedClusterReplica { .. }
2276 | ExecuteResponse::CreatedConnection { .. }
2277 | ExecuteResponse::CreatedDatabase { .. }
2278 | ExecuteResponse::CreatedIndex { .. }
2279 | ExecuteResponse::CreatedIntrospectionSubscribe
2280 | ExecuteResponse::CreatedMaterializedView { .. }
2281 | ExecuteResponse::CreatedRole
2282 | ExecuteResponse::CreatedSchema { .. }
2283 | ExecuteResponse::CreatedSecret { .. }
2284 | ExecuteResponse::CreatedSink { .. }
2285 | ExecuteResponse::CreatedSource { .. }
2286 | ExecuteResponse::CreatedTable { .. }
2287 | ExecuteResponse::CreatedType
2288 | ExecuteResponse::CreatedView { .. }
2289 | ExecuteResponse::CreatedViews { .. }
2290 | ExecuteResponse::CreatedNetworkPolicy
2291 | ExecuteResponse::Comment
2292 | ExecuteResponse::Deallocate { .. }
2293 | ExecuteResponse::Deleted(..)
2294 | ExecuteResponse::DiscardedAll
2295 | ExecuteResponse::DiscardedTemp
2296 | ExecuteResponse::DroppedObject(_)
2297 | ExecuteResponse::DroppedOwned
2298 | ExecuteResponse::GrantedPrivilege
2299 | ExecuteResponse::GrantedRole
2300 | ExecuteResponse::Inserted(..)
2301 | ExecuteResponse::Copied(..)
2302 | ExecuteResponse::Prepare
2303 | ExecuteResponse::Raised
2304 | ExecuteResponse::ReassignOwned
2305 | ExecuteResponse::RevokedPrivilege
2306 | ExecuteResponse::RevokedRole
2307 | ExecuteResponse::StartedTransaction { .. }
2308 | ExecuteResponse::Updated(..)
2309 | ExecuteResponse::ValidatedConnection => {
2310 command_complete!()
2311 }
2312 };
2313
2314 assert_none!(tag, "tag created but not consumed: {:?}", tag);
2315 r
2316 }
2317
2318 #[allow(clippy::too_many_arguments)]
2319 #[mz_ore::instrument(level = "debug")]
2321 async fn send_rows(
2322 &mut self,
2323 row_desc: RelationDesc,
2324 portal_name: String,
2325 mut rows: InProgressRows,
2326 max_rows: ExecuteCount,
2327 get_response: GetResponse,
2328 fetch_portal_name: Option<String>,
2329 timeout: ExecuteTimeout,
2330 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2331 let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2334 name
2335 } else {
2336 &portal_name
2337 };
2338 let result_formats = self
2339 .adapter_client
2340 .session()
2341 .get_portal_unverified(result_format_portal_name)
2342 .expect("valid fetch portal name for send rows")
2343 .result_formats
2344 .clone();
2345
2346 let (mut wait_once, mut deadline) = match timeout {
2347 ExecuteTimeout::None => (false, None),
2348 ExecuteTimeout::Seconds(t) => (
2349 false,
2350 Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2351 ),
2352 ExecuteTimeout::WaitOnce => (true, None),
2353 };
2354
2355 {
2357 let portal_name_desc = &self
2358 .adapter_client
2359 .session()
2360 .get_portal_unverified(portal_name.as_str())
2361 .expect("portal should exist")
2362 .desc
2363 .relation_desc;
2364 if let Some(portal_name_desc) = portal_name_desc {
2365 soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2366 }
2367 if let Some(fetch_portal_name) = &fetch_portal_name {
2368 let fetch_portal_desc = &self
2369 .adapter_client
2370 .session()
2371 .get_portal_unverified(fetch_portal_name)
2372 .expect("portal should exist")
2373 .desc
2374 .relation_desc;
2375 if let Some(fetch_portal_desc) = fetch_portal_desc {
2376 soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2377 }
2378 }
2379 }
2380
2381 self.conn.set_encode_state(
2382 row_desc
2383 .typ()
2384 .column_types
2385 .iter()
2386 .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2387 .zip_eq(result_formats)
2388 .collect(),
2389 );
2390
2391 let mut total_sent_rows = 0;
2392 let mut total_sent_bytes = 0;
2393 let mut want_rows = match max_rows {
2395 ExecuteCount::All => usize::MAX,
2396 ExecuteCount::Count(count) => count,
2397 };
2398
2399 loop {
2401 let batch = if rows.current.is_some() {
2404 FetchResult::Rows(rows.current.take())
2405 } else if want_rows == 0 {
2406 FetchResult::Rows(None)
2407 } else {
2408 let notice_fut = self.adapter_client.session().recv_notice();
2409 tokio::select! {
2424 biased;
2425 err = self.conn.wait_closed() => return Err(err),
2426 batch = rows.remaining.recv() => match batch {
2427 None => FetchResult::Rows(None),
2428 Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2429 Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2430 Some(PeekResponseUnary::DependencyDropped(dep)) => {
2431 FetchResult::Error(dep.query_terminated_error())
2432 }
2433 Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2434 },
2435 notice = notice_fut => {
2436 FetchResult::Notice(notice)
2437 }
2438 _ = time::sleep_until(
2439 deadline.unwrap_or_else(tokio::time::Instant::now),
2440 ), if deadline.is_some() => FetchResult::Rows(None),
2441 }
2442 };
2443
2444 match batch {
2445 FetchResult::Rows(None) => break,
2446 FetchResult::Rows(Some(mut batch_rows)) => {
2447 if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2448 let msg = err.to_string();
2449 return self
2450 .send_error_and_get_state(err.into_response(Severity::Error))
2451 .await
2452 .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2453 }
2454
2455 if wait_once && batch_rows.peek().is_some() {
2459 deadline = Some(tokio::time::Instant::now());
2460 wait_once = false;
2461 }
2462
2463 let mut sent_rows = 0;
2465 let mut sent_bytes = 0;
2466 let messages = (&mut batch_rows)
2467 .map(|row| {
2472 let row_len = row.byte_len();
2473 let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2474 (row_len, BackendMessage::DataRow(values))
2475 })
2476 .inspect(|(row_len, _)| {
2477 sent_bytes += row_len;
2478 sent_rows += 1
2479 })
2480 .map(|(_row_len, row)| row)
2481 .take(want_rows);
2482 self.send_all(messages).await?;
2483
2484 total_sent_rows += sent_rows;
2485 total_sent_bytes += sent_bytes;
2486 want_rows -= sent_rows;
2487
2488 if want_rows == 0 {
2491 if batch_rows.peek().is_some() {
2492 rows.current = Some(batch_rows);
2493 }
2494 break;
2495 }
2496
2497 self.conn.flush().await?;
2498 }
2499 FetchResult::Notice(notice) => {
2500 self.send(notice.into_response()).await?;
2501 self.conn.flush().await?;
2502 }
2503 FetchResult::Error(text) => {
2504 return self
2505 .send_error_and_get_state(ErrorResponse::error(
2506 SqlState::INTERNAL_ERROR,
2507 text.clone(),
2508 ))
2509 .await
2510 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2511 }
2512 FetchResult::Canceled => {
2513 return self
2514 .send_error_and_get_state(ErrorResponse::error(
2515 SqlState::QUERY_CANCELED,
2516 "canceling statement due to user request",
2517 ))
2518 .await
2519 .map(|state| (state, SendRowsEndedReason::Canceled));
2520 }
2521 }
2522 }
2523
2524 let portal = self
2525 .adapter_client
2526 .session()
2527 .get_portal_unverified_mut(&portal_name)
2528 .expect("valid portal name for send rows");
2529
2530 let saw_rows = rows.remaining.saw_rows;
2531 let no_more_rows = rows.no_more_rows();
2532 let metric_recorded = rows.remaining.metric_recorded;
2533 let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2534
2535 if no_more_rows && !metric_recorded {
2536 rows.remaining.metric_recorded = true;
2537 }
2538
2539 *portal.state = PortalState::InProgress(Some(rows));
2542
2543 let fetch_portal = fetch_portal_name.map(|name| {
2544 self.adapter_client
2545 .session()
2546 .get_portal_unverified_mut(&name)
2547 .expect("valid fetch portal")
2548 });
2549 let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2550 self.send(response_message).await?;
2551
2552 if no_more_rows && !metric_recorded {
2555 let statement_type = if let Some(stmt) = &self
2556 .adapter_client
2557 .session()
2558 .get_portal_unverified(&portal_name)
2559 .expect("valid portal name for send_rows")
2560 .stmt
2561 {
2562 metrics::statement_type_label_value(stmt.deref())
2563 } else {
2564 "no-statement"
2565 };
2566 let duration = if saw_rows {
2567 recorded_first_row_instant
2568 .expect("recorded_first_row_instant because saw_rows")
2569 .elapsed()
2570 } else {
2571 Duration::ZERO
2575 };
2576 self.adapter_client
2577 .inner()
2578 .metrics()
2579 .result_rows_first_to_last_byte_seconds
2580 .with_label_values(&[statement_type])
2581 .observe(duration.as_secs_f64());
2582 }
2583
2584 Ok((
2585 State::Ready,
2586 SendRowsEndedReason::Success {
2587 result_size: u64::cast_from(total_sent_bytes),
2588 rows_returned: u64::cast_from(total_sent_rows),
2589 },
2590 ))
2591 }
2592
2593 #[mz_ore::instrument(level = "debug")]
2594 async fn copy_rows(
2595 &mut self,
2596 format: CopyFormat,
2597 row_desc: RelationDesc,
2598 mut stream: RecordFirstRowStream,
2599 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2600 let (row_format, encode_format) = match format {
2601 CopyFormat::Text => (
2602 CopyFormatParams::Text(CopyTextFormatParams::default()),
2603 Format::Text,
2604 ),
2605 CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2606 CopyFormat::Csv => (
2607 CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2608 Format::Text,
2609 ),
2610 CopyFormat::Parquet => {
2611 let text = "Parquet format is not supported".to_string();
2612 return self
2613 .send_error_and_get_state(ErrorResponse::error(
2614 SqlState::INTERNAL_ERROR,
2615 text.clone(),
2616 ))
2617 .await
2618 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2619 }
2620 };
2621
2622 let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2623 mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2624 };
2625
2626 let typ = row_desc.typ();
2627 let column_formats = iter::repeat(encode_format)
2628 .take(typ.column_types.len())
2629 .collect();
2630 self.send(BackendMessage::CopyOutResponse {
2631 overall_format: encode_format,
2632 column_formats,
2633 })
2634 .await?;
2635
2636 let mut out = Vec::new();
2641
2642 if let CopyFormat::Binary = format {
2643 out.extend(b"PGCOPY\n\xFF\r\n\0");
2645 out.extend([0, 0, 0, 0]);
2647 out.extend([0, 0, 0, 0]);
2649 }
2650
2651 let mut count = 0;
2652 let mut total_sent_bytes = 0;
2653 loop {
2654 tokio::select! {
2655 e = self.conn.wait_closed() => return Err(e),
2656 batch = stream.recv() => match batch {
2657 None => break,
2658 Some(PeekResponseUnary::Error(text)) => {
2659 let err =
2660 ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone());
2661 return self
2662 .send_error_and_get_state(err)
2663 .await
2664 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2665 }
2666 Some(PeekResponseUnary::DependencyDropped(dep)) => {
2667 let err = dep.to_concurrent_dependency_drop();
2668 let text = err.to_string();
2669 let resp = err.into_response(Severity::Error);
2670 return self
2671 .send_error_and_get_state(resp)
2672 .await
2673 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2674 }
2675 Some(PeekResponseUnary::Canceled) => {
2676 return self.send_error_and_get_state(ErrorResponse::error(
2677 SqlState::QUERY_CANCELED,
2678 "canceling statement due to user request",
2679 ))
2680 .await.map(|state| (state, SendRowsEndedReason::Canceled));
2681 }
2682 Some(PeekResponseUnary::Rows(mut rows)) => {
2683 count += rows.count();
2684 while let Some(row) = rows.next() {
2685 total_sent_bytes += row.byte_len();
2686 encode_fn(row, typ, &mut out)?;
2687 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2688 .await?;
2689 }
2690 }
2691 },
2692 notice = self.adapter_client.session().recv_notice() => {
2693 self.send(notice.into_response())
2694 .await?;
2695 self.conn.flush().await?;
2696 }
2697 }
2698
2699 self.conn.flush().await?;
2700 }
2701 if let CopyFormat::Binary = format {
2703 let trailer: i16 = -1;
2704 out.extend(trailer.to_be_bytes());
2705 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2706 .await?;
2707 }
2708
2709 let tag = format!("COPY {}", count);
2710 self.send(BackendMessage::CopyDone).await?;
2711 self.send(BackendMessage::CommandComplete { tag }).await?;
2712 Ok((
2713 State::Ready,
2714 SendRowsEndedReason::Success {
2715 result_size: u64::cast_from(total_sent_bytes),
2716 rows_returned: u64::cast_from(count),
2717 },
2718 ))
2719 }
2720
2721 #[instrument(level = "debug")]
2724 async fn copy_from(
2725 &mut self,
2726 target_id: CatalogItemId,
2727 target_name: String,
2728 columns: Vec<ColumnIndex>,
2729 params: CopyFormatParams<'static>,
2730 row_desc: RelationDesc,
2731 mut ctx_extra: ExecuteContextGuard,
2732 ) -> Result<State, io::Error> {
2733 let res = self
2734 .copy_from_inner(
2735 target_id,
2736 target_name,
2737 columns,
2738 params,
2739 row_desc,
2740 &mut ctx_extra,
2741 )
2742 .await;
2743 match &res {
2744 Ok(State::Ready) => {
2745 self.adapter_client.retire_execute(
2746 ctx_extra,
2747 StatementEndedExecutionReason::Success {
2748 result_size: None,
2749 rows_returned: None,
2750 execution_strategy: None,
2751 },
2752 );
2753 }
2754 Ok(State::Done) => {
2755 self.adapter_client
2759 .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2760 }
2761 Err(e) => {
2762 self.adapter_client.retire_execute(
2763 ctx_extra,
2764 StatementEndedExecutionReason::Errored {
2765 error: format!("{e}"),
2766 },
2767 );
2768 }
2769 Ok(State::Drain) => {}
2770 }
2771 res
2772 }
2773
2774 async fn copy_from_inner(
2775 &mut self,
2776 target_id: CatalogItemId,
2777 target_name: String,
2778 columns: Vec<ColumnIndex>,
2779 params: CopyFormatParams<'static>,
2780 row_desc: RelationDesc,
2781 ctx_extra: &mut ExecuteContextGuard,
2782 ) -> Result<State, io::Error> {
2783 let typ = row_desc.typ();
2784 let column_formats = vec![Format::Text; typ.column_types.len()];
2785 self.send(BackendMessage::CopyInResponse {
2786 overall_format: Format::Text,
2787 column_formats,
2788 })
2789 .await?;
2790 self.conn.flush().await?;
2791
2792 let writer = match self
2794 .adapter_client
2795 .start_copy_from_stdin(
2796 target_id,
2797 target_name.clone(),
2798 columns.clone(),
2799 row_desc.clone(),
2800 params.clone(),
2801 )
2802 .await
2803 {
2804 Ok(writer) => writer,
2805 Err(e) => {
2806 loop {
2812 match self.conn.recv().await? {
2813 Some(FrontendMessage::CopyData(_)) => {}
2814 Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2815 break;
2816 }
2817 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2818 Some(_) => break,
2819 None => return Ok(State::Done),
2820 }
2821 }
2822 self.adapter_client.retire_execute(
2823 std::mem::take(ctx_extra),
2824 StatementEndedExecutionReason::Errored {
2825 error: e.to_string(),
2826 },
2827 );
2828 return self
2829 .send_error_and_get_state(e.into_response(Severity::Error))
2830 .await;
2831 }
2832 };
2833
2834 self.conn.set_copy_mode(true);
2836
2837 const BATCH_SIZE: usize = 32 * 1024 * 1024;
2839 let max_copy_from_row_size = self
2840 .adapter_client
2841 .get_system_vars()
2842 .await
2843 .max_copy_from_row_size()
2844 .try_into()
2845 .unwrap_or(usize::MAX);
2846
2847 let mut data = Vec::new();
2848 let mut row_scanner = CopyRowScanner::new(¶ms);
2849 let num_workers = writer.batch_txs.len();
2850 let mut next_worker: usize = 0;
2851 let mut saw_copy_done = false;
2852 let mut saw_end_marker = false;
2853 let mut copy_from_error: Option<(SqlState, String)> = None;
2854
2855 loop {
2858 let message = self.conn.recv().await?;
2859 match message {
2860 Some(FrontendMessage::CopyData(buf)) => {
2861 if saw_end_marker {
2862 continue;
2865 }
2866 data.extend(buf);
2867 row_scanner.scan_new_bytes(&data);
2868
2869 if let Some(end_pos) = row_scanner.end_marker_end() {
2870 data.truncate(end_pos);
2871 row_scanner.on_truncate(end_pos);
2872 saw_end_marker = true;
2873 }
2874
2875 if row_scanner.current_row_size(data.len()) > max_copy_from_row_size {
2877 copy_from_error = Some((
2878 SqlState::INSUFFICIENT_RESOURCES,
2879 format!(
2880 "COPY FROM STDIN row exceeded max_copy_from_row_size \
2881 ({max_copy_from_row_size} bytes)"
2882 ),
2883 ));
2884 break;
2885 }
2886
2887 let mut send_failed = false;
2890 while data.len() >= BATCH_SIZE {
2891 let split_pos = match row_scanner.last_row_end() {
2892 Some(pos) => pos,
2893 None => break, };
2895 let remainder = data.split_off(split_pos);
2896 let chunk = std::mem::replace(&mut data, remainder);
2897 row_scanner.on_split(split_pos);
2898 if writer.batch_txs[next_worker].send(chunk).await.is_err() {
2899 send_failed = true;
2900 break;
2901 }
2902 next_worker = (next_worker + 1) % num_workers;
2903 }
2904 if send_failed {
2907 break;
2908 }
2909 }
2910 Some(FrontendMessage::CopyDone) => {
2911 if !data.is_empty() {
2913 let chunk = std::mem::take(&mut data);
2914 let _ = writer.batch_txs[next_worker].send(chunk).await;
2916 }
2917 saw_copy_done = true;
2918 break;
2919 }
2920 Some(FrontendMessage::CopyFail(err)) => {
2921 self.adapter_client.retire_execute(
2922 std::mem::take(ctx_extra),
2923 StatementEndedExecutionReason::Canceled,
2924 );
2925 drop(writer);
2927 self.conn.set_copy_mode(false);
2928 return self
2929 .send_error_and_get_state(ErrorResponse::error(
2930 SqlState::QUERY_CANCELED,
2931 format!("COPY from stdin failed: {}", err),
2932 ))
2933 .await;
2934 }
2935 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2936 Some(_) => {
2937 let msg = "unexpected message type during COPY from stdin";
2938 self.adapter_client.retire_execute(
2939 std::mem::take(ctx_extra),
2940 StatementEndedExecutionReason::Errored {
2941 error: msg.to_string(),
2942 },
2943 );
2944 drop(writer);
2945 self.conn.set_copy_mode(false);
2946 return self
2947 .send_error_and_get_state(ErrorResponse::error(
2948 SqlState::PROTOCOL_VIOLATION,
2949 msg,
2950 ))
2951 .await;
2952 }
2953 None => {
2954 drop(writer);
2955 self.conn.set_copy_mode(false);
2956 return Ok(State::Done);
2957 }
2958 }
2959 }
2960
2961 if !saw_copy_done {
2965 loop {
2966 match self.conn.recv().await? {
2967 Some(FrontendMessage::CopyData(_)) => {}
2968 Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2969 break;
2970 }
2971 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2972 Some(_) => {
2973 let msg = "unexpected message type during COPY from stdin";
2974 self.adapter_client.retire_execute(
2975 std::mem::take(ctx_extra),
2976 StatementEndedExecutionReason::Errored {
2977 error: msg.to_string(),
2978 },
2979 );
2980 drop(writer);
2981 self.conn.set_copy_mode(false);
2982 return self
2983 .send_error_and_get_state(ErrorResponse::error(
2984 SqlState::PROTOCOL_VIOLATION,
2985 msg,
2986 ))
2987 .await;
2988 }
2989 None => {
2990 drop(writer);
2991 self.conn.set_copy_mode(false);
2992 return Ok(State::Done);
2993 }
2994 }
2995 }
2996 }
2997
2998 if let Some((code, msg)) = copy_from_error {
2999 self.adapter_client.retire_execute(
3000 std::mem::take(ctx_extra),
3001 StatementEndedExecutionReason::Errored { error: msg.clone() },
3002 );
3003 drop(writer);
3004 self.conn.set_copy_mode(false);
3005 return self
3006 .send_error_and_get_state(ErrorResponse::error(code, msg))
3007 .await;
3008 }
3009
3010 self.conn.set_copy_mode(false);
3011
3012 drop(writer.batch_txs);
3017
3018 let (proto_batches, row_count) = match writer.completion_rx.await {
3020 Ok(Ok(result)) => result,
3021 Ok(Err(e)) => {
3022 self.adapter_client.retire_execute(
3023 std::mem::take(ctx_extra),
3024 StatementEndedExecutionReason::Errored {
3025 error: e.to_string(),
3026 },
3027 );
3028 return self
3029 .send_error_and_get_state(e.into_response(Severity::Error))
3030 .await;
3031 }
3032 Err(_) => {
3033 let msg = "COPY FROM STDIN: background batch builder tasks dropped";
3034 self.adapter_client.retire_execute(
3035 std::mem::take(ctx_extra),
3036 StatementEndedExecutionReason::Errored {
3037 error: msg.to_string(),
3038 },
3039 );
3040 return self
3041 .send_error_and_get_state(ErrorResponse::error(SqlState::INTERNAL_ERROR, msg))
3042 .await;
3043 }
3044 };
3045
3046 if let Err(e) = self
3048 .adapter_client
3049 .stage_copy_from_stdin_batches(target_id, proto_batches)
3050 {
3051 self.adapter_client.retire_execute(
3052 std::mem::take(ctx_extra),
3053 StatementEndedExecutionReason::Errored {
3054 error: e.to_string(),
3055 },
3056 );
3057 return self
3058 .send_error_and_get_state(e.into_response(Severity::Error))
3059 .await;
3060 }
3061
3062 let tag = format!("COPY {}", row_count);
3063 self.send(BackendMessage::CommandComplete { tag }).await?;
3064
3065 Ok(State::Ready)
3066 }
3067
3068 #[instrument(level = "debug")]
3069 async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
3070 let notices = self
3071 .adapter_client
3072 .session()
3073 .drain_notices()
3074 .into_iter()
3075 .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
3076 self.send_all(notices).await?;
3077 Ok(())
3078 }
3079
3080 #[instrument(level = "debug")]
3081 async fn send_error_and_get_state(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
3082 assert!(err.severity.is_error());
3083 debug!(
3084 "cid={} error code={}",
3085 self.adapter_client.session().conn_id(),
3086 err.code.code()
3087 );
3088 let is_fatal = err.severity.is_fatal();
3089 self.send(BackendMessage::ErrorResponse(err)).await?;
3090
3091 let txn = self.adapter_client.session().transaction();
3092 match txn {
3093 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
3096 TransactionStatus::Started(_) => {
3098 self.rollback_transaction().await?;
3099 }
3100 TransactionStatus::InTransactionImplicit(_) => {
3102 self.rollback_transaction().await?;
3103 }
3104 TransactionStatus::InTransaction(_) => {
3106 self.adapter_client.fail_transaction();
3107 }
3108 };
3109 if is_fatal {
3110 Ok(State::Done)
3111 } else {
3112 Ok(State::Drain)
3113 }
3114 }
3115
3116 #[instrument(level = "debug")]
3117 async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
3118 self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
3119 SqlState::IN_FAILED_SQL_TRANSACTION,
3120 ABORTED_TXN_MSG,
3121 )))
3122 .await?;
3123 Ok(State::Drain)
3124 }
3125
3126 fn is_aborted_txn(&mut self) -> bool {
3127 matches!(
3128 self.adapter_client.session().transaction(),
3129 TransactionStatus::Failed(_)
3130 )
3131 }
3132}
3133
3134fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
3135 match (formats.len(), n) {
3136 (0, e) => Ok(vec![Format::Text; e]),
3137 (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
3138 (a, e) if a == e => Ok(formats),
3139 (a, e) => Err(format!(
3140 "expected {} field format specifiers, but got {}",
3141 e, a
3142 )),
3143 }
3144}
3145
3146fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
3147 match &stmt_desc.relation_desc {
3148 Some(desc) if !stmt_desc.is_copy => {
3149 BackendMessage::RowDescription(message::encode_row_description(desc, formats))
3150 }
3151 _ => BackendMessage::NoData,
3152 }
3153}
3154
3155type GetResponse = fn(
3156 max_rows: ExecuteCount,
3157 total_sent_rows: usize,
3158 fetch_portal: Option<PortalRefMut>,
3159) -> BackendMessage;
3160
3161fn portal_exec_message(
3164 max_rows: ExecuteCount,
3165 total_sent_rows: usize,
3166 _fetch_portal: Option<PortalRefMut>,
3167) -> BackendMessage {
3168 match max_rows {
3175 ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
3176 BackendMessage::PortalSuspended
3177 }
3178 _ => BackendMessage::CommandComplete {
3179 tag: format!("SELECT {}", total_sent_rows),
3180 },
3181 }
3182}
3183
3184fn fetch_message(
3186 _max_rows: ExecuteCount,
3187 total_sent_rows: usize,
3188 fetch_portal: Option<PortalRefMut>,
3189) -> BackendMessage {
3190 let tag = format!("FETCH {}", total_sent_rows);
3191 if let Some(portal) = fetch_portal {
3192 *portal.state = PortalState::Completed(Some(tag.clone()));
3193 }
3194 BackendMessage::CommandComplete { tag }
3195}
3196
3197fn get_authenticator(
3198 authenticator_kind: listeners::AuthenticatorKind,
3199 frontegg: Option<FronteggAuthenticator>,
3200 oidc: GenericOidcAuthenticator,
3201 adapter_client: mz_adapter::Client,
3202 oidc_auth_option_enabled: bool,
3205) -> Authenticator {
3206 match authenticator_kind {
3207 listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
3208 "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
3209 )),
3210 listeners::AuthenticatorKind::Password => Authenticator::Password(adapter_client),
3211 listeners::AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client),
3212 listeners::AuthenticatorKind::Oidc => {
3213 if oidc_auth_option_enabled {
3214 Authenticator::Oidc(oidc)
3215 } else {
3216 Authenticator::Password(adapter_client)
3219 }
3220 }
3221 listeners::AuthenticatorKind::None => Authenticator::None,
3222 }
3223}
3224
3225#[derive(Debug, Copy, Clone)]
3226enum ExecuteCount {
3227 All,
3228 Count(usize),
3229}
3230
3231fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
3233 match stmt {
3234 Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
3236 None => false,
3237 }
3238}
3239
3240#[derive(Debug)]
3241enum FetchResult {
3242 Rows(Option<Box<dyn RowIterator + Send + Sync>>),
3243 Canceled,
3244 Error(String),
3245 Notice(AdapterNotice),
3246}
3247
3248#[derive(Debug)]
3249struct CopyRowScanner {
3250 scan_pos: usize,
3251 last_row_end: Option<usize>,
3252 end_marker_end: Option<usize>,
3253 csv: Option<CsvScanState>,
3254}
3255
3256#[derive(Debug)]
3257struct CsvScanState {
3258 reader: csv_core::Reader,
3259 output: Vec<u8>,
3260 ends: Vec<usize>,
3261 record: Vec<u8>,
3262 record_ends: Vec<usize>,
3263 skip_first_record: bool,
3264}
3265
3266impl CopyRowScanner {
3267 fn new(params: &CopyFormatParams<'_>) -> Self {
3268 let csv = match params {
3269 CopyFormatParams::Csv(CopyCsvFormatParams {
3270 delimiter,
3271 quote,
3272 escape,
3273 header,
3274 ..
3275 }) => Some(CsvScanState::new(*delimiter, *quote, *escape, *header)),
3276 _ => None,
3277 };
3278
3279 CopyRowScanner {
3280 scan_pos: 0,
3281 last_row_end: None,
3282 end_marker_end: None,
3283 csv,
3284 }
3285 }
3286
3287 fn scan_new_bytes(&mut self, data: &[u8]) {
3288 if self.scan_pos >= data.len() {
3289 return;
3290 }
3291
3292 if let Some(csv) = self.csv.as_mut() {
3293 let mut input = &data[self.scan_pos..];
3294 let mut consumed = 0usize;
3295 while !input.is_empty() {
3296 let (result, n_input, n_output, n_ends) =
3297 csv.reader
3298 .read_record(input, &mut csv.output, &mut csv.ends);
3299 consumed += n_input;
3300 input = &input[n_input..];
3301 if !csv.output.is_empty() {
3302 csv.record.extend_from_slice(&csv.output[..n_output]);
3303 }
3304 if !csv.ends.is_empty() {
3305 csv.record_ends.extend_from_slice(&csv.ends[..n_ends]);
3306 }
3307
3308 match result {
3309 ReadRecordResult::InputEmpty => break,
3310 ReadRecordResult::OutputFull => {
3311 if n_input == 0 {
3312 csv.output
3313 .resize(csv.output.len().saturating_mul(2).max(1), 0);
3314 }
3315 }
3316 ReadRecordResult::OutputEndsFull => {
3317 if n_input == 0 {
3318 csv.ends.resize(csv.ends.len().saturating_mul(2).max(1), 0);
3319 }
3320 }
3321 ReadRecordResult::Record | ReadRecordResult::End => {
3322 let row_end = self.scan_pos + consumed;
3323 self.last_row_end = Some(row_end);
3324 if self.end_marker_end.is_none() {
3325 let is_marker = if csv.skip_first_record {
3326 csv.skip_first_record = false;
3327 false
3328 } else if csv.record_ends.len() == 1 {
3329 let end = csv.record_ends[0];
3330 end == 2 && csv.record.get(0..end) == Some(b"\\.")
3331 } else {
3332 false
3333 };
3334 if is_marker {
3335 self.end_marker_end = Some(row_end);
3336 break;
3337 }
3338 }
3339 csv.record.clear();
3340 csv.record_ends.clear();
3341 }
3342 }
3343 }
3344 } else {
3345 let mut row_start = self.last_row_end.unwrap_or(0);
3346 for (offset, b) in data[self.scan_pos..].iter().enumerate() {
3347 if *b == b'\n' {
3348 let row_end = self.scan_pos + offset + 1;
3349 self.last_row_end = Some(row_end);
3350 if self.end_marker_end.is_none() {
3351 let row = &data[row_start..row_end];
3352 if row.get(0..2) == Some(b"\\.") {
3353 self.end_marker_end = Some(row_end);
3354 break;
3355 }
3356 }
3357 row_start = row_end;
3358 }
3359 }
3360 }
3361
3362 self.scan_pos = data.len();
3363 }
3364
3365 fn last_row_end(&self) -> Option<usize> {
3366 self.last_row_end
3367 }
3368
3369 fn end_marker_end(&self) -> Option<usize> {
3370 self.end_marker_end
3371 }
3372
3373 fn current_row_size(&self, data_len: usize) -> usize {
3374 data_len.saturating_sub(self.last_row_end.unwrap_or(0))
3375 }
3376
3377 fn on_split(&mut self, split_pos: usize) {
3378 self.scan_pos = self.scan_pos.saturating_sub(split_pos);
3379 self.last_row_end = None;
3380 self.end_marker_end = self
3381 .end_marker_end
3382 .and_then(|end| end.checked_sub(split_pos));
3383 }
3384
3385 fn on_truncate(&mut self, new_len: usize) {
3386 self.scan_pos = self.scan_pos.min(new_len);
3387 self.last_row_end = self.last_row_end.filter(|&end| end <= new_len);
3388 self.end_marker_end = self.end_marker_end.filter(|&end| end <= new_len);
3389 }
3390}
3391
3392impl CsvScanState {
3393 fn new(delimiter: u8, quote: u8, escape: u8, header: bool) -> Self {
3394 let (double_quote, escape) = if quote == escape {
3395 (true, None)
3396 } else {
3397 (false, Some(escape))
3398 };
3399 CsvScanState {
3400 reader: csv_core::ReaderBuilder::new()
3401 .delimiter(delimiter)
3402 .quote(quote)
3403 .double_quote(double_quote)
3404 .escape(escape)
3405 .build(),
3406 output: vec![0; 1],
3407 ends: vec![0; 1],
3408 record: Vec::new(),
3409 record_ends: Vec::new(),
3410 skip_first_record: header,
3411 }
3412 }
3413}
3414
3415#[cfg(test)]
3416mod test {
3417 use super::*;
3418
3419 #[mz_ore::test]
3420 fn test_parse_options() {
3421 struct TestCase {
3422 input: &'static str,
3423 expect: Result<Vec<(&'static str, &'static str)>, ()>,
3424 }
3425 let tests = vec![
3426 TestCase {
3427 input: "",
3428 expect: Ok(vec![]),
3429 },
3430 TestCase {
3431 input: "--key",
3432 expect: Err(()),
3433 },
3434 TestCase {
3435 input: "--key=val",
3436 expect: Ok(vec![("key", "val")]),
3437 },
3438 TestCase {
3439 input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
3440 expect: Ok(vec![
3441 ("key", "val"),
3442 ("key2", "val2"),
3443 ("key3", "val3"),
3444 ("key4", "val4"),
3445 ("key5", "val5"),
3446 ]),
3447 },
3448 TestCase {
3449 input: r#"-c\ key=val"#,
3450 expect: Ok(vec![(" key", "val")]),
3451 },
3452 TestCase {
3453 input: "--key=val -ckey2 val2",
3454 expect: Err(()),
3455 },
3456 TestCase {
3458 input: "--key=",
3459 expect: Ok(vec![("key", "")]),
3460 },
3461 ];
3462 for test in tests {
3463 let got = parse_options(test.input);
3464 let expect = test.expect.map(|r| {
3465 r.into_iter()
3466 .map(|(k, v)| (k.to_owned(), v.to_owned()))
3467 .collect()
3468 });
3469 assert_eq!(got, expect, "input: {}", test.input);
3470 }
3471 }
3472
3473 #[mz_ore::test]
3474 fn test_parse_option() {
3475 struct TestCase {
3476 input: &'static str,
3477 expect: Result<(&'static str, &'static str), ()>,
3478 }
3479 let tests = vec![
3480 TestCase {
3481 input: "",
3482 expect: Err(()),
3483 },
3484 TestCase {
3485 input: "--",
3486 expect: Err(()),
3487 },
3488 TestCase {
3489 input: "--c",
3490 expect: Err(()),
3491 },
3492 TestCase {
3493 input: "a=b",
3494 expect: Err(()),
3495 },
3496 TestCase {
3497 input: "--a=b",
3498 expect: Ok(("a", "b")),
3499 },
3500 TestCase {
3501 input: "--ca=b",
3502 expect: Ok(("ca", "b")),
3503 },
3504 TestCase {
3505 input: "-ca=b",
3506 expect: Ok(("a", "b")),
3507 },
3508 TestCase {
3510 input: "--=",
3511 expect: Ok(("", "")),
3512 },
3513 ];
3514 for test in tests {
3515 let got = parse_option(test.input);
3516 assert_eq!(got, test.expect, "input: {}", test.input);
3517 }
3518 }
3519
3520 #[mz_ore::test]
3521 fn test_split_options() {
3522 struct TestCase {
3523 input: &'static str,
3524 expect: Vec<&'static str>,
3525 }
3526 let tests = vec![
3527 TestCase {
3528 input: "",
3529 expect: vec![],
3530 },
3531 TestCase {
3532 input: " ",
3533 expect: vec![],
3534 },
3535 TestCase {
3536 input: " a ",
3537 expect: vec!["a"],
3538 },
3539 TestCase {
3540 input: " ab cd ",
3541 expect: vec!["ab", "cd"],
3542 },
3543 TestCase {
3544 input: r#" ab\ cd "#,
3545 expect: vec!["ab ", "cd"],
3546 },
3547 TestCase {
3548 input: r#" ab\\ cd "#,
3549 expect: vec![r#"ab\"#, "cd"],
3550 },
3551 TestCase {
3552 input: r#" ab\\\ cd "#,
3553 expect: vec![r#"ab\ "#, "cd"],
3554 },
3555 TestCase {
3556 input: r#" ab\\\ cd "#,
3557 expect: vec![r#"ab\ cd"#],
3558 },
3559 TestCase {
3560 input: r#" ab\\\cd "#,
3561 expect: vec![r#"ab\cd"#],
3562 },
3563 TestCase {
3564 input: r#"a\"#,
3565 expect: vec!["a"],
3566 },
3567 TestCase {
3568 input: r#"a\ "#,
3569 expect: vec!["a "],
3570 },
3571 TestCase {
3572 input: r#"\"#,
3573 expect: vec![],
3574 },
3575 TestCase {
3576 input: r#"\ "#,
3577 expect: vec![r#" "#],
3578 },
3579 TestCase {
3580 input: r#" \ "#,
3581 expect: vec![r#" "#],
3582 },
3583 TestCase {
3584 input: r#"\ "#,
3585 expect: vec![r#" "#],
3586 },
3587 ];
3588 for test in tests {
3589 let got = split_options(test.input);
3590 assert_eq!(got, test.expect, "input: {}", test.input);
3591 }
3592 }
3593}