Skip to main content

mz_pgwire/
protocol.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use csv_core::ReadRecordResult;
21use futures::future::{BoxFuture, FutureExt, pending};
22use itertools::Itertools;
23use mz_adapter::client::RecordFirstRowStream;
24use mz_adapter::session::{
25    EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState, Session,
26    SessionConfig, TransactionStatus,
27};
28use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
29use mz_adapter::{
30    AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics,
31    verify_datum_desc,
32};
33use mz_auth::Authenticated;
34use mz_auth::password::Password;
35use mz_authenticator::{Authenticator, GenericOidcAuthenticator};
36use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
37use mz_ore::cast::CastFrom;
38use mz_ore::netio::AsyncReady;
39use mz_ore::now::{EpochMillis, SYSTEM_TIME};
40use mz_ore::str::StrExt;
41use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
42use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
43use mz_pgwire_common::{
44    ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
45    VERSIONS,
46};
47use mz_repr::{
48    CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
49    SqlRelationType, SqlScalarType,
50};
51use mz_server_core::TlsMode;
52use mz_server_core::listeners;
53use mz_server_core::listeners::AllowedRoles;
54use mz_sql::ast::display::AstDisplay;
55use mz_sql::ast::{
56    CopyDirection, CopyStatement, CopyTarget, FetchDirection, Ident, Raw, Statement,
57};
58use mz_sql::parse::StatementParseResult;
59use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
60use mz_sql::session::metadata::SessionMetadata;
61use mz_sql::session::user::INTERNAL_USER_NAMES;
62use mz_sql::session::vars::VarInput;
63use postgres::error::SqlState;
64use tokio::io::{self, AsyncRead, AsyncWrite};
65use tokio::select;
66use tokio::time::{self};
67use tokio_metrics::TaskMetrics;
68use tokio_stream::wrappers::UnboundedReceiverStream;
69use tracing::{Instrument, debug, debug_span, warn};
70use uuid::Uuid;
71
72use crate::codec::{
73    FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
74};
75use crate::message::{
76    self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
77    SASLServerFirstMessage,
78};
79
80/// Reports whether the given stream begins with a pgwire handshake.
81///
82/// To avoid false negatives, there must be at least eight bytes in `buf`.
83pub fn match_handshake(buf: &[u8]) -> bool {
84    // The pgwire StartupMessage looks like this:
85    //
86    //     i32 - Length of entire message.
87    //     i32 - Protocol version number.
88    //     [String] - Arbitrary key-value parameters of any length.
89    //
90    // Since arbitrary parameters can be included in the StartupMessage, the
91    // first Int32 is worthless, since the message could have any length.
92    // Instead, we sniff the protocol version number.
93    if buf.len() < 8 {
94        return false;
95    }
96    let version = NetworkEndian::read_i32(&buf[4..8]);
97    VERSIONS.contains(&version)
98}
99
100/// Parameters for the [`run`] function.
101pub struct RunParams<'a, A, I>
102where
103    I: Iterator<Item = TaskMetrics> + Send,
104{
105    /// The TLS mode of the pgwire server.
106    pub tls_mode: Option<TlsMode>,
107    /// A client for the adapter.
108    pub adapter_client: mz_adapter::Client,
109    /// The connection to the client.
110    pub conn: &'a mut FramedConn<A>,
111    /// The universally unique identifier for the connection.
112    pub conn_uuid: Uuid,
113    /// The protocol version that the client provided in the startup message.
114    pub version: i32,
115    /// The parameters that the client provided in the startup message.
116    pub params: BTreeMap<String, String>,
117    /// Frontegg JWT authenticator.
118    pub frontegg: Option<FronteggAuthenticator>,
119    /// OIDC authenticator.
120    pub oidc: GenericOidcAuthenticator,
121    /// The authentication method defined by the server's listener
122    /// configuration.
123    pub authenticator_kind: listeners::AuthenticatorKind,
124    /// Global connection limit and count
125    pub active_connection_counter: ConnectionCounter,
126    /// Helm chart version
127    pub helm_chart_version: Option<String>,
128    /// Whether to allow reserved users (ie: mz_system).
129    pub allowed_roles: AllowedRoles,
130    /// Tokio metrics
131    pub tokio_metrics_intervals: I,
132}
133
134/// Runs a pgwire connection to completion.
135///
136/// This involves responding to `FrontendMessage::StartupMessage` and all future
137/// requests until the client terminates the connection or a fatal error occurs.
138///
139/// Note that this function returns successfully even upon delivering a fatal
140/// error to the client. It only returns `Err` if an unexpected I/O error occurs
141/// while communicating with the client, e.g., if the connection is severed in
142/// the middle of a request.
143#[mz_ore::instrument(level = "debug")]
144pub async fn run<'a, A, I>(
145    RunParams {
146        tls_mode,
147        adapter_client,
148        conn,
149        conn_uuid,
150        version,
151        mut params,
152        frontegg,
153        oidc,
154        authenticator_kind,
155        active_connection_counter,
156        helm_chart_version,
157        allowed_roles,
158        tokio_metrics_intervals,
159    }: RunParams<'a, A, I>,
160) -> Result<(), io::Error>
161where
162    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
163    I: Iterator<Item = TaskMetrics> + Send,
164{
165    if version != VERSION_3 {
166        return conn
167            .send(ErrorResponse::fatal(
168                SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
169                "server does not support the client's requested protocol version",
170            ))
171            .await;
172    }
173
174    let user = params.remove("user").unwrap_or_else(String::new);
175    let options = parse_options(params.get("options").unwrap_or(&String::new()));
176
177    // If oidc_auth_enabled exists as an option, return its value and filter it from
178    // the remaining options.
179    let (oidc_auth_enabled, options) = extract_oidc_auth_enabled_from_options(options);
180    let authenticator = get_authenticator(
181        authenticator_kind,
182        frontegg,
183        oidc,
184        adapter_client.clone(),
185        oidc_auth_enabled,
186    );
187    // TODO move this somewhere it can be shared with HTTP
188    let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
189    // this is a superset of internal users
190    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
191    let role_allowed = match allowed_roles {
192        AllowedRoles::Normal => !is_reserved_user,
193        AllowedRoles::Internal => is_internal_user,
194        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
195    };
196    if !role_allowed {
197        let msg = format!("unauthorized login to user '{user}'");
198        return conn
199            .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
200            .await;
201    }
202
203    if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
204        return conn.send(err).await;
205    }
206
207    let authenticator_kind = authenticator.kind();
208
209    let (mut session, expired) = match authenticator {
210        Authenticator::Frontegg(frontegg) => {
211            let password = match request_cleartext_password(conn).await {
212                Ok(password) => password,
213                Err(PasswordRequestError::IoError(e)) => return Err(e),
214                Err(PasswordRequestError::InvalidPasswordError(e)) => {
215                    return conn.send(e).await;
216                }
217            };
218
219            let auth_response = frontegg.authenticate(&user, &password).await;
220            match auth_response {
221                // Create a session based on the auth session.
222                //
223                // In particular, it's important that the username come from the
224                // auth session, as Frontegg may return an email address with
225                // different casing than the user supplied via the pgwire
226                // username fN
227                Ok((mut auth_session, authenticated)) => {
228                    let session = adapter_client.new_session(
229                        SessionConfig {
230                            conn_id: conn.conn_id().clone(),
231                            uuid: conn_uuid,
232                            user: auth_session.user().into(),
233                            client_ip: conn.peer_addr().clone(),
234                            external_metadata_rx: Some(auth_session.external_metadata_rx()),
235                            helm_chart_version,
236                            authenticator_kind,
237                            groups: None,
238                        },
239                        authenticated,
240                    );
241                    let expired = async move { auth_session.expired().await };
242                    (session, expired.left_future())
243                }
244                Err(err) => {
245                    warn!(?err, "pgwire connection failed authentication");
246                    return conn
247                        .send(ErrorResponse::fatal(
248                            SqlState::INVALID_PASSWORD,
249                            "invalid password",
250                        ))
251                        .await;
252                }
253            }
254        }
255        Authenticator::Oidc(oidc) => {
256            // OIDC authentication: JWT sent as password in cleartext flow
257            let jwt = match request_cleartext_password(conn).await {
258                Ok(password) => password,
259                Err(PasswordRequestError::IoError(e)) => return Err(e),
260                Err(PasswordRequestError::InvalidPasswordError(e)) => {
261                    return conn.send(e).await;
262                }
263            };
264            let auth_response = oidc.authenticate(&jwt, Some(&user)).await;
265            match auth_response {
266                Ok((mut claims, authenticated)) => {
267                    let groups = claims.groups.take();
268                    let session = adapter_client.new_session(
269                        SessionConfig {
270                            conn_id: conn.conn_id().clone(),
271                            uuid: conn_uuid,
272                            user: std::mem::take(&mut claims.user),
273                            client_ip: conn.peer_addr().clone(),
274                            external_metadata_rx: None,
275                            helm_chart_version,
276                            authenticator_kind,
277                            groups,
278                        },
279                        authenticated,
280                    );
281                    // No invalidation of the auth session once authenticated,
282                    // so auth session lasts indefinitely.
283                    (session, pending().right_future())
284                }
285                Err(err) => {
286                    warn!(?err, "pgwire connection failed authentication");
287                    return conn.send(err.into_response()).await;
288                }
289            }
290        }
291        Authenticator::Password(adapter_client) => {
292            let session = match authenticate_with_password(
293                conn,
294                &adapter_client,
295                user,
296                conn_uuid,
297                helm_chart_version,
298                authenticator_kind,
299            )
300            .await
301            {
302                Ok(session) => session,
303                Err(PasswordRequestError::IoError(e)) => return Err(e),
304                Err(PasswordRequestError::InvalidPasswordError(e)) => {
305                    return conn.send(e).await;
306                }
307            };
308            // No frontegg check, so auth session lasts indefinitely.
309            (session, pending().right_future())
310        }
311        Authenticator::Sasl(adapter_client) => {
312            // Start the handshake
313            conn.send(BackendMessage::AuthenticationSASL).await?;
314            conn.flush().await?;
315            // Get the initial response indicating chosen mechanism
316            let (mechanism, initial_response) = match conn.recv().await? {
317                Some(FrontendMessage::RawAuthentication(data)) => {
318                    match decode_sasl_initial_response(Cursor::new(&data)).ok() {
319                        Some(FrontendMessage::SASLInitialResponse {
320                            gs2_header,
321                            mechanism,
322                            initial_response,
323                        }) => {
324                            // We do not support channel binding
325                            if gs2_header.channel_binding_enabled() {
326                                return conn
327                                    .send(ErrorResponse::fatal(
328                                        SqlState::PROTOCOL_VIOLATION,
329                                        "channel binding not supported",
330                                    ))
331                                    .await;
332                            }
333                            (mechanism, initial_response)
334                        }
335                        _ => {
336                            return conn
337                                .send(ErrorResponse::fatal(
338                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
339                                    "expected SASLInitialResponse message",
340                                ))
341                                .await;
342                        }
343                    }
344                }
345                _ => {
346                    return conn
347                        .send(ErrorResponse::fatal(
348                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
349                            "expected SASLInitialResponse message",
350                        ))
351                        .await;
352                }
353            };
354
355            if mechanism != "SCRAM-SHA-256" {
356                return conn
357                    .send(ErrorResponse::fatal(
358                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
359                        "unsupported SASL mechanism",
360                    ))
361                    .await;
362            }
363
364            if initial_response.nonce.len() > 256 {
365                return conn
366                    .send(ErrorResponse::fatal(
367                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
368                        "nonce too long",
369                    ))
370                    .await;
371            }
372
373            let (server_first_message_raw, mock_hash) = match adapter_client
374                .generate_sasl_challenge(&user, &initial_response.nonce)
375                .await
376            {
377                Ok(response) => {
378                    let server_first_message_raw = format!(
379                        "r={},s={},i={}",
380                        response.nonce, response.salt, response.iteration_count
381                    );
382
383                    let client_key = [0u8; 32];
384                    let server_key = [1u8; 32];
385                    let mock_hash = format!(
386                        "SCRAM-SHA-256${}:{}${}:{}",
387                        response.iteration_count,
388                        response.salt,
389                        BASE64_STANDARD.encode(client_key),
390                        BASE64_STANDARD.encode(server_key)
391                    );
392
393                    conn.send(BackendMessage::AuthenticationSASLContinue(
394                        SASLServerFirstMessage {
395                            iteration_count: response.iteration_count,
396                            nonce: response.nonce,
397                            salt: response.salt,
398                        },
399                    ))
400                    .await?;
401                    conn.flush().await?;
402                    (server_first_message_raw, mock_hash)
403                }
404                Err(e) => {
405                    return conn.send(e.into_response(Severity::Fatal)).await;
406                }
407            };
408
409            let authenticated = match conn.recv().await? {
410                Some(FrontendMessage::RawAuthentication(data)) => {
411                    match decode_sasl_response(Cursor::new(&data)).ok() {
412                        Some(FrontendMessage::SASLResponse(response)) => {
413                            let auth_message = format!(
414                                "{},{},{}",
415                                initial_response.client_first_message_bare_raw,
416                                server_first_message_raw,
417                                response.client_final_message_bare_raw
418                            );
419                            if response.proof.len() > 1024 {
420                                return conn
421                                    .send(ErrorResponse::fatal(
422                                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
423                                        "proof too long",
424                                    ))
425                                    .await;
426                            }
427                            match adapter_client
428                                .verify_sasl_proof(
429                                    &user,
430                                    &response.proof,
431                                    &auth_message,
432                                    &mock_hash,
433                                )
434                                .await
435                            {
436                                Ok((proof_response, authenticated)) => {
437                                    conn.send(BackendMessage::AuthenticationSASLFinal(
438                                        SASLServerFinalMessage {
439                                            kind: SASLServerFinalMessageKinds::Verifier(
440                                                proof_response.verifier,
441                                            ),
442                                            extensions: vec![],
443                                        },
444                                    ))
445                                    .await?;
446                                    conn.flush().await?;
447                                    authenticated
448                                }
449                                Err(_) => {
450                                    return conn
451                                        .send(ErrorResponse::fatal(
452                                            SqlState::INVALID_PASSWORD,
453                                            "invalid password",
454                                        ))
455                                        .await;
456                                }
457                            }
458                        }
459                        _ => {
460                            return conn
461                                .send(ErrorResponse::fatal(
462                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
463                                    "expected SASLResponse message",
464                                ))
465                                .await;
466                        }
467                    }
468                }
469                _ => {
470                    return conn
471                        .send(ErrorResponse::fatal(
472                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
473                            "expected SASLResponse message",
474                        ))
475                        .await;
476                }
477            };
478
479            let session = adapter_client.new_session(
480                SessionConfig {
481                    conn_id: conn.conn_id().clone(),
482                    uuid: conn_uuid,
483                    user,
484                    client_ip: conn.peer_addr().clone(),
485                    external_metadata_rx: None,
486                    helm_chart_version,
487                    authenticator_kind,
488                    groups: None,
489                },
490                authenticated,
491            );
492            // No frontegg check, so auth session lasts indefinitely.
493            let auth_session = pending().right_future();
494            (session, auth_session)
495        }
496
497        Authenticator::None => {
498            let session = adapter_client.new_session(
499                SessionConfig {
500                    conn_id: conn.conn_id().clone(),
501                    uuid: conn_uuid,
502                    user,
503                    client_ip: conn.peer_addr().clone(),
504                    external_metadata_rx: None,
505                    helm_chart_version,
506                    authenticator_kind,
507                    groups: None,
508                },
509                Authenticated,
510            );
511            // No frontegg check, so auth session lasts indefinitely.
512            let auth_session = pending().right_future();
513            (session, auth_session)
514        }
515    };
516
517    let system_vars = adapter_client.get_system_vars().await;
518    for (name, value) in params {
519        let settings = match name.as_str() {
520            "options" => match &options {
521                Ok(opts) => opts,
522                Err(()) => {
523                    session.add_notice(AdapterNotice::BadStartupSetting {
524                        name,
525                        reason: "could not parse".into(),
526                    });
527                    continue;
528                }
529            },
530            _ => &vec![(name, value)],
531        };
532        for (key, val) in settings {
533            const LOCAL: bool = false;
534            // TODO: Issuing an error here is better than what we did before
535            // (silently ignore errors on set), but erroring the connection
536            // might be the better behavior. We maybe need to support more
537            // options sent by psql and drivers before we can safely do this.
538            if let Err(err) = session
539                .vars_mut()
540                .set(&system_vars, key, VarInput::Flat(val), LOCAL)
541            {
542                session.add_notice(AdapterNotice::BadStartupSetting {
543                    name: key.clone(),
544                    reason: err.to_string(),
545                });
546            }
547        }
548    }
549    session
550        .vars_mut()
551        .end_transaction(EndTransactionAction::Commit);
552
553    let _guard = match active_connection_counter.allocate_connection(session.user()) {
554        Ok(drop_connection) => drop_connection,
555        Err(e) => {
556            let e: AdapterError = e.into();
557            return conn.send(e.into_response(Severity::Fatal)).await;
558        }
559    };
560
561    // Register session with adapter.
562    let mut adapter_client = match adapter_client.startup(session).await {
563        Ok(adapter_client) => adapter_client,
564        Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
565    };
566
567    let mut buf = vec![BackendMessage::AuthenticationOk];
568    for var in adapter_client.session().vars().notify_set() {
569        buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
570    }
571    buf.push(BackendMessage::BackendKeyData {
572        conn_id: adapter_client.session().conn_id().unhandled(),
573        secret_key: adapter_client.session().secret_key(),
574    });
575    buf.extend(
576        adapter_client
577            .session()
578            .drain_notices()
579            .into_iter()
580            .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
581    );
582    buf.push(BackendMessage::ReadyForQuery(
583        adapter_client.session().transaction().into(),
584    ));
585    conn.send_all(buf).await?;
586    conn.flush().await?;
587
588    let machine = StateMachine {
589        conn,
590        adapter_client,
591        txn_needs_commit: false,
592        tokio_metrics_intervals,
593    };
594
595    select! {
596        r = machine.run() => {
597            // Errors produced internally (like MAX_REQUEST_SIZE being exceeded) should send an
598            // error to the client informing them why the connection was closed. We still want to
599            // return the original error up the stack, though, so we skip error checking during conn
600            // operations.
601            if let Err(err) = &r {
602                let _ = conn
603                    .send(ErrorResponse::fatal(
604                        SqlState::CONNECTION_FAILURE,
605                        err.to_string(),
606                    ))
607                    .await;
608                let _ = conn.flush().await;
609            }
610            r
611        },
612        _ = expired => {
613            conn
614                .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
615                .await?;
616            conn.flush().await
617        }
618    }
619}
620
621/// Gets `oidc_auth_enabled` from options if it exists.
622/// Returns options with oidc_auth_enabled extracted
623/// and the oidc_auth_enabled value.
624fn extract_oidc_auth_enabled_from_options(
625    options: Result<Vec<(String, String)>, ()>,
626) -> (bool, Result<Vec<(String, String)>, ()>) {
627    let options = match options {
628        Ok(opts) => opts,
629        Err(_) => return (false, options),
630    };
631
632    let mut new_options = Vec::new();
633    let mut oidc_auth_enabled = false;
634
635    for (k, v) in options {
636        if k == "oidc_auth_enabled" {
637            oidc_auth_enabled = v.parse::<bool>().unwrap_or(false);
638        } else {
639            new_options.push((k, v));
640        }
641    }
642
643    (oidc_auth_enabled, Ok(new_options))
644}
645
646/// Returns (name, value) session settings pairs from an options value.
647///
648/// From Postgres, see pg_split_opts in postinit.c and process_postgres_switches
649/// in postgres.c.
650fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
651    let opts = split_options(value);
652    let mut pairs = Vec::with_capacity(opts.len());
653    let mut seen_prefix = false;
654    for opt in opts {
655        if !seen_prefix {
656            if opt == "-c" {
657                seen_prefix = true;
658            } else {
659                let (key, val) = parse_option(&opt)?;
660                pairs.push((key.to_owned(), val.to_owned()));
661            }
662        } else {
663            let (key, val) = opt.split_once('=').ok_or(())?;
664            pairs.push((key.to_owned(), val.to_owned()));
665            seen_prefix = false;
666        }
667    }
668    Ok(pairs)
669}
670
671/// Returns the parsed key and value from option of the form `--key=value`, `-c
672/// key=value`, or `-ckey=value`. Keys replace `-` with `_`. Returns an error if
673/// there was some other prefix.
674fn parse_option(option: &str) -> Result<(&str, &str), ()> {
675    let (key, value) = option.split_once('=').ok_or(())?;
676    for prefix in &["-c", "--"] {
677        if let Some(key) = key.strip_prefix(prefix) {
678            return Ok((key, value));
679        }
680    }
681    Err(())
682}
683
684/// Splits value by any number of spaces except those preceded by `\`.
685fn split_options(value: &str) -> Vec<String> {
686    let mut strs = Vec::new();
687    // Need to build a string because of the escaping, so we can't simply
688    // subslice into value, and this isn't called enough to need to make it
689    // smart so it only builds a string if needed.
690    let mut current = String::new();
691    let mut was_slash = false;
692    for c in value.chars() {
693        was_slash = match c {
694            ' ' => {
695                if was_slash {
696                    current.push(' ');
697                } else if !current.is_empty() {
698                    // To ignore multiple spaces in a row, only push if current
699                    // is not empty.
700                    strs.push(std::mem::take(&mut current));
701                }
702                false
703            }
704            '\\' => {
705                if was_slash {
706                    // Two slashes in a row will add a slash and not escape the
707                    // next char.
708                    current.push('\\');
709                    false
710                } else {
711                    true
712                }
713            }
714            _ => {
715                current.push(c);
716                false
717            }
718        };
719    }
720    // A `\` at the end will be ignored.
721    if !current.is_empty() {
722        strs.push(current);
723    }
724    strs
725}
726
727enum PasswordRequestError {
728    InvalidPasswordError(ErrorResponse),
729    IoError(io::Error),
730}
731
732impl From<io::Error> for PasswordRequestError {
733    fn from(e: io::Error) -> Self {
734        PasswordRequestError::IoError(e)
735    }
736}
737
738/// Requests a cleartext password from a connection and returns it if it is valid.
739/// Sends an error response in the connection if the password
740/// is not valid.
741async fn request_cleartext_password<A>(
742    conn: &mut FramedConn<A>,
743) -> Result<String, PasswordRequestError>
744where
745    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
746{
747    conn.send(BackendMessage::AuthenticationCleartextPassword)
748        .await?;
749    conn.flush().await?;
750
751    if let Some(message) = conn.recv().await? {
752        if let FrontendMessage::RawAuthentication(data) = message {
753            if let Some(FrontendMessage::Password { password }) =
754                decode_password(Cursor::new(&data)).ok()
755            {
756                return Ok(password);
757            }
758        }
759    }
760
761    Err(PasswordRequestError::InvalidPasswordError(
762        ErrorResponse::fatal(
763            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
764            "expected Password message",
765        ),
766    ))
767}
768
769/// Helper for password-based authentication using AdapterClient
770/// and returns an authenticated session.
771async fn authenticate_with_password<A>(
772    conn: &mut FramedConn<A>,
773    adapter_client: &mz_adapter::Client,
774    user: String,
775    conn_uuid: Uuid,
776    helm_chart_version: Option<String>,
777    authenticator_kind: mz_auth::AuthenticatorKind,
778) -> Result<Session, PasswordRequestError>
779where
780    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
781{
782    let password = match request_cleartext_password(conn).await {
783        Ok(password) => Password(password),
784        Err(e) => return Err(e),
785    };
786
787    let authenticated = match adapter_client.authenticate(&user, &password).await {
788        Ok(authenticated) => authenticated,
789        Err(err) => {
790            warn!(?err, "pgwire connection failed authentication");
791            return Err(PasswordRequestError::InvalidPasswordError(
792                ErrorResponse::fatal(SqlState::INVALID_PASSWORD, "invalid password"),
793            ));
794        }
795    };
796
797    let session = adapter_client.new_session(
798        SessionConfig {
799            conn_id: conn.conn_id().clone(),
800            uuid: conn_uuid,
801            user,
802            client_ip: conn.peer_addr().clone(),
803            external_metadata_rx: None,
804            helm_chart_version,
805            authenticator_kind,
806            groups: None,
807        },
808        authenticated,
809    );
810
811    Ok(session)
812}
813
814#[derive(Debug)]
815enum State {
816    Ready,
817    Drain,
818    Done,
819}
820
821struct StateMachine<'a, A, I>
822where
823    I: Iterator<Item = TaskMetrics> + Send + 'a,
824{
825    conn: &'a mut FramedConn<A>,
826    adapter_client: mz_adapter::SessionClient,
827    txn_needs_commit: bool,
828    tokio_metrics_intervals: I,
829}
830
831enum SendRowsEndedReason {
832    Success {
833        result_size: u64,
834        rows_returned: u64,
835    },
836    Errored {
837        error: String,
838    },
839    Canceled,
840}
841
842const ABORTED_TXN_MSG: &str =
843    "current transaction is aborted, commands ignored until end of transaction block";
844
845impl<'a, A, I> StateMachine<'a, A, I>
846where
847    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
848    I: Iterator<Item = TaskMetrics> + Send + 'a,
849{
850    // Manually desugar this (don't use `async fn run`) here because a much better
851    // error message is produced if there are problems with Send or other traits
852    // somewhere within the Future.
853    #[allow(clippy::manual_async_fn)]
854    #[mz_ore::instrument(level = "debug")]
855    fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
856        async move {
857            let mut state = State::Ready;
858            loop {
859                self.send_pending_notices().await?;
860                state = match state {
861                    State::Ready => self.advance_ready().await?,
862                    State::Drain => self.advance_drain().await?,
863                    State::Done => return Ok(()),
864                };
865                self.adapter_client
866                    .add_idle_in_transaction_session_timeout();
867            }
868        }
869    }
870
871    #[instrument(level = "debug")]
872    async fn advance_ready(&mut self) -> Result<State, io::Error> {
873        // Start a new metrics interval before the `recv()` call.
874        self.tokio_metrics_intervals
875            .next()
876            .expect("infinite iterator");
877
878        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
879        let message = select! {
880            biased;
881
882            // `recv_timeout()` is cancel-safe as per it's docs.
883            Some(timeout) = self.adapter_client.recv_timeout() => {
884                let err: AdapterError = timeout.into();
885                let conn_id = self.adapter_client.session().conn_id();
886                tracing::warn!("session timed out, conn_id {}", conn_id);
887
888                // Process the error, doing any state cleanup.
889                let error_response = err.into_response(Severity::Fatal);
890                let error_state = self.send_error_and_get_state(error_response).await;
891
892                // Terminate __after__ we do any cleanup.
893                self.adapter_client.terminate().await;
894
895                // We must wait for the client to send a request before we can send the error response.
896                // Due to the PG wire protocol, we can't send an ErrorResponse unless it is in response
897                // to a client message.
898                let _ = self.conn.recv().await?;
899                return error_state;
900            },
901            // `recv()` is cancel-safe as per it's docs.
902            message = self.conn.recv() => message?,
903        };
904
905        // Take the metrics since just before the `recv`.
906        let interval = self
907            .tokio_metrics_intervals
908            .next()
909            .expect("infinite iterator");
910        let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
911
912        // TODO(ggevay): Consider subtracting the scheduling delay from `received`. It's not obvious
913        // whether we should do this, because the result wouldn't exactly correspond to either first
914        // byte received or last byte received (for msgs that arrive in more than one network packet).
915        let received = SYSTEM_TIME();
916
917        self.adapter_client
918            .remove_idle_in_transaction_session_timeout();
919
920        // NOTE(guswynn): we could consider adding spans to all message types. Currently
921        // only a few message types seem useful.
922        let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
923
924        let start = message.as_ref().map(|_| Instant::now());
925        let next_state = match message {
926            Some(FrontendMessage::Query { sql }) => {
927                let query_root_span =
928                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
929                query_root_span.follows_from(tracing::Span::current());
930                self.query(sql, received)
931                    .instrument(query_root_span)
932                    .await?
933            }
934            Some(FrontendMessage::Parse {
935                name,
936                sql,
937                param_types,
938            }) => self.parse(name, sql, param_types).await?,
939            Some(FrontendMessage::Bind {
940                portal_name,
941                statement_name,
942                param_formats,
943                raw_params,
944                result_formats,
945            }) => {
946                self.bind(
947                    portal_name,
948                    statement_name,
949                    param_formats,
950                    raw_params,
951                    result_formats,
952                )
953                .await?
954            }
955            Some(FrontendMessage::Execute {
956                portal_name,
957                max_rows,
958            }) => {
959                let max_rows = match usize::try_from(max_rows) {
960                    Ok(0) | Err(_) => ExecuteCount::All, // If `max_rows < 0`, no limit.
961                    Ok(n) => ExecuteCount::Count(n),
962                };
963                let execute_root_span =
964                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
965                execute_root_span.follows_from(tracing::Span::current());
966                let state = self
967                    .execute(
968                        portal_name,
969                        max_rows,
970                        portal_exec_message,
971                        None,
972                        ExecuteTimeout::None,
973                        None,
974                        Some(received),
975                    )
976                    .instrument(execute_root_span)
977                    .await?;
978                // In PostgreSQL, when using the extended query protocol, some statements may
979                // trigger an eager commit of the current implicit transaction,
980                // see: <https://git.postgresql.org/gitweb/?p=postgresql.git&a=commitdiff&h=f92944137>.
981                //
982                // In Materialize, however, we eagerly commit every statement outside of an explicit
983                // transaction when using the extended query protocol. This allows us to eliminate
984                // the possibility of a multiple statement implicit transaction, which in turn
985                // allows us to apply single-statement optimizations to queries issued in implicit
986                // transactions in the extended query protocol.
987                //
988                // We don't immediately commit here to allow users to page through the portal if
989                // necessary. Committing the transaction would destroy the portal before the next
990                // Execute command has a chance to resume it. So we instead mark the transaction
991                // for commit the next time that `ensure_transaction` is called.
992                if self.adapter_client.session().transaction().is_implicit() {
993                    self.txn_needs_commit = true;
994                }
995                state
996            }
997            Some(FrontendMessage::DescribeStatement { name }) => {
998                self.describe_statement(&name).await?
999            }
1000            Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
1001            Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
1002            Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
1003            Some(FrontendMessage::Flush) => self.flush().await?,
1004            Some(FrontendMessage::Sync) => self.sync().await?,
1005            Some(FrontendMessage::Terminate) => State::Done,
1006
1007            Some(FrontendMessage::CopyData(_))
1008            | Some(FrontendMessage::CopyDone)
1009            | Some(FrontendMessage::CopyFail(_))
1010            | Some(FrontendMessage::Password { .. })
1011            | Some(FrontendMessage::RawAuthentication(_))
1012            | Some(FrontendMessage::SASLInitialResponse { .. })
1013            | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
1014            None => State::Done,
1015        };
1016
1017        if let Some(start) = start {
1018            self.adapter_client
1019                .inner()
1020                .metrics()
1021                .pgwire_message_processing_seconds
1022                .with_label_values(&[message_name])
1023                .observe(start.elapsed().as_secs_f64());
1024        }
1025        self.adapter_client
1026            .inner()
1027            .metrics()
1028            .pgwire_recv_scheduling_delay_ms
1029            .with_label_values(&[message_name])
1030            .observe(recv_scheduling_delay_ms);
1031
1032        Ok(next_state)
1033    }
1034
1035    async fn advance_drain(&mut self) -> Result<State, io::Error> {
1036        let message = self.conn.recv().await?;
1037        if message.is_some() {
1038            self.adapter_client
1039                .remove_idle_in_transaction_session_timeout();
1040        }
1041        match message {
1042            Some(FrontendMessage::Sync) => self.sync().await,
1043            None => Ok(State::Done),
1044            _ => Ok(State::Drain),
1045        }
1046    }
1047
1048    /// Note that `lifecycle_timestamps` belongs to the whole "Simple Query", because the whole
1049    /// Simple Query is received and parsed together. This means that if there are multiple
1050    /// statements in a Simple Query, then all of them have the same `lifecycle_timestamps`.
1051    #[instrument(level = "debug")]
1052    async fn one_query(
1053        &mut self,
1054        stmt: Statement<Raw>,
1055        sql: String,
1056        lifecycle_timestamps: LifecycleTimestamps,
1057    ) -> Result<State, io::Error> {
1058        // Bind the portal. Note that this does not set the empty string prepared
1059        // statement.
1060        const EMPTY_PORTAL: &str = "";
1061        if let Err(e) = self
1062            .adapter_client
1063            .declare(EMPTY_PORTAL.to_string(), stmt, sql)
1064            .await
1065        {
1066            return self
1067                .send_error_and_get_state(e.into_response(Severity::Error))
1068                .await;
1069        }
1070        let portal = self
1071            .adapter_client
1072            .session()
1073            .get_portal_unverified_mut(EMPTY_PORTAL)
1074            .expect("unnamed portal should be present");
1075
1076        *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
1077
1078        let stmt_desc = portal.desc.clone();
1079        if !stmt_desc.param_types.is_empty() {
1080            return self
1081                .send_error_and_get_state(ErrorResponse::error(
1082                    SqlState::UNDEFINED_PARAMETER,
1083                    "there is no parameter $1",
1084                ))
1085                .await;
1086        }
1087
1088        // Maybe send row description.
1089        if let Some(relation_desc) = &stmt_desc.relation_desc {
1090            if !stmt_desc.is_copy {
1091                let formats = vec![Format::Text; stmt_desc.arity()];
1092                self.send(BackendMessage::RowDescription(
1093                    message::encode_row_description(relation_desc, &formats),
1094                ))
1095                .await?;
1096            }
1097        }
1098
1099        let result = match self
1100            .adapter_client
1101            .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
1102            .await
1103        {
1104            Ok((response, execute_started)) => {
1105                self.send_pending_notices().await?;
1106                self.send_execute_response(
1107                    response,
1108                    stmt_desc.relation_desc,
1109                    EMPTY_PORTAL.to_string(),
1110                    ExecuteCount::All,
1111                    portal_exec_message,
1112                    None,
1113                    ExecuteTimeout::None,
1114                    execute_started,
1115                )
1116                .await
1117            }
1118            Err(e) => {
1119                self.send_pending_notices().await?;
1120                self.send_error_and_get_state(e.into_response(Severity::Error))
1121                    .await
1122            }
1123        };
1124
1125        // Destroy the portal.
1126        self.adapter_client.session().remove_portal(EMPTY_PORTAL);
1127
1128        result
1129    }
1130
1131    async fn ensure_transaction(
1132        &mut self,
1133        num_stmts: usize,
1134        message_type: &str,
1135    ) -> Result<(), io::Error> {
1136        let start = Instant::now();
1137        if self.txn_needs_commit {
1138            self.commit_transaction().await?;
1139        }
1140        // start_transaction can't error (but assert that just in case it changes in
1141        // the future.
1142        let res = self.adapter_client.start_transaction(Some(num_stmts));
1143        assert_ok!(res);
1144        self.adapter_client
1145            .inner()
1146            .metrics()
1147            .pgwire_ensure_transaction_seconds
1148            .with_label_values(&[message_type])
1149            .observe(start.elapsed().as_secs_f64());
1150        Ok(())
1151    }
1152
1153    fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1154        let parse_start = Instant::now();
1155        let result = match self.adapter_client.parse(sql) {
1156            Ok(result) => result.map_err(|e| {
1157                // Convert our 0-based byte position to pgwire's 1-based character
1158                // position.
1159                let pos = sql[..e.error.pos].chars().count() + 1;
1160                ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1161            }),
1162            Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1163        };
1164        self.adapter_client
1165            .inner()
1166            .metrics()
1167            .parse_seconds
1168            .observe(parse_start.elapsed().as_secs_f64());
1169        result
1170    }
1171
1172    /// Executes a "Simple Query", see
1173    /// <https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SIMPLE-QUERY>
1174    ///
1175    /// For implicit transaction handling, see "Multiple Statements in a Simple Query" in the above.
1176    #[instrument(level = "debug")]
1177    async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1178        // Parse first before doing any transaction checking.
1179        let stmts = match self.parse_sql(&sql) {
1180            Ok(stmts) => stmts,
1181            Err(err) => {
1182                self.send_error_and_get_state(err).await?;
1183                return self.ready().await;
1184            }
1185        };
1186
1187        let num_stmts = stmts.len();
1188
1189        // Compare with postgres' backend/tcop/postgres.c exec_simple_query.
1190        for StatementParseResult { ast: stmt, sql } in stmts {
1191            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1192            if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1193                self.aborted_txn_error().await?;
1194                break;
1195            }
1196
1197            // Start an implicit transaction if we aren't in any transaction and there's
1198            // more than one statement. This mirrors the `use_implicit_block` variable in
1199            // postgres.
1200            //
1201            // This needs to be done in the loop instead of once at the top because
1202            // a COMMIT/ROLLBACK statement needs to start a new transaction on next
1203            // statement.
1204            self.ensure_transaction(num_stmts, "query").await?;
1205
1206            match self
1207                .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1208                .await?
1209            {
1210                State::Ready => (),
1211                State::Drain => break,
1212                State::Done => return Ok(State::Done),
1213            }
1214        }
1215
1216        // Implicit transactions are closed at the end of a Query message.
1217        {
1218            if self.adapter_client.session().transaction().is_implicit() {
1219                self.commit_transaction().await?;
1220            }
1221        }
1222
1223        if num_stmts == 0 {
1224            self.send(BackendMessage::EmptyQueryResponse).await?;
1225        }
1226
1227        self.ready().await
1228    }
1229
1230    #[instrument(level = "debug")]
1231    async fn parse(
1232        &mut self,
1233        name: String,
1234        sql: String,
1235        param_oids: Vec<u32>,
1236    ) -> Result<State, io::Error> {
1237        // Start a transaction if we aren't in one.
1238        self.ensure_transaction(1, "parse").await?;
1239
1240        let mut param_types = vec![];
1241        for oid in param_oids {
1242            match mz_pgrepr::Type::from_oid(oid) {
1243                Ok(ty) => match SqlScalarType::try_from(&ty) {
1244                    Ok(ty) => param_types.push(Some(ty)),
1245                    Err(err) => {
1246                        return self
1247                            .send_error_and_get_state(ErrorResponse::error(
1248                                SqlState::INVALID_PARAMETER_VALUE,
1249                                err.to_string(),
1250                            ))
1251                            .await;
1252                    }
1253                },
1254                Err(_) if oid == 0 => param_types.push(None),
1255                Err(e) => {
1256                    return self
1257                        .send_error_and_get_state(ErrorResponse::error(
1258                            SqlState::PROTOCOL_VIOLATION,
1259                            e.to_string(),
1260                        ))
1261                        .await;
1262                }
1263            }
1264        }
1265
1266        let stmts = match self.parse_sql(&sql) {
1267            Ok(stmts) => stmts,
1268            Err(err) => {
1269                return self.send_error_and_get_state(err).await;
1270            }
1271        };
1272        if stmts.len() > 1 {
1273            return self
1274                .send_error_and_get_state(ErrorResponse::error(
1275                    SqlState::INTERNAL_ERROR,
1276                    "cannot insert multiple commands into a prepared statement",
1277                ))
1278                .await;
1279        }
1280        let (maybe_stmt, sql) = match stmts.into_iter().next() {
1281            None => (None, ""),
1282            Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1283        };
1284        if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1285            return self.aborted_txn_error().await;
1286        }
1287        match self
1288            .adapter_client
1289            .prepare(name, maybe_stmt, sql.to_string(), param_types)
1290            .await
1291        {
1292            Ok(()) => {
1293                self.send(BackendMessage::ParseComplete).await?;
1294                Ok(State::Ready)
1295            }
1296            Err(e) => {
1297                self.send_error_and_get_state(e.into_response(Severity::Error))
1298                    .await
1299            }
1300        }
1301    }
1302
1303    /// Commits and clears the current transaction.
1304    #[instrument(level = "debug")]
1305    async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1306        self.end_transaction(EndTransactionAction::Commit).await
1307    }
1308
1309    /// Rollback and clears the current transaction.
1310    #[instrument(level = "debug")]
1311    async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1312        self.end_transaction(EndTransactionAction::Rollback).await
1313    }
1314
1315    /// End a transaction and report to the user if an error occurred.
1316    #[instrument(level = "debug")]
1317    async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1318        self.txn_needs_commit = false;
1319        let resp = self.adapter_client.end_transaction(action).await;
1320        if let Err(err) = resp {
1321            self.send(BackendMessage::ErrorResponse(
1322                err.into_response(Severity::Error),
1323            ))
1324            .await?;
1325        }
1326        Ok(())
1327    }
1328
1329    #[instrument(level = "debug")]
1330    async fn bind(
1331        &mut self,
1332        portal_name: String,
1333        statement_name: String,
1334        param_formats: Vec<Format>,
1335        raw_params: Vec<Option<Vec<u8>>>,
1336        result_formats: Vec<Format>,
1337    ) -> Result<State, io::Error> {
1338        // Start a transaction if we aren't in one.
1339        self.ensure_transaction(1, "bind").await?;
1340
1341        let aborted_txn = self.is_aborted_txn();
1342        let stmt = match self
1343            .adapter_client
1344            .get_prepared_statement(&statement_name)
1345            .await
1346        {
1347            Ok(stmt) => stmt,
1348            Err(err) => {
1349                return self
1350                    .send_error_and_get_state(err.into_response(Severity::Error))
1351                    .await;
1352            }
1353        };
1354
1355        let param_types = &stmt.desc().param_types;
1356        if param_types.len() != raw_params.len() {
1357            let message = format!(
1358                "bind message supplies {actual} parameters, \
1359                 but prepared statement \"{name}\" requires {expected}",
1360                name = statement_name,
1361                actual = raw_params.len(),
1362                expected = param_types.len()
1363            );
1364            return self
1365                .send_error_and_get_state(ErrorResponse::error(
1366                    SqlState::PROTOCOL_VIOLATION,
1367                    message,
1368                ))
1369                .await;
1370        }
1371        let param_formats = match pad_formats(param_formats, raw_params.len()) {
1372            Ok(param_formats) => param_formats,
1373            Err(msg) => {
1374                return self
1375                    .send_error_and_get_state(ErrorResponse::error(
1376                        SqlState::PROTOCOL_VIOLATION,
1377                        msg,
1378                    ))
1379                    .await;
1380            }
1381        };
1382        if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1383            return self.aborted_txn_error().await;
1384        }
1385        let buf = RowArena::new();
1386        let mut params = vec![];
1387        for ((raw_param, mz_typ), format) in raw_params
1388            .into_iter()
1389            .zip_eq(param_types)
1390            .zip_eq(param_formats)
1391        {
1392            let pg_typ = mz_pgrepr::Type::from(mz_typ);
1393            let datum = match raw_param {
1394                None => Datum::Null,
1395                Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1396                    Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1397                        Ok(datum) => datum,
1398                        Err(msg) => {
1399                            return self
1400                                .send_error_and_get_state(ErrorResponse::error(
1401                                    SqlState::INVALID_PARAMETER_VALUE,
1402                                    msg,
1403                                ))
1404                                .await;
1405                        }
1406                    },
1407                    Err(err) => {
1408                        let msg = format!("unable to decode parameter: {}", err);
1409                        return self
1410                            .send_error_and_get_state(ErrorResponse::error(
1411                                SqlState::INVALID_PARAMETER_VALUE,
1412                                msg,
1413                            ))
1414                            .await;
1415                    }
1416                },
1417            };
1418            params.push((datum, mz_typ.clone()))
1419        }
1420
1421        let result_formats = match pad_formats(
1422            result_formats,
1423            stmt.desc()
1424                .relation_desc
1425                .clone()
1426                .map(|desc| desc.typ().column_types.len())
1427                .unwrap_or(0),
1428        ) {
1429            Ok(result_formats) => result_formats,
1430            Err(msg) => {
1431                return self
1432                    .send_error_and_get_state(ErrorResponse::error(
1433                        SqlState::PROTOCOL_VIOLATION,
1434                        msg,
1435                    ))
1436                    .await;
1437            }
1438        };
1439
1440        // Binary encodings are disabled for list, map, and aclitem types, but this doesn't
1441        // apply to COPY TO statements.
1442        if !stmt.stmt().map_or(false, |stmt| match stmt {
1443            Statement::Copy(CopyStatement {
1444                direction: CopyDirection::To,
1445                ..
1446            }) => true,
1447            Statement::Copy(CopyStatement {
1448                direction: CopyDirection::From,
1449                // To be conservative, we are restricting COPY FROM to only allow list/map/aclitem types if it is not
1450                // copying from STDIN. It is likely that this works in theory, but is risky and likely to OOM anyways
1451                // as all the data will be held in a buffer in memory before being processed.
1452                target: CopyTarget::Expr(_),
1453                ..
1454            }) => true,
1455            _ => false,
1456        }) {
1457            if let Some(desc) = stmt.desc().relation_desc.clone() {
1458                for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1459                    match (format, &ty.scalar_type) {
1460                        (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1461                            return self
1462                                .send_error_and_get_state(ErrorResponse::error(
1463                                    SqlState::PROTOCOL_VIOLATION,
1464                                    "binary encoding of list types is not implemented",
1465                                ))
1466                                .await;
1467                        }
1468                        (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1469                            return self
1470                                .send_error_and_get_state(ErrorResponse::error(
1471                                    SqlState::PROTOCOL_VIOLATION,
1472                                    "binary encoding of map types is not implemented",
1473                                ))
1474                                .await;
1475                        }
1476                        (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1477                            return self
1478                                .send_error_and_get_state(ErrorResponse::error(
1479                                    SqlState::PROTOCOL_VIOLATION,
1480                                    "binary encoding of aclitem types does not exist",
1481                                ))
1482                                .await;
1483                        }
1484                        _ => (),
1485                    }
1486                }
1487            }
1488        }
1489
1490        let desc = stmt.desc().clone();
1491        let logging = Arc::clone(stmt.logging());
1492        let stmt_ast = stmt.stmt().cloned();
1493        let state_revision = stmt.state_revision;
1494        if let Err(err) = self.adapter_client.session().set_portal(
1495            portal_name,
1496            desc,
1497            stmt_ast,
1498            logging,
1499            params,
1500            result_formats,
1501            state_revision,
1502        ) {
1503            return self
1504                .send_error_and_get_state(err.into_response(Severity::Error))
1505                .await;
1506        }
1507
1508        self.send(BackendMessage::BindComplete).await?;
1509        Ok(State::Ready)
1510    }
1511
1512    /// `outer_ctx_extra` is Some when we are executing as part of an outer statement, e.g., a FETCH
1513    /// triggering the execution of the underlying query.
1514    fn execute(
1515        &mut self,
1516        portal_name: String,
1517        max_rows: ExecuteCount,
1518        get_response: GetResponse,
1519        fetch_portal_name: Option<String>,
1520        timeout: ExecuteTimeout,
1521        outer_ctx_extra: Option<ExecuteContextGuard>,
1522        received: Option<EpochMillis>,
1523    ) -> BoxFuture<'_, Result<State, io::Error>> {
1524        async move {
1525            let aborted_txn = self.is_aborted_txn();
1526
1527            // Check if the portal has been started and can be continued.
1528            let portal = match self
1529                .adapter_client
1530                .session()
1531                .get_portal_unverified_mut(&portal_name)
1532            {
1533                Some(portal) => portal,
1534                None => {
1535                    let msg = format!("portal {} does not exist", portal_name.quoted());
1536                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1537                        self.adapter_client.retire_execute(
1538                            outer_ctx_extra,
1539                            StatementEndedExecutionReason::Errored { error: msg.clone() },
1540                        );
1541                    }
1542                    return self
1543                        .send_error_and_get_state(ErrorResponse::error(
1544                            SqlState::INVALID_CURSOR_NAME,
1545                            msg,
1546                        ))
1547                        .await;
1548                }
1549            };
1550
1551            *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1552
1553            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1554            let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1555            if aborted_txn && !txn_exit_stmt {
1556                if let Some(outer_ctx_extra) = outer_ctx_extra {
1557                    self.adapter_client.retire_execute(
1558                        outer_ctx_extra,
1559                        StatementEndedExecutionReason::Errored {
1560                            error: ABORTED_TXN_MSG.to_string(),
1561                        },
1562                    );
1563                }
1564                return self.aborted_txn_error().await;
1565            }
1566
1567            let row_desc = portal.desc.relation_desc.clone();
1568            match portal.state {
1569                PortalState::NotStarted => {
1570                    // Start a transaction if we aren't in one.
1571                    self.ensure_transaction(1, "execute").await?;
1572                    match self
1573                        .adapter_client
1574                        .execute(
1575                            portal_name.clone(),
1576                            self.conn.wait_closed(),
1577                            outer_ctx_extra,
1578                        )
1579                        .await
1580                    {
1581                        Ok((response, execute_started)) => {
1582                            self.send_pending_notices().await?;
1583                            self.send_execute_response(
1584                                response,
1585                                row_desc,
1586                                portal_name,
1587                                max_rows,
1588                                get_response,
1589                                fetch_portal_name,
1590                                timeout,
1591                                execute_started,
1592                            )
1593                            .await
1594                        }
1595                        Err(e) => {
1596                            self.send_pending_notices().await?;
1597                            self.send_error_and_get_state(e.into_response(Severity::Error))
1598                                .await
1599                        }
1600                    }
1601                }
1602                PortalState::InProgress(rows) => {
1603                    let rows = rows.take().expect("InProgress rows must be populated");
1604                    let (result, statement_ended_execution_reason) = match self
1605                        .send_rows(
1606                            row_desc.expect("portal missing row desc on resumption"),
1607                            portal_name,
1608                            rows,
1609                            max_rows,
1610                            get_response,
1611                            fetch_portal_name,
1612                            timeout,
1613                        )
1614                        .await
1615                    {
1616                        Err(e) => {
1617                            // This is an error communicating with the connection.
1618                            // We consider that to be a cancelation, rather than a query error.
1619                            (Err(e), StatementEndedExecutionReason::Canceled)
1620                        }
1621                        Ok((ok, SendRowsEndedReason::Canceled)) => {
1622                            (Ok(ok), StatementEndedExecutionReason::Canceled)
1623                        }
1624                        // NOTE: For now the values for `result_size` and
1625                        // `rows_returned` in fetches are a bit confusing.
1626                        // We record `Some(n)` for the first fetch, where `n` is
1627                        // the number of bytes/rows returned by the inner
1628                        // execute (regardless of how many rows the
1629                        // fetch fetched), and `None` for subsequent fetches.
1630                        //
1631                        // This arguably makes sense since the size/rows
1632                        // returned measures how much work the compute
1633                        // layer had to do to satisfy the query, but
1634                        // we should revisit it if/when we start
1635                        // logging the inner execute separately.
1636                        Ok((
1637                            ok,
1638                            SendRowsEndedReason::Success {
1639                                result_size: _,
1640                                rows_returned: _,
1641                            },
1642                        )) => (
1643                            Ok(ok),
1644                            StatementEndedExecutionReason::Success {
1645                                result_size: None,
1646                                rows_returned: None,
1647                                execution_strategy: None,
1648                            },
1649                        ),
1650                        Ok((ok, SendRowsEndedReason::Errored { error })) => {
1651                            (Ok(ok), StatementEndedExecutionReason::Errored { error })
1652                        }
1653                    };
1654                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1655                        self.adapter_client
1656                            .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1657                    }
1658                    result
1659                }
1660                // FETCH is an awkward command for our current architecture. In Postgres it
1661                // will extract <count> rows from the target portal, cache them, and return
1662                // them to the user as requested. Its command tag is always FETCH <num rows
1663                // extracted>. In Materialize, since we have chosen to not fully support FETCH,
1664                // we must remember the number of rows that were returned. Use this tag to
1665                // remember that information and return it.
1666                PortalState::Completed(Some(tag)) => {
1667                    let tag = tag.to_string();
1668                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1669                        self.adapter_client.retire_execute(
1670                            outer_ctx_extra,
1671                            StatementEndedExecutionReason::Success {
1672                                result_size: None,
1673                                rows_returned: None,
1674                                execution_strategy: None,
1675                            },
1676                        );
1677                    }
1678                    self.send(BackendMessage::CommandComplete { tag }).await?;
1679                    Ok(State::Ready)
1680                }
1681                PortalState::Completed(None) => {
1682                    let error = format!(
1683                        "portal {} cannot be run",
1684                        Ident::new_unchecked(portal_name).to_ast_string_stable()
1685                    );
1686                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1687                        self.adapter_client.retire_execute(
1688                            outer_ctx_extra,
1689                            StatementEndedExecutionReason::Errored {
1690                                error: error.clone(),
1691                            },
1692                        );
1693                    }
1694                    self.send_error_and_get_state(ErrorResponse::error(
1695                        SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1696                        error,
1697                    ))
1698                    .await
1699                }
1700            }
1701        }
1702        .instrument(debug_span!("execute"))
1703        .boxed()
1704    }
1705
1706    #[instrument(level = "debug")]
1707    async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1708        // Start a transaction if we aren't in one.
1709        self.ensure_transaction(1, "describe_statement").await?;
1710
1711        let stmt = match self.adapter_client.get_prepared_statement(name).await {
1712            Ok(stmt) => stmt,
1713            Err(err) => {
1714                return self
1715                    .send_error_and_get_state(err.into_response(Severity::Error))
1716                    .await;
1717            }
1718        };
1719        // Cloning to avoid a mutable borrow issue because `send` also uses `adapter_client`
1720        let parameter_desc = BackendMessage::ParameterDescription(
1721            stmt.desc()
1722                .param_types
1723                .iter()
1724                .map(mz_pgrepr::Type::from)
1725                .collect(),
1726        );
1727        // Claim that all results will be output in text format, even
1728        // though the true result formats are not yet known. A bit
1729        // weird, but this is the behavior that PostgreSQL specifies.
1730        let formats = vec![Format::Text; stmt.desc().arity()];
1731        let row_desc = describe_rows(stmt.desc(), &formats);
1732        self.send_all([parameter_desc, row_desc]).await?;
1733        Ok(State::Ready)
1734    }
1735
1736    #[instrument(level = "debug")]
1737    async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1738        // Start a transaction if we aren't in one.
1739        self.ensure_transaction(1, "describe_portal").await?;
1740
1741        let session = self.adapter_client.session();
1742        let row_desc = session
1743            .get_portal_unverified(name)
1744            .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1745        match row_desc {
1746            Some(row_desc) => {
1747                self.send(row_desc).await?;
1748                Ok(State::Ready)
1749            }
1750            None => {
1751                self.send_error_and_get_state(ErrorResponse::error(
1752                    SqlState::INVALID_CURSOR_NAME,
1753                    format!("portal {} does not exist", name.quoted()),
1754                ))
1755                .await
1756            }
1757        }
1758    }
1759
1760    #[instrument(level = "debug")]
1761    async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1762        self.adapter_client
1763            .session()
1764            .remove_prepared_statement(&name);
1765        self.send(BackendMessage::CloseComplete).await?;
1766        Ok(State::Ready)
1767    }
1768
1769    #[instrument(level = "debug")]
1770    async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1771        self.adapter_client.session().remove_portal(&name);
1772        self.send(BackendMessage::CloseComplete).await?;
1773        Ok(State::Ready)
1774    }
1775
1776    fn complete_portal(&mut self, name: &str) {
1777        let portal = self
1778            .adapter_client
1779            .session()
1780            .get_portal_unverified_mut(name)
1781            .expect("portal should exist");
1782        *portal.state = PortalState::Completed(None);
1783    }
1784
1785    async fn fetch(
1786        &mut self,
1787        name: String,
1788        count: Option<FetchDirection>,
1789        max_rows: ExecuteCount,
1790        fetch_portal_name: Option<String>,
1791        timeout: ExecuteTimeout,
1792        ctx_extra: ExecuteContextGuard,
1793    ) -> Result<State, io::Error> {
1794        // Unlike Execute, no count specified in FETCH returns 1 row, and 0 means 0
1795        // instead of All.
1796        let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1797
1798        // Figure out how many rows we should send back by looking at the various
1799        // combinations of the execute and fetch.
1800        //
1801        // In Postgres, Fetch will cache <count> rows from the target portal and
1802        // return those as requested (if, say, an Execute message was sent with a
1803        // max_rows < the Fetch's count). We expect that case to be incredibly rare and
1804        // so have chosen to not support it until users request it. This eases
1805        // implementation difficulty since we don't have to be able to "send" rows to
1806        // a buffer.
1807        //
1808        // TODO(mjibson): Test this somehow? Need to divide up the pgtest files in
1809        // order to have some that are not Postgres compatible.
1810        let count = match (max_rows, count) {
1811            (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1812                let count = usize::cast_from(count);
1813                if max_rows < count {
1814                    let msg = "Execute with max_rows < a FETCH's count is not supported";
1815                    self.adapter_client.retire_execute(
1816                        ctx_extra,
1817                        StatementEndedExecutionReason::Errored {
1818                            error: msg.to_string(),
1819                        },
1820                    );
1821                    return self
1822                        .send_error_and_get_state(ErrorResponse::error(
1823                            SqlState::FEATURE_NOT_SUPPORTED,
1824                            msg,
1825                        ))
1826                        .await;
1827                }
1828                ExecuteCount::Count(count)
1829            }
1830            (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1831                let msg = "Execute with max_rows of a FETCH ALL is not supported";
1832                self.adapter_client.retire_execute(
1833                    ctx_extra,
1834                    StatementEndedExecutionReason::Errored {
1835                        error: msg.to_string(),
1836                    },
1837                );
1838                return self
1839                    .send_error_and_get_state(ErrorResponse::error(
1840                        SqlState::FEATURE_NOT_SUPPORTED,
1841                        msg,
1842                    ))
1843                    .await;
1844            }
1845            (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1846            (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1847                ExecuteCount::Count(usize::cast_from(count))
1848            }
1849        };
1850        let cursor_name = name.to_string();
1851        self.execute(
1852            cursor_name,
1853            count,
1854            fetch_message,
1855            fetch_portal_name,
1856            timeout,
1857            Some(ctx_extra),
1858            None,
1859        )
1860        .await
1861    }
1862
1863    async fn flush(&mut self) -> Result<State, io::Error> {
1864        self.conn.flush().await?;
1865        Ok(State::Ready)
1866    }
1867
1868    /// Sends a backend message to the client, after applying a severity filter.
1869    ///
1870    /// The message is only sent if its severity is above the severity set
1871    /// in the session, with the default value being NOTICE.
1872    #[instrument(level = "debug")]
1873    async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1874    where
1875        M: Into<BackendMessage>,
1876    {
1877        let message: BackendMessage = message.into();
1878        let is_error =
1879            matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1880
1881        self.conn.send(message).await?;
1882
1883        // Flush immediately after sending an error response, as some clients
1884        // expect to be able to read the error response before sending a Sync
1885        // message. This is arguably in violation of the protocol specification,
1886        // but the specification is somewhat ambiguous, and easier to match
1887        // PostgreSQL here than to fix all the clients that have this
1888        // expectation.
1889        if is_error {
1890            self.conn.flush().await?;
1891        }
1892
1893        Ok(())
1894    }
1895
1896    #[instrument(level = "debug")]
1897    pub async fn send_all(
1898        &mut self,
1899        messages: impl IntoIterator<Item = BackendMessage>,
1900    ) -> Result<(), io::Error> {
1901        for m in messages {
1902            self.send(m).await?;
1903        }
1904        Ok(())
1905    }
1906
1907    #[instrument(level = "debug")]
1908    async fn sync(&mut self) -> Result<State, io::Error> {
1909        // Close the current transaction if we are in an implicit transaction.
1910        if self.adapter_client.session().transaction().is_implicit() {
1911            self.commit_transaction().await?;
1912        }
1913        self.ready().await
1914    }
1915
1916    #[instrument(level = "debug")]
1917    async fn ready(&mut self) -> Result<State, io::Error> {
1918        let txn_state = self.adapter_client.session().transaction().into();
1919        self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1920        self.flush().await
1921    }
1922
1923    #[allow(clippy::too_many_arguments)]
1924    #[instrument(level = "debug")]
1925    async fn send_execute_response(
1926        &mut self,
1927        response: ExecuteResponse,
1928        row_desc: Option<RelationDesc>,
1929        portal_name: String,
1930        max_rows: ExecuteCount,
1931        get_response: GetResponse,
1932        fetch_portal_name: Option<String>,
1933        timeout: ExecuteTimeout,
1934        execute_started: Instant,
1935    ) -> Result<State, io::Error> {
1936        let mut tag = response.tag();
1937
1938        macro_rules! command_complete {
1939            () => {{
1940                self.send(BackendMessage::CommandComplete {
1941                    tag: tag
1942                        .take()
1943                        .expect("command_complete only called on tag-generating results"),
1944                })
1945                .await?;
1946                Ok(State::Ready)
1947            }};
1948        }
1949
1950        let r = match response {
1951            ExecuteResponse::ClosedCursor => {
1952                self.complete_portal(&portal_name);
1953                command_complete!()
1954            }
1955            ExecuteResponse::DeclaredCursor => {
1956                self.complete_portal(&portal_name);
1957                command_complete!()
1958            }
1959            ExecuteResponse::EmptyQuery => {
1960                self.send(BackendMessage::EmptyQueryResponse).await?;
1961                Ok(State::Ready)
1962            }
1963            ExecuteResponse::Fetch {
1964                name,
1965                count,
1966                timeout,
1967                ctx_extra,
1968            } => {
1969                self.fetch(
1970                    name,
1971                    count,
1972                    max_rows,
1973                    Some(portal_name.to_string()),
1974                    timeout,
1975                    ctx_extra,
1976                )
1977                .await
1978            }
1979            ExecuteResponse::SendingRowsStreaming {
1980                rows,
1981                instance_id,
1982                strategy,
1983            } => {
1984                let row_desc = row_desc
1985                    .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1986
1987                let span = tracing::debug_span!("sending_rows_streaming");
1988
1989                self.send_rows(
1990                    row_desc,
1991                    portal_name,
1992                    InProgressRows::new(RecordFirstRowStream::new(
1993                        Box::new(rows),
1994                        execute_started,
1995                        &self.adapter_client,
1996                        Some(instance_id),
1997                        Some(strategy),
1998                    )),
1999                    max_rows,
2000                    get_response,
2001                    fetch_portal_name,
2002                    timeout,
2003                )
2004                .instrument(span)
2005                .await
2006                .map(|(state, _)| state)
2007            }
2008            ExecuteResponse::SendingRowsImmediate { rows } => {
2009                let row_desc = row_desc
2010                    .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
2011
2012                let span = tracing::debug_span!("sending_rows_immediate");
2013
2014                let stream =
2015                    futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
2016                self.send_rows(
2017                    row_desc,
2018                    portal_name,
2019                    InProgressRows::new(RecordFirstRowStream::new(
2020                        Box::new(stream),
2021                        execute_started,
2022                        &self.adapter_client,
2023                        None,
2024                        Some(StatementExecutionStrategy::Constant),
2025                    )),
2026                    max_rows,
2027                    get_response,
2028                    fetch_portal_name,
2029                    timeout,
2030                )
2031                .instrument(span)
2032                .await
2033                .map(|(state, _)| state)
2034            }
2035            ExecuteResponse::SetVariable { name, .. } => {
2036                // This code is somewhat awkwardly structured because we
2037                // can't hold `var` across an await point.
2038                let qn = name.to_string();
2039                let msg = if let Some(var) = self
2040                    .adapter_client
2041                    .session()
2042                    .vars_mut()
2043                    .notify_set()
2044                    .find(|v| v.name() == qn)
2045                {
2046                    Some(BackendMessage::ParameterStatus(var.name(), var.value()))
2047                } else {
2048                    None
2049                };
2050                if let Some(msg) = msg {
2051                    self.send(msg).await?;
2052                }
2053                command_complete!()
2054            }
2055            ExecuteResponse::Subscribing {
2056                rx,
2057                ctx_extra,
2058                instance_id,
2059            } => {
2060                if fetch_portal_name.is_none() {
2061                    let mut msg = ErrorResponse::notice(
2062                        SqlState::WARNING,
2063                        "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
2064                    );
2065                    if self.adapter_client.session().vars().application_name() == "psql" {
2066                        msg.hint = Some(
2067                            "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
2068                                .into(),
2069                        )
2070                    }
2071                    self.send(msg).await?;
2072                    self.conn.flush().await?;
2073                }
2074                let row_desc =
2075                    row_desc.expect("missing row description for ExecuteResponse::Subscribing");
2076                let (result, statement_ended_execution_reason) = match self
2077                    .send_rows(
2078                        row_desc,
2079                        portal_name,
2080                        InProgressRows::new(RecordFirstRowStream::new(
2081                            Box::new(UnboundedReceiverStream::new(rx)),
2082                            execute_started,
2083                            &self.adapter_client,
2084                            Some(instance_id),
2085                            None,
2086                        )),
2087                        max_rows,
2088                        get_response,
2089                        fetch_portal_name,
2090                        timeout,
2091                    )
2092                    .await
2093                {
2094                    Err(e) => {
2095                        // This is an error communicating with the connection.
2096                        // We consider that to be a cancelation, rather than a query error.
2097                        (Err(e), StatementEndedExecutionReason::Canceled)
2098                    }
2099                    Ok((ok, SendRowsEndedReason::Canceled)) => {
2100                        (Ok(ok), StatementEndedExecutionReason::Canceled)
2101                    }
2102                    Ok((
2103                        ok,
2104                        SendRowsEndedReason::Success {
2105                            result_size,
2106                            rows_returned,
2107                        },
2108                    )) => (
2109                        Ok(ok),
2110                        StatementEndedExecutionReason::Success {
2111                            result_size: Some(result_size),
2112                            rows_returned: Some(rows_returned),
2113                            execution_strategy: None,
2114                        },
2115                    ),
2116                    Ok((ok, SendRowsEndedReason::Errored { error })) => {
2117                        (Ok(ok), StatementEndedExecutionReason::Errored { error })
2118                    }
2119                };
2120                self.adapter_client
2121                    .retire_execute(ctx_extra, statement_ended_execution_reason);
2122                return result;
2123            }
2124            ExecuteResponse::CopyTo { format, resp } => {
2125                let row_desc =
2126                    row_desc.expect("missing row description for ExecuteResponse::CopyTo");
2127                match *resp {
2128                    ExecuteResponse::Subscribing {
2129                        rx,
2130                        ctx_extra,
2131                        instance_id,
2132                    } => {
2133                        let (result, statement_ended_execution_reason) = match self
2134                            .copy_rows(
2135                                format,
2136                                row_desc,
2137                                RecordFirstRowStream::new(
2138                                    Box::new(UnboundedReceiverStream::new(rx)),
2139                                    execute_started,
2140                                    &self.adapter_client,
2141                                    Some(instance_id),
2142                                    None,
2143                                ),
2144                            )
2145                            .await
2146                        {
2147                            Err(e) => {
2148                                // This is an error communicating with the connection.
2149                                // We consider that to be a cancelation, rather than a query error.
2150                                (Err(e), StatementEndedExecutionReason::Canceled)
2151                            }
2152                            Ok((
2153                                state,
2154                                SendRowsEndedReason::Success {
2155                                    result_size,
2156                                    rows_returned,
2157                                },
2158                            )) => (
2159                                Ok(state),
2160                                StatementEndedExecutionReason::Success {
2161                                    result_size: Some(result_size),
2162                                    rows_returned: Some(rows_returned),
2163                                    execution_strategy: None,
2164                                },
2165                            ),
2166                            Ok((state, SendRowsEndedReason::Errored { error })) => {
2167                                (Ok(state), StatementEndedExecutionReason::Errored { error })
2168                            }
2169                            Ok((state, SendRowsEndedReason::Canceled)) => {
2170                                (Ok(state), StatementEndedExecutionReason::Canceled)
2171                            }
2172                        };
2173                        self.adapter_client
2174                            .retire_execute(ctx_extra, statement_ended_execution_reason);
2175                        return result;
2176                    }
2177                    ExecuteResponse::SendingRowsStreaming {
2178                        rows,
2179                        instance_id,
2180                        strategy,
2181                    } => {
2182                        // We don't need to finalize execution here;
2183                        // it was already done in the
2184                        // coordinator. Just extract the state and
2185                        // return that.
2186                        return self
2187                            .copy_rows(
2188                                format,
2189                                row_desc,
2190                                RecordFirstRowStream::new(
2191                                    Box::new(rows),
2192                                    execute_started,
2193                                    &self.adapter_client,
2194                                    Some(instance_id),
2195                                    Some(strategy),
2196                                ),
2197                            )
2198                            .await
2199                            .map(|(state, _)| state);
2200                    }
2201                    ExecuteResponse::SendingRowsImmediate { rows } => {
2202                        let span = tracing::debug_span!("sending_rows_immediate");
2203
2204                        let rows = futures::stream::once(futures::future::ready(
2205                            PeekResponseUnary::Rows(rows),
2206                        ));
2207                        // We don't need to finalize execution here;
2208                        // it was already done in the
2209                        // coordinator. Just extract the state and
2210                        // return that.
2211                        return self
2212                            .copy_rows(
2213                                format,
2214                                row_desc,
2215                                RecordFirstRowStream::new(
2216                                    Box::new(rows),
2217                                    execute_started,
2218                                    &self.adapter_client,
2219                                    None,
2220                                    Some(StatementExecutionStrategy::Constant),
2221                                ),
2222                            )
2223                            .instrument(span)
2224                            .await
2225                            .map(|(state, _)| state);
2226                    }
2227                    _ => {
2228                        return self
2229                            .send_error_and_get_state(ErrorResponse::error(
2230                                SqlState::INTERNAL_ERROR,
2231                                "unsupported COPY response type".to_string(),
2232                            ))
2233                            .await;
2234                    }
2235                };
2236            }
2237            ExecuteResponse::CopyFrom {
2238                target_id,
2239                target_name,
2240                columns,
2241                params,
2242                ctx_extra,
2243            } => {
2244                let row_desc =
2245                    row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2246                self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2247                    .await
2248            }
2249            ExecuteResponse::TransactionCommitted { params }
2250            | ExecuteResponse::TransactionRolledBack { params } => {
2251                let notify_set: mz_ore::collections::HashSet<String> = self
2252                    .adapter_client
2253                    .session()
2254                    .vars()
2255                    .notify_set()
2256                    .map(|v| v.name().to_string())
2257                    .collect();
2258
2259                // Only report on parameters that are in the notify set.
2260                for (name, value) in params
2261                    .into_iter()
2262                    .filter(|(name, _v)| notify_set.contains(*name))
2263                {
2264                    let msg = BackendMessage::ParameterStatus(name, value);
2265                    self.send(msg).await?;
2266                }
2267                command_complete!()
2268            }
2269
2270            ExecuteResponse::AlteredDefaultPrivileges
2271            | ExecuteResponse::AlteredObject(..)
2272            | ExecuteResponse::AlteredRole
2273            | ExecuteResponse::AlteredSystemConfiguration
2274            | ExecuteResponse::CreatedCluster { .. }
2275            | ExecuteResponse::CreatedClusterReplica { .. }
2276            | ExecuteResponse::CreatedConnection { .. }
2277            | ExecuteResponse::CreatedDatabase { .. }
2278            | ExecuteResponse::CreatedIndex { .. }
2279            | ExecuteResponse::CreatedIntrospectionSubscribe
2280            | ExecuteResponse::CreatedMaterializedView { .. }
2281            | ExecuteResponse::CreatedRole
2282            | ExecuteResponse::CreatedSchema { .. }
2283            | ExecuteResponse::CreatedSecret { .. }
2284            | ExecuteResponse::CreatedSink { .. }
2285            | ExecuteResponse::CreatedSource { .. }
2286            | ExecuteResponse::CreatedTable { .. }
2287            | ExecuteResponse::CreatedType
2288            | ExecuteResponse::CreatedView { .. }
2289            | ExecuteResponse::CreatedViews { .. }
2290            | ExecuteResponse::CreatedNetworkPolicy
2291            | ExecuteResponse::Comment
2292            | ExecuteResponse::Deallocate { .. }
2293            | ExecuteResponse::Deleted(..)
2294            | ExecuteResponse::DiscardedAll
2295            | ExecuteResponse::DiscardedTemp
2296            | ExecuteResponse::DroppedObject(_)
2297            | ExecuteResponse::DroppedOwned
2298            | ExecuteResponse::GrantedPrivilege
2299            | ExecuteResponse::GrantedRole
2300            | ExecuteResponse::Inserted(..)
2301            | ExecuteResponse::Copied(..)
2302            | ExecuteResponse::Prepare
2303            | ExecuteResponse::Raised
2304            | ExecuteResponse::ReassignOwned
2305            | ExecuteResponse::RevokedPrivilege
2306            | ExecuteResponse::RevokedRole
2307            | ExecuteResponse::StartedTransaction { .. }
2308            | ExecuteResponse::Updated(..)
2309            | ExecuteResponse::ValidatedConnection => {
2310                command_complete!()
2311            }
2312        };
2313
2314        assert_none!(tag, "tag created but not consumed: {:?}", tag);
2315        r
2316    }
2317
2318    #[allow(clippy::too_many_arguments)]
2319    // TODO(guswynn): figure out how to get it to compile without skip_all
2320    #[mz_ore::instrument(level = "debug")]
2321    async fn send_rows(
2322        &mut self,
2323        row_desc: RelationDesc,
2324        portal_name: String,
2325        mut rows: InProgressRows,
2326        max_rows: ExecuteCount,
2327        get_response: GetResponse,
2328        fetch_portal_name: Option<String>,
2329        timeout: ExecuteTimeout,
2330    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2331        // If this portal is being executed from a FETCH then we need to use the result
2332        // format type of the outer portal.
2333        let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2334            name
2335        } else {
2336            &portal_name
2337        };
2338        let result_formats = self
2339            .adapter_client
2340            .session()
2341            .get_portal_unverified(result_format_portal_name)
2342            .expect("valid fetch portal name for send rows")
2343            .result_formats
2344            .clone();
2345
2346        let (mut wait_once, mut deadline) = match timeout {
2347            ExecuteTimeout::None => (false, None),
2348            ExecuteTimeout::Seconds(t) => (
2349                false,
2350                Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2351            ),
2352            ExecuteTimeout::WaitOnce => (true, None),
2353        };
2354
2355        // Sanity check that the various `RelationDesc`s match up.
2356        {
2357            let portal_name_desc = &self
2358                .adapter_client
2359                .session()
2360                .get_portal_unverified(portal_name.as_str())
2361                .expect("portal should exist")
2362                .desc
2363                .relation_desc;
2364            if let Some(portal_name_desc) = portal_name_desc {
2365                soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2366            }
2367            if let Some(fetch_portal_name) = &fetch_portal_name {
2368                let fetch_portal_desc = &self
2369                    .adapter_client
2370                    .session()
2371                    .get_portal_unverified(fetch_portal_name)
2372                    .expect("portal should exist")
2373                    .desc
2374                    .relation_desc;
2375                if let Some(fetch_portal_desc) = fetch_portal_desc {
2376                    soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2377                }
2378            }
2379        }
2380
2381        self.conn.set_encode_state(
2382            row_desc
2383                .typ()
2384                .column_types
2385                .iter()
2386                .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2387                .zip_eq(result_formats)
2388                .collect(),
2389        );
2390
2391        let mut total_sent_rows = 0;
2392        let mut total_sent_bytes = 0;
2393        // want_rows is the maximum number of rows the client wants.
2394        let mut want_rows = match max_rows {
2395            ExecuteCount::All => usize::MAX,
2396            ExecuteCount::Count(count) => count,
2397        };
2398
2399        // Send rows while the client still wants them and there are still rows to send.
2400        loop {
2401            // Fetch next batch of rows, waiting for a possible requested
2402            // timeout or notice.
2403            let batch = if rows.current.is_some() {
2404                FetchResult::Rows(rows.current.take())
2405            } else if want_rows == 0 {
2406                FetchResult::Rows(None)
2407            } else {
2408                let notice_fut = self.adapter_client.session().recv_notice();
2409                // Biased: drain available data before checking the deadline.
2410                // This is critical for the WaitOnce case, where the deadline
2411                // is set to `Instant::now()` right after the first batch:
2412                // without `biased`, `recv()` and the already-expired deadline
2413                // race nondeterministically, so we might break the loop
2414                // before `no_more_rows` is set (or even before ready rows
2415                // are consumed). With an explicit `TIMEOUT`, missing a batch
2416                // right at the boundary is acceptable, but WaitOnce fires
2417                // immediately and the race is not.
2418                //
2419                // Trade-off: if `recv()` keeps returning Ready (unlikely in
2420                // practice—row processing + flush is slower than upstream
2421                // tick granularity), a `TIMEOUT` deadline could be delayed.
2422                // See database-issues#9470.
2423                tokio::select! {
2424                    biased;
2425                    err = self.conn.wait_closed() => return Err(err),
2426                    batch = rows.remaining.recv() => match batch {
2427                        None => FetchResult::Rows(None),
2428                        Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2429                        Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2430                        Some(PeekResponseUnary::DependencyDropped(dep)) => {
2431                            FetchResult::Error(dep.query_terminated_error())
2432                        }
2433                        Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2434                    },
2435                    notice = notice_fut => {
2436                        FetchResult::Notice(notice)
2437                    }
2438                    _ = time::sleep_until(
2439                        deadline.unwrap_or_else(tokio::time::Instant::now),
2440                    ), if deadline.is_some() => FetchResult::Rows(None),
2441                }
2442            };
2443
2444            match batch {
2445                FetchResult::Rows(None) => break,
2446                FetchResult::Rows(Some(mut batch_rows)) => {
2447                    if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2448                        let msg = err.to_string();
2449                        return self
2450                            .send_error_and_get_state(err.into_response(Severity::Error))
2451                            .await
2452                            .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2453                    }
2454
2455                    // If wait_once is true: the first time this fn is called it blocks (same as
2456                    // deadline == None). The second time this fn is called it should behave the
2457                    // same a 0s timeout.
2458                    if wait_once && batch_rows.peek().is_some() {
2459                        deadline = Some(tokio::time::Instant::now());
2460                        wait_once = false;
2461                    }
2462
2463                    // Send a portion of the rows.
2464                    let mut sent_rows = 0;
2465                    let mut sent_bytes = 0;
2466                    let messages = (&mut batch_rows)
2467                        // TODO(parkmycar): This is a fair bit of juggling between iterator types
2468                        // to count the total number of bytes. Alternatively we could track the
2469                        // total sent bytes in this .map(...) call, but having side effects in map
2470                        // is a code smell.
2471                        .map(|row| {
2472                            let row_len = row.byte_len();
2473                            let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2474                            (row_len, BackendMessage::DataRow(values))
2475                        })
2476                        .inspect(|(row_len, _)| {
2477                            sent_bytes += row_len;
2478                            sent_rows += 1
2479                        })
2480                        .map(|(_row_len, row)| row)
2481                        .take(want_rows);
2482                    self.send_all(messages).await?;
2483
2484                    total_sent_rows += sent_rows;
2485                    total_sent_bytes += sent_bytes;
2486                    want_rows -= sent_rows;
2487
2488                    // If we have sent the number of requested rows, put the remainder of the batch
2489                    // (if any) back and stop sending.
2490                    if want_rows == 0 {
2491                        if batch_rows.peek().is_some() {
2492                            rows.current = Some(batch_rows);
2493                        }
2494                        break;
2495                    }
2496
2497                    self.conn.flush().await?;
2498                }
2499                FetchResult::Notice(notice) => {
2500                    self.send(notice.into_response()).await?;
2501                    self.conn.flush().await?;
2502                }
2503                FetchResult::Error(text) => {
2504                    return self
2505                        .send_error_and_get_state(ErrorResponse::error(
2506                            SqlState::INTERNAL_ERROR,
2507                            text.clone(),
2508                        ))
2509                        .await
2510                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2511                }
2512                FetchResult::Canceled => {
2513                    return self
2514                        .send_error_and_get_state(ErrorResponse::error(
2515                            SqlState::QUERY_CANCELED,
2516                            "canceling statement due to user request",
2517                        ))
2518                        .await
2519                        .map(|state| (state, SendRowsEndedReason::Canceled));
2520                }
2521            }
2522        }
2523
2524        let portal = self
2525            .adapter_client
2526            .session()
2527            .get_portal_unverified_mut(&portal_name)
2528            .expect("valid portal name for send rows");
2529
2530        let saw_rows = rows.remaining.saw_rows;
2531        let no_more_rows = rows.no_more_rows();
2532        let metric_recorded = rows.remaining.metric_recorded;
2533        let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2534
2535        if no_more_rows && !metric_recorded {
2536            rows.remaining.metric_recorded = true;
2537        }
2538
2539        // Always return rows back, even if it's empty. This prevents an unclosed
2540        // portal from re-executing after it has been emptied.
2541        *portal.state = PortalState::InProgress(Some(rows));
2542
2543        let fetch_portal = fetch_portal_name.map(|name| {
2544            self.adapter_client
2545                .session()
2546                .get_portal_unverified_mut(&name)
2547                .expect("valid fetch portal")
2548        });
2549        let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2550        self.send(response_message).await?;
2551
2552        // Attend to metrics if there are no more rows. Only record once per stream
2553        // to avoid polluting the histogram when an exhausted cursor is FETCHed again.
2554        if no_more_rows && !metric_recorded {
2555            let statement_type = if let Some(stmt) = &self
2556                .adapter_client
2557                .session()
2558                .get_portal_unverified(&portal_name)
2559                .expect("valid portal name for send_rows")
2560                .stmt
2561            {
2562                metrics::statement_type_label_value(stmt.deref())
2563            } else {
2564                "no-statement"
2565            };
2566            let duration = if saw_rows {
2567                recorded_first_row_instant
2568                    .expect("recorded_first_row_instant because saw_rows")
2569                    .elapsed()
2570            } else {
2571                // If the result is empty, then we define time from first to last row as 0.
2572                // (Note that, currently, an empty result involves a PeekResponse with 0 rows, which
2573                // does flip `saw_rows`, so this code path is currently not exercised.)
2574                Duration::ZERO
2575            };
2576            self.adapter_client
2577                .inner()
2578                .metrics()
2579                .result_rows_first_to_last_byte_seconds
2580                .with_label_values(&[statement_type])
2581                .observe(duration.as_secs_f64());
2582        }
2583
2584        Ok((
2585            State::Ready,
2586            SendRowsEndedReason::Success {
2587                result_size: u64::cast_from(total_sent_bytes),
2588                rows_returned: u64::cast_from(total_sent_rows),
2589            },
2590        ))
2591    }
2592
2593    #[mz_ore::instrument(level = "debug")]
2594    async fn copy_rows(
2595        &mut self,
2596        format: CopyFormat,
2597        row_desc: RelationDesc,
2598        mut stream: RecordFirstRowStream,
2599    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2600        let (row_format, encode_format) = match format {
2601            CopyFormat::Text => (
2602                CopyFormatParams::Text(CopyTextFormatParams::default()),
2603                Format::Text,
2604            ),
2605            CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2606            CopyFormat::Csv => (
2607                CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2608                Format::Text,
2609            ),
2610            CopyFormat::Parquet => {
2611                let text = "Parquet format is not supported".to_string();
2612                return self
2613                    .send_error_and_get_state(ErrorResponse::error(
2614                        SqlState::INTERNAL_ERROR,
2615                        text.clone(),
2616                    ))
2617                    .await
2618                    .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2619            }
2620        };
2621
2622        let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2623            mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2624        };
2625
2626        let typ = row_desc.typ();
2627        let column_formats = iter::repeat(encode_format)
2628            .take(typ.column_types.len())
2629            .collect();
2630        self.send(BackendMessage::CopyOutResponse {
2631            overall_format: encode_format,
2632            column_formats,
2633        })
2634        .await?;
2635
2636        // In Postgres, binary copy has a header that is followed (in the same
2637        // CopyData) by the first row. In order to replicate their behavior, use a
2638        // common vec that we can extend one time now and then fill up with the encode
2639        // functions.
2640        let mut out = Vec::new();
2641
2642        if let CopyFormat::Binary = format {
2643            // 11-byte signature.
2644            out.extend(b"PGCOPY\n\xFF\r\n\0");
2645            // 32-bit flags field.
2646            out.extend([0, 0, 0, 0]);
2647            // 32-bit header extension length field.
2648            out.extend([0, 0, 0, 0]);
2649        }
2650
2651        let mut count = 0;
2652        let mut total_sent_bytes = 0;
2653        loop {
2654            tokio::select! {
2655                e = self.conn.wait_closed() => return Err(e),
2656                batch = stream.recv() => match batch {
2657                    None => break,
2658                    Some(PeekResponseUnary::Error(text)) => {
2659                        let err =
2660                            ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone());
2661                        return self
2662                            .send_error_and_get_state(err)
2663                            .await
2664                            .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2665                    }
2666                    Some(PeekResponseUnary::DependencyDropped(dep)) => {
2667                        let err = dep.to_concurrent_dependency_drop();
2668                        let text = err.to_string();
2669                        let resp = err.into_response(Severity::Error);
2670                        return self
2671                            .send_error_and_get_state(resp)
2672                            .await
2673                            .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2674                    }
2675                    Some(PeekResponseUnary::Canceled) => {
2676                        return self.send_error_and_get_state(ErrorResponse::error(
2677                                SqlState::QUERY_CANCELED,
2678                                "canceling statement due to user request",
2679                            ))
2680                            .await.map(|state| (state, SendRowsEndedReason::Canceled));
2681                    }
2682                    Some(PeekResponseUnary::Rows(mut rows)) => {
2683                        count += rows.count();
2684                        while let Some(row) = rows.next() {
2685                            total_sent_bytes += row.byte_len();
2686                            encode_fn(row, typ, &mut out)?;
2687                            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2688                                .await?;
2689                        }
2690                    }
2691                },
2692                notice = self.adapter_client.session().recv_notice() => {
2693                    self.send(notice.into_response())
2694                        .await?;
2695                    self.conn.flush().await?;
2696                }
2697            }
2698
2699            self.conn.flush().await?;
2700        }
2701        // Send required trailers.
2702        if let CopyFormat::Binary = format {
2703            let trailer: i16 = -1;
2704            out.extend(trailer.to_be_bytes());
2705            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2706                .await?;
2707        }
2708
2709        let tag = format!("COPY {}", count);
2710        self.send(BackendMessage::CopyDone).await?;
2711        self.send(BackendMessage::CommandComplete { tag }).await?;
2712        Ok((
2713            State::Ready,
2714            SendRowsEndedReason::Success {
2715                result_size: u64::cast_from(total_sent_bytes),
2716                rows_returned: u64::cast_from(count),
2717            },
2718        ))
2719    }
2720
2721    /// Handles the copy-in mode of the postgres protocol from transferring
2722    /// data to the server.
2723    #[instrument(level = "debug")]
2724    async fn copy_from(
2725        &mut self,
2726        target_id: CatalogItemId,
2727        target_name: String,
2728        columns: Vec<ColumnIndex>,
2729        params: CopyFormatParams<'static>,
2730        row_desc: RelationDesc,
2731        mut ctx_extra: ExecuteContextGuard,
2732    ) -> Result<State, io::Error> {
2733        let res = self
2734            .copy_from_inner(
2735                target_id,
2736                target_name,
2737                columns,
2738                params,
2739                row_desc,
2740                &mut ctx_extra,
2741            )
2742            .await;
2743        match &res {
2744            Ok(State::Ready) => {
2745                self.adapter_client.retire_execute(
2746                    ctx_extra,
2747                    StatementEndedExecutionReason::Success {
2748                        result_size: None,
2749                        rows_returned: None,
2750                        execution_strategy: None,
2751                    },
2752                );
2753            }
2754            Ok(State::Done) => {
2755                // The connection closed gracefully without sending us a `CopyDone`,
2756                // causing us to just drop the copy request.
2757                // For the purposes of statement logging, we count this as a cancellation.
2758                self.adapter_client
2759                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2760            }
2761            Err(e) => {
2762                self.adapter_client.retire_execute(
2763                    ctx_extra,
2764                    StatementEndedExecutionReason::Errored {
2765                        error: format!("{e}"),
2766                    },
2767                );
2768            }
2769            Ok(State::Drain) => {}
2770        }
2771        res
2772    }
2773
2774    async fn copy_from_inner(
2775        &mut self,
2776        target_id: CatalogItemId,
2777        target_name: String,
2778        columns: Vec<ColumnIndex>,
2779        params: CopyFormatParams<'static>,
2780        row_desc: RelationDesc,
2781        ctx_extra: &mut ExecuteContextGuard,
2782    ) -> Result<State, io::Error> {
2783        let typ = row_desc.typ();
2784        let column_formats = vec![Format::Text; typ.column_types.len()];
2785        self.send(BackendMessage::CopyInResponse {
2786            overall_format: Format::Text,
2787            column_formats,
2788        })
2789        .await?;
2790        self.conn.flush().await?;
2791
2792        // Set up the parallel streaming batch builders in the coordinator.
2793        let writer = match self
2794            .adapter_client
2795            .start_copy_from_stdin(
2796                target_id,
2797                target_name.clone(),
2798                columns.clone(),
2799                row_desc.clone(),
2800                params.clone(),
2801            )
2802            .await
2803        {
2804            Ok(writer) => writer,
2805            Err(e) => {
2806                // Drain remaining CopyData/CopyDone/CopyFail messages from the
2807                // socket. Since CopyInResponse was already sent, the client may
2808                // have pipelined copy data that we must consume before returning
2809                // the error, otherwise they'd be misinterpreted as top-level
2810                // protocol messages and cause a deadlock.
2811                loop {
2812                    match self.conn.recv().await? {
2813                        Some(FrontendMessage::CopyData(_)) => {}
2814                        Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2815                            break;
2816                        }
2817                        Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2818                        Some(_) => break,
2819                        None => return Ok(State::Done),
2820                    }
2821                }
2822                self.adapter_client.retire_execute(
2823                    std::mem::take(ctx_extra),
2824                    StatementEndedExecutionReason::Errored {
2825                        error: e.to_string(),
2826                    },
2827                );
2828                return self
2829                    .send_error_and_get_state(e.into_response(Severity::Error))
2830                    .await;
2831            }
2832        };
2833
2834        // Enable copy mode on the codec to skip aggregate buffer size checks.
2835        self.conn.set_copy_mode(true);
2836
2837        // Batch size for splitting raw data across parallel workers (~32MB).
2838        const BATCH_SIZE: usize = 32 * 1024 * 1024;
2839        let max_copy_from_row_size = self
2840            .adapter_client
2841            .get_system_vars()
2842            .await
2843            .max_copy_from_row_size()
2844            .try_into()
2845            .unwrap_or(usize::MAX);
2846
2847        let mut data = Vec::new();
2848        let mut row_scanner = CopyRowScanner::new(&params);
2849        let num_workers = writer.batch_txs.len();
2850        let mut next_worker: usize = 0;
2851        let mut saw_copy_done = false;
2852        let mut saw_end_marker = false;
2853        let mut copy_from_error: Option<(SqlState, String)> = None;
2854
2855        // Receive loop: accumulate CopyData, split at row boundaries,
2856        // round-robin raw chunks to parallel batch builder workers.
2857        loop {
2858            let message = self.conn.recv().await?;
2859            match message {
2860                Some(FrontendMessage::CopyData(buf)) => {
2861                    if saw_end_marker {
2862                        // Per PostgreSQL COPY behavior, ignore all bytes after
2863                        // the end-of-copy marker until CopyDone.
2864                        continue;
2865                    }
2866                    data.extend(buf);
2867                    row_scanner.scan_new_bytes(&data);
2868
2869                    if let Some(end_pos) = row_scanner.end_marker_end() {
2870                        data.truncate(end_pos);
2871                        row_scanner.on_truncate(end_pos);
2872                        saw_end_marker = true;
2873                    }
2874
2875                    // Guard against pathological single rows that never terminate.
2876                    if row_scanner.current_row_size(data.len()) > max_copy_from_row_size {
2877                        copy_from_error = Some((
2878                            SqlState::INSUFFICIENT_RESOURCES,
2879                            format!(
2880                                "COPY FROM STDIN row exceeded max_copy_from_row_size \
2881                                 ({max_copy_from_row_size} bytes)"
2882                            ),
2883                        ));
2884                        break;
2885                    }
2886
2887                    // When buffer exceeds batch size, split at the last complete row
2888                    // and send the complete rows chunk to the next worker.
2889                    let mut send_failed = false;
2890                    while data.len() >= BATCH_SIZE {
2891                        let split_pos = match row_scanner.last_row_end() {
2892                            Some(pos) => pos,
2893                            None => break, // no complete row yet
2894                        };
2895                        let remainder = data.split_off(split_pos);
2896                        let chunk = std::mem::replace(&mut data, remainder);
2897                        row_scanner.on_split(split_pos);
2898                        if writer.batch_txs[next_worker].send(chunk).await.is_err() {
2899                            send_failed = true;
2900                            break;
2901                        }
2902                        next_worker = (next_worker + 1) % num_workers;
2903                    }
2904                    // Worker dropped (likely errored) — stop sending,
2905                    // fall through to completion_rx for the real error.
2906                    if send_failed {
2907                        break;
2908                    }
2909                }
2910                Some(FrontendMessage::CopyDone) => {
2911                    // Send any remaining data to the next worker.
2912                    if !data.is_empty() {
2913                        let chunk = std::mem::take(&mut data);
2914                        // Ignore send failure — completion_rx will have the error.
2915                        let _ = writer.batch_txs[next_worker].send(chunk).await;
2916                    }
2917                    saw_copy_done = true;
2918                    break;
2919                }
2920                Some(FrontendMessage::CopyFail(err)) => {
2921                    self.adapter_client.retire_execute(
2922                        std::mem::take(ctx_extra),
2923                        StatementEndedExecutionReason::Canceled,
2924                    );
2925                    // Drop the writer to signal cancellation to the background tasks.
2926                    drop(writer);
2927                    self.conn.set_copy_mode(false);
2928                    return self
2929                        .send_error_and_get_state(ErrorResponse::error(
2930                            SqlState::QUERY_CANCELED,
2931                            format!("COPY from stdin failed: {}", err),
2932                        ))
2933                        .await;
2934                }
2935                Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2936                Some(_) => {
2937                    let msg = "unexpected message type during COPY from stdin";
2938                    self.adapter_client.retire_execute(
2939                        std::mem::take(ctx_extra),
2940                        StatementEndedExecutionReason::Errored {
2941                            error: msg.to_string(),
2942                        },
2943                    );
2944                    drop(writer);
2945                    self.conn.set_copy_mode(false);
2946                    return self
2947                        .send_error_and_get_state(ErrorResponse::error(
2948                            SqlState::PROTOCOL_VIOLATION,
2949                            msg,
2950                        ))
2951                        .await;
2952                }
2953                None => {
2954                    drop(writer);
2955                    self.conn.set_copy_mode(false);
2956                    return Ok(State::Done);
2957                }
2958            }
2959        }
2960
2961        // If we exited the receive loop before seeing `CopyDone` (e.g. because
2962        // a worker failed and dropped its channel), keep draining COPY input to
2963        // avoid desynchronizing the protocol state machine.
2964        if !saw_copy_done {
2965            loop {
2966                match self.conn.recv().await? {
2967                    Some(FrontendMessage::CopyData(_)) => {}
2968                    Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2969                        break;
2970                    }
2971                    Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2972                    Some(_) => {
2973                        let msg = "unexpected message type during COPY from stdin";
2974                        self.adapter_client.retire_execute(
2975                            std::mem::take(ctx_extra),
2976                            StatementEndedExecutionReason::Errored {
2977                                error: msg.to_string(),
2978                            },
2979                        );
2980                        drop(writer);
2981                        self.conn.set_copy_mode(false);
2982                        return self
2983                            .send_error_and_get_state(ErrorResponse::error(
2984                                SqlState::PROTOCOL_VIOLATION,
2985                                msg,
2986                            ))
2987                            .await;
2988                    }
2989                    None => {
2990                        drop(writer);
2991                        self.conn.set_copy_mode(false);
2992                        return Ok(State::Done);
2993                    }
2994                }
2995            }
2996        }
2997
2998        if let Some((code, msg)) = copy_from_error {
2999            self.adapter_client.retire_execute(
3000                std::mem::take(ctx_extra),
3001                StatementEndedExecutionReason::Errored { error: msg.clone() },
3002            );
3003            drop(writer);
3004            self.conn.set_copy_mode(false);
3005            return self
3006                .send_error_and_get_state(ErrorResponse::error(code, msg))
3007                .await;
3008        }
3009
3010        self.conn.set_copy_mode(false);
3011
3012        // Drop all senders to signal EOF to the background batch builders.
3013        // If copy_err is set, a worker already failed — dropping the senders
3014        // will cause remaining workers to stop, and we'll get the real error
3015        // from completion_rx below.
3016        drop(writer.batch_txs);
3017
3018        // Wait for all parallel workers to finish building batches.
3019        let (proto_batches, row_count) = match writer.completion_rx.await {
3020            Ok(Ok(result)) => result,
3021            Ok(Err(e)) => {
3022                self.adapter_client.retire_execute(
3023                    std::mem::take(ctx_extra),
3024                    StatementEndedExecutionReason::Errored {
3025                        error: e.to_string(),
3026                    },
3027                );
3028                return self
3029                    .send_error_and_get_state(e.into_response(Severity::Error))
3030                    .await;
3031            }
3032            Err(_) => {
3033                let msg = "COPY FROM STDIN: background batch builder tasks dropped";
3034                self.adapter_client.retire_execute(
3035                    std::mem::take(ctx_extra),
3036                    StatementEndedExecutionReason::Errored {
3037                        error: msg.to_string(),
3038                    },
3039                );
3040                return self
3041                    .send_error_and_get_state(ErrorResponse::error(SqlState::INTERNAL_ERROR, msg))
3042                    .await;
3043            }
3044        };
3045
3046        // Stage all batches in the session's transaction for atomic commit.
3047        if let Err(e) = self
3048            .adapter_client
3049            .stage_copy_from_stdin_batches(target_id, proto_batches)
3050        {
3051            self.adapter_client.retire_execute(
3052                std::mem::take(ctx_extra),
3053                StatementEndedExecutionReason::Errored {
3054                    error: e.to_string(),
3055                },
3056            );
3057            return self
3058                .send_error_and_get_state(e.into_response(Severity::Error))
3059                .await;
3060        }
3061
3062        let tag = format!("COPY {}", row_count);
3063        self.send(BackendMessage::CommandComplete { tag }).await?;
3064
3065        Ok(State::Ready)
3066    }
3067
3068    #[instrument(level = "debug")]
3069    async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
3070        let notices = self
3071            .adapter_client
3072            .session()
3073            .drain_notices()
3074            .into_iter()
3075            .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
3076        self.send_all(notices).await?;
3077        Ok(())
3078    }
3079
3080    #[instrument(level = "debug")]
3081    async fn send_error_and_get_state(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
3082        assert!(err.severity.is_error());
3083        debug!(
3084            "cid={} error code={}",
3085            self.adapter_client.session().conn_id(),
3086            err.code.code()
3087        );
3088        let is_fatal = err.severity.is_fatal();
3089        self.send(BackendMessage::ErrorResponse(err)).await?;
3090
3091        let txn = self.adapter_client.session().transaction();
3092        match txn {
3093            // Error can be called from describe and parse and so might not be in an active
3094            // transaction.
3095            TransactionStatus::Default | TransactionStatus::Failed(_) => {}
3096            // In Started (i.e., a single statement), cleanup ourselves.
3097            TransactionStatus::Started(_) => {
3098                self.rollback_transaction().await?;
3099            }
3100            // Implicit transactions also clear themselves.
3101            TransactionStatus::InTransactionImplicit(_) => {
3102                self.rollback_transaction().await?;
3103            }
3104            // Explicit transactions move to failed.
3105            TransactionStatus::InTransaction(_) => {
3106                self.adapter_client.fail_transaction();
3107            }
3108        };
3109        if is_fatal {
3110            Ok(State::Done)
3111        } else {
3112            Ok(State::Drain)
3113        }
3114    }
3115
3116    #[instrument(level = "debug")]
3117    async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
3118        self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
3119            SqlState::IN_FAILED_SQL_TRANSACTION,
3120            ABORTED_TXN_MSG,
3121        )))
3122        .await?;
3123        Ok(State::Drain)
3124    }
3125
3126    fn is_aborted_txn(&mut self) -> bool {
3127        matches!(
3128            self.adapter_client.session().transaction(),
3129            TransactionStatus::Failed(_)
3130        )
3131    }
3132}
3133
3134fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
3135    match (formats.len(), n) {
3136        (0, e) => Ok(vec![Format::Text; e]),
3137        (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
3138        (a, e) if a == e => Ok(formats),
3139        (a, e) => Err(format!(
3140            "expected {} field format specifiers, but got {}",
3141            e, a
3142        )),
3143    }
3144}
3145
3146fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
3147    match &stmt_desc.relation_desc {
3148        Some(desc) if !stmt_desc.is_copy => {
3149            BackendMessage::RowDescription(message::encode_row_description(desc, formats))
3150        }
3151        _ => BackendMessage::NoData,
3152    }
3153}
3154
3155type GetResponse = fn(
3156    max_rows: ExecuteCount,
3157    total_sent_rows: usize,
3158    fetch_portal: Option<PortalRefMut>,
3159) -> BackendMessage;
3160
3161// A GetResponse used by send_rows during execute messages on portals or for
3162// simple query messages.
3163fn portal_exec_message(
3164    max_rows: ExecuteCount,
3165    total_sent_rows: usize,
3166    _fetch_portal: Option<PortalRefMut>,
3167) -> BackendMessage {
3168    // If max_rows is not specified, we will always send back a CommandComplete. If
3169    // max_rows is specified, we only send CommandComplete if there were more rows
3170    // requested than were remaining. That is, if max_rows == number of rows that
3171    // were remaining before sending (not that are remaining after sending), then
3172    // we still send a PortalSuspended. The number of remaining rows after the rows
3173    // have been sent doesn't matter. This matches postgres.
3174    match max_rows {
3175        ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
3176            BackendMessage::PortalSuspended
3177        }
3178        _ => BackendMessage::CommandComplete {
3179            tag: format!("SELECT {}", total_sent_rows),
3180        },
3181    }
3182}
3183
3184// A GetResponse used by send_rows during FETCH queries.
3185fn fetch_message(
3186    _max_rows: ExecuteCount,
3187    total_sent_rows: usize,
3188    fetch_portal: Option<PortalRefMut>,
3189) -> BackendMessage {
3190    let tag = format!("FETCH {}", total_sent_rows);
3191    if let Some(portal) = fetch_portal {
3192        *portal.state = PortalState::Completed(Some(tag.clone()));
3193    }
3194    BackendMessage::CommandComplete { tag }
3195}
3196
3197fn get_authenticator(
3198    authenticator_kind: listeners::AuthenticatorKind,
3199    frontegg: Option<FronteggAuthenticator>,
3200    oidc: GenericOidcAuthenticator,
3201    adapter_client: mz_adapter::Client,
3202    // If oidc_auth_enabled exists as an option in the pgwire connection's
3203    // `option` parameter
3204    oidc_auth_option_enabled: bool,
3205) -> Authenticator {
3206    match authenticator_kind {
3207        listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
3208            "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
3209        )),
3210        listeners::AuthenticatorKind::Password => Authenticator::Password(adapter_client),
3211        listeners::AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client),
3212        listeners::AuthenticatorKind::Oidc => {
3213            if oidc_auth_option_enabled {
3214                Authenticator::Oidc(oidc)
3215            } else {
3216                // Fallback to password authentication if oidc auth is not enabled
3217                // through options.
3218                Authenticator::Password(adapter_client)
3219            }
3220        }
3221        listeners::AuthenticatorKind::None => Authenticator::None,
3222    }
3223}
3224
3225#[derive(Debug, Copy, Clone)]
3226enum ExecuteCount {
3227    All,
3228    Count(usize),
3229}
3230
3231// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
3232fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
3233    match stmt {
3234        // Add PREPARE to this if we ever support it.
3235        Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
3236        None => false,
3237    }
3238}
3239
3240#[derive(Debug)]
3241enum FetchResult {
3242    Rows(Option<Box<dyn RowIterator + Send + Sync>>),
3243    Canceled,
3244    Error(String),
3245    Notice(AdapterNotice),
3246}
3247
3248#[derive(Debug)]
3249struct CopyRowScanner {
3250    scan_pos: usize,
3251    last_row_end: Option<usize>,
3252    end_marker_end: Option<usize>,
3253    csv: Option<CsvScanState>,
3254}
3255
3256#[derive(Debug)]
3257struct CsvScanState {
3258    reader: csv_core::Reader,
3259    output: Vec<u8>,
3260    ends: Vec<usize>,
3261    record: Vec<u8>,
3262    record_ends: Vec<usize>,
3263    skip_first_record: bool,
3264}
3265
3266impl CopyRowScanner {
3267    fn new(params: &CopyFormatParams<'_>) -> Self {
3268        let csv = match params {
3269            CopyFormatParams::Csv(CopyCsvFormatParams {
3270                delimiter,
3271                quote,
3272                escape,
3273                header,
3274                ..
3275            }) => Some(CsvScanState::new(*delimiter, *quote, *escape, *header)),
3276            _ => None,
3277        };
3278
3279        CopyRowScanner {
3280            scan_pos: 0,
3281            last_row_end: None,
3282            end_marker_end: None,
3283            csv,
3284        }
3285    }
3286
3287    fn scan_new_bytes(&mut self, data: &[u8]) {
3288        if self.scan_pos >= data.len() {
3289            return;
3290        }
3291
3292        if let Some(csv) = self.csv.as_mut() {
3293            let mut input = &data[self.scan_pos..];
3294            let mut consumed = 0usize;
3295            while !input.is_empty() {
3296                let (result, n_input, n_output, n_ends) =
3297                    csv.reader
3298                        .read_record(input, &mut csv.output, &mut csv.ends);
3299                consumed += n_input;
3300                input = &input[n_input..];
3301                if !csv.output.is_empty() {
3302                    csv.record.extend_from_slice(&csv.output[..n_output]);
3303                }
3304                if !csv.ends.is_empty() {
3305                    csv.record_ends.extend_from_slice(&csv.ends[..n_ends]);
3306                }
3307
3308                match result {
3309                    ReadRecordResult::InputEmpty => break,
3310                    ReadRecordResult::OutputFull => {
3311                        if n_input == 0 {
3312                            csv.output
3313                                .resize(csv.output.len().saturating_mul(2).max(1), 0);
3314                        }
3315                    }
3316                    ReadRecordResult::OutputEndsFull => {
3317                        if n_input == 0 {
3318                            csv.ends.resize(csv.ends.len().saturating_mul(2).max(1), 0);
3319                        }
3320                    }
3321                    ReadRecordResult::Record | ReadRecordResult::End => {
3322                        let row_end = self.scan_pos + consumed;
3323                        self.last_row_end = Some(row_end);
3324                        if self.end_marker_end.is_none() {
3325                            let is_marker = if csv.skip_first_record {
3326                                csv.skip_first_record = false;
3327                                false
3328                            } else if csv.record_ends.len() == 1 {
3329                                let end = csv.record_ends[0];
3330                                end == 2 && csv.record.get(0..end) == Some(b"\\.")
3331                            } else {
3332                                false
3333                            };
3334                            if is_marker {
3335                                self.end_marker_end = Some(row_end);
3336                                break;
3337                            }
3338                        }
3339                        csv.record.clear();
3340                        csv.record_ends.clear();
3341                    }
3342                }
3343            }
3344        } else {
3345            let mut row_start = self.last_row_end.unwrap_or(0);
3346            for (offset, b) in data[self.scan_pos..].iter().enumerate() {
3347                if *b == b'\n' {
3348                    let row_end = self.scan_pos + offset + 1;
3349                    self.last_row_end = Some(row_end);
3350                    if self.end_marker_end.is_none() {
3351                        let row = &data[row_start..row_end];
3352                        if row.get(0..2) == Some(b"\\.") {
3353                            self.end_marker_end = Some(row_end);
3354                            break;
3355                        }
3356                    }
3357                    row_start = row_end;
3358                }
3359            }
3360        }
3361
3362        self.scan_pos = data.len();
3363    }
3364
3365    fn last_row_end(&self) -> Option<usize> {
3366        self.last_row_end
3367    }
3368
3369    fn end_marker_end(&self) -> Option<usize> {
3370        self.end_marker_end
3371    }
3372
3373    fn current_row_size(&self, data_len: usize) -> usize {
3374        data_len.saturating_sub(self.last_row_end.unwrap_or(0))
3375    }
3376
3377    fn on_split(&mut self, split_pos: usize) {
3378        self.scan_pos = self.scan_pos.saturating_sub(split_pos);
3379        self.last_row_end = None;
3380        self.end_marker_end = self
3381            .end_marker_end
3382            .and_then(|end| end.checked_sub(split_pos));
3383    }
3384
3385    fn on_truncate(&mut self, new_len: usize) {
3386        self.scan_pos = self.scan_pos.min(new_len);
3387        self.last_row_end = self.last_row_end.filter(|&end| end <= new_len);
3388        self.end_marker_end = self.end_marker_end.filter(|&end| end <= new_len);
3389    }
3390}
3391
3392impl CsvScanState {
3393    fn new(delimiter: u8, quote: u8, escape: u8, header: bool) -> Self {
3394        let (double_quote, escape) = if quote == escape {
3395            (true, None)
3396        } else {
3397            (false, Some(escape))
3398        };
3399        CsvScanState {
3400            reader: csv_core::ReaderBuilder::new()
3401                .delimiter(delimiter)
3402                .quote(quote)
3403                .double_quote(double_quote)
3404                .escape(escape)
3405                .build(),
3406            output: vec![0; 1],
3407            ends: vec![0; 1],
3408            record: Vec::new(),
3409            record_ends: Vec::new(),
3410            skip_first_record: header,
3411        }
3412    }
3413}
3414
3415#[cfg(test)]
3416mod test {
3417    use super::*;
3418
3419    #[mz_ore::test]
3420    fn test_parse_options() {
3421        struct TestCase {
3422            input: &'static str,
3423            expect: Result<Vec<(&'static str, &'static str)>, ()>,
3424        }
3425        let tests = vec![
3426            TestCase {
3427                input: "",
3428                expect: Ok(vec![]),
3429            },
3430            TestCase {
3431                input: "--key",
3432                expect: Err(()),
3433            },
3434            TestCase {
3435                input: "--key=val",
3436                expect: Ok(vec![("key", "val")]),
3437            },
3438            TestCase {
3439                input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
3440                expect: Ok(vec![
3441                    ("key", "val"),
3442                    ("key2", "val2"),
3443                    ("key3", "val3"),
3444                    ("key4", "val4"),
3445                    ("key5", "val5"),
3446                ]),
3447            },
3448            TestCase {
3449                input: r#"-c\ key=val"#,
3450                expect: Ok(vec![(" key", "val")]),
3451            },
3452            TestCase {
3453                input: "--key=val -ckey2 val2",
3454                expect: Err(()),
3455            },
3456            // Unclear what this should do.
3457            TestCase {
3458                input: "--key=",
3459                expect: Ok(vec![("key", "")]),
3460            },
3461        ];
3462        for test in tests {
3463            let got = parse_options(test.input);
3464            let expect = test.expect.map(|r| {
3465                r.into_iter()
3466                    .map(|(k, v)| (k.to_owned(), v.to_owned()))
3467                    .collect()
3468            });
3469            assert_eq!(got, expect, "input: {}", test.input);
3470        }
3471    }
3472
3473    #[mz_ore::test]
3474    fn test_parse_option() {
3475        struct TestCase {
3476            input: &'static str,
3477            expect: Result<(&'static str, &'static str), ()>,
3478        }
3479        let tests = vec![
3480            TestCase {
3481                input: "",
3482                expect: Err(()),
3483            },
3484            TestCase {
3485                input: "--",
3486                expect: Err(()),
3487            },
3488            TestCase {
3489                input: "--c",
3490                expect: Err(()),
3491            },
3492            TestCase {
3493                input: "a=b",
3494                expect: Err(()),
3495            },
3496            TestCase {
3497                input: "--a=b",
3498                expect: Ok(("a", "b")),
3499            },
3500            TestCase {
3501                input: "--ca=b",
3502                expect: Ok(("ca", "b")),
3503            },
3504            TestCase {
3505                input: "-ca=b",
3506                expect: Ok(("a", "b")),
3507            },
3508            // Unclear what this should error, but at least test it.
3509            TestCase {
3510                input: "--=",
3511                expect: Ok(("", "")),
3512            },
3513        ];
3514        for test in tests {
3515            let got = parse_option(test.input);
3516            assert_eq!(got, test.expect, "input: {}", test.input);
3517        }
3518    }
3519
3520    #[mz_ore::test]
3521    fn test_split_options() {
3522        struct TestCase {
3523            input: &'static str,
3524            expect: Vec<&'static str>,
3525        }
3526        let tests = vec![
3527            TestCase {
3528                input: "",
3529                expect: vec![],
3530            },
3531            TestCase {
3532                input: "  ",
3533                expect: vec![],
3534            },
3535            TestCase {
3536                input: " a ",
3537                expect: vec!["a"],
3538            },
3539            TestCase {
3540                input: "  ab     cd   ",
3541                expect: vec!["ab", "cd"],
3542            },
3543            TestCase {
3544                input: r#"  ab\     cd   "#,
3545                expect: vec!["ab ", "cd"],
3546            },
3547            TestCase {
3548                input: r#"  ab\\     cd   "#,
3549                expect: vec![r#"ab\"#, "cd"],
3550            },
3551            TestCase {
3552                input: r#"  ab\\\     cd   "#,
3553                expect: vec![r#"ab\ "#, "cd"],
3554            },
3555            TestCase {
3556                input: r#"  ab\\\ cd   "#,
3557                expect: vec![r#"ab\ cd"#],
3558            },
3559            TestCase {
3560                input: r#"  ab\\\cd   "#,
3561                expect: vec![r#"ab\cd"#],
3562            },
3563            TestCase {
3564                input: r#"a\"#,
3565                expect: vec!["a"],
3566            },
3567            TestCase {
3568                input: r#"a\ "#,
3569                expect: vec!["a "],
3570            },
3571            TestCase {
3572                input: r#"\"#,
3573                expect: vec![],
3574            },
3575            TestCase {
3576                input: r#"\ "#,
3577                expect: vec![r#" "#],
3578            },
3579            TestCase {
3580                input: r#" \ "#,
3581                expect: vec![r#" "#],
3582            },
3583            TestCase {
3584                input: r#"\  "#,
3585                expect: vec![r#" "#],
3586            },
3587        ];
3588        for test in tests {
3589            let got = split_options(test.input);
3590            assert_eq!(got, test.expect, "input: {}", test.input);
3591        }
3592    }
3593}