1use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use csv_core::ReadRecordResult;
21use futures::future::{BoxFuture, FutureExt, pending};
22use itertools::Itertools;
23use mz_adapter::client::RecordFirstRowStream;
24use mz_adapter::session::{
25 EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState, Session,
26 SessionConfig, TransactionStatus,
27};
28use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
29use mz_adapter::{
30 AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics,
31 verify_datum_desc,
32};
33use mz_auth::Authenticated;
34use mz_auth::password::Password;
35use mz_authenticator::{Authenticator, GenericOidcAuthenticator};
36use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
37use mz_ore::cast::CastFrom;
38use mz_ore::netio::AsyncReady;
39use mz_ore::now::{EpochMillis, SYSTEM_TIME};
40use mz_ore::str::StrExt;
41use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
42use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
43use mz_pgwire_common::{
44 ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
45 VERSIONS,
46};
47use mz_repr::{
48 CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
49 SqlRelationType, SqlScalarType,
50};
51use mz_server_core::TlsMode;
52use mz_server_core::listeners;
53use mz_server_core::listeners::AllowedRoles;
54use mz_sql::ast::display::AstDisplay;
55use mz_sql::ast::{
56 CopyDirection, CopyStatement, CopyTarget, FetchDirection, Ident, Raw, Statement,
57};
58use mz_sql::parse::StatementParseResult;
59use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
60use mz_sql::session::metadata::SessionMetadata;
61use mz_sql::session::user::INTERNAL_USER_NAMES;
62use mz_sql::session::vars::VarInput;
63use postgres::error::SqlState;
64use tokio::io::{self, AsyncRead, AsyncWrite};
65use tokio::select;
66use tokio::time::{self};
67use tokio_metrics::TaskMetrics;
68use tokio_stream::wrappers::UnboundedReceiverStream;
69use tracing::{Instrument, debug, debug_span, warn};
70use uuid::Uuid;
71
72use crate::codec::{
73 FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
74};
75use crate::message::{
76 self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
77 SASLServerFirstMessage,
78};
79
80pub fn match_handshake(buf: &[u8]) -> bool {
84 if buf.len() < 8 {
94 return false;
95 }
96 let version = NetworkEndian::read_i32(&buf[4..8]);
97 VERSIONS.contains(&version)
98}
99
100pub struct RunParams<'a, A, I>
102where
103 I: Iterator<Item = TaskMetrics> + Send,
104{
105 pub tls_mode: Option<TlsMode>,
107 pub adapter_client: mz_adapter::Client,
109 pub conn: &'a mut FramedConn<A>,
111 pub conn_uuid: Uuid,
113 pub version: i32,
115 pub params: BTreeMap<String, String>,
117 pub frontegg: Option<FronteggAuthenticator>,
119 pub oidc: GenericOidcAuthenticator,
121 pub authenticator_kind: listeners::AuthenticatorKind,
124 pub active_connection_counter: ConnectionCounter,
126 pub helm_chart_version: Option<String>,
128 pub allowed_roles: AllowedRoles,
130 pub tokio_metrics_intervals: I,
132}
133
134#[mz_ore::instrument(level = "debug")]
144pub async fn run<'a, A, I>(
145 RunParams {
146 tls_mode,
147 adapter_client,
148 conn,
149 conn_uuid,
150 version,
151 mut params,
152 frontegg,
153 oidc,
154 authenticator_kind,
155 active_connection_counter,
156 helm_chart_version,
157 allowed_roles,
158 tokio_metrics_intervals,
159 }: RunParams<'a, A, I>,
160) -> Result<(), io::Error>
161where
162 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
163 I: Iterator<Item = TaskMetrics> + Send,
164{
165 if version != VERSION_3 {
166 return conn
167 .send(ErrorResponse::fatal(
168 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
169 "server does not support the client's requested protocol version",
170 ))
171 .await;
172 }
173
174 let user = params.remove("user").unwrap_or_else(String::new);
175 let options = parse_options(params.get("options").unwrap_or(&String::new()));
176
177 let (oidc_auth_enabled, options) = extract_oidc_auth_enabled_from_options(options);
180 let authenticator = get_authenticator(
181 authenticator_kind,
182 frontegg,
183 oidc,
184 adapter_client.clone(),
185 oidc_auth_enabled,
186 );
187 let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
189 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
191 let role_allowed = match allowed_roles {
192 AllowedRoles::Normal => !is_reserved_user,
193 AllowedRoles::Internal => is_internal_user,
194 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
195 };
196 if !role_allowed {
197 let msg = format!("unauthorized login to user '{user}'");
198 return conn
199 .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
200 .await;
201 }
202
203 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
204 return conn.send(err).await;
205 }
206
207 let authenticator_kind = authenticator.kind();
208
209 let (mut session, expired) = match authenticator {
210 Authenticator::Frontegg(frontegg) => {
211 let password = match request_cleartext_password(conn).await {
212 Ok(password) => password,
213 Err(PasswordRequestError::IoError(e)) => return Err(e),
214 Err(PasswordRequestError::InvalidPasswordError(e)) => {
215 return conn.send(e).await;
216 }
217 };
218
219 let auth_response = frontegg.authenticate(&user, &password).await;
220 match auth_response {
221 Ok((mut auth_session, authenticated)) => {
228 let session = adapter_client.new_session(
229 SessionConfig {
230 conn_id: conn.conn_id().clone(),
231 uuid: conn_uuid,
232 user: auth_session.user().into(),
233 client_ip: conn.peer_addr().clone(),
234 external_metadata_rx: Some(auth_session.external_metadata_rx()),
235 helm_chart_version,
236 authenticator_kind,
237 },
238 authenticated,
239 );
240 let expired = async move { auth_session.expired().await };
241 (session, expired.left_future())
242 }
243 Err(err) => {
244 warn!(?err, "pgwire connection failed authentication");
245 return conn
246 .send(ErrorResponse::fatal(
247 SqlState::INVALID_PASSWORD,
248 "invalid password",
249 ))
250 .await;
251 }
252 }
253 }
254 Authenticator::Oidc(oidc) => {
255 let jwt = match request_cleartext_password(conn).await {
257 Ok(password) => password,
258 Err(PasswordRequestError::IoError(e)) => return Err(e),
259 Err(PasswordRequestError::InvalidPasswordError(e)) => {
260 return conn.send(e).await;
261 }
262 };
263 let auth_response = oidc.authenticate(&jwt, Some(&user)).await;
264 match auth_response {
265 Ok((claims, authenticated)) => {
266 let session = adapter_client.new_session(
267 SessionConfig {
268 conn_id: conn.conn_id().clone(),
269 uuid: conn_uuid,
270 user: claims.user,
271 client_ip: conn.peer_addr().clone(),
272 external_metadata_rx: None,
273 helm_chart_version,
274 authenticator_kind,
275 },
276 authenticated,
277 );
278 (session, pending().right_future())
281 }
282 Err(err) => {
283 warn!(?err, "pgwire connection failed authentication");
284 return conn.send(err.into_response()).await;
285 }
286 }
287 }
288 Authenticator::Password(adapter_client) => {
289 let session = match authenticate_with_password(
290 conn,
291 &adapter_client,
292 user,
293 conn_uuid,
294 helm_chart_version,
295 authenticator_kind,
296 )
297 .await
298 {
299 Ok(session) => session,
300 Err(PasswordRequestError::IoError(e)) => return Err(e),
301 Err(PasswordRequestError::InvalidPasswordError(e)) => {
302 return conn.send(e).await;
303 }
304 };
305 (session, pending().right_future())
307 }
308 Authenticator::Sasl(adapter_client) => {
309 conn.send(BackendMessage::AuthenticationSASL).await?;
311 conn.flush().await?;
312 let (mechanism, initial_response) = match conn.recv().await? {
314 Some(FrontendMessage::RawAuthentication(data)) => {
315 match decode_sasl_initial_response(Cursor::new(&data)).ok() {
316 Some(FrontendMessage::SASLInitialResponse {
317 gs2_header,
318 mechanism,
319 initial_response,
320 }) => {
321 if gs2_header.channel_binding_enabled() {
323 return conn
324 .send(ErrorResponse::fatal(
325 SqlState::PROTOCOL_VIOLATION,
326 "channel binding not supported",
327 ))
328 .await;
329 }
330 (mechanism, initial_response)
331 }
332 _ => {
333 return conn
334 .send(ErrorResponse::fatal(
335 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
336 "expected SASLInitialResponse message",
337 ))
338 .await;
339 }
340 }
341 }
342 _ => {
343 return conn
344 .send(ErrorResponse::fatal(
345 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
346 "expected SASLInitialResponse message",
347 ))
348 .await;
349 }
350 };
351
352 if mechanism != "SCRAM-SHA-256" {
353 return conn
354 .send(ErrorResponse::fatal(
355 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
356 "unsupported SASL mechanism",
357 ))
358 .await;
359 }
360
361 if initial_response.nonce.len() > 256 {
362 return conn
363 .send(ErrorResponse::fatal(
364 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
365 "nonce too long",
366 ))
367 .await;
368 }
369
370 let (server_first_message_raw, mock_hash) = match adapter_client
371 .generate_sasl_challenge(&user, &initial_response.nonce)
372 .await
373 {
374 Ok(response) => {
375 let server_first_message_raw = format!(
376 "r={},s={},i={}",
377 response.nonce, response.salt, response.iteration_count
378 );
379
380 let client_key = [0u8; 32];
381 let server_key = [1u8; 32];
382 let mock_hash = format!(
383 "SCRAM-SHA-256${}:{}${}:{}",
384 response.iteration_count,
385 response.salt,
386 BASE64_STANDARD.encode(client_key),
387 BASE64_STANDARD.encode(server_key)
388 );
389
390 conn.send(BackendMessage::AuthenticationSASLContinue(
391 SASLServerFirstMessage {
392 iteration_count: response.iteration_count,
393 nonce: response.nonce,
394 salt: response.salt,
395 },
396 ))
397 .await?;
398 conn.flush().await?;
399 (server_first_message_raw, mock_hash)
400 }
401 Err(e) => {
402 return conn.send(e.into_response(Severity::Fatal)).await;
403 }
404 };
405
406 let authenticated = match conn.recv().await? {
407 Some(FrontendMessage::RawAuthentication(data)) => {
408 match decode_sasl_response(Cursor::new(&data)).ok() {
409 Some(FrontendMessage::SASLResponse(response)) => {
410 let auth_message = format!(
411 "{},{},{}",
412 initial_response.client_first_message_bare_raw,
413 server_first_message_raw,
414 response.client_final_message_bare_raw
415 );
416 if response.proof.len() > 1024 {
417 return conn
418 .send(ErrorResponse::fatal(
419 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
420 "proof too long",
421 ))
422 .await;
423 }
424 match adapter_client
425 .verify_sasl_proof(
426 &user,
427 &response.proof,
428 &auth_message,
429 &mock_hash,
430 )
431 .await
432 {
433 Ok((proof_response, authenticated)) => {
434 conn.send(BackendMessage::AuthenticationSASLFinal(
435 SASLServerFinalMessage {
436 kind: SASLServerFinalMessageKinds::Verifier(
437 proof_response.verifier,
438 ),
439 extensions: vec![],
440 },
441 ))
442 .await?;
443 conn.flush().await?;
444 authenticated
445 }
446 Err(_) => {
447 return conn
448 .send(ErrorResponse::fatal(
449 SqlState::INVALID_PASSWORD,
450 "invalid password",
451 ))
452 .await;
453 }
454 }
455 }
456 _ => {
457 return conn
458 .send(ErrorResponse::fatal(
459 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
460 "expected SASLResponse message",
461 ))
462 .await;
463 }
464 }
465 }
466 _ => {
467 return conn
468 .send(ErrorResponse::fatal(
469 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
470 "expected SASLResponse message",
471 ))
472 .await;
473 }
474 };
475
476 let session = adapter_client.new_session(
477 SessionConfig {
478 conn_id: conn.conn_id().clone(),
479 uuid: conn_uuid,
480 user,
481 client_ip: conn.peer_addr().clone(),
482 external_metadata_rx: None,
483 helm_chart_version,
484 authenticator_kind,
485 },
486 authenticated,
487 );
488 let auth_session = pending().right_future();
490 (session, auth_session)
491 }
492
493 Authenticator::None => {
494 let session = adapter_client.new_session(
495 SessionConfig {
496 conn_id: conn.conn_id().clone(),
497 uuid: conn_uuid,
498 user,
499 client_ip: conn.peer_addr().clone(),
500 external_metadata_rx: None,
501 helm_chart_version,
502 authenticator_kind,
503 },
504 Authenticated,
505 );
506 let auth_session = pending().right_future();
508 (session, auth_session)
509 }
510 };
511
512 let system_vars = adapter_client.get_system_vars().await;
513 for (name, value) in params {
514 let settings = match name.as_str() {
515 "options" => match &options {
516 Ok(opts) => opts,
517 Err(()) => {
518 session.add_notice(AdapterNotice::BadStartupSetting {
519 name,
520 reason: "could not parse".into(),
521 });
522 continue;
523 }
524 },
525 _ => &vec![(name, value)],
526 };
527 for (key, val) in settings {
528 const LOCAL: bool = false;
529 if let Err(err) = session
534 .vars_mut()
535 .set(&system_vars, key, VarInput::Flat(val), LOCAL)
536 {
537 session.add_notice(AdapterNotice::BadStartupSetting {
538 name: key.clone(),
539 reason: err.to_string(),
540 });
541 }
542 }
543 }
544 session
545 .vars_mut()
546 .end_transaction(EndTransactionAction::Commit);
547
548 let _guard = match active_connection_counter.allocate_connection(session.user()) {
549 Ok(drop_connection) => drop_connection,
550 Err(e) => {
551 let e: AdapterError = e.into();
552 return conn.send(e.into_response(Severity::Fatal)).await;
553 }
554 };
555
556 let mut adapter_client = match adapter_client.startup(session).await {
558 Ok(adapter_client) => adapter_client,
559 Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
560 };
561
562 let mut buf = vec![BackendMessage::AuthenticationOk];
563 for var in adapter_client.session().vars().notify_set() {
564 buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
565 }
566 buf.push(BackendMessage::BackendKeyData {
567 conn_id: adapter_client.session().conn_id().unhandled(),
568 secret_key: adapter_client.session().secret_key(),
569 });
570 buf.extend(
571 adapter_client
572 .session()
573 .drain_notices()
574 .into_iter()
575 .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
576 );
577 buf.push(BackendMessage::ReadyForQuery(
578 adapter_client.session().transaction().into(),
579 ));
580 conn.send_all(buf).await?;
581 conn.flush().await?;
582
583 let machine = StateMachine {
584 conn,
585 adapter_client,
586 txn_needs_commit: false,
587 tokio_metrics_intervals,
588 };
589
590 select! {
591 r = machine.run() => {
592 if let Err(err) = &r {
597 let _ = conn
598 .send(ErrorResponse::fatal(
599 SqlState::CONNECTION_FAILURE,
600 err.to_string(),
601 ))
602 .await;
603 let _ = conn.flush().await;
604 }
605 r
606 },
607 _ = expired => {
608 conn
609 .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
610 .await?;
611 conn.flush().await
612 }
613 }
614}
615
616fn extract_oidc_auth_enabled_from_options(
620 options: Result<Vec<(String, String)>, ()>,
621) -> (bool, Result<Vec<(String, String)>, ()>) {
622 let options = match options {
623 Ok(opts) => opts,
624 Err(_) => return (false, options),
625 };
626
627 let mut new_options = Vec::new();
628 let mut oidc_auth_enabled = false;
629
630 for (k, v) in options {
631 if k == "oidc_auth_enabled" {
632 oidc_auth_enabled = v.parse::<bool>().unwrap_or(false);
633 } else {
634 new_options.push((k, v));
635 }
636 }
637
638 (oidc_auth_enabled, Ok(new_options))
639}
640
641fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
646 let opts = split_options(value);
647 let mut pairs = Vec::with_capacity(opts.len());
648 let mut seen_prefix = false;
649 for opt in opts {
650 if !seen_prefix {
651 if opt == "-c" {
652 seen_prefix = true;
653 } else {
654 let (key, val) = parse_option(&opt)?;
655 pairs.push((key.to_owned(), val.to_owned()));
656 }
657 } else {
658 let (key, val) = opt.split_once('=').ok_or(())?;
659 pairs.push((key.to_owned(), val.to_owned()));
660 seen_prefix = false;
661 }
662 }
663 Ok(pairs)
664}
665
666fn parse_option(option: &str) -> Result<(&str, &str), ()> {
670 let (key, value) = option.split_once('=').ok_or(())?;
671 for prefix in &["-c", "--"] {
672 if let Some(key) = key.strip_prefix(prefix) {
673 return Ok((key, value));
674 }
675 }
676 Err(())
677}
678
679fn split_options(value: &str) -> Vec<String> {
681 let mut strs = Vec::new();
682 let mut current = String::new();
686 let mut was_slash = false;
687 for c in value.chars() {
688 was_slash = match c {
689 ' ' => {
690 if was_slash {
691 current.push(' ');
692 } else if !current.is_empty() {
693 strs.push(std::mem::take(&mut current));
696 }
697 false
698 }
699 '\\' => {
700 if was_slash {
701 current.push('\\');
704 false
705 } else {
706 true
707 }
708 }
709 _ => {
710 current.push(c);
711 false
712 }
713 };
714 }
715 if !current.is_empty() {
717 strs.push(current);
718 }
719 strs
720}
721
722enum PasswordRequestError {
723 InvalidPasswordError(ErrorResponse),
724 IoError(io::Error),
725}
726
727impl From<io::Error> for PasswordRequestError {
728 fn from(e: io::Error) -> Self {
729 PasswordRequestError::IoError(e)
730 }
731}
732
733async fn request_cleartext_password<A>(
737 conn: &mut FramedConn<A>,
738) -> Result<String, PasswordRequestError>
739where
740 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
741{
742 conn.send(BackendMessage::AuthenticationCleartextPassword)
743 .await?;
744 conn.flush().await?;
745
746 if let Some(message) = conn.recv().await? {
747 if let FrontendMessage::RawAuthentication(data) = message {
748 if let Some(FrontendMessage::Password { password }) =
749 decode_password(Cursor::new(&data)).ok()
750 {
751 return Ok(password);
752 }
753 }
754 }
755
756 Err(PasswordRequestError::InvalidPasswordError(
757 ErrorResponse::fatal(
758 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
759 "expected Password message",
760 ),
761 ))
762}
763
764async fn authenticate_with_password<A>(
767 conn: &mut FramedConn<A>,
768 adapter_client: &mz_adapter::Client,
769 user: String,
770 conn_uuid: Uuid,
771 helm_chart_version: Option<String>,
772 authenticator_kind: mz_auth::AuthenticatorKind,
773) -> Result<Session, PasswordRequestError>
774where
775 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
776{
777 let password = match request_cleartext_password(conn).await {
778 Ok(password) => Password(password),
779 Err(e) => return Err(e),
780 };
781
782 let authenticated = match adapter_client.authenticate(&user, &password).await {
783 Ok(authenticated) => authenticated,
784 Err(err) => {
785 warn!(?err, "pgwire connection failed authentication");
786 return Err(PasswordRequestError::InvalidPasswordError(
787 ErrorResponse::fatal(SqlState::INVALID_PASSWORD, "invalid password"),
788 ));
789 }
790 };
791
792 let session = adapter_client.new_session(
793 SessionConfig {
794 conn_id: conn.conn_id().clone(),
795 uuid: conn_uuid,
796 user,
797 client_ip: conn.peer_addr().clone(),
798 external_metadata_rx: None,
799 helm_chart_version,
800 authenticator_kind,
801 },
802 authenticated,
803 );
804
805 Ok(session)
806}
807
808#[derive(Debug)]
809enum State {
810 Ready,
811 Drain,
812 Done,
813}
814
815struct StateMachine<'a, A, I>
816where
817 I: Iterator<Item = TaskMetrics> + Send + 'a,
818{
819 conn: &'a mut FramedConn<A>,
820 adapter_client: mz_adapter::SessionClient,
821 txn_needs_commit: bool,
822 tokio_metrics_intervals: I,
823}
824
825enum SendRowsEndedReason {
826 Success {
827 result_size: u64,
828 rows_returned: u64,
829 },
830 Errored {
831 error: String,
832 },
833 Canceled,
834}
835
836const ABORTED_TXN_MSG: &str =
837 "current transaction is aborted, commands ignored until end of transaction block";
838
839impl<'a, A, I> StateMachine<'a, A, I>
840where
841 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
842 I: Iterator<Item = TaskMetrics> + Send + 'a,
843{
844 #[allow(clippy::manual_async_fn)]
848 #[mz_ore::instrument(level = "debug")]
849 fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
850 async move {
851 let mut state = State::Ready;
852 loop {
853 self.send_pending_notices().await?;
854 state = match state {
855 State::Ready => self.advance_ready().await?,
856 State::Drain => self.advance_drain().await?,
857 State::Done => return Ok(()),
858 };
859 self.adapter_client
860 .add_idle_in_transaction_session_timeout();
861 }
862 }
863 }
864
865 #[instrument(level = "debug")]
866 async fn advance_ready(&mut self) -> Result<State, io::Error> {
867 self.tokio_metrics_intervals
869 .next()
870 .expect("infinite iterator");
871
872 let message = select! {
874 biased;
875
876 Some(timeout) = self.adapter_client.recv_timeout() => {
878 let err: AdapterError = timeout.into();
879 let conn_id = self.adapter_client.session().conn_id();
880 tracing::warn!("session timed out, conn_id {}", conn_id);
881
882 let error_response = err.into_response(Severity::Fatal);
884 let error_state = self.send_error_and_get_state(error_response).await;
885
886 self.adapter_client.terminate().await;
888
889 let _ = self.conn.recv().await?;
893 return error_state;
894 },
895 message = self.conn.recv() => message?,
897 };
898
899 let interval = self
901 .tokio_metrics_intervals
902 .next()
903 .expect("infinite iterator");
904 let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
905
906 let received = SYSTEM_TIME();
910
911 self.adapter_client
912 .remove_idle_in_transaction_session_timeout();
913
914 let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
917
918 let start = message.as_ref().map(|_| Instant::now());
919 let next_state = match message {
920 Some(FrontendMessage::Query { sql }) => {
921 let query_root_span =
922 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
923 query_root_span.follows_from(tracing::Span::current());
924 self.query(sql, received)
925 .instrument(query_root_span)
926 .await?
927 }
928 Some(FrontendMessage::Parse {
929 name,
930 sql,
931 param_types,
932 }) => self.parse(name, sql, param_types).await?,
933 Some(FrontendMessage::Bind {
934 portal_name,
935 statement_name,
936 param_formats,
937 raw_params,
938 result_formats,
939 }) => {
940 self.bind(
941 portal_name,
942 statement_name,
943 param_formats,
944 raw_params,
945 result_formats,
946 )
947 .await?
948 }
949 Some(FrontendMessage::Execute {
950 portal_name,
951 max_rows,
952 }) => {
953 let max_rows = match usize::try_from(max_rows) {
954 Ok(0) | Err(_) => ExecuteCount::All, Ok(n) => ExecuteCount::Count(n),
956 };
957 let execute_root_span =
958 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
959 execute_root_span.follows_from(tracing::Span::current());
960 let state = self
961 .execute(
962 portal_name,
963 max_rows,
964 portal_exec_message,
965 None,
966 ExecuteTimeout::None,
967 None,
968 Some(received),
969 )
970 .instrument(execute_root_span)
971 .await?;
972 if self.adapter_client.session().transaction().is_implicit() {
987 self.txn_needs_commit = true;
988 }
989 state
990 }
991 Some(FrontendMessage::DescribeStatement { name }) => {
992 self.describe_statement(&name).await?
993 }
994 Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
995 Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
996 Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
997 Some(FrontendMessage::Flush) => self.flush().await?,
998 Some(FrontendMessage::Sync) => self.sync().await?,
999 Some(FrontendMessage::Terminate) => State::Done,
1000
1001 Some(FrontendMessage::CopyData(_))
1002 | Some(FrontendMessage::CopyDone)
1003 | Some(FrontendMessage::CopyFail(_))
1004 | Some(FrontendMessage::Password { .. })
1005 | Some(FrontendMessage::RawAuthentication(_))
1006 | Some(FrontendMessage::SASLInitialResponse { .. })
1007 | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
1008 None => State::Done,
1009 };
1010
1011 if let Some(start) = start {
1012 self.adapter_client
1013 .inner()
1014 .metrics()
1015 .pgwire_message_processing_seconds
1016 .with_label_values(&[message_name])
1017 .observe(start.elapsed().as_secs_f64());
1018 }
1019 self.adapter_client
1020 .inner()
1021 .metrics()
1022 .pgwire_recv_scheduling_delay_ms
1023 .with_label_values(&[message_name])
1024 .observe(recv_scheduling_delay_ms);
1025
1026 Ok(next_state)
1027 }
1028
1029 async fn advance_drain(&mut self) -> Result<State, io::Error> {
1030 let message = self.conn.recv().await?;
1031 if message.is_some() {
1032 self.adapter_client
1033 .remove_idle_in_transaction_session_timeout();
1034 }
1035 match message {
1036 Some(FrontendMessage::Sync) => self.sync().await,
1037 None => Ok(State::Done),
1038 _ => Ok(State::Drain),
1039 }
1040 }
1041
1042 #[instrument(level = "debug")]
1046 async fn one_query(
1047 &mut self,
1048 stmt: Statement<Raw>,
1049 sql: String,
1050 lifecycle_timestamps: LifecycleTimestamps,
1051 ) -> Result<State, io::Error> {
1052 const EMPTY_PORTAL: &str = "";
1055 if let Err(e) = self
1056 .adapter_client
1057 .declare(EMPTY_PORTAL.to_string(), stmt, sql)
1058 .await
1059 {
1060 return self
1061 .send_error_and_get_state(e.into_response(Severity::Error))
1062 .await;
1063 }
1064 let portal = self
1065 .adapter_client
1066 .session()
1067 .get_portal_unverified_mut(EMPTY_PORTAL)
1068 .expect("unnamed portal should be present");
1069
1070 *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
1071
1072 let stmt_desc = portal.desc.clone();
1073 if !stmt_desc.param_types.is_empty() {
1074 return self
1075 .send_error_and_get_state(ErrorResponse::error(
1076 SqlState::UNDEFINED_PARAMETER,
1077 "there is no parameter $1",
1078 ))
1079 .await;
1080 }
1081
1082 if let Some(relation_desc) = &stmt_desc.relation_desc {
1084 if !stmt_desc.is_copy {
1085 let formats = vec![Format::Text; stmt_desc.arity()];
1086 self.send(BackendMessage::RowDescription(
1087 message::encode_row_description(relation_desc, &formats),
1088 ))
1089 .await?;
1090 }
1091 }
1092
1093 let result = match self
1094 .adapter_client
1095 .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
1096 .await
1097 {
1098 Ok((response, execute_started)) => {
1099 self.send_pending_notices().await?;
1100 self.send_execute_response(
1101 response,
1102 stmt_desc.relation_desc,
1103 EMPTY_PORTAL.to_string(),
1104 ExecuteCount::All,
1105 portal_exec_message,
1106 None,
1107 ExecuteTimeout::None,
1108 execute_started,
1109 )
1110 .await
1111 }
1112 Err(e) => {
1113 self.send_pending_notices().await?;
1114 self.send_error_and_get_state(e.into_response(Severity::Error))
1115 .await
1116 }
1117 };
1118
1119 self.adapter_client.session().remove_portal(EMPTY_PORTAL);
1121
1122 result
1123 }
1124
1125 async fn ensure_transaction(
1126 &mut self,
1127 num_stmts: usize,
1128 message_type: &str,
1129 ) -> Result<(), io::Error> {
1130 let start = Instant::now();
1131 if self.txn_needs_commit {
1132 self.commit_transaction().await?;
1133 }
1134 let res = self.adapter_client.start_transaction(Some(num_stmts));
1137 assert_ok!(res);
1138 self.adapter_client
1139 .inner()
1140 .metrics()
1141 .pgwire_ensure_transaction_seconds
1142 .with_label_values(&[message_type])
1143 .observe(start.elapsed().as_secs_f64());
1144 Ok(())
1145 }
1146
1147 fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1148 let parse_start = Instant::now();
1149 let result = match self.adapter_client.parse(sql) {
1150 Ok(result) => result.map_err(|e| {
1151 let pos = sql[..e.error.pos].chars().count() + 1;
1154 ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1155 }),
1156 Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1157 };
1158 self.adapter_client
1159 .inner()
1160 .metrics()
1161 .parse_seconds
1162 .observe(parse_start.elapsed().as_secs_f64());
1163 result
1164 }
1165
1166 #[instrument(level = "debug")]
1171 async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1172 let stmts = match self.parse_sql(&sql) {
1174 Ok(stmts) => stmts,
1175 Err(err) => {
1176 self.send_error_and_get_state(err).await?;
1177 return self.ready().await;
1178 }
1179 };
1180
1181 let num_stmts = stmts.len();
1182
1183 for StatementParseResult { ast: stmt, sql } in stmts {
1185 if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1187 self.aborted_txn_error().await?;
1188 break;
1189 }
1190
1191 self.ensure_transaction(num_stmts, "query").await?;
1199
1200 match self
1201 .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1202 .await?
1203 {
1204 State::Ready => (),
1205 State::Drain => break,
1206 State::Done => return Ok(State::Done),
1207 }
1208 }
1209
1210 {
1212 if self.adapter_client.session().transaction().is_implicit() {
1213 self.commit_transaction().await?;
1214 }
1215 }
1216
1217 if num_stmts == 0 {
1218 self.send(BackendMessage::EmptyQueryResponse).await?;
1219 }
1220
1221 self.ready().await
1222 }
1223
1224 #[instrument(level = "debug")]
1225 async fn parse(
1226 &mut self,
1227 name: String,
1228 sql: String,
1229 param_oids: Vec<u32>,
1230 ) -> Result<State, io::Error> {
1231 self.ensure_transaction(1, "parse").await?;
1233
1234 let mut param_types = vec![];
1235 for oid in param_oids {
1236 match mz_pgrepr::Type::from_oid(oid) {
1237 Ok(ty) => match SqlScalarType::try_from(&ty) {
1238 Ok(ty) => param_types.push(Some(ty)),
1239 Err(err) => {
1240 return self
1241 .send_error_and_get_state(ErrorResponse::error(
1242 SqlState::INVALID_PARAMETER_VALUE,
1243 err.to_string(),
1244 ))
1245 .await;
1246 }
1247 },
1248 Err(_) if oid == 0 => param_types.push(None),
1249 Err(e) => {
1250 return self
1251 .send_error_and_get_state(ErrorResponse::error(
1252 SqlState::PROTOCOL_VIOLATION,
1253 e.to_string(),
1254 ))
1255 .await;
1256 }
1257 }
1258 }
1259
1260 let stmts = match self.parse_sql(&sql) {
1261 Ok(stmts) => stmts,
1262 Err(err) => {
1263 return self.send_error_and_get_state(err).await;
1264 }
1265 };
1266 if stmts.len() > 1 {
1267 return self
1268 .send_error_and_get_state(ErrorResponse::error(
1269 SqlState::INTERNAL_ERROR,
1270 "cannot insert multiple commands into a prepared statement",
1271 ))
1272 .await;
1273 }
1274 let (maybe_stmt, sql) = match stmts.into_iter().next() {
1275 None => (None, ""),
1276 Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1277 };
1278 if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1279 return self.aborted_txn_error().await;
1280 }
1281 match self
1282 .adapter_client
1283 .prepare(name, maybe_stmt, sql.to_string(), param_types)
1284 .await
1285 {
1286 Ok(()) => {
1287 self.send(BackendMessage::ParseComplete).await?;
1288 Ok(State::Ready)
1289 }
1290 Err(e) => {
1291 self.send_error_and_get_state(e.into_response(Severity::Error))
1292 .await
1293 }
1294 }
1295 }
1296
1297 #[instrument(level = "debug")]
1299 async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1300 self.end_transaction(EndTransactionAction::Commit).await
1301 }
1302
1303 #[instrument(level = "debug")]
1305 async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1306 self.end_transaction(EndTransactionAction::Rollback).await
1307 }
1308
1309 #[instrument(level = "debug")]
1311 async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1312 self.txn_needs_commit = false;
1313 let resp = self.adapter_client.end_transaction(action).await;
1314 if let Err(err) = resp {
1315 self.send(BackendMessage::ErrorResponse(
1316 err.into_response(Severity::Error),
1317 ))
1318 .await?;
1319 }
1320 Ok(())
1321 }
1322
1323 #[instrument(level = "debug")]
1324 async fn bind(
1325 &mut self,
1326 portal_name: String,
1327 statement_name: String,
1328 param_formats: Vec<Format>,
1329 raw_params: Vec<Option<Vec<u8>>>,
1330 result_formats: Vec<Format>,
1331 ) -> Result<State, io::Error> {
1332 self.ensure_transaction(1, "bind").await?;
1334
1335 let aborted_txn = self.is_aborted_txn();
1336 let stmt = match self
1337 .adapter_client
1338 .get_prepared_statement(&statement_name)
1339 .await
1340 {
1341 Ok(stmt) => stmt,
1342 Err(err) => {
1343 return self
1344 .send_error_and_get_state(err.into_response(Severity::Error))
1345 .await;
1346 }
1347 };
1348
1349 let param_types = &stmt.desc().param_types;
1350 if param_types.len() != raw_params.len() {
1351 let message = format!(
1352 "bind message supplies {actual} parameters, \
1353 but prepared statement \"{name}\" requires {expected}",
1354 name = statement_name,
1355 actual = raw_params.len(),
1356 expected = param_types.len()
1357 );
1358 return self
1359 .send_error_and_get_state(ErrorResponse::error(
1360 SqlState::PROTOCOL_VIOLATION,
1361 message,
1362 ))
1363 .await;
1364 }
1365 let param_formats = match pad_formats(param_formats, raw_params.len()) {
1366 Ok(param_formats) => param_formats,
1367 Err(msg) => {
1368 return self
1369 .send_error_and_get_state(ErrorResponse::error(
1370 SqlState::PROTOCOL_VIOLATION,
1371 msg,
1372 ))
1373 .await;
1374 }
1375 };
1376 if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1377 return self.aborted_txn_error().await;
1378 }
1379 let buf = RowArena::new();
1380 let mut params = vec![];
1381 for ((raw_param, mz_typ), format) in raw_params
1382 .into_iter()
1383 .zip_eq(param_types)
1384 .zip_eq(param_formats)
1385 {
1386 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1387 let datum = match raw_param {
1388 None => Datum::Null,
1389 Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1390 Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1391 Ok(datum) => datum,
1392 Err(msg) => {
1393 return self
1394 .send_error_and_get_state(ErrorResponse::error(
1395 SqlState::INVALID_PARAMETER_VALUE,
1396 msg,
1397 ))
1398 .await;
1399 }
1400 },
1401 Err(err) => {
1402 let msg = format!("unable to decode parameter: {}", err);
1403 return self
1404 .send_error_and_get_state(ErrorResponse::error(
1405 SqlState::INVALID_PARAMETER_VALUE,
1406 msg,
1407 ))
1408 .await;
1409 }
1410 },
1411 };
1412 params.push((datum, mz_typ.clone()))
1413 }
1414
1415 let result_formats = match pad_formats(
1416 result_formats,
1417 stmt.desc()
1418 .relation_desc
1419 .clone()
1420 .map(|desc| desc.typ().column_types.len())
1421 .unwrap_or(0),
1422 ) {
1423 Ok(result_formats) => result_formats,
1424 Err(msg) => {
1425 return self
1426 .send_error_and_get_state(ErrorResponse::error(
1427 SqlState::PROTOCOL_VIOLATION,
1428 msg,
1429 ))
1430 .await;
1431 }
1432 };
1433
1434 if !stmt.stmt().map_or(false, |stmt| match stmt {
1437 Statement::Copy(CopyStatement {
1438 direction: CopyDirection::To,
1439 ..
1440 }) => true,
1441 Statement::Copy(CopyStatement {
1442 direction: CopyDirection::From,
1443 target: CopyTarget::Expr(_),
1447 ..
1448 }) => true,
1449 _ => false,
1450 }) {
1451 if let Some(desc) = stmt.desc().relation_desc.clone() {
1452 for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1453 match (format, &ty.scalar_type) {
1454 (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1455 return self
1456 .send_error_and_get_state(ErrorResponse::error(
1457 SqlState::PROTOCOL_VIOLATION,
1458 "binary encoding of list types is not implemented",
1459 ))
1460 .await;
1461 }
1462 (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1463 return self
1464 .send_error_and_get_state(ErrorResponse::error(
1465 SqlState::PROTOCOL_VIOLATION,
1466 "binary encoding of map types is not implemented",
1467 ))
1468 .await;
1469 }
1470 (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1471 return self
1472 .send_error_and_get_state(ErrorResponse::error(
1473 SqlState::PROTOCOL_VIOLATION,
1474 "binary encoding of aclitem types does not exist",
1475 ))
1476 .await;
1477 }
1478 _ => (),
1479 }
1480 }
1481 }
1482 }
1483
1484 let desc = stmt.desc().clone();
1485 let logging = Arc::clone(stmt.logging());
1486 let stmt_ast = stmt.stmt().cloned();
1487 let state_revision = stmt.state_revision;
1488 if let Err(err) = self.adapter_client.session().set_portal(
1489 portal_name,
1490 desc,
1491 stmt_ast,
1492 logging,
1493 params,
1494 result_formats,
1495 state_revision,
1496 ) {
1497 return self
1498 .send_error_and_get_state(err.into_response(Severity::Error))
1499 .await;
1500 }
1501
1502 self.send(BackendMessage::BindComplete).await?;
1503 Ok(State::Ready)
1504 }
1505
1506 fn execute(
1509 &mut self,
1510 portal_name: String,
1511 max_rows: ExecuteCount,
1512 get_response: GetResponse,
1513 fetch_portal_name: Option<String>,
1514 timeout: ExecuteTimeout,
1515 outer_ctx_extra: Option<ExecuteContextGuard>,
1516 received: Option<EpochMillis>,
1517 ) -> BoxFuture<'_, Result<State, io::Error>> {
1518 async move {
1519 let aborted_txn = self.is_aborted_txn();
1520
1521 let portal = match self
1523 .adapter_client
1524 .session()
1525 .get_portal_unverified_mut(&portal_name)
1526 {
1527 Some(portal) => portal,
1528 None => {
1529 let msg = format!("portal {} does not exist", portal_name.quoted());
1530 if let Some(outer_ctx_extra) = outer_ctx_extra {
1531 self.adapter_client.retire_execute(
1532 outer_ctx_extra,
1533 StatementEndedExecutionReason::Errored { error: msg.clone() },
1534 );
1535 }
1536 return self
1537 .send_error_and_get_state(ErrorResponse::error(
1538 SqlState::INVALID_CURSOR_NAME,
1539 msg,
1540 ))
1541 .await;
1542 }
1543 };
1544
1545 *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1546
1547 let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1549 if aborted_txn && !txn_exit_stmt {
1550 if let Some(outer_ctx_extra) = outer_ctx_extra {
1551 self.adapter_client.retire_execute(
1552 outer_ctx_extra,
1553 StatementEndedExecutionReason::Errored {
1554 error: ABORTED_TXN_MSG.to_string(),
1555 },
1556 );
1557 }
1558 return self.aborted_txn_error().await;
1559 }
1560
1561 let row_desc = portal.desc.relation_desc.clone();
1562 match portal.state {
1563 PortalState::NotStarted => {
1564 self.ensure_transaction(1, "execute").await?;
1566 match self
1567 .adapter_client
1568 .execute(
1569 portal_name.clone(),
1570 self.conn.wait_closed(),
1571 outer_ctx_extra,
1572 )
1573 .await
1574 {
1575 Ok((response, execute_started)) => {
1576 self.send_pending_notices().await?;
1577 self.send_execute_response(
1578 response,
1579 row_desc,
1580 portal_name,
1581 max_rows,
1582 get_response,
1583 fetch_portal_name,
1584 timeout,
1585 execute_started,
1586 )
1587 .await
1588 }
1589 Err(e) => {
1590 self.send_pending_notices().await?;
1591 self.send_error_and_get_state(e.into_response(Severity::Error))
1592 .await
1593 }
1594 }
1595 }
1596 PortalState::InProgress(rows) => {
1597 let rows = rows.take().expect("InProgress rows must be populated");
1598 let (result, statement_ended_execution_reason) = match self
1599 .send_rows(
1600 row_desc.expect("portal missing row desc on resumption"),
1601 portal_name,
1602 rows,
1603 max_rows,
1604 get_response,
1605 fetch_portal_name,
1606 timeout,
1607 )
1608 .await
1609 {
1610 Err(e) => {
1611 (Err(e), StatementEndedExecutionReason::Canceled)
1614 }
1615 Ok((ok, SendRowsEndedReason::Canceled)) => {
1616 (Ok(ok), StatementEndedExecutionReason::Canceled)
1617 }
1618 Ok((
1631 ok,
1632 SendRowsEndedReason::Success {
1633 result_size: _,
1634 rows_returned: _,
1635 },
1636 )) => (
1637 Ok(ok),
1638 StatementEndedExecutionReason::Success {
1639 result_size: None,
1640 rows_returned: None,
1641 execution_strategy: None,
1642 },
1643 ),
1644 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1645 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1646 }
1647 };
1648 if let Some(outer_ctx_extra) = outer_ctx_extra {
1649 self.adapter_client
1650 .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1651 }
1652 result
1653 }
1654 PortalState::Completed(Some(tag)) => {
1661 let tag = tag.to_string();
1662 if let Some(outer_ctx_extra) = outer_ctx_extra {
1663 self.adapter_client.retire_execute(
1664 outer_ctx_extra,
1665 StatementEndedExecutionReason::Success {
1666 result_size: None,
1667 rows_returned: None,
1668 execution_strategy: None,
1669 },
1670 );
1671 }
1672 self.send(BackendMessage::CommandComplete { tag }).await?;
1673 Ok(State::Ready)
1674 }
1675 PortalState::Completed(None) => {
1676 let error = format!(
1677 "portal {} cannot be run",
1678 Ident::new_unchecked(portal_name).to_ast_string_stable()
1679 );
1680 if let Some(outer_ctx_extra) = outer_ctx_extra {
1681 self.adapter_client.retire_execute(
1682 outer_ctx_extra,
1683 StatementEndedExecutionReason::Errored {
1684 error: error.clone(),
1685 },
1686 );
1687 }
1688 self.send_error_and_get_state(ErrorResponse::error(
1689 SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1690 error,
1691 ))
1692 .await
1693 }
1694 }
1695 }
1696 .instrument(debug_span!("execute"))
1697 .boxed()
1698 }
1699
1700 #[instrument(level = "debug")]
1701 async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1702 self.ensure_transaction(1, "describe_statement").await?;
1704
1705 let stmt = match self.adapter_client.get_prepared_statement(name).await {
1706 Ok(stmt) => stmt,
1707 Err(err) => {
1708 return self
1709 .send_error_and_get_state(err.into_response(Severity::Error))
1710 .await;
1711 }
1712 };
1713 let parameter_desc = BackendMessage::ParameterDescription(
1715 stmt.desc()
1716 .param_types
1717 .iter()
1718 .map(mz_pgrepr::Type::from)
1719 .collect(),
1720 );
1721 let formats = vec![Format::Text; stmt.desc().arity()];
1725 let row_desc = describe_rows(stmt.desc(), &formats);
1726 self.send_all([parameter_desc, row_desc]).await?;
1727 Ok(State::Ready)
1728 }
1729
1730 #[instrument(level = "debug")]
1731 async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1732 self.ensure_transaction(1, "describe_portal").await?;
1734
1735 let session = self.adapter_client.session();
1736 let row_desc = session
1737 .get_portal_unverified(name)
1738 .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1739 match row_desc {
1740 Some(row_desc) => {
1741 self.send(row_desc).await?;
1742 Ok(State::Ready)
1743 }
1744 None => {
1745 self.send_error_and_get_state(ErrorResponse::error(
1746 SqlState::INVALID_CURSOR_NAME,
1747 format!("portal {} does not exist", name.quoted()),
1748 ))
1749 .await
1750 }
1751 }
1752 }
1753
1754 #[instrument(level = "debug")]
1755 async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1756 self.adapter_client
1757 .session()
1758 .remove_prepared_statement(&name);
1759 self.send(BackendMessage::CloseComplete).await?;
1760 Ok(State::Ready)
1761 }
1762
1763 #[instrument(level = "debug")]
1764 async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1765 self.adapter_client.session().remove_portal(&name);
1766 self.send(BackendMessage::CloseComplete).await?;
1767 Ok(State::Ready)
1768 }
1769
1770 fn complete_portal(&mut self, name: &str) {
1771 let portal = self
1772 .adapter_client
1773 .session()
1774 .get_portal_unverified_mut(name)
1775 .expect("portal should exist");
1776 *portal.state = PortalState::Completed(None);
1777 }
1778
1779 async fn fetch(
1780 &mut self,
1781 name: String,
1782 count: Option<FetchDirection>,
1783 max_rows: ExecuteCount,
1784 fetch_portal_name: Option<String>,
1785 timeout: ExecuteTimeout,
1786 ctx_extra: ExecuteContextGuard,
1787 ) -> Result<State, io::Error> {
1788 let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1791
1792 let count = match (max_rows, count) {
1805 (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1806 let count = usize::cast_from(count);
1807 if max_rows < count {
1808 let msg = "Execute with max_rows < a FETCH's count is not supported";
1809 self.adapter_client.retire_execute(
1810 ctx_extra,
1811 StatementEndedExecutionReason::Errored {
1812 error: msg.to_string(),
1813 },
1814 );
1815 return self
1816 .send_error_and_get_state(ErrorResponse::error(
1817 SqlState::FEATURE_NOT_SUPPORTED,
1818 msg,
1819 ))
1820 .await;
1821 }
1822 ExecuteCount::Count(count)
1823 }
1824 (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1825 let msg = "Execute with max_rows of a FETCH ALL is not supported";
1826 self.adapter_client.retire_execute(
1827 ctx_extra,
1828 StatementEndedExecutionReason::Errored {
1829 error: msg.to_string(),
1830 },
1831 );
1832 return self
1833 .send_error_and_get_state(ErrorResponse::error(
1834 SqlState::FEATURE_NOT_SUPPORTED,
1835 msg,
1836 ))
1837 .await;
1838 }
1839 (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1840 (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1841 ExecuteCount::Count(usize::cast_from(count))
1842 }
1843 };
1844 let cursor_name = name.to_string();
1845 self.execute(
1846 cursor_name,
1847 count,
1848 fetch_message,
1849 fetch_portal_name,
1850 timeout,
1851 Some(ctx_extra),
1852 None,
1853 )
1854 .await
1855 }
1856
1857 async fn flush(&mut self) -> Result<State, io::Error> {
1858 self.conn.flush().await?;
1859 Ok(State::Ready)
1860 }
1861
1862 #[instrument(level = "debug")]
1867 async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1868 where
1869 M: Into<BackendMessage>,
1870 {
1871 let message: BackendMessage = message.into();
1872 let is_error =
1873 matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1874
1875 self.conn.send(message).await?;
1876
1877 if is_error {
1884 self.conn.flush().await?;
1885 }
1886
1887 Ok(())
1888 }
1889
1890 #[instrument(level = "debug")]
1891 pub async fn send_all(
1892 &mut self,
1893 messages: impl IntoIterator<Item = BackendMessage>,
1894 ) -> Result<(), io::Error> {
1895 for m in messages {
1896 self.send(m).await?;
1897 }
1898 Ok(())
1899 }
1900
1901 #[instrument(level = "debug")]
1902 async fn sync(&mut self) -> Result<State, io::Error> {
1903 if self.adapter_client.session().transaction().is_implicit() {
1905 self.commit_transaction().await?;
1906 }
1907 self.ready().await
1908 }
1909
1910 #[instrument(level = "debug")]
1911 async fn ready(&mut self) -> Result<State, io::Error> {
1912 let txn_state = self.adapter_client.session().transaction().into();
1913 self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1914 self.flush().await
1915 }
1916
1917 #[allow(clippy::too_many_arguments)]
1918 #[instrument(level = "debug")]
1919 async fn send_execute_response(
1920 &mut self,
1921 response: ExecuteResponse,
1922 row_desc: Option<RelationDesc>,
1923 portal_name: String,
1924 max_rows: ExecuteCount,
1925 get_response: GetResponse,
1926 fetch_portal_name: Option<String>,
1927 timeout: ExecuteTimeout,
1928 execute_started: Instant,
1929 ) -> Result<State, io::Error> {
1930 let mut tag = response.tag();
1931
1932 macro_rules! command_complete {
1933 () => {{
1934 self.send(BackendMessage::CommandComplete {
1935 tag: tag
1936 .take()
1937 .expect("command_complete only called on tag-generating results"),
1938 })
1939 .await?;
1940 Ok(State::Ready)
1941 }};
1942 }
1943
1944 let r = match response {
1945 ExecuteResponse::ClosedCursor => {
1946 self.complete_portal(&portal_name);
1947 command_complete!()
1948 }
1949 ExecuteResponse::DeclaredCursor => {
1950 self.complete_portal(&portal_name);
1951 command_complete!()
1952 }
1953 ExecuteResponse::EmptyQuery => {
1954 self.send(BackendMessage::EmptyQueryResponse).await?;
1955 Ok(State::Ready)
1956 }
1957 ExecuteResponse::Fetch {
1958 name,
1959 count,
1960 timeout,
1961 ctx_extra,
1962 } => {
1963 self.fetch(
1964 name,
1965 count,
1966 max_rows,
1967 Some(portal_name.to_string()),
1968 timeout,
1969 ctx_extra,
1970 )
1971 .await
1972 }
1973 ExecuteResponse::SendingRowsStreaming {
1974 rows,
1975 instance_id,
1976 strategy,
1977 } => {
1978 let row_desc = row_desc
1979 .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1980
1981 let span = tracing::debug_span!("sending_rows_streaming");
1982
1983 self.send_rows(
1984 row_desc,
1985 portal_name,
1986 InProgressRows::new(RecordFirstRowStream::new(
1987 Box::new(rows),
1988 execute_started,
1989 &self.adapter_client,
1990 Some(instance_id),
1991 Some(strategy),
1992 )),
1993 max_rows,
1994 get_response,
1995 fetch_portal_name,
1996 timeout,
1997 )
1998 .instrument(span)
1999 .await
2000 .map(|(state, _)| state)
2001 }
2002 ExecuteResponse::SendingRowsImmediate { rows } => {
2003 let row_desc = row_desc
2004 .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
2005
2006 let span = tracing::debug_span!("sending_rows_immediate");
2007
2008 let stream =
2009 futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
2010 self.send_rows(
2011 row_desc,
2012 portal_name,
2013 InProgressRows::new(RecordFirstRowStream::new(
2014 Box::new(stream),
2015 execute_started,
2016 &self.adapter_client,
2017 None,
2018 Some(StatementExecutionStrategy::Constant),
2019 )),
2020 max_rows,
2021 get_response,
2022 fetch_portal_name,
2023 timeout,
2024 )
2025 .instrument(span)
2026 .await
2027 .map(|(state, _)| state)
2028 }
2029 ExecuteResponse::SetVariable { name, .. } => {
2030 let qn = name.to_string();
2033 let msg = if let Some(var) = self
2034 .adapter_client
2035 .session()
2036 .vars_mut()
2037 .notify_set()
2038 .find(|v| v.name() == qn)
2039 {
2040 Some(BackendMessage::ParameterStatus(var.name(), var.value()))
2041 } else {
2042 None
2043 };
2044 if let Some(msg) = msg {
2045 self.send(msg).await?;
2046 }
2047 command_complete!()
2048 }
2049 ExecuteResponse::Subscribing {
2050 rx,
2051 ctx_extra,
2052 instance_id,
2053 } => {
2054 if fetch_portal_name.is_none() {
2055 let mut msg = ErrorResponse::notice(
2056 SqlState::WARNING,
2057 "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
2058 );
2059 if self.adapter_client.session().vars().application_name() == "psql" {
2060 msg.hint = Some(
2061 "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
2062 .into(),
2063 )
2064 }
2065 self.send(msg).await?;
2066 self.conn.flush().await?;
2067 }
2068 let row_desc =
2069 row_desc.expect("missing row description for ExecuteResponse::Subscribing");
2070 let (result, statement_ended_execution_reason) = match self
2071 .send_rows(
2072 row_desc,
2073 portal_name,
2074 InProgressRows::new(RecordFirstRowStream::new(
2075 Box::new(UnboundedReceiverStream::new(rx)),
2076 execute_started,
2077 &self.adapter_client,
2078 Some(instance_id),
2079 None,
2080 )),
2081 max_rows,
2082 get_response,
2083 fetch_portal_name,
2084 timeout,
2085 )
2086 .await
2087 {
2088 Err(e) => {
2089 (Err(e), StatementEndedExecutionReason::Canceled)
2092 }
2093 Ok((ok, SendRowsEndedReason::Canceled)) => {
2094 (Ok(ok), StatementEndedExecutionReason::Canceled)
2095 }
2096 Ok((
2097 ok,
2098 SendRowsEndedReason::Success {
2099 result_size,
2100 rows_returned,
2101 },
2102 )) => (
2103 Ok(ok),
2104 StatementEndedExecutionReason::Success {
2105 result_size: Some(result_size),
2106 rows_returned: Some(rows_returned),
2107 execution_strategy: None,
2108 },
2109 ),
2110 Ok((ok, SendRowsEndedReason::Errored { error })) => {
2111 (Ok(ok), StatementEndedExecutionReason::Errored { error })
2112 }
2113 };
2114 self.adapter_client
2115 .retire_execute(ctx_extra, statement_ended_execution_reason);
2116 return result;
2117 }
2118 ExecuteResponse::CopyTo { format, resp } => {
2119 let row_desc =
2120 row_desc.expect("missing row description for ExecuteResponse::CopyTo");
2121 match *resp {
2122 ExecuteResponse::Subscribing {
2123 rx,
2124 ctx_extra,
2125 instance_id,
2126 } => {
2127 let (result, statement_ended_execution_reason) = match self
2128 .copy_rows(
2129 format,
2130 row_desc,
2131 RecordFirstRowStream::new(
2132 Box::new(UnboundedReceiverStream::new(rx)),
2133 execute_started,
2134 &self.adapter_client,
2135 Some(instance_id),
2136 None,
2137 ),
2138 )
2139 .await
2140 {
2141 Err(e) => {
2142 (Err(e), StatementEndedExecutionReason::Canceled)
2145 }
2146 Ok((
2147 state,
2148 SendRowsEndedReason::Success {
2149 result_size,
2150 rows_returned,
2151 },
2152 )) => (
2153 Ok(state),
2154 StatementEndedExecutionReason::Success {
2155 result_size: Some(result_size),
2156 rows_returned: Some(rows_returned),
2157 execution_strategy: None,
2158 },
2159 ),
2160 Ok((state, SendRowsEndedReason::Errored { error })) => {
2161 (Ok(state), StatementEndedExecutionReason::Errored { error })
2162 }
2163 Ok((state, SendRowsEndedReason::Canceled)) => {
2164 (Ok(state), StatementEndedExecutionReason::Canceled)
2165 }
2166 };
2167 self.adapter_client
2168 .retire_execute(ctx_extra, statement_ended_execution_reason);
2169 return result;
2170 }
2171 ExecuteResponse::SendingRowsStreaming {
2172 rows,
2173 instance_id,
2174 strategy,
2175 } => {
2176 return self
2181 .copy_rows(
2182 format,
2183 row_desc,
2184 RecordFirstRowStream::new(
2185 Box::new(rows),
2186 execute_started,
2187 &self.adapter_client,
2188 Some(instance_id),
2189 Some(strategy),
2190 ),
2191 )
2192 .await
2193 .map(|(state, _)| state);
2194 }
2195 ExecuteResponse::SendingRowsImmediate { rows } => {
2196 let span = tracing::debug_span!("sending_rows_immediate");
2197
2198 let rows = futures::stream::once(futures::future::ready(
2199 PeekResponseUnary::Rows(rows),
2200 ));
2201 return self
2206 .copy_rows(
2207 format,
2208 row_desc,
2209 RecordFirstRowStream::new(
2210 Box::new(rows),
2211 execute_started,
2212 &self.adapter_client,
2213 None,
2214 Some(StatementExecutionStrategy::Constant),
2215 ),
2216 )
2217 .instrument(span)
2218 .await
2219 .map(|(state, _)| state);
2220 }
2221 _ => {
2222 return self
2223 .send_error_and_get_state(ErrorResponse::error(
2224 SqlState::INTERNAL_ERROR,
2225 "unsupported COPY response type".to_string(),
2226 ))
2227 .await;
2228 }
2229 };
2230 }
2231 ExecuteResponse::CopyFrom {
2232 target_id,
2233 target_name,
2234 columns,
2235 params,
2236 ctx_extra,
2237 } => {
2238 let row_desc =
2239 row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2240 self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2241 .await
2242 }
2243 ExecuteResponse::TransactionCommitted { params }
2244 | ExecuteResponse::TransactionRolledBack { params } => {
2245 let notify_set: mz_ore::collections::HashSet<String> = self
2246 .adapter_client
2247 .session()
2248 .vars()
2249 .notify_set()
2250 .map(|v| v.name().to_string())
2251 .collect();
2252
2253 for (name, value) in params
2255 .into_iter()
2256 .filter(|(name, _v)| notify_set.contains(*name))
2257 {
2258 let msg = BackendMessage::ParameterStatus(name, value);
2259 self.send(msg).await?;
2260 }
2261 command_complete!()
2262 }
2263
2264 ExecuteResponse::AlteredDefaultPrivileges
2265 | ExecuteResponse::AlteredObject(..)
2266 | ExecuteResponse::AlteredRole
2267 | ExecuteResponse::AlteredSystemConfiguration
2268 | ExecuteResponse::CreatedCluster { .. }
2269 | ExecuteResponse::CreatedClusterReplica { .. }
2270 | ExecuteResponse::CreatedConnection { .. }
2271 | ExecuteResponse::CreatedDatabase { .. }
2272 | ExecuteResponse::CreatedIndex { .. }
2273 | ExecuteResponse::CreatedIntrospectionSubscribe
2274 | ExecuteResponse::CreatedMaterializedView { .. }
2275 | ExecuteResponse::CreatedContinualTask { .. }
2276 | ExecuteResponse::CreatedRole
2277 | ExecuteResponse::CreatedSchema { .. }
2278 | ExecuteResponse::CreatedSecret { .. }
2279 | ExecuteResponse::CreatedSink { .. }
2280 | ExecuteResponse::CreatedSource { .. }
2281 | ExecuteResponse::CreatedTable { .. }
2282 | ExecuteResponse::CreatedType
2283 | ExecuteResponse::CreatedView { .. }
2284 | ExecuteResponse::CreatedViews { .. }
2285 | ExecuteResponse::CreatedNetworkPolicy
2286 | ExecuteResponse::Comment
2287 | ExecuteResponse::Deallocate { .. }
2288 | ExecuteResponse::Deleted(..)
2289 | ExecuteResponse::DiscardedAll
2290 | ExecuteResponse::DiscardedTemp
2291 | ExecuteResponse::DroppedObject(_)
2292 | ExecuteResponse::DroppedOwned
2293 | ExecuteResponse::GrantedPrivilege
2294 | ExecuteResponse::GrantedRole
2295 | ExecuteResponse::Inserted(..)
2296 | ExecuteResponse::Copied(..)
2297 | ExecuteResponse::Prepare
2298 | ExecuteResponse::Raised
2299 | ExecuteResponse::ReassignOwned
2300 | ExecuteResponse::RevokedPrivilege
2301 | ExecuteResponse::RevokedRole
2302 | ExecuteResponse::StartedTransaction { .. }
2303 | ExecuteResponse::Updated(..)
2304 | ExecuteResponse::ValidatedConnection => {
2305 command_complete!()
2306 }
2307 };
2308
2309 assert_none!(tag, "tag created but not consumed: {:?}", tag);
2310 r
2311 }
2312
2313 #[allow(clippy::too_many_arguments)]
2314 #[mz_ore::instrument(level = "debug")]
2316 async fn send_rows(
2317 &mut self,
2318 row_desc: RelationDesc,
2319 portal_name: String,
2320 mut rows: InProgressRows,
2321 max_rows: ExecuteCount,
2322 get_response: GetResponse,
2323 fetch_portal_name: Option<String>,
2324 timeout: ExecuteTimeout,
2325 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2326 let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2329 name
2330 } else {
2331 &portal_name
2332 };
2333 let result_formats = self
2334 .adapter_client
2335 .session()
2336 .get_portal_unverified(result_format_portal_name)
2337 .expect("valid fetch portal name for send rows")
2338 .result_formats
2339 .clone();
2340
2341 let (mut wait_once, mut deadline) = match timeout {
2342 ExecuteTimeout::None => (false, None),
2343 ExecuteTimeout::Seconds(t) => (
2344 false,
2345 Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2346 ),
2347 ExecuteTimeout::WaitOnce => (true, None),
2348 };
2349
2350 {
2352 let portal_name_desc = &self
2353 .adapter_client
2354 .session()
2355 .get_portal_unverified(portal_name.as_str())
2356 .expect("portal should exist")
2357 .desc
2358 .relation_desc;
2359 if let Some(portal_name_desc) = portal_name_desc {
2360 soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2361 }
2362 if let Some(fetch_portal_name) = &fetch_portal_name {
2363 let fetch_portal_desc = &self
2364 .adapter_client
2365 .session()
2366 .get_portal_unverified(fetch_portal_name)
2367 .expect("portal should exist")
2368 .desc
2369 .relation_desc;
2370 if let Some(fetch_portal_desc) = fetch_portal_desc {
2371 soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2372 }
2373 }
2374 }
2375
2376 self.conn.set_encode_state(
2377 row_desc
2378 .typ()
2379 .column_types
2380 .iter()
2381 .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2382 .zip_eq(result_formats)
2383 .collect(),
2384 );
2385
2386 let mut total_sent_rows = 0;
2387 let mut total_sent_bytes = 0;
2388 let mut want_rows = match max_rows {
2390 ExecuteCount::All => usize::MAX,
2391 ExecuteCount::Count(count) => count,
2392 };
2393
2394 loop {
2396 let batch = if rows.current.is_some() {
2399 FetchResult::Rows(rows.current.take())
2400 } else if want_rows == 0 {
2401 FetchResult::Rows(None)
2402 } else {
2403 let notice_fut = self.adapter_client.session().recv_notice();
2404 tokio::select! {
2419 biased;
2420 err = self.conn.wait_closed() => return Err(err),
2421 batch = rows.remaining.recv() => match batch {
2422 None => FetchResult::Rows(None),
2423 Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2424 Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2425 Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2426 },
2427 notice = notice_fut => {
2428 FetchResult::Notice(notice)
2429 }
2430 _ = time::sleep_until(
2431 deadline.unwrap_or_else(tokio::time::Instant::now),
2432 ), if deadline.is_some() => FetchResult::Rows(None),
2433 }
2434 };
2435
2436 match batch {
2437 FetchResult::Rows(None) => break,
2438 FetchResult::Rows(Some(mut batch_rows)) => {
2439 if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2440 let msg = err.to_string();
2441 return self
2442 .send_error_and_get_state(err.into_response(Severity::Error))
2443 .await
2444 .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2445 }
2446
2447 if wait_once && batch_rows.peek().is_some() {
2451 deadline = Some(tokio::time::Instant::now());
2452 wait_once = false;
2453 }
2454
2455 let mut sent_rows = 0;
2457 let mut sent_bytes = 0;
2458 let messages = (&mut batch_rows)
2459 .map(|row| {
2464 let row_len = row.byte_len();
2465 let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2466 (row_len, BackendMessage::DataRow(values))
2467 })
2468 .inspect(|(row_len, _)| {
2469 sent_bytes += row_len;
2470 sent_rows += 1
2471 })
2472 .map(|(_row_len, row)| row)
2473 .take(want_rows);
2474 self.send_all(messages).await?;
2475
2476 total_sent_rows += sent_rows;
2477 total_sent_bytes += sent_bytes;
2478 want_rows -= sent_rows;
2479
2480 if want_rows == 0 {
2483 if batch_rows.peek().is_some() {
2484 rows.current = Some(batch_rows);
2485 }
2486 break;
2487 }
2488
2489 self.conn.flush().await?;
2490 }
2491 FetchResult::Notice(notice) => {
2492 self.send(notice.into_response()).await?;
2493 self.conn.flush().await?;
2494 }
2495 FetchResult::Error(text) => {
2496 return self
2497 .send_error_and_get_state(ErrorResponse::error(
2498 SqlState::INTERNAL_ERROR,
2499 text.clone(),
2500 ))
2501 .await
2502 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2503 }
2504 FetchResult::Canceled => {
2505 return self
2506 .send_error_and_get_state(ErrorResponse::error(
2507 SqlState::QUERY_CANCELED,
2508 "canceling statement due to user request",
2509 ))
2510 .await
2511 .map(|state| (state, SendRowsEndedReason::Canceled));
2512 }
2513 }
2514 }
2515
2516 let portal = self
2517 .adapter_client
2518 .session()
2519 .get_portal_unverified_mut(&portal_name)
2520 .expect("valid portal name for send rows");
2521
2522 let saw_rows = rows.remaining.saw_rows;
2523 let no_more_rows = rows.no_more_rows();
2524 let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2525
2526 *portal.state = PortalState::InProgress(Some(rows));
2529
2530 let fetch_portal = fetch_portal_name.map(|name| {
2531 self.adapter_client
2532 .session()
2533 .get_portal_unverified_mut(&name)
2534 .expect("valid fetch portal")
2535 });
2536 let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2537 self.send(response_message).await?;
2538
2539 if no_more_rows {
2541 let statement_type = if let Some(stmt) = &self
2542 .adapter_client
2543 .session()
2544 .get_portal_unverified(&portal_name)
2545 .expect("valid portal name for send_rows")
2546 .stmt
2547 {
2548 metrics::statement_type_label_value(stmt.deref())
2549 } else {
2550 "no-statement"
2551 };
2552 let duration = if saw_rows {
2553 recorded_first_row_instant
2554 .expect("recorded_first_row_instant because saw_rows")
2555 .elapsed()
2556 } else {
2557 Duration::ZERO
2561 };
2562 self.adapter_client
2563 .inner()
2564 .metrics()
2565 .result_rows_first_to_last_byte_seconds
2566 .with_label_values(&[statement_type])
2567 .observe(duration.as_secs_f64());
2568 }
2569
2570 Ok((
2571 State::Ready,
2572 SendRowsEndedReason::Success {
2573 result_size: u64::cast_from(total_sent_bytes),
2574 rows_returned: u64::cast_from(total_sent_rows),
2575 },
2576 ))
2577 }
2578
2579 #[mz_ore::instrument(level = "debug")]
2580 async fn copy_rows(
2581 &mut self,
2582 format: CopyFormat,
2583 row_desc: RelationDesc,
2584 mut stream: RecordFirstRowStream,
2585 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2586 let (row_format, encode_format) = match format {
2587 CopyFormat::Text => (
2588 CopyFormatParams::Text(CopyTextFormatParams::default()),
2589 Format::Text,
2590 ),
2591 CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2592 CopyFormat::Csv => (
2593 CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2594 Format::Text,
2595 ),
2596 CopyFormat::Parquet => {
2597 let text = "Parquet format is not supported".to_string();
2598 return self
2599 .send_error_and_get_state(ErrorResponse::error(
2600 SqlState::INTERNAL_ERROR,
2601 text.clone(),
2602 ))
2603 .await
2604 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2605 }
2606 };
2607
2608 let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2609 mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2610 };
2611
2612 let typ = row_desc.typ();
2613 let column_formats = iter::repeat(encode_format)
2614 .take(typ.column_types.len())
2615 .collect();
2616 self.send(BackendMessage::CopyOutResponse {
2617 overall_format: encode_format,
2618 column_formats,
2619 })
2620 .await?;
2621
2622 let mut out = Vec::new();
2627
2628 if let CopyFormat::Binary = format {
2629 out.extend(b"PGCOPY\n\xFF\r\n\0");
2631 out.extend([0, 0, 0, 0]);
2633 out.extend([0, 0, 0, 0]);
2635 }
2636
2637 let mut count = 0;
2638 let mut total_sent_bytes = 0;
2639 loop {
2640 tokio::select! {
2641 e = self.conn.wait_closed() => return Err(e),
2642 batch = stream.recv() => match batch {
2643 None => break,
2644 Some(PeekResponseUnary::Error(text)) => {
2645 let err =
2646 ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone());
2647 return self
2648 .send_error_and_get_state(err)
2649 .await
2650 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2651 }
2652 Some(PeekResponseUnary::Canceled) => {
2653 return self.send_error_and_get_state(ErrorResponse::error(
2654 SqlState::QUERY_CANCELED,
2655 "canceling statement due to user request",
2656 ))
2657 .await.map(|state| (state, SendRowsEndedReason::Canceled));
2658 }
2659 Some(PeekResponseUnary::Rows(mut rows)) => {
2660 count += rows.count();
2661 while let Some(row) = rows.next() {
2662 total_sent_bytes += row.byte_len();
2663 encode_fn(row, typ, &mut out)?;
2664 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2665 .await?;
2666 }
2667 }
2668 },
2669 notice = self.adapter_client.session().recv_notice() => {
2670 self.send(notice.into_response())
2671 .await?;
2672 self.conn.flush().await?;
2673 }
2674 }
2675
2676 self.conn.flush().await?;
2677 }
2678 if let CopyFormat::Binary = format {
2680 let trailer: i16 = -1;
2681 out.extend(trailer.to_be_bytes());
2682 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2683 .await?;
2684 }
2685
2686 let tag = format!("COPY {}", count);
2687 self.send(BackendMessage::CopyDone).await?;
2688 self.send(BackendMessage::CommandComplete { tag }).await?;
2689 Ok((
2690 State::Ready,
2691 SendRowsEndedReason::Success {
2692 result_size: u64::cast_from(total_sent_bytes),
2693 rows_returned: u64::cast_from(count),
2694 },
2695 ))
2696 }
2697
2698 #[instrument(level = "debug")]
2701 async fn copy_from(
2702 &mut self,
2703 target_id: CatalogItemId,
2704 target_name: String,
2705 columns: Vec<ColumnIndex>,
2706 params: CopyFormatParams<'static>,
2707 row_desc: RelationDesc,
2708 mut ctx_extra: ExecuteContextGuard,
2709 ) -> Result<State, io::Error> {
2710 let res = self
2711 .copy_from_inner(
2712 target_id,
2713 target_name,
2714 columns,
2715 params,
2716 row_desc,
2717 &mut ctx_extra,
2718 )
2719 .await;
2720 match &res {
2721 Ok(State::Ready) => {
2722 self.adapter_client.retire_execute(
2723 ctx_extra,
2724 StatementEndedExecutionReason::Success {
2725 result_size: None,
2726 rows_returned: None,
2727 execution_strategy: None,
2728 },
2729 );
2730 }
2731 Ok(State::Done) => {
2732 self.adapter_client
2736 .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2737 }
2738 Err(e) => {
2739 self.adapter_client.retire_execute(
2740 ctx_extra,
2741 StatementEndedExecutionReason::Errored {
2742 error: format!("{e}"),
2743 },
2744 );
2745 }
2746 Ok(State::Drain) => {}
2747 }
2748 res
2749 }
2750
2751 async fn copy_from_inner(
2752 &mut self,
2753 target_id: CatalogItemId,
2754 target_name: String,
2755 columns: Vec<ColumnIndex>,
2756 params: CopyFormatParams<'static>,
2757 row_desc: RelationDesc,
2758 ctx_extra: &mut ExecuteContextGuard,
2759 ) -> Result<State, io::Error> {
2760 let typ = row_desc.typ();
2761 let column_formats = vec![Format::Text; typ.column_types.len()];
2762 self.send(BackendMessage::CopyInResponse {
2763 overall_format: Format::Text,
2764 column_formats,
2765 })
2766 .await?;
2767 self.conn.flush().await?;
2768
2769 let writer = match self
2771 .adapter_client
2772 .start_copy_from_stdin(
2773 target_id,
2774 target_name.clone(),
2775 columns.clone(),
2776 row_desc.clone(),
2777 params.clone(),
2778 )
2779 .await
2780 {
2781 Ok(writer) => writer,
2782 Err(e) => {
2783 loop {
2789 match self.conn.recv().await? {
2790 Some(FrontendMessage::CopyData(_)) => {}
2791 Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2792 break;
2793 }
2794 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2795 Some(_) => break,
2796 None => return Ok(State::Done),
2797 }
2798 }
2799 self.adapter_client.retire_execute(
2800 std::mem::take(ctx_extra),
2801 StatementEndedExecutionReason::Errored {
2802 error: e.to_string(),
2803 },
2804 );
2805 return self
2806 .send_error_and_get_state(e.into_response(Severity::Error))
2807 .await;
2808 }
2809 };
2810
2811 self.conn.set_copy_mode(true);
2813
2814 const BATCH_SIZE: usize = 32 * 1024 * 1024;
2816 let max_copy_from_row_size = self
2817 .adapter_client
2818 .get_system_vars()
2819 .await
2820 .max_copy_from_row_size()
2821 .try_into()
2822 .unwrap_or(usize::MAX);
2823
2824 let mut data = Vec::new();
2825 let mut row_scanner = CopyRowScanner::new(¶ms);
2826 let num_workers = writer.batch_txs.len();
2827 let mut next_worker: usize = 0;
2828 let mut saw_copy_done = false;
2829 let mut saw_end_marker = false;
2830 let mut copy_from_error: Option<(SqlState, String)> = None;
2831
2832 loop {
2835 let message = self.conn.recv().await?;
2836 match message {
2837 Some(FrontendMessage::CopyData(buf)) => {
2838 if saw_end_marker {
2839 continue;
2842 }
2843 data.extend(buf);
2844 row_scanner.scan_new_bytes(&data);
2845
2846 if let Some(end_pos) = row_scanner.end_marker_end() {
2847 data.truncate(end_pos);
2848 row_scanner.on_truncate(end_pos);
2849 saw_end_marker = true;
2850 }
2851
2852 if row_scanner.current_row_size(data.len()) > max_copy_from_row_size {
2854 copy_from_error = Some((
2855 SqlState::INSUFFICIENT_RESOURCES,
2856 format!(
2857 "COPY FROM STDIN row exceeded max_copy_from_row_size \
2858 ({max_copy_from_row_size} bytes)"
2859 ),
2860 ));
2861 break;
2862 }
2863
2864 let mut send_failed = false;
2867 while data.len() >= BATCH_SIZE {
2868 let split_pos = match row_scanner.last_row_end() {
2869 Some(pos) => pos,
2870 None => break, };
2872 let remainder = data.split_off(split_pos);
2873 let chunk = std::mem::replace(&mut data, remainder);
2874 row_scanner.on_split(split_pos);
2875 if writer.batch_txs[next_worker].send(chunk).await.is_err() {
2876 send_failed = true;
2877 break;
2878 }
2879 next_worker = (next_worker + 1) % num_workers;
2880 }
2881 if send_failed {
2884 break;
2885 }
2886 }
2887 Some(FrontendMessage::CopyDone) => {
2888 if !data.is_empty() {
2890 let chunk = std::mem::take(&mut data);
2891 let _ = writer.batch_txs[next_worker].send(chunk).await;
2893 }
2894 saw_copy_done = true;
2895 break;
2896 }
2897 Some(FrontendMessage::CopyFail(err)) => {
2898 self.adapter_client.retire_execute(
2899 std::mem::take(ctx_extra),
2900 StatementEndedExecutionReason::Canceled,
2901 );
2902 drop(writer);
2904 self.conn.set_copy_mode(false);
2905 return self
2906 .send_error_and_get_state(ErrorResponse::error(
2907 SqlState::QUERY_CANCELED,
2908 format!("COPY from stdin failed: {}", err),
2909 ))
2910 .await;
2911 }
2912 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2913 Some(_) => {
2914 let msg = "unexpected message type during COPY from stdin";
2915 self.adapter_client.retire_execute(
2916 std::mem::take(ctx_extra),
2917 StatementEndedExecutionReason::Errored {
2918 error: msg.to_string(),
2919 },
2920 );
2921 drop(writer);
2922 self.conn.set_copy_mode(false);
2923 return self
2924 .send_error_and_get_state(ErrorResponse::error(
2925 SqlState::PROTOCOL_VIOLATION,
2926 msg,
2927 ))
2928 .await;
2929 }
2930 None => {
2931 drop(writer);
2932 self.conn.set_copy_mode(false);
2933 return Ok(State::Done);
2934 }
2935 }
2936 }
2937
2938 if !saw_copy_done {
2942 loop {
2943 match self.conn.recv().await? {
2944 Some(FrontendMessage::CopyData(_)) => {}
2945 Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2946 break;
2947 }
2948 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2949 Some(_) => {
2950 let msg = "unexpected message type during COPY from stdin";
2951 self.adapter_client.retire_execute(
2952 std::mem::take(ctx_extra),
2953 StatementEndedExecutionReason::Errored {
2954 error: msg.to_string(),
2955 },
2956 );
2957 drop(writer);
2958 self.conn.set_copy_mode(false);
2959 return self
2960 .send_error_and_get_state(ErrorResponse::error(
2961 SqlState::PROTOCOL_VIOLATION,
2962 msg,
2963 ))
2964 .await;
2965 }
2966 None => {
2967 drop(writer);
2968 self.conn.set_copy_mode(false);
2969 return Ok(State::Done);
2970 }
2971 }
2972 }
2973 }
2974
2975 if let Some((code, msg)) = copy_from_error {
2976 self.adapter_client.retire_execute(
2977 std::mem::take(ctx_extra),
2978 StatementEndedExecutionReason::Errored { error: msg.clone() },
2979 );
2980 drop(writer);
2981 self.conn.set_copy_mode(false);
2982 return self
2983 .send_error_and_get_state(ErrorResponse::error(code, msg))
2984 .await;
2985 }
2986
2987 self.conn.set_copy_mode(false);
2988
2989 drop(writer.batch_txs);
2994
2995 let (proto_batches, row_count) = match writer.completion_rx.await {
2997 Ok(Ok(result)) => result,
2998 Ok(Err(e)) => {
2999 self.adapter_client.retire_execute(
3000 std::mem::take(ctx_extra),
3001 StatementEndedExecutionReason::Errored {
3002 error: e.to_string(),
3003 },
3004 );
3005 return self
3006 .send_error_and_get_state(e.into_response(Severity::Error))
3007 .await;
3008 }
3009 Err(_) => {
3010 let msg = "COPY FROM STDIN: background batch builder tasks dropped";
3011 self.adapter_client.retire_execute(
3012 std::mem::take(ctx_extra),
3013 StatementEndedExecutionReason::Errored {
3014 error: msg.to_string(),
3015 },
3016 );
3017 return self
3018 .send_error_and_get_state(ErrorResponse::error(SqlState::INTERNAL_ERROR, msg))
3019 .await;
3020 }
3021 };
3022
3023 if let Err(e) = self
3025 .adapter_client
3026 .stage_copy_from_stdin_batches(target_id, proto_batches)
3027 {
3028 self.adapter_client.retire_execute(
3029 std::mem::take(ctx_extra),
3030 StatementEndedExecutionReason::Errored {
3031 error: e.to_string(),
3032 },
3033 );
3034 return self
3035 .send_error_and_get_state(e.into_response(Severity::Error))
3036 .await;
3037 }
3038
3039 let tag = format!("COPY {}", row_count);
3040 self.send(BackendMessage::CommandComplete { tag }).await?;
3041
3042 Ok(State::Ready)
3043 }
3044
3045 #[instrument(level = "debug")]
3046 async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
3047 let notices = self
3048 .adapter_client
3049 .session()
3050 .drain_notices()
3051 .into_iter()
3052 .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
3053 self.send_all(notices).await?;
3054 Ok(())
3055 }
3056
3057 #[instrument(level = "debug")]
3058 async fn send_error_and_get_state(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
3059 assert!(err.severity.is_error());
3060 debug!(
3061 "cid={} error code={}",
3062 self.adapter_client.session().conn_id(),
3063 err.code.code()
3064 );
3065 let is_fatal = err.severity.is_fatal();
3066 self.send(BackendMessage::ErrorResponse(err)).await?;
3067
3068 let txn = self.adapter_client.session().transaction();
3069 match txn {
3070 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
3073 TransactionStatus::Started(_) => {
3075 self.rollback_transaction().await?;
3076 }
3077 TransactionStatus::InTransactionImplicit(_) => {
3079 self.rollback_transaction().await?;
3080 }
3081 TransactionStatus::InTransaction(_) => {
3083 self.adapter_client.fail_transaction();
3084 }
3085 };
3086 if is_fatal {
3087 Ok(State::Done)
3088 } else {
3089 Ok(State::Drain)
3090 }
3091 }
3092
3093 #[instrument(level = "debug")]
3094 async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
3095 self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
3096 SqlState::IN_FAILED_SQL_TRANSACTION,
3097 ABORTED_TXN_MSG,
3098 )))
3099 .await?;
3100 Ok(State::Drain)
3101 }
3102
3103 fn is_aborted_txn(&mut self) -> bool {
3104 matches!(
3105 self.adapter_client.session().transaction(),
3106 TransactionStatus::Failed(_)
3107 )
3108 }
3109}
3110
3111fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
3112 match (formats.len(), n) {
3113 (0, e) => Ok(vec![Format::Text; e]),
3114 (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
3115 (a, e) if a == e => Ok(formats),
3116 (a, e) => Err(format!(
3117 "expected {} field format specifiers, but got {}",
3118 e, a
3119 )),
3120 }
3121}
3122
3123fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
3124 match &stmt_desc.relation_desc {
3125 Some(desc) if !stmt_desc.is_copy => {
3126 BackendMessage::RowDescription(message::encode_row_description(desc, formats))
3127 }
3128 _ => BackendMessage::NoData,
3129 }
3130}
3131
3132type GetResponse = fn(
3133 max_rows: ExecuteCount,
3134 total_sent_rows: usize,
3135 fetch_portal: Option<PortalRefMut>,
3136) -> BackendMessage;
3137
3138fn portal_exec_message(
3141 max_rows: ExecuteCount,
3142 total_sent_rows: usize,
3143 _fetch_portal: Option<PortalRefMut>,
3144) -> BackendMessage {
3145 match max_rows {
3152 ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
3153 BackendMessage::PortalSuspended
3154 }
3155 _ => BackendMessage::CommandComplete {
3156 tag: format!("SELECT {}", total_sent_rows),
3157 },
3158 }
3159}
3160
3161fn fetch_message(
3163 _max_rows: ExecuteCount,
3164 total_sent_rows: usize,
3165 fetch_portal: Option<PortalRefMut>,
3166) -> BackendMessage {
3167 let tag = format!("FETCH {}", total_sent_rows);
3168 if let Some(portal) = fetch_portal {
3169 *portal.state = PortalState::Completed(Some(tag.clone()));
3170 }
3171 BackendMessage::CommandComplete { tag }
3172}
3173
3174fn get_authenticator(
3175 authenticator_kind: listeners::AuthenticatorKind,
3176 frontegg: Option<FronteggAuthenticator>,
3177 oidc: GenericOidcAuthenticator,
3178 adapter_client: mz_adapter::Client,
3179 oidc_auth_option_enabled: bool,
3182) -> Authenticator {
3183 match authenticator_kind {
3184 listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
3185 "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
3186 )),
3187 listeners::AuthenticatorKind::Password => Authenticator::Password(adapter_client),
3188 listeners::AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client),
3189 listeners::AuthenticatorKind::Oidc => {
3190 if oidc_auth_option_enabled {
3191 Authenticator::Oidc(oidc)
3192 } else {
3193 Authenticator::Password(adapter_client)
3196 }
3197 }
3198 listeners::AuthenticatorKind::None => Authenticator::None,
3199 }
3200}
3201
3202#[derive(Debug, Copy, Clone)]
3203enum ExecuteCount {
3204 All,
3205 Count(usize),
3206}
3207
3208fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
3210 match stmt {
3211 Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
3213 None => false,
3214 }
3215}
3216
3217#[derive(Debug)]
3218enum FetchResult {
3219 Rows(Option<Box<dyn RowIterator + Send + Sync>>),
3220 Canceled,
3221 Error(String),
3222 Notice(AdapterNotice),
3223}
3224
3225#[derive(Debug)]
3226struct CopyRowScanner {
3227 scan_pos: usize,
3228 last_row_end: Option<usize>,
3229 end_marker_end: Option<usize>,
3230 csv: Option<CsvScanState>,
3231}
3232
3233#[derive(Debug)]
3234struct CsvScanState {
3235 reader: csv_core::Reader,
3236 output: Vec<u8>,
3237 ends: Vec<usize>,
3238 record: Vec<u8>,
3239 record_ends: Vec<usize>,
3240 skip_first_record: bool,
3241}
3242
3243impl CopyRowScanner {
3244 fn new(params: &CopyFormatParams<'_>) -> Self {
3245 let csv = match params {
3246 CopyFormatParams::Csv(CopyCsvFormatParams {
3247 delimiter,
3248 quote,
3249 escape,
3250 header,
3251 ..
3252 }) => Some(CsvScanState::new(*delimiter, *quote, *escape, *header)),
3253 _ => None,
3254 };
3255
3256 CopyRowScanner {
3257 scan_pos: 0,
3258 last_row_end: None,
3259 end_marker_end: None,
3260 csv,
3261 }
3262 }
3263
3264 fn scan_new_bytes(&mut self, data: &[u8]) {
3265 if self.scan_pos >= data.len() {
3266 return;
3267 }
3268
3269 if let Some(csv) = self.csv.as_mut() {
3270 let mut input = &data[self.scan_pos..];
3271 let mut consumed = 0usize;
3272 while !input.is_empty() {
3273 let (result, n_input, n_output, n_ends) =
3274 csv.reader
3275 .read_record(input, &mut csv.output, &mut csv.ends);
3276 consumed += n_input;
3277 input = &input[n_input..];
3278 if !csv.output.is_empty() {
3279 csv.record.extend_from_slice(&csv.output[..n_output]);
3280 }
3281 if !csv.ends.is_empty() {
3282 csv.record_ends.extend_from_slice(&csv.ends[..n_ends]);
3283 }
3284
3285 match result {
3286 ReadRecordResult::InputEmpty => break,
3287 ReadRecordResult::OutputFull => {
3288 if n_input == 0 {
3289 csv.output
3290 .resize(csv.output.len().saturating_mul(2).max(1), 0);
3291 }
3292 }
3293 ReadRecordResult::OutputEndsFull => {
3294 if n_input == 0 {
3295 csv.ends.resize(csv.ends.len().saturating_mul(2).max(1), 0);
3296 }
3297 }
3298 ReadRecordResult::Record | ReadRecordResult::End => {
3299 let row_end = self.scan_pos + consumed;
3300 self.last_row_end = Some(row_end);
3301 if self.end_marker_end.is_none() {
3302 let is_marker = if csv.skip_first_record {
3303 csv.skip_first_record = false;
3304 false
3305 } else if csv.record_ends.len() == 1 {
3306 let end = csv.record_ends[0];
3307 end == 2 && csv.record.get(0..end) == Some(b"\\.")
3308 } else {
3309 false
3310 };
3311 if is_marker {
3312 self.end_marker_end = Some(row_end);
3313 break;
3314 }
3315 }
3316 csv.record.clear();
3317 csv.record_ends.clear();
3318 }
3319 }
3320 }
3321 } else {
3322 let mut row_start = self.last_row_end.unwrap_or(0);
3323 for (offset, b) in data[self.scan_pos..].iter().enumerate() {
3324 if *b == b'\n' {
3325 let row_end = self.scan_pos + offset + 1;
3326 self.last_row_end = Some(row_end);
3327 if self.end_marker_end.is_none() {
3328 let row = &data[row_start..row_end];
3329 if row.get(0..2) == Some(b"\\.") {
3330 self.end_marker_end = Some(row_end);
3331 break;
3332 }
3333 }
3334 row_start = row_end;
3335 }
3336 }
3337 }
3338
3339 self.scan_pos = data.len();
3340 }
3341
3342 fn last_row_end(&self) -> Option<usize> {
3343 self.last_row_end
3344 }
3345
3346 fn end_marker_end(&self) -> Option<usize> {
3347 self.end_marker_end
3348 }
3349
3350 fn current_row_size(&self, data_len: usize) -> usize {
3351 data_len.saturating_sub(self.last_row_end.unwrap_or(0))
3352 }
3353
3354 fn on_split(&mut self, split_pos: usize) {
3355 self.scan_pos = self.scan_pos.saturating_sub(split_pos);
3356 self.last_row_end = None;
3357 self.end_marker_end = self
3358 .end_marker_end
3359 .and_then(|end| end.checked_sub(split_pos));
3360 }
3361
3362 fn on_truncate(&mut self, new_len: usize) {
3363 self.scan_pos = self.scan_pos.min(new_len);
3364 self.last_row_end = self.last_row_end.filter(|&end| end <= new_len);
3365 self.end_marker_end = self.end_marker_end.filter(|&end| end <= new_len);
3366 }
3367}
3368
3369impl CsvScanState {
3370 fn new(delimiter: u8, quote: u8, escape: u8, header: bool) -> Self {
3371 let (double_quote, escape) = if quote == escape {
3372 (true, None)
3373 } else {
3374 (false, Some(escape))
3375 };
3376 CsvScanState {
3377 reader: csv_core::ReaderBuilder::new()
3378 .delimiter(delimiter)
3379 .quote(quote)
3380 .double_quote(double_quote)
3381 .escape(escape)
3382 .build(),
3383 output: vec![0; 1],
3384 ends: vec![0; 1],
3385 record: Vec::new(),
3386 record_ends: Vec::new(),
3387 skip_first_record: header,
3388 }
3389 }
3390}
3391
3392#[cfg(test)]
3393mod test {
3394 use super::*;
3395
3396 #[mz_ore::test]
3397 fn test_parse_options() {
3398 struct TestCase {
3399 input: &'static str,
3400 expect: Result<Vec<(&'static str, &'static str)>, ()>,
3401 }
3402 let tests = vec![
3403 TestCase {
3404 input: "",
3405 expect: Ok(vec![]),
3406 },
3407 TestCase {
3408 input: "--key",
3409 expect: Err(()),
3410 },
3411 TestCase {
3412 input: "--key=val",
3413 expect: Ok(vec![("key", "val")]),
3414 },
3415 TestCase {
3416 input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
3417 expect: Ok(vec![
3418 ("key", "val"),
3419 ("key2", "val2"),
3420 ("key3", "val3"),
3421 ("key4", "val4"),
3422 ("key5", "val5"),
3423 ]),
3424 },
3425 TestCase {
3426 input: r#"-c\ key=val"#,
3427 expect: Ok(vec![(" key", "val")]),
3428 },
3429 TestCase {
3430 input: "--key=val -ckey2 val2",
3431 expect: Err(()),
3432 },
3433 TestCase {
3435 input: "--key=",
3436 expect: Ok(vec![("key", "")]),
3437 },
3438 ];
3439 for test in tests {
3440 let got = parse_options(test.input);
3441 let expect = test.expect.map(|r| {
3442 r.into_iter()
3443 .map(|(k, v)| (k.to_owned(), v.to_owned()))
3444 .collect()
3445 });
3446 assert_eq!(got, expect, "input: {}", test.input);
3447 }
3448 }
3449
3450 #[mz_ore::test]
3451 fn test_parse_option() {
3452 struct TestCase {
3453 input: &'static str,
3454 expect: Result<(&'static str, &'static str), ()>,
3455 }
3456 let tests = vec![
3457 TestCase {
3458 input: "",
3459 expect: Err(()),
3460 },
3461 TestCase {
3462 input: "--",
3463 expect: Err(()),
3464 },
3465 TestCase {
3466 input: "--c",
3467 expect: Err(()),
3468 },
3469 TestCase {
3470 input: "a=b",
3471 expect: Err(()),
3472 },
3473 TestCase {
3474 input: "--a=b",
3475 expect: Ok(("a", "b")),
3476 },
3477 TestCase {
3478 input: "--ca=b",
3479 expect: Ok(("ca", "b")),
3480 },
3481 TestCase {
3482 input: "-ca=b",
3483 expect: Ok(("a", "b")),
3484 },
3485 TestCase {
3487 input: "--=",
3488 expect: Ok(("", "")),
3489 },
3490 ];
3491 for test in tests {
3492 let got = parse_option(test.input);
3493 assert_eq!(got, test.expect, "input: {}", test.input);
3494 }
3495 }
3496
3497 #[mz_ore::test]
3498 fn test_split_options() {
3499 struct TestCase {
3500 input: &'static str,
3501 expect: Vec<&'static str>,
3502 }
3503 let tests = vec![
3504 TestCase {
3505 input: "",
3506 expect: vec![],
3507 },
3508 TestCase {
3509 input: " ",
3510 expect: vec![],
3511 },
3512 TestCase {
3513 input: " a ",
3514 expect: vec!["a"],
3515 },
3516 TestCase {
3517 input: " ab cd ",
3518 expect: vec!["ab", "cd"],
3519 },
3520 TestCase {
3521 input: r#" ab\ cd "#,
3522 expect: vec!["ab ", "cd"],
3523 },
3524 TestCase {
3525 input: r#" ab\\ cd "#,
3526 expect: vec![r#"ab\"#, "cd"],
3527 },
3528 TestCase {
3529 input: r#" ab\\\ cd "#,
3530 expect: vec![r#"ab\ "#, "cd"],
3531 },
3532 TestCase {
3533 input: r#" ab\\\ cd "#,
3534 expect: vec![r#"ab\ cd"#],
3535 },
3536 TestCase {
3537 input: r#" ab\\\cd "#,
3538 expect: vec![r#"ab\cd"#],
3539 },
3540 TestCase {
3541 input: r#"a\"#,
3542 expect: vec!["a"],
3543 },
3544 TestCase {
3545 input: r#"a\ "#,
3546 expect: vec!["a "],
3547 },
3548 TestCase {
3549 input: r#"\"#,
3550 expect: vec![],
3551 },
3552 TestCase {
3553 input: r#"\ "#,
3554 expect: vec![r#" "#],
3555 },
3556 TestCase {
3557 input: r#" \ "#,
3558 expect: vec![r#" "#],
3559 },
3560 TestCase {
3561 input: r#"\ "#,
3562 expect: vec![r#" "#],
3563 },
3564 ];
3565 for test in tests {
3566 let got = split_options(test.input);
3567 assert_eq!(got, test.expect, "input: {}", test.input);
3568 }
3569 }
3570}