Skip to main content

mz_pgwire/
protocol.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use csv_core::ReadRecordResult;
21use futures::future::{BoxFuture, FutureExt, pending};
22use itertools::Itertools;
23use mz_adapter::client::RecordFirstRowStream;
24use mz_adapter::session::{
25    EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState, Session,
26    SessionConfig, TransactionStatus,
27};
28use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
29use mz_adapter::{
30    AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics,
31    verify_datum_desc,
32};
33use mz_auth::Authenticated;
34use mz_auth::password::Password;
35use mz_authenticator::{Authenticator, GenericOidcAuthenticator};
36use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
37use mz_ore::cast::CastFrom;
38use mz_ore::netio::AsyncReady;
39use mz_ore::now::{EpochMillis, SYSTEM_TIME};
40use mz_ore::str::StrExt;
41use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
42use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
43use mz_pgwire_common::{
44    ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
45    VERSIONS,
46};
47use mz_repr::{
48    CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
49    SqlRelationType, SqlScalarType,
50};
51use mz_server_core::TlsMode;
52use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind};
53use mz_sql::ast::display::AstDisplay;
54use mz_sql::ast::{
55    CopyDirection, CopyStatement, CopyTarget, FetchDirection, Ident, Raw, Statement,
56};
57use mz_sql::parse::StatementParseResult;
58use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
59use mz_sql::session::metadata::SessionMetadata;
60use mz_sql::session::user::INTERNAL_USER_NAMES;
61use mz_sql::session::vars::VarInput;
62use postgres::error::SqlState;
63use tokio::io::{self, AsyncRead, AsyncWrite};
64use tokio::select;
65use tokio::time::{self};
66use tokio_metrics::TaskMetrics;
67use tokio_stream::wrappers::UnboundedReceiverStream;
68use tracing::{Instrument, debug, debug_span, warn};
69use uuid::Uuid;
70
71use crate::codec::{
72    FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
73};
74use crate::message::{
75    self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
76    SASLServerFirstMessage,
77};
78
79/// Reports whether the given stream begins with a pgwire handshake.
80///
81/// To avoid false negatives, there must be at least eight bytes in `buf`.
82pub fn match_handshake(buf: &[u8]) -> bool {
83    // The pgwire StartupMessage looks like this:
84    //
85    //     i32 - Length of entire message.
86    //     i32 - Protocol version number.
87    //     [String] - Arbitrary key-value parameters of any length.
88    //
89    // Since arbitrary parameters can be included in the StartupMessage, the
90    // first Int32 is worthless, since the message could have any length.
91    // Instead, we sniff the protocol version number.
92    if buf.len() < 8 {
93        return false;
94    }
95    let version = NetworkEndian::read_i32(&buf[4..8]);
96    VERSIONS.contains(&version)
97}
98
99/// Parameters for the [`run`] function.
100pub struct RunParams<'a, A, I>
101where
102    I: Iterator<Item = TaskMetrics> + Send,
103{
104    /// The TLS mode of the pgwire server.
105    pub tls_mode: Option<TlsMode>,
106    /// A client for the adapter.
107    pub adapter_client: mz_adapter::Client,
108    /// The connection to the client.
109    pub conn: &'a mut FramedConn<A>,
110    /// The universally unique identifier for the connection.
111    pub conn_uuid: Uuid,
112    /// The protocol version that the client provided in the startup message.
113    pub version: i32,
114    /// The parameters that the client provided in the startup message.
115    pub params: BTreeMap<String, String>,
116    /// Frontegg JWT authenticator.
117    pub frontegg: Option<FronteggAuthenticator>,
118    /// OIDC authenticator.
119    pub oidc: GenericOidcAuthenticator,
120    /// The authentication method defined by the server's listener
121    /// configuration.
122    pub authenticator_kind: AuthenticatorKind,
123    /// Global connection limit and count
124    pub active_connection_counter: ConnectionCounter,
125    /// Helm chart version
126    pub helm_chart_version: Option<String>,
127    /// Whether to allow reserved users (ie: mz_system).
128    pub allowed_roles: AllowedRoles,
129    /// Tokio metrics
130    pub tokio_metrics_intervals: I,
131}
132
133/// Runs a pgwire connection to completion.
134///
135/// This involves responding to `FrontendMessage::StartupMessage` and all future
136/// requests until the client terminates the connection or a fatal error occurs.
137///
138/// Note that this function returns successfully even upon delivering a fatal
139/// error to the client. It only returns `Err` if an unexpected I/O error occurs
140/// while communicating with the client, e.g., if the connection is severed in
141/// the middle of a request.
142#[mz_ore::instrument(level = "debug")]
143pub async fn run<'a, A, I>(
144    RunParams {
145        tls_mode,
146        adapter_client,
147        conn,
148        conn_uuid,
149        version,
150        mut params,
151        frontegg,
152        oidc,
153        authenticator_kind,
154        active_connection_counter,
155        helm_chart_version,
156        allowed_roles,
157        tokio_metrics_intervals,
158    }: RunParams<'a, A, I>,
159) -> Result<(), io::Error>
160where
161    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
162    I: Iterator<Item = TaskMetrics> + Send,
163{
164    if version != VERSION_3 {
165        return conn
166            .send(ErrorResponse::fatal(
167                SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
168                "server does not support the client's requested protocol version",
169            ))
170            .await;
171    }
172
173    let user = params.remove("user").unwrap_or_else(String::new);
174    let options = parse_options(params.get("options").unwrap_or(&String::new()));
175
176    // If oidc_auth_enabled exists as an option, return its value and filter it from
177    // the remaining options.
178    let (oidc_auth_enabled, options) = extract_oidc_auth_enabled_from_options(options);
179    let authenticator = get_authenticator(
180        authenticator_kind,
181        frontegg,
182        oidc,
183        adapter_client.clone(),
184        oidc_auth_enabled,
185    );
186    // TODO move this somewhere it can be shared with HTTP
187    let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
188    // this is a superset of internal users
189    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
190    let role_allowed = match allowed_roles {
191        AllowedRoles::Normal => !is_reserved_user,
192        AllowedRoles::Internal => is_internal_user,
193        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
194    };
195    if !role_allowed {
196        let msg = format!("unauthorized login to user '{user}'");
197        return conn
198            .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
199            .await;
200    }
201
202    if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
203        return conn.send(err).await;
204    }
205
206    let (mut session, expired) = match authenticator {
207        Authenticator::Frontegg(frontegg) => {
208            let password = match request_cleartext_password(conn).await {
209                Ok(password) => password,
210                Err(PasswordRequestError::IoError(e)) => return Err(e),
211                Err(PasswordRequestError::InvalidPasswordError(e)) => {
212                    return conn.send(e).await;
213                }
214            };
215
216            let auth_response = frontegg.authenticate(&user, &password).await;
217            match auth_response {
218                // Create a session based on the auth session.
219                //
220                // In particular, it's important that the username come from the
221                // auth session, as Frontegg may return an email address with
222                // different casing than the user supplied via the pgwire
223                // username fN
224                Ok((mut auth_session, authenticated)) => {
225                    let session = adapter_client.new_session(
226                        SessionConfig {
227                            conn_id: conn.conn_id().clone(),
228                            uuid: conn_uuid,
229                            user: auth_session.user().into(),
230                            client_ip: conn.peer_addr().clone(),
231                            external_metadata_rx: Some(auth_session.external_metadata_rx()),
232                            helm_chart_version,
233                        },
234                        authenticated,
235                    );
236                    let expired = async move { auth_session.expired().await };
237                    (session, expired.left_future())
238                }
239                Err(err) => {
240                    warn!(?err, "pgwire connection failed authentication");
241                    return conn
242                        .send(ErrorResponse::fatal(
243                            SqlState::INVALID_PASSWORD,
244                            "invalid password",
245                        ))
246                        .await;
247                }
248            }
249        }
250        Authenticator::Oidc(oidc) => {
251            // OIDC authentication: JWT sent as password in cleartext flow
252            let jwt = match request_cleartext_password(conn).await {
253                Ok(password) => password,
254                Err(PasswordRequestError::IoError(e)) => return Err(e),
255                Err(PasswordRequestError::InvalidPasswordError(e)) => {
256                    return conn.send(e).await;
257                }
258            };
259            let auth_response = oidc.authenticate(&jwt, Some(&user)).await;
260            match auth_response {
261                Ok((claims, authenticated)) => {
262                    let session = adapter_client.new_session(
263                        SessionConfig {
264                            conn_id: conn.conn_id().clone(),
265                            uuid: conn_uuid,
266                            user: claims.user,
267                            client_ip: conn.peer_addr().clone(),
268                            external_metadata_rx: None,
269                            helm_chart_version,
270                        },
271                        authenticated,
272                    );
273                    // No invalidation of the auth session once authenticated,
274                    // so auth session lasts indefinitely.
275                    (session, pending().right_future())
276                }
277                Err(err) => {
278                    warn!(?err, "pgwire connection failed authentication");
279                    return conn.send(err.into_response()).await;
280                }
281            }
282        }
283        Authenticator::Password(adapter_client) => {
284            let session = match authenticate_with_password(
285                conn,
286                &adapter_client,
287                user,
288                conn_uuid,
289                helm_chart_version,
290            )
291            .await
292            {
293                Ok(session) => session,
294                Err(PasswordRequestError::IoError(e)) => return Err(e),
295                Err(PasswordRequestError::InvalidPasswordError(e)) => {
296                    return conn.send(e).await;
297                }
298            };
299            // No frontegg check, so auth session lasts indefinitely.
300            (session, pending().right_future())
301        }
302        Authenticator::Sasl(adapter_client) => {
303            // Start the handshake
304            conn.send(BackendMessage::AuthenticationSASL).await?;
305            conn.flush().await?;
306            // Get the initial response indicating chosen mechanism
307            let (mechanism, initial_response) = match conn.recv().await? {
308                Some(FrontendMessage::RawAuthentication(data)) => {
309                    match decode_sasl_initial_response(Cursor::new(&data)).ok() {
310                        Some(FrontendMessage::SASLInitialResponse {
311                            gs2_header,
312                            mechanism,
313                            initial_response,
314                        }) => {
315                            // We do not support channel binding
316                            if gs2_header.channel_binding_enabled() {
317                                return conn
318                                    .send(ErrorResponse::fatal(
319                                        SqlState::PROTOCOL_VIOLATION,
320                                        "channel binding not supported",
321                                    ))
322                                    .await;
323                            }
324                            (mechanism, initial_response)
325                        }
326                        _ => {
327                            return conn
328                                .send(ErrorResponse::fatal(
329                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
330                                    "expected SASLInitialResponse message",
331                                ))
332                                .await;
333                        }
334                    }
335                }
336                _ => {
337                    return conn
338                        .send(ErrorResponse::fatal(
339                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
340                            "expected SASLInitialResponse message",
341                        ))
342                        .await;
343                }
344            };
345
346            if mechanism != "SCRAM-SHA-256" {
347                return conn
348                    .send(ErrorResponse::fatal(
349                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
350                        "unsupported SASL mechanism",
351                    ))
352                    .await;
353            }
354
355            if initial_response.nonce.len() > 256 {
356                return conn
357                    .send(ErrorResponse::fatal(
358                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
359                        "nonce too long",
360                    ))
361                    .await;
362            }
363
364            let (server_first_message_raw, mock_hash) = match adapter_client
365                .generate_sasl_challenge(&user, &initial_response.nonce)
366                .await
367            {
368                Ok(response) => {
369                    let server_first_message_raw = format!(
370                        "r={},s={},i={}",
371                        response.nonce, response.salt, response.iteration_count
372                    );
373
374                    let client_key = [0u8; 32];
375                    let server_key = [1u8; 32];
376                    let mock_hash = format!(
377                        "SCRAM-SHA-256${}:{}${}:{}",
378                        response.iteration_count,
379                        response.salt,
380                        BASE64_STANDARD.encode(client_key),
381                        BASE64_STANDARD.encode(server_key)
382                    );
383
384                    conn.send(BackendMessage::AuthenticationSASLContinue(
385                        SASLServerFirstMessage {
386                            iteration_count: response.iteration_count,
387                            nonce: response.nonce,
388                            salt: response.salt,
389                        },
390                    ))
391                    .await?;
392                    conn.flush().await?;
393                    (server_first_message_raw, mock_hash)
394                }
395                Err(e) => {
396                    return conn.send(e.into_response(Severity::Fatal)).await;
397                }
398            };
399
400            let authenticated = match conn.recv().await? {
401                Some(FrontendMessage::RawAuthentication(data)) => {
402                    match decode_sasl_response(Cursor::new(&data)).ok() {
403                        Some(FrontendMessage::SASLResponse(response)) => {
404                            let auth_message = format!(
405                                "{},{},{}",
406                                initial_response.client_first_message_bare_raw,
407                                server_first_message_raw,
408                                response.client_final_message_bare_raw
409                            );
410                            if response.proof.len() > 1024 {
411                                return conn
412                                    .send(ErrorResponse::fatal(
413                                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
414                                        "proof too long",
415                                    ))
416                                    .await;
417                            }
418                            match adapter_client
419                                .verify_sasl_proof(
420                                    &user,
421                                    &response.proof,
422                                    &auth_message,
423                                    &mock_hash,
424                                )
425                                .await
426                            {
427                                Ok((proof_response, authenticated)) => {
428                                    conn.send(BackendMessage::AuthenticationSASLFinal(
429                                        SASLServerFinalMessage {
430                                            kind: SASLServerFinalMessageKinds::Verifier(
431                                                proof_response.verifier,
432                                            ),
433                                            extensions: vec![],
434                                        },
435                                    ))
436                                    .await?;
437                                    conn.flush().await?;
438                                    authenticated
439                                }
440                                Err(_) => {
441                                    return conn
442                                        .send(ErrorResponse::fatal(
443                                            SqlState::INVALID_PASSWORD,
444                                            "invalid password",
445                                        ))
446                                        .await;
447                                }
448                            }
449                        }
450                        _ => {
451                            return conn
452                                .send(ErrorResponse::fatal(
453                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
454                                    "expected SASLResponse message",
455                                ))
456                                .await;
457                        }
458                    }
459                }
460                _ => {
461                    return conn
462                        .send(ErrorResponse::fatal(
463                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
464                            "expected SASLResponse message",
465                        ))
466                        .await;
467                }
468            };
469
470            let session = adapter_client.new_session(
471                SessionConfig {
472                    conn_id: conn.conn_id().clone(),
473                    uuid: conn_uuid,
474                    user,
475                    client_ip: conn.peer_addr().clone(),
476                    external_metadata_rx: None,
477                    helm_chart_version,
478                },
479                authenticated,
480            );
481            // No frontegg check, so auth session lasts indefinitely.
482            let auth_session = pending().right_future();
483            (session, auth_session)
484        }
485
486        Authenticator::None => {
487            let session = adapter_client.new_session(
488                SessionConfig {
489                    conn_id: conn.conn_id().clone(),
490                    uuid: conn_uuid,
491                    user,
492                    client_ip: conn.peer_addr().clone(),
493                    external_metadata_rx: None,
494                    helm_chart_version,
495                },
496                Authenticated,
497            );
498            // No frontegg check, so auth session lasts indefinitely.
499            let auth_session = pending().right_future();
500            (session, auth_session)
501        }
502    };
503
504    let system_vars = adapter_client.get_system_vars().await;
505    for (name, value) in params {
506        let settings = match name.as_str() {
507            "options" => match &options {
508                Ok(opts) => opts,
509                Err(()) => {
510                    session.add_notice(AdapterNotice::BadStartupSetting {
511                        name,
512                        reason: "could not parse".into(),
513                    });
514                    continue;
515                }
516            },
517            _ => &vec![(name, value)],
518        };
519        for (key, val) in settings {
520            const LOCAL: bool = false;
521            // TODO: Issuing an error here is better than what we did before
522            // (silently ignore errors on set), but erroring the connection
523            // might be the better behavior. We maybe need to support more
524            // options sent by psql and drivers before we can safely do this.
525            if let Err(err) = session
526                .vars_mut()
527                .set(&system_vars, key, VarInput::Flat(val), LOCAL)
528            {
529                session.add_notice(AdapterNotice::BadStartupSetting {
530                    name: key.clone(),
531                    reason: err.to_string(),
532                });
533            }
534        }
535    }
536    session
537        .vars_mut()
538        .end_transaction(EndTransactionAction::Commit);
539
540    let _guard = match active_connection_counter.allocate_connection(session.user()) {
541        Ok(drop_connection) => drop_connection,
542        Err(e) => {
543            let e: AdapterError = e.into();
544            return conn.send(e.into_response(Severity::Fatal)).await;
545        }
546    };
547
548    // Register session with adapter.
549    let mut adapter_client = match adapter_client.startup(session).await {
550        Ok(adapter_client) => adapter_client,
551        Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
552    };
553
554    let mut buf = vec![BackendMessage::AuthenticationOk];
555    for var in adapter_client.session().vars().notify_set() {
556        buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
557    }
558    buf.push(BackendMessage::BackendKeyData {
559        conn_id: adapter_client.session().conn_id().unhandled(),
560        secret_key: adapter_client.session().secret_key(),
561    });
562    buf.extend(
563        adapter_client
564            .session()
565            .drain_notices()
566            .into_iter()
567            .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
568    );
569    buf.push(BackendMessage::ReadyForQuery(
570        adapter_client.session().transaction().into(),
571    ));
572    conn.send_all(buf).await?;
573    conn.flush().await?;
574
575    let machine = StateMachine {
576        conn,
577        adapter_client,
578        txn_needs_commit: false,
579        tokio_metrics_intervals,
580    };
581
582    select! {
583        r = machine.run() => {
584            // Errors produced internally (like MAX_REQUEST_SIZE being exceeded) should send an
585            // error to the client informing them why the connection was closed. We still want to
586            // return the original error up the stack, though, so we skip error checking during conn
587            // operations.
588            if let Err(err) = &r {
589                let _ = conn
590                    .send(ErrorResponse::fatal(
591                        SqlState::CONNECTION_FAILURE,
592                        err.to_string(),
593                    ))
594                    .await;
595                let _ = conn.flush().await;
596            }
597            r
598        },
599        _ = expired => {
600            conn
601                .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
602                .await?;
603            conn.flush().await
604        }
605    }
606}
607
608/// Gets `oidc_auth_enabled` from options if it exists.
609/// Returns options with oidc_auth_enabled extracted
610/// and the oidc_auth_enabled value.
611fn extract_oidc_auth_enabled_from_options(
612    options: Result<Vec<(String, String)>, ()>,
613) -> (bool, Result<Vec<(String, String)>, ()>) {
614    let options = match options {
615        Ok(opts) => opts,
616        Err(_) => return (false, options),
617    };
618
619    let mut new_options = Vec::new();
620    let mut oidc_auth_enabled = false;
621
622    for (k, v) in options {
623        if k == "oidc_auth_enabled" {
624            oidc_auth_enabled = v.parse::<bool>().unwrap_or(false);
625        } else {
626            new_options.push((k, v));
627        }
628    }
629
630    (oidc_auth_enabled, Ok(new_options))
631}
632
633/// Returns (name, value) session settings pairs from an options value.
634///
635/// From Postgres, see pg_split_opts in postinit.c and process_postgres_switches
636/// in postgres.c.
637fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
638    let opts = split_options(value);
639    let mut pairs = Vec::with_capacity(opts.len());
640    let mut seen_prefix = false;
641    for opt in opts {
642        if !seen_prefix {
643            if opt == "-c" {
644                seen_prefix = true;
645            } else {
646                let (key, val) = parse_option(&opt)?;
647                pairs.push((key.to_owned(), val.to_owned()));
648            }
649        } else {
650            let (key, val) = opt.split_once('=').ok_or(())?;
651            pairs.push((key.to_owned(), val.to_owned()));
652            seen_prefix = false;
653        }
654    }
655    Ok(pairs)
656}
657
658/// Returns the parsed key and value from option of the form `--key=value`, `-c
659/// key=value`, or `-ckey=value`. Keys replace `-` with `_`. Returns an error if
660/// there was some other prefix.
661fn parse_option(option: &str) -> Result<(&str, &str), ()> {
662    let (key, value) = option.split_once('=').ok_or(())?;
663    for prefix in &["-c", "--"] {
664        if let Some(key) = key.strip_prefix(prefix) {
665            return Ok((key, value));
666        }
667    }
668    Err(())
669}
670
671/// Splits value by any number of spaces except those preceded by `\`.
672fn split_options(value: &str) -> Vec<String> {
673    let mut strs = Vec::new();
674    // Need to build a string because of the escaping, so we can't simply
675    // subslice into value, and this isn't called enough to need to make it
676    // smart so it only builds a string if needed.
677    let mut current = String::new();
678    let mut was_slash = false;
679    for c in value.chars() {
680        was_slash = match c {
681            ' ' => {
682                if was_slash {
683                    current.push(' ');
684                } else if !current.is_empty() {
685                    // To ignore multiple spaces in a row, only push if current
686                    // is not empty.
687                    strs.push(std::mem::take(&mut current));
688                }
689                false
690            }
691            '\\' => {
692                if was_slash {
693                    // Two slashes in a row will add a slash and not escape the
694                    // next char.
695                    current.push('\\');
696                    false
697                } else {
698                    true
699                }
700            }
701            _ => {
702                current.push(c);
703                false
704            }
705        };
706    }
707    // A `\` at the end will be ignored.
708    if !current.is_empty() {
709        strs.push(current);
710    }
711    strs
712}
713
714enum PasswordRequestError {
715    InvalidPasswordError(ErrorResponse),
716    IoError(io::Error),
717}
718
719impl From<io::Error> for PasswordRequestError {
720    fn from(e: io::Error) -> Self {
721        PasswordRequestError::IoError(e)
722    }
723}
724
725/// Requests a cleartext password from a connection and returns it if it is valid.
726/// Sends an error response in the connection if the password
727/// is not valid.
728async fn request_cleartext_password<A>(
729    conn: &mut FramedConn<A>,
730) -> Result<String, PasswordRequestError>
731where
732    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
733{
734    conn.send(BackendMessage::AuthenticationCleartextPassword)
735        .await?;
736    conn.flush().await?;
737
738    if let Some(message) = conn.recv().await? {
739        if let FrontendMessage::RawAuthentication(data) = message {
740            if let Some(FrontendMessage::Password { password }) =
741                decode_password(Cursor::new(&data)).ok()
742            {
743                return Ok(password);
744            }
745        }
746    }
747
748    Err(PasswordRequestError::InvalidPasswordError(
749        ErrorResponse::fatal(
750            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
751            "expected Password message",
752        ),
753    ))
754}
755
756/// Helper for password-based authentication using AdapterClient
757/// and returns an authenticated session.
758async fn authenticate_with_password<A>(
759    conn: &mut FramedConn<A>,
760    adapter_client: &mz_adapter::Client,
761    user: String,
762    conn_uuid: Uuid,
763    helm_chart_version: Option<String>,
764) -> Result<Session, PasswordRequestError>
765where
766    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
767{
768    let password = match request_cleartext_password(conn).await {
769        Ok(password) => Password(password),
770        Err(e) => return Err(e),
771    };
772
773    let authenticated = match adapter_client.authenticate(&user, &password).await {
774        Ok(authenticated) => authenticated,
775        Err(err) => {
776            warn!(?err, "pgwire connection failed authentication");
777            return Err(PasswordRequestError::InvalidPasswordError(
778                ErrorResponse::fatal(SqlState::INVALID_PASSWORD, "invalid password"),
779            ));
780        }
781    };
782
783    let session = adapter_client.new_session(
784        SessionConfig {
785            conn_id: conn.conn_id().clone(),
786            uuid: conn_uuid,
787            user,
788            client_ip: conn.peer_addr().clone(),
789            external_metadata_rx: None,
790            helm_chart_version,
791        },
792        authenticated,
793    );
794
795    Ok(session)
796}
797
798#[derive(Debug)]
799enum State {
800    Ready,
801    Drain,
802    Done,
803}
804
805struct StateMachine<'a, A, I>
806where
807    I: Iterator<Item = TaskMetrics> + Send + 'a,
808{
809    conn: &'a mut FramedConn<A>,
810    adapter_client: mz_adapter::SessionClient,
811    txn_needs_commit: bool,
812    tokio_metrics_intervals: I,
813}
814
815enum SendRowsEndedReason {
816    Success {
817        result_size: u64,
818        rows_returned: u64,
819    },
820    Errored {
821        error: String,
822    },
823    Canceled,
824}
825
826const ABORTED_TXN_MSG: &str =
827    "current transaction is aborted, commands ignored until end of transaction block";
828
829impl<'a, A, I> StateMachine<'a, A, I>
830where
831    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
832    I: Iterator<Item = TaskMetrics> + Send + 'a,
833{
834    // Manually desugar this (don't use `async fn run`) here because a much better
835    // error message is produced if there are problems with Send or other traits
836    // somewhere within the Future.
837    #[allow(clippy::manual_async_fn)]
838    #[mz_ore::instrument(level = "debug")]
839    fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
840        async move {
841            let mut state = State::Ready;
842            loop {
843                self.send_pending_notices().await?;
844                state = match state {
845                    State::Ready => self.advance_ready().await?,
846                    State::Drain => self.advance_drain().await?,
847                    State::Done => return Ok(()),
848                };
849                self.adapter_client
850                    .add_idle_in_transaction_session_timeout();
851            }
852        }
853    }
854
855    #[instrument(level = "debug")]
856    async fn advance_ready(&mut self) -> Result<State, io::Error> {
857        // Start a new metrics interval before the `recv()` call.
858        self.tokio_metrics_intervals
859            .next()
860            .expect("infinite iterator");
861
862        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
863        let message = select! {
864            biased;
865
866            // `recv_timeout()` is cancel-safe as per it's docs.
867            Some(timeout) = self.adapter_client.recv_timeout() => {
868                let err: AdapterError = timeout.into();
869                let conn_id = self.adapter_client.session().conn_id();
870                tracing::warn!("session timed out, conn_id {}", conn_id);
871
872                // Process the error, doing any state cleanup.
873                let error_response = err.into_response(Severity::Fatal);
874                let error_state = self.send_error_and_get_state(error_response).await;
875
876                // Terminate __after__ we do any cleanup.
877                self.adapter_client.terminate().await;
878
879                // We must wait for the client to send a request before we can send the error response.
880                // Due to the PG wire protocol, we can't send an ErrorResponse unless it is in response
881                // to a client message.
882                let _ = self.conn.recv().await?;
883                return error_state;
884            },
885            // `recv()` is cancel-safe as per it's docs.
886            message = self.conn.recv() => message?,
887        };
888
889        // Take the metrics since just before the `recv`.
890        let interval = self
891            .tokio_metrics_intervals
892            .next()
893            .expect("infinite iterator");
894        let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
895
896        // TODO(ggevay): Consider subtracting the scheduling delay from `received`. It's not obvious
897        // whether we should do this, because the result wouldn't exactly correspond to either first
898        // byte received or last byte received (for msgs that arrive in more than one network packet).
899        let received = SYSTEM_TIME();
900
901        self.adapter_client
902            .remove_idle_in_transaction_session_timeout();
903
904        // NOTE(guswynn): we could consider adding spans to all message types. Currently
905        // only a few message types seem useful.
906        let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
907
908        let start = message.as_ref().map(|_| Instant::now());
909        let next_state = match message {
910            Some(FrontendMessage::Query { sql }) => {
911                let query_root_span =
912                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
913                query_root_span.follows_from(tracing::Span::current());
914                self.query(sql, received)
915                    .instrument(query_root_span)
916                    .await?
917            }
918            Some(FrontendMessage::Parse {
919                name,
920                sql,
921                param_types,
922            }) => self.parse(name, sql, param_types).await?,
923            Some(FrontendMessage::Bind {
924                portal_name,
925                statement_name,
926                param_formats,
927                raw_params,
928                result_formats,
929            }) => {
930                self.bind(
931                    portal_name,
932                    statement_name,
933                    param_formats,
934                    raw_params,
935                    result_formats,
936                )
937                .await?
938            }
939            Some(FrontendMessage::Execute {
940                portal_name,
941                max_rows,
942            }) => {
943                let max_rows = match usize::try_from(max_rows) {
944                    Ok(0) | Err(_) => ExecuteCount::All, // If `max_rows < 0`, no limit.
945                    Ok(n) => ExecuteCount::Count(n),
946                };
947                let execute_root_span =
948                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
949                execute_root_span.follows_from(tracing::Span::current());
950                let state = self
951                    .execute(
952                        portal_name,
953                        max_rows,
954                        portal_exec_message,
955                        None,
956                        ExecuteTimeout::None,
957                        None,
958                        Some(received),
959                    )
960                    .instrument(execute_root_span)
961                    .await?;
962                // In PostgreSQL, when using the extended query protocol, some statements may
963                // trigger an eager commit of the current implicit transaction,
964                // see: <https://git.postgresql.org/gitweb/?p=postgresql.git&a=commitdiff&h=f92944137>.
965                //
966                // In Materialize, however, we eagerly commit every statement outside of an explicit
967                // transaction when using the extended query protocol. This allows us to eliminate
968                // the possibility of a multiple statement implicit transaction, which in turn
969                // allows us to apply single-statement optimizations to queries issued in implicit
970                // transactions in the extended query protocol.
971                //
972                // We don't immediately commit here to allow users to page through the portal if
973                // necessary. Committing the transaction would destroy the portal before the next
974                // Execute command has a chance to resume it. So we instead mark the transaction
975                // for commit the next time that `ensure_transaction` is called.
976                if self.adapter_client.session().transaction().is_implicit() {
977                    self.txn_needs_commit = true;
978                }
979                state
980            }
981            Some(FrontendMessage::DescribeStatement { name }) => {
982                self.describe_statement(&name).await?
983            }
984            Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
985            Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
986            Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
987            Some(FrontendMessage::Flush) => self.flush().await?,
988            Some(FrontendMessage::Sync) => self.sync().await?,
989            Some(FrontendMessage::Terminate) => State::Done,
990
991            Some(FrontendMessage::CopyData(_))
992            | Some(FrontendMessage::CopyDone)
993            | Some(FrontendMessage::CopyFail(_))
994            | Some(FrontendMessage::Password { .. })
995            | Some(FrontendMessage::RawAuthentication(_))
996            | Some(FrontendMessage::SASLInitialResponse { .. })
997            | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
998            None => State::Done,
999        };
1000
1001        if let Some(start) = start {
1002            self.adapter_client
1003                .inner()
1004                .metrics()
1005                .pgwire_message_processing_seconds
1006                .with_label_values(&[message_name])
1007                .observe(start.elapsed().as_secs_f64());
1008        }
1009        self.adapter_client
1010            .inner()
1011            .metrics()
1012            .pgwire_recv_scheduling_delay_ms
1013            .with_label_values(&[message_name])
1014            .observe(recv_scheduling_delay_ms);
1015
1016        Ok(next_state)
1017    }
1018
1019    async fn advance_drain(&mut self) -> Result<State, io::Error> {
1020        let message = self.conn.recv().await?;
1021        if message.is_some() {
1022            self.adapter_client
1023                .remove_idle_in_transaction_session_timeout();
1024        }
1025        match message {
1026            Some(FrontendMessage::Sync) => self.sync().await,
1027            None => Ok(State::Done),
1028            _ => Ok(State::Drain),
1029        }
1030    }
1031
1032    /// Note that `lifecycle_timestamps` belongs to the whole "Simple Query", because the whole
1033    /// Simple Query is received and parsed together. This means that if there are multiple
1034    /// statements in a Simple Query, then all of them have the same `lifecycle_timestamps`.
1035    #[instrument(level = "debug")]
1036    async fn one_query(
1037        &mut self,
1038        stmt: Statement<Raw>,
1039        sql: String,
1040        lifecycle_timestamps: LifecycleTimestamps,
1041    ) -> Result<State, io::Error> {
1042        // Bind the portal. Note that this does not set the empty string prepared
1043        // statement.
1044        const EMPTY_PORTAL: &str = "";
1045        if let Err(e) = self
1046            .adapter_client
1047            .declare(EMPTY_PORTAL.to_string(), stmt, sql)
1048            .await
1049        {
1050            return self
1051                .send_error_and_get_state(e.into_response(Severity::Error))
1052                .await;
1053        }
1054        let portal = self
1055            .adapter_client
1056            .session()
1057            .get_portal_unverified_mut(EMPTY_PORTAL)
1058            .expect("unnamed portal should be present");
1059
1060        *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
1061
1062        let stmt_desc = portal.desc.clone();
1063        if !stmt_desc.param_types.is_empty() {
1064            return self
1065                .send_error_and_get_state(ErrorResponse::error(
1066                    SqlState::UNDEFINED_PARAMETER,
1067                    "there is no parameter $1",
1068                ))
1069                .await;
1070        }
1071
1072        // Maybe send row description.
1073        if let Some(relation_desc) = &stmt_desc.relation_desc {
1074            if !stmt_desc.is_copy {
1075                let formats = vec![Format::Text; stmt_desc.arity()];
1076                self.send(BackendMessage::RowDescription(
1077                    message::encode_row_description(relation_desc, &formats),
1078                ))
1079                .await?;
1080            }
1081        }
1082
1083        let result = match self
1084            .adapter_client
1085            .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
1086            .await
1087        {
1088            Ok((response, execute_started)) => {
1089                self.send_pending_notices().await?;
1090                self.send_execute_response(
1091                    response,
1092                    stmt_desc.relation_desc,
1093                    EMPTY_PORTAL.to_string(),
1094                    ExecuteCount::All,
1095                    portal_exec_message,
1096                    None,
1097                    ExecuteTimeout::None,
1098                    execute_started,
1099                )
1100                .await
1101            }
1102            Err(e) => {
1103                self.send_pending_notices().await?;
1104                self.send_error_and_get_state(e.into_response(Severity::Error))
1105                    .await
1106            }
1107        };
1108
1109        // Destroy the portal.
1110        self.adapter_client.session().remove_portal(EMPTY_PORTAL);
1111
1112        result
1113    }
1114
1115    async fn ensure_transaction(
1116        &mut self,
1117        num_stmts: usize,
1118        message_type: &str,
1119    ) -> Result<(), io::Error> {
1120        let start = Instant::now();
1121        if self.txn_needs_commit {
1122            self.commit_transaction().await?;
1123        }
1124        // start_transaction can't error (but assert that just in case it changes in
1125        // the future.
1126        let res = self.adapter_client.start_transaction(Some(num_stmts));
1127        assert_ok!(res);
1128        self.adapter_client
1129            .inner()
1130            .metrics()
1131            .pgwire_ensure_transaction_seconds
1132            .with_label_values(&[message_type])
1133            .observe(start.elapsed().as_secs_f64());
1134        Ok(())
1135    }
1136
1137    fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1138        let parse_start = Instant::now();
1139        let result = match self.adapter_client.parse(sql) {
1140            Ok(result) => result.map_err(|e| {
1141                // Convert our 0-based byte position to pgwire's 1-based character
1142                // position.
1143                let pos = sql[..e.error.pos].chars().count() + 1;
1144                ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1145            }),
1146            Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1147        };
1148        self.adapter_client
1149            .inner()
1150            .metrics()
1151            .parse_seconds
1152            .observe(parse_start.elapsed().as_secs_f64());
1153        result
1154    }
1155
1156    /// Executes a "Simple Query", see
1157    /// <https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SIMPLE-QUERY>
1158    ///
1159    /// For implicit transaction handling, see "Multiple Statements in a Simple Query" in the above.
1160    #[instrument(level = "debug")]
1161    async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1162        // Parse first before doing any transaction checking.
1163        let stmts = match self.parse_sql(&sql) {
1164            Ok(stmts) => stmts,
1165            Err(err) => {
1166                self.send_error_and_get_state(err).await?;
1167                return self.ready().await;
1168            }
1169        };
1170
1171        let num_stmts = stmts.len();
1172
1173        // Compare with postgres' backend/tcop/postgres.c exec_simple_query.
1174        for StatementParseResult { ast: stmt, sql } in stmts {
1175            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1176            if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1177                self.aborted_txn_error().await?;
1178                break;
1179            }
1180
1181            // Start an implicit transaction if we aren't in any transaction and there's
1182            // more than one statement. This mirrors the `use_implicit_block` variable in
1183            // postgres.
1184            //
1185            // This needs to be done in the loop instead of once at the top because
1186            // a COMMIT/ROLLBACK statement needs to start a new transaction on next
1187            // statement.
1188            self.ensure_transaction(num_stmts, "query").await?;
1189
1190            match self
1191                .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1192                .await?
1193            {
1194                State::Ready => (),
1195                State::Drain => break,
1196                State::Done => return Ok(State::Done),
1197            }
1198        }
1199
1200        // Implicit transactions are closed at the end of a Query message.
1201        {
1202            if self.adapter_client.session().transaction().is_implicit() {
1203                self.commit_transaction().await?;
1204            }
1205        }
1206
1207        if num_stmts == 0 {
1208            self.send(BackendMessage::EmptyQueryResponse).await?;
1209        }
1210
1211        self.ready().await
1212    }
1213
1214    #[instrument(level = "debug")]
1215    async fn parse(
1216        &mut self,
1217        name: String,
1218        sql: String,
1219        param_oids: Vec<u32>,
1220    ) -> Result<State, io::Error> {
1221        // Start a transaction if we aren't in one.
1222        self.ensure_transaction(1, "parse").await?;
1223
1224        let mut param_types = vec![];
1225        for oid in param_oids {
1226            match mz_pgrepr::Type::from_oid(oid) {
1227                Ok(ty) => match SqlScalarType::try_from(&ty) {
1228                    Ok(ty) => param_types.push(Some(ty)),
1229                    Err(err) => {
1230                        return self
1231                            .send_error_and_get_state(ErrorResponse::error(
1232                                SqlState::INVALID_PARAMETER_VALUE,
1233                                err.to_string(),
1234                            ))
1235                            .await;
1236                    }
1237                },
1238                Err(_) if oid == 0 => param_types.push(None),
1239                Err(e) => {
1240                    return self
1241                        .send_error_and_get_state(ErrorResponse::error(
1242                            SqlState::PROTOCOL_VIOLATION,
1243                            e.to_string(),
1244                        ))
1245                        .await;
1246                }
1247            }
1248        }
1249
1250        let stmts = match self.parse_sql(&sql) {
1251            Ok(stmts) => stmts,
1252            Err(err) => {
1253                return self.send_error_and_get_state(err).await;
1254            }
1255        };
1256        if stmts.len() > 1 {
1257            return self
1258                .send_error_and_get_state(ErrorResponse::error(
1259                    SqlState::INTERNAL_ERROR,
1260                    "cannot insert multiple commands into a prepared statement",
1261                ))
1262                .await;
1263        }
1264        let (maybe_stmt, sql) = match stmts.into_iter().next() {
1265            None => (None, ""),
1266            Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1267        };
1268        if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1269            return self.aborted_txn_error().await;
1270        }
1271        match self
1272            .adapter_client
1273            .prepare(name, maybe_stmt, sql.to_string(), param_types)
1274            .await
1275        {
1276            Ok(()) => {
1277                self.send(BackendMessage::ParseComplete).await?;
1278                Ok(State::Ready)
1279            }
1280            Err(e) => {
1281                self.send_error_and_get_state(e.into_response(Severity::Error))
1282                    .await
1283            }
1284        }
1285    }
1286
1287    /// Commits and clears the current transaction.
1288    #[instrument(level = "debug")]
1289    async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1290        self.end_transaction(EndTransactionAction::Commit).await
1291    }
1292
1293    /// Rollback and clears the current transaction.
1294    #[instrument(level = "debug")]
1295    async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1296        self.end_transaction(EndTransactionAction::Rollback).await
1297    }
1298
1299    /// End a transaction and report to the user if an error occurred.
1300    #[instrument(level = "debug")]
1301    async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1302        self.txn_needs_commit = false;
1303        let resp = self.adapter_client.end_transaction(action).await;
1304        if let Err(err) = resp {
1305            self.send(BackendMessage::ErrorResponse(
1306                err.into_response(Severity::Error),
1307            ))
1308            .await?;
1309        }
1310        Ok(())
1311    }
1312
1313    #[instrument(level = "debug")]
1314    async fn bind(
1315        &mut self,
1316        portal_name: String,
1317        statement_name: String,
1318        param_formats: Vec<Format>,
1319        raw_params: Vec<Option<Vec<u8>>>,
1320        result_formats: Vec<Format>,
1321    ) -> Result<State, io::Error> {
1322        // Start a transaction if we aren't in one.
1323        self.ensure_transaction(1, "bind").await?;
1324
1325        let aborted_txn = self.is_aborted_txn();
1326        let stmt = match self
1327            .adapter_client
1328            .get_prepared_statement(&statement_name)
1329            .await
1330        {
1331            Ok(stmt) => stmt,
1332            Err(err) => {
1333                return self
1334                    .send_error_and_get_state(err.into_response(Severity::Error))
1335                    .await;
1336            }
1337        };
1338
1339        let param_types = &stmt.desc().param_types;
1340        if param_types.len() != raw_params.len() {
1341            let message = format!(
1342                "bind message supplies {actual} parameters, \
1343                 but prepared statement \"{name}\" requires {expected}",
1344                name = statement_name,
1345                actual = raw_params.len(),
1346                expected = param_types.len()
1347            );
1348            return self
1349                .send_error_and_get_state(ErrorResponse::error(
1350                    SqlState::PROTOCOL_VIOLATION,
1351                    message,
1352                ))
1353                .await;
1354        }
1355        let param_formats = match pad_formats(param_formats, raw_params.len()) {
1356            Ok(param_formats) => param_formats,
1357            Err(msg) => {
1358                return self
1359                    .send_error_and_get_state(ErrorResponse::error(
1360                        SqlState::PROTOCOL_VIOLATION,
1361                        msg,
1362                    ))
1363                    .await;
1364            }
1365        };
1366        if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1367            return self.aborted_txn_error().await;
1368        }
1369        let buf = RowArena::new();
1370        let mut params = vec![];
1371        for ((raw_param, mz_typ), format) in raw_params
1372            .into_iter()
1373            .zip_eq(param_types)
1374            .zip_eq(param_formats)
1375        {
1376            let pg_typ = mz_pgrepr::Type::from(mz_typ);
1377            let datum = match raw_param {
1378                None => Datum::Null,
1379                Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1380                    Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1381                        Ok(datum) => datum,
1382                        Err(msg) => {
1383                            return self
1384                                .send_error_and_get_state(ErrorResponse::error(
1385                                    SqlState::INVALID_PARAMETER_VALUE,
1386                                    msg,
1387                                ))
1388                                .await;
1389                        }
1390                    },
1391                    Err(err) => {
1392                        let msg = format!("unable to decode parameter: {}", err);
1393                        return self
1394                            .send_error_and_get_state(ErrorResponse::error(
1395                                SqlState::INVALID_PARAMETER_VALUE,
1396                                msg,
1397                            ))
1398                            .await;
1399                    }
1400                },
1401            };
1402            params.push((datum, mz_typ.clone()))
1403        }
1404
1405        let result_formats = match pad_formats(
1406            result_formats,
1407            stmt.desc()
1408                .relation_desc
1409                .clone()
1410                .map(|desc| desc.typ().column_types.len())
1411                .unwrap_or(0),
1412        ) {
1413            Ok(result_formats) => result_formats,
1414            Err(msg) => {
1415                return self
1416                    .send_error_and_get_state(ErrorResponse::error(
1417                        SqlState::PROTOCOL_VIOLATION,
1418                        msg,
1419                    ))
1420                    .await;
1421            }
1422        };
1423
1424        // Binary encodings are disabled for list, map, and aclitem types, but this doesn't
1425        // apply to COPY TO statements.
1426        if !stmt.stmt().map_or(false, |stmt| match stmt {
1427            Statement::Copy(CopyStatement {
1428                direction: CopyDirection::To,
1429                ..
1430            }) => true,
1431            Statement::Copy(CopyStatement {
1432                direction: CopyDirection::From,
1433                // To be conservative, we are restricting COPY FROM to only allow list/map/aclitem types if it is not
1434                // copying from STDIN. It is likely that this works in theory, but is risky and likely to OOM anyways
1435                // as all the data will be held in a buffer in memory before being processed.
1436                target: CopyTarget::Expr(_),
1437                ..
1438            }) => true,
1439            _ => false,
1440        }) {
1441            if let Some(desc) = stmt.desc().relation_desc.clone() {
1442                for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1443                    match (format, &ty.scalar_type) {
1444                        (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1445                            return self
1446                                .send_error_and_get_state(ErrorResponse::error(
1447                                    SqlState::PROTOCOL_VIOLATION,
1448                                    "binary encoding of list types is not implemented",
1449                                ))
1450                                .await;
1451                        }
1452                        (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1453                            return self
1454                                .send_error_and_get_state(ErrorResponse::error(
1455                                    SqlState::PROTOCOL_VIOLATION,
1456                                    "binary encoding of map types is not implemented",
1457                                ))
1458                                .await;
1459                        }
1460                        (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1461                            return self
1462                                .send_error_and_get_state(ErrorResponse::error(
1463                                    SqlState::PROTOCOL_VIOLATION,
1464                                    "binary encoding of aclitem types does not exist",
1465                                ))
1466                                .await;
1467                        }
1468                        _ => (),
1469                    }
1470                }
1471            }
1472        }
1473
1474        let desc = stmt.desc().clone();
1475        let logging = Arc::clone(stmt.logging());
1476        let stmt_ast = stmt.stmt().cloned();
1477        let state_revision = stmt.state_revision;
1478        if let Err(err) = self.adapter_client.session().set_portal(
1479            portal_name,
1480            desc,
1481            stmt_ast,
1482            logging,
1483            params,
1484            result_formats,
1485            state_revision,
1486        ) {
1487            return self
1488                .send_error_and_get_state(err.into_response(Severity::Error))
1489                .await;
1490        }
1491
1492        self.send(BackendMessage::BindComplete).await?;
1493        Ok(State::Ready)
1494    }
1495
1496    /// `outer_ctx_extra` is Some when we are executing as part of an outer statement, e.g., a FETCH
1497    /// triggering the execution of the underlying query.
1498    fn execute(
1499        &mut self,
1500        portal_name: String,
1501        max_rows: ExecuteCount,
1502        get_response: GetResponse,
1503        fetch_portal_name: Option<String>,
1504        timeout: ExecuteTimeout,
1505        outer_ctx_extra: Option<ExecuteContextGuard>,
1506        received: Option<EpochMillis>,
1507    ) -> BoxFuture<'_, Result<State, io::Error>> {
1508        async move {
1509            let aborted_txn = self.is_aborted_txn();
1510
1511            // Check if the portal has been started and can be continued.
1512            let portal = match self
1513                .adapter_client
1514                .session()
1515                .get_portal_unverified_mut(&portal_name)
1516            {
1517                Some(portal) => portal,
1518                None => {
1519                    let msg = format!("portal {} does not exist", portal_name.quoted());
1520                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1521                        self.adapter_client.retire_execute(
1522                            outer_ctx_extra,
1523                            StatementEndedExecutionReason::Errored { error: msg.clone() },
1524                        );
1525                    }
1526                    return self
1527                        .send_error_and_get_state(ErrorResponse::error(
1528                            SqlState::INVALID_CURSOR_NAME,
1529                            msg,
1530                        ))
1531                        .await;
1532                }
1533            };
1534
1535            *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1536
1537            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1538            let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1539            if aborted_txn && !txn_exit_stmt {
1540                if let Some(outer_ctx_extra) = outer_ctx_extra {
1541                    self.adapter_client.retire_execute(
1542                        outer_ctx_extra,
1543                        StatementEndedExecutionReason::Errored {
1544                            error: ABORTED_TXN_MSG.to_string(),
1545                        },
1546                    );
1547                }
1548                return self.aborted_txn_error().await;
1549            }
1550
1551            let row_desc = portal.desc.relation_desc.clone();
1552            match portal.state {
1553                PortalState::NotStarted => {
1554                    // Start a transaction if we aren't in one.
1555                    self.ensure_transaction(1, "execute").await?;
1556                    match self
1557                        .adapter_client
1558                        .execute(
1559                            portal_name.clone(),
1560                            self.conn.wait_closed(),
1561                            outer_ctx_extra,
1562                        )
1563                        .await
1564                    {
1565                        Ok((response, execute_started)) => {
1566                            self.send_pending_notices().await?;
1567                            self.send_execute_response(
1568                                response,
1569                                row_desc,
1570                                portal_name,
1571                                max_rows,
1572                                get_response,
1573                                fetch_portal_name,
1574                                timeout,
1575                                execute_started,
1576                            )
1577                            .await
1578                        }
1579                        Err(e) => {
1580                            self.send_pending_notices().await?;
1581                            self.send_error_and_get_state(e.into_response(Severity::Error))
1582                                .await
1583                        }
1584                    }
1585                }
1586                PortalState::InProgress(rows) => {
1587                    let rows = rows.take().expect("InProgress rows must be populated");
1588                    let (result, statement_ended_execution_reason) = match self
1589                        .send_rows(
1590                            row_desc.expect("portal missing row desc on resumption"),
1591                            portal_name,
1592                            rows,
1593                            max_rows,
1594                            get_response,
1595                            fetch_portal_name,
1596                            timeout,
1597                        )
1598                        .await
1599                    {
1600                        Err(e) => {
1601                            // This is an error communicating with the connection.
1602                            // We consider that to be a cancelation, rather than a query error.
1603                            (Err(e), StatementEndedExecutionReason::Canceled)
1604                        }
1605                        Ok((ok, SendRowsEndedReason::Canceled)) => {
1606                            (Ok(ok), StatementEndedExecutionReason::Canceled)
1607                        }
1608                        // NOTE: For now the values for `result_size` and
1609                        // `rows_returned` in fetches are a bit confusing.
1610                        // We record `Some(n)` for the first fetch, where `n` is
1611                        // the number of bytes/rows returned by the inner
1612                        // execute (regardless of how many rows the
1613                        // fetch fetched), and `None` for subsequent fetches.
1614                        //
1615                        // This arguably makes sense since the size/rows
1616                        // returned measures how much work the compute
1617                        // layer had to do to satisfy the query, but
1618                        // we should revisit it if/when we start
1619                        // logging the inner execute separately.
1620                        Ok((
1621                            ok,
1622                            SendRowsEndedReason::Success {
1623                                result_size: _,
1624                                rows_returned: _,
1625                            },
1626                        )) => (
1627                            Ok(ok),
1628                            StatementEndedExecutionReason::Success {
1629                                result_size: None,
1630                                rows_returned: None,
1631                                execution_strategy: None,
1632                            },
1633                        ),
1634                        Ok((ok, SendRowsEndedReason::Errored { error })) => {
1635                            (Ok(ok), StatementEndedExecutionReason::Errored { error })
1636                        }
1637                    };
1638                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1639                        self.adapter_client
1640                            .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1641                    }
1642                    result
1643                }
1644                // FETCH is an awkward command for our current architecture. In Postgres it
1645                // will extract <count> rows from the target portal, cache them, and return
1646                // them to the user as requested. Its command tag is always FETCH <num rows
1647                // extracted>. In Materialize, since we have chosen to not fully support FETCH,
1648                // we must remember the number of rows that were returned. Use this tag to
1649                // remember that information and return it.
1650                PortalState::Completed(Some(tag)) => {
1651                    let tag = tag.to_string();
1652                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1653                        self.adapter_client.retire_execute(
1654                            outer_ctx_extra,
1655                            StatementEndedExecutionReason::Success {
1656                                result_size: None,
1657                                rows_returned: None,
1658                                execution_strategy: None,
1659                            },
1660                        );
1661                    }
1662                    self.send(BackendMessage::CommandComplete { tag }).await?;
1663                    Ok(State::Ready)
1664                }
1665                PortalState::Completed(None) => {
1666                    let error = format!(
1667                        "portal {} cannot be run",
1668                        Ident::new_unchecked(portal_name).to_ast_string_stable()
1669                    );
1670                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1671                        self.adapter_client.retire_execute(
1672                            outer_ctx_extra,
1673                            StatementEndedExecutionReason::Errored {
1674                                error: error.clone(),
1675                            },
1676                        );
1677                    }
1678                    self.send_error_and_get_state(ErrorResponse::error(
1679                        SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1680                        error,
1681                    ))
1682                    .await
1683                }
1684            }
1685        }
1686        .instrument(debug_span!("execute"))
1687        .boxed()
1688    }
1689
1690    #[instrument(level = "debug")]
1691    async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1692        // Start a transaction if we aren't in one.
1693        self.ensure_transaction(1, "describe_statement").await?;
1694
1695        let stmt = match self.adapter_client.get_prepared_statement(name).await {
1696            Ok(stmt) => stmt,
1697            Err(err) => {
1698                return self
1699                    .send_error_and_get_state(err.into_response(Severity::Error))
1700                    .await;
1701            }
1702        };
1703        // Cloning to avoid a mutable borrow issue because `send` also uses `adapter_client`
1704        let parameter_desc = BackendMessage::ParameterDescription(
1705            stmt.desc()
1706                .param_types
1707                .iter()
1708                .map(mz_pgrepr::Type::from)
1709                .collect(),
1710        );
1711        // Claim that all results will be output in text format, even
1712        // though the true result formats are not yet known. A bit
1713        // weird, but this is the behavior that PostgreSQL specifies.
1714        let formats = vec![Format::Text; stmt.desc().arity()];
1715        let row_desc = describe_rows(stmt.desc(), &formats);
1716        self.send_all([parameter_desc, row_desc]).await?;
1717        Ok(State::Ready)
1718    }
1719
1720    #[instrument(level = "debug")]
1721    async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1722        // Start a transaction if we aren't in one.
1723        self.ensure_transaction(1, "describe_portal").await?;
1724
1725        let session = self.adapter_client.session();
1726        let row_desc = session
1727            .get_portal_unverified(name)
1728            .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1729        match row_desc {
1730            Some(row_desc) => {
1731                self.send(row_desc).await?;
1732                Ok(State::Ready)
1733            }
1734            None => {
1735                self.send_error_and_get_state(ErrorResponse::error(
1736                    SqlState::INVALID_CURSOR_NAME,
1737                    format!("portal {} does not exist", name.quoted()),
1738                ))
1739                .await
1740            }
1741        }
1742    }
1743
1744    #[instrument(level = "debug")]
1745    async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1746        self.adapter_client
1747            .session()
1748            .remove_prepared_statement(&name);
1749        self.send(BackendMessage::CloseComplete).await?;
1750        Ok(State::Ready)
1751    }
1752
1753    #[instrument(level = "debug")]
1754    async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1755        self.adapter_client.session().remove_portal(&name);
1756        self.send(BackendMessage::CloseComplete).await?;
1757        Ok(State::Ready)
1758    }
1759
1760    fn complete_portal(&mut self, name: &str) {
1761        let portal = self
1762            .adapter_client
1763            .session()
1764            .get_portal_unverified_mut(name)
1765            .expect("portal should exist");
1766        *portal.state = PortalState::Completed(None);
1767    }
1768
1769    async fn fetch(
1770        &mut self,
1771        name: String,
1772        count: Option<FetchDirection>,
1773        max_rows: ExecuteCount,
1774        fetch_portal_name: Option<String>,
1775        timeout: ExecuteTimeout,
1776        ctx_extra: ExecuteContextGuard,
1777    ) -> Result<State, io::Error> {
1778        // Unlike Execute, no count specified in FETCH returns 1 row, and 0 means 0
1779        // instead of All.
1780        let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1781
1782        // Figure out how many rows we should send back by looking at the various
1783        // combinations of the execute and fetch.
1784        //
1785        // In Postgres, Fetch will cache <count> rows from the target portal and
1786        // return those as requested (if, say, an Execute message was sent with a
1787        // max_rows < the Fetch's count). We expect that case to be incredibly rare and
1788        // so have chosen to not support it until users request it. This eases
1789        // implementation difficulty since we don't have to be able to "send" rows to
1790        // a buffer.
1791        //
1792        // TODO(mjibson): Test this somehow? Need to divide up the pgtest files in
1793        // order to have some that are not Postgres compatible.
1794        let count = match (max_rows, count) {
1795            (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1796                let count = usize::cast_from(count);
1797                if max_rows < count {
1798                    let msg = "Execute with max_rows < a FETCH's count is not supported";
1799                    self.adapter_client.retire_execute(
1800                        ctx_extra,
1801                        StatementEndedExecutionReason::Errored {
1802                            error: msg.to_string(),
1803                        },
1804                    );
1805                    return self
1806                        .send_error_and_get_state(ErrorResponse::error(
1807                            SqlState::FEATURE_NOT_SUPPORTED,
1808                            msg,
1809                        ))
1810                        .await;
1811                }
1812                ExecuteCount::Count(count)
1813            }
1814            (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1815                let msg = "Execute with max_rows of a FETCH ALL is not supported";
1816                self.adapter_client.retire_execute(
1817                    ctx_extra,
1818                    StatementEndedExecutionReason::Errored {
1819                        error: msg.to_string(),
1820                    },
1821                );
1822                return self
1823                    .send_error_and_get_state(ErrorResponse::error(
1824                        SqlState::FEATURE_NOT_SUPPORTED,
1825                        msg,
1826                    ))
1827                    .await;
1828            }
1829            (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1830            (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1831                ExecuteCount::Count(usize::cast_from(count))
1832            }
1833        };
1834        let cursor_name = name.to_string();
1835        self.execute(
1836            cursor_name,
1837            count,
1838            fetch_message,
1839            fetch_portal_name,
1840            timeout,
1841            Some(ctx_extra),
1842            None,
1843        )
1844        .await
1845    }
1846
1847    async fn flush(&mut self) -> Result<State, io::Error> {
1848        self.conn.flush().await?;
1849        Ok(State::Ready)
1850    }
1851
1852    /// Sends a backend message to the client, after applying a severity filter.
1853    ///
1854    /// The message is only sent if its severity is above the severity set
1855    /// in the session, with the default value being NOTICE.
1856    #[instrument(level = "debug")]
1857    async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1858    where
1859        M: Into<BackendMessage>,
1860    {
1861        let message: BackendMessage = message.into();
1862        let is_error =
1863            matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1864
1865        self.conn.send(message).await?;
1866
1867        // Flush immediately after sending an error response, as some clients
1868        // expect to be able to read the error response before sending a Sync
1869        // message. This is arguably in violation of the protocol specification,
1870        // but the specification is somewhat ambiguous, and easier to match
1871        // PostgreSQL here than to fix all the clients that have this
1872        // expectation.
1873        if is_error {
1874            self.conn.flush().await?;
1875        }
1876
1877        Ok(())
1878    }
1879
1880    #[instrument(level = "debug")]
1881    pub async fn send_all(
1882        &mut self,
1883        messages: impl IntoIterator<Item = BackendMessage>,
1884    ) -> Result<(), io::Error> {
1885        for m in messages {
1886            self.send(m).await?;
1887        }
1888        Ok(())
1889    }
1890
1891    #[instrument(level = "debug")]
1892    async fn sync(&mut self) -> Result<State, io::Error> {
1893        // Close the current transaction if we are in an implicit transaction.
1894        if self.adapter_client.session().transaction().is_implicit() {
1895            self.commit_transaction().await?;
1896        }
1897        self.ready().await
1898    }
1899
1900    #[instrument(level = "debug")]
1901    async fn ready(&mut self) -> Result<State, io::Error> {
1902        let txn_state = self.adapter_client.session().transaction().into();
1903        self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1904        self.flush().await
1905    }
1906
1907    #[allow(clippy::too_many_arguments)]
1908    #[instrument(level = "debug")]
1909    async fn send_execute_response(
1910        &mut self,
1911        response: ExecuteResponse,
1912        row_desc: Option<RelationDesc>,
1913        portal_name: String,
1914        max_rows: ExecuteCount,
1915        get_response: GetResponse,
1916        fetch_portal_name: Option<String>,
1917        timeout: ExecuteTimeout,
1918        execute_started: Instant,
1919    ) -> Result<State, io::Error> {
1920        let mut tag = response.tag();
1921
1922        macro_rules! command_complete {
1923            () => {{
1924                self.send(BackendMessage::CommandComplete {
1925                    tag: tag
1926                        .take()
1927                        .expect("command_complete only called on tag-generating results"),
1928                })
1929                .await?;
1930                Ok(State::Ready)
1931            }};
1932        }
1933
1934        let r = match response {
1935            ExecuteResponse::ClosedCursor => {
1936                self.complete_portal(&portal_name);
1937                command_complete!()
1938            }
1939            ExecuteResponse::DeclaredCursor => {
1940                self.complete_portal(&portal_name);
1941                command_complete!()
1942            }
1943            ExecuteResponse::EmptyQuery => {
1944                self.send(BackendMessage::EmptyQueryResponse).await?;
1945                Ok(State::Ready)
1946            }
1947            ExecuteResponse::Fetch {
1948                name,
1949                count,
1950                timeout,
1951                ctx_extra,
1952            } => {
1953                self.fetch(
1954                    name,
1955                    count,
1956                    max_rows,
1957                    Some(portal_name.to_string()),
1958                    timeout,
1959                    ctx_extra,
1960                )
1961                .await
1962            }
1963            ExecuteResponse::SendingRowsStreaming {
1964                rows,
1965                instance_id,
1966                strategy,
1967            } => {
1968                let row_desc = row_desc
1969                    .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1970
1971                let span = tracing::debug_span!("sending_rows_streaming");
1972
1973                self.send_rows(
1974                    row_desc,
1975                    portal_name,
1976                    InProgressRows::new(RecordFirstRowStream::new(
1977                        Box::new(rows),
1978                        execute_started,
1979                        &self.adapter_client,
1980                        Some(instance_id),
1981                        Some(strategy),
1982                    )),
1983                    max_rows,
1984                    get_response,
1985                    fetch_portal_name,
1986                    timeout,
1987                )
1988                .instrument(span)
1989                .await
1990                .map(|(state, _)| state)
1991            }
1992            ExecuteResponse::SendingRowsImmediate { rows } => {
1993                let row_desc = row_desc
1994                    .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1995
1996                let span = tracing::debug_span!("sending_rows_immediate");
1997
1998                let stream =
1999                    futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
2000                self.send_rows(
2001                    row_desc,
2002                    portal_name,
2003                    InProgressRows::new(RecordFirstRowStream::new(
2004                        Box::new(stream),
2005                        execute_started,
2006                        &self.adapter_client,
2007                        None,
2008                        Some(StatementExecutionStrategy::Constant),
2009                    )),
2010                    max_rows,
2011                    get_response,
2012                    fetch_portal_name,
2013                    timeout,
2014                )
2015                .instrument(span)
2016                .await
2017                .map(|(state, _)| state)
2018            }
2019            ExecuteResponse::SetVariable { name, .. } => {
2020                // This code is somewhat awkwardly structured because we
2021                // can't hold `var` across an await point.
2022                let qn = name.to_string();
2023                let msg = if let Some(var) = self
2024                    .adapter_client
2025                    .session()
2026                    .vars_mut()
2027                    .notify_set()
2028                    .find(|v| v.name() == qn)
2029                {
2030                    Some(BackendMessage::ParameterStatus(var.name(), var.value()))
2031                } else {
2032                    None
2033                };
2034                if let Some(msg) = msg {
2035                    self.send(msg).await?;
2036                }
2037                command_complete!()
2038            }
2039            ExecuteResponse::Subscribing {
2040                rx,
2041                ctx_extra,
2042                instance_id,
2043            } => {
2044                if fetch_portal_name.is_none() {
2045                    let mut msg = ErrorResponse::notice(
2046                        SqlState::WARNING,
2047                        "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
2048                    );
2049                    if self.adapter_client.session().vars().application_name() == "psql" {
2050                        msg.hint = Some(
2051                            "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
2052                                .into(),
2053                        )
2054                    }
2055                    self.send(msg).await?;
2056                    self.conn.flush().await?;
2057                }
2058                let row_desc =
2059                    row_desc.expect("missing row description for ExecuteResponse::Subscribing");
2060                let (result, statement_ended_execution_reason) = match self
2061                    .send_rows(
2062                        row_desc,
2063                        portal_name,
2064                        InProgressRows::new(RecordFirstRowStream::new(
2065                            Box::new(UnboundedReceiverStream::new(rx)),
2066                            execute_started,
2067                            &self.adapter_client,
2068                            Some(instance_id),
2069                            None,
2070                        )),
2071                        max_rows,
2072                        get_response,
2073                        fetch_portal_name,
2074                        timeout,
2075                    )
2076                    .await
2077                {
2078                    Err(e) => {
2079                        // This is an error communicating with the connection.
2080                        // We consider that to be a cancelation, rather than a query error.
2081                        (Err(e), StatementEndedExecutionReason::Canceled)
2082                    }
2083                    Ok((ok, SendRowsEndedReason::Canceled)) => {
2084                        (Ok(ok), StatementEndedExecutionReason::Canceled)
2085                    }
2086                    Ok((
2087                        ok,
2088                        SendRowsEndedReason::Success {
2089                            result_size,
2090                            rows_returned,
2091                        },
2092                    )) => (
2093                        Ok(ok),
2094                        StatementEndedExecutionReason::Success {
2095                            result_size: Some(result_size),
2096                            rows_returned: Some(rows_returned),
2097                            execution_strategy: None,
2098                        },
2099                    ),
2100                    Ok((ok, SendRowsEndedReason::Errored { error })) => {
2101                        (Ok(ok), StatementEndedExecutionReason::Errored { error })
2102                    }
2103                };
2104                self.adapter_client
2105                    .retire_execute(ctx_extra, statement_ended_execution_reason);
2106                return result;
2107            }
2108            ExecuteResponse::CopyTo { format, resp } => {
2109                let row_desc =
2110                    row_desc.expect("missing row description for ExecuteResponse::CopyTo");
2111                match *resp {
2112                    ExecuteResponse::Subscribing {
2113                        rx,
2114                        ctx_extra,
2115                        instance_id,
2116                    } => {
2117                        let (result, statement_ended_execution_reason) = match self
2118                            .copy_rows(
2119                                format,
2120                                row_desc,
2121                                RecordFirstRowStream::new(
2122                                    Box::new(UnboundedReceiverStream::new(rx)),
2123                                    execute_started,
2124                                    &self.adapter_client,
2125                                    Some(instance_id),
2126                                    None,
2127                                ),
2128                            )
2129                            .await
2130                        {
2131                            Err(e) => {
2132                                // This is an error communicating with the connection.
2133                                // We consider that to be a cancelation, rather than a query error.
2134                                (Err(e), StatementEndedExecutionReason::Canceled)
2135                            }
2136                            Ok((
2137                                state,
2138                                SendRowsEndedReason::Success {
2139                                    result_size,
2140                                    rows_returned,
2141                                },
2142                            )) => (
2143                                Ok(state),
2144                                StatementEndedExecutionReason::Success {
2145                                    result_size: Some(result_size),
2146                                    rows_returned: Some(rows_returned),
2147                                    execution_strategy: None,
2148                                },
2149                            ),
2150                            Ok((state, SendRowsEndedReason::Errored { error })) => {
2151                                (Ok(state), StatementEndedExecutionReason::Errored { error })
2152                            }
2153                            Ok((state, SendRowsEndedReason::Canceled)) => {
2154                                (Ok(state), StatementEndedExecutionReason::Canceled)
2155                            }
2156                        };
2157                        self.adapter_client
2158                            .retire_execute(ctx_extra, statement_ended_execution_reason);
2159                        return result;
2160                    }
2161                    ExecuteResponse::SendingRowsStreaming {
2162                        rows,
2163                        instance_id,
2164                        strategy,
2165                    } => {
2166                        // We don't need to finalize execution here;
2167                        // it was already done in the
2168                        // coordinator. Just extract the state and
2169                        // return that.
2170                        return self
2171                            .copy_rows(
2172                                format,
2173                                row_desc,
2174                                RecordFirstRowStream::new(
2175                                    Box::new(rows),
2176                                    execute_started,
2177                                    &self.adapter_client,
2178                                    Some(instance_id),
2179                                    Some(strategy),
2180                                ),
2181                            )
2182                            .await
2183                            .map(|(state, _)| state);
2184                    }
2185                    ExecuteResponse::SendingRowsImmediate { rows } => {
2186                        let span = tracing::debug_span!("sending_rows_immediate");
2187
2188                        let rows = futures::stream::once(futures::future::ready(
2189                            PeekResponseUnary::Rows(rows),
2190                        ));
2191                        // We don't need to finalize execution here;
2192                        // it was already done in the
2193                        // coordinator. Just extract the state and
2194                        // return that.
2195                        return self
2196                            .copy_rows(
2197                                format,
2198                                row_desc,
2199                                RecordFirstRowStream::new(
2200                                    Box::new(rows),
2201                                    execute_started,
2202                                    &self.adapter_client,
2203                                    None,
2204                                    Some(StatementExecutionStrategy::Constant),
2205                                ),
2206                            )
2207                            .instrument(span)
2208                            .await
2209                            .map(|(state, _)| state);
2210                    }
2211                    _ => {
2212                        return self
2213                            .send_error_and_get_state(ErrorResponse::error(
2214                                SqlState::INTERNAL_ERROR,
2215                                "unsupported COPY response type".to_string(),
2216                            ))
2217                            .await;
2218                    }
2219                };
2220            }
2221            ExecuteResponse::CopyFrom {
2222                target_id,
2223                target_name,
2224                columns,
2225                params,
2226                ctx_extra,
2227            } => {
2228                let row_desc =
2229                    row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2230                self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2231                    .await
2232            }
2233            ExecuteResponse::TransactionCommitted { params }
2234            | ExecuteResponse::TransactionRolledBack { params } => {
2235                let notify_set: mz_ore::collections::HashSet<String> = self
2236                    .adapter_client
2237                    .session()
2238                    .vars()
2239                    .notify_set()
2240                    .map(|v| v.name().to_string())
2241                    .collect();
2242
2243                // Only report on parameters that are in the notify set.
2244                for (name, value) in params
2245                    .into_iter()
2246                    .filter(|(name, _v)| notify_set.contains(*name))
2247                {
2248                    let msg = BackendMessage::ParameterStatus(name, value);
2249                    self.send(msg).await?;
2250                }
2251                command_complete!()
2252            }
2253
2254            ExecuteResponse::AlteredDefaultPrivileges
2255            | ExecuteResponse::AlteredObject(..)
2256            | ExecuteResponse::AlteredRole
2257            | ExecuteResponse::AlteredSystemConfiguration
2258            | ExecuteResponse::CreatedCluster { .. }
2259            | ExecuteResponse::CreatedClusterReplica { .. }
2260            | ExecuteResponse::CreatedConnection { .. }
2261            | ExecuteResponse::CreatedDatabase { .. }
2262            | ExecuteResponse::CreatedIndex { .. }
2263            | ExecuteResponse::CreatedIntrospectionSubscribe
2264            | ExecuteResponse::CreatedMaterializedView { .. }
2265            | ExecuteResponse::CreatedContinualTask { .. }
2266            | ExecuteResponse::CreatedRole
2267            | ExecuteResponse::CreatedSchema { .. }
2268            | ExecuteResponse::CreatedSecret { .. }
2269            | ExecuteResponse::CreatedSink { .. }
2270            | ExecuteResponse::CreatedSource { .. }
2271            | ExecuteResponse::CreatedTable { .. }
2272            | ExecuteResponse::CreatedType
2273            | ExecuteResponse::CreatedView { .. }
2274            | ExecuteResponse::CreatedViews { .. }
2275            | ExecuteResponse::CreatedNetworkPolicy
2276            | ExecuteResponse::Comment
2277            | ExecuteResponse::Deallocate { .. }
2278            | ExecuteResponse::Deleted(..)
2279            | ExecuteResponse::DiscardedAll
2280            | ExecuteResponse::DiscardedTemp
2281            | ExecuteResponse::DroppedObject(_)
2282            | ExecuteResponse::DroppedOwned
2283            | ExecuteResponse::GrantedPrivilege
2284            | ExecuteResponse::GrantedRole
2285            | ExecuteResponse::Inserted(..)
2286            | ExecuteResponse::Copied(..)
2287            | ExecuteResponse::Prepare
2288            | ExecuteResponse::Raised
2289            | ExecuteResponse::ReassignOwned
2290            | ExecuteResponse::RevokedPrivilege
2291            | ExecuteResponse::RevokedRole
2292            | ExecuteResponse::StartedTransaction { .. }
2293            | ExecuteResponse::Updated(..)
2294            | ExecuteResponse::ValidatedConnection => {
2295                command_complete!()
2296            }
2297        };
2298
2299        assert_none!(tag, "tag created but not consumed: {:?}", tag);
2300        r
2301    }
2302
2303    #[allow(clippy::too_many_arguments)]
2304    // TODO(guswynn): figure out how to get it to compile without skip_all
2305    #[mz_ore::instrument(level = "debug")]
2306    async fn send_rows(
2307        &mut self,
2308        row_desc: RelationDesc,
2309        portal_name: String,
2310        mut rows: InProgressRows,
2311        max_rows: ExecuteCount,
2312        get_response: GetResponse,
2313        fetch_portal_name: Option<String>,
2314        timeout: ExecuteTimeout,
2315    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2316        // If this portal is being executed from a FETCH then we need to use the result
2317        // format type of the outer portal.
2318        let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2319            name
2320        } else {
2321            &portal_name
2322        };
2323        let result_formats = self
2324            .adapter_client
2325            .session()
2326            .get_portal_unverified(result_format_portal_name)
2327            .expect("valid fetch portal name for send rows")
2328            .result_formats
2329            .clone();
2330
2331        let (mut wait_once, mut deadline) = match timeout {
2332            ExecuteTimeout::None => (false, None),
2333            ExecuteTimeout::Seconds(t) => (
2334                false,
2335                Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2336            ),
2337            ExecuteTimeout::WaitOnce => (true, None),
2338        };
2339
2340        // Sanity check that the various `RelationDesc`s match up.
2341        {
2342            let portal_name_desc = &self
2343                .adapter_client
2344                .session()
2345                .get_portal_unverified(portal_name.as_str())
2346                .expect("portal should exist")
2347                .desc
2348                .relation_desc;
2349            if let Some(portal_name_desc) = portal_name_desc {
2350                soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2351            }
2352            if let Some(fetch_portal_name) = &fetch_portal_name {
2353                let fetch_portal_desc = &self
2354                    .adapter_client
2355                    .session()
2356                    .get_portal_unverified(fetch_portal_name)
2357                    .expect("portal should exist")
2358                    .desc
2359                    .relation_desc;
2360                if let Some(fetch_portal_desc) = fetch_portal_desc {
2361                    soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2362                }
2363            }
2364        }
2365
2366        self.conn.set_encode_state(
2367            row_desc
2368                .typ()
2369                .column_types
2370                .iter()
2371                .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2372                .zip_eq(result_formats)
2373                .collect(),
2374        );
2375
2376        let mut total_sent_rows = 0;
2377        let mut total_sent_bytes = 0;
2378        // want_rows is the maximum number of rows the client wants.
2379        let mut want_rows = match max_rows {
2380            ExecuteCount::All => usize::MAX,
2381            ExecuteCount::Count(count) => count,
2382        };
2383
2384        // Send rows while the client still wants them and there are still rows to send.
2385        loop {
2386            // Fetch next batch of rows, waiting for a possible requested
2387            // timeout or notice.
2388            let batch = if rows.current.is_some() {
2389                FetchResult::Rows(rows.current.take())
2390            } else if want_rows == 0 {
2391                FetchResult::Rows(None)
2392            } else {
2393                let notice_fut = self.adapter_client.session().recv_notice();
2394                tokio::select! {
2395                    err = self.conn.wait_closed() => return Err(err),
2396                    _ = time::sleep_until(
2397                        deadline.unwrap_or_else(tokio::time::Instant::now),
2398                    ), if deadline.is_some() => FetchResult::Rows(None),
2399                    notice = notice_fut => {
2400                        FetchResult::Notice(notice)
2401                    }
2402                    batch = rows.remaining.recv() => match batch {
2403                        None => FetchResult::Rows(None),
2404                        Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2405                        Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2406                        Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2407                    },
2408                }
2409            };
2410
2411            match batch {
2412                FetchResult::Rows(None) => break,
2413                FetchResult::Rows(Some(mut batch_rows)) => {
2414                    if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2415                        let msg = err.to_string();
2416                        return self
2417                            .send_error_and_get_state(err.into_response(Severity::Error))
2418                            .await
2419                            .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2420                    }
2421
2422                    // If wait_once is true: the first time this fn is called it blocks (same as
2423                    // deadline == None). The second time this fn is called it should behave the
2424                    // same a 0s timeout.
2425                    if wait_once && batch_rows.peek().is_some() {
2426                        deadline = Some(tokio::time::Instant::now());
2427                        wait_once = false;
2428                    }
2429
2430                    // Send a portion of the rows.
2431                    let mut sent_rows = 0;
2432                    let mut sent_bytes = 0;
2433                    let messages = (&mut batch_rows)
2434                        // TODO(parkmycar): This is a fair bit of juggling between iterator types
2435                        // to count the total number of bytes. Alternatively we could track the
2436                        // total sent bytes in this .map(...) call, but having side effects in map
2437                        // is a code smell.
2438                        .map(|row| {
2439                            let row_len = row.byte_len();
2440                            let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2441                            (row_len, BackendMessage::DataRow(values))
2442                        })
2443                        .inspect(|(row_len, _)| {
2444                            sent_bytes += row_len;
2445                            sent_rows += 1
2446                        })
2447                        .map(|(_row_len, row)| row)
2448                        .take(want_rows);
2449                    self.send_all(messages).await?;
2450
2451                    total_sent_rows += sent_rows;
2452                    total_sent_bytes += sent_bytes;
2453                    want_rows -= sent_rows;
2454
2455                    // If we have sent the number of requested rows, put the remainder of the batch
2456                    // (if any) back and stop sending.
2457                    if want_rows == 0 {
2458                        if batch_rows.peek().is_some() {
2459                            rows.current = Some(batch_rows);
2460                        }
2461                        break;
2462                    }
2463
2464                    self.conn.flush().await?;
2465                }
2466                FetchResult::Notice(notice) => {
2467                    self.send(notice.into_response()).await?;
2468                    self.conn.flush().await?;
2469                }
2470                FetchResult::Error(text) => {
2471                    return self
2472                        .send_error_and_get_state(ErrorResponse::error(
2473                            SqlState::INTERNAL_ERROR,
2474                            text.clone(),
2475                        ))
2476                        .await
2477                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2478                }
2479                FetchResult::Canceled => {
2480                    return self
2481                        .send_error_and_get_state(ErrorResponse::error(
2482                            SqlState::QUERY_CANCELED,
2483                            "canceling statement due to user request",
2484                        ))
2485                        .await
2486                        .map(|state| (state, SendRowsEndedReason::Canceled));
2487                }
2488            }
2489        }
2490
2491        let portal = self
2492            .adapter_client
2493            .session()
2494            .get_portal_unverified_mut(&portal_name)
2495            .expect("valid portal name for send rows");
2496
2497        let saw_rows = rows.remaining.saw_rows;
2498        let no_more_rows = rows.no_more_rows();
2499        let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2500
2501        // Always return rows back, even if it's empty. This prevents an unclosed
2502        // portal from re-executing after it has been emptied.
2503        *portal.state = PortalState::InProgress(Some(rows));
2504
2505        let fetch_portal = fetch_portal_name.map(|name| {
2506            self.adapter_client
2507                .session()
2508                .get_portal_unverified_mut(&name)
2509                .expect("valid fetch portal")
2510        });
2511        let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2512        self.send(response_message).await?;
2513
2514        // Attend to metrics if there are no more rows.
2515        if no_more_rows {
2516            let statement_type = if let Some(stmt) = &self
2517                .adapter_client
2518                .session()
2519                .get_portal_unverified(&portal_name)
2520                .expect("valid portal name for send_rows")
2521                .stmt
2522            {
2523                metrics::statement_type_label_value(stmt.deref())
2524            } else {
2525                "no-statement"
2526            };
2527            let duration = if saw_rows {
2528                recorded_first_row_instant
2529                    .expect("recorded_first_row_instant because saw_rows")
2530                    .elapsed()
2531            } else {
2532                // If the result is empty, then we define time from first to last row as 0.
2533                // (Note that, currently, an empty result involves a PeekResponse with 0 rows, which
2534                // does flip `saw_rows`, so this code path is currently not exercised.)
2535                Duration::ZERO
2536            };
2537            self.adapter_client
2538                .inner()
2539                .metrics()
2540                .result_rows_first_to_last_byte_seconds
2541                .with_label_values(&[statement_type])
2542                .observe(duration.as_secs_f64());
2543        }
2544
2545        Ok((
2546            State::Ready,
2547            SendRowsEndedReason::Success {
2548                result_size: u64::cast_from(total_sent_bytes),
2549                rows_returned: u64::cast_from(total_sent_rows),
2550            },
2551        ))
2552    }
2553
2554    #[mz_ore::instrument(level = "debug")]
2555    async fn copy_rows(
2556        &mut self,
2557        format: CopyFormat,
2558        row_desc: RelationDesc,
2559        mut stream: RecordFirstRowStream,
2560    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2561        let (row_format, encode_format) = match format {
2562            CopyFormat::Text => (
2563                CopyFormatParams::Text(CopyTextFormatParams::default()),
2564                Format::Text,
2565            ),
2566            CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2567            CopyFormat::Csv => (
2568                CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2569                Format::Text,
2570            ),
2571            CopyFormat::Parquet => {
2572                let text = "Parquet format is not supported".to_string();
2573                return self
2574                    .send_error_and_get_state(ErrorResponse::error(
2575                        SqlState::INTERNAL_ERROR,
2576                        text.clone(),
2577                    ))
2578                    .await
2579                    .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2580            }
2581        };
2582
2583        let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2584            mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2585        };
2586
2587        let typ = row_desc.typ();
2588        let column_formats = iter::repeat(encode_format)
2589            .take(typ.column_types.len())
2590            .collect();
2591        self.send(BackendMessage::CopyOutResponse {
2592            overall_format: encode_format,
2593            column_formats,
2594        })
2595        .await?;
2596
2597        // In Postgres, binary copy has a header that is followed (in the same
2598        // CopyData) by the first row. In order to replicate their behavior, use a
2599        // common vec that we can extend one time now and then fill up with the encode
2600        // functions.
2601        let mut out = Vec::new();
2602
2603        if let CopyFormat::Binary = format {
2604            // 11-byte signature.
2605            out.extend(b"PGCOPY\n\xFF\r\n\0");
2606            // 32-bit flags field.
2607            out.extend([0, 0, 0, 0]);
2608            // 32-bit header extension length field.
2609            out.extend([0, 0, 0, 0]);
2610        }
2611
2612        let mut count = 0;
2613        let mut total_sent_bytes = 0;
2614        loop {
2615            tokio::select! {
2616                e = self.conn.wait_closed() => return Err(e),
2617                batch = stream.recv() => match batch {
2618                    None => break,
2619                    Some(PeekResponseUnary::Error(text)) => {
2620                        let err =
2621                            ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone());
2622                        return self
2623                            .send_error_and_get_state(err)
2624                            .await
2625                            .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2626                    }
2627                    Some(PeekResponseUnary::Canceled) => {
2628                        return self.send_error_and_get_state(ErrorResponse::error(
2629                                SqlState::QUERY_CANCELED,
2630                                "canceling statement due to user request",
2631                            ))
2632                            .await.map(|state| (state, SendRowsEndedReason::Canceled));
2633                    }
2634                    Some(PeekResponseUnary::Rows(mut rows)) => {
2635                        count += rows.count();
2636                        while let Some(row) = rows.next() {
2637                            total_sent_bytes += row.byte_len();
2638                            encode_fn(row, typ, &mut out)?;
2639                            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2640                                .await?;
2641                        }
2642                    }
2643                },
2644                notice = self.adapter_client.session().recv_notice() => {
2645                    self.send(notice.into_response())
2646                        .await?;
2647                    self.conn.flush().await?;
2648                }
2649            }
2650
2651            self.conn.flush().await?;
2652        }
2653        // Send required trailers.
2654        if let CopyFormat::Binary = format {
2655            let trailer: i16 = -1;
2656            out.extend(trailer.to_be_bytes());
2657            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2658                .await?;
2659        }
2660
2661        let tag = format!("COPY {}", count);
2662        self.send(BackendMessage::CopyDone).await?;
2663        self.send(BackendMessage::CommandComplete { tag }).await?;
2664        Ok((
2665            State::Ready,
2666            SendRowsEndedReason::Success {
2667                result_size: u64::cast_from(total_sent_bytes),
2668                rows_returned: u64::cast_from(count),
2669            },
2670        ))
2671    }
2672
2673    /// Handles the copy-in mode of the postgres protocol from transferring
2674    /// data to the server.
2675    #[instrument(level = "debug")]
2676    async fn copy_from(
2677        &mut self,
2678        target_id: CatalogItemId,
2679        target_name: String,
2680        columns: Vec<ColumnIndex>,
2681        params: CopyFormatParams<'static>,
2682        row_desc: RelationDesc,
2683        mut ctx_extra: ExecuteContextGuard,
2684    ) -> Result<State, io::Error> {
2685        let res = self
2686            .copy_from_inner(
2687                target_id,
2688                target_name,
2689                columns,
2690                params,
2691                row_desc,
2692                &mut ctx_extra,
2693            )
2694            .await;
2695        match &res {
2696            Ok(State::Ready) => {
2697                self.adapter_client.retire_execute(
2698                    ctx_extra,
2699                    StatementEndedExecutionReason::Success {
2700                        result_size: None,
2701                        rows_returned: None,
2702                        execution_strategy: None,
2703                    },
2704                );
2705            }
2706            Ok(State::Done) => {
2707                // The connection closed gracefully without sending us a `CopyDone`,
2708                // causing us to just drop the copy request.
2709                // For the purposes of statement logging, we count this as a cancellation.
2710                self.adapter_client
2711                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2712            }
2713            Err(e) => {
2714                self.adapter_client.retire_execute(
2715                    ctx_extra,
2716                    StatementEndedExecutionReason::Errored {
2717                        error: format!("{e}"),
2718                    },
2719                );
2720            }
2721            Ok(state) if matches!(state, State::Drain) => {}
2722            other => {
2723                mz_ore::soft_panic_or_log!(
2724                    "unexpected COPY FROM state in copy_from: {other:?}; \
2725                     relying on ExecuteContextGuard::drop to retire as Aborted"
2726                );
2727            }
2728        }
2729        res
2730    }
2731
2732    async fn copy_from_inner(
2733        &mut self,
2734        target_id: CatalogItemId,
2735        target_name: String,
2736        columns: Vec<ColumnIndex>,
2737        params: CopyFormatParams<'static>,
2738        row_desc: RelationDesc,
2739        ctx_extra: &mut ExecuteContextGuard,
2740    ) -> Result<State, io::Error> {
2741        let typ = row_desc.typ();
2742        let column_formats = vec![Format::Text; typ.column_types.len()];
2743        self.send(BackendMessage::CopyInResponse {
2744            overall_format: Format::Text,
2745            column_formats,
2746        })
2747        .await?;
2748        self.conn.flush().await?;
2749
2750        // Set up the parallel streaming batch builders in the coordinator.
2751        let writer = match self
2752            .adapter_client
2753            .start_copy_from_stdin(
2754                target_id,
2755                target_name.clone(),
2756                columns.clone(),
2757                row_desc.clone(),
2758                params.clone(),
2759            )
2760            .await
2761        {
2762            Ok(writer) => writer,
2763            Err(e) => {
2764                // Drain remaining CopyData/CopyDone/CopyFail messages from the
2765                // socket. Since CopyInResponse was already sent, the client may
2766                // have pipelined copy data that we must consume before returning
2767                // the error, otherwise they'd be misinterpreted as top-level
2768                // protocol messages and cause a deadlock.
2769                loop {
2770                    match self.conn.recv().await? {
2771                        Some(FrontendMessage::CopyData(_)) => {}
2772                        Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2773                            break;
2774                        }
2775                        Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2776                        Some(_) => break,
2777                        None => return Ok(State::Done),
2778                    }
2779                }
2780                self.adapter_client.retire_execute(
2781                    std::mem::take(ctx_extra),
2782                    StatementEndedExecutionReason::Errored {
2783                        error: e.to_string(),
2784                    },
2785                );
2786                return self
2787                    .send_error_and_get_state(e.into_response(Severity::Error))
2788                    .await;
2789            }
2790        };
2791
2792        // Enable copy mode on the codec to skip aggregate buffer size checks.
2793        self.conn.set_copy_mode(true);
2794
2795        // Batch size for splitting raw data across parallel workers (~32MB).
2796        const BATCH_SIZE: usize = 32 * 1024 * 1024;
2797        let max_copy_from_row_size = self
2798            .adapter_client
2799            .get_system_vars()
2800            .await
2801            .max_copy_from_row_size()
2802            .try_into()
2803            .unwrap_or(usize::MAX);
2804
2805        let mut data = Vec::new();
2806        let mut row_scanner = CopyRowScanner::new(&params);
2807        let num_workers = writer.batch_txs.len();
2808        let mut next_worker: usize = 0;
2809        let mut saw_copy_done = false;
2810        let mut saw_end_marker = false;
2811        let mut copy_from_error: Option<(SqlState, String)> = None;
2812
2813        // Receive loop: accumulate CopyData, split at row boundaries,
2814        // round-robin raw chunks to parallel batch builder workers.
2815        loop {
2816            let message = self.conn.recv().await?;
2817            match message {
2818                Some(FrontendMessage::CopyData(buf)) => {
2819                    if saw_end_marker {
2820                        // Per PostgreSQL COPY behavior, ignore all bytes after
2821                        // the end-of-copy marker until CopyDone.
2822                        continue;
2823                    }
2824                    data.extend(buf);
2825                    row_scanner.scan_new_bytes(&data);
2826
2827                    if let Some(end_pos) = row_scanner.end_marker_end() {
2828                        data.truncate(end_pos);
2829                        row_scanner.on_truncate(end_pos);
2830                        saw_end_marker = true;
2831                    }
2832
2833                    // Guard against pathological single rows that never terminate.
2834                    if row_scanner.current_row_size(data.len()) > max_copy_from_row_size {
2835                        copy_from_error = Some((
2836                            SqlState::INSUFFICIENT_RESOURCES,
2837                            format!(
2838                                "COPY FROM STDIN row exceeded max_copy_from_row_size \
2839                                 ({max_copy_from_row_size} bytes)"
2840                            ),
2841                        ));
2842                        break;
2843                    }
2844
2845                    // When buffer exceeds batch size, split at the last complete row
2846                    // and send the complete rows chunk to the next worker.
2847                    let mut send_failed = false;
2848                    while data.len() >= BATCH_SIZE {
2849                        let split_pos = match row_scanner.last_row_end() {
2850                            Some(pos) => pos,
2851                            None => break, // no complete row yet
2852                        };
2853                        let remainder = data.split_off(split_pos);
2854                        let chunk = std::mem::replace(&mut data, remainder);
2855                        row_scanner.on_split(split_pos);
2856                        if writer.batch_txs[next_worker].send(chunk).await.is_err() {
2857                            send_failed = true;
2858                            break;
2859                        }
2860                        next_worker = (next_worker + 1) % num_workers;
2861                    }
2862                    // Worker dropped (likely errored) — stop sending,
2863                    // fall through to completion_rx for the real error.
2864                    if send_failed {
2865                        break;
2866                    }
2867                }
2868                Some(FrontendMessage::CopyDone) => {
2869                    // Send any remaining data to the next worker.
2870                    if !data.is_empty() {
2871                        let chunk = std::mem::take(&mut data);
2872                        // Ignore send failure — completion_rx will have the error.
2873                        let _ = writer.batch_txs[next_worker].send(chunk).await;
2874                    }
2875                    saw_copy_done = true;
2876                    break;
2877                }
2878                Some(FrontendMessage::CopyFail(err)) => {
2879                    self.adapter_client.retire_execute(
2880                        std::mem::take(ctx_extra),
2881                        StatementEndedExecutionReason::Canceled,
2882                    );
2883                    // Drop the writer to signal cancellation to the background tasks.
2884                    drop(writer);
2885                    self.conn.set_copy_mode(false);
2886                    return self
2887                        .send_error_and_get_state(ErrorResponse::error(
2888                            SqlState::QUERY_CANCELED,
2889                            format!("COPY from stdin failed: {}", err),
2890                        ))
2891                        .await;
2892                }
2893                Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2894                Some(_) => {
2895                    let msg = "unexpected message type during COPY from stdin";
2896                    self.adapter_client.retire_execute(
2897                        std::mem::take(ctx_extra),
2898                        StatementEndedExecutionReason::Errored {
2899                            error: msg.to_string(),
2900                        },
2901                    );
2902                    drop(writer);
2903                    self.conn.set_copy_mode(false);
2904                    return self
2905                        .send_error_and_get_state(ErrorResponse::error(
2906                            SqlState::PROTOCOL_VIOLATION,
2907                            msg,
2908                        ))
2909                        .await;
2910                }
2911                None => {
2912                    drop(writer);
2913                    self.conn.set_copy_mode(false);
2914                    return Ok(State::Done);
2915                }
2916            }
2917        }
2918
2919        // If we exited the receive loop before seeing `CopyDone` (e.g. because
2920        // a worker failed and dropped its channel), keep draining COPY input to
2921        // avoid desynchronizing the protocol state machine.
2922        if !saw_copy_done {
2923            loop {
2924                match self.conn.recv().await? {
2925                    Some(FrontendMessage::CopyData(_)) => {}
2926                    Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2927                        break;
2928                    }
2929                    Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2930                    Some(_) => {
2931                        let msg = "unexpected message type during COPY from stdin";
2932                        self.adapter_client.retire_execute(
2933                            std::mem::take(ctx_extra),
2934                            StatementEndedExecutionReason::Errored {
2935                                error: msg.to_string(),
2936                            },
2937                        );
2938                        drop(writer);
2939                        self.conn.set_copy_mode(false);
2940                        return self
2941                            .send_error_and_get_state(ErrorResponse::error(
2942                                SqlState::PROTOCOL_VIOLATION,
2943                                msg,
2944                            ))
2945                            .await;
2946                    }
2947                    None => {
2948                        drop(writer);
2949                        self.conn.set_copy_mode(false);
2950                        return Ok(State::Done);
2951                    }
2952                }
2953            }
2954        }
2955
2956        if let Some((code, msg)) = copy_from_error {
2957            self.adapter_client.retire_execute(
2958                std::mem::take(ctx_extra),
2959                StatementEndedExecutionReason::Errored { error: msg.clone() },
2960            );
2961            drop(writer);
2962            self.conn.set_copy_mode(false);
2963            return self
2964                .send_error_and_get_state(ErrorResponse::error(code, msg))
2965                .await;
2966        }
2967
2968        self.conn.set_copy_mode(false);
2969
2970        // Drop all senders to signal EOF to the background batch builders.
2971        // If copy_err is set, a worker already failed — dropping the senders
2972        // will cause remaining workers to stop, and we'll get the real error
2973        // from completion_rx below.
2974        drop(writer.batch_txs);
2975
2976        // Wait for all parallel workers to finish building batches.
2977        let (proto_batches, row_count) = match writer.completion_rx.await {
2978            Ok(Ok(result)) => result,
2979            Ok(Err(e)) => {
2980                self.adapter_client.retire_execute(
2981                    std::mem::take(ctx_extra),
2982                    StatementEndedExecutionReason::Errored {
2983                        error: e.to_string(),
2984                    },
2985                );
2986                return self
2987                    .send_error_and_get_state(e.into_response(Severity::Error))
2988                    .await;
2989            }
2990            Err(_) => {
2991                let msg = "COPY FROM STDIN: background batch builder tasks dropped";
2992                self.adapter_client.retire_execute(
2993                    std::mem::take(ctx_extra),
2994                    StatementEndedExecutionReason::Errored {
2995                        error: msg.to_string(),
2996                    },
2997                );
2998                return self
2999                    .send_error_and_get_state(ErrorResponse::error(SqlState::INTERNAL_ERROR, msg))
3000                    .await;
3001            }
3002        };
3003
3004        // Stage all batches in the session's transaction for atomic commit.
3005        if let Err(e) = self
3006            .adapter_client
3007            .stage_copy_from_stdin_batches(target_id, proto_batches)
3008        {
3009            self.adapter_client.retire_execute(
3010                std::mem::take(ctx_extra),
3011                StatementEndedExecutionReason::Errored {
3012                    error: e.to_string(),
3013                },
3014            );
3015            return self
3016                .send_error_and_get_state(e.into_response(Severity::Error))
3017                .await;
3018        }
3019
3020        let tag = format!("COPY {}", row_count);
3021        self.send(BackendMessage::CommandComplete { tag }).await?;
3022
3023        Ok(State::Ready)
3024    }
3025
3026    #[instrument(level = "debug")]
3027    async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
3028        let notices = self
3029            .adapter_client
3030            .session()
3031            .drain_notices()
3032            .into_iter()
3033            .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
3034        self.send_all(notices).await?;
3035        Ok(())
3036    }
3037
3038    #[instrument(level = "debug")]
3039    async fn send_error_and_get_state(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
3040        assert!(err.severity.is_error());
3041        debug!(
3042            "cid={} error code={}",
3043            self.adapter_client.session().conn_id(),
3044            err.code.code()
3045        );
3046        let is_fatal = err.severity.is_fatal();
3047        self.send(BackendMessage::ErrorResponse(err)).await?;
3048
3049        let txn = self.adapter_client.session().transaction();
3050        match txn {
3051            // Error can be called from describe and parse and so might not be in an active
3052            // transaction.
3053            TransactionStatus::Default | TransactionStatus::Failed(_) => {}
3054            // In Started (i.e., a single statement), cleanup ourselves.
3055            TransactionStatus::Started(_) => {
3056                self.rollback_transaction().await?;
3057            }
3058            // Implicit transactions also clear themselves.
3059            TransactionStatus::InTransactionImplicit(_) => {
3060                self.rollback_transaction().await?;
3061            }
3062            // Explicit transactions move to failed.
3063            TransactionStatus::InTransaction(_) => {
3064                self.adapter_client.fail_transaction();
3065            }
3066        };
3067        if is_fatal {
3068            Ok(State::Done)
3069        } else {
3070            Ok(State::Drain)
3071        }
3072    }
3073
3074    #[instrument(level = "debug")]
3075    async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
3076        self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
3077            SqlState::IN_FAILED_SQL_TRANSACTION,
3078            ABORTED_TXN_MSG,
3079        )))
3080        .await?;
3081        Ok(State::Drain)
3082    }
3083
3084    fn is_aborted_txn(&mut self) -> bool {
3085        matches!(
3086            self.adapter_client.session().transaction(),
3087            TransactionStatus::Failed(_)
3088        )
3089    }
3090}
3091
3092fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
3093    match (formats.len(), n) {
3094        (0, e) => Ok(vec![Format::Text; e]),
3095        (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
3096        (a, e) if a == e => Ok(formats),
3097        (a, e) => Err(format!(
3098            "expected {} field format specifiers, but got {}",
3099            e, a
3100        )),
3101    }
3102}
3103
3104fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
3105    match &stmt_desc.relation_desc {
3106        Some(desc) if !stmt_desc.is_copy => {
3107            BackendMessage::RowDescription(message::encode_row_description(desc, formats))
3108        }
3109        _ => BackendMessage::NoData,
3110    }
3111}
3112
3113type GetResponse = fn(
3114    max_rows: ExecuteCount,
3115    total_sent_rows: usize,
3116    fetch_portal: Option<PortalRefMut>,
3117) -> BackendMessage;
3118
3119// A GetResponse used by send_rows during execute messages on portals or for
3120// simple query messages.
3121fn portal_exec_message(
3122    max_rows: ExecuteCount,
3123    total_sent_rows: usize,
3124    _fetch_portal: Option<PortalRefMut>,
3125) -> BackendMessage {
3126    // If max_rows is not specified, we will always send back a CommandComplete. If
3127    // max_rows is specified, we only send CommandComplete if there were more rows
3128    // requested than were remaining. That is, if max_rows == number of rows that
3129    // were remaining before sending (not that are remaining after sending), then
3130    // we still send a PortalSuspended. The number of remaining rows after the rows
3131    // have been sent doesn't matter. This matches postgres.
3132    match max_rows {
3133        ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
3134            BackendMessage::PortalSuspended
3135        }
3136        _ => BackendMessage::CommandComplete {
3137            tag: format!("SELECT {}", total_sent_rows),
3138        },
3139    }
3140}
3141
3142// A GetResponse used by send_rows during FETCH queries.
3143fn fetch_message(
3144    _max_rows: ExecuteCount,
3145    total_sent_rows: usize,
3146    fetch_portal: Option<PortalRefMut>,
3147) -> BackendMessage {
3148    let tag = format!("FETCH {}", total_sent_rows);
3149    if let Some(portal) = fetch_portal {
3150        *portal.state = PortalState::Completed(Some(tag.clone()));
3151    }
3152    BackendMessage::CommandComplete { tag }
3153}
3154
3155fn get_authenticator(
3156    authenticator_kind: AuthenticatorKind,
3157    frontegg: Option<FronteggAuthenticator>,
3158    oidc: GenericOidcAuthenticator,
3159    adapter_client: mz_adapter::Client,
3160    // If oidc_auth_enabled exists as an option in the pgwire connection's
3161    // `option` parameter
3162    oidc_auth_option_enabled: bool,
3163) -> Authenticator {
3164    match authenticator_kind {
3165        AuthenticatorKind::Frontegg => Authenticator::Frontegg(
3166            frontegg.expect("Frontegg authenticator should exist with AuthenticatorKind::Frontegg"),
3167        ),
3168        AuthenticatorKind::Password => Authenticator::Password(adapter_client),
3169        AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client),
3170        AuthenticatorKind::Oidc => {
3171            if oidc_auth_option_enabled {
3172                Authenticator::Oidc(oidc)
3173            } else {
3174                // Fallback to password authentication if oidc auth is not enabled
3175                // through options.
3176                Authenticator::Password(adapter_client)
3177            }
3178        }
3179        AuthenticatorKind::None => Authenticator::None,
3180    }
3181}
3182
3183#[derive(Debug, Copy, Clone)]
3184enum ExecuteCount {
3185    All,
3186    Count(usize),
3187}
3188
3189// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
3190fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
3191    match stmt {
3192        // Add PREPARE to this if we ever support it.
3193        Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
3194        None => false,
3195    }
3196}
3197
3198#[derive(Debug)]
3199enum FetchResult {
3200    Rows(Option<Box<dyn RowIterator + Send + Sync>>),
3201    Canceled,
3202    Error(String),
3203    Notice(AdapterNotice),
3204}
3205
3206#[derive(Debug)]
3207struct CopyRowScanner {
3208    scan_pos: usize,
3209    last_row_end: Option<usize>,
3210    end_marker_end: Option<usize>,
3211    csv: Option<CsvScanState>,
3212}
3213
3214#[derive(Debug)]
3215struct CsvScanState {
3216    reader: csv_core::Reader,
3217    output: Vec<u8>,
3218    ends: Vec<usize>,
3219    record: Vec<u8>,
3220    record_ends: Vec<usize>,
3221    skip_first_record: bool,
3222}
3223
3224impl CopyRowScanner {
3225    fn new(params: &CopyFormatParams<'_>) -> Self {
3226        let csv = match params {
3227            CopyFormatParams::Csv(CopyCsvFormatParams {
3228                delimiter,
3229                quote,
3230                escape,
3231                header,
3232                ..
3233            }) => Some(CsvScanState::new(*delimiter, *quote, *escape, *header)),
3234            _ => None,
3235        };
3236
3237        CopyRowScanner {
3238            scan_pos: 0,
3239            last_row_end: None,
3240            end_marker_end: None,
3241            csv,
3242        }
3243    }
3244
3245    fn scan_new_bytes(&mut self, data: &[u8]) {
3246        if self.scan_pos >= data.len() {
3247            return;
3248        }
3249
3250        if let Some(csv) = self.csv.as_mut() {
3251            let mut input = &data[self.scan_pos..];
3252            let mut consumed = 0usize;
3253            while !input.is_empty() {
3254                let (result, n_input, n_output, n_ends) =
3255                    csv.reader
3256                        .read_record(input, &mut csv.output, &mut csv.ends);
3257                consumed += n_input;
3258                input = &input[n_input..];
3259                if !csv.output.is_empty() {
3260                    csv.record.extend_from_slice(&csv.output[..n_output]);
3261                }
3262                if !csv.ends.is_empty() {
3263                    csv.record_ends.extend_from_slice(&csv.ends[..n_ends]);
3264                }
3265
3266                match result {
3267                    ReadRecordResult::InputEmpty => break,
3268                    ReadRecordResult::OutputFull => {
3269                        if n_input == 0 {
3270                            csv.output
3271                                .resize(csv.output.len().saturating_mul(2).max(1), 0);
3272                        }
3273                    }
3274                    ReadRecordResult::OutputEndsFull => {
3275                        if n_input == 0 {
3276                            csv.ends.resize(csv.ends.len().saturating_mul(2).max(1), 0);
3277                        }
3278                    }
3279                    ReadRecordResult::Record | ReadRecordResult::End => {
3280                        let row_end = self.scan_pos + consumed;
3281                        self.last_row_end = Some(row_end);
3282                        if self.end_marker_end.is_none() {
3283                            let is_marker = if csv.skip_first_record {
3284                                csv.skip_first_record = false;
3285                                false
3286                            } else if csv.record_ends.len() == 1 {
3287                                let end = csv.record_ends[0];
3288                                end == 2 && csv.record.get(0..end) == Some(b"\\.")
3289                            } else {
3290                                false
3291                            };
3292                            if is_marker {
3293                                self.end_marker_end = Some(row_end);
3294                                break;
3295                            }
3296                        }
3297                        csv.record.clear();
3298                        csv.record_ends.clear();
3299                    }
3300                }
3301            }
3302        } else {
3303            let mut row_start = self.last_row_end.unwrap_or(0);
3304            for (offset, b) in data[self.scan_pos..].iter().enumerate() {
3305                if *b == b'\n' {
3306                    let row_end = self.scan_pos + offset + 1;
3307                    self.last_row_end = Some(row_end);
3308                    if self.end_marker_end.is_none() {
3309                        let row = &data[row_start..row_end];
3310                        if row.get(0..2) == Some(b"\\.") {
3311                            self.end_marker_end = Some(row_end);
3312                            break;
3313                        }
3314                    }
3315                    row_start = row_end;
3316                }
3317            }
3318        }
3319
3320        self.scan_pos = data.len();
3321    }
3322
3323    fn last_row_end(&self) -> Option<usize> {
3324        self.last_row_end
3325    }
3326
3327    fn end_marker_end(&self) -> Option<usize> {
3328        self.end_marker_end
3329    }
3330
3331    fn current_row_size(&self, data_len: usize) -> usize {
3332        data_len.saturating_sub(self.last_row_end.unwrap_or(0))
3333    }
3334
3335    fn on_split(&mut self, split_pos: usize) {
3336        self.scan_pos = self.scan_pos.saturating_sub(split_pos);
3337        self.last_row_end = None;
3338        self.end_marker_end = self
3339            .end_marker_end
3340            .and_then(|end| end.checked_sub(split_pos));
3341    }
3342
3343    fn on_truncate(&mut self, new_len: usize) {
3344        self.scan_pos = self.scan_pos.min(new_len);
3345        self.last_row_end = self.last_row_end.filter(|&end| end <= new_len);
3346        self.end_marker_end = self.end_marker_end.filter(|&end| end <= new_len);
3347    }
3348}
3349
3350impl CsvScanState {
3351    fn new(delimiter: u8, quote: u8, escape: u8, header: bool) -> Self {
3352        let (double_quote, escape) = if quote == escape {
3353            (true, None)
3354        } else {
3355            (false, Some(escape))
3356        };
3357        CsvScanState {
3358            reader: csv_core::ReaderBuilder::new()
3359                .delimiter(delimiter)
3360                .quote(quote)
3361                .double_quote(double_quote)
3362                .escape(escape)
3363                .build(),
3364            output: vec![0; 1],
3365            ends: vec![0; 1],
3366            record: Vec::new(),
3367            record_ends: Vec::new(),
3368            skip_first_record: header,
3369        }
3370    }
3371}
3372
3373#[cfg(test)]
3374mod test {
3375    use super::*;
3376
3377    #[mz_ore::test]
3378    fn test_parse_options() {
3379        struct TestCase {
3380            input: &'static str,
3381            expect: Result<Vec<(&'static str, &'static str)>, ()>,
3382        }
3383        let tests = vec![
3384            TestCase {
3385                input: "",
3386                expect: Ok(vec![]),
3387            },
3388            TestCase {
3389                input: "--key",
3390                expect: Err(()),
3391            },
3392            TestCase {
3393                input: "--key=val",
3394                expect: Ok(vec![("key", "val")]),
3395            },
3396            TestCase {
3397                input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
3398                expect: Ok(vec![
3399                    ("key", "val"),
3400                    ("key2", "val2"),
3401                    ("key3", "val3"),
3402                    ("key4", "val4"),
3403                    ("key5", "val5"),
3404                ]),
3405            },
3406            TestCase {
3407                input: r#"-c\ key=val"#,
3408                expect: Ok(vec![(" key", "val")]),
3409            },
3410            TestCase {
3411                input: "--key=val -ckey2 val2",
3412                expect: Err(()),
3413            },
3414            // Unclear what this should do.
3415            TestCase {
3416                input: "--key=",
3417                expect: Ok(vec![("key", "")]),
3418            },
3419        ];
3420        for test in tests {
3421            let got = parse_options(test.input);
3422            let expect = test.expect.map(|r| {
3423                r.into_iter()
3424                    .map(|(k, v)| (k.to_owned(), v.to_owned()))
3425                    .collect()
3426            });
3427            assert_eq!(got, expect, "input: {}", test.input);
3428        }
3429    }
3430
3431    #[mz_ore::test]
3432    fn test_parse_option() {
3433        struct TestCase {
3434            input: &'static str,
3435            expect: Result<(&'static str, &'static str), ()>,
3436        }
3437        let tests = vec![
3438            TestCase {
3439                input: "",
3440                expect: Err(()),
3441            },
3442            TestCase {
3443                input: "--",
3444                expect: Err(()),
3445            },
3446            TestCase {
3447                input: "--c",
3448                expect: Err(()),
3449            },
3450            TestCase {
3451                input: "a=b",
3452                expect: Err(()),
3453            },
3454            TestCase {
3455                input: "--a=b",
3456                expect: Ok(("a", "b")),
3457            },
3458            TestCase {
3459                input: "--ca=b",
3460                expect: Ok(("ca", "b")),
3461            },
3462            TestCase {
3463                input: "-ca=b",
3464                expect: Ok(("a", "b")),
3465            },
3466            // Unclear what this should error, but at least test it.
3467            TestCase {
3468                input: "--=",
3469                expect: Ok(("", "")),
3470            },
3471        ];
3472        for test in tests {
3473            let got = parse_option(test.input);
3474            assert_eq!(got, test.expect, "input: {}", test.input);
3475        }
3476    }
3477
3478    #[mz_ore::test]
3479    fn test_split_options() {
3480        struct TestCase {
3481            input: &'static str,
3482            expect: Vec<&'static str>,
3483        }
3484        let tests = vec![
3485            TestCase {
3486                input: "",
3487                expect: vec![],
3488            },
3489            TestCase {
3490                input: "  ",
3491                expect: vec![],
3492            },
3493            TestCase {
3494                input: " a ",
3495                expect: vec!["a"],
3496            },
3497            TestCase {
3498                input: "  ab     cd   ",
3499                expect: vec!["ab", "cd"],
3500            },
3501            TestCase {
3502                input: r#"  ab\     cd   "#,
3503                expect: vec!["ab ", "cd"],
3504            },
3505            TestCase {
3506                input: r#"  ab\\     cd   "#,
3507                expect: vec![r#"ab\"#, "cd"],
3508            },
3509            TestCase {
3510                input: r#"  ab\\\     cd   "#,
3511                expect: vec![r#"ab\ "#, "cd"],
3512            },
3513            TestCase {
3514                input: r#"  ab\\\ cd   "#,
3515                expect: vec![r#"ab\ cd"#],
3516            },
3517            TestCase {
3518                input: r#"  ab\\\cd   "#,
3519                expect: vec![r#"ab\cd"#],
3520            },
3521            TestCase {
3522                input: r#"a\"#,
3523                expect: vec!["a"],
3524            },
3525            TestCase {
3526                input: r#"a\ "#,
3527                expect: vec!["a "],
3528            },
3529            TestCase {
3530                input: r#"\"#,
3531                expect: vec![],
3532            },
3533            TestCase {
3534                input: r#"\ "#,
3535                expect: vec![r#" "#],
3536            },
3537            TestCase {
3538                input: r#" \ "#,
3539                expect: vec![r#" "#],
3540            },
3541            TestCase {
3542                input: r#"\  "#,
3543                expect: vec![r#" "#],
3544            },
3545        ];
3546        for test in tests {
3547            let got = split_options(test.input);
3548            assert_eq!(got, test.expect, "input: {}", test.input);
3549        }
3550    }
3551}