Skip to main content

mz_pgwire/
protocol.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use csv_core::ReadRecordResult;
21use futures::future::{BoxFuture, FutureExt, pending};
22use itertools::Itertools;
23use mz_adapter::client::RecordFirstRowStream;
24use mz_adapter::session::{
25    EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState, Session,
26    SessionConfig, TransactionStatus,
27};
28use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
29use mz_adapter::{
30    AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics,
31    verify_datum_desc,
32};
33use mz_adapter_types::dyncfgs::OIDC_GROUP_CLAIM;
34use mz_auth::Authenticated;
35use mz_auth::password::Password;
36use mz_authenticator::{Authenticator, GenericOidcAuthenticator};
37use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
38use mz_ore::cast::CastFrom;
39use mz_ore::netio::AsyncReady;
40use mz_ore::now::{EpochMillis, SYSTEM_TIME};
41use mz_ore::str::StrExt;
42use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log, soft_assert_or_log};
43use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
44use mz_pgwire_common::{
45    ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
46    VERSIONS,
47};
48use mz_repr::{
49    CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
50    SqlRelationType, SqlScalarType,
51};
52use mz_server_core::TlsMode;
53use mz_server_core::listeners;
54use mz_server_core::listeners::AllowedRoles;
55use mz_sql::ast::display::AstDisplay;
56use mz_sql::ast::{
57    CopyDirection, CopyStatement, CopyTarget, FetchDirection, Ident, Raw, Statement,
58};
59use mz_sql::parse::StatementParseResult;
60use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
61use mz_sql::session::metadata::SessionMetadata;
62use mz_sql::session::user::INTERNAL_USER_NAMES;
63use mz_sql::session::vars::VarInput;
64use postgres::error::SqlState;
65use tokio::io::{self, AsyncRead, AsyncWrite};
66use tokio::select;
67use tokio::time::{self};
68use tokio_metrics::TaskMetrics;
69use tokio_stream::wrappers::UnboundedReceiverStream;
70use tracing::{Instrument, debug, debug_span, warn};
71use uuid::Uuid;
72
73use crate::codec::{
74    FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
75};
76use crate::message::{
77    self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
78    SASLServerFirstMessage,
79};
80
81/// Reports whether the given stream begins with a pgwire handshake.
82///
83/// To avoid false negatives, there must be at least eight bytes in `buf`.
84pub fn match_handshake(buf: &[u8]) -> bool {
85    // The pgwire StartupMessage looks like this:
86    //
87    //     i32 - Length of entire message.
88    //     i32 - Protocol version number.
89    //     [String] - Arbitrary key-value parameters of any length.
90    //
91    // Since arbitrary parameters can be included in the StartupMessage, the
92    // first Int32 is worthless, since the message could have any length.
93    // Instead, we sniff the protocol version number.
94    if buf.len() < 8 {
95        return false;
96    }
97    let version = NetworkEndian::read_i32(&buf[4..8]);
98    VERSIONS.contains(&version)
99}
100
101/// Parameters for the [`run`] function.
102pub struct RunParams<'a, A, I>
103where
104    I: Iterator<Item = TaskMetrics> + Send,
105{
106    /// The TLS mode of the pgwire server.
107    pub tls_mode: Option<TlsMode>,
108    /// A client for the adapter.
109    pub adapter_client: mz_adapter::Client,
110    /// The connection to the client.
111    pub conn: &'a mut FramedConn<A>,
112    /// The universally unique identifier for the connection.
113    pub conn_uuid: Uuid,
114    /// The protocol version that the client provided in the startup message.
115    pub version: i32,
116    /// The parameters that the client provided in the startup message.
117    pub params: BTreeMap<String, String>,
118    /// Frontegg JWT authenticator.
119    pub frontegg: Option<FronteggAuthenticator>,
120    /// OIDC authenticator.
121    pub oidc: GenericOidcAuthenticator,
122    /// The authentication method defined by the server's listener
123    /// configuration.
124    pub authenticator_kind: listeners::AuthenticatorKind,
125    /// Global connection limit and count
126    pub active_connection_counter: ConnectionCounter,
127    /// Helm chart version
128    pub helm_chart_version: Option<String>,
129    /// Whether to allow reserved users (ie: mz_system).
130    pub allowed_roles: AllowedRoles,
131    /// Tokio metrics
132    pub tokio_metrics_intervals: I,
133}
134
135/// Runs a pgwire connection to completion.
136///
137/// This involves responding to `FrontendMessage::StartupMessage` and all future
138/// requests until the client terminates the connection or a fatal error occurs.
139///
140/// Note that this function returns successfully even upon delivering a fatal
141/// error to the client. It only returns `Err` if an unexpected I/O error occurs
142/// while communicating with the client, e.g., if the connection is severed in
143/// the middle of a request.
144#[mz_ore::instrument(level = "debug")]
145pub async fn run<'a, A, I>(
146    RunParams {
147        tls_mode,
148        adapter_client,
149        conn,
150        conn_uuid,
151        version,
152        mut params,
153        frontegg,
154        oidc,
155        authenticator_kind,
156        active_connection_counter,
157        helm_chart_version,
158        allowed_roles,
159        tokio_metrics_intervals,
160    }: RunParams<'a, A, I>,
161) -> Result<(), io::Error>
162where
163    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
164    I: Iterator<Item = TaskMetrics> + Send,
165{
166    if version != VERSION_3 {
167        return conn
168            .send(ErrorResponse::fatal(
169                SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
170                "server does not support the client's requested protocol version",
171            ))
172            .await;
173    }
174
175    let user = params.remove("user").unwrap_or_else(String::new);
176    let options = parse_options(params.get("options").unwrap_or(&String::new()));
177    let authenticator =
178        get_authenticator(authenticator_kind, frontegg, oidc, adapter_client.clone());
179    // TODO move this somewhere it can be shared with HTTP
180    let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
181    // this is a superset of internal users
182    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
183    let role_allowed = match allowed_roles {
184        AllowedRoles::Normal => !is_reserved_user,
185        AllowedRoles::Internal => is_internal_user,
186        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
187    };
188    if !role_allowed {
189        let msg = format!("unauthorized login to user '{user}'");
190        return conn
191            .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
192            .await;
193    }
194
195    if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
196        return conn.send(err).await;
197    }
198
199    let authenticator_kind = authenticator.kind();
200
201    let (mut session, expired) = match authenticator {
202        Authenticator::Frontegg(frontegg) => {
203            let password = match request_cleartext_password(conn).await {
204                Ok(password) => password,
205                Err(PasswordRequestError::IoError(e)) => return Err(e),
206                Err(PasswordRequestError::InvalidPasswordError(e)) => {
207                    return conn.send(e).await;
208                }
209            };
210
211            let group_claim =
212                OIDC_GROUP_CLAIM.get(adapter_client.get_system_vars().await.dyncfgs());
213            let auth_response = frontegg
214                .authenticate(&user, &password, Some(&group_claim))
215                .await;
216            match auth_response {
217                // Create a session based on the auth session.
218                //
219                // In particular, it's important that the username come from the
220                // auth session, as Frontegg may return an email address with
221                // different casing than the user supplied via the pgwire
222                // username fN
223                Ok((mut auth_session, authenticated)) => {
224                    let groups = auth_session.groups();
225                    let session = adapter_client.new_session(
226                        SessionConfig {
227                            conn_id: conn.conn_id().clone(),
228                            uuid: conn_uuid,
229                            user: auth_session.user().into(),
230                            client_ip: conn.peer_addr().clone(),
231                            external_metadata_rx: Some(auth_session.external_metadata_rx()),
232                            helm_chart_version,
233                            authenticator_kind,
234                            groups,
235                        },
236                        authenticated,
237                    );
238                    let expired = async move { auth_session.expired().await };
239                    (session, expired.left_future())
240                }
241                Err(err) => {
242                    warn!(?err, "pgwire connection failed authentication");
243                    return conn
244                        .send(ErrorResponse::fatal(
245                            SqlState::INVALID_PASSWORD,
246                            "invalid password",
247                        ))
248                        .await;
249                }
250            }
251        }
252        Authenticator::Oidc(oidc) => {
253            // OIDC listener: accepts either a JWT (uses OIDC authentication) or a
254            // plain SQL password (uses SQL password authentication).
255            let password = match request_cleartext_password(conn).await {
256                Ok(password) => password,
257                Err(PasswordRequestError::IoError(e)) => return Err(e),
258                Err(PasswordRequestError::InvalidPasswordError(e)) => {
259                    return conn.send(e).await;
260                }
261            };
262            if is_jwt(&password) {
263                let auth_response = oidc.authenticate(&password, Some(&user)).await;
264                match auth_response {
265                    Ok((mut claims, authenticated)) => {
266                        let groups = claims.groups.take();
267                        let session = adapter_client.new_session(
268                            SessionConfig {
269                                conn_id: conn.conn_id().clone(),
270                                uuid: conn_uuid,
271                                user: std::mem::take(&mut claims.user),
272                                client_ip: conn.peer_addr().clone(),
273                                external_metadata_rx: None,
274                                helm_chart_version,
275                                authenticator_kind,
276                                groups,
277                            },
278                            authenticated,
279                        );
280                        // No invalidation of the auth session once authenticated,
281                        // so auth session lasts indefinitely.
282                        (session, pending().right_future())
283                    }
284                    Err(err) => {
285                        warn!(?err, "pgwire connection failed authentication");
286                        return conn.send(err.into_response()).await;
287                    }
288                }
289            } else {
290                let session = match authenticate_with_password(
291                    conn,
292                    &adapter_client,
293                    user,
294                    Password(password),
295                    conn_uuid,
296                    helm_chart_version,
297                )
298                .await
299                {
300                    Ok(session) => session,
301                    Err(PasswordRequestError::IoError(e)) => return Err(e),
302                    Err(PasswordRequestError::InvalidPasswordError(e)) => {
303                        return conn.send(e).await;
304                    }
305                };
306                (session, pending().right_future())
307            }
308        }
309        Authenticator::Password(adapter_client) => {
310            let password = match request_cleartext_password(conn).await {
311                Ok(password) => password,
312                Err(PasswordRequestError::IoError(e)) => return Err(e),
313                Err(PasswordRequestError::InvalidPasswordError(e)) => {
314                    return conn.send(e).await;
315                }
316            };
317            let session = match authenticate_with_password(
318                conn,
319                &adapter_client,
320                user,
321                Password(password),
322                conn_uuid,
323                helm_chart_version,
324            )
325            .await
326            {
327                Ok(session) => session,
328                Err(PasswordRequestError::IoError(e)) => return Err(e),
329                Err(PasswordRequestError::InvalidPasswordError(e)) => {
330                    return conn.send(e).await;
331                }
332            };
333            // No frontegg check, so auth session lasts indefinitely.
334            (session, pending().right_future())
335        }
336        Authenticator::Sasl(adapter_client) => {
337            // Start the handshake
338            conn.send(BackendMessage::AuthenticationSASL).await?;
339            conn.flush().await?;
340            // Get the initial response indicating chosen mechanism
341            let (mechanism, initial_response) = match conn.recv().await? {
342                Some(FrontendMessage::RawAuthentication(data)) => {
343                    match decode_sasl_initial_response(Cursor::new(&data)).ok() {
344                        Some(FrontendMessage::SASLInitialResponse {
345                            gs2_header,
346                            mechanism,
347                            initial_response,
348                        }) => {
349                            // We do not support channel binding
350                            if gs2_header.channel_binding_enabled() {
351                                return conn
352                                    .send(ErrorResponse::fatal(
353                                        SqlState::PROTOCOL_VIOLATION,
354                                        "channel binding not supported",
355                                    ))
356                                    .await;
357                            }
358                            (mechanism, initial_response)
359                        }
360                        _ => {
361                            return conn
362                                .send(ErrorResponse::fatal(
363                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
364                                    "expected SASLInitialResponse message",
365                                ))
366                                .await;
367                        }
368                    }
369                }
370                _ => {
371                    return conn
372                        .send(ErrorResponse::fatal(
373                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
374                            "expected SASLInitialResponse message",
375                        ))
376                        .await;
377                }
378            };
379
380            if mechanism != "SCRAM-SHA-256" {
381                return conn
382                    .send(ErrorResponse::fatal(
383                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
384                        "unsupported SASL mechanism",
385                    ))
386                    .await;
387            }
388
389            if initial_response.nonce.len() > 256 {
390                return conn
391                    .send(ErrorResponse::fatal(
392                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
393                        "nonce too long",
394                    ))
395                    .await;
396            }
397
398            let (server_first_message_raw, mock_hash) = match adapter_client
399                .generate_sasl_challenge(&user, &initial_response.nonce)
400                .await
401            {
402                Ok(response) => {
403                    let server_first_message_raw = format!(
404                        "r={},s={},i={}",
405                        response.nonce, response.salt, response.iteration_count
406                    );
407
408                    let client_key = [0u8; 32];
409                    let server_key = [1u8; 32];
410                    let mock_hash = format!(
411                        "SCRAM-SHA-256${}:{}${}:{}",
412                        response.iteration_count,
413                        response.salt,
414                        BASE64_STANDARD.encode(client_key),
415                        BASE64_STANDARD.encode(server_key)
416                    );
417
418                    conn.send(BackendMessage::AuthenticationSASLContinue(
419                        SASLServerFirstMessage {
420                            iteration_count: response.iteration_count,
421                            nonce: response.nonce,
422                            salt: response.salt,
423                        },
424                    ))
425                    .await?;
426                    conn.flush().await?;
427                    (server_first_message_raw, mock_hash)
428                }
429                Err(e) => {
430                    return conn.send(e.into_response(Severity::Fatal)).await;
431                }
432            };
433
434            let authenticated = match conn.recv().await? {
435                Some(FrontendMessage::RawAuthentication(data)) => {
436                    match decode_sasl_response(Cursor::new(&data)).ok() {
437                        Some(FrontendMessage::SASLResponse(response)) => {
438                            let auth_message = format!(
439                                "{},{},{}",
440                                initial_response.client_first_message_bare_raw,
441                                server_first_message_raw,
442                                response.client_final_message_bare_raw
443                            );
444                            if response.proof.len() > 1024 {
445                                return conn
446                                    .send(ErrorResponse::fatal(
447                                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
448                                        "proof too long",
449                                    ))
450                                    .await;
451                            }
452                            match adapter_client
453                                .verify_sasl_proof(
454                                    &user,
455                                    &response.proof,
456                                    &auth_message,
457                                    &mock_hash,
458                                )
459                                .await
460                            {
461                                Ok((proof_response, authenticated)) => {
462                                    conn.send(BackendMessage::AuthenticationSASLFinal(
463                                        SASLServerFinalMessage {
464                                            kind: SASLServerFinalMessageKinds::Verifier(
465                                                proof_response.verifier,
466                                            ),
467                                            extensions: vec![],
468                                        },
469                                    ))
470                                    .await?;
471                                    conn.flush().await?;
472                                    authenticated
473                                }
474                                Err(_) => {
475                                    return conn
476                                        .send(ErrorResponse::fatal(
477                                            SqlState::INVALID_PASSWORD,
478                                            "invalid password",
479                                        ))
480                                        .await;
481                                }
482                            }
483                        }
484                        _ => {
485                            return conn
486                                .send(ErrorResponse::fatal(
487                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
488                                    "expected SASLResponse message",
489                                ))
490                                .await;
491                        }
492                    }
493                }
494                _ => {
495                    return conn
496                        .send(ErrorResponse::fatal(
497                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
498                            "expected SASLResponse message",
499                        ))
500                        .await;
501                }
502            };
503
504            let session = adapter_client.new_session(
505                SessionConfig {
506                    conn_id: conn.conn_id().clone(),
507                    uuid: conn_uuid,
508                    user,
509                    client_ip: conn.peer_addr().clone(),
510                    external_metadata_rx: None,
511                    helm_chart_version,
512                    authenticator_kind,
513                    groups: None,
514                },
515                authenticated,
516            );
517            // No frontegg check, so auth session lasts indefinitely.
518            let auth_session = pending().right_future();
519            (session, auth_session)
520        }
521
522        Authenticator::None => {
523            let session = adapter_client.new_session(
524                SessionConfig {
525                    conn_id: conn.conn_id().clone(),
526                    uuid: conn_uuid,
527                    user,
528                    client_ip: conn.peer_addr().clone(),
529                    external_metadata_rx: None,
530                    helm_chart_version,
531                    authenticator_kind,
532                    groups: None,
533                },
534                Authenticated,
535            );
536            // No frontegg check, so auth session lasts indefinitely.
537            let auth_session = pending().right_future();
538            (session, auth_session)
539        }
540    };
541
542    let system_vars = adapter_client.get_system_vars().await;
543    for (name, value) in params {
544        let settings = match name.as_str() {
545            "options" => match &options {
546                Ok(opts) => opts,
547                Err(()) => {
548                    session.add_notice(AdapterNotice::BadStartupSetting {
549                        name,
550                        reason: "could not parse".into(),
551                    });
552                    continue;
553                }
554            },
555            _ => &vec![(name, value)],
556        };
557        for (key, val) in settings {
558            const LOCAL: bool = false;
559            // TODO: Issuing an error here is better than what we did before
560            // (silently ignore errors on set), but erroring the connection
561            // might be the better behavior. We maybe need to support more
562            // options sent by psql and drivers before we can safely do this.
563            if let Err(err) = session
564                .vars_mut()
565                .set(&system_vars, key, VarInput::Flat(val), LOCAL)
566            {
567                session.add_notice(AdapterNotice::BadStartupSetting {
568                    name: key.clone(),
569                    reason: err.to_string(),
570                });
571            }
572        }
573    }
574    session
575        .vars_mut()
576        .end_transaction(EndTransactionAction::Commit);
577
578    let _guard = match active_connection_counter.allocate_connection(session.user()) {
579        Ok(drop_connection) => drop_connection,
580        Err(e) => {
581            let e: AdapterError = e.into();
582            return conn.send(e.into_response(Severity::Fatal)).await;
583        }
584    };
585
586    // Register session with adapter.
587    let mut adapter_client = match adapter_client.startup(session).await {
588        Ok(adapter_client) => adapter_client,
589        Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
590    };
591
592    let mut buf = vec![BackendMessage::AuthenticationOk];
593    for var in adapter_client.session().vars().notify_set() {
594        buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
595    }
596    buf.push(BackendMessage::BackendKeyData {
597        conn_id: adapter_client.session().conn_id().unhandled(),
598        secret_key: adapter_client.session().secret_key(),
599    });
600    buf.extend(
601        adapter_client
602            .session()
603            .drain_notices()
604            .into_iter()
605            .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
606    );
607    buf.push(BackendMessage::ReadyForQuery(
608        adapter_client.session().transaction().into(),
609    ));
610    conn.send_all(buf).await?;
611    conn.flush().await?;
612
613    let machine = StateMachine {
614        conn,
615        adapter_client,
616        txn_needs_commit: false,
617        tokio_metrics_intervals,
618    };
619
620    select! {
621        r = machine.run() => {
622            // Errors produced internally (like MAX_REQUEST_SIZE being exceeded) should send an
623            // error to the client informing them why the connection was closed. We still want to
624            // return the original error up the stack, though, so we skip error checking during conn
625            // operations.
626            if let Err(err) = &r {
627                let _ = conn
628                    .send(ErrorResponse::fatal(
629                        SqlState::CONNECTION_FAILURE,
630                        err.to_string(),
631                    ))
632                    .await;
633                let _ = conn.flush().await;
634            }
635            r
636        },
637        _ = expired => {
638            conn
639                .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
640                .await?;
641            conn.flush().await
642        }
643    }
644}
645
646/// Decides if a given password is a JWT by checking
647/// if we can decode its header.
648fn is_jwt(password: &str) -> bool {
649    jsonwebtoken::decode_header(password).is_ok()
650}
651
652/// Returns (name, value) session settings pairs from an options value.
653///
654/// From Postgres, see pg_split_opts in postinit.c and process_postgres_switches
655/// in postgres.c.
656fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
657    let opts = split_options(value);
658    let mut pairs = Vec::with_capacity(opts.len());
659    let mut seen_prefix = false;
660    for opt in opts {
661        if !seen_prefix {
662            if opt == "-c" {
663                seen_prefix = true;
664            } else {
665                let (key, val) = parse_option(&opt)?;
666                pairs.push((key.to_owned(), val.to_owned()));
667            }
668        } else {
669            let (key, val) = opt.split_once('=').ok_or(())?;
670            pairs.push((key.to_owned(), val.to_owned()));
671            seen_prefix = false;
672        }
673    }
674    Ok(pairs)
675}
676
677/// Returns the parsed key and value from option of the form `--key=value`, `-c
678/// key=value`, or `-ckey=value`. Keys replace `-` with `_`. Returns an error if
679/// there was some other prefix.
680fn parse_option(option: &str) -> Result<(&str, &str), ()> {
681    let (key, value) = option.split_once('=').ok_or(())?;
682    for prefix in &["-c", "--"] {
683        if let Some(key) = key.strip_prefix(prefix) {
684            return Ok((key, value));
685        }
686    }
687    Err(())
688}
689
690/// Splits value by any number of spaces except those preceded by `\`.
691fn split_options(value: &str) -> Vec<String> {
692    let mut strs = Vec::new();
693    // Need to build a string because of the escaping, so we can't simply
694    // subslice into value, and this isn't called enough to need to make it
695    // smart so it only builds a string if needed.
696    let mut current = String::new();
697    let mut was_slash = false;
698    for c in value.chars() {
699        was_slash = match c {
700            ' ' => {
701                if was_slash {
702                    current.push(' ');
703                } else if !current.is_empty() {
704                    // To ignore multiple spaces in a row, only push if current
705                    // is not empty.
706                    strs.push(std::mem::take(&mut current));
707                }
708                false
709            }
710            '\\' => {
711                if was_slash {
712                    // Two slashes in a row will add a slash and not escape the
713                    // next char.
714                    current.push('\\');
715                    false
716                } else {
717                    true
718                }
719            }
720            _ => {
721                current.push(c);
722                false
723            }
724        };
725    }
726    // A `\` at the end will be ignored.
727    if !current.is_empty() {
728        strs.push(current);
729    }
730    strs
731}
732
733enum PasswordRequestError {
734    InvalidPasswordError(ErrorResponse),
735    IoError(io::Error),
736}
737
738impl From<io::Error> for PasswordRequestError {
739    fn from(e: io::Error) -> Self {
740        PasswordRequestError::IoError(e)
741    }
742}
743
744/// Requests a cleartext password from a connection and returns it if it is valid.
745/// Sends an error response in the connection if the password
746/// is not valid.
747async fn request_cleartext_password<A>(
748    conn: &mut FramedConn<A>,
749) -> Result<String, PasswordRequestError>
750where
751    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
752{
753    conn.send(BackendMessage::AuthenticationCleartextPassword)
754        .await?;
755    conn.flush().await?;
756
757    if let Some(message) = conn.recv().await? {
758        if let FrontendMessage::RawAuthentication(data) = message {
759            if let Some(FrontendMessage::Password { password }) =
760                decode_password(Cursor::new(&data)).ok()
761            {
762                return Ok(password);
763            }
764        }
765    }
766
767    Err(PasswordRequestError::InvalidPasswordError(
768        ErrorResponse::fatal(
769            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
770            "expected Password message",
771        ),
772    ))
773}
774
775/// Helper for password-based authentication using AdapterClient
776/// and returns an authenticated session.
777async fn authenticate_with_password<A>(
778    conn: &FramedConn<A>,
779    adapter_client: &mz_adapter::Client,
780    user: String,
781    password: Password,
782    conn_uuid: Uuid,
783    helm_chart_version: Option<String>,
784) -> Result<Session, PasswordRequestError>
785where
786    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
787{
788    let authenticated = match adapter_client.authenticate(&user, &password).await {
789        Ok(authenticated) => authenticated,
790        Err(err) => {
791            warn!(?err, "pgwire connection failed authentication");
792            return Err(PasswordRequestError::InvalidPasswordError(
793                ErrorResponse::fatal(SqlState::INVALID_PASSWORD, "invalid password"),
794            ));
795        }
796    };
797
798    let session = adapter_client.new_session(
799        SessionConfig {
800            conn_id: conn.conn_id().clone(),
801            uuid: conn_uuid,
802            user,
803            client_ip: conn.peer_addr().clone(),
804            external_metadata_rx: None,
805            helm_chart_version,
806            authenticator_kind: mz_auth::AuthenticatorKind::Password,
807            groups: None,
808        },
809        authenticated,
810    );
811
812    Ok(session)
813}
814
815#[derive(Debug)]
816enum State {
817    Ready,
818    Drain,
819    Done,
820}
821
822struct StateMachine<'a, A, I>
823where
824    I: Iterator<Item = TaskMetrics> + Send + 'a,
825{
826    conn: &'a mut FramedConn<A>,
827    adapter_client: mz_adapter::SessionClient,
828    txn_needs_commit: bool,
829    tokio_metrics_intervals: I,
830}
831
832enum SendRowsEndedReason {
833    Success {
834        result_size: u64,
835        rows_returned: u64,
836    },
837    Errored {
838        error: String,
839    },
840    Canceled,
841}
842
843const ABORTED_TXN_MSG: &str =
844    "current transaction is aborted, commands ignored until end of transaction block";
845
846impl<'a, A, I> StateMachine<'a, A, I>
847where
848    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
849    I: Iterator<Item = TaskMetrics> + Send + 'a,
850{
851    // Manually desugar this (don't use `async fn run`) here because a much better
852    // error message is produced if there are problems with Send or other traits
853    // somewhere within the Future.
854    #[allow(clippy::manual_async_fn)]
855    #[mz_ore::instrument(level = "debug")]
856    fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
857        async move {
858            let mut state = State::Ready;
859            loop {
860                self.send_pending_notices().await?;
861                state = match state {
862                    State::Ready => self.advance_ready().await?,
863                    State::Drain => self.advance_drain().await?,
864                    State::Done => return Ok(()),
865                };
866                self.adapter_client
867                    .add_idle_in_transaction_session_timeout();
868            }
869        }
870    }
871
872    #[instrument(level = "debug")]
873    async fn advance_ready(&mut self) -> Result<State, io::Error> {
874        // Start a new metrics interval before the `recv()` call.
875        self.tokio_metrics_intervals
876            .next()
877            .expect("infinite iterator");
878
879        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
880        let message = select! {
881            biased;
882
883            // `recv_timeout()` is cancel-safe as per it's docs.
884            Some(timeout) = self.adapter_client.recv_timeout() => {
885                let err: AdapterError = timeout.into();
886                let conn_id = self.adapter_client.session().conn_id();
887                tracing::warn!("session timed out, conn_id {}", conn_id);
888
889                // Process the error, doing any state cleanup.
890                let error_response = err.into_response(Severity::Fatal);
891                let error_state = self.send_error_and_get_state(error_response).await;
892
893                // Terminate __after__ we do any cleanup.
894                self.adapter_client.terminate().await;
895
896                // We must wait for the client to send a request before we can send the error response.
897                // Due to the PG wire protocol, we can't send an ErrorResponse unless it is in response
898                // to a client message.
899                let _ = self.conn.recv().await?;
900                return error_state;
901            },
902            // `recv()` is cancel-safe as per it's docs.
903            message = self.conn.recv() => message?,
904        };
905
906        // Take the metrics since just before the `recv`.
907        let interval = self
908            .tokio_metrics_intervals
909            .next()
910            .expect("infinite iterator");
911        let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
912
913        // TODO(ggevay): Consider subtracting the scheduling delay from `received`. It's not obvious
914        // whether we should do this, because the result wouldn't exactly correspond to either first
915        // byte received or last byte received (for msgs that arrive in more than one network packet).
916        let received = SYSTEM_TIME();
917
918        self.adapter_client
919            .remove_idle_in_transaction_session_timeout();
920
921        // NOTE(guswynn): we could consider adding spans to all message types. Currently
922        // only a few message types seem useful.
923        let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
924
925        let start = message.as_ref().map(|_| Instant::now());
926        let next_state = match message {
927            Some(FrontendMessage::Query { sql }) => {
928                let query_root_span =
929                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
930                query_root_span.follows_from(tracing::Span::current());
931                self.query(sql, received)
932                    .instrument(query_root_span)
933                    .await?
934            }
935            Some(FrontendMessage::Parse {
936                name,
937                sql,
938                param_types,
939            }) => self.parse(name, sql, param_types).await?,
940            Some(FrontendMessage::Bind {
941                portal_name,
942                statement_name,
943                param_formats,
944                raw_params,
945                result_formats,
946            }) => {
947                self.bind(
948                    portal_name,
949                    statement_name,
950                    param_formats,
951                    raw_params,
952                    result_formats,
953                )
954                .await?
955            }
956            Some(FrontendMessage::Execute {
957                portal_name,
958                max_rows,
959            }) => {
960                let max_rows = match usize::try_from(max_rows) {
961                    Ok(0) | Err(_) => ExecuteCount::All, // If `max_rows < 0`, no limit.
962                    Ok(n) => ExecuteCount::Count(n),
963                };
964                let execute_root_span =
965                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
966                execute_root_span.follows_from(tracing::Span::current());
967                let state = self
968                    .execute(
969                        portal_name,
970                        max_rows,
971                        portal_exec_message,
972                        None,
973                        ExecuteTimeout::None,
974                        None,
975                        Some(received),
976                    )
977                    .instrument(execute_root_span)
978                    .await?;
979                // In PostgreSQL, when using the extended query protocol, some statements may
980                // trigger an eager commit of the current implicit transaction,
981                // see: <https://git.postgresql.org/gitweb/?p=postgresql.git&a=commitdiff&h=f92944137>.
982                //
983                // In Materialize, however, we eagerly commit every statement outside of an explicit
984                // transaction when using the extended query protocol. This allows us to eliminate
985                // the possibility of a multiple statement implicit transaction, which in turn
986                // allows us to apply single-statement optimizations to queries issued in implicit
987                // transactions in the extended query protocol.
988                //
989                // We don't immediately commit here to allow users to page through the portal if
990                // necessary. Committing the transaction would destroy the portal before the next
991                // Execute command has a chance to resume it. So we instead mark the transaction
992                // for commit the next time that `ensure_transaction` is called.
993                if self.adapter_client.session().transaction().is_implicit() {
994                    self.txn_needs_commit = true;
995                }
996                state
997            }
998            Some(FrontendMessage::DescribeStatement { name }) => {
999                self.describe_statement(&name).await?
1000            }
1001            Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
1002            Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
1003            Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
1004            Some(FrontendMessage::Flush) => self.flush().await?,
1005            Some(FrontendMessage::Sync) => self.sync().await?,
1006            Some(FrontendMessage::Terminate) => State::Done,
1007
1008            Some(FrontendMessage::CopyData(_))
1009            | Some(FrontendMessage::CopyDone)
1010            | Some(FrontendMessage::CopyFail(_))
1011            | Some(FrontendMessage::Password { .. })
1012            | Some(FrontendMessage::RawAuthentication(_))
1013            | Some(FrontendMessage::SASLInitialResponse { .. })
1014            | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
1015            None => State::Done,
1016        };
1017
1018        if let Some(start) = start {
1019            self.adapter_client
1020                .inner()
1021                .metrics()
1022                .pgwire_message_processing_seconds
1023                .with_label_values(&[message_name])
1024                .observe(start.elapsed().as_secs_f64());
1025        }
1026        self.adapter_client
1027            .inner()
1028            .metrics()
1029            .pgwire_recv_scheduling_delay_ms
1030            .with_label_values(&[message_name])
1031            .observe(recv_scheduling_delay_ms);
1032
1033        Ok(next_state)
1034    }
1035
1036    async fn advance_drain(&mut self) -> Result<State, io::Error> {
1037        let message = self.conn.recv().await?;
1038        if message.is_some() {
1039            self.adapter_client
1040                .remove_idle_in_transaction_session_timeout();
1041        }
1042        match message {
1043            Some(FrontendMessage::Sync) => self.sync().await,
1044            None => Ok(State::Done),
1045            _ => Ok(State::Drain),
1046        }
1047    }
1048
1049    /// Note that `lifecycle_timestamps` belongs to the whole "Simple Query", because the whole
1050    /// Simple Query is received and parsed together. This means that if there are multiple
1051    /// statements in a Simple Query, then all of them have the same `lifecycle_timestamps`.
1052    #[instrument(level = "debug")]
1053    async fn one_query(
1054        &mut self,
1055        stmt: Statement<Raw>,
1056        sql: String,
1057        lifecycle_timestamps: LifecycleTimestamps,
1058    ) -> Result<State, io::Error> {
1059        // Bind the portal. Note that this does not set the empty string prepared
1060        // statement.
1061        const EMPTY_PORTAL: &str = "";
1062        if let Err(e) = self
1063            .adapter_client
1064            .declare(EMPTY_PORTAL.to_string(), stmt, sql)
1065            .await
1066        {
1067            return self
1068                .send_error_and_get_state(e.into_response(Severity::Error))
1069                .await;
1070        }
1071        let portal = self
1072            .adapter_client
1073            .session()
1074            .get_portal_unverified_mut(EMPTY_PORTAL)
1075            .expect("unnamed portal should be present");
1076
1077        *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
1078
1079        let stmt_desc = portal.desc.clone();
1080        if !stmt_desc.param_types.is_empty() {
1081            return self
1082                .send_error_and_get_state(ErrorResponse::error(
1083                    SqlState::UNDEFINED_PARAMETER,
1084                    "there is no parameter $1",
1085                ))
1086                .await;
1087        }
1088
1089        // Maybe send row description.
1090        if let Some(relation_desc) = &stmt_desc.relation_desc {
1091            if !stmt_desc.is_copy {
1092                let formats = vec![Format::Text; stmt_desc.arity()];
1093                self.send(BackendMessage::RowDescription(
1094                    message::encode_row_description(relation_desc, &formats),
1095                ))
1096                .await?;
1097            }
1098        }
1099
1100        let result = match self
1101            .adapter_client
1102            .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
1103            .await
1104        {
1105            Ok((response, execute_started)) => {
1106                self.send_pending_notices().await?;
1107                self.send_execute_response(
1108                    response,
1109                    stmt_desc.relation_desc,
1110                    EMPTY_PORTAL.to_string(),
1111                    ExecuteCount::All,
1112                    portal_exec_message,
1113                    None,
1114                    ExecuteTimeout::None,
1115                    execute_started,
1116                )
1117                .await
1118            }
1119            Err(e) => {
1120                self.send_pending_notices().await?;
1121                self.send_error_and_get_state(e.into_response(Severity::Error))
1122                    .await
1123            }
1124        };
1125
1126        // Destroy the portal.
1127        self.adapter_client.session().remove_portal(EMPTY_PORTAL);
1128
1129        result
1130    }
1131
1132    async fn ensure_transaction(
1133        &mut self,
1134        num_stmts: usize,
1135        message_type: &str,
1136    ) -> Result<(), io::Error> {
1137        let start = Instant::now();
1138        if self.txn_needs_commit {
1139            self.commit_transaction().await?;
1140        }
1141        // start_transaction can't error (but assert that just in case it changes in
1142        // the future.
1143        let res = self.adapter_client.start_transaction(Some(num_stmts));
1144        assert_ok!(res);
1145        self.adapter_client
1146            .inner()
1147            .metrics()
1148            .pgwire_ensure_transaction_seconds
1149            .with_label_values(&[message_type])
1150            .observe(start.elapsed().as_secs_f64());
1151        Ok(())
1152    }
1153
1154    fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1155        let parse_start = Instant::now();
1156        let result = match self.adapter_client.parse(sql) {
1157            Ok(result) => result.map_err(|e| {
1158                // Convert our 0-based byte position to pgwire's 1-based character
1159                // position.
1160                let pos = sql[..e.error.pos].chars().count() + 1;
1161                ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1162            }),
1163            Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1164        };
1165        self.adapter_client
1166            .inner()
1167            .metrics()
1168            .parse_seconds
1169            .observe(parse_start.elapsed().as_secs_f64());
1170        result
1171    }
1172
1173    /// Executes a "Simple Query", see
1174    /// <https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SIMPLE-QUERY>
1175    ///
1176    /// For implicit transaction handling, see "Multiple Statements in a Simple Query" in the above.
1177    #[instrument(level = "debug")]
1178    async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1179        // Parse first before doing any transaction checking.
1180        let stmts = match self.parse_sql(&sql) {
1181            Ok(stmts) => stmts,
1182            Err(err) => {
1183                self.send_error_and_get_state(err).await?;
1184                return self.ready().await;
1185            }
1186        };
1187
1188        let num_stmts = stmts.len();
1189
1190        // Compare with postgres' backend/tcop/postgres.c exec_simple_query.
1191        for StatementParseResult { ast: stmt, sql } in stmts {
1192            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1193            if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1194                self.aborted_txn_error().await?;
1195                break;
1196            }
1197
1198            // Start an implicit transaction if we aren't in any transaction and there's
1199            // more than one statement. This mirrors the `use_implicit_block` variable in
1200            // postgres.
1201            //
1202            // This needs to be done in the loop instead of once at the top because
1203            // a COMMIT/ROLLBACK statement needs to start a new transaction on next
1204            // statement.
1205            self.ensure_transaction(num_stmts, "query").await?;
1206
1207            match self
1208                .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1209                .await?
1210            {
1211                State::Ready => (),
1212                State::Drain => break,
1213                State::Done => return Ok(State::Done),
1214            }
1215        }
1216
1217        // Implicit transactions are closed at the end of a Query message.
1218        {
1219            if self.adapter_client.session().transaction().is_implicit() {
1220                self.commit_transaction().await?;
1221            }
1222        }
1223
1224        if num_stmts == 0 {
1225            self.send(BackendMessage::EmptyQueryResponse).await?;
1226        }
1227
1228        self.ready().await
1229    }
1230
1231    #[instrument(level = "debug")]
1232    async fn parse(
1233        &mut self,
1234        name: String,
1235        sql: String,
1236        param_oids: Vec<u32>,
1237    ) -> Result<State, io::Error> {
1238        // Start a transaction if we aren't in one.
1239        self.ensure_transaction(1, "parse").await?;
1240
1241        let mut param_types = vec![];
1242        for oid in param_oids {
1243            match mz_pgrepr::Type::from_oid(oid) {
1244                Ok(ty) => match SqlScalarType::try_from(&ty) {
1245                    Ok(ty) => param_types.push(Some(ty)),
1246                    Err(err) => {
1247                        return self
1248                            .send_error_and_get_state(ErrorResponse::error(
1249                                SqlState::INVALID_PARAMETER_VALUE,
1250                                err.to_string(),
1251                            ))
1252                            .await;
1253                    }
1254                },
1255                Err(_) if oid == 0 => param_types.push(None),
1256                Err(e) => {
1257                    return self
1258                        .send_error_and_get_state(ErrorResponse::error(
1259                            SqlState::PROTOCOL_VIOLATION,
1260                            e.to_string(),
1261                        ))
1262                        .await;
1263                }
1264            }
1265        }
1266
1267        let stmts = match self.parse_sql(&sql) {
1268            Ok(stmts) => stmts,
1269            Err(err) => {
1270                return self.send_error_and_get_state(err).await;
1271            }
1272        };
1273        if stmts.len() > 1 {
1274            return self
1275                .send_error_and_get_state(ErrorResponse::error(
1276                    SqlState::INTERNAL_ERROR,
1277                    "cannot insert multiple commands into a prepared statement",
1278                ))
1279                .await;
1280        }
1281        let (maybe_stmt, sql) = match stmts.into_iter().next() {
1282            None => (None, ""),
1283            Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1284        };
1285        if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1286            return self.aborted_txn_error().await;
1287        }
1288        match self
1289            .adapter_client
1290            .prepare(name, maybe_stmt, sql.to_string(), param_types)
1291            .await
1292        {
1293            Ok(()) => {
1294                self.send(BackendMessage::ParseComplete).await?;
1295                Ok(State::Ready)
1296            }
1297            Err(e) => {
1298                self.send_error_and_get_state(e.into_response(Severity::Error))
1299                    .await
1300            }
1301        }
1302    }
1303
1304    /// Commits and clears the current transaction.
1305    #[instrument(level = "debug")]
1306    async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1307        self.end_transaction(EndTransactionAction::Commit).await
1308    }
1309
1310    /// Rollback and clears the current transaction.
1311    #[instrument(level = "debug")]
1312    async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1313        self.end_transaction(EndTransactionAction::Rollback).await
1314    }
1315
1316    /// End a transaction and report to the user if an error occurred.
1317    #[instrument(level = "debug")]
1318    async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1319        self.txn_needs_commit = false;
1320        let resp = self.adapter_client.end_transaction(action).await;
1321        if let Err(err) = resp {
1322            self.send(BackendMessage::ErrorResponse(
1323                err.into_response(Severity::Error),
1324            ))
1325            .await?;
1326        }
1327        Ok(())
1328    }
1329
1330    #[instrument(level = "debug")]
1331    async fn bind(
1332        &mut self,
1333        portal_name: String,
1334        statement_name: String,
1335        param_formats: Vec<Format>,
1336        raw_params: Vec<Option<Vec<u8>>>,
1337        result_formats: Vec<Format>,
1338    ) -> Result<State, io::Error> {
1339        // Start a transaction if we aren't in one.
1340        self.ensure_transaction(1, "bind").await?;
1341
1342        let aborted_txn = self.is_aborted_txn();
1343        let stmt = match self
1344            .adapter_client
1345            .get_prepared_statement(&statement_name)
1346            .await
1347        {
1348            Ok(stmt) => stmt,
1349            Err(err) => {
1350                return self
1351                    .send_error_and_get_state(err.into_response(Severity::Error))
1352                    .await;
1353            }
1354        };
1355
1356        let param_types = &stmt.desc().param_types;
1357        if param_types.len() != raw_params.len() {
1358            let message = format!(
1359                "bind message supplies {actual} parameters, \
1360                 but prepared statement \"{name}\" requires {expected}",
1361                name = statement_name,
1362                actual = raw_params.len(),
1363                expected = param_types.len()
1364            );
1365            return self
1366                .send_error_and_get_state(ErrorResponse::error(
1367                    SqlState::PROTOCOL_VIOLATION,
1368                    message,
1369                ))
1370                .await;
1371        }
1372        let param_formats = match pad_formats(param_formats, raw_params.len()) {
1373            Ok(param_formats) => param_formats,
1374            Err(msg) => {
1375                return self
1376                    .send_error_and_get_state(ErrorResponse::error(
1377                        SqlState::PROTOCOL_VIOLATION,
1378                        msg,
1379                    ))
1380                    .await;
1381            }
1382        };
1383        if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1384            return self.aborted_txn_error().await;
1385        }
1386        let buf = RowArena::new();
1387        let mut params = vec![];
1388        for ((raw_param, mz_typ), format) in raw_params
1389            .into_iter()
1390            .zip_eq(param_types)
1391            .zip_eq(param_formats)
1392        {
1393            let pg_typ = mz_pgrepr::Type::from(mz_typ);
1394            let datum = match raw_param {
1395                None => Datum::Null,
1396                Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1397                    Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1398                        Ok(datum) => datum,
1399                        Err(msg) => {
1400                            return self
1401                                .send_error_and_get_state(ErrorResponse::error(
1402                                    SqlState::INVALID_PARAMETER_VALUE,
1403                                    msg,
1404                                ))
1405                                .await;
1406                        }
1407                    },
1408                    Err(err) => {
1409                        let msg = format!("unable to decode parameter: {}", err);
1410                        return self
1411                            .send_error_and_get_state(ErrorResponse::error(
1412                                SqlState::INVALID_PARAMETER_VALUE,
1413                                msg,
1414                            ))
1415                            .await;
1416                    }
1417                },
1418            };
1419            params.push((datum, mz_typ.clone()))
1420        }
1421
1422        let result_formats = match pad_formats(
1423            result_formats,
1424            stmt.desc()
1425                .relation_desc
1426                .clone()
1427                .map(|desc| desc.typ().column_types.len())
1428                .unwrap_or(0),
1429        ) {
1430            Ok(result_formats) => result_formats,
1431            Err(msg) => {
1432                return self
1433                    .send_error_and_get_state(ErrorResponse::error(
1434                        SqlState::PROTOCOL_VIOLATION,
1435                        msg,
1436                    ))
1437                    .await;
1438            }
1439        };
1440
1441        // Binary encodings are disabled for list, map, and aclitem types, but this doesn't
1442        // apply to COPY TO statements.
1443        if !stmt.stmt().map_or(false, |stmt| match stmt {
1444            Statement::Copy(CopyStatement {
1445                direction: CopyDirection::To,
1446                ..
1447            }) => true,
1448            Statement::Copy(CopyStatement {
1449                direction: CopyDirection::From,
1450                // To be conservative, we are restricting COPY FROM to only allow list/map/aclitem types if it is not
1451                // copying from STDIN. It is likely that this works in theory, but is risky and likely to OOM anyways
1452                // as all the data will be held in a buffer in memory before being processed.
1453                target: CopyTarget::Expr(_),
1454                ..
1455            }) => true,
1456            _ => false,
1457        }) {
1458            if let Some(desc) = stmt.desc().relation_desc.clone() {
1459                for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1460                    match (format, &ty.scalar_type) {
1461                        (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1462                            return self
1463                                .send_error_and_get_state(ErrorResponse::error(
1464                                    SqlState::PROTOCOL_VIOLATION,
1465                                    "binary encoding of list types is not implemented",
1466                                ))
1467                                .await;
1468                        }
1469                        (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1470                            return self
1471                                .send_error_and_get_state(ErrorResponse::error(
1472                                    SqlState::PROTOCOL_VIOLATION,
1473                                    "binary encoding of map types is not implemented",
1474                                ))
1475                                .await;
1476                        }
1477                        (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1478                            return self
1479                                .send_error_and_get_state(ErrorResponse::error(
1480                                    SqlState::PROTOCOL_VIOLATION,
1481                                    "binary encoding of aclitem types does not exist",
1482                                ))
1483                                .await;
1484                        }
1485                        _ => (),
1486                    }
1487                }
1488            }
1489        }
1490
1491        let desc = stmt.desc().clone();
1492        let logging = Arc::clone(stmt.logging());
1493        let stmt_ast = stmt.stmt().cloned();
1494        let state_revision = stmt.state_revision;
1495        if let Err(err) = self.adapter_client.session().set_portal(
1496            portal_name,
1497            desc,
1498            stmt_ast,
1499            logging,
1500            params,
1501            result_formats,
1502            state_revision,
1503        ) {
1504            return self
1505                .send_error_and_get_state(err.into_response(Severity::Error))
1506                .await;
1507        }
1508
1509        self.send(BackendMessage::BindComplete).await?;
1510        Ok(State::Ready)
1511    }
1512
1513    /// `outer_ctx_extra` is Some when we are executing as part of an outer statement, e.g., a FETCH
1514    /// triggering the execution of the underlying query.
1515    fn execute(
1516        &mut self,
1517        portal_name: String,
1518        max_rows: ExecuteCount,
1519        get_response: GetResponse,
1520        fetch_portal_name: Option<String>,
1521        timeout: ExecuteTimeout,
1522        outer_ctx_extra: Option<ExecuteContextGuard>,
1523        received: Option<EpochMillis>,
1524    ) -> BoxFuture<'_, Result<State, io::Error>> {
1525        async move {
1526            let aborted_txn = self.is_aborted_txn();
1527
1528            // Check if the portal has been started and can be continued.
1529            let portal = match self
1530                .adapter_client
1531                .session()
1532                .get_portal_unverified_mut(&portal_name)
1533            {
1534                Some(portal) => portal,
1535                None => {
1536                    let msg = format!("portal {} does not exist", portal_name.quoted());
1537                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1538                        self.adapter_client.retire_execute(
1539                            outer_ctx_extra,
1540                            StatementEndedExecutionReason::Errored { error: msg.clone() },
1541                        );
1542                    }
1543                    return self
1544                        .send_error_and_get_state(ErrorResponse::error(
1545                            SqlState::INVALID_CURSOR_NAME,
1546                            msg,
1547                        ))
1548                        .await;
1549                }
1550            };
1551
1552            *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1553
1554            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1555            let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1556            if aborted_txn && !txn_exit_stmt {
1557                if let Some(outer_ctx_extra) = outer_ctx_extra {
1558                    self.adapter_client.retire_execute(
1559                        outer_ctx_extra,
1560                        StatementEndedExecutionReason::Errored {
1561                            error: ABORTED_TXN_MSG.to_string(),
1562                        },
1563                    );
1564                }
1565                return self.aborted_txn_error().await;
1566            }
1567
1568            let row_desc = portal.desc.relation_desc.clone();
1569            match portal.state {
1570                PortalState::NotStarted => {
1571                    // Start a transaction if we aren't in one.
1572                    self.ensure_transaction(1, "execute").await?;
1573                    match self
1574                        .adapter_client
1575                        .execute(
1576                            portal_name.clone(),
1577                            self.conn.wait_closed(),
1578                            outer_ctx_extra,
1579                        )
1580                        .await
1581                    {
1582                        Ok((response, execute_started)) => {
1583                            self.send_pending_notices().await?;
1584                            self.send_execute_response(
1585                                response,
1586                                row_desc,
1587                                portal_name,
1588                                max_rows,
1589                                get_response,
1590                                fetch_portal_name,
1591                                timeout,
1592                                execute_started,
1593                            )
1594                            .await
1595                        }
1596                        Err(e) => {
1597                            self.send_pending_notices().await?;
1598                            self.send_error_and_get_state(e.into_response(Severity::Error))
1599                                .await
1600                        }
1601                    }
1602                }
1603                PortalState::InProgress(rows) => {
1604                    let rows = rows.take().expect("InProgress rows must be populated");
1605                    let (result, statement_ended_execution_reason) = match self
1606                        .send_rows(
1607                            row_desc.expect("portal missing row desc on resumption"),
1608                            portal_name,
1609                            rows,
1610                            max_rows,
1611                            get_response,
1612                            fetch_portal_name,
1613                            timeout,
1614                        )
1615                        .await
1616                    {
1617                        Err(e) => {
1618                            // This is an error communicating with the connection.
1619                            // We consider that to be a cancelation, rather than a query error.
1620                            (Err(e), StatementEndedExecutionReason::Canceled)
1621                        }
1622                        Ok((ok, SendRowsEndedReason::Canceled)) => {
1623                            (Ok(ok), StatementEndedExecutionReason::Canceled)
1624                        }
1625                        // NOTE: For now the values for `result_size` and
1626                        // `rows_returned` in fetches are a bit confusing.
1627                        // We record `Some(n)` for the first fetch, where `n` is
1628                        // the number of bytes/rows returned by the inner
1629                        // execute (regardless of how many rows the
1630                        // fetch fetched), and `None` for subsequent fetches.
1631                        //
1632                        // This arguably makes sense since the size/rows
1633                        // returned measures how much work the compute
1634                        // layer had to do to satisfy the query, but
1635                        // we should revisit it if/when we start
1636                        // logging the inner execute separately.
1637                        Ok((
1638                            ok,
1639                            SendRowsEndedReason::Success {
1640                                result_size: _,
1641                                rows_returned: _,
1642                            },
1643                        )) => (
1644                            Ok(ok),
1645                            StatementEndedExecutionReason::Success {
1646                                result_size: None,
1647                                rows_returned: None,
1648                                execution_strategy: None,
1649                            },
1650                        ),
1651                        Ok((ok, SendRowsEndedReason::Errored { error })) => {
1652                            (Ok(ok), StatementEndedExecutionReason::Errored { error })
1653                        }
1654                    };
1655                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1656                        self.adapter_client
1657                            .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1658                    }
1659                    result
1660                }
1661                // FETCH is an awkward command for our current architecture. In Postgres it
1662                // will extract <count> rows from the target portal, cache them, and return
1663                // them to the user as requested. Its command tag is always FETCH <num rows
1664                // extracted>. In Materialize, since we have chosen to not fully support FETCH,
1665                // we must remember the number of rows that were returned. Use this tag to
1666                // remember that information and return it.
1667                PortalState::Completed(Some(tag)) => {
1668                    let tag = tag.to_string();
1669                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1670                        self.adapter_client.retire_execute(
1671                            outer_ctx_extra,
1672                            StatementEndedExecutionReason::Success {
1673                                result_size: None,
1674                                rows_returned: None,
1675                                execution_strategy: None,
1676                            },
1677                        );
1678                    }
1679                    self.send(BackendMessage::CommandComplete { tag }).await?;
1680                    Ok(State::Ready)
1681                }
1682                PortalState::Completed(None) => {
1683                    let error = format!(
1684                        "portal {} cannot be run",
1685                        Ident::new_unchecked(portal_name).to_ast_string_stable()
1686                    );
1687                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1688                        self.adapter_client.retire_execute(
1689                            outer_ctx_extra,
1690                            StatementEndedExecutionReason::Errored {
1691                                error: error.clone(),
1692                            },
1693                        );
1694                    }
1695                    self.send_error_and_get_state(ErrorResponse::error(
1696                        SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1697                        error,
1698                    ))
1699                    .await
1700                }
1701            }
1702        }
1703        .instrument(debug_span!("execute"))
1704        .boxed()
1705    }
1706
1707    #[instrument(level = "debug")]
1708    async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1709        // Start a transaction if we aren't in one.
1710        self.ensure_transaction(1, "describe_statement").await?;
1711
1712        let stmt = match self.adapter_client.get_prepared_statement(name).await {
1713            Ok(stmt) => stmt,
1714            Err(err) => {
1715                return self
1716                    .send_error_and_get_state(err.into_response(Severity::Error))
1717                    .await;
1718            }
1719        };
1720        // Cloning to avoid a mutable borrow issue because `send` also uses `adapter_client`
1721        let parameter_desc = BackendMessage::ParameterDescription(
1722            stmt.desc()
1723                .param_types
1724                .iter()
1725                .map(mz_pgrepr::Type::from)
1726                .collect(),
1727        );
1728        // Claim that all results will be output in text format, even
1729        // though the true result formats are not yet known. A bit
1730        // weird, but this is the behavior that PostgreSQL specifies.
1731        let formats = vec![Format::Text; stmt.desc().arity()];
1732        let row_desc = describe_rows(stmt.desc(), &formats);
1733        self.send_all([parameter_desc, row_desc]).await?;
1734        Ok(State::Ready)
1735    }
1736
1737    #[instrument(level = "debug")]
1738    async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1739        // Start a transaction if we aren't in one.
1740        self.ensure_transaction(1, "describe_portal").await?;
1741
1742        let session = self.adapter_client.session();
1743        let row_desc = session
1744            .get_portal_unverified(name)
1745            .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1746        match row_desc {
1747            Some(row_desc) => {
1748                self.send(row_desc).await?;
1749                Ok(State::Ready)
1750            }
1751            None => {
1752                self.send_error_and_get_state(ErrorResponse::error(
1753                    SqlState::INVALID_CURSOR_NAME,
1754                    format!("portal {} does not exist", name.quoted()),
1755                ))
1756                .await
1757            }
1758        }
1759    }
1760
1761    #[instrument(level = "debug")]
1762    async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1763        self.adapter_client
1764            .session()
1765            .remove_prepared_statement(&name);
1766        self.send(BackendMessage::CloseComplete).await?;
1767        Ok(State::Ready)
1768    }
1769
1770    #[instrument(level = "debug")]
1771    async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1772        self.adapter_client.session().remove_portal(&name);
1773        self.send(BackendMessage::CloseComplete).await?;
1774        Ok(State::Ready)
1775    }
1776
1777    fn complete_portal(&mut self, name: &str) {
1778        let portal = self
1779            .adapter_client
1780            .session()
1781            .get_portal_unverified_mut(name)
1782            .expect("portal should exist");
1783        *portal.state = PortalState::Completed(None);
1784    }
1785
1786    async fn fetch(
1787        &mut self,
1788        name: String,
1789        count: Option<FetchDirection>,
1790        max_rows: ExecuteCount,
1791        fetch_portal_name: Option<String>,
1792        timeout: ExecuteTimeout,
1793        ctx_extra: ExecuteContextGuard,
1794    ) -> Result<State, io::Error> {
1795        // Unlike Execute, no count specified in FETCH returns 1 row, and 0 means 0
1796        // instead of All.
1797        let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1798
1799        // Figure out how many rows we should send back by looking at the various
1800        // combinations of the execute and fetch.
1801        //
1802        // In Postgres, Fetch will cache <count> rows from the target portal and
1803        // return those as requested (if, say, an Execute message was sent with a
1804        // max_rows < the Fetch's count). We expect that case to be incredibly rare and
1805        // so have chosen to not support it until users request it. This eases
1806        // implementation difficulty since we don't have to be able to "send" rows to
1807        // a buffer.
1808        //
1809        // TODO(mjibson): Test this somehow? Need to divide up the pgtest files in
1810        // order to have some that are not Postgres compatible.
1811        let count = match (max_rows, count) {
1812            (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1813                let count = usize::cast_from(count);
1814                if max_rows < count {
1815                    let msg = "Execute with max_rows < a FETCH's count is not supported";
1816                    self.adapter_client.retire_execute(
1817                        ctx_extra,
1818                        StatementEndedExecutionReason::Errored {
1819                            error: msg.to_string(),
1820                        },
1821                    );
1822                    return self
1823                        .send_error_and_get_state(ErrorResponse::error(
1824                            SqlState::FEATURE_NOT_SUPPORTED,
1825                            msg,
1826                        ))
1827                        .await;
1828                }
1829                ExecuteCount::Count(count)
1830            }
1831            (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1832                let msg = "Execute with max_rows of a FETCH ALL is not supported";
1833                self.adapter_client.retire_execute(
1834                    ctx_extra,
1835                    StatementEndedExecutionReason::Errored {
1836                        error: msg.to_string(),
1837                    },
1838                );
1839                return self
1840                    .send_error_and_get_state(ErrorResponse::error(
1841                        SqlState::FEATURE_NOT_SUPPORTED,
1842                        msg,
1843                    ))
1844                    .await;
1845            }
1846            (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1847            (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1848                ExecuteCount::Count(usize::cast_from(count))
1849            }
1850        };
1851        let cursor_name = name.to_string();
1852        self.execute(
1853            cursor_name,
1854            count,
1855            fetch_message,
1856            fetch_portal_name,
1857            timeout,
1858            Some(ctx_extra),
1859            None,
1860        )
1861        .await
1862    }
1863
1864    async fn flush(&mut self) -> Result<State, io::Error> {
1865        self.conn.flush().await?;
1866        Ok(State::Ready)
1867    }
1868
1869    /// Sends a backend message to the client, after applying a severity filter.
1870    ///
1871    /// The message is only sent if its severity is above the severity set
1872    /// in the session, with the default value being NOTICE.
1873    #[instrument(level = "debug")]
1874    async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1875    where
1876        M: Into<BackendMessage>,
1877    {
1878        let message: BackendMessage = message.into();
1879        let is_error =
1880            matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1881
1882        self.conn.send(message).await?;
1883
1884        // Flush immediately after sending an error response, as some clients
1885        // expect to be able to read the error response before sending a Sync
1886        // message. This is arguably in violation of the protocol specification,
1887        // but the specification is somewhat ambiguous, and easier to match
1888        // PostgreSQL here than to fix all the clients that have this
1889        // expectation.
1890        if is_error {
1891            self.conn.flush().await?;
1892        }
1893
1894        Ok(())
1895    }
1896
1897    #[instrument(level = "debug")]
1898    pub async fn send_all(
1899        &mut self,
1900        messages: impl IntoIterator<Item = BackendMessage>,
1901    ) -> Result<(), io::Error> {
1902        for m in messages {
1903            self.send(m).await?;
1904        }
1905        Ok(())
1906    }
1907
1908    #[instrument(level = "debug")]
1909    async fn sync(&mut self) -> Result<State, io::Error> {
1910        // Close the current transaction if we are in an implicit transaction.
1911        if self.adapter_client.session().transaction().is_implicit() {
1912            self.commit_transaction().await?;
1913        }
1914        self.ready().await
1915    }
1916
1917    #[instrument(level = "debug")]
1918    async fn ready(&mut self) -> Result<State, io::Error> {
1919        let txn_state = self.adapter_client.session().transaction().into();
1920        self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1921        self.flush().await
1922    }
1923
1924    #[allow(clippy::too_many_arguments)]
1925    #[instrument(level = "debug")]
1926    async fn send_execute_response(
1927        &mut self,
1928        response: ExecuteResponse,
1929        row_desc: Option<RelationDesc>,
1930        portal_name: String,
1931        max_rows: ExecuteCount,
1932        get_response: GetResponse,
1933        fetch_portal_name: Option<String>,
1934        timeout: ExecuteTimeout,
1935        execute_started: Instant,
1936    ) -> Result<State, io::Error> {
1937        let mut tag = response.tag();
1938
1939        macro_rules! command_complete {
1940            () => {{
1941                self.send(BackendMessage::CommandComplete {
1942                    tag: tag
1943                        .take()
1944                        .expect("command_complete only called on tag-generating results"),
1945                })
1946                .await?;
1947                Ok(State::Ready)
1948            }};
1949        }
1950
1951        let r = match response {
1952            ExecuteResponse::ClosedCursor => {
1953                self.complete_portal(&portal_name);
1954                command_complete!()
1955            }
1956            ExecuteResponse::DeclaredCursor => {
1957                self.complete_portal(&portal_name);
1958                command_complete!()
1959            }
1960            ExecuteResponse::EmptyQuery => {
1961                self.send(BackendMessage::EmptyQueryResponse).await?;
1962                Ok(State::Ready)
1963            }
1964            ExecuteResponse::Fetch {
1965                name,
1966                count,
1967                timeout,
1968                ctx_extra,
1969            } => {
1970                self.fetch(
1971                    name,
1972                    count,
1973                    max_rows,
1974                    Some(portal_name.to_string()),
1975                    timeout,
1976                    ctx_extra,
1977                )
1978                .await
1979            }
1980            ExecuteResponse::SendingRowsStreaming {
1981                rows,
1982                instance_id,
1983                strategy,
1984            } => {
1985                let row_desc = row_desc
1986                    .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1987
1988                let span = tracing::debug_span!("sending_rows_streaming");
1989
1990                self.send_rows(
1991                    row_desc,
1992                    portal_name,
1993                    InProgressRows::new(RecordFirstRowStream::new(
1994                        Box::new(rows),
1995                        execute_started,
1996                        &self.adapter_client,
1997                        Some(instance_id),
1998                        Some(strategy),
1999                    )),
2000                    max_rows,
2001                    get_response,
2002                    fetch_portal_name,
2003                    timeout,
2004                )
2005                .instrument(span)
2006                .await
2007                .map(|(state, _)| state)
2008            }
2009            ExecuteResponse::SendingRowsImmediate { rows } => {
2010                let row_desc = row_desc
2011                    .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
2012
2013                let span = tracing::debug_span!("sending_rows_immediate");
2014
2015                let stream =
2016                    futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
2017                self.send_rows(
2018                    row_desc,
2019                    portal_name,
2020                    InProgressRows::new(RecordFirstRowStream::new(
2021                        Box::new(stream),
2022                        execute_started,
2023                        &self.adapter_client,
2024                        None,
2025                        Some(StatementExecutionStrategy::Constant),
2026                    )),
2027                    max_rows,
2028                    get_response,
2029                    fetch_portal_name,
2030                    timeout,
2031                )
2032                .instrument(span)
2033                .await
2034                .map(|(state, _)| state)
2035            }
2036            ExecuteResponse::SetVariable { name, .. } => {
2037                // This code is somewhat awkwardly structured because we
2038                // can't hold `var` across an await point.
2039                let qn = name.to_string();
2040                let msg = if let Some(var) = self
2041                    .adapter_client
2042                    .session()
2043                    .vars_mut()
2044                    .notify_set()
2045                    .find(|v| v.name() == qn)
2046                {
2047                    Some(BackendMessage::ParameterStatus(var.name(), var.value()))
2048                } else {
2049                    None
2050                };
2051                if let Some(msg) = msg {
2052                    self.send(msg).await?;
2053                }
2054                command_complete!()
2055            }
2056            ExecuteResponse::Subscribing {
2057                rx,
2058                ctx_extra,
2059                instance_id,
2060            } => {
2061                if fetch_portal_name.is_none() {
2062                    let mut msg = ErrorResponse::notice(
2063                        SqlState::WARNING,
2064                        "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
2065                    );
2066                    if self.adapter_client.session().vars().application_name() == "psql" {
2067                        msg.hint = Some(
2068                            "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
2069                                .into(),
2070                        )
2071                    }
2072                    self.send(msg).await?;
2073                    self.conn.flush().await?;
2074                }
2075                let row_desc =
2076                    row_desc.expect("missing row description for ExecuteResponse::Subscribing");
2077                let (result, statement_ended_execution_reason) = match self
2078                    .send_rows(
2079                        row_desc,
2080                        portal_name,
2081                        InProgressRows::new(RecordFirstRowStream::new(
2082                            Box::new(UnboundedReceiverStream::new(rx)),
2083                            execute_started,
2084                            &self.adapter_client,
2085                            Some(instance_id),
2086                            None,
2087                        )),
2088                        max_rows,
2089                        get_response,
2090                        fetch_portal_name,
2091                        timeout,
2092                    )
2093                    .await
2094                {
2095                    Err(e) => {
2096                        // This is an error communicating with the connection.
2097                        // We consider that to be a cancelation, rather than a query error.
2098                        (Err(e), StatementEndedExecutionReason::Canceled)
2099                    }
2100                    Ok((ok, SendRowsEndedReason::Canceled)) => {
2101                        (Ok(ok), StatementEndedExecutionReason::Canceled)
2102                    }
2103                    Ok((
2104                        ok,
2105                        SendRowsEndedReason::Success {
2106                            result_size,
2107                            rows_returned,
2108                        },
2109                    )) => (
2110                        Ok(ok),
2111                        StatementEndedExecutionReason::Success {
2112                            result_size: Some(result_size),
2113                            rows_returned: Some(rows_returned),
2114                            execution_strategy: None,
2115                        },
2116                    ),
2117                    Ok((ok, SendRowsEndedReason::Errored { error })) => {
2118                        (Ok(ok), StatementEndedExecutionReason::Errored { error })
2119                    }
2120                };
2121                self.adapter_client
2122                    .retire_execute(ctx_extra, statement_ended_execution_reason);
2123                return result;
2124            }
2125            ExecuteResponse::CopyTo { format, resp } => {
2126                let row_desc =
2127                    row_desc.expect("missing row description for ExecuteResponse::CopyTo");
2128                match *resp {
2129                    ExecuteResponse::Subscribing {
2130                        rx,
2131                        ctx_extra,
2132                        instance_id,
2133                    } => {
2134                        let (result, statement_ended_execution_reason) = match self
2135                            .copy_rows(
2136                                format,
2137                                row_desc,
2138                                RecordFirstRowStream::new(
2139                                    Box::new(UnboundedReceiverStream::new(rx)),
2140                                    execute_started,
2141                                    &self.adapter_client,
2142                                    Some(instance_id),
2143                                    None,
2144                                ),
2145                            )
2146                            .await
2147                        {
2148                            Err(e) => {
2149                                // This is an error communicating with the connection.
2150                                // We consider that to be a cancelation, rather than a query error.
2151                                (Err(e), StatementEndedExecutionReason::Canceled)
2152                            }
2153                            Ok((
2154                                state,
2155                                SendRowsEndedReason::Success {
2156                                    result_size,
2157                                    rows_returned,
2158                                },
2159                            )) => (
2160                                Ok(state),
2161                                StatementEndedExecutionReason::Success {
2162                                    result_size: Some(result_size),
2163                                    rows_returned: Some(rows_returned),
2164                                    execution_strategy: None,
2165                                },
2166                            ),
2167                            Ok((state, SendRowsEndedReason::Errored { error })) => {
2168                                (Ok(state), StatementEndedExecutionReason::Errored { error })
2169                            }
2170                            Ok((state, SendRowsEndedReason::Canceled)) => {
2171                                (Ok(state), StatementEndedExecutionReason::Canceled)
2172                            }
2173                        };
2174                        self.adapter_client
2175                            .retire_execute(ctx_extra, statement_ended_execution_reason);
2176                        return result;
2177                    }
2178                    ExecuteResponse::SendingRowsStreaming {
2179                        rows,
2180                        instance_id,
2181                        strategy,
2182                    } => {
2183                        // We don't need to finalize execution here;
2184                        // it was already done in the
2185                        // coordinator. Just extract the state and
2186                        // return that.
2187                        return self
2188                            .copy_rows(
2189                                format,
2190                                row_desc,
2191                                RecordFirstRowStream::new(
2192                                    Box::new(rows),
2193                                    execute_started,
2194                                    &self.adapter_client,
2195                                    Some(instance_id),
2196                                    Some(strategy),
2197                                ),
2198                            )
2199                            .await
2200                            .map(|(state, _)| state);
2201                    }
2202                    ExecuteResponse::SendingRowsImmediate { rows } => {
2203                        let span = tracing::debug_span!("sending_rows_immediate");
2204
2205                        let rows = futures::stream::once(futures::future::ready(
2206                            PeekResponseUnary::Rows(rows),
2207                        ));
2208                        // We don't need to finalize execution here;
2209                        // it was already done in the
2210                        // coordinator. Just extract the state and
2211                        // return that.
2212                        return self
2213                            .copy_rows(
2214                                format,
2215                                row_desc,
2216                                RecordFirstRowStream::new(
2217                                    Box::new(rows),
2218                                    execute_started,
2219                                    &self.adapter_client,
2220                                    None,
2221                                    Some(StatementExecutionStrategy::Constant),
2222                                ),
2223                            )
2224                            .instrument(span)
2225                            .await
2226                            .map(|(state, _)| state);
2227                    }
2228                    _ => {
2229                        return self
2230                            .send_error_and_get_state(ErrorResponse::error(
2231                                SqlState::INTERNAL_ERROR,
2232                                "unsupported COPY response type".to_string(),
2233                            ))
2234                            .await;
2235                    }
2236                };
2237            }
2238            ExecuteResponse::CopyFrom {
2239                target_id,
2240                target_name,
2241                columns,
2242                params,
2243                ctx_extra,
2244            } => {
2245                let row_desc =
2246                    row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2247                self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2248                    .await
2249            }
2250            ExecuteResponse::TransactionCommitted { params }
2251            | ExecuteResponse::TransactionRolledBack { params } => {
2252                let notify_set: mz_ore::collections::HashSet<String> = self
2253                    .adapter_client
2254                    .session()
2255                    .vars()
2256                    .notify_set()
2257                    .map(|v| v.name().to_string())
2258                    .collect();
2259
2260                // Only report on parameters that are in the notify set.
2261                for (name, value) in params
2262                    .into_iter()
2263                    .filter(|(name, _v)| notify_set.contains(*name))
2264                {
2265                    let msg = BackendMessage::ParameterStatus(name, value);
2266                    self.send(msg).await?;
2267                }
2268                command_complete!()
2269            }
2270
2271            ExecuteResponse::AlteredDefaultPrivileges
2272            | ExecuteResponse::AlteredObject(..)
2273            | ExecuteResponse::AlteredRole
2274            | ExecuteResponse::AlteredSystemConfiguration
2275            | ExecuteResponse::CreatedCluster { .. }
2276            | ExecuteResponse::CreatedClusterReplica { .. }
2277            | ExecuteResponse::CreatedConnection { .. }
2278            | ExecuteResponse::CreatedDatabase { .. }
2279            | ExecuteResponse::CreatedIndex { .. }
2280            | ExecuteResponse::CreatedIntrospectionSubscribe
2281            | ExecuteResponse::CreatedMaterializedView { .. }
2282            | ExecuteResponse::CreatedRole
2283            | ExecuteResponse::CreatedSchema { .. }
2284            | ExecuteResponse::CreatedSecret { .. }
2285            | ExecuteResponse::CreatedSink { .. }
2286            | ExecuteResponse::CreatedSource { .. }
2287            | ExecuteResponse::CreatedTable { .. }
2288            | ExecuteResponse::CreatedType
2289            | ExecuteResponse::CreatedView { .. }
2290            | ExecuteResponse::CreatedViews { .. }
2291            | ExecuteResponse::CreatedNetworkPolicy
2292            | ExecuteResponse::Comment
2293            | ExecuteResponse::Deallocate { .. }
2294            | ExecuteResponse::Deleted(..)
2295            | ExecuteResponse::DiscardedAll
2296            | ExecuteResponse::DiscardedTemp
2297            | ExecuteResponse::DroppedObject(_)
2298            | ExecuteResponse::DroppedOwned
2299            | ExecuteResponse::GrantedPrivilege
2300            | ExecuteResponse::GrantedRole
2301            | ExecuteResponse::Inserted(..)
2302            | ExecuteResponse::Copied(..)
2303            | ExecuteResponse::Prepare
2304            | ExecuteResponse::Raised
2305            | ExecuteResponse::ReassignOwned
2306            | ExecuteResponse::RevokedPrivilege
2307            | ExecuteResponse::RevokedRole
2308            | ExecuteResponse::StartedTransaction { .. }
2309            | ExecuteResponse::Updated(..)
2310            | ExecuteResponse::ValidatedConnection => {
2311                command_complete!()
2312            }
2313        };
2314
2315        assert_none!(tag, "tag created but not consumed: {:?}", tag);
2316        r
2317    }
2318
2319    #[allow(clippy::too_many_arguments)]
2320    // TODO(guswynn): figure out how to get it to compile without skip_all
2321    #[mz_ore::instrument(level = "debug")]
2322    async fn send_rows(
2323        &mut self,
2324        row_desc: RelationDesc,
2325        portal_name: String,
2326        mut rows: InProgressRows,
2327        max_rows: ExecuteCount,
2328        get_response: GetResponse,
2329        fetch_portal_name: Option<String>,
2330        timeout: ExecuteTimeout,
2331    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2332        // If this portal is being executed from a FETCH then we need to use the result
2333        // format type of the outer portal.
2334        let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2335            name
2336        } else {
2337            &portal_name
2338        };
2339        let result_formats = self
2340            .adapter_client
2341            .session()
2342            .get_portal_unverified(result_format_portal_name)
2343            .expect("valid fetch portal name for send rows")
2344            .result_formats
2345            .clone();
2346
2347        let (mut wait_once, mut deadline) = match timeout {
2348            ExecuteTimeout::None => (false, None),
2349            ExecuteTimeout::Seconds(t) => (
2350                false,
2351                Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2352            ),
2353            ExecuteTimeout::WaitOnce => (true, None),
2354        };
2355
2356        // Sanity check that the various `RelationDesc`s match up.
2357        {
2358            let portal_name_desc = &self
2359                .adapter_client
2360                .session()
2361                .get_portal_unverified(portal_name.as_str())
2362                .expect("portal should exist")
2363                .desc
2364                .relation_desc;
2365            if let Some(portal_name_desc) = portal_name_desc {
2366                soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2367            }
2368            if let Some(fetch_portal_name) = &fetch_portal_name {
2369                let fetch_portal_desc = &self
2370                    .adapter_client
2371                    .session()
2372                    .get_portal_unverified(fetch_portal_name)
2373                    .expect("portal should exist")
2374                    .desc
2375                    .relation_desc;
2376                if let Some(fetch_portal_desc) = fetch_portal_desc {
2377                    soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2378                }
2379            }
2380        }
2381
2382        self.conn.set_encode_state(
2383            row_desc
2384                .typ()
2385                .column_types
2386                .iter()
2387                .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2388                .zip_eq(result_formats)
2389                .collect(),
2390        );
2391
2392        let mut total_sent_rows = 0;
2393        let mut total_sent_bytes = 0;
2394        // want_rows is the maximum number of rows the client wants.
2395        let mut want_rows = match max_rows {
2396            ExecuteCount::All => usize::MAX,
2397            ExecuteCount::Count(count) => count,
2398        };
2399
2400        // Send rows while the client still wants them and there are still rows to send.
2401        loop {
2402            // Fetch next batch of rows, waiting for a possible requested
2403            // timeout or notice.
2404            let batch = if rows.current.is_some() {
2405                FetchResult::Rows(rows.current.take())
2406            } else if want_rows == 0 {
2407                FetchResult::Rows(None)
2408            } else {
2409                let notice_fut = self.adapter_client.session().recv_notice();
2410                // Biased: drain available data before checking the deadline.
2411                // This is critical for the WaitOnce case, where the deadline
2412                // is set to `Instant::now()` right after the first batch:
2413                // without `biased`, `recv()` and the already-expired deadline
2414                // race nondeterministically, so we might break the loop
2415                // before `no_more_rows` is set (or even before ready rows
2416                // are consumed). With an explicit `TIMEOUT`, missing a batch
2417                // right at the boundary is acceptable, but WaitOnce fires
2418                // immediately and the race is not.
2419                //
2420                // Trade-off: if `recv()` keeps returning Ready (unlikely in
2421                // practice—row processing + flush is slower than upstream
2422                // tick granularity), a `TIMEOUT` deadline could be delayed.
2423                // See database-issues#9470.
2424                tokio::select! {
2425                    biased;
2426                    err = self.conn.wait_closed() => return Err(err),
2427                    batch = rows.remaining.recv() => match batch {
2428                        None => FetchResult::Rows(None),
2429                        Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2430                        Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2431                        Some(PeekResponseUnary::DependencyDropped(dep)) => {
2432                            FetchResult::Error(dep.query_terminated_error())
2433                        }
2434                        Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2435                    },
2436                    notice = notice_fut => {
2437                        FetchResult::Notice(notice)
2438                    }
2439                    _ = time::sleep_until(
2440                        deadline.unwrap_or_else(tokio::time::Instant::now),
2441                    ), if deadline.is_some() => FetchResult::Rows(None),
2442                }
2443            };
2444
2445            match batch {
2446                FetchResult::Rows(None) => break,
2447                FetchResult::Rows(Some(mut batch_rows)) => {
2448                    if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2449                        let msg = err.to_string();
2450                        return self
2451                            .send_error_and_get_state(err.into_response(Severity::Error))
2452                            .await
2453                            .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2454                    }
2455
2456                    // If wait_once is true: the first time this fn is called it blocks (same as
2457                    // deadline == None). The second time this fn is called it should behave the
2458                    // same a 0s timeout.
2459                    if wait_once && batch_rows.peek().is_some() {
2460                        deadline = Some(tokio::time::Instant::now());
2461                        wait_once = false;
2462                    }
2463
2464                    // Send a portion of the rows.
2465                    let mut sent_rows = 0;
2466                    let mut sent_bytes = 0;
2467                    let messages = (&mut batch_rows)
2468                        // TODO(parkmycar): This is a fair bit of juggling between iterator types
2469                        // to count the total number of bytes. Alternatively we could track the
2470                        // total sent bytes in this .map(...) call, but having side effects in map
2471                        // is a code smell.
2472                        .map(|row| {
2473                            let row_len = row.byte_len();
2474                            let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2475                            (row_len, BackendMessage::DataRow(values))
2476                        })
2477                        .inspect(|(row_len, _)| {
2478                            sent_bytes += row_len;
2479                            sent_rows += 1
2480                        })
2481                        .map(|(_row_len, row)| row)
2482                        .take(want_rows);
2483                    self.send_all(messages).await?;
2484
2485                    total_sent_rows += sent_rows;
2486                    total_sent_bytes += sent_bytes;
2487                    want_rows -= sent_rows;
2488
2489                    // If we have sent the number of requested rows, put the remainder of the batch
2490                    // (if any) back and stop sending.
2491                    if want_rows == 0 {
2492                        if batch_rows.peek().is_some() {
2493                            rows.current = Some(batch_rows);
2494                        }
2495                        break;
2496                    }
2497
2498                    self.conn.flush().await?;
2499                }
2500                FetchResult::Notice(notice) => {
2501                    self.send(notice.into_response()).await?;
2502                    self.conn.flush().await?;
2503                }
2504                FetchResult::Error(text) => {
2505                    return self
2506                        .send_error_and_get_state(ErrorResponse::error(
2507                            SqlState::INTERNAL_ERROR,
2508                            text.clone(),
2509                        ))
2510                        .await
2511                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2512                }
2513                FetchResult::Canceled => {
2514                    return self
2515                        .send_error_and_get_state(ErrorResponse::error(
2516                            SqlState::QUERY_CANCELED,
2517                            "canceling statement due to user request",
2518                        ))
2519                        .await
2520                        .map(|state| (state, SendRowsEndedReason::Canceled));
2521                }
2522            }
2523        }
2524
2525        let portal = self
2526            .adapter_client
2527            .session()
2528            .get_portal_unverified_mut(&portal_name)
2529            .expect("valid portal name for send rows");
2530
2531        let saw_rows = rows.remaining.saw_rows;
2532        let no_more_rows = rows.no_more_rows();
2533        let metric_recorded = rows.remaining.metric_recorded;
2534        let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2535
2536        if no_more_rows && !metric_recorded {
2537            rows.remaining.metric_recorded = true;
2538        }
2539
2540        // Always return rows back, even if it's empty. This prevents an unclosed
2541        // portal from re-executing after it has been emptied.
2542        *portal.state = PortalState::InProgress(Some(rows));
2543
2544        let fetch_portal = fetch_portal_name.map(|name| {
2545            self.adapter_client
2546                .session()
2547                .get_portal_unverified_mut(&name)
2548                .expect("valid fetch portal")
2549        });
2550        let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2551        self.send(response_message).await?;
2552
2553        // Attend to metrics if there are no more rows. Only record once per stream
2554        // to avoid polluting the histogram when an exhausted cursor is FETCHed again.
2555        if no_more_rows && !metric_recorded {
2556            let statement_type = if let Some(stmt) = &self
2557                .adapter_client
2558                .session()
2559                .get_portal_unverified(&portal_name)
2560                .expect("valid portal name for send_rows")
2561                .stmt
2562            {
2563                metrics::statement_type_label_value(stmt.deref())
2564            } else {
2565                "no-statement"
2566            };
2567            let duration = if saw_rows {
2568                recorded_first_row_instant
2569                    .expect("recorded_first_row_instant because saw_rows")
2570                    .elapsed()
2571            } else {
2572                // If the result is empty, then we define time from first to last row as 0.
2573                // (Note that, currently, an empty result involves a PeekResponse with 0 rows, which
2574                // does flip `saw_rows`, so this code path is currently not exercised.)
2575                Duration::ZERO
2576            };
2577            self.adapter_client
2578                .inner()
2579                .metrics()
2580                .result_rows_first_to_last_byte_seconds
2581                .with_label_values(&[statement_type])
2582                .observe(duration.as_secs_f64());
2583        }
2584
2585        Ok((
2586            State::Ready,
2587            SendRowsEndedReason::Success {
2588                result_size: u64::cast_from(total_sent_bytes),
2589                rows_returned: u64::cast_from(total_sent_rows),
2590            },
2591        ))
2592    }
2593
2594    #[mz_ore::instrument(level = "debug")]
2595    async fn copy_rows(
2596        &mut self,
2597        format: CopyFormat,
2598        row_desc: RelationDesc,
2599        mut stream: RecordFirstRowStream,
2600    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2601        let (row_format, encode_format) = match format {
2602            CopyFormat::Text => (
2603                CopyFormatParams::Text(CopyTextFormatParams::default()),
2604                Format::Text,
2605            ),
2606            CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2607            CopyFormat::Csv => (
2608                CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2609                Format::Text,
2610            ),
2611            CopyFormat::Parquet => {
2612                let text = "Parquet format is not supported".to_string();
2613                return self
2614                    .send_error_and_get_state(ErrorResponse::error(
2615                        SqlState::INTERNAL_ERROR,
2616                        text.clone(),
2617                    ))
2618                    .await
2619                    .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2620            }
2621        };
2622
2623        let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2624            mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2625        };
2626
2627        let typ = row_desc.typ();
2628        let column_formats = iter::repeat(encode_format)
2629            .take(typ.column_types.len())
2630            .collect();
2631        self.send(BackendMessage::CopyOutResponse {
2632            overall_format: encode_format,
2633            column_formats,
2634        })
2635        .await?;
2636
2637        // In Postgres, binary copy has a header that is followed (in the same
2638        // CopyData) by the first row. In order to replicate their behavior, use a
2639        // common vec that we can extend one time now and then fill up with the encode
2640        // functions.
2641        let mut out = Vec::new();
2642
2643        if let CopyFormat::Binary = format {
2644            // 11-byte signature.
2645            out.extend(b"PGCOPY\n\xFF\r\n\0");
2646            // 32-bit flags field.
2647            out.extend([0, 0, 0, 0]);
2648            // 32-bit header extension length field.
2649            out.extend([0, 0, 0, 0]);
2650        }
2651
2652        let mut count = 0;
2653        let mut total_sent_bytes = 0;
2654        loop {
2655            tokio::select! {
2656                e = self.conn.wait_closed() => return Err(e),
2657                batch = stream.recv() => match batch {
2658                    None => break,
2659                    Some(PeekResponseUnary::Error(text)) => {
2660                        let err =
2661                            ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone());
2662                        return self
2663                            .send_error_and_get_state(err)
2664                            .await
2665                            .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2666                    }
2667                    Some(PeekResponseUnary::DependencyDropped(dep)) => {
2668                        let err = dep.to_concurrent_dependency_drop();
2669                        let text = err.to_string();
2670                        let resp = err.into_response(Severity::Error);
2671                        return self
2672                            .send_error_and_get_state(resp)
2673                            .await
2674                            .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2675                    }
2676                    Some(PeekResponseUnary::Canceled) => {
2677                        return self.send_error_and_get_state(ErrorResponse::error(
2678                                SqlState::QUERY_CANCELED,
2679                                "canceling statement due to user request",
2680                            ))
2681                            .await.map(|state| (state, SendRowsEndedReason::Canceled));
2682                    }
2683                    Some(PeekResponseUnary::Rows(mut rows)) => {
2684                        count += rows.count();
2685                        while let Some(row) = rows.next() {
2686                            total_sent_bytes += row.byte_len();
2687                            encode_fn(row, typ, &mut out)?;
2688                            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2689                                .await?;
2690                        }
2691                    }
2692                },
2693                notice = self.adapter_client.session().recv_notice() => {
2694                    self.send(notice.into_response())
2695                        .await?;
2696                    self.conn.flush().await?;
2697                }
2698            }
2699
2700            self.conn.flush().await?;
2701        }
2702        // Send required trailers.
2703        if let CopyFormat::Binary = format {
2704            let trailer: i16 = -1;
2705            out.extend(trailer.to_be_bytes());
2706            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2707                .await?;
2708        }
2709
2710        let tag = format!("COPY {}", count);
2711        self.send(BackendMessage::CopyDone).await?;
2712        self.send(BackendMessage::CommandComplete { tag }).await?;
2713        Ok((
2714            State::Ready,
2715            SendRowsEndedReason::Success {
2716                result_size: u64::cast_from(total_sent_bytes),
2717                rows_returned: u64::cast_from(count),
2718            },
2719        ))
2720    }
2721
2722    /// Handles the copy-in mode of the postgres protocol from transferring
2723    /// data to the server.
2724    #[instrument(level = "debug")]
2725    async fn copy_from(
2726        &mut self,
2727        target_id: CatalogItemId,
2728        target_name: String,
2729        columns: Vec<ColumnIndex>,
2730        params: CopyFormatParams<'static>,
2731        row_desc: RelationDesc,
2732        mut ctx_extra: ExecuteContextGuard,
2733    ) -> Result<State, io::Error> {
2734        let res = self
2735            .copy_from_inner(
2736                target_id,
2737                target_name,
2738                columns,
2739                params,
2740                row_desc,
2741                &mut ctx_extra,
2742            )
2743            .await;
2744        match &res {
2745            Ok(State::Ready) => {
2746                self.adapter_client.retire_execute(
2747                    ctx_extra,
2748                    StatementEndedExecutionReason::Success {
2749                        result_size: None,
2750                        rows_returned: None,
2751                        execution_strategy: None,
2752                    },
2753                );
2754            }
2755            Ok(State::Done) => {
2756                // The connection closed gracefully without sending us a `CopyDone`,
2757                // causing us to just drop the copy request.
2758                // For the purposes of statement logging, we count this as a cancellation.
2759                self.adapter_client
2760                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2761            }
2762            Err(e) => {
2763                self.adapter_client.retire_execute(
2764                    ctx_extra,
2765                    StatementEndedExecutionReason::Errored {
2766                        error: format!("{e}"),
2767                    },
2768                );
2769            }
2770            Ok(State::Drain) => {}
2771        }
2772        res
2773    }
2774
2775    async fn copy_from_inner(
2776        &mut self,
2777        target_id: CatalogItemId,
2778        target_name: String,
2779        columns: Vec<ColumnIndex>,
2780        params: CopyFormatParams<'static>,
2781        row_desc: RelationDesc,
2782        ctx_extra: &mut ExecuteContextGuard,
2783    ) -> Result<State, io::Error> {
2784        let typ = row_desc.typ();
2785        let column_formats = vec![Format::Text; typ.column_types.len()];
2786        self.send(BackendMessage::CopyInResponse {
2787            overall_format: Format::Text,
2788            column_formats,
2789        })
2790        .await?;
2791        self.conn.flush().await?;
2792
2793        // Set up the parallel streaming batch builders in the coordinator.
2794        let writer = match self
2795            .adapter_client
2796            .start_copy_from_stdin(
2797                target_id,
2798                target_name.clone(),
2799                columns.clone(),
2800                row_desc.clone(),
2801                params.clone(),
2802            )
2803            .await
2804        {
2805            Ok(writer) => writer,
2806            Err(e) => {
2807                // Drain remaining CopyData/CopyDone/CopyFail messages from the
2808                // socket. Since CopyInResponse was already sent, the client may
2809                // have pipelined copy data that we must consume before returning
2810                // the error, otherwise they'd be misinterpreted as top-level
2811                // protocol messages and cause a deadlock.
2812                loop {
2813                    match self.conn.recv().await? {
2814                        Some(FrontendMessage::CopyData(_)) => {}
2815                        Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2816                            break;
2817                        }
2818                        Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2819                        Some(_) => break,
2820                        None => return Ok(State::Done),
2821                    }
2822                }
2823                self.adapter_client.retire_execute(
2824                    std::mem::take(ctx_extra),
2825                    StatementEndedExecutionReason::Errored {
2826                        error: e.to_string(),
2827                    },
2828                );
2829                return self
2830                    .send_error_and_get_state(e.into_response(Severity::Error))
2831                    .await;
2832            }
2833        };
2834
2835        // Enable copy mode on the codec to skip aggregate buffer size checks.
2836        self.conn.set_copy_mode(true);
2837
2838        // Batch size for splitting raw data across parallel workers (~32MB).
2839        const BATCH_SIZE: usize = 32 * 1024 * 1024;
2840        let max_copy_from_row_size = self
2841            .adapter_client
2842            .get_system_vars()
2843            .await
2844            .max_copy_from_row_size()
2845            .try_into()
2846            .unwrap_or(usize::MAX);
2847
2848        let mut data = Vec::new();
2849        let mut row_scanner = CopyRowScanner::new(&params);
2850        let num_workers = writer.batch_txs.len();
2851        let mut next_worker: usize = 0;
2852        let mut saw_copy_done = false;
2853        let mut saw_end_marker = false;
2854        let mut copy_from_error: Option<(SqlState, String)> = None;
2855
2856        // Receive loop: accumulate CopyData, split at row boundaries,
2857        // round-robin raw chunks to parallel batch builder workers.
2858        loop {
2859            let message = self.conn.recv().await?;
2860            match message {
2861                Some(FrontendMessage::CopyData(buf)) => {
2862                    if saw_end_marker {
2863                        // Per PostgreSQL COPY behavior, ignore all bytes after
2864                        // the end-of-copy marker until CopyDone.
2865                        continue;
2866                    }
2867                    data.extend(buf);
2868                    row_scanner.scan_new_bytes(&data);
2869
2870                    if let Some(end_pos) = row_scanner.end_marker_end() {
2871                        data.truncate(end_pos);
2872                        row_scanner.on_truncate(end_pos);
2873                        saw_end_marker = true;
2874                    }
2875
2876                    // Guard against pathological single rows that never terminate.
2877                    if row_scanner.current_row_size(data.len()) > max_copy_from_row_size {
2878                        copy_from_error = Some((
2879                            SqlState::INSUFFICIENT_RESOURCES,
2880                            format!(
2881                                "COPY FROM STDIN row exceeded max_copy_from_row_size \
2882                                 ({max_copy_from_row_size} bytes)"
2883                            ),
2884                        ));
2885                        break;
2886                    }
2887
2888                    // When buffer exceeds batch size, split at the last complete row
2889                    // and send the complete rows chunk to the next worker.
2890                    let mut send_failed = false;
2891                    while data.len() >= BATCH_SIZE {
2892                        let split_pos = match row_scanner.last_row_end() {
2893                            Some(pos) => pos,
2894                            None => break, // no complete row yet
2895                        };
2896                        let remainder = data.split_off(split_pos);
2897                        let chunk = std::mem::replace(&mut data, remainder);
2898                        row_scanner.on_split(split_pos);
2899                        if writer.batch_txs[next_worker].send(chunk).await.is_err() {
2900                            send_failed = true;
2901                            break;
2902                        }
2903                        next_worker = (next_worker + 1) % num_workers;
2904                    }
2905                    // Worker dropped (likely errored) — stop sending,
2906                    // fall through to completion_rx for the real error.
2907                    if send_failed {
2908                        break;
2909                    }
2910                }
2911                Some(FrontendMessage::CopyDone) => {
2912                    // Send any remaining data to the next worker.
2913                    if !data.is_empty() {
2914                        let chunk = std::mem::take(&mut data);
2915                        // Ignore send failure — completion_rx will have the error.
2916                        let _ = writer.batch_txs[next_worker].send(chunk).await;
2917                    }
2918                    saw_copy_done = true;
2919                    break;
2920                }
2921                Some(FrontendMessage::CopyFail(err)) => {
2922                    self.adapter_client.retire_execute(
2923                        std::mem::take(ctx_extra),
2924                        StatementEndedExecutionReason::Canceled,
2925                    );
2926                    // Drop the writer to signal cancellation to the background tasks.
2927                    drop(writer);
2928                    self.conn.set_copy_mode(false);
2929                    return self
2930                        .send_error_and_get_state(ErrorResponse::error(
2931                            SqlState::QUERY_CANCELED,
2932                            format!("COPY from stdin failed: {}", err),
2933                        ))
2934                        .await;
2935                }
2936                Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2937                Some(_) => {
2938                    let msg = "unexpected message type during COPY from stdin";
2939                    self.adapter_client.retire_execute(
2940                        std::mem::take(ctx_extra),
2941                        StatementEndedExecutionReason::Errored {
2942                            error: msg.to_string(),
2943                        },
2944                    );
2945                    drop(writer);
2946                    self.conn.set_copy_mode(false);
2947                    return self
2948                        .send_error_and_get_state(ErrorResponse::error(
2949                            SqlState::PROTOCOL_VIOLATION,
2950                            msg,
2951                        ))
2952                        .await;
2953                }
2954                None => {
2955                    drop(writer);
2956                    self.conn.set_copy_mode(false);
2957                    return Ok(State::Done);
2958                }
2959            }
2960        }
2961
2962        // If we exited the receive loop before seeing `CopyDone` (e.g. because
2963        // a worker failed and dropped its channel), keep draining COPY input to
2964        // avoid desynchronizing the protocol state machine.
2965        if !saw_copy_done {
2966            loop {
2967                match self.conn.recv().await? {
2968                    Some(FrontendMessage::CopyData(_)) => {}
2969                    Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2970                        break;
2971                    }
2972                    Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2973                    Some(_) => {
2974                        let msg = "unexpected message type during COPY from stdin";
2975                        self.adapter_client.retire_execute(
2976                            std::mem::take(ctx_extra),
2977                            StatementEndedExecutionReason::Errored {
2978                                error: msg.to_string(),
2979                            },
2980                        );
2981                        drop(writer);
2982                        self.conn.set_copy_mode(false);
2983                        return self
2984                            .send_error_and_get_state(ErrorResponse::error(
2985                                SqlState::PROTOCOL_VIOLATION,
2986                                msg,
2987                            ))
2988                            .await;
2989                    }
2990                    None => {
2991                        drop(writer);
2992                        self.conn.set_copy_mode(false);
2993                        return Ok(State::Done);
2994                    }
2995                }
2996            }
2997        }
2998
2999        if let Some((code, msg)) = copy_from_error {
3000            self.adapter_client.retire_execute(
3001                std::mem::take(ctx_extra),
3002                StatementEndedExecutionReason::Errored { error: msg.clone() },
3003            );
3004            drop(writer);
3005            self.conn.set_copy_mode(false);
3006            return self
3007                .send_error_and_get_state(ErrorResponse::error(code, msg))
3008                .await;
3009        }
3010
3011        self.conn.set_copy_mode(false);
3012
3013        // Drop all senders to signal EOF to the background batch builders.
3014        // If copy_err is set, a worker already failed — dropping the senders
3015        // will cause remaining workers to stop, and we'll get the real error
3016        // from completion_rx below.
3017        drop(writer.batch_txs);
3018
3019        // Wait for all parallel workers to finish building batches.
3020        let (proto_batches, row_count) = match writer.completion_rx.await {
3021            Ok(Ok(result)) => result,
3022            Ok(Err(e)) => {
3023                self.adapter_client.retire_execute(
3024                    std::mem::take(ctx_extra),
3025                    StatementEndedExecutionReason::Errored {
3026                        error: e.to_string(),
3027                    },
3028                );
3029                return self
3030                    .send_error_and_get_state(e.into_response(Severity::Error))
3031                    .await;
3032            }
3033            Err(_) => {
3034                let msg = "COPY FROM STDIN: background batch builder tasks dropped";
3035                self.adapter_client.retire_execute(
3036                    std::mem::take(ctx_extra),
3037                    StatementEndedExecutionReason::Errored {
3038                        error: msg.to_string(),
3039                    },
3040                );
3041                return self
3042                    .send_error_and_get_state(ErrorResponse::error(SqlState::INTERNAL_ERROR, msg))
3043                    .await;
3044            }
3045        };
3046
3047        // Stage all batches in the session's transaction for atomic commit.
3048        if let Err(e) = self
3049            .adapter_client
3050            .stage_copy_from_stdin_batches(target_id, proto_batches)
3051        {
3052            self.adapter_client.retire_execute(
3053                std::mem::take(ctx_extra),
3054                StatementEndedExecutionReason::Errored {
3055                    error: e.to_string(),
3056                },
3057            );
3058            return self
3059                .send_error_and_get_state(e.into_response(Severity::Error))
3060                .await;
3061        }
3062
3063        let tag = format!("COPY {}", row_count);
3064        self.send(BackendMessage::CommandComplete { tag }).await?;
3065
3066        Ok(State::Ready)
3067    }
3068
3069    #[instrument(level = "debug")]
3070    async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
3071        let notices = self
3072            .adapter_client
3073            .session()
3074            .drain_notices()
3075            .into_iter()
3076            .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
3077        self.send_all(notices).await?;
3078        Ok(())
3079    }
3080
3081    #[instrument(level = "debug")]
3082    async fn send_error_and_get_state(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
3083        assert!(err.severity.is_error());
3084        debug!(
3085            "cid={} error code={}",
3086            self.adapter_client.session().conn_id(),
3087            err.code.code()
3088        );
3089        let is_fatal = err.severity.is_fatal();
3090        self.send(BackendMessage::ErrorResponse(err)).await?;
3091
3092        let txn = self.adapter_client.session().transaction();
3093        match txn {
3094            // Error can be called from describe and parse and so might not be in an active
3095            // transaction.
3096            TransactionStatus::Default | TransactionStatus::Failed(_) => {}
3097            // In Started (i.e., a single statement), cleanup ourselves.
3098            TransactionStatus::Started(_) => {
3099                self.rollback_transaction().await?;
3100            }
3101            // Implicit transactions also clear themselves.
3102            TransactionStatus::InTransactionImplicit(_) => {
3103                self.rollback_transaction().await?;
3104            }
3105            // Explicit transactions move to failed.
3106            TransactionStatus::InTransaction(_) => {
3107                self.adapter_client.fail_transaction();
3108            }
3109        };
3110        if is_fatal {
3111            Ok(State::Done)
3112        } else {
3113            Ok(State::Drain)
3114        }
3115    }
3116
3117    #[instrument(level = "debug")]
3118    async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
3119        self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
3120            SqlState::IN_FAILED_SQL_TRANSACTION,
3121            ABORTED_TXN_MSG,
3122        )))
3123        .await?;
3124        Ok(State::Drain)
3125    }
3126
3127    fn is_aborted_txn(&mut self) -> bool {
3128        matches!(
3129            self.adapter_client.session().transaction(),
3130            TransactionStatus::Failed(_)
3131        )
3132    }
3133}
3134
3135fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
3136    match (formats.len(), n) {
3137        (0, e) => Ok(vec![Format::Text; e]),
3138        (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
3139        (a, e) if a == e => Ok(formats),
3140        (a, e) => Err(format!(
3141            "expected {} field format specifiers, but got {}",
3142            e, a
3143        )),
3144    }
3145}
3146
3147fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
3148    match &stmt_desc.relation_desc {
3149        Some(desc) if !stmt_desc.is_copy => {
3150            BackendMessage::RowDescription(message::encode_row_description(desc, formats))
3151        }
3152        _ => BackendMessage::NoData,
3153    }
3154}
3155
3156type GetResponse = fn(
3157    max_rows: ExecuteCount,
3158    total_sent_rows: usize,
3159    fetch_portal: Option<PortalRefMut>,
3160) -> BackendMessage;
3161
3162// A GetResponse used by send_rows during execute messages on portals or for
3163// simple query messages.
3164fn portal_exec_message(
3165    max_rows: ExecuteCount,
3166    total_sent_rows: usize,
3167    _fetch_portal: Option<PortalRefMut>,
3168) -> BackendMessage {
3169    // If max_rows is not specified, we will always send back a CommandComplete. If
3170    // max_rows is specified, we only send CommandComplete if there were more rows
3171    // requested than were remaining. That is, if max_rows == number of rows that
3172    // were remaining before sending (not that are remaining after sending), then
3173    // we still send a PortalSuspended. The number of remaining rows after the rows
3174    // have been sent doesn't matter. This matches postgres.
3175    match max_rows {
3176        ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
3177            BackendMessage::PortalSuspended
3178        }
3179        _ => BackendMessage::CommandComplete {
3180            tag: format!("SELECT {}", total_sent_rows),
3181        },
3182    }
3183}
3184
3185// A GetResponse used by send_rows during FETCH queries.
3186fn fetch_message(
3187    _max_rows: ExecuteCount,
3188    total_sent_rows: usize,
3189    fetch_portal: Option<PortalRefMut>,
3190) -> BackendMessage {
3191    let tag = format!("FETCH {}", total_sent_rows);
3192    if let Some(portal) = fetch_portal {
3193        *portal.state = PortalState::Completed(Some(tag.clone()));
3194    }
3195    BackendMessage::CommandComplete { tag }
3196}
3197
3198fn get_authenticator(
3199    authenticator_kind: listeners::AuthenticatorKind,
3200    frontegg: Option<FronteggAuthenticator>,
3201    oidc: GenericOidcAuthenticator,
3202    adapter_client: mz_adapter::Client,
3203) -> Authenticator {
3204    match authenticator_kind {
3205        listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
3206            "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
3207        )),
3208        listeners::AuthenticatorKind::Password => Authenticator::Password(adapter_client),
3209        listeners::AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client),
3210        listeners::AuthenticatorKind::Oidc => Authenticator::Oidc(oidc),
3211        listeners::AuthenticatorKind::None => Authenticator::None,
3212    }
3213}
3214
3215#[derive(Debug, Copy, Clone)]
3216enum ExecuteCount {
3217    All,
3218    Count(usize),
3219}
3220
3221// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
3222fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
3223    match stmt {
3224        // Add PREPARE to this if we ever support it.
3225        Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
3226        None => false,
3227    }
3228}
3229
3230#[derive(Debug)]
3231enum FetchResult {
3232    Rows(Option<Box<dyn RowIterator + Send + Sync>>),
3233    Canceled,
3234    Error(String),
3235    Notice(AdapterNotice),
3236}
3237
3238#[derive(Debug)]
3239struct CopyRowScanner {
3240    scan_pos: usize,
3241    last_row_end: Option<usize>,
3242    end_marker_end: Option<usize>,
3243    // Byte offset within `data` at which the in-progress CSV record begins.
3244    // Used to verify the end-of-copy marker against the raw input bytes,
3245    // distinguishing a literal `\.` line from a quoted CSV value `"\."`
3246    // whose decoded form is also `\.`.
3247    record_start: usize,
3248    csv: Option<CsvScanState>,
3249}
3250
3251#[derive(Debug)]
3252struct CsvScanState {
3253    reader: csv_core::Reader,
3254    output: Vec<u8>,
3255    ends: Vec<usize>,
3256    skip_first_record: bool,
3257}
3258
3259impl CopyRowScanner {
3260    fn new(params: &CopyFormatParams<'_>) -> Self {
3261        let csv = match params {
3262            CopyFormatParams::Csv(CopyCsvFormatParams {
3263                delimiter,
3264                quote,
3265                escape,
3266                header,
3267                ..
3268            }) => Some(CsvScanState::new(*delimiter, *quote, *escape, *header)),
3269            _ => None,
3270        };
3271
3272        CopyRowScanner {
3273            scan_pos: 0,
3274            last_row_end: None,
3275            end_marker_end: None,
3276            record_start: 0,
3277            csv,
3278        }
3279    }
3280
3281    fn scan_new_bytes(&mut self, data: &[u8]) {
3282        if self.scan_pos >= data.len() {
3283            return;
3284        }
3285
3286        if let Some(csv) = self.csv.as_mut() {
3287            let mut input = &data[self.scan_pos..];
3288            let mut consumed = 0usize;
3289            while !input.is_empty() {
3290                let (result, n_input, _n_output, _n_ends) =
3291                    csv.reader
3292                        .read_record(input, &mut csv.output, &mut csv.ends);
3293                consumed += n_input;
3294                input = &input[n_input..];
3295
3296                match result {
3297                    ReadRecordResult::InputEmpty => break,
3298                    ReadRecordResult::OutputFull => {
3299                        if n_input == 0 {
3300                            csv.output
3301                                .resize(csv.output.len().saturating_mul(2).max(1), 0);
3302                        }
3303                    }
3304                    ReadRecordResult::OutputEndsFull => {
3305                        if n_input == 0 {
3306                            csv.ends.resize(csv.ends.len().saturating_mul(2).max(1), 0);
3307                        }
3308                    }
3309                    ReadRecordResult::Record | ReadRecordResult::End => {
3310                        let row_end = self.scan_pos + consumed;
3311                        self.last_row_end = Some(row_end);
3312                        if self.end_marker_end.is_none() {
3313                            let is_marker = if csv.skip_first_record {
3314                                csv.skip_first_record = false;
3315                                false
3316                            } else {
3317                                // Detect the marker against the raw input
3318                                // bytes, not the CSV-decoded record. A quoted
3319                                // data row `"\."` decodes to `\.` but must be
3320                                // imported as data; only a bare `\.` line
3321                                // terminates the COPY.
3322                                let raw = &data[self.record_start..row_end];
3323                                // csv-core ends a CRLF record after the `\r`,
3324                                // leaving the trailing `\n` as the leading byte
3325                                // of the next record's span; a CR-only record
3326                                // ends in a lone `\r`. So a `\.` marker record's
3327                                // raw span can be `\.\n` (LF), `\n\.\r` (CRLF)
3328                                // or `\.\r` (CR). Trim CR/LF from both ends
3329                                // before comparing — a trailing-only strip would
3330                                // miss the CRLF/CR forms. Quoted `"\."` data
3331                                // keeps its surrounding quotes after trimming and
3332                                // is therefore correctly rejected.
3333                                let start = raw
3334                                    .iter()
3335                                    .take_while(|&&b| b == b'\r' || b == b'\n')
3336                                    .count();
3337                                let trailing = raw[start..]
3338                                    .iter()
3339                                    .rev()
3340                                    .take_while(|&&b| b == b'\r' || b == b'\n')
3341                                    .count();
3342                                let trimmed = &raw[start..raw.len() - trailing];
3343                                trimmed == b"\\."
3344                            };
3345                            if is_marker {
3346                                self.end_marker_end = Some(row_end);
3347                                self.record_start = row_end;
3348                                break;
3349                            }
3350                        }
3351                        self.record_start = row_end;
3352                    }
3353                }
3354            }
3355        } else {
3356            let mut row_start = self.last_row_end.unwrap_or(0);
3357            for (offset, b) in data[self.scan_pos..].iter().enumerate() {
3358                if *b == b'\n' {
3359                    let row_end = self.scan_pos + offset + 1;
3360                    self.last_row_end = Some(row_end);
3361                    if self.end_marker_end.is_none() {
3362                        let row = &data[row_start..row_end];
3363                        if row.get(0..2) == Some(b"\\.") {
3364                            self.end_marker_end = Some(row_end);
3365                            break;
3366                        }
3367                    }
3368                    row_start = row_end;
3369                }
3370            }
3371        }
3372
3373        self.scan_pos = data.len();
3374    }
3375
3376    fn last_row_end(&self) -> Option<usize> {
3377        self.last_row_end
3378    }
3379
3380    fn end_marker_end(&self) -> Option<usize> {
3381        self.end_marker_end
3382    }
3383
3384    fn current_row_size(&self, data_len: usize) -> usize {
3385        data_len.saturating_sub(self.last_row_end.unwrap_or(0))
3386    }
3387
3388    fn on_split(&mut self, split_pos: usize) {
3389        self.scan_pos = self.scan_pos.saturating_sub(split_pos);
3390        self.last_row_end = None;
3391        self.end_marker_end = self
3392            .end_marker_end
3393            .and_then(|end| end.checked_sub(split_pos));
3394        // `record_start` is only maintained for the CSV path; the text and
3395        // binary paths leave it at 0. For CSV, splits always occur at a
3396        // completed-row boundary, so the in-progress record (if any) starts at
3397        // the new beginning of the buffer. Assert that invariant so the
3398        // `saturating_sub` below doesn't silently paper over a bug that
3399        // bisected an in-progress record — but only when CSV is in use, since
3400        // otherwise `record_start` is meaninglessly 0.
3401        soft_assert_or_log!(
3402            self.csv.is_none() || self.record_start >= split_pos,
3403            "split bisected an in-progress CSV record: record_start={} < split_pos={}",
3404            self.record_start,
3405            split_pos,
3406        );
3407        self.record_start = self.record_start.saturating_sub(split_pos);
3408    }
3409
3410    fn on_truncate(&mut self, new_len: usize) {
3411        self.scan_pos = self.scan_pos.min(new_len);
3412        self.last_row_end = self.last_row_end.filter(|&end| end <= new_len);
3413        self.end_marker_end = self.end_marker_end.filter(|&end| end <= new_len);
3414        self.record_start = self.record_start.min(new_len);
3415    }
3416}
3417
3418impl CsvScanState {
3419    fn new(delimiter: u8, quote: u8, escape: u8, header: bool) -> Self {
3420        let (double_quote, escape) = if quote == escape {
3421            (true, None)
3422        } else {
3423            (false, Some(escape))
3424        };
3425        CsvScanState {
3426            reader: csv_core::ReaderBuilder::new()
3427                .delimiter(delimiter)
3428                .quote(quote)
3429                .double_quote(double_quote)
3430                .escape(escape)
3431                .build(),
3432            output: vec![0; 1],
3433            ends: vec![0; 1],
3434            skip_first_record: header,
3435        }
3436    }
3437}
3438
3439#[cfg(test)]
3440mod test {
3441    use super::*;
3442
3443    #[mz_ore::test]
3444    fn test_copy_row_scanner_end_marker_line_endings() {
3445        // The pgwire COPY row scanner must detect a bare `\.` end-of-copy
3446        // marker for every line ending, and must never mistake a quoted
3447        // `"\."` data row for it. csv-core ends a CRLF record after the `\r`
3448        // (leaving the `\n` as the next record's leading byte), so the raw
3449        // record span of a `\.` marker is `\.\n` (LF), `\n\.\r` (CRLF) or
3450        // `\.\r` (CR); a trailing-only strip would miss the CRLF/CR forms and
3451        // silently import post-marker rows.
3452        let params = CopyFormatParams::Csv(CopyCsvFormatParams::default());
3453
3454        let marker_end = |data: &[u8]| -> Option<usize> {
3455            let mut scanner = CopyRowScanner::new(&params);
3456            scanner.scan_new_bytes(data);
3457            scanner.end_marker_end()
3458        };
3459
3460        for eol in [&b"\n"[..], b"\r\n", b"\r"] {
3461            let join = |lines: &[&str]| -> Vec<u8> {
3462                let mut out = Vec::new();
3463                for line in lines {
3464                    out.extend_from_slice(line.as_bytes());
3465                    out.extend_from_slice(eol);
3466                }
3467                out
3468            };
3469
3470            // Bare `\.` (the marker is the second record, so record_start has
3471            // already advanced past the orphaned terminator of `first`).
3472            // csv-core reports the record after a single terminator byte, so
3473            // the marker boundary sits just past `first<eol>\.` + one byte.
3474            let data = join(&["first", "\\.", "after"]);
3475            let mut prefix = Vec::new();
3476            prefix.extend_from_slice(b"first");
3477            prefix.extend_from_slice(eol);
3478            prefix.extend_from_slice(b"\\.");
3479            assert_eq!(
3480                marker_end(&data),
3481                Some(prefix.len() + 1),
3482                "bare marker, eol={eol:?}"
3483            );
3484
3485            // Quoted "\." is data, not the marker.
3486            let data = join(&["before", "\"\\.\"", "after"]);
3487            assert_eq!(marker_end(&data), None, "quoted marker, eol={eol:?}");
3488        }
3489    }
3490
3491    #[mz_ore::test]
3492    fn test_copy_row_scanner_non_csv_split() {
3493        // Regression: `record_start` is only maintained for the CSV path; the
3494        // text and binary paths leave it at 0. `on_split` must therefore not
3495        // assert `record_start >= split_pos` for those formats — that fires on
3496        // every split of a large text/binary COPY stream (soft-assertions
3497        // panic under test). Mirrors `COPY ... FROM STDIN` (default text
3498        // format) splitting at a row boundary once the buffer fills.
3499        for params in [
3500            CopyFormatParams::Text(CopyTextFormatParams::default()),
3501            CopyFormatParams::Binary,
3502        ] {
3503            let mut scanner = CopyRowScanner::new(&params);
3504            let data = b"1\thello world\t2\tsome text value here\n\
3505                         3\thello world\t6\tsome text value here\n";
3506            scanner.scan_new_bytes(data);
3507            let split_pos = scanner.last_row_end().expect("a complete row");
3508            assert!(split_pos > 0, "params={params:?}");
3509            // Must not panic via the CSV-only `on_split` soft-assert.
3510            scanner.on_split(split_pos);
3511            assert_eq!(scanner.record_start, 0, "params={params:?}");
3512        }
3513    }
3514
3515    #[mz_ore::test]
3516    fn test_parse_options() {
3517        struct TestCase {
3518            input: &'static str,
3519            expect: Result<Vec<(&'static str, &'static str)>, ()>,
3520        }
3521        let tests = vec![
3522            TestCase {
3523                input: "",
3524                expect: Ok(vec![]),
3525            },
3526            TestCase {
3527                input: "--key",
3528                expect: Err(()),
3529            },
3530            TestCase {
3531                input: "--key=val",
3532                expect: Ok(vec![("key", "val")]),
3533            },
3534            TestCase {
3535                input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
3536                expect: Ok(vec![
3537                    ("key", "val"),
3538                    ("key2", "val2"),
3539                    ("key3", "val3"),
3540                    ("key4", "val4"),
3541                    ("key5", "val5"),
3542                ]),
3543            },
3544            TestCase {
3545                input: r#"-c\ key=val"#,
3546                expect: Ok(vec![(" key", "val")]),
3547            },
3548            TestCase {
3549                input: "--key=val -ckey2 val2",
3550                expect: Err(()),
3551            },
3552            // Unclear what this should do.
3553            TestCase {
3554                input: "--key=",
3555                expect: Ok(vec![("key", "")]),
3556            },
3557        ];
3558        for test in tests {
3559            let got = parse_options(test.input);
3560            let expect = test.expect.map(|r| {
3561                r.into_iter()
3562                    .map(|(k, v)| (k.to_owned(), v.to_owned()))
3563                    .collect()
3564            });
3565            assert_eq!(got, expect, "input: {}", test.input);
3566        }
3567    }
3568
3569    #[mz_ore::test]
3570    fn test_parse_option() {
3571        struct TestCase {
3572            input: &'static str,
3573            expect: Result<(&'static str, &'static str), ()>,
3574        }
3575        let tests = vec![
3576            TestCase {
3577                input: "",
3578                expect: Err(()),
3579            },
3580            TestCase {
3581                input: "--",
3582                expect: Err(()),
3583            },
3584            TestCase {
3585                input: "--c",
3586                expect: Err(()),
3587            },
3588            TestCase {
3589                input: "a=b",
3590                expect: Err(()),
3591            },
3592            TestCase {
3593                input: "--a=b",
3594                expect: Ok(("a", "b")),
3595            },
3596            TestCase {
3597                input: "--ca=b",
3598                expect: Ok(("ca", "b")),
3599            },
3600            TestCase {
3601                input: "-ca=b",
3602                expect: Ok(("a", "b")),
3603            },
3604            // Unclear what this should error, but at least test it.
3605            TestCase {
3606                input: "--=",
3607                expect: Ok(("", "")),
3608            },
3609        ];
3610        for test in tests {
3611            let got = parse_option(test.input);
3612            assert_eq!(got, test.expect, "input: {}", test.input);
3613        }
3614    }
3615
3616    #[mz_ore::test]
3617    fn test_split_options() {
3618        struct TestCase {
3619            input: &'static str,
3620            expect: Vec<&'static str>,
3621        }
3622        let tests = vec![
3623            TestCase {
3624                input: "",
3625                expect: vec![],
3626            },
3627            TestCase {
3628                input: "  ",
3629                expect: vec![],
3630            },
3631            TestCase {
3632                input: " a ",
3633                expect: vec!["a"],
3634            },
3635            TestCase {
3636                input: "  ab     cd   ",
3637                expect: vec!["ab", "cd"],
3638            },
3639            TestCase {
3640                input: r#"  ab\     cd   "#,
3641                expect: vec!["ab ", "cd"],
3642            },
3643            TestCase {
3644                input: r#"  ab\\     cd   "#,
3645                expect: vec![r#"ab\"#, "cd"],
3646            },
3647            TestCase {
3648                input: r#"  ab\\\     cd   "#,
3649                expect: vec![r#"ab\ "#, "cd"],
3650            },
3651            TestCase {
3652                input: r#"  ab\\\ cd   "#,
3653                expect: vec![r#"ab\ cd"#],
3654            },
3655            TestCase {
3656                input: r#"  ab\\\cd   "#,
3657                expect: vec![r#"ab\cd"#],
3658            },
3659            TestCase {
3660                input: r#"a\"#,
3661                expect: vec!["a"],
3662            },
3663            TestCase {
3664                input: r#"a\ "#,
3665                expect: vec!["a "],
3666            },
3667            TestCase {
3668                input: r#"\"#,
3669                expect: vec![],
3670            },
3671            TestCase {
3672                input: r#"\ "#,
3673                expect: vec![r#" "#],
3674            },
3675            TestCase {
3676                input: r#" \ "#,
3677                expect: vec![r#" "#],
3678            },
3679            TestCase {
3680                input: r#"\  "#,
3681                expect: vec![r#" "#],
3682            },
3683        ];
3684        for test in tests {
3685            let got = split_options(test.input);
3686            assert_eq!(got, test.expect, "input: {}", test.input);
3687        }
3688    }
3689
3690    #[mz_ore::test]
3691    fn test_is_jwt() {
3692        // A real JWT header decodes successfully.
3693        assert!(is_jwt("eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0.signature"));
3694        // Not JWTs: plain strings, wrong segment count, non-JSON headers.
3695        for s in [
3696            "",
3697            "secure_password",
3698            "p4ss.w0rd",
3699            "aaa.bbb.ccc",
3700            "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0",
3701            "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0.sig.extra",
3702        ] {
3703            assert!(!is_jwt(s), "is_jwt({s:?})");
3704        }
3705    }
3706}