Skip to main content

mz_pgwire/
protocol.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use csv_core::ReadRecordResult;
21use futures::future::{BoxFuture, FutureExt, pending};
22use itertools::Itertools;
23use mz_adapter::client::RecordFirstRowStream;
24use mz_adapter::session::{
25    EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState, Session,
26    SessionConfig, TransactionStatus,
27};
28use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
29use mz_adapter::{
30    AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics,
31    verify_datum_desc,
32};
33use mz_auth::Authenticated;
34use mz_auth::password::Password;
35use mz_authenticator::{Authenticator, GenericOidcAuthenticator};
36use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
37use mz_ore::cast::CastFrom;
38use mz_ore::netio::AsyncReady;
39use mz_ore::now::{EpochMillis, SYSTEM_TIME};
40use mz_ore::str::StrExt;
41use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
42use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
43use mz_pgwire_common::{
44    ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
45    VERSIONS,
46};
47use mz_repr::{
48    CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
49    SqlRelationType, SqlScalarType,
50};
51use mz_server_core::TlsMode;
52use mz_server_core::listeners;
53use mz_server_core::listeners::AllowedRoles;
54use mz_sql::ast::display::AstDisplay;
55use mz_sql::ast::{
56    CopyDirection, CopyStatement, CopyTarget, FetchDirection, Ident, Raw, Statement,
57};
58use mz_sql::parse::StatementParseResult;
59use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
60use mz_sql::session::metadata::SessionMetadata;
61use mz_sql::session::user::INTERNAL_USER_NAMES;
62use mz_sql::session::vars::VarInput;
63use postgres::error::SqlState;
64use tokio::io::{self, AsyncRead, AsyncWrite};
65use tokio::select;
66use tokio::time::{self};
67use tokio_metrics::TaskMetrics;
68use tokio_stream::wrappers::UnboundedReceiverStream;
69use tracing::{Instrument, debug, debug_span, warn};
70use uuid::Uuid;
71
72use crate::codec::{
73    FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
74};
75use crate::message::{
76    self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
77    SASLServerFirstMessage,
78};
79
80/// Reports whether the given stream begins with a pgwire handshake.
81///
82/// To avoid false negatives, there must be at least eight bytes in `buf`.
83pub fn match_handshake(buf: &[u8]) -> bool {
84    // The pgwire StartupMessage looks like this:
85    //
86    //     i32 - Length of entire message.
87    //     i32 - Protocol version number.
88    //     [String] - Arbitrary key-value parameters of any length.
89    //
90    // Since arbitrary parameters can be included in the StartupMessage, the
91    // first Int32 is worthless, since the message could have any length.
92    // Instead, we sniff the protocol version number.
93    if buf.len() < 8 {
94        return false;
95    }
96    let version = NetworkEndian::read_i32(&buf[4..8]);
97    VERSIONS.contains(&version)
98}
99
100/// Parameters for the [`run`] function.
101pub struct RunParams<'a, A, I>
102where
103    I: Iterator<Item = TaskMetrics> + Send,
104{
105    /// The TLS mode of the pgwire server.
106    pub tls_mode: Option<TlsMode>,
107    /// A client for the adapter.
108    pub adapter_client: mz_adapter::Client,
109    /// The connection to the client.
110    pub conn: &'a mut FramedConn<A>,
111    /// The universally unique identifier for the connection.
112    pub conn_uuid: Uuid,
113    /// The protocol version that the client provided in the startup message.
114    pub version: i32,
115    /// The parameters that the client provided in the startup message.
116    pub params: BTreeMap<String, String>,
117    /// Frontegg JWT authenticator.
118    pub frontegg: Option<FronteggAuthenticator>,
119    /// OIDC authenticator.
120    pub oidc: GenericOidcAuthenticator,
121    /// The authentication method defined by the server's listener
122    /// configuration.
123    pub authenticator_kind: listeners::AuthenticatorKind,
124    /// Global connection limit and count
125    pub active_connection_counter: ConnectionCounter,
126    /// Helm chart version
127    pub helm_chart_version: Option<String>,
128    /// Whether to allow reserved users (ie: mz_system).
129    pub allowed_roles: AllowedRoles,
130    /// Tokio metrics
131    pub tokio_metrics_intervals: I,
132}
133
134/// Runs a pgwire connection to completion.
135///
136/// This involves responding to `FrontendMessage::StartupMessage` and all future
137/// requests until the client terminates the connection or a fatal error occurs.
138///
139/// Note that this function returns successfully even upon delivering a fatal
140/// error to the client. It only returns `Err` if an unexpected I/O error occurs
141/// while communicating with the client, e.g., if the connection is severed in
142/// the middle of a request.
143#[mz_ore::instrument(level = "debug")]
144pub async fn run<'a, A, I>(
145    RunParams {
146        tls_mode,
147        adapter_client,
148        conn,
149        conn_uuid,
150        version,
151        mut params,
152        frontegg,
153        oidc,
154        authenticator_kind,
155        active_connection_counter,
156        helm_chart_version,
157        allowed_roles,
158        tokio_metrics_intervals,
159    }: RunParams<'a, A, I>,
160) -> Result<(), io::Error>
161where
162    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
163    I: Iterator<Item = TaskMetrics> + Send,
164{
165    if version != VERSION_3 {
166        return conn
167            .send(ErrorResponse::fatal(
168                SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
169                "server does not support the client's requested protocol version",
170            ))
171            .await;
172    }
173
174    let user = params.remove("user").unwrap_or_else(String::new);
175    let options = parse_options(params.get("options").unwrap_or(&String::new()));
176
177    // If oidc_auth_enabled exists as an option, return its value and filter it from
178    // the remaining options.
179    let (oidc_auth_enabled, options) = extract_oidc_auth_enabled_from_options(options);
180    let authenticator = get_authenticator(
181        authenticator_kind,
182        frontegg,
183        oidc,
184        adapter_client.clone(),
185        oidc_auth_enabled,
186    );
187    // TODO move this somewhere it can be shared with HTTP
188    let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
189    // this is a superset of internal users
190    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
191    let role_allowed = match allowed_roles {
192        AllowedRoles::Normal => !is_reserved_user,
193        AllowedRoles::Internal => is_internal_user,
194        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
195    };
196    if !role_allowed {
197        let msg = format!("unauthorized login to user '{user}'");
198        return conn
199            .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
200            .await;
201    }
202
203    if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
204        return conn.send(err).await;
205    }
206
207    let authenticator_kind = authenticator.kind();
208
209    let (mut session, expired) = match authenticator {
210        Authenticator::Frontegg(frontegg) => {
211            let password = match request_cleartext_password(conn).await {
212                Ok(password) => password,
213                Err(PasswordRequestError::IoError(e)) => return Err(e),
214                Err(PasswordRequestError::InvalidPasswordError(e)) => {
215                    return conn.send(e).await;
216                }
217            };
218
219            let auth_response = frontegg.authenticate(&user, &password).await;
220            match auth_response {
221                // Create a session based on the auth session.
222                //
223                // In particular, it's important that the username come from the
224                // auth session, as Frontegg may return an email address with
225                // different casing than the user supplied via the pgwire
226                // username fN
227                Ok((mut auth_session, authenticated)) => {
228                    let session = adapter_client.new_session(
229                        SessionConfig {
230                            conn_id: conn.conn_id().clone(),
231                            uuid: conn_uuid,
232                            user: auth_session.user().into(),
233                            client_ip: conn.peer_addr().clone(),
234                            external_metadata_rx: Some(auth_session.external_metadata_rx()),
235                            helm_chart_version,
236                            authenticator_kind,
237                        },
238                        authenticated,
239                    );
240                    let expired = async move { auth_session.expired().await };
241                    (session, expired.left_future())
242                }
243                Err(err) => {
244                    warn!(?err, "pgwire connection failed authentication");
245                    return conn
246                        .send(ErrorResponse::fatal(
247                            SqlState::INVALID_PASSWORD,
248                            "invalid password",
249                        ))
250                        .await;
251                }
252            }
253        }
254        Authenticator::Oidc(oidc) => {
255            // OIDC authentication: JWT sent as password in cleartext flow
256            let jwt = match request_cleartext_password(conn).await {
257                Ok(password) => password,
258                Err(PasswordRequestError::IoError(e)) => return Err(e),
259                Err(PasswordRequestError::InvalidPasswordError(e)) => {
260                    return conn.send(e).await;
261                }
262            };
263            let auth_response = oidc.authenticate(&jwt, Some(&user)).await;
264            match auth_response {
265                Ok((claims, authenticated)) => {
266                    let session = adapter_client.new_session(
267                        SessionConfig {
268                            conn_id: conn.conn_id().clone(),
269                            uuid: conn_uuid,
270                            user: claims.user,
271                            client_ip: conn.peer_addr().clone(),
272                            external_metadata_rx: None,
273                            helm_chart_version,
274                            authenticator_kind,
275                        },
276                        authenticated,
277                    );
278                    // No invalidation of the auth session once authenticated,
279                    // so auth session lasts indefinitely.
280                    (session, pending().right_future())
281                }
282                Err(err) => {
283                    warn!(?err, "pgwire connection failed authentication");
284                    return conn.send(err.into_response()).await;
285                }
286            }
287        }
288        Authenticator::Password(adapter_client) => {
289            let session = match authenticate_with_password(
290                conn,
291                &adapter_client,
292                user,
293                conn_uuid,
294                helm_chart_version,
295                authenticator_kind,
296            )
297            .await
298            {
299                Ok(session) => session,
300                Err(PasswordRequestError::IoError(e)) => return Err(e),
301                Err(PasswordRequestError::InvalidPasswordError(e)) => {
302                    return conn.send(e).await;
303                }
304            };
305            // No frontegg check, so auth session lasts indefinitely.
306            (session, pending().right_future())
307        }
308        Authenticator::Sasl(adapter_client) => {
309            // Start the handshake
310            conn.send(BackendMessage::AuthenticationSASL).await?;
311            conn.flush().await?;
312            // Get the initial response indicating chosen mechanism
313            let (mechanism, initial_response) = match conn.recv().await? {
314                Some(FrontendMessage::RawAuthentication(data)) => {
315                    match decode_sasl_initial_response(Cursor::new(&data)).ok() {
316                        Some(FrontendMessage::SASLInitialResponse {
317                            gs2_header,
318                            mechanism,
319                            initial_response,
320                        }) => {
321                            // We do not support channel binding
322                            if gs2_header.channel_binding_enabled() {
323                                return conn
324                                    .send(ErrorResponse::fatal(
325                                        SqlState::PROTOCOL_VIOLATION,
326                                        "channel binding not supported",
327                                    ))
328                                    .await;
329                            }
330                            (mechanism, initial_response)
331                        }
332                        _ => {
333                            return conn
334                                .send(ErrorResponse::fatal(
335                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
336                                    "expected SASLInitialResponse message",
337                                ))
338                                .await;
339                        }
340                    }
341                }
342                _ => {
343                    return conn
344                        .send(ErrorResponse::fatal(
345                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
346                            "expected SASLInitialResponse message",
347                        ))
348                        .await;
349                }
350            };
351
352            if mechanism != "SCRAM-SHA-256" {
353                return conn
354                    .send(ErrorResponse::fatal(
355                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
356                        "unsupported SASL mechanism",
357                    ))
358                    .await;
359            }
360
361            if initial_response.nonce.len() > 256 {
362                return conn
363                    .send(ErrorResponse::fatal(
364                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
365                        "nonce too long",
366                    ))
367                    .await;
368            }
369
370            let (server_first_message_raw, mock_hash) = match adapter_client
371                .generate_sasl_challenge(&user, &initial_response.nonce)
372                .await
373            {
374                Ok(response) => {
375                    let server_first_message_raw = format!(
376                        "r={},s={},i={}",
377                        response.nonce, response.salt, response.iteration_count
378                    );
379
380                    let client_key = [0u8; 32];
381                    let server_key = [1u8; 32];
382                    let mock_hash = format!(
383                        "SCRAM-SHA-256${}:{}${}:{}",
384                        response.iteration_count,
385                        response.salt,
386                        BASE64_STANDARD.encode(client_key),
387                        BASE64_STANDARD.encode(server_key)
388                    );
389
390                    conn.send(BackendMessage::AuthenticationSASLContinue(
391                        SASLServerFirstMessage {
392                            iteration_count: response.iteration_count,
393                            nonce: response.nonce,
394                            salt: response.salt,
395                        },
396                    ))
397                    .await?;
398                    conn.flush().await?;
399                    (server_first_message_raw, mock_hash)
400                }
401                Err(e) => {
402                    return conn.send(e.into_response(Severity::Fatal)).await;
403                }
404            };
405
406            let authenticated = match conn.recv().await? {
407                Some(FrontendMessage::RawAuthentication(data)) => {
408                    match decode_sasl_response(Cursor::new(&data)).ok() {
409                        Some(FrontendMessage::SASLResponse(response)) => {
410                            let auth_message = format!(
411                                "{},{},{}",
412                                initial_response.client_first_message_bare_raw,
413                                server_first_message_raw,
414                                response.client_final_message_bare_raw
415                            );
416                            if response.proof.len() > 1024 {
417                                return conn
418                                    .send(ErrorResponse::fatal(
419                                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
420                                        "proof too long",
421                                    ))
422                                    .await;
423                            }
424                            match adapter_client
425                                .verify_sasl_proof(
426                                    &user,
427                                    &response.proof,
428                                    &auth_message,
429                                    &mock_hash,
430                                )
431                                .await
432                            {
433                                Ok((proof_response, authenticated)) => {
434                                    conn.send(BackendMessage::AuthenticationSASLFinal(
435                                        SASLServerFinalMessage {
436                                            kind: SASLServerFinalMessageKinds::Verifier(
437                                                proof_response.verifier,
438                                            ),
439                                            extensions: vec![],
440                                        },
441                                    ))
442                                    .await?;
443                                    conn.flush().await?;
444                                    authenticated
445                                }
446                                Err(_) => {
447                                    return conn
448                                        .send(ErrorResponse::fatal(
449                                            SqlState::INVALID_PASSWORD,
450                                            "invalid password",
451                                        ))
452                                        .await;
453                                }
454                            }
455                        }
456                        _ => {
457                            return conn
458                                .send(ErrorResponse::fatal(
459                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
460                                    "expected SASLResponse message",
461                                ))
462                                .await;
463                        }
464                    }
465                }
466                _ => {
467                    return conn
468                        .send(ErrorResponse::fatal(
469                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
470                            "expected SASLResponse message",
471                        ))
472                        .await;
473                }
474            };
475
476            let session = adapter_client.new_session(
477                SessionConfig {
478                    conn_id: conn.conn_id().clone(),
479                    uuid: conn_uuid,
480                    user,
481                    client_ip: conn.peer_addr().clone(),
482                    external_metadata_rx: None,
483                    helm_chart_version,
484                    authenticator_kind,
485                },
486                authenticated,
487            );
488            // No frontegg check, so auth session lasts indefinitely.
489            let auth_session = pending().right_future();
490            (session, auth_session)
491        }
492
493        Authenticator::None => {
494            let session = adapter_client.new_session(
495                SessionConfig {
496                    conn_id: conn.conn_id().clone(),
497                    uuid: conn_uuid,
498                    user,
499                    client_ip: conn.peer_addr().clone(),
500                    external_metadata_rx: None,
501                    helm_chart_version,
502                    authenticator_kind,
503                },
504                Authenticated,
505            );
506            // No frontegg check, so auth session lasts indefinitely.
507            let auth_session = pending().right_future();
508            (session, auth_session)
509        }
510    };
511
512    let system_vars = adapter_client.get_system_vars().await;
513    for (name, value) in params {
514        let settings = match name.as_str() {
515            "options" => match &options {
516                Ok(opts) => opts,
517                Err(()) => {
518                    session.add_notice(AdapterNotice::BadStartupSetting {
519                        name,
520                        reason: "could not parse".into(),
521                    });
522                    continue;
523                }
524            },
525            _ => &vec![(name, value)],
526        };
527        for (key, val) in settings {
528            const LOCAL: bool = false;
529            // TODO: Issuing an error here is better than what we did before
530            // (silently ignore errors on set), but erroring the connection
531            // might be the better behavior. We maybe need to support more
532            // options sent by psql and drivers before we can safely do this.
533            if let Err(err) = session
534                .vars_mut()
535                .set(&system_vars, key, VarInput::Flat(val), LOCAL)
536            {
537                session.add_notice(AdapterNotice::BadStartupSetting {
538                    name: key.clone(),
539                    reason: err.to_string(),
540                });
541            }
542        }
543    }
544    session
545        .vars_mut()
546        .end_transaction(EndTransactionAction::Commit);
547
548    let _guard = match active_connection_counter.allocate_connection(session.user()) {
549        Ok(drop_connection) => drop_connection,
550        Err(e) => {
551            let e: AdapterError = e.into();
552            return conn.send(e.into_response(Severity::Fatal)).await;
553        }
554    };
555
556    // Register session with adapter.
557    let mut adapter_client = match adapter_client.startup(session).await {
558        Ok(adapter_client) => adapter_client,
559        Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
560    };
561
562    let mut buf = vec![BackendMessage::AuthenticationOk];
563    for var in adapter_client.session().vars().notify_set() {
564        buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
565    }
566    buf.push(BackendMessage::BackendKeyData {
567        conn_id: adapter_client.session().conn_id().unhandled(),
568        secret_key: adapter_client.session().secret_key(),
569    });
570    buf.extend(
571        adapter_client
572            .session()
573            .drain_notices()
574            .into_iter()
575            .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
576    );
577    buf.push(BackendMessage::ReadyForQuery(
578        adapter_client.session().transaction().into(),
579    ));
580    conn.send_all(buf).await?;
581    conn.flush().await?;
582
583    let machine = StateMachine {
584        conn,
585        adapter_client,
586        txn_needs_commit: false,
587        tokio_metrics_intervals,
588    };
589
590    select! {
591        r = machine.run() => {
592            // Errors produced internally (like MAX_REQUEST_SIZE being exceeded) should send an
593            // error to the client informing them why the connection was closed. We still want to
594            // return the original error up the stack, though, so we skip error checking during conn
595            // operations.
596            if let Err(err) = &r {
597                let _ = conn
598                    .send(ErrorResponse::fatal(
599                        SqlState::CONNECTION_FAILURE,
600                        err.to_string(),
601                    ))
602                    .await;
603                let _ = conn.flush().await;
604            }
605            r
606        },
607        _ = expired => {
608            conn
609                .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
610                .await?;
611            conn.flush().await
612        }
613    }
614}
615
616/// Gets `oidc_auth_enabled` from options if it exists.
617/// Returns options with oidc_auth_enabled extracted
618/// and the oidc_auth_enabled value.
619fn extract_oidc_auth_enabled_from_options(
620    options: Result<Vec<(String, String)>, ()>,
621) -> (bool, Result<Vec<(String, String)>, ()>) {
622    let options = match options {
623        Ok(opts) => opts,
624        Err(_) => return (false, options),
625    };
626
627    let mut new_options = Vec::new();
628    let mut oidc_auth_enabled = false;
629
630    for (k, v) in options {
631        if k == "oidc_auth_enabled" {
632            oidc_auth_enabled = v.parse::<bool>().unwrap_or(false);
633        } else {
634            new_options.push((k, v));
635        }
636    }
637
638    (oidc_auth_enabled, Ok(new_options))
639}
640
641/// Returns (name, value) session settings pairs from an options value.
642///
643/// From Postgres, see pg_split_opts in postinit.c and process_postgres_switches
644/// in postgres.c.
645fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
646    let opts = split_options(value);
647    let mut pairs = Vec::with_capacity(opts.len());
648    let mut seen_prefix = false;
649    for opt in opts {
650        if !seen_prefix {
651            if opt == "-c" {
652                seen_prefix = true;
653            } else {
654                let (key, val) = parse_option(&opt)?;
655                pairs.push((key.to_owned(), val.to_owned()));
656            }
657        } else {
658            let (key, val) = opt.split_once('=').ok_or(())?;
659            pairs.push((key.to_owned(), val.to_owned()));
660            seen_prefix = false;
661        }
662    }
663    Ok(pairs)
664}
665
666/// Returns the parsed key and value from option of the form `--key=value`, `-c
667/// key=value`, or `-ckey=value`. Keys replace `-` with `_`. Returns an error if
668/// there was some other prefix.
669fn parse_option(option: &str) -> Result<(&str, &str), ()> {
670    let (key, value) = option.split_once('=').ok_or(())?;
671    for prefix in &["-c", "--"] {
672        if let Some(key) = key.strip_prefix(prefix) {
673            return Ok((key, value));
674        }
675    }
676    Err(())
677}
678
679/// Splits value by any number of spaces except those preceded by `\`.
680fn split_options(value: &str) -> Vec<String> {
681    let mut strs = Vec::new();
682    // Need to build a string because of the escaping, so we can't simply
683    // subslice into value, and this isn't called enough to need to make it
684    // smart so it only builds a string if needed.
685    let mut current = String::new();
686    let mut was_slash = false;
687    for c in value.chars() {
688        was_slash = match c {
689            ' ' => {
690                if was_slash {
691                    current.push(' ');
692                } else if !current.is_empty() {
693                    // To ignore multiple spaces in a row, only push if current
694                    // is not empty.
695                    strs.push(std::mem::take(&mut current));
696                }
697                false
698            }
699            '\\' => {
700                if was_slash {
701                    // Two slashes in a row will add a slash and not escape the
702                    // next char.
703                    current.push('\\');
704                    false
705                } else {
706                    true
707                }
708            }
709            _ => {
710                current.push(c);
711                false
712            }
713        };
714    }
715    // A `\` at the end will be ignored.
716    if !current.is_empty() {
717        strs.push(current);
718    }
719    strs
720}
721
722enum PasswordRequestError {
723    InvalidPasswordError(ErrorResponse),
724    IoError(io::Error),
725}
726
727impl From<io::Error> for PasswordRequestError {
728    fn from(e: io::Error) -> Self {
729        PasswordRequestError::IoError(e)
730    }
731}
732
733/// Requests a cleartext password from a connection and returns it if it is valid.
734/// Sends an error response in the connection if the password
735/// is not valid.
736async fn request_cleartext_password<A>(
737    conn: &mut FramedConn<A>,
738) -> Result<String, PasswordRequestError>
739where
740    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
741{
742    conn.send(BackendMessage::AuthenticationCleartextPassword)
743        .await?;
744    conn.flush().await?;
745
746    if let Some(message) = conn.recv().await? {
747        if let FrontendMessage::RawAuthentication(data) = message {
748            if let Some(FrontendMessage::Password { password }) =
749                decode_password(Cursor::new(&data)).ok()
750            {
751                return Ok(password);
752            }
753        }
754    }
755
756    Err(PasswordRequestError::InvalidPasswordError(
757        ErrorResponse::fatal(
758            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
759            "expected Password message",
760        ),
761    ))
762}
763
764/// Helper for password-based authentication using AdapterClient
765/// and returns an authenticated session.
766async fn authenticate_with_password<A>(
767    conn: &mut FramedConn<A>,
768    adapter_client: &mz_adapter::Client,
769    user: String,
770    conn_uuid: Uuid,
771    helm_chart_version: Option<String>,
772    authenticator_kind: mz_auth::AuthenticatorKind,
773) -> Result<Session, PasswordRequestError>
774where
775    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
776{
777    let password = match request_cleartext_password(conn).await {
778        Ok(password) => Password(password),
779        Err(e) => return Err(e),
780    };
781
782    let authenticated = match adapter_client.authenticate(&user, &password).await {
783        Ok(authenticated) => authenticated,
784        Err(err) => {
785            warn!(?err, "pgwire connection failed authentication");
786            return Err(PasswordRequestError::InvalidPasswordError(
787                ErrorResponse::fatal(SqlState::INVALID_PASSWORD, "invalid password"),
788            ));
789        }
790    };
791
792    let session = adapter_client.new_session(
793        SessionConfig {
794            conn_id: conn.conn_id().clone(),
795            uuid: conn_uuid,
796            user,
797            client_ip: conn.peer_addr().clone(),
798            external_metadata_rx: None,
799            helm_chart_version,
800            authenticator_kind,
801        },
802        authenticated,
803    );
804
805    Ok(session)
806}
807
808#[derive(Debug)]
809enum State {
810    Ready,
811    Drain,
812    Done,
813}
814
815struct StateMachine<'a, A, I>
816where
817    I: Iterator<Item = TaskMetrics> + Send + 'a,
818{
819    conn: &'a mut FramedConn<A>,
820    adapter_client: mz_adapter::SessionClient,
821    txn_needs_commit: bool,
822    tokio_metrics_intervals: I,
823}
824
825enum SendRowsEndedReason {
826    Success {
827        result_size: u64,
828        rows_returned: u64,
829    },
830    Errored {
831        error: String,
832    },
833    Canceled,
834}
835
836const ABORTED_TXN_MSG: &str =
837    "current transaction is aborted, commands ignored until end of transaction block";
838
839impl<'a, A, I> StateMachine<'a, A, I>
840where
841    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
842    I: Iterator<Item = TaskMetrics> + Send + 'a,
843{
844    // Manually desugar this (don't use `async fn run`) here because a much better
845    // error message is produced if there are problems with Send or other traits
846    // somewhere within the Future.
847    #[allow(clippy::manual_async_fn)]
848    #[mz_ore::instrument(level = "debug")]
849    fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
850        async move {
851            let mut state = State::Ready;
852            loop {
853                self.send_pending_notices().await?;
854                state = match state {
855                    State::Ready => self.advance_ready().await?,
856                    State::Drain => self.advance_drain().await?,
857                    State::Done => return Ok(()),
858                };
859                self.adapter_client
860                    .add_idle_in_transaction_session_timeout();
861            }
862        }
863    }
864
865    #[instrument(level = "debug")]
866    async fn advance_ready(&mut self) -> Result<State, io::Error> {
867        // Start a new metrics interval before the `recv()` call.
868        self.tokio_metrics_intervals
869            .next()
870            .expect("infinite iterator");
871
872        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
873        let message = select! {
874            biased;
875
876            // `recv_timeout()` is cancel-safe as per it's docs.
877            Some(timeout) = self.adapter_client.recv_timeout() => {
878                let err: AdapterError = timeout.into();
879                let conn_id = self.adapter_client.session().conn_id();
880                tracing::warn!("session timed out, conn_id {}", conn_id);
881
882                // Process the error, doing any state cleanup.
883                let error_response = err.into_response(Severity::Fatal);
884                let error_state = self.send_error_and_get_state(error_response).await;
885
886                // Terminate __after__ we do any cleanup.
887                self.adapter_client.terminate().await;
888
889                // We must wait for the client to send a request before we can send the error response.
890                // Due to the PG wire protocol, we can't send an ErrorResponse unless it is in response
891                // to a client message.
892                let _ = self.conn.recv().await?;
893                return error_state;
894            },
895            // `recv()` is cancel-safe as per it's docs.
896            message = self.conn.recv() => message?,
897        };
898
899        // Take the metrics since just before the `recv`.
900        let interval = self
901            .tokio_metrics_intervals
902            .next()
903            .expect("infinite iterator");
904        let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
905
906        // TODO(ggevay): Consider subtracting the scheduling delay from `received`. It's not obvious
907        // whether we should do this, because the result wouldn't exactly correspond to either first
908        // byte received or last byte received (for msgs that arrive in more than one network packet).
909        let received = SYSTEM_TIME();
910
911        self.adapter_client
912            .remove_idle_in_transaction_session_timeout();
913
914        // NOTE(guswynn): we could consider adding spans to all message types. Currently
915        // only a few message types seem useful.
916        let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
917
918        let start = message.as_ref().map(|_| Instant::now());
919        let next_state = match message {
920            Some(FrontendMessage::Query { sql }) => {
921                let query_root_span =
922                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
923                query_root_span.follows_from(tracing::Span::current());
924                self.query(sql, received)
925                    .instrument(query_root_span)
926                    .await?
927            }
928            Some(FrontendMessage::Parse {
929                name,
930                sql,
931                param_types,
932            }) => self.parse(name, sql, param_types).await?,
933            Some(FrontendMessage::Bind {
934                portal_name,
935                statement_name,
936                param_formats,
937                raw_params,
938                result_formats,
939            }) => {
940                self.bind(
941                    portal_name,
942                    statement_name,
943                    param_formats,
944                    raw_params,
945                    result_formats,
946                )
947                .await?
948            }
949            Some(FrontendMessage::Execute {
950                portal_name,
951                max_rows,
952            }) => {
953                let max_rows = match usize::try_from(max_rows) {
954                    Ok(0) | Err(_) => ExecuteCount::All, // If `max_rows < 0`, no limit.
955                    Ok(n) => ExecuteCount::Count(n),
956                };
957                let execute_root_span =
958                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
959                execute_root_span.follows_from(tracing::Span::current());
960                let state = self
961                    .execute(
962                        portal_name,
963                        max_rows,
964                        portal_exec_message,
965                        None,
966                        ExecuteTimeout::None,
967                        None,
968                        Some(received),
969                    )
970                    .instrument(execute_root_span)
971                    .await?;
972                // In PostgreSQL, when using the extended query protocol, some statements may
973                // trigger an eager commit of the current implicit transaction,
974                // see: <https://git.postgresql.org/gitweb/?p=postgresql.git&a=commitdiff&h=f92944137>.
975                //
976                // In Materialize, however, we eagerly commit every statement outside of an explicit
977                // transaction when using the extended query protocol. This allows us to eliminate
978                // the possibility of a multiple statement implicit transaction, which in turn
979                // allows us to apply single-statement optimizations to queries issued in implicit
980                // transactions in the extended query protocol.
981                //
982                // We don't immediately commit here to allow users to page through the portal if
983                // necessary. Committing the transaction would destroy the portal before the next
984                // Execute command has a chance to resume it. So we instead mark the transaction
985                // for commit the next time that `ensure_transaction` is called.
986                if self.adapter_client.session().transaction().is_implicit() {
987                    self.txn_needs_commit = true;
988                }
989                state
990            }
991            Some(FrontendMessage::DescribeStatement { name }) => {
992                self.describe_statement(&name).await?
993            }
994            Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
995            Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
996            Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
997            Some(FrontendMessage::Flush) => self.flush().await?,
998            Some(FrontendMessage::Sync) => self.sync().await?,
999            Some(FrontendMessage::Terminate) => State::Done,
1000
1001            Some(FrontendMessage::CopyData(_))
1002            | Some(FrontendMessage::CopyDone)
1003            | Some(FrontendMessage::CopyFail(_))
1004            | Some(FrontendMessage::Password { .. })
1005            | Some(FrontendMessage::RawAuthentication(_))
1006            | Some(FrontendMessage::SASLInitialResponse { .. })
1007            | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
1008            None => State::Done,
1009        };
1010
1011        if let Some(start) = start {
1012            self.adapter_client
1013                .inner()
1014                .metrics()
1015                .pgwire_message_processing_seconds
1016                .with_label_values(&[message_name])
1017                .observe(start.elapsed().as_secs_f64());
1018        }
1019        self.adapter_client
1020            .inner()
1021            .metrics()
1022            .pgwire_recv_scheduling_delay_ms
1023            .with_label_values(&[message_name])
1024            .observe(recv_scheduling_delay_ms);
1025
1026        Ok(next_state)
1027    }
1028
1029    async fn advance_drain(&mut self) -> Result<State, io::Error> {
1030        let message = self.conn.recv().await?;
1031        if message.is_some() {
1032            self.adapter_client
1033                .remove_idle_in_transaction_session_timeout();
1034        }
1035        match message {
1036            Some(FrontendMessage::Sync) => self.sync().await,
1037            None => Ok(State::Done),
1038            _ => Ok(State::Drain),
1039        }
1040    }
1041
1042    /// Note that `lifecycle_timestamps` belongs to the whole "Simple Query", because the whole
1043    /// Simple Query is received and parsed together. This means that if there are multiple
1044    /// statements in a Simple Query, then all of them have the same `lifecycle_timestamps`.
1045    #[instrument(level = "debug")]
1046    async fn one_query(
1047        &mut self,
1048        stmt: Statement<Raw>,
1049        sql: String,
1050        lifecycle_timestamps: LifecycleTimestamps,
1051    ) -> Result<State, io::Error> {
1052        // Bind the portal. Note that this does not set the empty string prepared
1053        // statement.
1054        const EMPTY_PORTAL: &str = "";
1055        if let Err(e) = self
1056            .adapter_client
1057            .declare(EMPTY_PORTAL.to_string(), stmt, sql)
1058            .await
1059        {
1060            return self
1061                .send_error_and_get_state(e.into_response(Severity::Error))
1062                .await;
1063        }
1064        let portal = self
1065            .adapter_client
1066            .session()
1067            .get_portal_unverified_mut(EMPTY_PORTAL)
1068            .expect("unnamed portal should be present");
1069
1070        *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
1071
1072        let stmt_desc = portal.desc.clone();
1073        if !stmt_desc.param_types.is_empty() {
1074            return self
1075                .send_error_and_get_state(ErrorResponse::error(
1076                    SqlState::UNDEFINED_PARAMETER,
1077                    "there is no parameter $1",
1078                ))
1079                .await;
1080        }
1081
1082        // Maybe send row description.
1083        if let Some(relation_desc) = &stmt_desc.relation_desc {
1084            if !stmt_desc.is_copy {
1085                let formats = vec![Format::Text; stmt_desc.arity()];
1086                self.send(BackendMessage::RowDescription(
1087                    message::encode_row_description(relation_desc, &formats),
1088                ))
1089                .await?;
1090            }
1091        }
1092
1093        let result = match self
1094            .adapter_client
1095            .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
1096            .await
1097        {
1098            Ok((response, execute_started)) => {
1099                self.send_pending_notices().await?;
1100                self.send_execute_response(
1101                    response,
1102                    stmt_desc.relation_desc,
1103                    EMPTY_PORTAL.to_string(),
1104                    ExecuteCount::All,
1105                    portal_exec_message,
1106                    None,
1107                    ExecuteTimeout::None,
1108                    execute_started,
1109                )
1110                .await
1111            }
1112            Err(e) => {
1113                self.send_pending_notices().await?;
1114                self.send_error_and_get_state(e.into_response(Severity::Error))
1115                    .await
1116            }
1117        };
1118
1119        // Destroy the portal.
1120        self.adapter_client.session().remove_portal(EMPTY_PORTAL);
1121
1122        result
1123    }
1124
1125    async fn ensure_transaction(
1126        &mut self,
1127        num_stmts: usize,
1128        message_type: &str,
1129    ) -> Result<(), io::Error> {
1130        let start = Instant::now();
1131        if self.txn_needs_commit {
1132            self.commit_transaction().await?;
1133        }
1134        // start_transaction can't error (but assert that just in case it changes in
1135        // the future.
1136        let res = self.adapter_client.start_transaction(Some(num_stmts));
1137        assert_ok!(res);
1138        self.adapter_client
1139            .inner()
1140            .metrics()
1141            .pgwire_ensure_transaction_seconds
1142            .with_label_values(&[message_type])
1143            .observe(start.elapsed().as_secs_f64());
1144        Ok(())
1145    }
1146
1147    fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1148        let parse_start = Instant::now();
1149        let result = match self.adapter_client.parse(sql) {
1150            Ok(result) => result.map_err(|e| {
1151                // Convert our 0-based byte position to pgwire's 1-based character
1152                // position.
1153                let pos = sql[..e.error.pos].chars().count() + 1;
1154                ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1155            }),
1156            Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1157        };
1158        self.adapter_client
1159            .inner()
1160            .metrics()
1161            .parse_seconds
1162            .observe(parse_start.elapsed().as_secs_f64());
1163        result
1164    }
1165
1166    /// Executes a "Simple Query", see
1167    /// <https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SIMPLE-QUERY>
1168    ///
1169    /// For implicit transaction handling, see "Multiple Statements in a Simple Query" in the above.
1170    #[instrument(level = "debug")]
1171    async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1172        // Parse first before doing any transaction checking.
1173        let stmts = match self.parse_sql(&sql) {
1174            Ok(stmts) => stmts,
1175            Err(err) => {
1176                self.send_error_and_get_state(err).await?;
1177                return self.ready().await;
1178            }
1179        };
1180
1181        let num_stmts = stmts.len();
1182
1183        // Compare with postgres' backend/tcop/postgres.c exec_simple_query.
1184        for StatementParseResult { ast: stmt, sql } in stmts {
1185            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1186            if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1187                self.aborted_txn_error().await?;
1188                break;
1189            }
1190
1191            // Start an implicit transaction if we aren't in any transaction and there's
1192            // more than one statement. This mirrors the `use_implicit_block` variable in
1193            // postgres.
1194            //
1195            // This needs to be done in the loop instead of once at the top because
1196            // a COMMIT/ROLLBACK statement needs to start a new transaction on next
1197            // statement.
1198            self.ensure_transaction(num_stmts, "query").await?;
1199
1200            match self
1201                .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1202                .await?
1203            {
1204                State::Ready => (),
1205                State::Drain => break,
1206                State::Done => return Ok(State::Done),
1207            }
1208        }
1209
1210        // Implicit transactions are closed at the end of a Query message.
1211        {
1212            if self.adapter_client.session().transaction().is_implicit() {
1213                self.commit_transaction().await?;
1214            }
1215        }
1216
1217        if num_stmts == 0 {
1218            self.send(BackendMessage::EmptyQueryResponse).await?;
1219        }
1220
1221        self.ready().await
1222    }
1223
1224    #[instrument(level = "debug")]
1225    async fn parse(
1226        &mut self,
1227        name: String,
1228        sql: String,
1229        param_oids: Vec<u32>,
1230    ) -> Result<State, io::Error> {
1231        // Start a transaction if we aren't in one.
1232        self.ensure_transaction(1, "parse").await?;
1233
1234        let mut param_types = vec![];
1235        for oid in param_oids {
1236            match mz_pgrepr::Type::from_oid(oid) {
1237                Ok(ty) => match SqlScalarType::try_from(&ty) {
1238                    Ok(ty) => param_types.push(Some(ty)),
1239                    Err(err) => {
1240                        return self
1241                            .send_error_and_get_state(ErrorResponse::error(
1242                                SqlState::INVALID_PARAMETER_VALUE,
1243                                err.to_string(),
1244                            ))
1245                            .await;
1246                    }
1247                },
1248                Err(_) if oid == 0 => param_types.push(None),
1249                Err(e) => {
1250                    return self
1251                        .send_error_and_get_state(ErrorResponse::error(
1252                            SqlState::PROTOCOL_VIOLATION,
1253                            e.to_string(),
1254                        ))
1255                        .await;
1256                }
1257            }
1258        }
1259
1260        let stmts = match self.parse_sql(&sql) {
1261            Ok(stmts) => stmts,
1262            Err(err) => {
1263                return self.send_error_and_get_state(err).await;
1264            }
1265        };
1266        if stmts.len() > 1 {
1267            return self
1268                .send_error_and_get_state(ErrorResponse::error(
1269                    SqlState::INTERNAL_ERROR,
1270                    "cannot insert multiple commands into a prepared statement",
1271                ))
1272                .await;
1273        }
1274        let (maybe_stmt, sql) = match stmts.into_iter().next() {
1275            None => (None, ""),
1276            Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1277        };
1278        if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1279            return self.aborted_txn_error().await;
1280        }
1281        match self
1282            .adapter_client
1283            .prepare(name, maybe_stmt, sql.to_string(), param_types)
1284            .await
1285        {
1286            Ok(()) => {
1287                self.send(BackendMessage::ParseComplete).await?;
1288                Ok(State::Ready)
1289            }
1290            Err(e) => {
1291                self.send_error_and_get_state(e.into_response(Severity::Error))
1292                    .await
1293            }
1294        }
1295    }
1296
1297    /// Commits and clears the current transaction.
1298    #[instrument(level = "debug")]
1299    async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1300        self.end_transaction(EndTransactionAction::Commit).await
1301    }
1302
1303    /// Rollback and clears the current transaction.
1304    #[instrument(level = "debug")]
1305    async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1306        self.end_transaction(EndTransactionAction::Rollback).await
1307    }
1308
1309    /// End a transaction and report to the user if an error occurred.
1310    #[instrument(level = "debug")]
1311    async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1312        self.txn_needs_commit = false;
1313        let resp = self.adapter_client.end_transaction(action).await;
1314        if let Err(err) = resp {
1315            self.send(BackendMessage::ErrorResponse(
1316                err.into_response(Severity::Error),
1317            ))
1318            .await?;
1319        }
1320        Ok(())
1321    }
1322
1323    #[instrument(level = "debug")]
1324    async fn bind(
1325        &mut self,
1326        portal_name: String,
1327        statement_name: String,
1328        param_formats: Vec<Format>,
1329        raw_params: Vec<Option<Vec<u8>>>,
1330        result_formats: Vec<Format>,
1331    ) -> Result<State, io::Error> {
1332        // Start a transaction if we aren't in one.
1333        self.ensure_transaction(1, "bind").await?;
1334
1335        let aborted_txn = self.is_aborted_txn();
1336        let stmt = match self
1337            .adapter_client
1338            .get_prepared_statement(&statement_name)
1339            .await
1340        {
1341            Ok(stmt) => stmt,
1342            Err(err) => {
1343                return self
1344                    .send_error_and_get_state(err.into_response(Severity::Error))
1345                    .await;
1346            }
1347        };
1348
1349        let param_types = &stmt.desc().param_types;
1350        if param_types.len() != raw_params.len() {
1351            let message = format!(
1352                "bind message supplies {actual} parameters, \
1353                 but prepared statement \"{name}\" requires {expected}",
1354                name = statement_name,
1355                actual = raw_params.len(),
1356                expected = param_types.len()
1357            );
1358            return self
1359                .send_error_and_get_state(ErrorResponse::error(
1360                    SqlState::PROTOCOL_VIOLATION,
1361                    message,
1362                ))
1363                .await;
1364        }
1365        let param_formats = match pad_formats(param_formats, raw_params.len()) {
1366            Ok(param_formats) => param_formats,
1367            Err(msg) => {
1368                return self
1369                    .send_error_and_get_state(ErrorResponse::error(
1370                        SqlState::PROTOCOL_VIOLATION,
1371                        msg,
1372                    ))
1373                    .await;
1374            }
1375        };
1376        if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1377            return self.aborted_txn_error().await;
1378        }
1379        let buf = RowArena::new();
1380        let mut params = vec![];
1381        for ((raw_param, mz_typ), format) in raw_params
1382            .into_iter()
1383            .zip_eq(param_types)
1384            .zip_eq(param_formats)
1385        {
1386            let pg_typ = mz_pgrepr::Type::from(mz_typ);
1387            let datum = match raw_param {
1388                None => Datum::Null,
1389                Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1390                    Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1391                        Ok(datum) => datum,
1392                        Err(msg) => {
1393                            return self
1394                                .send_error_and_get_state(ErrorResponse::error(
1395                                    SqlState::INVALID_PARAMETER_VALUE,
1396                                    msg,
1397                                ))
1398                                .await;
1399                        }
1400                    },
1401                    Err(err) => {
1402                        let msg = format!("unable to decode parameter: {}", err);
1403                        return self
1404                            .send_error_and_get_state(ErrorResponse::error(
1405                                SqlState::INVALID_PARAMETER_VALUE,
1406                                msg,
1407                            ))
1408                            .await;
1409                    }
1410                },
1411            };
1412            params.push((datum, mz_typ.clone()))
1413        }
1414
1415        let result_formats = match pad_formats(
1416            result_formats,
1417            stmt.desc()
1418                .relation_desc
1419                .clone()
1420                .map(|desc| desc.typ().column_types.len())
1421                .unwrap_or(0),
1422        ) {
1423            Ok(result_formats) => result_formats,
1424            Err(msg) => {
1425                return self
1426                    .send_error_and_get_state(ErrorResponse::error(
1427                        SqlState::PROTOCOL_VIOLATION,
1428                        msg,
1429                    ))
1430                    .await;
1431            }
1432        };
1433
1434        // Binary encodings are disabled for list, map, and aclitem types, but this doesn't
1435        // apply to COPY TO statements.
1436        if !stmt.stmt().map_or(false, |stmt| match stmt {
1437            Statement::Copy(CopyStatement {
1438                direction: CopyDirection::To,
1439                ..
1440            }) => true,
1441            Statement::Copy(CopyStatement {
1442                direction: CopyDirection::From,
1443                // To be conservative, we are restricting COPY FROM to only allow list/map/aclitem types if it is not
1444                // copying from STDIN. It is likely that this works in theory, but is risky and likely to OOM anyways
1445                // as all the data will be held in a buffer in memory before being processed.
1446                target: CopyTarget::Expr(_),
1447                ..
1448            }) => true,
1449            _ => false,
1450        }) {
1451            if let Some(desc) = stmt.desc().relation_desc.clone() {
1452                for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1453                    match (format, &ty.scalar_type) {
1454                        (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1455                            return self
1456                                .send_error_and_get_state(ErrorResponse::error(
1457                                    SqlState::PROTOCOL_VIOLATION,
1458                                    "binary encoding of list types is not implemented",
1459                                ))
1460                                .await;
1461                        }
1462                        (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1463                            return self
1464                                .send_error_and_get_state(ErrorResponse::error(
1465                                    SqlState::PROTOCOL_VIOLATION,
1466                                    "binary encoding of map types is not implemented",
1467                                ))
1468                                .await;
1469                        }
1470                        (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1471                            return self
1472                                .send_error_and_get_state(ErrorResponse::error(
1473                                    SqlState::PROTOCOL_VIOLATION,
1474                                    "binary encoding of aclitem types does not exist",
1475                                ))
1476                                .await;
1477                        }
1478                        _ => (),
1479                    }
1480                }
1481            }
1482        }
1483
1484        let desc = stmt.desc().clone();
1485        let logging = Arc::clone(stmt.logging());
1486        let stmt_ast = stmt.stmt().cloned();
1487        let state_revision = stmt.state_revision;
1488        if let Err(err) = self.adapter_client.session().set_portal(
1489            portal_name,
1490            desc,
1491            stmt_ast,
1492            logging,
1493            params,
1494            result_formats,
1495            state_revision,
1496        ) {
1497            return self
1498                .send_error_and_get_state(err.into_response(Severity::Error))
1499                .await;
1500        }
1501
1502        self.send(BackendMessage::BindComplete).await?;
1503        Ok(State::Ready)
1504    }
1505
1506    /// `outer_ctx_extra` is Some when we are executing as part of an outer statement, e.g., a FETCH
1507    /// triggering the execution of the underlying query.
1508    fn execute(
1509        &mut self,
1510        portal_name: String,
1511        max_rows: ExecuteCount,
1512        get_response: GetResponse,
1513        fetch_portal_name: Option<String>,
1514        timeout: ExecuteTimeout,
1515        outer_ctx_extra: Option<ExecuteContextGuard>,
1516        received: Option<EpochMillis>,
1517    ) -> BoxFuture<'_, Result<State, io::Error>> {
1518        async move {
1519            let aborted_txn = self.is_aborted_txn();
1520
1521            // Check if the portal has been started and can be continued.
1522            let portal = match self
1523                .adapter_client
1524                .session()
1525                .get_portal_unverified_mut(&portal_name)
1526            {
1527                Some(portal) => portal,
1528                None => {
1529                    let msg = format!("portal {} does not exist", portal_name.quoted());
1530                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1531                        self.adapter_client.retire_execute(
1532                            outer_ctx_extra,
1533                            StatementEndedExecutionReason::Errored { error: msg.clone() },
1534                        );
1535                    }
1536                    return self
1537                        .send_error_and_get_state(ErrorResponse::error(
1538                            SqlState::INVALID_CURSOR_NAME,
1539                            msg,
1540                        ))
1541                        .await;
1542                }
1543            };
1544
1545            *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1546
1547            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1548            let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1549            if aborted_txn && !txn_exit_stmt {
1550                if let Some(outer_ctx_extra) = outer_ctx_extra {
1551                    self.adapter_client.retire_execute(
1552                        outer_ctx_extra,
1553                        StatementEndedExecutionReason::Errored {
1554                            error: ABORTED_TXN_MSG.to_string(),
1555                        },
1556                    );
1557                }
1558                return self.aborted_txn_error().await;
1559            }
1560
1561            let row_desc = portal.desc.relation_desc.clone();
1562            match portal.state {
1563                PortalState::NotStarted => {
1564                    // Start a transaction if we aren't in one.
1565                    self.ensure_transaction(1, "execute").await?;
1566                    match self
1567                        .adapter_client
1568                        .execute(
1569                            portal_name.clone(),
1570                            self.conn.wait_closed(),
1571                            outer_ctx_extra,
1572                        )
1573                        .await
1574                    {
1575                        Ok((response, execute_started)) => {
1576                            self.send_pending_notices().await?;
1577                            self.send_execute_response(
1578                                response,
1579                                row_desc,
1580                                portal_name,
1581                                max_rows,
1582                                get_response,
1583                                fetch_portal_name,
1584                                timeout,
1585                                execute_started,
1586                            )
1587                            .await
1588                        }
1589                        Err(e) => {
1590                            self.send_pending_notices().await?;
1591                            self.send_error_and_get_state(e.into_response(Severity::Error))
1592                                .await
1593                        }
1594                    }
1595                }
1596                PortalState::InProgress(rows) => {
1597                    let rows = rows.take().expect("InProgress rows must be populated");
1598                    let (result, statement_ended_execution_reason) = match self
1599                        .send_rows(
1600                            row_desc.expect("portal missing row desc on resumption"),
1601                            portal_name,
1602                            rows,
1603                            max_rows,
1604                            get_response,
1605                            fetch_portal_name,
1606                            timeout,
1607                        )
1608                        .await
1609                    {
1610                        Err(e) => {
1611                            // This is an error communicating with the connection.
1612                            // We consider that to be a cancelation, rather than a query error.
1613                            (Err(e), StatementEndedExecutionReason::Canceled)
1614                        }
1615                        Ok((ok, SendRowsEndedReason::Canceled)) => {
1616                            (Ok(ok), StatementEndedExecutionReason::Canceled)
1617                        }
1618                        // NOTE: For now the values for `result_size` and
1619                        // `rows_returned` in fetches are a bit confusing.
1620                        // We record `Some(n)` for the first fetch, where `n` is
1621                        // the number of bytes/rows returned by the inner
1622                        // execute (regardless of how many rows the
1623                        // fetch fetched), and `None` for subsequent fetches.
1624                        //
1625                        // This arguably makes sense since the size/rows
1626                        // returned measures how much work the compute
1627                        // layer had to do to satisfy the query, but
1628                        // we should revisit it if/when we start
1629                        // logging the inner execute separately.
1630                        Ok((
1631                            ok,
1632                            SendRowsEndedReason::Success {
1633                                result_size: _,
1634                                rows_returned: _,
1635                            },
1636                        )) => (
1637                            Ok(ok),
1638                            StatementEndedExecutionReason::Success {
1639                                result_size: None,
1640                                rows_returned: None,
1641                                execution_strategy: None,
1642                            },
1643                        ),
1644                        Ok((ok, SendRowsEndedReason::Errored { error })) => {
1645                            (Ok(ok), StatementEndedExecutionReason::Errored { error })
1646                        }
1647                    };
1648                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1649                        self.adapter_client
1650                            .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1651                    }
1652                    result
1653                }
1654                // FETCH is an awkward command for our current architecture. In Postgres it
1655                // will extract <count> rows from the target portal, cache them, and return
1656                // them to the user as requested. Its command tag is always FETCH <num rows
1657                // extracted>. In Materialize, since we have chosen to not fully support FETCH,
1658                // we must remember the number of rows that were returned. Use this tag to
1659                // remember that information and return it.
1660                PortalState::Completed(Some(tag)) => {
1661                    let tag = tag.to_string();
1662                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1663                        self.adapter_client.retire_execute(
1664                            outer_ctx_extra,
1665                            StatementEndedExecutionReason::Success {
1666                                result_size: None,
1667                                rows_returned: None,
1668                                execution_strategy: None,
1669                            },
1670                        );
1671                    }
1672                    self.send(BackendMessage::CommandComplete { tag }).await?;
1673                    Ok(State::Ready)
1674                }
1675                PortalState::Completed(None) => {
1676                    let error = format!(
1677                        "portal {} cannot be run",
1678                        Ident::new_unchecked(portal_name).to_ast_string_stable()
1679                    );
1680                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1681                        self.adapter_client.retire_execute(
1682                            outer_ctx_extra,
1683                            StatementEndedExecutionReason::Errored {
1684                                error: error.clone(),
1685                            },
1686                        );
1687                    }
1688                    self.send_error_and_get_state(ErrorResponse::error(
1689                        SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1690                        error,
1691                    ))
1692                    .await
1693                }
1694            }
1695        }
1696        .instrument(debug_span!("execute"))
1697        .boxed()
1698    }
1699
1700    #[instrument(level = "debug")]
1701    async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1702        // Start a transaction if we aren't in one.
1703        self.ensure_transaction(1, "describe_statement").await?;
1704
1705        let stmt = match self.adapter_client.get_prepared_statement(name).await {
1706            Ok(stmt) => stmt,
1707            Err(err) => {
1708                return self
1709                    .send_error_and_get_state(err.into_response(Severity::Error))
1710                    .await;
1711            }
1712        };
1713        // Cloning to avoid a mutable borrow issue because `send` also uses `adapter_client`
1714        let parameter_desc = BackendMessage::ParameterDescription(
1715            stmt.desc()
1716                .param_types
1717                .iter()
1718                .map(mz_pgrepr::Type::from)
1719                .collect(),
1720        );
1721        // Claim that all results will be output in text format, even
1722        // though the true result formats are not yet known. A bit
1723        // weird, but this is the behavior that PostgreSQL specifies.
1724        let formats = vec![Format::Text; stmt.desc().arity()];
1725        let row_desc = describe_rows(stmt.desc(), &formats);
1726        self.send_all([parameter_desc, row_desc]).await?;
1727        Ok(State::Ready)
1728    }
1729
1730    #[instrument(level = "debug")]
1731    async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1732        // Start a transaction if we aren't in one.
1733        self.ensure_transaction(1, "describe_portal").await?;
1734
1735        let session = self.adapter_client.session();
1736        let row_desc = session
1737            .get_portal_unverified(name)
1738            .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1739        match row_desc {
1740            Some(row_desc) => {
1741                self.send(row_desc).await?;
1742                Ok(State::Ready)
1743            }
1744            None => {
1745                self.send_error_and_get_state(ErrorResponse::error(
1746                    SqlState::INVALID_CURSOR_NAME,
1747                    format!("portal {} does not exist", name.quoted()),
1748                ))
1749                .await
1750            }
1751        }
1752    }
1753
1754    #[instrument(level = "debug")]
1755    async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1756        self.adapter_client
1757            .session()
1758            .remove_prepared_statement(&name);
1759        self.send(BackendMessage::CloseComplete).await?;
1760        Ok(State::Ready)
1761    }
1762
1763    #[instrument(level = "debug")]
1764    async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1765        self.adapter_client.session().remove_portal(&name);
1766        self.send(BackendMessage::CloseComplete).await?;
1767        Ok(State::Ready)
1768    }
1769
1770    fn complete_portal(&mut self, name: &str) {
1771        let portal = self
1772            .adapter_client
1773            .session()
1774            .get_portal_unverified_mut(name)
1775            .expect("portal should exist");
1776        *portal.state = PortalState::Completed(None);
1777    }
1778
1779    async fn fetch(
1780        &mut self,
1781        name: String,
1782        count: Option<FetchDirection>,
1783        max_rows: ExecuteCount,
1784        fetch_portal_name: Option<String>,
1785        timeout: ExecuteTimeout,
1786        ctx_extra: ExecuteContextGuard,
1787    ) -> Result<State, io::Error> {
1788        // Unlike Execute, no count specified in FETCH returns 1 row, and 0 means 0
1789        // instead of All.
1790        let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1791
1792        // Figure out how many rows we should send back by looking at the various
1793        // combinations of the execute and fetch.
1794        //
1795        // In Postgres, Fetch will cache <count> rows from the target portal and
1796        // return those as requested (if, say, an Execute message was sent with a
1797        // max_rows < the Fetch's count). We expect that case to be incredibly rare and
1798        // so have chosen to not support it until users request it. This eases
1799        // implementation difficulty since we don't have to be able to "send" rows to
1800        // a buffer.
1801        //
1802        // TODO(mjibson): Test this somehow? Need to divide up the pgtest files in
1803        // order to have some that are not Postgres compatible.
1804        let count = match (max_rows, count) {
1805            (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1806                let count = usize::cast_from(count);
1807                if max_rows < count {
1808                    let msg = "Execute with max_rows < a FETCH's count is not supported";
1809                    self.adapter_client.retire_execute(
1810                        ctx_extra,
1811                        StatementEndedExecutionReason::Errored {
1812                            error: msg.to_string(),
1813                        },
1814                    );
1815                    return self
1816                        .send_error_and_get_state(ErrorResponse::error(
1817                            SqlState::FEATURE_NOT_SUPPORTED,
1818                            msg,
1819                        ))
1820                        .await;
1821                }
1822                ExecuteCount::Count(count)
1823            }
1824            (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1825                let msg = "Execute with max_rows of a FETCH ALL is not supported";
1826                self.adapter_client.retire_execute(
1827                    ctx_extra,
1828                    StatementEndedExecutionReason::Errored {
1829                        error: msg.to_string(),
1830                    },
1831                );
1832                return self
1833                    .send_error_and_get_state(ErrorResponse::error(
1834                        SqlState::FEATURE_NOT_SUPPORTED,
1835                        msg,
1836                    ))
1837                    .await;
1838            }
1839            (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1840            (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1841                ExecuteCount::Count(usize::cast_from(count))
1842            }
1843        };
1844        let cursor_name = name.to_string();
1845        self.execute(
1846            cursor_name,
1847            count,
1848            fetch_message,
1849            fetch_portal_name,
1850            timeout,
1851            Some(ctx_extra),
1852            None,
1853        )
1854        .await
1855    }
1856
1857    async fn flush(&mut self) -> Result<State, io::Error> {
1858        self.conn.flush().await?;
1859        Ok(State::Ready)
1860    }
1861
1862    /// Sends a backend message to the client, after applying a severity filter.
1863    ///
1864    /// The message is only sent if its severity is above the severity set
1865    /// in the session, with the default value being NOTICE.
1866    #[instrument(level = "debug")]
1867    async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1868    where
1869        M: Into<BackendMessage>,
1870    {
1871        let message: BackendMessage = message.into();
1872        let is_error =
1873            matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1874
1875        self.conn.send(message).await?;
1876
1877        // Flush immediately after sending an error response, as some clients
1878        // expect to be able to read the error response before sending a Sync
1879        // message. This is arguably in violation of the protocol specification,
1880        // but the specification is somewhat ambiguous, and easier to match
1881        // PostgreSQL here than to fix all the clients that have this
1882        // expectation.
1883        if is_error {
1884            self.conn.flush().await?;
1885        }
1886
1887        Ok(())
1888    }
1889
1890    #[instrument(level = "debug")]
1891    pub async fn send_all(
1892        &mut self,
1893        messages: impl IntoIterator<Item = BackendMessage>,
1894    ) -> Result<(), io::Error> {
1895        for m in messages {
1896            self.send(m).await?;
1897        }
1898        Ok(())
1899    }
1900
1901    #[instrument(level = "debug")]
1902    async fn sync(&mut self) -> Result<State, io::Error> {
1903        // Close the current transaction if we are in an implicit transaction.
1904        if self.adapter_client.session().transaction().is_implicit() {
1905            self.commit_transaction().await?;
1906        }
1907        self.ready().await
1908    }
1909
1910    #[instrument(level = "debug")]
1911    async fn ready(&mut self) -> Result<State, io::Error> {
1912        let txn_state = self.adapter_client.session().transaction().into();
1913        self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1914        self.flush().await
1915    }
1916
1917    #[allow(clippy::too_many_arguments)]
1918    #[instrument(level = "debug")]
1919    async fn send_execute_response(
1920        &mut self,
1921        response: ExecuteResponse,
1922        row_desc: Option<RelationDesc>,
1923        portal_name: String,
1924        max_rows: ExecuteCount,
1925        get_response: GetResponse,
1926        fetch_portal_name: Option<String>,
1927        timeout: ExecuteTimeout,
1928        execute_started: Instant,
1929    ) -> Result<State, io::Error> {
1930        let mut tag = response.tag();
1931
1932        macro_rules! command_complete {
1933            () => {{
1934                self.send(BackendMessage::CommandComplete {
1935                    tag: tag
1936                        .take()
1937                        .expect("command_complete only called on tag-generating results"),
1938                })
1939                .await?;
1940                Ok(State::Ready)
1941            }};
1942        }
1943
1944        let r = match response {
1945            ExecuteResponse::ClosedCursor => {
1946                self.complete_portal(&portal_name);
1947                command_complete!()
1948            }
1949            ExecuteResponse::DeclaredCursor => {
1950                self.complete_portal(&portal_name);
1951                command_complete!()
1952            }
1953            ExecuteResponse::EmptyQuery => {
1954                self.send(BackendMessage::EmptyQueryResponse).await?;
1955                Ok(State::Ready)
1956            }
1957            ExecuteResponse::Fetch {
1958                name,
1959                count,
1960                timeout,
1961                ctx_extra,
1962            } => {
1963                self.fetch(
1964                    name,
1965                    count,
1966                    max_rows,
1967                    Some(portal_name.to_string()),
1968                    timeout,
1969                    ctx_extra,
1970                )
1971                .await
1972            }
1973            ExecuteResponse::SendingRowsStreaming {
1974                rows,
1975                instance_id,
1976                strategy,
1977            } => {
1978                let row_desc = row_desc
1979                    .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1980
1981                let span = tracing::debug_span!("sending_rows_streaming");
1982
1983                self.send_rows(
1984                    row_desc,
1985                    portal_name,
1986                    InProgressRows::new(RecordFirstRowStream::new(
1987                        Box::new(rows),
1988                        execute_started,
1989                        &self.adapter_client,
1990                        Some(instance_id),
1991                        Some(strategy),
1992                    )),
1993                    max_rows,
1994                    get_response,
1995                    fetch_portal_name,
1996                    timeout,
1997                )
1998                .instrument(span)
1999                .await
2000                .map(|(state, _)| state)
2001            }
2002            ExecuteResponse::SendingRowsImmediate { rows } => {
2003                let row_desc = row_desc
2004                    .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
2005
2006                let span = tracing::debug_span!("sending_rows_immediate");
2007
2008                let stream =
2009                    futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
2010                self.send_rows(
2011                    row_desc,
2012                    portal_name,
2013                    InProgressRows::new(RecordFirstRowStream::new(
2014                        Box::new(stream),
2015                        execute_started,
2016                        &self.adapter_client,
2017                        None,
2018                        Some(StatementExecutionStrategy::Constant),
2019                    )),
2020                    max_rows,
2021                    get_response,
2022                    fetch_portal_name,
2023                    timeout,
2024                )
2025                .instrument(span)
2026                .await
2027                .map(|(state, _)| state)
2028            }
2029            ExecuteResponse::SetVariable { name, .. } => {
2030                // This code is somewhat awkwardly structured because we
2031                // can't hold `var` across an await point.
2032                let qn = name.to_string();
2033                let msg = if let Some(var) = self
2034                    .adapter_client
2035                    .session()
2036                    .vars_mut()
2037                    .notify_set()
2038                    .find(|v| v.name() == qn)
2039                {
2040                    Some(BackendMessage::ParameterStatus(var.name(), var.value()))
2041                } else {
2042                    None
2043                };
2044                if let Some(msg) = msg {
2045                    self.send(msg).await?;
2046                }
2047                command_complete!()
2048            }
2049            ExecuteResponse::Subscribing {
2050                rx,
2051                ctx_extra,
2052                instance_id,
2053            } => {
2054                if fetch_portal_name.is_none() {
2055                    let mut msg = ErrorResponse::notice(
2056                        SqlState::WARNING,
2057                        "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
2058                    );
2059                    if self.adapter_client.session().vars().application_name() == "psql" {
2060                        msg.hint = Some(
2061                            "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
2062                                .into(),
2063                        )
2064                    }
2065                    self.send(msg).await?;
2066                    self.conn.flush().await?;
2067                }
2068                let row_desc =
2069                    row_desc.expect("missing row description for ExecuteResponse::Subscribing");
2070                let (result, statement_ended_execution_reason) = match self
2071                    .send_rows(
2072                        row_desc,
2073                        portal_name,
2074                        InProgressRows::new(RecordFirstRowStream::new(
2075                            Box::new(UnboundedReceiverStream::new(rx)),
2076                            execute_started,
2077                            &self.adapter_client,
2078                            Some(instance_id),
2079                            None,
2080                        )),
2081                        max_rows,
2082                        get_response,
2083                        fetch_portal_name,
2084                        timeout,
2085                    )
2086                    .await
2087                {
2088                    Err(e) => {
2089                        // This is an error communicating with the connection.
2090                        // We consider that to be a cancelation, rather than a query error.
2091                        (Err(e), StatementEndedExecutionReason::Canceled)
2092                    }
2093                    Ok((ok, SendRowsEndedReason::Canceled)) => {
2094                        (Ok(ok), StatementEndedExecutionReason::Canceled)
2095                    }
2096                    Ok((
2097                        ok,
2098                        SendRowsEndedReason::Success {
2099                            result_size,
2100                            rows_returned,
2101                        },
2102                    )) => (
2103                        Ok(ok),
2104                        StatementEndedExecutionReason::Success {
2105                            result_size: Some(result_size),
2106                            rows_returned: Some(rows_returned),
2107                            execution_strategy: None,
2108                        },
2109                    ),
2110                    Ok((ok, SendRowsEndedReason::Errored { error })) => {
2111                        (Ok(ok), StatementEndedExecutionReason::Errored { error })
2112                    }
2113                };
2114                self.adapter_client
2115                    .retire_execute(ctx_extra, statement_ended_execution_reason);
2116                return result;
2117            }
2118            ExecuteResponse::CopyTo { format, resp } => {
2119                let row_desc =
2120                    row_desc.expect("missing row description for ExecuteResponse::CopyTo");
2121                match *resp {
2122                    ExecuteResponse::Subscribing {
2123                        rx,
2124                        ctx_extra,
2125                        instance_id,
2126                    } => {
2127                        let (result, statement_ended_execution_reason) = match self
2128                            .copy_rows(
2129                                format,
2130                                row_desc,
2131                                RecordFirstRowStream::new(
2132                                    Box::new(UnboundedReceiverStream::new(rx)),
2133                                    execute_started,
2134                                    &self.adapter_client,
2135                                    Some(instance_id),
2136                                    None,
2137                                ),
2138                            )
2139                            .await
2140                        {
2141                            Err(e) => {
2142                                // This is an error communicating with the connection.
2143                                // We consider that to be a cancelation, rather than a query error.
2144                                (Err(e), StatementEndedExecutionReason::Canceled)
2145                            }
2146                            Ok((
2147                                state,
2148                                SendRowsEndedReason::Success {
2149                                    result_size,
2150                                    rows_returned,
2151                                },
2152                            )) => (
2153                                Ok(state),
2154                                StatementEndedExecutionReason::Success {
2155                                    result_size: Some(result_size),
2156                                    rows_returned: Some(rows_returned),
2157                                    execution_strategy: None,
2158                                },
2159                            ),
2160                            Ok((state, SendRowsEndedReason::Errored { error })) => {
2161                                (Ok(state), StatementEndedExecutionReason::Errored { error })
2162                            }
2163                            Ok((state, SendRowsEndedReason::Canceled)) => {
2164                                (Ok(state), StatementEndedExecutionReason::Canceled)
2165                            }
2166                        };
2167                        self.adapter_client
2168                            .retire_execute(ctx_extra, statement_ended_execution_reason);
2169                        return result;
2170                    }
2171                    ExecuteResponse::SendingRowsStreaming {
2172                        rows,
2173                        instance_id,
2174                        strategy,
2175                    } => {
2176                        // We don't need to finalize execution here;
2177                        // it was already done in the
2178                        // coordinator. Just extract the state and
2179                        // return that.
2180                        return self
2181                            .copy_rows(
2182                                format,
2183                                row_desc,
2184                                RecordFirstRowStream::new(
2185                                    Box::new(rows),
2186                                    execute_started,
2187                                    &self.adapter_client,
2188                                    Some(instance_id),
2189                                    Some(strategy),
2190                                ),
2191                            )
2192                            .await
2193                            .map(|(state, _)| state);
2194                    }
2195                    ExecuteResponse::SendingRowsImmediate { rows } => {
2196                        let span = tracing::debug_span!("sending_rows_immediate");
2197
2198                        let rows = futures::stream::once(futures::future::ready(
2199                            PeekResponseUnary::Rows(rows),
2200                        ));
2201                        // We don't need to finalize execution here;
2202                        // it was already done in the
2203                        // coordinator. Just extract the state and
2204                        // return that.
2205                        return self
2206                            .copy_rows(
2207                                format,
2208                                row_desc,
2209                                RecordFirstRowStream::new(
2210                                    Box::new(rows),
2211                                    execute_started,
2212                                    &self.adapter_client,
2213                                    None,
2214                                    Some(StatementExecutionStrategy::Constant),
2215                                ),
2216                            )
2217                            .instrument(span)
2218                            .await
2219                            .map(|(state, _)| state);
2220                    }
2221                    _ => {
2222                        return self
2223                            .send_error_and_get_state(ErrorResponse::error(
2224                                SqlState::INTERNAL_ERROR,
2225                                "unsupported COPY response type".to_string(),
2226                            ))
2227                            .await;
2228                    }
2229                };
2230            }
2231            ExecuteResponse::CopyFrom {
2232                target_id,
2233                target_name,
2234                columns,
2235                params,
2236                ctx_extra,
2237            } => {
2238                let row_desc =
2239                    row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2240                self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2241                    .await
2242            }
2243            ExecuteResponse::TransactionCommitted { params }
2244            | ExecuteResponse::TransactionRolledBack { params } => {
2245                let notify_set: mz_ore::collections::HashSet<String> = self
2246                    .adapter_client
2247                    .session()
2248                    .vars()
2249                    .notify_set()
2250                    .map(|v| v.name().to_string())
2251                    .collect();
2252
2253                // Only report on parameters that are in the notify set.
2254                for (name, value) in params
2255                    .into_iter()
2256                    .filter(|(name, _v)| notify_set.contains(*name))
2257                {
2258                    let msg = BackendMessage::ParameterStatus(name, value);
2259                    self.send(msg).await?;
2260                }
2261                command_complete!()
2262            }
2263
2264            ExecuteResponse::AlteredDefaultPrivileges
2265            | ExecuteResponse::AlteredObject(..)
2266            | ExecuteResponse::AlteredRole
2267            | ExecuteResponse::AlteredSystemConfiguration
2268            | ExecuteResponse::CreatedCluster { .. }
2269            | ExecuteResponse::CreatedClusterReplica { .. }
2270            | ExecuteResponse::CreatedConnection { .. }
2271            | ExecuteResponse::CreatedDatabase { .. }
2272            | ExecuteResponse::CreatedIndex { .. }
2273            | ExecuteResponse::CreatedIntrospectionSubscribe
2274            | ExecuteResponse::CreatedMaterializedView { .. }
2275            | ExecuteResponse::CreatedContinualTask { .. }
2276            | ExecuteResponse::CreatedRole
2277            | ExecuteResponse::CreatedSchema { .. }
2278            | ExecuteResponse::CreatedSecret { .. }
2279            | ExecuteResponse::CreatedSink { .. }
2280            | ExecuteResponse::CreatedSource { .. }
2281            | ExecuteResponse::CreatedTable { .. }
2282            | ExecuteResponse::CreatedType
2283            | ExecuteResponse::CreatedView { .. }
2284            | ExecuteResponse::CreatedViews { .. }
2285            | ExecuteResponse::CreatedNetworkPolicy
2286            | ExecuteResponse::Comment
2287            | ExecuteResponse::Deallocate { .. }
2288            | ExecuteResponse::Deleted(..)
2289            | ExecuteResponse::DiscardedAll
2290            | ExecuteResponse::DiscardedTemp
2291            | ExecuteResponse::DroppedObject(_)
2292            | ExecuteResponse::DroppedOwned
2293            | ExecuteResponse::GrantedPrivilege
2294            | ExecuteResponse::GrantedRole
2295            | ExecuteResponse::Inserted(..)
2296            | ExecuteResponse::Copied(..)
2297            | ExecuteResponse::Prepare
2298            | ExecuteResponse::Raised
2299            | ExecuteResponse::ReassignOwned
2300            | ExecuteResponse::RevokedPrivilege
2301            | ExecuteResponse::RevokedRole
2302            | ExecuteResponse::StartedTransaction { .. }
2303            | ExecuteResponse::Updated(..)
2304            | ExecuteResponse::ValidatedConnection => {
2305                command_complete!()
2306            }
2307        };
2308
2309        assert_none!(tag, "tag created but not consumed: {:?}", tag);
2310        r
2311    }
2312
2313    #[allow(clippy::too_many_arguments)]
2314    // TODO(guswynn): figure out how to get it to compile without skip_all
2315    #[mz_ore::instrument(level = "debug")]
2316    async fn send_rows(
2317        &mut self,
2318        row_desc: RelationDesc,
2319        portal_name: String,
2320        mut rows: InProgressRows,
2321        max_rows: ExecuteCount,
2322        get_response: GetResponse,
2323        fetch_portal_name: Option<String>,
2324        timeout: ExecuteTimeout,
2325    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2326        // If this portal is being executed from a FETCH then we need to use the result
2327        // format type of the outer portal.
2328        let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2329            name
2330        } else {
2331            &portal_name
2332        };
2333        let result_formats = self
2334            .adapter_client
2335            .session()
2336            .get_portal_unverified(result_format_portal_name)
2337            .expect("valid fetch portal name for send rows")
2338            .result_formats
2339            .clone();
2340
2341        let (mut wait_once, mut deadline) = match timeout {
2342            ExecuteTimeout::None => (false, None),
2343            ExecuteTimeout::Seconds(t) => (
2344                false,
2345                Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2346            ),
2347            ExecuteTimeout::WaitOnce => (true, None),
2348        };
2349
2350        // Sanity check that the various `RelationDesc`s match up.
2351        {
2352            let portal_name_desc = &self
2353                .adapter_client
2354                .session()
2355                .get_portal_unverified(portal_name.as_str())
2356                .expect("portal should exist")
2357                .desc
2358                .relation_desc;
2359            if let Some(portal_name_desc) = portal_name_desc {
2360                soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2361            }
2362            if let Some(fetch_portal_name) = &fetch_portal_name {
2363                let fetch_portal_desc = &self
2364                    .adapter_client
2365                    .session()
2366                    .get_portal_unverified(fetch_portal_name)
2367                    .expect("portal should exist")
2368                    .desc
2369                    .relation_desc;
2370                if let Some(fetch_portal_desc) = fetch_portal_desc {
2371                    soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2372                }
2373            }
2374        }
2375
2376        self.conn.set_encode_state(
2377            row_desc
2378                .typ()
2379                .column_types
2380                .iter()
2381                .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2382                .zip_eq(result_formats)
2383                .collect(),
2384        );
2385
2386        let mut total_sent_rows = 0;
2387        let mut total_sent_bytes = 0;
2388        // want_rows is the maximum number of rows the client wants.
2389        let mut want_rows = match max_rows {
2390            ExecuteCount::All => usize::MAX,
2391            ExecuteCount::Count(count) => count,
2392        };
2393
2394        // Send rows while the client still wants them and there are still rows to send.
2395        loop {
2396            // Fetch next batch of rows, waiting for a possible requested
2397            // timeout or notice.
2398            let batch = if rows.current.is_some() {
2399                FetchResult::Rows(rows.current.take())
2400            } else if want_rows == 0 {
2401                FetchResult::Rows(None)
2402            } else {
2403                let notice_fut = self.adapter_client.session().recv_notice();
2404                // Biased: drain available data before checking the deadline.
2405                // This is critical for the WaitOnce case, where the deadline
2406                // is set to `Instant::now()` right after the first batch:
2407                // without `biased`, `recv()` and the already-expired deadline
2408                // race nondeterministically, so we might break the loop
2409                // before `no_more_rows` is set (or even before ready rows
2410                // are consumed). With an explicit `TIMEOUT`, missing a batch
2411                // right at the boundary is acceptable, but WaitOnce fires
2412                // immediately and the race is not.
2413                //
2414                // Trade-off: if `recv()` keeps returning Ready (unlikely in
2415                // practice—row processing + flush is slower than upstream
2416                // tick granularity), a `TIMEOUT` deadline could be delayed.
2417                // See database-issues#9470.
2418                tokio::select! {
2419                    biased;
2420                    err = self.conn.wait_closed() => return Err(err),
2421                    batch = rows.remaining.recv() => match batch {
2422                        None => FetchResult::Rows(None),
2423                        Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2424                        Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2425                        Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2426                    },
2427                    notice = notice_fut => {
2428                        FetchResult::Notice(notice)
2429                    }
2430                    _ = time::sleep_until(
2431                        deadline.unwrap_or_else(tokio::time::Instant::now),
2432                    ), if deadline.is_some() => FetchResult::Rows(None),
2433                }
2434            };
2435
2436            match batch {
2437                FetchResult::Rows(None) => break,
2438                FetchResult::Rows(Some(mut batch_rows)) => {
2439                    if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2440                        let msg = err.to_string();
2441                        return self
2442                            .send_error_and_get_state(err.into_response(Severity::Error))
2443                            .await
2444                            .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2445                    }
2446
2447                    // If wait_once is true: the first time this fn is called it blocks (same as
2448                    // deadline == None). The second time this fn is called it should behave the
2449                    // same a 0s timeout.
2450                    if wait_once && batch_rows.peek().is_some() {
2451                        deadline = Some(tokio::time::Instant::now());
2452                        wait_once = false;
2453                    }
2454
2455                    // Send a portion of the rows.
2456                    let mut sent_rows = 0;
2457                    let mut sent_bytes = 0;
2458                    let messages = (&mut batch_rows)
2459                        // TODO(parkmycar): This is a fair bit of juggling between iterator types
2460                        // to count the total number of bytes. Alternatively we could track the
2461                        // total sent bytes in this .map(...) call, but having side effects in map
2462                        // is a code smell.
2463                        .map(|row| {
2464                            let row_len = row.byte_len();
2465                            let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2466                            (row_len, BackendMessage::DataRow(values))
2467                        })
2468                        .inspect(|(row_len, _)| {
2469                            sent_bytes += row_len;
2470                            sent_rows += 1
2471                        })
2472                        .map(|(_row_len, row)| row)
2473                        .take(want_rows);
2474                    self.send_all(messages).await?;
2475
2476                    total_sent_rows += sent_rows;
2477                    total_sent_bytes += sent_bytes;
2478                    want_rows -= sent_rows;
2479
2480                    // If we have sent the number of requested rows, put the remainder of the batch
2481                    // (if any) back and stop sending.
2482                    if want_rows == 0 {
2483                        if batch_rows.peek().is_some() {
2484                            rows.current = Some(batch_rows);
2485                        }
2486                        break;
2487                    }
2488
2489                    self.conn.flush().await?;
2490                }
2491                FetchResult::Notice(notice) => {
2492                    self.send(notice.into_response()).await?;
2493                    self.conn.flush().await?;
2494                }
2495                FetchResult::Error(text) => {
2496                    return self
2497                        .send_error_and_get_state(ErrorResponse::error(
2498                            SqlState::INTERNAL_ERROR,
2499                            text.clone(),
2500                        ))
2501                        .await
2502                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2503                }
2504                FetchResult::Canceled => {
2505                    return self
2506                        .send_error_and_get_state(ErrorResponse::error(
2507                            SqlState::QUERY_CANCELED,
2508                            "canceling statement due to user request",
2509                        ))
2510                        .await
2511                        .map(|state| (state, SendRowsEndedReason::Canceled));
2512                }
2513            }
2514        }
2515
2516        let portal = self
2517            .adapter_client
2518            .session()
2519            .get_portal_unverified_mut(&portal_name)
2520            .expect("valid portal name for send rows");
2521
2522        let saw_rows = rows.remaining.saw_rows;
2523        let no_more_rows = rows.no_more_rows();
2524        let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2525
2526        // Always return rows back, even if it's empty. This prevents an unclosed
2527        // portal from re-executing after it has been emptied.
2528        *portal.state = PortalState::InProgress(Some(rows));
2529
2530        let fetch_portal = fetch_portal_name.map(|name| {
2531            self.adapter_client
2532                .session()
2533                .get_portal_unverified_mut(&name)
2534                .expect("valid fetch portal")
2535        });
2536        let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2537        self.send(response_message).await?;
2538
2539        // Attend to metrics if there are no more rows.
2540        if no_more_rows {
2541            let statement_type = if let Some(stmt) = &self
2542                .adapter_client
2543                .session()
2544                .get_portal_unverified(&portal_name)
2545                .expect("valid portal name for send_rows")
2546                .stmt
2547            {
2548                metrics::statement_type_label_value(stmt.deref())
2549            } else {
2550                "no-statement"
2551            };
2552            let duration = if saw_rows {
2553                recorded_first_row_instant
2554                    .expect("recorded_first_row_instant because saw_rows")
2555                    .elapsed()
2556            } else {
2557                // If the result is empty, then we define time from first to last row as 0.
2558                // (Note that, currently, an empty result involves a PeekResponse with 0 rows, which
2559                // does flip `saw_rows`, so this code path is currently not exercised.)
2560                Duration::ZERO
2561            };
2562            self.adapter_client
2563                .inner()
2564                .metrics()
2565                .result_rows_first_to_last_byte_seconds
2566                .with_label_values(&[statement_type])
2567                .observe(duration.as_secs_f64());
2568        }
2569
2570        Ok((
2571            State::Ready,
2572            SendRowsEndedReason::Success {
2573                result_size: u64::cast_from(total_sent_bytes),
2574                rows_returned: u64::cast_from(total_sent_rows),
2575            },
2576        ))
2577    }
2578
2579    #[mz_ore::instrument(level = "debug")]
2580    async fn copy_rows(
2581        &mut self,
2582        format: CopyFormat,
2583        row_desc: RelationDesc,
2584        mut stream: RecordFirstRowStream,
2585    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2586        let (row_format, encode_format) = match format {
2587            CopyFormat::Text => (
2588                CopyFormatParams::Text(CopyTextFormatParams::default()),
2589                Format::Text,
2590            ),
2591            CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2592            CopyFormat::Csv => (
2593                CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2594                Format::Text,
2595            ),
2596            CopyFormat::Parquet => {
2597                let text = "Parquet format is not supported".to_string();
2598                return self
2599                    .send_error_and_get_state(ErrorResponse::error(
2600                        SqlState::INTERNAL_ERROR,
2601                        text.clone(),
2602                    ))
2603                    .await
2604                    .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2605            }
2606        };
2607
2608        let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2609            mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2610        };
2611
2612        let typ = row_desc.typ();
2613        let column_formats = iter::repeat(encode_format)
2614            .take(typ.column_types.len())
2615            .collect();
2616        self.send(BackendMessage::CopyOutResponse {
2617            overall_format: encode_format,
2618            column_formats,
2619        })
2620        .await?;
2621
2622        // In Postgres, binary copy has a header that is followed (in the same
2623        // CopyData) by the first row. In order to replicate their behavior, use a
2624        // common vec that we can extend one time now and then fill up with the encode
2625        // functions.
2626        let mut out = Vec::new();
2627
2628        if let CopyFormat::Binary = format {
2629            // 11-byte signature.
2630            out.extend(b"PGCOPY\n\xFF\r\n\0");
2631            // 32-bit flags field.
2632            out.extend([0, 0, 0, 0]);
2633            // 32-bit header extension length field.
2634            out.extend([0, 0, 0, 0]);
2635        }
2636
2637        let mut count = 0;
2638        let mut total_sent_bytes = 0;
2639        loop {
2640            tokio::select! {
2641                e = self.conn.wait_closed() => return Err(e),
2642                batch = stream.recv() => match batch {
2643                    None => break,
2644                    Some(PeekResponseUnary::Error(text)) => {
2645                        let err =
2646                            ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone());
2647                        return self
2648                            .send_error_and_get_state(err)
2649                            .await
2650                            .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2651                    }
2652                    Some(PeekResponseUnary::Canceled) => {
2653                        return self.send_error_and_get_state(ErrorResponse::error(
2654                                SqlState::QUERY_CANCELED,
2655                                "canceling statement due to user request",
2656                            ))
2657                            .await.map(|state| (state, SendRowsEndedReason::Canceled));
2658                    }
2659                    Some(PeekResponseUnary::Rows(mut rows)) => {
2660                        count += rows.count();
2661                        while let Some(row) = rows.next() {
2662                            total_sent_bytes += row.byte_len();
2663                            encode_fn(row, typ, &mut out)?;
2664                            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2665                                .await?;
2666                        }
2667                    }
2668                },
2669                notice = self.adapter_client.session().recv_notice() => {
2670                    self.send(notice.into_response())
2671                        .await?;
2672                    self.conn.flush().await?;
2673                }
2674            }
2675
2676            self.conn.flush().await?;
2677        }
2678        // Send required trailers.
2679        if let CopyFormat::Binary = format {
2680            let trailer: i16 = -1;
2681            out.extend(trailer.to_be_bytes());
2682            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2683                .await?;
2684        }
2685
2686        let tag = format!("COPY {}", count);
2687        self.send(BackendMessage::CopyDone).await?;
2688        self.send(BackendMessage::CommandComplete { tag }).await?;
2689        Ok((
2690            State::Ready,
2691            SendRowsEndedReason::Success {
2692                result_size: u64::cast_from(total_sent_bytes),
2693                rows_returned: u64::cast_from(count),
2694            },
2695        ))
2696    }
2697
2698    /// Handles the copy-in mode of the postgres protocol from transferring
2699    /// data to the server.
2700    #[instrument(level = "debug")]
2701    async fn copy_from(
2702        &mut self,
2703        target_id: CatalogItemId,
2704        target_name: String,
2705        columns: Vec<ColumnIndex>,
2706        params: CopyFormatParams<'static>,
2707        row_desc: RelationDesc,
2708        mut ctx_extra: ExecuteContextGuard,
2709    ) -> Result<State, io::Error> {
2710        let res = self
2711            .copy_from_inner(
2712                target_id,
2713                target_name,
2714                columns,
2715                params,
2716                row_desc,
2717                &mut ctx_extra,
2718            )
2719            .await;
2720        match &res {
2721            Ok(State::Ready) => {
2722                self.adapter_client.retire_execute(
2723                    ctx_extra,
2724                    StatementEndedExecutionReason::Success {
2725                        result_size: None,
2726                        rows_returned: None,
2727                        execution_strategy: None,
2728                    },
2729                );
2730            }
2731            Ok(State::Done) => {
2732                // The connection closed gracefully without sending us a `CopyDone`,
2733                // causing us to just drop the copy request.
2734                // For the purposes of statement logging, we count this as a cancellation.
2735                self.adapter_client
2736                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2737            }
2738            Err(e) => {
2739                self.adapter_client.retire_execute(
2740                    ctx_extra,
2741                    StatementEndedExecutionReason::Errored {
2742                        error: format!("{e}"),
2743                    },
2744                );
2745            }
2746            Ok(State::Drain) => {}
2747        }
2748        res
2749    }
2750
2751    async fn copy_from_inner(
2752        &mut self,
2753        target_id: CatalogItemId,
2754        target_name: String,
2755        columns: Vec<ColumnIndex>,
2756        params: CopyFormatParams<'static>,
2757        row_desc: RelationDesc,
2758        ctx_extra: &mut ExecuteContextGuard,
2759    ) -> Result<State, io::Error> {
2760        let typ = row_desc.typ();
2761        let column_formats = vec![Format::Text; typ.column_types.len()];
2762        self.send(BackendMessage::CopyInResponse {
2763            overall_format: Format::Text,
2764            column_formats,
2765        })
2766        .await?;
2767        self.conn.flush().await?;
2768
2769        // Set up the parallel streaming batch builders in the coordinator.
2770        let writer = match self
2771            .adapter_client
2772            .start_copy_from_stdin(
2773                target_id,
2774                target_name.clone(),
2775                columns.clone(),
2776                row_desc.clone(),
2777                params.clone(),
2778            )
2779            .await
2780        {
2781            Ok(writer) => writer,
2782            Err(e) => {
2783                // Drain remaining CopyData/CopyDone/CopyFail messages from the
2784                // socket. Since CopyInResponse was already sent, the client may
2785                // have pipelined copy data that we must consume before returning
2786                // the error, otherwise they'd be misinterpreted as top-level
2787                // protocol messages and cause a deadlock.
2788                loop {
2789                    match self.conn.recv().await? {
2790                        Some(FrontendMessage::CopyData(_)) => {}
2791                        Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2792                            break;
2793                        }
2794                        Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2795                        Some(_) => break,
2796                        None => return Ok(State::Done),
2797                    }
2798                }
2799                self.adapter_client.retire_execute(
2800                    std::mem::take(ctx_extra),
2801                    StatementEndedExecutionReason::Errored {
2802                        error: e.to_string(),
2803                    },
2804                );
2805                return self
2806                    .send_error_and_get_state(e.into_response(Severity::Error))
2807                    .await;
2808            }
2809        };
2810
2811        // Enable copy mode on the codec to skip aggregate buffer size checks.
2812        self.conn.set_copy_mode(true);
2813
2814        // Batch size for splitting raw data across parallel workers (~32MB).
2815        const BATCH_SIZE: usize = 32 * 1024 * 1024;
2816        let max_copy_from_row_size = self
2817            .adapter_client
2818            .get_system_vars()
2819            .await
2820            .max_copy_from_row_size()
2821            .try_into()
2822            .unwrap_or(usize::MAX);
2823
2824        let mut data = Vec::new();
2825        let mut row_scanner = CopyRowScanner::new(&params);
2826        let num_workers = writer.batch_txs.len();
2827        let mut next_worker: usize = 0;
2828        let mut saw_copy_done = false;
2829        let mut saw_end_marker = false;
2830        let mut copy_from_error: Option<(SqlState, String)> = None;
2831
2832        // Receive loop: accumulate CopyData, split at row boundaries,
2833        // round-robin raw chunks to parallel batch builder workers.
2834        loop {
2835            let message = self.conn.recv().await?;
2836            match message {
2837                Some(FrontendMessage::CopyData(buf)) => {
2838                    if saw_end_marker {
2839                        // Per PostgreSQL COPY behavior, ignore all bytes after
2840                        // the end-of-copy marker until CopyDone.
2841                        continue;
2842                    }
2843                    data.extend(buf);
2844                    row_scanner.scan_new_bytes(&data);
2845
2846                    if let Some(end_pos) = row_scanner.end_marker_end() {
2847                        data.truncate(end_pos);
2848                        row_scanner.on_truncate(end_pos);
2849                        saw_end_marker = true;
2850                    }
2851
2852                    // Guard against pathological single rows that never terminate.
2853                    if row_scanner.current_row_size(data.len()) > max_copy_from_row_size {
2854                        copy_from_error = Some((
2855                            SqlState::INSUFFICIENT_RESOURCES,
2856                            format!(
2857                                "COPY FROM STDIN row exceeded max_copy_from_row_size \
2858                                 ({max_copy_from_row_size} bytes)"
2859                            ),
2860                        ));
2861                        break;
2862                    }
2863
2864                    // When buffer exceeds batch size, split at the last complete row
2865                    // and send the complete rows chunk to the next worker.
2866                    let mut send_failed = false;
2867                    while data.len() >= BATCH_SIZE {
2868                        let split_pos = match row_scanner.last_row_end() {
2869                            Some(pos) => pos,
2870                            None => break, // no complete row yet
2871                        };
2872                        let remainder = data.split_off(split_pos);
2873                        let chunk = std::mem::replace(&mut data, remainder);
2874                        row_scanner.on_split(split_pos);
2875                        if writer.batch_txs[next_worker].send(chunk).await.is_err() {
2876                            send_failed = true;
2877                            break;
2878                        }
2879                        next_worker = (next_worker + 1) % num_workers;
2880                    }
2881                    // Worker dropped (likely errored) — stop sending,
2882                    // fall through to completion_rx for the real error.
2883                    if send_failed {
2884                        break;
2885                    }
2886                }
2887                Some(FrontendMessage::CopyDone) => {
2888                    // Send any remaining data to the next worker.
2889                    if !data.is_empty() {
2890                        let chunk = std::mem::take(&mut data);
2891                        // Ignore send failure — completion_rx will have the error.
2892                        let _ = writer.batch_txs[next_worker].send(chunk).await;
2893                    }
2894                    saw_copy_done = true;
2895                    break;
2896                }
2897                Some(FrontendMessage::CopyFail(err)) => {
2898                    self.adapter_client.retire_execute(
2899                        std::mem::take(ctx_extra),
2900                        StatementEndedExecutionReason::Canceled,
2901                    );
2902                    // Drop the writer to signal cancellation to the background tasks.
2903                    drop(writer);
2904                    self.conn.set_copy_mode(false);
2905                    return self
2906                        .send_error_and_get_state(ErrorResponse::error(
2907                            SqlState::QUERY_CANCELED,
2908                            format!("COPY from stdin failed: {}", err),
2909                        ))
2910                        .await;
2911                }
2912                Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2913                Some(_) => {
2914                    let msg = "unexpected message type during COPY from stdin";
2915                    self.adapter_client.retire_execute(
2916                        std::mem::take(ctx_extra),
2917                        StatementEndedExecutionReason::Errored {
2918                            error: msg.to_string(),
2919                        },
2920                    );
2921                    drop(writer);
2922                    self.conn.set_copy_mode(false);
2923                    return self
2924                        .send_error_and_get_state(ErrorResponse::error(
2925                            SqlState::PROTOCOL_VIOLATION,
2926                            msg,
2927                        ))
2928                        .await;
2929                }
2930                None => {
2931                    drop(writer);
2932                    self.conn.set_copy_mode(false);
2933                    return Ok(State::Done);
2934                }
2935            }
2936        }
2937
2938        // If we exited the receive loop before seeing `CopyDone` (e.g. because
2939        // a worker failed and dropped its channel), keep draining COPY input to
2940        // avoid desynchronizing the protocol state machine.
2941        if !saw_copy_done {
2942            loop {
2943                match self.conn.recv().await? {
2944                    Some(FrontendMessage::CopyData(_)) => {}
2945                    Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2946                        break;
2947                    }
2948                    Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2949                    Some(_) => {
2950                        let msg = "unexpected message type during COPY from stdin";
2951                        self.adapter_client.retire_execute(
2952                            std::mem::take(ctx_extra),
2953                            StatementEndedExecutionReason::Errored {
2954                                error: msg.to_string(),
2955                            },
2956                        );
2957                        drop(writer);
2958                        self.conn.set_copy_mode(false);
2959                        return self
2960                            .send_error_and_get_state(ErrorResponse::error(
2961                                SqlState::PROTOCOL_VIOLATION,
2962                                msg,
2963                            ))
2964                            .await;
2965                    }
2966                    None => {
2967                        drop(writer);
2968                        self.conn.set_copy_mode(false);
2969                        return Ok(State::Done);
2970                    }
2971                }
2972            }
2973        }
2974
2975        if let Some((code, msg)) = copy_from_error {
2976            self.adapter_client.retire_execute(
2977                std::mem::take(ctx_extra),
2978                StatementEndedExecutionReason::Errored { error: msg.clone() },
2979            );
2980            drop(writer);
2981            self.conn.set_copy_mode(false);
2982            return self
2983                .send_error_and_get_state(ErrorResponse::error(code, msg))
2984                .await;
2985        }
2986
2987        self.conn.set_copy_mode(false);
2988
2989        // Drop all senders to signal EOF to the background batch builders.
2990        // If copy_err is set, a worker already failed — dropping the senders
2991        // will cause remaining workers to stop, and we'll get the real error
2992        // from completion_rx below.
2993        drop(writer.batch_txs);
2994
2995        // Wait for all parallel workers to finish building batches.
2996        let (proto_batches, row_count) = match writer.completion_rx.await {
2997            Ok(Ok(result)) => result,
2998            Ok(Err(e)) => {
2999                self.adapter_client.retire_execute(
3000                    std::mem::take(ctx_extra),
3001                    StatementEndedExecutionReason::Errored {
3002                        error: e.to_string(),
3003                    },
3004                );
3005                return self
3006                    .send_error_and_get_state(e.into_response(Severity::Error))
3007                    .await;
3008            }
3009            Err(_) => {
3010                let msg = "COPY FROM STDIN: background batch builder tasks dropped";
3011                self.adapter_client.retire_execute(
3012                    std::mem::take(ctx_extra),
3013                    StatementEndedExecutionReason::Errored {
3014                        error: msg.to_string(),
3015                    },
3016                );
3017                return self
3018                    .send_error_and_get_state(ErrorResponse::error(SqlState::INTERNAL_ERROR, msg))
3019                    .await;
3020            }
3021        };
3022
3023        // Stage all batches in the session's transaction for atomic commit.
3024        if let Err(e) = self
3025            .adapter_client
3026            .stage_copy_from_stdin_batches(target_id, proto_batches)
3027        {
3028            self.adapter_client.retire_execute(
3029                std::mem::take(ctx_extra),
3030                StatementEndedExecutionReason::Errored {
3031                    error: e.to_string(),
3032                },
3033            );
3034            return self
3035                .send_error_and_get_state(e.into_response(Severity::Error))
3036                .await;
3037        }
3038
3039        let tag = format!("COPY {}", row_count);
3040        self.send(BackendMessage::CommandComplete { tag }).await?;
3041
3042        Ok(State::Ready)
3043    }
3044
3045    #[instrument(level = "debug")]
3046    async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
3047        let notices = self
3048            .adapter_client
3049            .session()
3050            .drain_notices()
3051            .into_iter()
3052            .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
3053        self.send_all(notices).await?;
3054        Ok(())
3055    }
3056
3057    #[instrument(level = "debug")]
3058    async fn send_error_and_get_state(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
3059        assert!(err.severity.is_error());
3060        debug!(
3061            "cid={} error code={}",
3062            self.adapter_client.session().conn_id(),
3063            err.code.code()
3064        );
3065        let is_fatal = err.severity.is_fatal();
3066        self.send(BackendMessage::ErrorResponse(err)).await?;
3067
3068        let txn = self.adapter_client.session().transaction();
3069        match txn {
3070            // Error can be called from describe and parse and so might not be in an active
3071            // transaction.
3072            TransactionStatus::Default | TransactionStatus::Failed(_) => {}
3073            // In Started (i.e., a single statement), cleanup ourselves.
3074            TransactionStatus::Started(_) => {
3075                self.rollback_transaction().await?;
3076            }
3077            // Implicit transactions also clear themselves.
3078            TransactionStatus::InTransactionImplicit(_) => {
3079                self.rollback_transaction().await?;
3080            }
3081            // Explicit transactions move to failed.
3082            TransactionStatus::InTransaction(_) => {
3083                self.adapter_client.fail_transaction();
3084            }
3085        };
3086        if is_fatal {
3087            Ok(State::Done)
3088        } else {
3089            Ok(State::Drain)
3090        }
3091    }
3092
3093    #[instrument(level = "debug")]
3094    async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
3095        self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
3096            SqlState::IN_FAILED_SQL_TRANSACTION,
3097            ABORTED_TXN_MSG,
3098        )))
3099        .await?;
3100        Ok(State::Drain)
3101    }
3102
3103    fn is_aborted_txn(&mut self) -> bool {
3104        matches!(
3105            self.adapter_client.session().transaction(),
3106            TransactionStatus::Failed(_)
3107        )
3108    }
3109}
3110
3111fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
3112    match (formats.len(), n) {
3113        (0, e) => Ok(vec![Format::Text; e]),
3114        (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
3115        (a, e) if a == e => Ok(formats),
3116        (a, e) => Err(format!(
3117            "expected {} field format specifiers, but got {}",
3118            e, a
3119        )),
3120    }
3121}
3122
3123fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
3124    match &stmt_desc.relation_desc {
3125        Some(desc) if !stmt_desc.is_copy => {
3126            BackendMessage::RowDescription(message::encode_row_description(desc, formats))
3127        }
3128        _ => BackendMessage::NoData,
3129    }
3130}
3131
3132type GetResponse = fn(
3133    max_rows: ExecuteCount,
3134    total_sent_rows: usize,
3135    fetch_portal: Option<PortalRefMut>,
3136) -> BackendMessage;
3137
3138// A GetResponse used by send_rows during execute messages on portals or for
3139// simple query messages.
3140fn portal_exec_message(
3141    max_rows: ExecuteCount,
3142    total_sent_rows: usize,
3143    _fetch_portal: Option<PortalRefMut>,
3144) -> BackendMessage {
3145    // If max_rows is not specified, we will always send back a CommandComplete. If
3146    // max_rows is specified, we only send CommandComplete if there were more rows
3147    // requested than were remaining. That is, if max_rows == number of rows that
3148    // were remaining before sending (not that are remaining after sending), then
3149    // we still send a PortalSuspended. The number of remaining rows after the rows
3150    // have been sent doesn't matter. This matches postgres.
3151    match max_rows {
3152        ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
3153            BackendMessage::PortalSuspended
3154        }
3155        _ => BackendMessage::CommandComplete {
3156            tag: format!("SELECT {}", total_sent_rows),
3157        },
3158    }
3159}
3160
3161// A GetResponse used by send_rows during FETCH queries.
3162fn fetch_message(
3163    _max_rows: ExecuteCount,
3164    total_sent_rows: usize,
3165    fetch_portal: Option<PortalRefMut>,
3166) -> BackendMessage {
3167    let tag = format!("FETCH {}", total_sent_rows);
3168    if let Some(portal) = fetch_portal {
3169        *portal.state = PortalState::Completed(Some(tag.clone()));
3170    }
3171    BackendMessage::CommandComplete { tag }
3172}
3173
3174fn get_authenticator(
3175    authenticator_kind: listeners::AuthenticatorKind,
3176    frontegg: Option<FronteggAuthenticator>,
3177    oidc: GenericOidcAuthenticator,
3178    adapter_client: mz_adapter::Client,
3179    // If oidc_auth_enabled exists as an option in the pgwire connection's
3180    // `option` parameter
3181    oidc_auth_option_enabled: bool,
3182) -> Authenticator {
3183    match authenticator_kind {
3184        listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
3185            "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
3186        )),
3187        listeners::AuthenticatorKind::Password => Authenticator::Password(adapter_client),
3188        listeners::AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client),
3189        listeners::AuthenticatorKind::Oidc => {
3190            if oidc_auth_option_enabled {
3191                Authenticator::Oidc(oidc)
3192            } else {
3193                // Fallback to password authentication if oidc auth is not enabled
3194                // through options.
3195                Authenticator::Password(adapter_client)
3196            }
3197        }
3198        listeners::AuthenticatorKind::None => Authenticator::None,
3199    }
3200}
3201
3202#[derive(Debug, Copy, Clone)]
3203enum ExecuteCount {
3204    All,
3205    Count(usize),
3206}
3207
3208// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
3209fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
3210    match stmt {
3211        // Add PREPARE to this if we ever support it.
3212        Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
3213        None => false,
3214    }
3215}
3216
3217#[derive(Debug)]
3218enum FetchResult {
3219    Rows(Option<Box<dyn RowIterator + Send + Sync>>),
3220    Canceled,
3221    Error(String),
3222    Notice(AdapterNotice),
3223}
3224
3225#[derive(Debug)]
3226struct CopyRowScanner {
3227    scan_pos: usize,
3228    last_row_end: Option<usize>,
3229    end_marker_end: Option<usize>,
3230    csv: Option<CsvScanState>,
3231}
3232
3233#[derive(Debug)]
3234struct CsvScanState {
3235    reader: csv_core::Reader,
3236    output: Vec<u8>,
3237    ends: Vec<usize>,
3238    record: Vec<u8>,
3239    record_ends: Vec<usize>,
3240    skip_first_record: bool,
3241}
3242
3243impl CopyRowScanner {
3244    fn new(params: &CopyFormatParams<'_>) -> Self {
3245        let csv = match params {
3246            CopyFormatParams::Csv(CopyCsvFormatParams {
3247                delimiter,
3248                quote,
3249                escape,
3250                header,
3251                ..
3252            }) => Some(CsvScanState::new(*delimiter, *quote, *escape, *header)),
3253            _ => None,
3254        };
3255
3256        CopyRowScanner {
3257            scan_pos: 0,
3258            last_row_end: None,
3259            end_marker_end: None,
3260            csv,
3261        }
3262    }
3263
3264    fn scan_new_bytes(&mut self, data: &[u8]) {
3265        if self.scan_pos >= data.len() {
3266            return;
3267        }
3268
3269        if let Some(csv) = self.csv.as_mut() {
3270            let mut input = &data[self.scan_pos..];
3271            let mut consumed = 0usize;
3272            while !input.is_empty() {
3273                let (result, n_input, n_output, n_ends) =
3274                    csv.reader
3275                        .read_record(input, &mut csv.output, &mut csv.ends);
3276                consumed += n_input;
3277                input = &input[n_input..];
3278                if !csv.output.is_empty() {
3279                    csv.record.extend_from_slice(&csv.output[..n_output]);
3280                }
3281                if !csv.ends.is_empty() {
3282                    csv.record_ends.extend_from_slice(&csv.ends[..n_ends]);
3283                }
3284
3285                match result {
3286                    ReadRecordResult::InputEmpty => break,
3287                    ReadRecordResult::OutputFull => {
3288                        if n_input == 0 {
3289                            csv.output
3290                                .resize(csv.output.len().saturating_mul(2).max(1), 0);
3291                        }
3292                    }
3293                    ReadRecordResult::OutputEndsFull => {
3294                        if n_input == 0 {
3295                            csv.ends.resize(csv.ends.len().saturating_mul(2).max(1), 0);
3296                        }
3297                    }
3298                    ReadRecordResult::Record | ReadRecordResult::End => {
3299                        let row_end = self.scan_pos + consumed;
3300                        self.last_row_end = Some(row_end);
3301                        if self.end_marker_end.is_none() {
3302                            let is_marker = if csv.skip_first_record {
3303                                csv.skip_first_record = false;
3304                                false
3305                            } else if csv.record_ends.len() == 1 {
3306                                let end = csv.record_ends[0];
3307                                end == 2 && csv.record.get(0..end) == Some(b"\\.")
3308                            } else {
3309                                false
3310                            };
3311                            if is_marker {
3312                                self.end_marker_end = Some(row_end);
3313                                break;
3314                            }
3315                        }
3316                        csv.record.clear();
3317                        csv.record_ends.clear();
3318                    }
3319                }
3320            }
3321        } else {
3322            let mut row_start = self.last_row_end.unwrap_or(0);
3323            for (offset, b) in data[self.scan_pos..].iter().enumerate() {
3324                if *b == b'\n' {
3325                    let row_end = self.scan_pos + offset + 1;
3326                    self.last_row_end = Some(row_end);
3327                    if self.end_marker_end.is_none() {
3328                        let row = &data[row_start..row_end];
3329                        if row.get(0..2) == Some(b"\\.") {
3330                            self.end_marker_end = Some(row_end);
3331                            break;
3332                        }
3333                    }
3334                    row_start = row_end;
3335                }
3336            }
3337        }
3338
3339        self.scan_pos = data.len();
3340    }
3341
3342    fn last_row_end(&self) -> Option<usize> {
3343        self.last_row_end
3344    }
3345
3346    fn end_marker_end(&self) -> Option<usize> {
3347        self.end_marker_end
3348    }
3349
3350    fn current_row_size(&self, data_len: usize) -> usize {
3351        data_len.saturating_sub(self.last_row_end.unwrap_or(0))
3352    }
3353
3354    fn on_split(&mut self, split_pos: usize) {
3355        self.scan_pos = self.scan_pos.saturating_sub(split_pos);
3356        self.last_row_end = None;
3357        self.end_marker_end = self
3358            .end_marker_end
3359            .and_then(|end| end.checked_sub(split_pos));
3360    }
3361
3362    fn on_truncate(&mut self, new_len: usize) {
3363        self.scan_pos = self.scan_pos.min(new_len);
3364        self.last_row_end = self.last_row_end.filter(|&end| end <= new_len);
3365        self.end_marker_end = self.end_marker_end.filter(|&end| end <= new_len);
3366    }
3367}
3368
3369impl CsvScanState {
3370    fn new(delimiter: u8, quote: u8, escape: u8, header: bool) -> Self {
3371        let (double_quote, escape) = if quote == escape {
3372            (true, None)
3373        } else {
3374            (false, Some(escape))
3375        };
3376        CsvScanState {
3377            reader: csv_core::ReaderBuilder::new()
3378                .delimiter(delimiter)
3379                .quote(quote)
3380                .double_quote(double_quote)
3381                .escape(escape)
3382                .build(),
3383            output: vec![0; 1],
3384            ends: vec![0; 1],
3385            record: Vec::new(),
3386            record_ends: Vec::new(),
3387            skip_first_record: header,
3388        }
3389    }
3390}
3391
3392#[cfg(test)]
3393mod test {
3394    use super::*;
3395
3396    #[mz_ore::test]
3397    fn test_parse_options() {
3398        struct TestCase {
3399            input: &'static str,
3400            expect: Result<Vec<(&'static str, &'static str)>, ()>,
3401        }
3402        let tests = vec![
3403            TestCase {
3404                input: "",
3405                expect: Ok(vec![]),
3406            },
3407            TestCase {
3408                input: "--key",
3409                expect: Err(()),
3410            },
3411            TestCase {
3412                input: "--key=val",
3413                expect: Ok(vec![("key", "val")]),
3414            },
3415            TestCase {
3416                input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
3417                expect: Ok(vec![
3418                    ("key", "val"),
3419                    ("key2", "val2"),
3420                    ("key3", "val3"),
3421                    ("key4", "val4"),
3422                    ("key5", "val5"),
3423                ]),
3424            },
3425            TestCase {
3426                input: r#"-c\ key=val"#,
3427                expect: Ok(vec![(" key", "val")]),
3428            },
3429            TestCase {
3430                input: "--key=val -ckey2 val2",
3431                expect: Err(()),
3432            },
3433            // Unclear what this should do.
3434            TestCase {
3435                input: "--key=",
3436                expect: Ok(vec![("key", "")]),
3437            },
3438        ];
3439        for test in tests {
3440            let got = parse_options(test.input);
3441            let expect = test.expect.map(|r| {
3442                r.into_iter()
3443                    .map(|(k, v)| (k.to_owned(), v.to_owned()))
3444                    .collect()
3445            });
3446            assert_eq!(got, expect, "input: {}", test.input);
3447        }
3448    }
3449
3450    #[mz_ore::test]
3451    fn test_parse_option() {
3452        struct TestCase {
3453            input: &'static str,
3454            expect: Result<(&'static str, &'static str), ()>,
3455        }
3456        let tests = vec![
3457            TestCase {
3458                input: "",
3459                expect: Err(()),
3460            },
3461            TestCase {
3462                input: "--",
3463                expect: Err(()),
3464            },
3465            TestCase {
3466                input: "--c",
3467                expect: Err(()),
3468            },
3469            TestCase {
3470                input: "a=b",
3471                expect: Err(()),
3472            },
3473            TestCase {
3474                input: "--a=b",
3475                expect: Ok(("a", "b")),
3476            },
3477            TestCase {
3478                input: "--ca=b",
3479                expect: Ok(("ca", "b")),
3480            },
3481            TestCase {
3482                input: "-ca=b",
3483                expect: Ok(("a", "b")),
3484            },
3485            // Unclear what this should error, but at least test it.
3486            TestCase {
3487                input: "--=",
3488                expect: Ok(("", "")),
3489            },
3490        ];
3491        for test in tests {
3492            let got = parse_option(test.input);
3493            assert_eq!(got, test.expect, "input: {}", test.input);
3494        }
3495    }
3496
3497    #[mz_ore::test]
3498    fn test_split_options() {
3499        struct TestCase {
3500            input: &'static str,
3501            expect: Vec<&'static str>,
3502        }
3503        let tests = vec![
3504            TestCase {
3505                input: "",
3506                expect: vec![],
3507            },
3508            TestCase {
3509                input: "  ",
3510                expect: vec![],
3511            },
3512            TestCase {
3513                input: " a ",
3514                expect: vec!["a"],
3515            },
3516            TestCase {
3517                input: "  ab     cd   ",
3518                expect: vec!["ab", "cd"],
3519            },
3520            TestCase {
3521                input: r#"  ab\     cd   "#,
3522                expect: vec!["ab ", "cd"],
3523            },
3524            TestCase {
3525                input: r#"  ab\\     cd   "#,
3526                expect: vec![r#"ab\"#, "cd"],
3527            },
3528            TestCase {
3529                input: r#"  ab\\\     cd   "#,
3530                expect: vec![r#"ab\ "#, "cd"],
3531            },
3532            TestCase {
3533                input: r#"  ab\\\ cd   "#,
3534                expect: vec![r#"ab\ cd"#],
3535            },
3536            TestCase {
3537                input: r#"  ab\\\cd   "#,
3538                expect: vec![r#"ab\cd"#],
3539            },
3540            TestCase {
3541                input: r#"a\"#,
3542                expect: vec!["a"],
3543            },
3544            TestCase {
3545                input: r#"a\ "#,
3546                expect: vec!["a "],
3547            },
3548            TestCase {
3549                input: r#"\"#,
3550                expect: vec![],
3551            },
3552            TestCase {
3553                input: r#"\ "#,
3554                expect: vec![r#" "#],
3555            },
3556            TestCase {
3557                input: r#" \ "#,
3558                expect: vec![r#" "#],
3559            },
3560            TestCase {
3561                input: r#"\  "#,
3562                expect: vec![r#" "#],
3563            },
3564        ];
3565        for test in tests {
3566            let got = split_options(test.input);
3567            assert_eq!(got, test.expect, "input: {}", test.input);
3568        }
3569    }
3570}