Skip to main content

mz_pgwire/
protocol.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use csv_core::ReadRecordResult;
21use futures::future::{BoxFuture, FutureExt, pending};
22use itertools::Itertools;
23use mz_adapter::client::RecordFirstRowStream;
24use mz_adapter::session::{
25    EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState, Session,
26    SessionConfig, TransactionStatus,
27};
28use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
29use mz_adapter::{
30    AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics,
31    verify_datum_desc,
32};
33use mz_adapter_types::dyncfgs::OIDC_GROUP_CLAIM;
34use mz_auth::Authenticated;
35use mz_auth::password::Password;
36use mz_authenticator::{Authenticator, GenericOidcAuthenticator};
37use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
38use mz_ore::cast::CastFrom;
39use mz_ore::netio::AsyncReady;
40use mz_ore::now::{EpochMillis, SYSTEM_TIME};
41use mz_ore::str::StrExt;
42use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log, soft_assert_or_log};
43use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
44use mz_pgwire_common::{
45    ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
46    VERSIONS,
47};
48use mz_repr::{
49    CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
50    SqlRelationType, SqlScalarType,
51};
52use mz_server_core::TlsMode;
53use mz_server_core::listeners;
54use mz_server_core::listeners::AllowedRoles;
55use mz_sql::ast::display::AstDisplay;
56use mz_sql::ast::{
57    CopyDirection, CopyStatement, CopyTarget, FetchDirection, Ident, Raw, Statement,
58};
59use mz_sql::parse::StatementParseResult;
60use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
61use mz_sql::session::metadata::SessionMetadata;
62use mz_sql::session::user::INTERNAL_USER_NAMES;
63use mz_sql::session::vars::VarInput;
64use postgres::error::SqlState;
65use tokio::io::{self, AsyncRead, AsyncWrite};
66use tokio::select;
67use tokio::time::{self};
68use tokio_metrics::TaskMetrics;
69use tokio_stream::wrappers::UnboundedReceiverStream;
70use tracing::{Instrument, debug, debug_span, warn};
71use uuid::Uuid;
72
73use crate::codec::{
74    FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
75};
76use crate::message::{
77    self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
78    SASLServerFirstMessage,
79};
80
81/// Reports whether the given stream begins with a pgwire handshake.
82///
83/// To avoid false negatives, there must be at least eight bytes in `buf`.
84pub fn match_handshake(buf: &[u8]) -> bool {
85    // The pgwire StartupMessage looks like this:
86    //
87    //     i32 - Length of entire message.
88    //     i32 - Protocol version number.
89    //     [String] - Arbitrary key-value parameters of any length.
90    //
91    // Since arbitrary parameters can be included in the StartupMessage, the
92    // first Int32 is worthless, since the message could have any length.
93    // Instead, we sniff the protocol version number.
94    if buf.len() < 8 {
95        return false;
96    }
97    let version = NetworkEndian::read_i32(&buf[4..8]);
98    VERSIONS.contains(&version)
99}
100
101/// Parameters for the [`run`] function.
102pub struct RunParams<'a, A, I>
103where
104    I: Iterator<Item = TaskMetrics> + Send,
105{
106    /// The TLS mode of the pgwire server.
107    pub tls_mode: Option<TlsMode>,
108    /// A client for the adapter.
109    pub adapter_client: mz_adapter::Client,
110    /// The connection to the client.
111    pub conn: &'a mut FramedConn<A>,
112    /// The universally unique identifier for the connection.
113    pub conn_uuid: Uuid,
114    /// The protocol version that the client provided in the startup message.
115    pub version: i32,
116    /// The parameters that the client provided in the startup message.
117    pub params: BTreeMap<String, String>,
118    /// Frontegg JWT authenticator.
119    pub frontegg: Option<FronteggAuthenticator>,
120    /// OIDC authenticator.
121    pub oidc: GenericOidcAuthenticator,
122    /// The authentication method defined by the server's listener
123    /// configuration.
124    pub authenticator_kind: listeners::AuthenticatorKind,
125    /// Global connection limit and count
126    pub active_connection_counter: ConnectionCounter,
127    /// Helm chart version
128    pub helm_chart_version: Option<String>,
129    /// Whether to allow reserved users (ie: mz_system).
130    pub allowed_roles: AllowedRoles,
131    /// Tokio metrics
132    pub tokio_metrics_intervals: I,
133}
134
135/// Runs a pgwire connection to completion.
136///
137/// This involves responding to `FrontendMessage::StartupMessage` and all future
138/// requests until the client terminates the connection or a fatal error occurs.
139///
140/// Note that this function returns successfully even upon delivering a fatal
141/// error to the client. It only returns `Err` if an unexpected I/O error occurs
142/// while communicating with the client, e.g., if the connection is severed in
143/// the middle of a request.
144#[mz_ore::instrument(level = "debug")]
145pub async fn run<'a, A, I>(
146    RunParams {
147        tls_mode,
148        adapter_client,
149        conn,
150        conn_uuid,
151        version,
152        mut params,
153        frontegg,
154        oidc,
155        authenticator_kind,
156        active_connection_counter,
157        helm_chart_version,
158        allowed_roles,
159        tokio_metrics_intervals,
160    }: RunParams<'a, A, I>,
161) -> Result<(), io::Error>
162where
163    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
164    I: Iterator<Item = TaskMetrics> + Send,
165{
166    if version != VERSION_3 {
167        return conn
168            .send(ErrorResponse::fatal(
169                SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
170                "server does not support the client's requested protocol version",
171            ))
172            .await;
173    }
174
175    let user = params.remove("user").unwrap_or_else(String::new);
176    let options = parse_options(params.get("options").unwrap_or(&String::new()));
177    let authenticator =
178        get_authenticator(authenticator_kind, frontegg, oidc, adapter_client.clone());
179    // TODO move this somewhere it can be shared with HTTP
180    let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
181    // this is a superset of internal users
182    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
183    let role_allowed = match allowed_roles {
184        AllowedRoles::Normal => !is_reserved_user,
185        AllowedRoles::Internal => is_internal_user,
186        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
187    };
188    if !role_allowed {
189        let msg = format!("unauthorized login to user '{user}'");
190        return conn
191            .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
192            .await;
193    }
194
195    if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
196        return conn.send(err).await;
197    }
198
199    let authenticator_kind = authenticator.kind();
200
201    let (mut session, expired) = match authenticator {
202        Authenticator::Frontegg(frontegg) => {
203            let password = match request_cleartext_password(conn).await {
204                Ok(password) => password,
205                Err(PasswordRequestError::IoError(e)) => return Err(e),
206                Err(PasswordRequestError::InvalidPasswordError(e)) => {
207                    return conn.send(e).await;
208                }
209            };
210
211            let group_claim =
212                OIDC_GROUP_CLAIM.get(adapter_client.get_system_vars().await.dyncfgs());
213            let auth_response = frontegg
214                .authenticate(&user, &password, Some(&group_claim))
215                .await;
216            match auth_response {
217                // Create a session based on the auth session.
218                //
219                // In particular, it's important that the username come from the
220                // auth session, as Frontegg may return an email address with
221                // different casing than the user supplied via the pgwire
222                // username fN
223                Ok((mut auth_session, authenticated)) => {
224                    let groups = auth_session.groups();
225                    let session = adapter_client.new_session(
226                        SessionConfig {
227                            conn_id: conn.conn_id().clone(),
228                            uuid: conn_uuid,
229                            user: auth_session.user().into(),
230                            client_ip: conn.peer_addr().clone(),
231                            external_metadata_rx: Some(auth_session.external_metadata_rx()),
232                            helm_chart_version,
233                            authenticator_kind,
234                            groups,
235                        },
236                        authenticated,
237                    );
238                    let expired = async move { auth_session.expired().await };
239                    (session, expired.left_future())
240                }
241                Err(err) => {
242                    warn!(?err, "pgwire connection failed authentication");
243                    return conn
244                        .send(ErrorResponse::fatal(
245                            SqlState::INVALID_PASSWORD,
246                            "invalid password",
247                        ))
248                        .await;
249                }
250            }
251        }
252        Authenticator::Oidc(oidc) => {
253            // OIDC listener: accepts either a JWT (uses OIDC authentication) or a
254            // plain SQL password (uses SQL password authentication).
255            let password = match request_cleartext_password(conn).await {
256                Ok(password) => password,
257                Err(PasswordRequestError::IoError(e)) => return Err(e),
258                Err(PasswordRequestError::InvalidPasswordError(e)) => {
259                    return conn.send(e).await;
260                }
261            };
262            if is_jwt(&password) {
263                let auth_response = oidc.authenticate(&password, Some(&user)).await;
264                match auth_response {
265                    Ok((mut claims, authenticated)) => {
266                        let groups = claims.groups.take();
267                        let session = adapter_client.new_session(
268                            SessionConfig {
269                                conn_id: conn.conn_id().clone(),
270                                uuid: conn_uuid,
271                                user: std::mem::take(&mut claims.user),
272                                client_ip: conn.peer_addr().clone(),
273                                external_metadata_rx: None,
274                                helm_chart_version,
275                                authenticator_kind,
276                                groups,
277                            },
278                            authenticated,
279                        );
280                        // No invalidation of the auth session once authenticated,
281                        // so auth session lasts indefinitely.
282                        (session, pending().right_future())
283                    }
284                    Err(err) => {
285                        warn!(?err, "pgwire connection failed authentication");
286                        return conn.send(err.into_response()).await;
287                    }
288                }
289            } else {
290                let session = match authenticate_with_password(
291                    conn,
292                    &adapter_client,
293                    user,
294                    Password(password),
295                    conn_uuid,
296                    helm_chart_version,
297                )
298                .await
299                {
300                    Ok(session) => session,
301                    Err(PasswordRequestError::IoError(e)) => return Err(e),
302                    Err(PasswordRequestError::InvalidPasswordError(e)) => {
303                        return conn.send(e).await;
304                    }
305                };
306                (session, pending().right_future())
307            }
308        }
309        Authenticator::Password(adapter_client) => {
310            let password = match request_cleartext_password(conn).await {
311                Ok(password) => password,
312                Err(PasswordRequestError::IoError(e)) => return Err(e),
313                Err(PasswordRequestError::InvalidPasswordError(e)) => {
314                    return conn.send(e).await;
315                }
316            };
317            let session = match authenticate_with_password(
318                conn,
319                &adapter_client,
320                user,
321                Password(password),
322                conn_uuid,
323                helm_chart_version,
324            )
325            .await
326            {
327                Ok(session) => session,
328                Err(PasswordRequestError::IoError(e)) => return Err(e),
329                Err(PasswordRequestError::InvalidPasswordError(e)) => {
330                    return conn.send(e).await;
331                }
332            };
333            // No frontegg check, so auth session lasts indefinitely.
334            (session, pending().right_future())
335        }
336        Authenticator::Sasl(adapter_client) => {
337            // Start the handshake
338            conn.send(BackendMessage::AuthenticationSASL).await?;
339            conn.flush().await?;
340            // Get the initial response indicating chosen mechanism
341            let (mechanism, initial_response) = match conn.recv().await? {
342                Some(FrontendMessage::RawAuthentication(data)) => {
343                    match decode_sasl_initial_response(Cursor::new(&data)).ok() {
344                        Some(FrontendMessage::SASLInitialResponse {
345                            gs2_header,
346                            mechanism,
347                            initial_response,
348                        }) => {
349                            // We do not support channel binding
350                            if gs2_header.channel_binding_enabled() {
351                                return conn
352                                    .send(ErrorResponse::fatal(
353                                        SqlState::PROTOCOL_VIOLATION,
354                                        "channel binding not supported",
355                                    ))
356                                    .await;
357                            }
358                            (mechanism, initial_response)
359                        }
360                        _ => {
361                            return conn
362                                .send(ErrorResponse::fatal(
363                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
364                                    "expected SASLInitialResponse message",
365                                ))
366                                .await;
367                        }
368                    }
369                }
370                _ => {
371                    return conn
372                        .send(ErrorResponse::fatal(
373                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
374                            "expected SASLInitialResponse message",
375                        ))
376                        .await;
377                }
378            };
379
380            if mechanism != "SCRAM-SHA-256" {
381                return conn
382                    .send(ErrorResponse::fatal(
383                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
384                        "unsupported SASL mechanism",
385                    ))
386                    .await;
387            }
388
389            if initial_response.nonce.len() > 256 {
390                return conn
391                    .send(ErrorResponse::fatal(
392                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
393                        "nonce too long",
394                    ))
395                    .await;
396            }
397
398            let (server_first_message_raw, mock_hash) = match adapter_client
399                .generate_sasl_challenge(&user, &initial_response.nonce)
400                .await
401            {
402                Ok(response) => {
403                    let server_first_message_raw = format!(
404                        "r={},s={},i={}",
405                        response.nonce, response.salt, response.iteration_count
406                    );
407
408                    let client_key = [0u8; 32];
409                    let server_key = [1u8; 32];
410                    let mock_hash = format!(
411                        "SCRAM-SHA-256${}:{}${}:{}",
412                        response.iteration_count,
413                        response.salt,
414                        BASE64_STANDARD.encode(client_key),
415                        BASE64_STANDARD.encode(server_key)
416                    );
417
418                    conn.send(BackendMessage::AuthenticationSASLContinue(
419                        SASLServerFirstMessage {
420                            iteration_count: response.iteration_count,
421                            nonce: response.nonce,
422                            salt: response.salt,
423                        },
424                    ))
425                    .await?;
426                    conn.flush().await?;
427                    (server_first_message_raw, mock_hash)
428                }
429                Err(e) => {
430                    return conn.send(e.into_response(Severity::Fatal)).await;
431                }
432            };
433
434            let authenticated = match conn.recv().await? {
435                Some(FrontendMessage::RawAuthentication(data)) => {
436                    match decode_sasl_response(Cursor::new(&data)).ok() {
437                        Some(FrontendMessage::SASLResponse(response)) => {
438                            let auth_message = format!(
439                                "{},{},{}",
440                                initial_response.client_first_message_bare_raw,
441                                server_first_message_raw,
442                                response.client_final_message_bare_raw
443                            );
444                            if response.proof.len() > 1024 {
445                                return conn
446                                    .send(ErrorResponse::fatal(
447                                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
448                                        "proof too long",
449                                    ))
450                                    .await;
451                            }
452                            match adapter_client
453                                .verify_sasl_proof(
454                                    &user,
455                                    &response.proof,
456                                    &auth_message,
457                                    &mock_hash,
458                                )
459                                .await
460                            {
461                                Ok((proof_response, authenticated)) => {
462                                    conn.send(BackendMessage::AuthenticationSASLFinal(
463                                        SASLServerFinalMessage {
464                                            kind: SASLServerFinalMessageKinds::Verifier(
465                                                proof_response.verifier,
466                                            ),
467                                            extensions: vec![],
468                                        },
469                                    ))
470                                    .await?;
471                                    conn.flush().await?;
472                                    authenticated
473                                }
474                                Err(_) => {
475                                    return conn
476                                        .send(ErrorResponse::fatal(
477                                            SqlState::INVALID_PASSWORD,
478                                            "invalid password",
479                                        ))
480                                        .await;
481                                }
482                            }
483                        }
484                        _ => {
485                            return conn
486                                .send(ErrorResponse::fatal(
487                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
488                                    "expected SASLResponse message",
489                                ))
490                                .await;
491                        }
492                    }
493                }
494                _ => {
495                    return conn
496                        .send(ErrorResponse::fatal(
497                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
498                            "expected SASLResponse message",
499                        ))
500                        .await;
501                }
502            };
503
504            let session = adapter_client.new_session(
505                SessionConfig {
506                    conn_id: conn.conn_id().clone(),
507                    uuid: conn_uuid,
508                    user,
509                    client_ip: conn.peer_addr().clone(),
510                    external_metadata_rx: None,
511                    helm_chart_version,
512                    authenticator_kind,
513                    groups: None,
514                },
515                authenticated,
516            );
517            // No frontegg check, so auth session lasts indefinitely.
518            let auth_session = pending().right_future();
519            (session, auth_session)
520        }
521
522        Authenticator::None => {
523            let session = adapter_client.new_session(
524                SessionConfig {
525                    conn_id: conn.conn_id().clone(),
526                    uuid: conn_uuid,
527                    user,
528                    client_ip: conn.peer_addr().clone(),
529                    external_metadata_rx: None,
530                    helm_chart_version,
531                    authenticator_kind,
532                    groups: None,
533                },
534                Authenticated,
535            );
536            // No frontegg check, so auth session lasts indefinitely.
537            let auth_session = pending().right_future();
538            (session, auth_session)
539        }
540    };
541
542    let system_vars = adapter_client.get_system_vars().await;
543    for (name, value) in params {
544        let settings = match name.as_str() {
545            "options" => match &options {
546                Ok(opts) => opts,
547                Err(()) => {
548                    session.add_notice(AdapterNotice::BadStartupSetting {
549                        name,
550                        reason: "could not parse".into(),
551                    });
552                    continue;
553                }
554            },
555            _ => &vec![(name, value)],
556        };
557        for (key, val) in settings {
558            const LOCAL: bool = false;
559            // TODO: Issuing an error here is better than what we did before
560            // (silently ignore errors on set), but erroring the connection
561            // might be the better behavior. We maybe need to support more
562            // options sent by psql and drivers before we can safely do this.
563            if let Err(err) = session
564                .vars_mut()
565                .set(&system_vars, key, VarInput::Flat(val), LOCAL)
566            {
567                session.add_notice(AdapterNotice::BadStartupSetting {
568                    name: key.clone(),
569                    reason: err.to_string(),
570                });
571            }
572        }
573    }
574    session
575        .vars_mut()
576        .end_transaction(EndTransactionAction::Commit);
577
578    let _guard = match active_connection_counter.allocate_connection(session.user()) {
579        Ok(drop_connection) => drop_connection,
580        Err(e) => {
581            let e: AdapterError = e.into();
582            return conn.send(e.into_response(Severity::Fatal)).await;
583        }
584    };
585
586    // Register session with adapter.
587    let mut adapter_client = match adapter_client.startup(session).await {
588        Ok(adapter_client) => adapter_client,
589        Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
590    };
591
592    let mut buf = vec![BackendMessage::AuthenticationOk];
593    for var in adapter_client.session().vars().notify_set() {
594        buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
595    }
596    buf.push(BackendMessage::BackendKeyData {
597        conn_id: adapter_client.session().conn_id().unhandled(),
598        secret_key: adapter_client.session().secret_key(),
599    });
600    buf.extend(
601        adapter_client
602            .session()
603            .drain_notices()
604            .into_iter()
605            .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
606    );
607    buf.push(BackendMessage::ReadyForQuery(
608        adapter_client.session().transaction().into(),
609    ));
610    conn.send_all(buf).await?;
611    conn.flush().await?;
612
613    let machine = StateMachine {
614        conn,
615        adapter_client,
616        txn_needs_commit: false,
617        tokio_metrics_intervals,
618    };
619
620    select! {
621        r = machine.run() => {
622            // Errors produced internally (like MAX_REQUEST_SIZE being exceeded) should send an
623            // error to the client informing them why the connection was closed. We still want to
624            // return the original error up the stack, though, so we skip error checking during conn
625            // operations.
626            if let Err(err) = &r {
627                let _ = conn
628                    .send(ErrorResponse::fatal(
629                        SqlState::CONNECTION_FAILURE,
630                        err.to_string(),
631                    ))
632                    .await;
633                let _ = conn.flush().await;
634            }
635            r
636        },
637        _ = expired => {
638            conn
639                .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
640                .await?;
641            conn.flush().await
642        }
643    }
644}
645
646/// Decides if a given password is a JWT by checking
647/// if we can decode its header.
648fn is_jwt(password: &str) -> bool {
649    jsonwebtoken::decode_header(password).is_ok()
650}
651
652/// Returns (name, value) session settings pairs from an options value.
653///
654/// From Postgres, see pg_split_opts in postinit.c and process_postgres_switches
655/// in postgres.c.
656fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
657    let opts = split_options(value);
658    let mut pairs = Vec::with_capacity(opts.len());
659    let mut seen_prefix = false;
660    for opt in opts {
661        if !seen_prefix {
662            if opt == "-c" {
663                seen_prefix = true;
664            } else {
665                let (key, val) = parse_option(&opt)?;
666                pairs.push((key.to_owned(), val.to_owned()));
667            }
668        } else {
669            let (key, val) = opt.split_once('=').ok_or(())?;
670            pairs.push((key.to_owned(), val.to_owned()));
671            seen_prefix = false;
672        }
673    }
674    Ok(pairs)
675}
676
677/// Returns the parsed key and value from option of the form `--key=value`, `-c
678/// key=value`, or `-ckey=value`. Keys replace `-` with `_`. Returns an error if
679/// there was some other prefix.
680fn parse_option(option: &str) -> Result<(&str, &str), ()> {
681    let (key, value) = option.split_once('=').ok_or(())?;
682    for prefix in &["-c", "--"] {
683        if let Some(key) = key.strip_prefix(prefix) {
684            return Ok((key, value));
685        }
686    }
687    Err(())
688}
689
690/// Splits value by any number of spaces except those preceded by `\`.
691fn split_options(value: &str) -> Vec<String> {
692    let mut strs = Vec::new();
693    // Need to build a string because of the escaping, so we can't simply
694    // subslice into value, and this isn't called enough to need to make it
695    // smart so it only builds a string if needed.
696    let mut current = String::new();
697    let mut was_slash = false;
698    for c in value.chars() {
699        was_slash = match c {
700            ' ' => {
701                if was_slash {
702                    current.push(' ');
703                } else if !current.is_empty() {
704                    // To ignore multiple spaces in a row, only push if current
705                    // is not empty.
706                    strs.push(std::mem::take(&mut current));
707                }
708                false
709            }
710            '\\' => {
711                if was_slash {
712                    // Two slashes in a row will add a slash and not escape the
713                    // next char.
714                    current.push('\\');
715                    false
716                } else {
717                    true
718                }
719            }
720            _ => {
721                current.push(c);
722                false
723            }
724        };
725    }
726    // A `\` at the end will be ignored.
727    if !current.is_empty() {
728        strs.push(current);
729    }
730    strs
731}
732
733enum PasswordRequestError {
734    InvalidPasswordError(ErrorResponse),
735    IoError(io::Error),
736}
737
738impl From<io::Error> for PasswordRequestError {
739    fn from(e: io::Error) -> Self {
740        PasswordRequestError::IoError(e)
741    }
742}
743
744/// Requests a cleartext password from a connection and returns it if it is valid.
745/// Sends an error response in the connection if the password
746/// is not valid.
747async fn request_cleartext_password<A>(
748    conn: &mut FramedConn<A>,
749) -> Result<String, PasswordRequestError>
750where
751    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
752{
753    conn.send(BackendMessage::AuthenticationCleartextPassword)
754        .await?;
755    conn.flush().await?;
756
757    if let Some(message) = conn.recv().await? {
758        if let FrontendMessage::RawAuthentication(data) = message {
759            if let Some(FrontendMessage::Password { password }) =
760                decode_password(Cursor::new(&data)).ok()
761            {
762                return Ok(password);
763            }
764        }
765    }
766
767    Err(PasswordRequestError::InvalidPasswordError(
768        ErrorResponse::fatal(
769            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
770            "expected Password message",
771        ),
772    ))
773}
774
775/// Helper for password-based authentication using AdapterClient
776/// and returns an authenticated session.
777async fn authenticate_with_password<A>(
778    conn: &FramedConn<A>,
779    adapter_client: &mz_adapter::Client,
780    user: String,
781    password: Password,
782    conn_uuid: Uuid,
783    helm_chart_version: Option<String>,
784) -> Result<Session, PasswordRequestError>
785where
786    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
787{
788    let authenticated = match adapter_client.authenticate(&user, &password).await {
789        Ok(authenticated) => authenticated,
790        Err(err) => {
791            warn!(?err, "pgwire connection failed authentication");
792            return Err(PasswordRequestError::InvalidPasswordError(
793                ErrorResponse::fatal(SqlState::INVALID_PASSWORD, "invalid password"),
794            ));
795        }
796    };
797
798    let session = adapter_client.new_session(
799        SessionConfig {
800            conn_id: conn.conn_id().clone(),
801            uuid: conn_uuid,
802            user,
803            client_ip: conn.peer_addr().clone(),
804            external_metadata_rx: None,
805            helm_chart_version,
806            authenticator_kind: mz_auth::AuthenticatorKind::Password,
807            groups: None,
808        },
809        authenticated,
810    );
811
812    Ok(session)
813}
814
815#[derive(Debug)]
816enum State {
817    Ready,
818    Drain,
819    Done,
820}
821
822struct StateMachine<'a, A, I>
823where
824    I: Iterator<Item = TaskMetrics> + Send + 'a,
825{
826    conn: &'a mut FramedConn<A>,
827    adapter_client: mz_adapter::SessionClient,
828    txn_needs_commit: bool,
829    tokio_metrics_intervals: I,
830}
831
832enum SendRowsEndedReason {
833    Success {
834        result_size: u64,
835        rows_returned: u64,
836    },
837    Errored {
838        error: String,
839    },
840    Canceled,
841}
842
843const ABORTED_TXN_MSG: &str =
844    "current transaction is aborted, commands ignored until end of transaction block";
845
846impl<'a, A, I> StateMachine<'a, A, I>
847where
848    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
849    I: Iterator<Item = TaskMetrics> + Send + 'a,
850{
851    // Manually desugar this (don't use `async fn run`) here because a much better
852    // error message is produced if there are problems with Send or other traits
853    // somewhere within the Future.
854    #[allow(clippy::manual_async_fn)]
855    #[mz_ore::instrument(level = "debug")]
856    fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
857        async move {
858            let mut state = State::Ready;
859            loop {
860                self.send_pending_notices().await?;
861                state = match state {
862                    State::Ready => self.advance_ready().await?,
863                    State::Drain => self.advance_drain().await?,
864                    State::Done => return Ok(()),
865                };
866                self.adapter_client
867                    .add_idle_in_transaction_session_timeout();
868            }
869        }
870    }
871
872    #[instrument(level = "debug")]
873    async fn advance_ready(&mut self) -> Result<State, io::Error> {
874        // Start a new metrics interval before the `recv()` call.
875        self.tokio_metrics_intervals
876            .next()
877            .expect("infinite iterator");
878
879        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
880        let message = select! {
881            biased;
882
883            // `recv_timeout()` is cancel-safe as per it's docs.
884            Some(timeout) = self.adapter_client.recv_timeout() => {
885                let err: AdapterError = timeout.into();
886                let conn_id = self.adapter_client.session().conn_id();
887                tracing::warn!("session timed out, conn_id {}", conn_id);
888
889                // Process the error, doing any state cleanup.
890                let error_response = err.into_response(Severity::Fatal);
891                let error_state = self.send_error_and_get_state(error_response).await;
892
893                // Terminate __after__ we do any cleanup.
894                self.adapter_client.terminate().await;
895
896                // We must wait for the client to send a request before we can send the error response.
897                // Due to the PG wire protocol, we can't send an ErrorResponse unless it is in response
898                // to a client message.
899                let _ = self.conn.recv().await?;
900                return error_state;
901            },
902            // `recv()` is cancel-safe as per it's docs.
903            message = self.conn.recv() => message?,
904        };
905
906        // Take the metrics since just before the `recv`.
907        let interval = self
908            .tokio_metrics_intervals
909            .next()
910            .expect("infinite iterator");
911        let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
912
913        // TODO(ggevay): Consider subtracting the scheduling delay from `received`. It's not obvious
914        // whether we should do this, because the result wouldn't exactly correspond to either first
915        // byte received or last byte received (for msgs that arrive in more than one network packet).
916        let received = SYSTEM_TIME();
917
918        self.adapter_client
919            .remove_idle_in_transaction_session_timeout();
920
921        // NOTE(guswynn): we could consider adding spans to all message types. Currently
922        // only a few message types seem useful.
923        let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
924
925        let start = message.as_ref().map(|_| Instant::now());
926        let next_state = match message {
927            Some(FrontendMessage::Query { sql }) => {
928                let query_root_span =
929                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
930                query_root_span.follows_from(tracing::Span::current());
931                self.query(sql, received)
932                    .instrument(query_root_span)
933                    .await?
934            }
935            Some(FrontendMessage::Parse {
936                name,
937                sql,
938                param_types,
939            }) => self.parse(name, sql, param_types).await?,
940            Some(FrontendMessage::Bind {
941                portal_name,
942                statement_name,
943                param_formats,
944                raw_params,
945                result_formats,
946            }) => {
947                self.bind(
948                    portal_name,
949                    statement_name,
950                    param_formats,
951                    raw_params,
952                    result_formats,
953                )
954                .await?
955            }
956            Some(FrontendMessage::Execute {
957                portal_name,
958                max_rows,
959            }) => {
960                let max_rows = match usize::try_from(max_rows) {
961                    Ok(0) | Err(_) => ExecuteCount::All, // If `max_rows < 0`, no limit.
962                    Ok(n) => ExecuteCount::Count(n),
963                };
964                let execute_root_span =
965                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
966                execute_root_span.follows_from(tracing::Span::current());
967                let state = self
968                    .execute(
969                        portal_name,
970                        max_rows,
971                        portal_exec_message,
972                        None,
973                        ExecuteTimeout::None,
974                        None,
975                        Some(received),
976                    )
977                    .instrument(execute_root_span)
978                    .await?;
979                // In PostgreSQL, when using the extended query protocol, some statements may
980                // trigger an eager commit of the current implicit transaction,
981                // see: <https://git.postgresql.org/gitweb/?p=postgresql.git&a=commitdiff&h=f92944137>.
982                //
983                // In Materialize, however, we eagerly commit every statement outside of an explicit
984                // transaction when using the extended query protocol. This allows us to eliminate
985                // the possibility of a multiple statement implicit transaction, which in turn
986                // allows us to apply single-statement optimizations to queries issued in implicit
987                // transactions in the extended query protocol.
988                //
989                // We don't immediately commit here to allow users to page through the portal if
990                // necessary. Committing the transaction would destroy the portal before the next
991                // Execute command has a chance to resume it. So we instead mark the transaction
992                // for commit the next time that `ensure_transaction` is called.
993                if self.adapter_client.session().transaction().is_implicit() {
994                    self.txn_needs_commit = true;
995                }
996                state
997            }
998            Some(FrontendMessage::DescribeStatement { name }) => {
999                self.describe_statement(&name).await?
1000            }
1001            Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
1002            Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
1003            Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
1004            Some(FrontendMessage::Flush) => self.flush().await?,
1005            Some(FrontendMessage::Sync) => self.sync().await?,
1006            Some(FrontendMessage::Terminate) => State::Done,
1007
1008            Some(FrontendMessage::CopyData(_))
1009            | Some(FrontendMessage::CopyDone)
1010            | Some(FrontendMessage::CopyFail(_))
1011            | Some(FrontendMessage::Password { .. })
1012            | Some(FrontendMessage::RawAuthentication(_))
1013            | Some(FrontendMessage::SASLInitialResponse { .. })
1014            | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
1015            None => State::Done,
1016        };
1017
1018        if let Some(start) = start {
1019            self.adapter_client
1020                .inner()
1021                .metrics()
1022                .pgwire_message_processing_seconds
1023                .with_label_values(&[message_name])
1024                .observe(start.elapsed().as_secs_f64());
1025        }
1026        self.adapter_client
1027            .inner()
1028            .metrics()
1029            .pgwire_recv_scheduling_delay_ms
1030            .with_label_values(&[message_name])
1031            .observe(recv_scheduling_delay_ms);
1032
1033        Ok(next_state)
1034    }
1035
1036    async fn advance_drain(&mut self) -> Result<State, io::Error> {
1037        let message = self.conn.recv().await?;
1038        if message.is_some() {
1039            self.adapter_client
1040                .remove_idle_in_transaction_session_timeout();
1041        }
1042        match message {
1043            Some(FrontendMessage::Sync) => self.sync().await,
1044            None => Ok(State::Done),
1045            _ => Ok(State::Drain),
1046        }
1047    }
1048
1049    /// Note that `lifecycle_timestamps` belongs to the whole "Simple Query", because the whole
1050    /// Simple Query is received and parsed together. This means that if there are multiple
1051    /// statements in a Simple Query, then all of them have the same `lifecycle_timestamps`.
1052    #[instrument(level = "debug")]
1053    async fn one_query(
1054        &mut self,
1055        stmt: Statement<Raw>,
1056        sql: String,
1057        lifecycle_timestamps: LifecycleTimestamps,
1058    ) -> Result<State, io::Error> {
1059        // Bind the portal. Note that this does not set the empty string prepared
1060        // statement.
1061        const EMPTY_PORTAL: &str = "";
1062        if let Err(e) = self
1063            .adapter_client
1064            .declare(EMPTY_PORTAL.to_string(), stmt, sql)
1065            .await
1066        {
1067            return self
1068                .send_error_and_get_state(e.into_response(Severity::Error))
1069                .await;
1070        }
1071        let portal = self
1072            .adapter_client
1073            .session()
1074            .get_portal_unverified_mut(EMPTY_PORTAL)
1075            .expect("unnamed portal should be present");
1076
1077        *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
1078
1079        let stmt_desc = portal.desc.clone();
1080        if !stmt_desc.param_types.is_empty() {
1081            return self
1082                .send_error_and_get_state(ErrorResponse::error(
1083                    SqlState::UNDEFINED_PARAMETER,
1084                    "there is no parameter $1",
1085                ))
1086                .await;
1087        }
1088
1089        // Maybe send row description.
1090        if let Some(relation_desc) = &stmt_desc.relation_desc {
1091            if !stmt_desc.is_copy {
1092                let formats = vec![Format::Text; stmt_desc.arity()];
1093                self.send(BackendMessage::RowDescription(
1094                    message::encode_row_description(relation_desc, &formats),
1095                ))
1096                .await?;
1097            }
1098        }
1099
1100        let result = match self
1101            .adapter_client
1102            .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
1103            .await
1104        {
1105            Ok((response, execute_started)) => {
1106                self.send_pending_notices().await?;
1107                self.send_execute_response(
1108                    response,
1109                    stmt_desc.relation_desc,
1110                    EMPTY_PORTAL.to_string(),
1111                    ExecuteCount::All,
1112                    portal_exec_message,
1113                    None,
1114                    ExecuteTimeout::None,
1115                    execute_started,
1116                )
1117                .await
1118            }
1119            Err(e) => {
1120                self.send_pending_notices().await?;
1121                self.send_error_and_get_state(e.into_response(Severity::Error))
1122                    .await
1123            }
1124        };
1125
1126        // Destroy the portal.
1127        self.adapter_client.session().remove_portal(EMPTY_PORTAL);
1128
1129        result
1130    }
1131
1132    async fn ensure_transaction(
1133        &mut self,
1134        num_stmts: usize,
1135        message_type: &str,
1136    ) -> Result<(), io::Error> {
1137        let start = Instant::now();
1138        if self.txn_needs_commit {
1139            self.commit_transaction().await?;
1140        }
1141        // start_transaction can't error (but assert that just in case it changes in
1142        // the future.
1143        let res = self.adapter_client.start_transaction(Some(num_stmts));
1144        assert_ok!(res);
1145        self.adapter_client
1146            .inner()
1147            .metrics()
1148            .pgwire_ensure_transaction_seconds
1149            .with_label_values(&[message_type])
1150            .observe(start.elapsed().as_secs_f64());
1151        Ok(())
1152    }
1153
1154    fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1155        let parse_start = Instant::now();
1156        let result = match self.adapter_client.parse(sql) {
1157            Ok(result) => result.map_err(|e| {
1158                // Convert our 0-based byte position to pgwire's 1-based character
1159                // position.
1160                let pos = sql[..e.error.pos].chars().count() + 1;
1161                ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1162            }),
1163            Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1164        };
1165        self.adapter_client
1166            .inner()
1167            .metrics()
1168            .parse_seconds
1169            .observe(parse_start.elapsed().as_secs_f64());
1170        result
1171    }
1172
1173    /// Executes a "Simple Query", see
1174    /// <https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SIMPLE-QUERY>
1175    ///
1176    /// For implicit transaction handling, see "Multiple Statements in a Simple Query" in the above.
1177    #[instrument(level = "debug")]
1178    async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1179        // Parse first before doing any transaction checking.
1180        let stmts = match self.parse_sql(&sql) {
1181            Ok(stmts) => stmts,
1182            Err(err) => {
1183                self.send_error_and_get_state(err).await?;
1184                return self.ready().await;
1185            }
1186        };
1187
1188        let num_stmts = stmts.len();
1189
1190        // Compare with postgres' backend/tcop/postgres.c exec_simple_query.
1191        for StatementParseResult { ast: stmt, sql } in stmts {
1192            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1193            if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1194                self.aborted_txn_error().await?;
1195                break;
1196            }
1197
1198            // Start an implicit transaction if we aren't in any transaction and there's
1199            // more than one statement. This mirrors the `use_implicit_block` variable in
1200            // postgres.
1201            //
1202            // This needs to be done in the loop instead of once at the top because
1203            // a COMMIT/ROLLBACK statement needs to start a new transaction on next
1204            // statement.
1205            self.ensure_transaction(num_stmts, "query").await?;
1206
1207            match self
1208                .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1209                .await?
1210            {
1211                State::Ready => (),
1212                State::Drain => break,
1213                State::Done => return Ok(State::Done),
1214            }
1215        }
1216
1217        // Implicit transactions are closed at the end of a Query message.
1218        {
1219            if self.adapter_client.session().transaction().is_implicit() {
1220                self.commit_transaction().await?;
1221            }
1222        }
1223
1224        if num_stmts == 0 {
1225            self.send(BackendMessage::EmptyQueryResponse).await?;
1226        }
1227
1228        self.ready().await
1229    }
1230
1231    #[instrument(level = "debug")]
1232    async fn parse(
1233        &mut self,
1234        name: String,
1235        sql: String,
1236        param_oids: Vec<u32>,
1237    ) -> Result<State, io::Error> {
1238        // Start a transaction if we aren't in one.
1239        self.ensure_transaction(1, "parse").await?;
1240
1241        let mut param_types = vec![];
1242        for oid in param_oids {
1243            match mz_pgrepr::Type::from_oid(oid) {
1244                Ok(ty) => match SqlScalarType::try_from(&ty) {
1245                    Ok(ty) => param_types.push(Some(ty)),
1246                    Err(err) => {
1247                        return self
1248                            .send_error_and_get_state(ErrorResponse::error(
1249                                SqlState::INVALID_PARAMETER_VALUE,
1250                                err.to_string(),
1251                            ))
1252                            .await;
1253                    }
1254                },
1255                Err(_) if oid == 0 => param_types.push(None),
1256                Err(e) => {
1257                    return self
1258                        .send_error_and_get_state(ErrorResponse::error(
1259                            SqlState::PROTOCOL_VIOLATION,
1260                            e.to_string(),
1261                        ))
1262                        .await;
1263                }
1264            }
1265        }
1266
1267        let stmts = match self.parse_sql(&sql) {
1268            Ok(stmts) => stmts,
1269            Err(err) => {
1270                return self.send_error_and_get_state(err).await;
1271            }
1272        };
1273        if stmts.len() > 1 {
1274            return self
1275                .send_error_and_get_state(ErrorResponse::error(
1276                    SqlState::INTERNAL_ERROR,
1277                    "cannot insert multiple commands into a prepared statement",
1278                ))
1279                .await;
1280        }
1281        let (maybe_stmt, sql) = match stmts.into_iter().next() {
1282            None => (None, ""),
1283            Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1284        };
1285        if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1286            return self.aborted_txn_error().await;
1287        }
1288        match self
1289            .adapter_client
1290            .prepare(name, maybe_stmt, sql.to_string(), param_types)
1291            .await
1292        {
1293            Ok(()) => {
1294                self.send(BackendMessage::ParseComplete).await?;
1295                Ok(State::Ready)
1296            }
1297            Err(e) => {
1298                self.send_error_and_get_state(e.into_response(Severity::Error))
1299                    .await
1300            }
1301        }
1302    }
1303
1304    /// Commits and clears the current transaction.
1305    #[instrument(level = "debug")]
1306    async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1307        self.end_transaction(EndTransactionAction::Commit).await
1308    }
1309
1310    /// Rollback and clears the current transaction.
1311    #[instrument(level = "debug")]
1312    async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1313        self.end_transaction(EndTransactionAction::Rollback).await
1314    }
1315
1316    /// End a transaction and report to the user if an error occurred.
1317    #[instrument(level = "debug")]
1318    async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1319        self.txn_needs_commit = false;
1320        let resp = self.adapter_client.end_transaction(action).await;
1321        if let Err(err) = resp {
1322            self.send(BackendMessage::ErrorResponse(
1323                err.into_response(Severity::Error),
1324            ))
1325            .await?;
1326        }
1327        Ok(())
1328    }
1329
1330    #[instrument(level = "debug")]
1331    async fn bind(
1332        &mut self,
1333        portal_name: String,
1334        statement_name: String,
1335        param_formats: Vec<Format>,
1336        raw_params: Vec<Option<Vec<u8>>>,
1337        result_formats: Vec<Format>,
1338    ) -> Result<State, io::Error> {
1339        // Start a transaction if we aren't in one.
1340        self.ensure_transaction(1, "bind").await?;
1341
1342        let aborted_txn = self.is_aborted_txn();
1343        let stmt = match self
1344            .adapter_client
1345            .get_prepared_statement(&statement_name)
1346            .await
1347        {
1348            Ok(stmt) => stmt,
1349            Err(err) => {
1350                return self
1351                    .send_error_and_get_state(err.into_response(Severity::Error))
1352                    .await;
1353            }
1354        };
1355
1356        let param_types = &stmt.desc().param_types;
1357        if param_types.len() != raw_params.len() {
1358            let message = format!(
1359                "bind message supplies {actual} parameters, \
1360                 but prepared statement \"{name}\" requires {expected}",
1361                name = statement_name,
1362                actual = raw_params.len(),
1363                expected = param_types.len()
1364            );
1365            return self
1366                .send_error_and_get_state(ErrorResponse::error(
1367                    SqlState::PROTOCOL_VIOLATION,
1368                    message,
1369                ))
1370                .await;
1371        }
1372        let param_formats = match pad_formats(param_formats, raw_params.len()) {
1373            Ok(param_formats) => param_formats,
1374            Err(msg) => {
1375                return self
1376                    .send_error_and_get_state(ErrorResponse::error(
1377                        SqlState::PROTOCOL_VIOLATION,
1378                        msg,
1379                    ))
1380                    .await;
1381            }
1382        };
1383        if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1384            return self.aborted_txn_error().await;
1385        }
1386        let buf = RowArena::new();
1387        let mut params = vec![];
1388        for ((raw_param, mz_typ), format) in raw_params
1389            .into_iter()
1390            .zip_eq(param_types)
1391            .zip_eq(param_formats)
1392        {
1393            let pg_typ = mz_pgrepr::Type::from(mz_typ);
1394            let datum = match raw_param {
1395                None => Datum::Null,
1396                Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1397                    Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1398                        Ok(datum) => datum,
1399                        Err(msg) => {
1400                            return self
1401                                .send_error_and_get_state(ErrorResponse::error(
1402                                    SqlState::INVALID_PARAMETER_VALUE,
1403                                    msg,
1404                                ))
1405                                .await;
1406                        }
1407                    },
1408                    Err(err) => {
1409                        let msg = format!("unable to decode parameter: {}", err);
1410                        return self
1411                            .send_error_and_get_state(ErrorResponse::error(
1412                                SqlState::INVALID_PARAMETER_VALUE,
1413                                msg,
1414                            ))
1415                            .await;
1416                    }
1417                },
1418            };
1419            params.push((datum, mz_typ.clone()))
1420        }
1421
1422        let result_formats = match pad_formats(
1423            result_formats,
1424            stmt.desc()
1425                .relation_desc
1426                .clone()
1427                .map(|desc| desc.typ().column_types.len())
1428                .unwrap_or(0),
1429        ) {
1430            Ok(result_formats) => result_formats,
1431            Err(msg) => {
1432                return self
1433                    .send_error_and_get_state(ErrorResponse::error(
1434                        SqlState::PROTOCOL_VIOLATION,
1435                        msg,
1436                    ))
1437                    .await;
1438            }
1439        };
1440
1441        // Binary encodings are disabled for list, map, and aclitem types, but this doesn't
1442        // apply to COPY TO statements.
1443        if !stmt.stmt().map_or(false, |stmt| match stmt {
1444            Statement::Copy(CopyStatement {
1445                direction: CopyDirection::To,
1446                ..
1447            }) => true,
1448            Statement::Copy(CopyStatement {
1449                direction: CopyDirection::From,
1450                // To be conservative, we are restricting COPY FROM to only allow list/map/aclitem types if it is not
1451                // copying from STDIN. It is likely that this works in theory, but is risky and likely to OOM anyways
1452                // as all the data will be held in a buffer in memory before being processed.
1453                target: CopyTarget::Expr(_),
1454                ..
1455            }) => true,
1456            _ => false,
1457        }) {
1458            if let Some(desc) = stmt.desc().relation_desc.clone() {
1459                for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1460                    if let Format::Binary = format {
1461                        if let Err(msg) = mz_pgrepr::Value::binary_encoding_error(&ty.scalar_type) {
1462                            return self
1463                                .send_error_and_get_state(ErrorResponse::error(
1464                                    SqlState::PROTOCOL_VIOLATION,
1465                                    msg,
1466                                ))
1467                                .await;
1468                        }
1469                    }
1470                }
1471            }
1472        }
1473
1474        let desc = stmt.desc().clone();
1475        let logging = Arc::clone(stmt.logging());
1476        let stmt_ast = stmt.stmt().cloned();
1477        let state_revision = stmt.state_revision;
1478        if let Err(err) = self.adapter_client.session().set_portal(
1479            portal_name,
1480            desc,
1481            stmt_ast,
1482            logging,
1483            params,
1484            result_formats,
1485            state_revision,
1486        ) {
1487            return self
1488                .send_error_and_get_state(err.into_response(Severity::Error))
1489                .await;
1490        }
1491
1492        self.send(BackendMessage::BindComplete).await?;
1493        Ok(State::Ready)
1494    }
1495
1496    /// `outer_ctx_extra` is Some when we are executing as part of an outer statement, e.g., a FETCH
1497    /// triggering the execution of the underlying query.
1498    fn execute(
1499        &mut self,
1500        portal_name: String,
1501        max_rows: ExecuteCount,
1502        get_response: GetResponse,
1503        fetch_portal_name: Option<String>,
1504        timeout: ExecuteTimeout,
1505        outer_ctx_extra: Option<ExecuteContextGuard>,
1506        received: Option<EpochMillis>,
1507    ) -> BoxFuture<'_, Result<State, io::Error>> {
1508        async move {
1509            let aborted_txn = self.is_aborted_txn();
1510
1511            // Check if the portal has been started and can be continued.
1512            let portal = match self
1513                .adapter_client
1514                .session()
1515                .get_portal_unverified_mut(&portal_name)
1516            {
1517                Some(portal) => portal,
1518                None => {
1519                    let msg = format!("portal {} does not exist", portal_name.quoted());
1520                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1521                        self.adapter_client.retire_execute(
1522                            outer_ctx_extra,
1523                            StatementEndedExecutionReason::Errored { error: msg.clone() },
1524                        );
1525                    }
1526                    return self
1527                        .send_error_and_get_state(ErrorResponse::error(
1528                            SqlState::INVALID_CURSOR_NAME,
1529                            msg,
1530                        ))
1531                        .await;
1532                }
1533            };
1534
1535            *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1536
1537            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1538            let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1539            if aborted_txn && !txn_exit_stmt {
1540                if let Some(outer_ctx_extra) = outer_ctx_extra {
1541                    self.adapter_client.retire_execute(
1542                        outer_ctx_extra,
1543                        StatementEndedExecutionReason::Errored {
1544                            error: ABORTED_TXN_MSG.to_string(),
1545                        },
1546                    );
1547                }
1548                return self.aborted_txn_error().await;
1549            }
1550
1551            let row_desc = portal.desc.relation_desc.clone();
1552            match portal.state {
1553                PortalState::NotStarted => {
1554                    // Start a transaction if we aren't in one.
1555                    self.ensure_transaction(1, "execute").await?;
1556                    match self
1557                        .adapter_client
1558                        .execute(
1559                            portal_name.clone(),
1560                            self.conn.wait_closed(),
1561                            outer_ctx_extra,
1562                        )
1563                        .await
1564                    {
1565                        Ok((response, execute_started)) => {
1566                            self.send_pending_notices().await?;
1567                            self.send_execute_response(
1568                                response,
1569                                row_desc,
1570                                portal_name,
1571                                max_rows,
1572                                get_response,
1573                                fetch_portal_name,
1574                                timeout,
1575                                execute_started,
1576                            )
1577                            .await
1578                        }
1579                        Err(e) => {
1580                            self.send_pending_notices().await?;
1581                            self.send_error_and_get_state(e.into_response(Severity::Error))
1582                                .await
1583                        }
1584                    }
1585                }
1586                PortalState::InProgress(rows) => {
1587                    let rows = rows.take().expect("InProgress rows must be populated");
1588                    let (result, statement_ended_execution_reason) = match self
1589                        .send_rows(
1590                            row_desc.expect("portal missing row desc on resumption"),
1591                            portal_name,
1592                            rows,
1593                            max_rows,
1594                            get_response,
1595                            fetch_portal_name,
1596                            timeout,
1597                        )
1598                        .await
1599                    {
1600                        Err(e) => {
1601                            // This is an error communicating with the connection.
1602                            // We consider that to be a cancelation, rather than a query error.
1603                            (Err(e), StatementEndedExecutionReason::Canceled)
1604                        }
1605                        Ok((ok, SendRowsEndedReason::Canceled)) => {
1606                            (Ok(ok), StatementEndedExecutionReason::Canceled)
1607                        }
1608                        // NOTE: For now the values for `result_size` and
1609                        // `rows_returned` in fetches are a bit confusing.
1610                        // We record `Some(n)` for the first fetch, where `n` is
1611                        // the number of bytes/rows returned by the inner
1612                        // execute (regardless of how many rows the
1613                        // fetch fetched), and `None` for subsequent fetches.
1614                        //
1615                        // This arguably makes sense since the size/rows
1616                        // returned measures how much work the compute
1617                        // layer had to do to satisfy the query, but
1618                        // we should revisit it if/when we start
1619                        // logging the inner execute separately.
1620                        Ok((
1621                            ok,
1622                            SendRowsEndedReason::Success {
1623                                result_size: _,
1624                                rows_returned: _,
1625                            },
1626                        )) => (
1627                            Ok(ok),
1628                            StatementEndedExecutionReason::Success {
1629                                result_size: None,
1630                                rows_returned: None,
1631                                execution_strategy: None,
1632                            },
1633                        ),
1634                        Ok((ok, SendRowsEndedReason::Errored { error })) => {
1635                            (Ok(ok), StatementEndedExecutionReason::Errored { error })
1636                        }
1637                    };
1638                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1639                        self.adapter_client
1640                            .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1641                    }
1642                    result
1643                }
1644                // FETCH is an awkward command for our current architecture. In Postgres it
1645                // will extract <count> rows from the target portal, cache them, and return
1646                // them to the user as requested. Its command tag is always FETCH <num rows
1647                // extracted>. In Materialize, since we have chosen to not fully support FETCH,
1648                // we must remember the number of rows that were returned. Use this tag to
1649                // remember that information and return it.
1650                PortalState::Completed(Some(tag)) => {
1651                    let tag = tag.to_string();
1652                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1653                        self.adapter_client.retire_execute(
1654                            outer_ctx_extra,
1655                            StatementEndedExecutionReason::Success {
1656                                result_size: None,
1657                                rows_returned: None,
1658                                execution_strategy: None,
1659                            },
1660                        );
1661                    }
1662                    self.send(BackendMessage::CommandComplete { tag }).await?;
1663                    Ok(State::Ready)
1664                }
1665                PortalState::Completed(None) => {
1666                    let error = format!(
1667                        "portal {} cannot be run",
1668                        Ident::new_unchecked(portal_name).to_ast_string_stable()
1669                    );
1670                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1671                        self.adapter_client.retire_execute(
1672                            outer_ctx_extra,
1673                            StatementEndedExecutionReason::Errored {
1674                                error: error.clone(),
1675                            },
1676                        );
1677                    }
1678                    self.send_error_and_get_state(ErrorResponse::error(
1679                        SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1680                        error,
1681                    ))
1682                    .await
1683                }
1684            }
1685        }
1686        .instrument(debug_span!("execute"))
1687        .boxed()
1688    }
1689
1690    #[instrument(level = "debug")]
1691    async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1692        // Start a transaction if we aren't in one.
1693        self.ensure_transaction(1, "describe_statement").await?;
1694
1695        let stmt = match self.adapter_client.get_prepared_statement(name).await {
1696            Ok(stmt) => stmt,
1697            Err(err) => {
1698                return self
1699                    .send_error_and_get_state(err.into_response(Severity::Error))
1700                    .await;
1701            }
1702        };
1703        // Cloning to avoid a mutable borrow issue because `send` also uses `adapter_client`
1704        let parameter_desc = BackendMessage::ParameterDescription(
1705            stmt.desc()
1706                .param_types
1707                .iter()
1708                .map(mz_pgrepr::Type::from)
1709                .collect(),
1710        );
1711        // Claim that all results will be output in text format, even
1712        // though the true result formats are not yet known. A bit
1713        // weird, but this is the behavior that PostgreSQL specifies.
1714        let formats = vec![Format::Text; stmt.desc().arity()];
1715        let row_desc = describe_rows(stmt.desc(), &formats);
1716        self.send_all([parameter_desc, row_desc]).await?;
1717        Ok(State::Ready)
1718    }
1719
1720    #[instrument(level = "debug")]
1721    async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1722        // Start a transaction if we aren't in one.
1723        self.ensure_transaction(1, "describe_portal").await?;
1724
1725        let session = self.adapter_client.session();
1726        let row_desc = session
1727            .get_portal_unverified(name)
1728            .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1729        match row_desc {
1730            Some(row_desc) => {
1731                self.send(row_desc).await?;
1732                Ok(State::Ready)
1733            }
1734            None => {
1735                self.send_error_and_get_state(ErrorResponse::error(
1736                    SqlState::INVALID_CURSOR_NAME,
1737                    format!("portal {} does not exist", name.quoted()),
1738                ))
1739                .await
1740            }
1741        }
1742    }
1743
1744    #[instrument(level = "debug")]
1745    async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1746        self.adapter_client
1747            .session()
1748            .remove_prepared_statement(&name);
1749        self.send(BackendMessage::CloseComplete).await?;
1750        Ok(State::Ready)
1751    }
1752
1753    #[instrument(level = "debug")]
1754    async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1755        self.adapter_client.session().remove_portal(&name);
1756        self.send(BackendMessage::CloseComplete).await?;
1757        Ok(State::Ready)
1758    }
1759
1760    fn complete_portal(&mut self, name: &str) {
1761        let portal = self
1762            .adapter_client
1763            .session()
1764            .get_portal_unverified_mut(name)
1765            .expect("portal should exist");
1766        *portal.state = PortalState::Completed(None);
1767    }
1768
1769    async fn fetch(
1770        &mut self,
1771        name: String,
1772        count: Option<FetchDirection>,
1773        max_rows: ExecuteCount,
1774        fetch_portal_name: Option<String>,
1775        timeout: ExecuteTimeout,
1776        ctx_extra: ExecuteContextGuard,
1777    ) -> Result<State, io::Error> {
1778        // Unlike Execute, no count specified in FETCH returns 1 row, and 0 means 0
1779        // instead of All.
1780        let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1781
1782        // Figure out how many rows we should send back by looking at the various
1783        // combinations of the execute and fetch.
1784        //
1785        // In Postgres, Fetch will cache <count> rows from the target portal and
1786        // return those as requested (if, say, an Execute message was sent with a
1787        // max_rows < the Fetch's count). We expect that case to be incredibly rare and
1788        // so have chosen to not support it until users request it. This eases
1789        // implementation difficulty since we don't have to be able to "send" rows to
1790        // a buffer.
1791        //
1792        // TODO(mjibson): Test this somehow? Need to divide up the pgtest files in
1793        // order to have some that are not Postgres compatible.
1794        let count = match (max_rows, count) {
1795            (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1796                let count = usize::cast_from(count);
1797                if max_rows < count {
1798                    let msg = "Execute with max_rows < a FETCH's count is not supported";
1799                    self.adapter_client.retire_execute(
1800                        ctx_extra,
1801                        StatementEndedExecutionReason::Errored {
1802                            error: msg.to_string(),
1803                        },
1804                    );
1805                    return self
1806                        .send_error_and_get_state(ErrorResponse::error(
1807                            SqlState::FEATURE_NOT_SUPPORTED,
1808                            msg,
1809                        ))
1810                        .await;
1811                }
1812                ExecuteCount::Count(count)
1813            }
1814            (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1815                let msg = "Execute with max_rows of a FETCH ALL is not supported";
1816                self.adapter_client.retire_execute(
1817                    ctx_extra,
1818                    StatementEndedExecutionReason::Errored {
1819                        error: msg.to_string(),
1820                    },
1821                );
1822                return self
1823                    .send_error_and_get_state(ErrorResponse::error(
1824                        SqlState::FEATURE_NOT_SUPPORTED,
1825                        msg,
1826                    ))
1827                    .await;
1828            }
1829            (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1830            (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1831                ExecuteCount::Count(usize::cast_from(count))
1832            }
1833        };
1834        let cursor_name = name.to_string();
1835        self.execute(
1836            cursor_name,
1837            count,
1838            fetch_message,
1839            fetch_portal_name,
1840            timeout,
1841            Some(ctx_extra),
1842            None,
1843        )
1844        .await
1845    }
1846
1847    async fn flush(&mut self) -> Result<State, io::Error> {
1848        self.conn.flush().await?;
1849        Ok(State::Ready)
1850    }
1851
1852    /// Sends a backend message to the client, after applying a severity filter.
1853    ///
1854    /// The message is only sent if its severity is above the severity set
1855    /// in the session, with the default value being NOTICE.
1856    #[instrument(level = "debug")]
1857    async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1858    where
1859        M: Into<BackendMessage>,
1860    {
1861        let message: BackendMessage = message.into();
1862        let is_error =
1863            matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1864
1865        self.conn.send(message).await?;
1866
1867        // Flush immediately after sending an error response, as some clients
1868        // expect to be able to read the error response before sending a Sync
1869        // message. This is arguably in violation of the protocol specification,
1870        // but the specification is somewhat ambiguous, and easier to match
1871        // PostgreSQL here than to fix all the clients that have this
1872        // expectation.
1873        if is_error {
1874            self.conn.flush().await?;
1875        }
1876
1877        Ok(())
1878    }
1879
1880    #[instrument(level = "debug")]
1881    pub async fn send_all(
1882        &mut self,
1883        messages: impl IntoIterator<Item = BackendMessage>,
1884    ) -> Result<(), io::Error> {
1885        for m in messages {
1886            self.send(m).await?;
1887        }
1888        Ok(())
1889    }
1890
1891    #[instrument(level = "debug")]
1892    async fn sync(&mut self) -> Result<State, io::Error> {
1893        // Close the current transaction if we are in an implicit transaction.
1894        if self.adapter_client.session().transaction().is_implicit() {
1895            self.commit_transaction().await?;
1896        }
1897        self.ready().await
1898    }
1899
1900    #[instrument(level = "debug")]
1901    async fn ready(&mut self) -> Result<State, io::Error> {
1902        let txn_state = self.adapter_client.session().transaction().into();
1903        self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1904        self.flush().await
1905    }
1906
1907    #[allow(clippy::too_many_arguments)]
1908    #[instrument(level = "debug")]
1909    async fn send_execute_response(
1910        &mut self,
1911        response: ExecuteResponse,
1912        row_desc: Option<RelationDesc>,
1913        portal_name: String,
1914        max_rows: ExecuteCount,
1915        get_response: GetResponse,
1916        fetch_portal_name: Option<String>,
1917        timeout: ExecuteTimeout,
1918        execute_started: Instant,
1919    ) -> Result<State, io::Error> {
1920        let mut tag = response.tag();
1921
1922        macro_rules! command_complete {
1923            () => {{
1924                self.send(BackendMessage::CommandComplete {
1925                    tag: tag
1926                        .take()
1927                        .expect("command_complete only called on tag-generating results"),
1928                })
1929                .await?;
1930                Ok(State::Ready)
1931            }};
1932        }
1933
1934        let r = match response {
1935            ExecuteResponse::ClosedCursor => {
1936                self.complete_portal(&portal_name);
1937                command_complete!()
1938            }
1939            ExecuteResponse::DeclaredCursor => {
1940                self.complete_portal(&portal_name);
1941                command_complete!()
1942            }
1943            ExecuteResponse::EmptyQuery => {
1944                self.send(BackendMessage::EmptyQueryResponse).await?;
1945                Ok(State::Ready)
1946            }
1947            ExecuteResponse::Fetch {
1948                name,
1949                count,
1950                timeout,
1951                ctx_extra,
1952            } => {
1953                self.fetch(
1954                    name,
1955                    count,
1956                    max_rows,
1957                    Some(portal_name.to_string()),
1958                    timeout,
1959                    ctx_extra,
1960                )
1961                .await
1962            }
1963            ExecuteResponse::SendingRowsStreaming {
1964                rows,
1965                instance_id,
1966                strategy,
1967            } => {
1968                let row_desc = row_desc
1969                    .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1970
1971                let span = tracing::debug_span!("sending_rows_streaming");
1972
1973                self.send_rows(
1974                    row_desc,
1975                    portal_name,
1976                    InProgressRows::new(RecordFirstRowStream::new(
1977                        Box::new(rows),
1978                        execute_started,
1979                        &self.adapter_client,
1980                        Some(instance_id),
1981                        Some(strategy),
1982                    )),
1983                    max_rows,
1984                    get_response,
1985                    fetch_portal_name,
1986                    timeout,
1987                )
1988                .instrument(span)
1989                .await
1990                .map(|(state, _)| state)
1991            }
1992            ExecuteResponse::SendingRowsImmediate { rows } => {
1993                let row_desc = row_desc
1994                    .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1995
1996                let span = tracing::debug_span!("sending_rows_immediate");
1997
1998                let stream =
1999                    futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
2000                self.send_rows(
2001                    row_desc,
2002                    portal_name,
2003                    InProgressRows::new(RecordFirstRowStream::new(
2004                        Box::new(stream),
2005                        execute_started,
2006                        &self.adapter_client,
2007                        None,
2008                        Some(StatementExecutionStrategy::Constant),
2009                    )),
2010                    max_rows,
2011                    get_response,
2012                    fetch_portal_name,
2013                    timeout,
2014                )
2015                .instrument(span)
2016                .await
2017                .map(|(state, _)| state)
2018            }
2019            ExecuteResponse::SetVariable { name, .. } => {
2020                // This code is somewhat awkwardly structured because we
2021                // can't hold `var` across an await point.
2022                let qn = name.to_string();
2023                let msg = if let Some(var) = self
2024                    .adapter_client
2025                    .session()
2026                    .vars_mut()
2027                    .notify_set()
2028                    .find(|v| v.name() == qn)
2029                {
2030                    Some(BackendMessage::ParameterStatus(var.name(), var.value()))
2031                } else {
2032                    None
2033                };
2034                if let Some(msg) = msg {
2035                    self.send(msg).await?;
2036                }
2037                command_complete!()
2038            }
2039            ExecuteResponse::Subscribing {
2040                rx,
2041                ctx_extra,
2042                instance_id,
2043            } => {
2044                if fetch_portal_name.is_none() {
2045                    let mut msg = ErrorResponse::notice(
2046                        SqlState::WARNING,
2047                        "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
2048                    );
2049                    if self.adapter_client.session().vars().application_name() == "psql" {
2050                        msg.hint = Some(
2051                            "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
2052                                .into(),
2053                        )
2054                    }
2055                    self.send(msg).await?;
2056                    self.conn.flush().await?;
2057                }
2058                let row_desc =
2059                    row_desc.expect("missing row description for ExecuteResponse::Subscribing");
2060                let (result, statement_ended_execution_reason) = match self
2061                    .send_rows(
2062                        row_desc,
2063                        portal_name,
2064                        InProgressRows::new(RecordFirstRowStream::new(
2065                            Box::new(UnboundedReceiverStream::new(rx)),
2066                            execute_started,
2067                            &self.adapter_client,
2068                            Some(instance_id),
2069                            None,
2070                        )),
2071                        max_rows,
2072                        get_response,
2073                        fetch_portal_name,
2074                        timeout,
2075                    )
2076                    .await
2077                {
2078                    Err(e) => {
2079                        // This is an error communicating with the connection.
2080                        // We consider that to be a cancelation, rather than a query error.
2081                        (Err(e), StatementEndedExecutionReason::Canceled)
2082                    }
2083                    Ok((ok, SendRowsEndedReason::Canceled)) => {
2084                        (Ok(ok), StatementEndedExecutionReason::Canceled)
2085                    }
2086                    Ok((
2087                        ok,
2088                        SendRowsEndedReason::Success {
2089                            result_size,
2090                            rows_returned,
2091                        },
2092                    )) => (
2093                        Ok(ok),
2094                        StatementEndedExecutionReason::Success {
2095                            result_size: Some(result_size),
2096                            rows_returned: Some(rows_returned),
2097                            execution_strategy: None,
2098                        },
2099                    ),
2100                    Ok((ok, SendRowsEndedReason::Errored { error })) => {
2101                        (Ok(ok), StatementEndedExecutionReason::Errored { error })
2102                    }
2103                };
2104                self.adapter_client
2105                    .retire_execute(ctx_extra, statement_ended_execution_reason);
2106                return result;
2107            }
2108            ExecuteResponse::CopyTo { format, resp } => {
2109                let row_desc =
2110                    row_desc.expect("missing row description for ExecuteResponse::CopyTo");
2111                match *resp {
2112                    ExecuteResponse::Subscribing {
2113                        rx,
2114                        ctx_extra,
2115                        instance_id,
2116                    } => {
2117                        let (result, statement_ended_execution_reason) = match self
2118                            .copy_rows(
2119                                format,
2120                                row_desc,
2121                                RecordFirstRowStream::new(
2122                                    Box::new(UnboundedReceiverStream::new(rx)),
2123                                    execute_started,
2124                                    &self.adapter_client,
2125                                    Some(instance_id),
2126                                    None,
2127                                ),
2128                            )
2129                            .await
2130                        {
2131                            Err(e) => {
2132                                // This is an error communicating with the connection.
2133                                // We consider that to be a cancelation, rather than a query error.
2134                                (Err(e), StatementEndedExecutionReason::Canceled)
2135                            }
2136                            Ok((
2137                                state,
2138                                SendRowsEndedReason::Success {
2139                                    result_size,
2140                                    rows_returned,
2141                                },
2142                            )) => (
2143                                Ok(state),
2144                                StatementEndedExecutionReason::Success {
2145                                    result_size: Some(result_size),
2146                                    rows_returned: Some(rows_returned),
2147                                    execution_strategy: None,
2148                                },
2149                            ),
2150                            Ok((state, SendRowsEndedReason::Errored { error })) => {
2151                                (Ok(state), StatementEndedExecutionReason::Errored { error })
2152                            }
2153                            Ok((state, SendRowsEndedReason::Canceled)) => {
2154                                (Ok(state), StatementEndedExecutionReason::Canceled)
2155                            }
2156                        };
2157                        self.adapter_client
2158                            .retire_execute(ctx_extra, statement_ended_execution_reason);
2159                        return result;
2160                    }
2161                    ExecuteResponse::SendingRowsStreaming {
2162                        rows,
2163                        instance_id,
2164                        strategy,
2165                    } => {
2166                        // We don't need to finalize execution here;
2167                        // it was already done in the
2168                        // coordinator. Just extract the state and
2169                        // return that.
2170                        return self
2171                            .copy_rows(
2172                                format,
2173                                row_desc,
2174                                RecordFirstRowStream::new(
2175                                    Box::new(rows),
2176                                    execute_started,
2177                                    &self.adapter_client,
2178                                    Some(instance_id),
2179                                    Some(strategy),
2180                                ),
2181                            )
2182                            .await
2183                            .map(|(state, _)| state);
2184                    }
2185                    ExecuteResponse::SendingRowsImmediate { rows } => {
2186                        let span = tracing::debug_span!("sending_rows_immediate");
2187
2188                        let rows = futures::stream::once(futures::future::ready(
2189                            PeekResponseUnary::Rows(rows),
2190                        ));
2191                        // We don't need to finalize execution here;
2192                        // it was already done in the
2193                        // coordinator. Just extract the state and
2194                        // return that.
2195                        return self
2196                            .copy_rows(
2197                                format,
2198                                row_desc,
2199                                RecordFirstRowStream::new(
2200                                    Box::new(rows),
2201                                    execute_started,
2202                                    &self.adapter_client,
2203                                    None,
2204                                    Some(StatementExecutionStrategy::Constant),
2205                                ),
2206                            )
2207                            .instrument(span)
2208                            .await
2209                            .map(|(state, _)| state);
2210                    }
2211                    _ => {
2212                        return self
2213                            .send_error_and_get_state(ErrorResponse::error(
2214                                SqlState::INTERNAL_ERROR,
2215                                "unsupported COPY response type".to_string(),
2216                            ))
2217                            .await;
2218                    }
2219                };
2220            }
2221            ExecuteResponse::CopyFrom {
2222                target_id,
2223                target_name,
2224                columns,
2225                params,
2226                ctx_extra,
2227            } => {
2228                let row_desc =
2229                    row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2230                self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2231                    .await
2232            }
2233            ExecuteResponse::TransactionCommitted { params }
2234            | ExecuteResponse::TransactionRolledBack { params } => {
2235                let notify_set: mz_ore::collections::HashSet<String> = self
2236                    .adapter_client
2237                    .session()
2238                    .vars()
2239                    .notify_set()
2240                    .map(|v| v.name().to_string())
2241                    .collect();
2242
2243                // Only report on parameters that are in the notify set.
2244                for (name, value) in params
2245                    .into_iter()
2246                    .filter(|(name, _v)| notify_set.contains(*name))
2247                {
2248                    let msg = BackendMessage::ParameterStatus(name, value);
2249                    self.send(msg).await?;
2250                }
2251                command_complete!()
2252            }
2253
2254            ExecuteResponse::AlteredDefaultPrivileges
2255            | ExecuteResponse::AlteredObject(..)
2256            | ExecuteResponse::AlteredRole
2257            | ExecuteResponse::AlteredSystemConfiguration
2258            | ExecuteResponse::CreatedCluster { .. }
2259            | ExecuteResponse::CreatedClusterReplica { .. }
2260            | ExecuteResponse::CreatedConnection { .. }
2261            | ExecuteResponse::CreatedDatabase { .. }
2262            | ExecuteResponse::CreatedIndex { .. }
2263            | ExecuteResponse::CreatedIntrospectionSubscribe
2264            | ExecuteResponse::CreatedMaterializedView { .. }
2265            | ExecuteResponse::CreatedRole
2266            | ExecuteResponse::CreatedSchema { .. }
2267            | ExecuteResponse::CreatedSecret { .. }
2268            | ExecuteResponse::CreatedSink { .. }
2269            | ExecuteResponse::CreatedSource { .. }
2270            | ExecuteResponse::CreatedTable { .. }
2271            | ExecuteResponse::CreatedType
2272            | ExecuteResponse::CreatedView { .. }
2273            | ExecuteResponse::CreatedViews { .. }
2274            | ExecuteResponse::CreatedNetworkPolicy
2275            | ExecuteResponse::Comment
2276            | ExecuteResponse::Deallocate { .. }
2277            | ExecuteResponse::Deleted(..)
2278            | ExecuteResponse::DiscardedAll
2279            | ExecuteResponse::DiscardedTemp
2280            | ExecuteResponse::DroppedObject(_)
2281            | ExecuteResponse::DroppedOwned
2282            | ExecuteResponse::GrantedPrivilege
2283            | ExecuteResponse::GrantedRole
2284            | ExecuteResponse::Inserted(..)
2285            | ExecuteResponse::Copied(..)
2286            | ExecuteResponse::Prepare
2287            | ExecuteResponse::Raised
2288            | ExecuteResponse::ReassignOwned
2289            | ExecuteResponse::RevokedPrivilege
2290            | ExecuteResponse::RevokedRole
2291            | ExecuteResponse::StartedTransaction { .. }
2292            | ExecuteResponse::Updated(..)
2293            | ExecuteResponse::ValidatedConnection => {
2294                command_complete!()
2295            }
2296        };
2297
2298        assert_none!(tag, "tag created but not consumed: {:?}", tag);
2299        r
2300    }
2301
2302    #[allow(clippy::too_many_arguments)]
2303    // TODO(guswynn): figure out how to get it to compile without skip_all
2304    #[mz_ore::instrument(level = "debug")]
2305    async fn send_rows(
2306        &mut self,
2307        row_desc: RelationDesc,
2308        portal_name: String,
2309        mut rows: InProgressRows,
2310        max_rows: ExecuteCount,
2311        get_response: GetResponse,
2312        fetch_portal_name: Option<String>,
2313        timeout: ExecuteTimeout,
2314    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2315        // If this portal is being executed from a FETCH then we need to use the result
2316        // format type of the outer portal.
2317        let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2318            name
2319        } else {
2320            &portal_name
2321        };
2322        let result_formats = self
2323            .adapter_client
2324            .session()
2325            .get_portal_unverified(result_format_portal_name)
2326            .expect("valid fetch portal name for send rows")
2327            .result_formats
2328            .clone();
2329
2330        let (mut wait_once, mut deadline) = match timeout {
2331            ExecuteTimeout::None => (false, None),
2332            ExecuteTimeout::Seconds(t) => (
2333                false,
2334                Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2335            ),
2336            ExecuteTimeout::WaitOnce => (true, None),
2337        };
2338
2339        // Sanity check that the various `RelationDesc`s match up.
2340        {
2341            let portal_name_desc = &self
2342                .adapter_client
2343                .session()
2344                .get_portal_unverified(portal_name.as_str())
2345                .expect("portal should exist")
2346                .desc
2347                .relation_desc;
2348            if let Some(portal_name_desc) = portal_name_desc {
2349                soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2350            }
2351            if let Some(fetch_portal_name) = &fetch_portal_name {
2352                let fetch_portal_desc = &self
2353                    .adapter_client
2354                    .session()
2355                    .get_portal_unverified(fetch_portal_name)
2356                    .expect("portal should exist")
2357                    .desc
2358                    .relation_desc;
2359                if let Some(fetch_portal_desc) = fetch_portal_desc {
2360                    soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2361                }
2362            }
2363        }
2364
2365        self.conn.set_encode_state(
2366            row_desc
2367                .typ()
2368                .column_types
2369                .iter()
2370                .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2371                .zip_eq(result_formats)
2372                .collect(),
2373        );
2374
2375        let mut total_sent_rows = 0;
2376        let mut total_sent_bytes = 0;
2377        // want_rows is the maximum number of rows the client wants.
2378        let mut want_rows = match max_rows {
2379            ExecuteCount::All => usize::MAX,
2380            ExecuteCount::Count(count) => count,
2381        };
2382
2383        // Send rows while the client still wants them and there are still rows to send.
2384        loop {
2385            // Fetch next batch of rows, waiting for a possible requested
2386            // timeout or notice.
2387            let batch = if rows.current.is_some() {
2388                FetchResult::Rows(rows.current.take())
2389            } else if want_rows == 0 {
2390                FetchResult::Rows(None)
2391            } else {
2392                let notice_fut = self.adapter_client.session().recv_notice();
2393                // Biased: drain available data before checking the deadline.
2394                // This is critical for the WaitOnce case, where the deadline
2395                // is set to `Instant::now()` right after the first batch:
2396                // without `biased`, `recv()` and the already-expired deadline
2397                // race nondeterministically, so we might break the loop
2398                // before `no_more_rows` is set (or even before ready rows
2399                // are consumed). With an explicit `TIMEOUT`, missing a batch
2400                // right at the boundary is acceptable, but WaitOnce fires
2401                // immediately and the race is not.
2402                //
2403                // Trade-off: if `recv()` keeps returning Ready (unlikely in
2404                // practice—row processing + flush is slower than upstream
2405                // tick granularity), a `TIMEOUT` deadline could be delayed.
2406                // See database-issues#9470.
2407                tokio::select! {
2408                    biased;
2409                    err = self.conn.wait_closed() => return Err(err),
2410                    batch = rows.remaining.recv() => match batch {
2411                        None => FetchResult::Rows(None),
2412                        Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2413                        Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2414                        Some(PeekResponseUnary::DependencyDropped(dep)) => {
2415                            FetchResult::Error(dep.query_terminated_error())
2416                        }
2417                        Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2418                    },
2419                    notice = notice_fut => {
2420                        FetchResult::Notice(notice)
2421                    }
2422                    _ = time::sleep_until(
2423                        deadline.unwrap_or_else(tokio::time::Instant::now),
2424                    ), if deadline.is_some() => FetchResult::Rows(None),
2425                }
2426            };
2427
2428            match batch {
2429                FetchResult::Rows(None) => break,
2430                FetchResult::Rows(Some(mut batch_rows)) => {
2431                    if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2432                        let msg = err.to_string();
2433                        return self
2434                            .send_error_and_get_state(err.into_response(Severity::Error))
2435                            .await
2436                            .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2437                    }
2438
2439                    // If wait_once is true: the first time this fn is called it blocks (same as
2440                    // deadline == None). The second time this fn is called it should behave the
2441                    // same a 0s timeout.
2442                    if wait_once && batch_rows.peek().is_some() {
2443                        deadline = Some(tokio::time::Instant::now());
2444                        wait_once = false;
2445                    }
2446
2447                    // Send a portion of the rows.
2448                    let mut sent_rows = 0;
2449                    let mut sent_bytes = 0;
2450                    let messages = (&mut batch_rows)
2451                        // TODO(parkmycar): This is a fair bit of juggling between iterator types
2452                        // to count the total number of bytes. Alternatively we could track the
2453                        // total sent bytes in this .map(...) call, but having side effects in map
2454                        // is a code smell.
2455                        .map(|row| {
2456                            let row_len = row.byte_len();
2457                            let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2458                            (row_len, BackendMessage::DataRow(values))
2459                        })
2460                        .inspect(|(row_len, _)| {
2461                            sent_bytes += row_len;
2462                            sent_rows += 1
2463                        })
2464                        .map(|(_row_len, row)| row)
2465                        .take(want_rows);
2466                    self.send_all(messages).await?;
2467
2468                    total_sent_rows += sent_rows;
2469                    total_sent_bytes += sent_bytes;
2470                    want_rows -= sent_rows;
2471
2472                    // If we have sent the number of requested rows, put the remainder of the batch
2473                    // (if any) back and stop sending.
2474                    if want_rows == 0 {
2475                        if batch_rows.peek().is_some() {
2476                            rows.current = Some(batch_rows);
2477                        }
2478                        break;
2479                    }
2480
2481                    self.conn.flush().await?;
2482                }
2483                FetchResult::Notice(notice) => {
2484                    self.send(notice.into_response()).await?;
2485                    self.conn.flush().await?;
2486                }
2487                FetchResult::Error(text) => {
2488                    return self
2489                        .send_error_and_get_state(ErrorResponse::error(
2490                            SqlState::INTERNAL_ERROR,
2491                            text.clone(),
2492                        ))
2493                        .await
2494                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2495                }
2496                FetchResult::Canceled => {
2497                    return self
2498                        .send_error_and_get_state(ErrorResponse::error(
2499                            SqlState::QUERY_CANCELED,
2500                            "canceling statement due to user request",
2501                        ))
2502                        .await
2503                        .map(|state| (state, SendRowsEndedReason::Canceled));
2504                }
2505            }
2506        }
2507
2508        let portal = self
2509            .adapter_client
2510            .session()
2511            .get_portal_unverified_mut(&portal_name)
2512            .expect("valid portal name for send rows");
2513
2514        let saw_rows = rows.remaining.saw_rows;
2515        let no_more_rows = rows.no_more_rows();
2516        let metric_recorded = rows.remaining.metric_recorded;
2517        let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2518
2519        if no_more_rows && !metric_recorded {
2520            rows.remaining.metric_recorded = true;
2521        }
2522
2523        // Always return rows back, even if it's empty. This prevents an unclosed
2524        // portal from re-executing after it has been emptied.
2525        *portal.state = PortalState::InProgress(Some(rows));
2526
2527        let fetch_portal = fetch_portal_name.map(|name| {
2528            self.adapter_client
2529                .session()
2530                .get_portal_unverified_mut(&name)
2531                .expect("valid fetch portal")
2532        });
2533        let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2534        self.send(response_message).await?;
2535
2536        // Attend to metrics if there are no more rows. Only record once per stream
2537        // to avoid polluting the histogram when an exhausted cursor is FETCHed again.
2538        if no_more_rows && !metric_recorded {
2539            let statement_type = if let Some(stmt) = &self
2540                .adapter_client
2541                .session()
2542                .get_portal_unverified(&portal_name)
2543                .expect("valid portal name for send_rows")
2544                .stmt
2545            {
2546                metrics::statement_type_label_value(stmt.deref())
2547            } else {
2548                "no-statement"
2549            };
2550            let duration = if saw_rows {
2551                recorded_first_row_instant
2552                    .expect("recorded_first_row_instant because saw_rows")
2553                    .elapsed()
2554            } else {
2555                // If the result is empty, then we define time from first to last row as 0.
2556                // (Note that, currently, an empty result involves a PeekResponse with 0 rows, which
2557                // does flip `saw_rows`, so this code path is currently not exercised.)
2558                Duration::ZERO
2559            };
2560            self.adapter_client
2561                .inner()
2562                .metrics()
2563                .result_rows_first_to_last_byte_seconds
2564                .with_label_values(&[statement_type])
2565                .observe(duration.as_secs_f64());
2566        }
2567
2568        Ok((
2569            State::Ready,
2570            SendRowsEndedReason::Success {
2571                result_size: u64::cast_from(total_sent_bytes),
2572                rows_returned: u64::cast_from(total_sent_rows),
2573            },
2574        ))
2575    }
2576
2577    #[mz_ore::instrument(level = "debug")]
2578    async fn copy_rows(
2579        &mut self,
2580        format: CopyFormat,
2581        row_desc: RelationDesc,
2582        mut stream: RecordFirstRowStream,
2583    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2584        let (row_format, encode_format) = match format {
2585            CopyFormat::Text => (
2586                CopyFormatParams::Text(CopyTextFormatParams::default()),
2587                Format::Text,
2588            ),
2589            CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2590            CopyFormat::Csv => (
2591                CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2592                Format::Text,
2593            ),
2594            CopyFormat::Parquet => {
2595                let text = "Parquet format is not supported".to_string();
2596                return self
2597                    .send_error_and_get_state(ErrorResponse::error(
2598                        SqlState::INTERNAL_ERROR,
2599                        text.clone(),
2600                    ))
2601                    .await
2602                    .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2603            }
2604        };
2605
2606        // Binary encoding is not implemented for some types (e.g., list, map,
2607        // and aclitem). Unlike the extended query protocol's Bind handler, COPY
2608        // does not validate this when binding the portal: the portal's result
2609        // formats describe the `CopyData` wrapper, not the COPY format itself,
2610        // so the Bind handler explicitly skips `COPY TO` statements. We must
2611        // therefore check here, before streaming any rows, otherwise
2612        // `encode_binary` would panic mid-stream (SQL-323).
2613        if let CopyFormat::Binary = format {
2614            if let Some(msg) = row_desc
2615                .iter_types()
2616                .find_map(|ty| mz_pgrepr::Value::binary_encoding_error(&ty.scalar_type).err())
2617            {
2618                return self
2619                    .send_error_and_get_state(ErrorResponse::error(
2620                        SqlState::PROTOCOL_VIOLATION,
2621                        msg,
2622                    ))
2623                    .await
2624                    .map(|state| {
2625                        (
2626                            state,
2627                            SendRowsEndedReason::Errored {
2628                                error: msg.to_string(),
2629                            },
2630                        )
2631                    });
2632            }
2633        }
2634
2635        let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2636            mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2637        };
2638
2639        let typ = row_desc.typ();
2640        let column_formats = iter::repeat(encode_format)
2641            .take(typ.column_types.len())
2642            .collect();
2643        self.send(BackendMessage::CopyOutResponse {
2644            overall_format: encode_format,
2645            column_formats,
2646        })
2647        .await?;
2648
2649        // In Postgres, binary copy has a header that is followed (in the same
2650        // CopyData) by the first row. In order to replicate their behavior, use a
2651        // common vec that we can extend one time now and then fill up with the encode
2652        // functions.
2653        let mut out = Vec::new();
2654
2655        if let CopyFormat::Binary = format {
2656            // 11-byte signature.
2657            out.extend(b"PGCOPY\n\xFF\r\n\0");
2658            // 32-bit flags field.
2659            out.extend([0, 0, 0, 0]);
2660            // 32-bit header extension length field.
2661            out.extend([0, 0, 0, 0]);
2662        }
2663
2664        let mut count = 0;
2665        let mut total_sent_bytes = 0;
2666        loop {
2667            tokio::select! {
2668                e = self.conn.wait_closed() => return Err(e),
2669                batch = stream.recv() => match batch {
2670                    None => break,
2671                    Some(PeekResponseUnary::Error(text)) => {
2672                        let err =
2673                            ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone());
2674                        return self
2675                            .send_error_and_get_state(err)
2676                            .await
2677                            .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2678                    }
2679                    Some(PeekResponseUnary::DependencyDropped(dep)) => {
2680                        let err = dep.to_concurrent_dependency_drop();
2681                        let text = err.to_string();
2682                        let resp = err.into_response(Severity::Error);
2683                        return self
2684                            .send_error_and_get_state(resp)
2685                            .await
2686                            .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2687                    }
2688                    Some(PeekResponseUnary::Canceled) => {
2689                        return self.send_error_and_get_state(ErrorResponse::error(
2690                                SqlState::QUERY_CANCELED,
2691                                "canceling statement due to user request",
2692                            ))
2693                            .await.map(|state| (state, SendRowsEndedReason::Canceled));
2694                    }
2695                    Some(PeekResponseUnary::Rows(mut rows)) => {
2696                        count += rows.count();
2697                        while let Some(row) = rows.next() {
2698                            total_sent_bytes += row.byte_len();
2699                            encode_fn(row, typ, &mut out)?;
2700                            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2701                                .await?;
2702                        }
2703                    }
2704                },
2705                notice = self.adapter_client.session().recv_notice() => {
2706                    self.send(notice.into_response())
2707                        .await?;
2708                    self.conn.flush().await?;
2709                }
2710            }
2711
2712            self.conn.flush().await?;
2713        }
2714        // Send required trailers.
2715        if let CopyFormat::Binary = format {
2716            let trailer: i16 = -1;
2717            out.extend(trailer.to_be_bytes());
2718            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2719                .await?;
2720        }
2721
2722        let tag = format!("COPY {}", count);
2723        self.send(BackendMessage::CopyDone).await?;
2724        self.send(BackendMessage::CommandComplete { tag }).await?;
2725        Ok((
2726            State::Ready,
2727            SendRowsEndedReason::Success {
2728                result_size: u64::cast_from(total_sent_bytes),
2729                rows_returned: u64::cast_from(count),
2730            },
2731        ))
2732    }
2733
2734    /// Handles the copy-in mode of the postgres protocol from transferring
2735    /// data to the server.
2736    #[instrument(level = "debug")]
2737    async fn copy_from(
2738        &mut self,
2739        target_id: CatalogItemId,
2740        target_name: String,
2741        columns: Vec<ColumnIndex>,
2742        params: CopyFormatParams<'static>,
2743        row_desc: RelationDesc,
2744        mut ctx_extra: ExecuteContextGuard,
2745    ) -> Result<State, io::Error> {
2746        let res = self
2747            .copy_from_inner(
2748                target_id,
2749                target_name,
2750                columns,
2751                params,
2752                row_desc,
2753                &mut ctx_extra,
2754            )
2755            .await;
2756        match &res {
2757            Ok(State::Ready) => {
2758                self.adapter_client.retire_execute(
2759                    ctx_extra,
2760                    StatementEndedExecutionReason::Success {
2761                        result_size: None,
2762                        rows_returned: None,
2763                        execution_strategy: None,
2764                    },
2765                );
2766            }
2767            Ok(State::Done) => {
2768                // The connection closed gracefully without sending us a `CopyDone`,
2769                // causing us to just drop the copy request.
2770                // For the purposes of statement logging, we count this as a cancellation.
2771                self.adapter_client
2772                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2773            }
2774            Err(e) => {
2775                self.adapter_client.retire_execute(
2776                    ctx_extra,
2777                    StatementEndedExecutionReason::Errored {
2778                        error: format!("{e}"),
2779                    },
2780                );
2781            }
2782            Ok(State::Drain) => {}
2783        }
2784        res
2785    }
2786
2787    async fn copy_from_inner(
2788        &mut self,
2789        target_id: CatalogItemId,
2790        target_name: String,
2791        columns: Vec<ColumnIndex>,
2792        params: CopyFormatParams<'static>,
2793        row_desc: RelationDesc,
2794        ctx_extra: &mut ExecuteContextGuard,
2795    ) -> Result<State, io::Error> {
2796        let typ = row_desc.typ();
2797        let column_formats = vec![Format::Text; typ.column_types.len()];
2798        self.send(BackendMessage::CopyInResponse {
2799            overall_format: Format::Text,
2800            column_formats,
2801        })
2802        .await?;
2803        self.conn.flush().await?;
2804
2805        // Set up the parallel streaming batch builders in the coordinator.
2806        let writer = match self
2807            .adapter_client
2808            .start_copy_from_stdin(
2809                target_id,
2810                target_name.clone(),
2811                columns.clone(),
2812                row_desc.clone(),
2813                params.clone(),
2814            )
2815            .await
2816        {
2817            Ok(writer) => writer,
2818            Err(e) => {
2819                // Drain remaining CopyData/CopyDone/CopyFail messages from the
2820                // socket. Since CopyInResponse was already sent, the client may
2821                // have pipelined copy data that we must consume before returning
2822                // the error, otherwise they'd be misinterpreted as top-level
2823                // protocol messages and cause a deadlock.
2824                loop {
2825                    match self.conn.recv().await? {
2826                        Some(FrontendMessage::CopyData(_)) => {}
2827                        Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2828                            break;
2829                        }
2830                        Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2831                        Some(_) => break,
2832                        None => return Ok(State::Done),
2833                    }
2834                }
2835                self.adapter_client.retire_execute(
2836                    std::mem::take(ctx_extra),
2837                    StatementEndedExecutionReason::Errored {
2838                        error: e.to_string(),
2839                    },
2840                );
2841                return self
2842                    .send_error_and_get_state(e.into_response(Severity::Error))
2843                    .await;
2844            }
2845        };
2846
2847        // Enable copy mode on the codec to skip aggregate buffer size checks.
2848        self.conn.set_copy_mode(true);
2849
2850        // Batch size for splitting raw data across parallel workers (~32MB).
2851        const BATCH_SIZE: usize = 32 * 1024 * 1024;
2852        let max_copy_from_row_size = self
2853            .adapter_client
2854            .get_system_vars()
2855            .await
2856            .max_copy_from_row_size()
2857            .try_into()
2858            .unwrap_or(usize::MAX);
2859
2860        let mut data = Vec::new();
2861        let mut row_scanner = CopyRowScanner::new(&params);
2862        let num_workers = writer.batch_txs.len();
2863        let mut next_worker: usize = 0;
2864        let mut saw_copy_done = false;
2865        let mut saw_end_marker = false;
2866        let mut copy_from_error: Option<(SqlState, String)> = None;
2867
2868        // Receive loop: accumulate CopyData, split at row boundaries,
2869        // round-robin raw chunks to parallel batch builder workers.
2870        loop {
2871            let message = self.conn.recv().await?;
2872            match message {
2873                Some(FrontendMessage::CopyData(buf)) => {
2874                    if saw_end_marker {
2875                        // Per PostgreSQL COPY behavior, ignore all bytes after
2876                        // the end-of-copy marker until CopyDone.
2877                        continue;
2878                    }
2879                    data.extend(buf);
2880                    row_scanner.scan_new_bytes(&data);
2881
2882                    if let Some(end_pos) = row_scanner.end_marker_end() {
2883                        data.truncate(end_pos);
2884                        row_scanner.on_truncate(end_pos);
2885                        saw_end_marker = true;
2886                    }
2887
2888                    // Guard against pathological single rows that never terminate.
2889                    if row_scanner.current_row_size(data.len()) > max_copy_from_row_size {
2890                        copy_from_error = Some((
2891                            SqlState::INSUFFICIENT_RESOURCES,
2892                            format!(
2893                                "COPY FROM STDIN row exceeded max_copy_from_row_size \
2894                                 ({max_copy_from_row_size} bytes)"
2895                            ),
2896                        ));
2897                        break;
2898                    }
2899
2900                    // When buffer exceeds batch size, split at the last complete row
2901                    // and send the complete rows chunk to the next worker.
2902                    let mut send_failed = false;
2903                    while data.len() >= BATCH_SIZE {
2904                        let split_pos = match row_scanner.last_row_end() {
2905                            Some(pos) => pos,
2906                            None => break, // no complete row yet
2907                        };
2908                        let remainder = data.split_off(split_pos);
2909                        let chunk = std::mem::replace(&mut data, remainder);
2910                        row_scanner.on_split(split_pos);
2911                        if writer.batch_txs[next_worker].send(chunk).await.is_err() {
2912                            send_failed = true;
2913                            break;
2914                        }
2915                        next_worker = (next_worker + 1) % num_workers;
2916                    }
2917                    // Worker dropped (likely errored) — stop sending,
2918                    // fall through to completion_rx for the real error.
2919                    if send_failed {
2920                        break;
2921                    }
2922                }
2923                Some(FrontendMessage::CopyDone) => {
2924                    // Send any remaining data to the next worker.
2925                    if !data.is_empty() {
2926                        let chunk = std::mem::take(&mut data);
2927                        // Ignore send failure — completion_rx will have the error.
2928                        let _ = writer.batch_txs[next_worker].send(chunk).await;
2929                    }
2930                    saw_copy_done = true;
2931                    break;
2932                }
2933                Some(FrontendMessage::CopyFail(err)) => {
2934                    self.adapter_client.retire_execute(
2935                        std::mem::take(ctx_extra),
2936                        StatementEndedExecutionReason::Canceled,
2937                    );
2938                    // Drop the writer to signal cancellation to the background tasks.
2939                    drop(writer);
2940                    self.conn.set_copy_mode(false);
2941                    return self
2942                        .send_error_and_get_state(ErrorResponse::error(
2943                            SqlState::QUERY_CANCELED,
2944                            format!("COPY from stdin failed: {}", err),
2945                        ))
2946                        .await;
2947                }
2948                Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2949                Some(_) => {
2950                    let msg = "unexpected message type during COPY from stdin";
2951                    self.adapter_client.retire_execute(
2952                        std::mem::take(ctx_extra),
2953                        StatementEndedExecutionReason::Errored {
2954                            error: msg.to_string(),
2955                        },
2956                    );
2957                    drop(writer);
2958                    self.conn.set_copy_mode(false);
2959                    return self
2960                        .send_error_and_get_state(ErrorResponse::error(
2961                            SqlState::PROTOCOL_VIOLATION,
2962                            msg,
2963                        ))
2964                        .await;
2965                }
2966                None => {
2967                    drop(writer);
2968                    self.conn.set_copy_mode(false);
2969                    return Ok(State::Done);
2970                }
2971            }
2972        }
2973
2974        // If we exited the receive loop before seeing `CopyDone` (e.g. because
2975        // a worker failed and dropped its channel), keep draining COPY input to
2976        // avoid desynchronizing the protocol state machine.
2977        if !saw_copy_done {
2978            loop {
2979                match self.conn.recv().await? {
2980                    Some(FrontendMessage::CopyData(_)) => {}
2981                    Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2982                        break;
2983                    }
2984                    Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2985                    Some(_) => {
2986                        let msg = "unexpected message type during COPY from stdin";
2987                        self.adapter_client.retire_execute(
2988                            std::mem::take(ctx_extra),
2989                            StatementEndedExecutionReason::Errored {
2990                                error: msg.to_string(),
2991                            },
2992                        );
2993                        drop(writer);
2994                        self.conn.set_copy_mode(false);
2995                        return self
2996                            .send_error_and_get_state(ErrorResponse::error(
2997                                SqlState::PROTOCOL_VIOLATION,
2998                                msg,
2999                            ))
3000                            .await;
3001                    }
3002                    None => {
3003                        drop(writer);
3004                        self.conn.set_copy_mode(false);
3005                        return Ok(State::Done);
3006                    }
3007                }
3008            }
3009        }
3010
3011        if let Some((code, msg)) = copy_from_error {
3012            self.adapter_client.retire_execute(
3013                std::mem::take(ctx_extra),
3014                StatementEndedExecutionReason::Errored { error: msg.clone() },
3015            );
3016            drop(writer);
3017            self.conn.set_copy_mode(false);
3018            return self
3019                .send_error_and_get_state(ErrorResponse::error(code, msg))
3020                .await;
3021        }
3022
3023        self.conn.set_copy_mode(false);
3024
3025        // Drop all senders to signal EOF to the background batch builders.
3026        // If copy_err is set, a worker already failed — dropping the senders
3027        // will cause remaining workers to stop, and we'll get the real error
3028        // from completion_rx below.
3029        drop(writer.batch_txs);
3030
3031        // Wait for all parallel workers to finish building batches.
3032        let (proto_batches, row_count) = match writer.completion_rx.await {
3033            Ok(Ok(result)) => result,
3034            Ok(Err(e)) => {
3035                self.adapter_client.retire_execute(
3036                    std::mem::take(ctx_extra),
3037                    StatementEndedExecutionReason::Errored {
3038                        error: e.to_string(),
3039                    },
3040                );
3041                return self
3042                    .send_error_and_get_state(e.into_response(Severity::Error))
3043                    .await;
3044            }
3045            Err(_) => {
3046                let msg = "COPY FROM STDIN: background batch builder tasks dropped";
3047                self.adapter_client.retire_execute(
3048                    std::mem::take(ctx_extra),
3049                    StatementEndedExecutionReason::Errored {
3050                        error: msg.to_string(),
3051                    },
3052                );
3053                return self
3054                    .send_error_and_get_state(ErrorResponse::error(SqlState::INTERNAL_ERROR, msg))
3055                    .await;
3056            }
3057        };
3058
3059        // Stage all batches in the session's transaction for atomic commit.
3060        if let Err(e) = self
3061            .adapter_client
3062            .stage_copy_from_stdin_batches(target_id, proto_batches)
3063        {
3064            self.adapter_client.retire_execute(
3065                std::mem::take(ctx_extra),
3066                StatementEndedExecutionReason::Errored {
3067                    error: e.to_string(),
3068                },
3069            );
3070            return self
3071                .send_error_and_get_state(e.into_response(Severity::Error))
3072                .await;
3073        }
3074
3075        let tag = format!("COPY {}", row_count);
3076        self.send(BackendMessage::CommandComplete { tag }).await?;
3077
3078        Ok(State::Ready)
3079    }
3080
3081    #[instrument(level = "debug")]
3082    async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
3083        let notices = self
3084            .adapter_client
3085            .session()
3086            .drain_notices()
3087            .into_iter()
3088            .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
3089        self.send_all(notices).await?;
3090        Ok(())
3091    }
3092
3093    #[instrument(level = "debug")]
3094    async fn send_error_and_get_state(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
3095        assert!(err.severity.is_error());
3096        debug!(
3097            "cid={} error code={}",
3098            self.adapter_client.session().conn_id(),
3099            err.code.code()
3100        );
3101        let is_fatal = err.severity.is_fatal();
3102        self.send(BackendMessage::ErrorResponse(err)).await?;
3103
3104        let txn = self.adapter_client.session().transaction();
3105        match txn {
3106            // Error can be called from describe and parse and so might not be in an active
3107            // transaction.
3108            TransactionStatus::Default | TransactionStatus::Failed(_) => {}
3109            // In Started (i.e., a single statement), cleanup ourselves.
3110            TransactionStatus::Started(_) => {
3111                self.rollback_transaction().await?;
3112            }
3113            // Implicit transactions also clear themselves.
3114            TransactionStatus::InTransactionImplicit(_) => {
3115                self.rollback_transaction().await?;
3116            }
3117            // Explicit transactions move to failed.
3118            TransactionStatus::InTransaction(_) => {
3119                self.adapter_client.fail_transaction();
3120            }
3121        };
3122        if is_fatal {
3123            Ok(State::Done)
3124        } else {
3125            Ok(State::Drain)
3126        }
3127    }
3128
3129    #[instrument(level = "debug")]
3130    async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
3131        self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
3132            SqlState::IN_FAILED_SQL_TRANSACTION,
3133            ABORTED_TXN_MSG,
3134        )))
3135        .await?;
3136        Ok(State::Drain)
3137    }
3138
3139    fn is_aborted_txn(&mut self) -> bool {
3140        matches!(
3141            self.adapter_client.session().transaction(),
3142            TransactionStatus::Failed(_)
3143        )
3144    }
3145}
3146
3147fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
3148    match (formats.len(), n) {
3149        (0, e) => Ok(vec![Format::Text; e]),
3150        (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
3151        (a, e) if a == e => Ok(formats),
3152        (a, e) => Err(format!(
3153            "expected {} field format specifiers, but got {}",
3154            e, a
3155        )),
3156    }
3157}
3158
3159fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
3160    match &stmt_desc.relation_desc {
3161        Some(desc) if !stmt_desc.is_copy => {
3162            BackendMessage::RowDescription(message::encode_row_description(desc, formats))
3163        }
3164        _ => BackendMessage::NoData,
3165    }
3166}
3167
3168type GetResponse = fn(
3169    max_rows: ExecuteCount,
3170    total_sent_rows: usize,
3171    fetch_portal: Option<PortalRefMut>,
3172) -> BackendMessage;
3173
3174// A GetResponse used by send_rows during execute messages on portals or for
3175// simple query messages.
3176fn portal_exec_message(
3177    max_rows: ExecuteCount,
3178    total_sent_rows: usize,
3179    _fetch_portal: Option<PortalRefMut>,
3180) -> BackendMessage {
3181    // If max_rows is not specified, we will always send back a CommandComplete. If
3182    // max_rows is specified, we only send CommandComplete if there were more rows
3183    // requested than were remaining. That is, if max_rows == number of rows that
3184    // were remaining before sending (not that are remaining after sending), then
3185    // we still send a PortalSuspended. The number of remaining rows after the rows
3186    // have been sent doesn't matter. This matches postgres.
3187    match max_rows {
3188        ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
3189            BackendMessage::PortalSuspended
3190        }
3191        _ => BackendMessage::CommandComplete {
3192            tag: format!("SELECT {}", total_sent_rows),
3193        },
3194    }
3195}
3196
3197// A GetResponse used by send_rows during FETCH queries.
3198fn fetch_message(
3199    _max_rows: ExecuteCount,
3200    total_sent_rows: usize,
3201    fetch_portal: Option<PortalRefMut>,
3202) -> BackendMessage {
3203    let tag = format!("FETCH {}", total_sent_rows);
3204    if let Some(portal) = fetch_portal {
3205        *portal.state = PortalState::Completed(Some(tag.clone()));
3206    }
3207    BackendMessage::CommandComplete { tag }
3208}
3209
3210fn get_authenticator(
3211    authenticator_kind: listeners::AuthenticatorKind,
3212    frontegg: Option<FronteggAuthenticator>,
3213    oidc: GenericOidcAuthenticator,
3214    adapter_client: mz_adapter::Client,
3215) -> Authenticator {
3216    match authenticator_kind {
3217        listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
3218            "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
3219        )),
3220        listeners::AuthenticatorKind::Password => Authenticator::Password(adapter_client),
3221        listeners::AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client),
3222        listeners::AuthenticatorKind::Oidc => Authenticator::Oidc(oidc),
3223        listeners::AuthenticatorKind::None => Authenticator::None,
3224    }
3225}
3226
3227#[derive(Debug, Copy, Clone)]
3228enum ExecuteCount {
3229    All,
3230    Count(usize),
3231}
3232
3233// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
3234fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
3235    match stmt {
3236        // Add PREPARE to this if we ever support it.
3237        Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
3238        None => false,
3239    }
3240}
3241
3242#[derive(Debug)]
3243enum FetchResult {
3244    Rows(Option<Box<dyn RowIterator + Send + Sync>>),
3245    Canceled,
3246    Error(String),
3247    Notice(AdapterNotice),
3248}
3249
3250#[derive(Debug)]
3251struct CopyRowScanner {
3252    scan_pos: usize,
3253    last_row_end: Option<usize>,
3254    end_marker_end: Option<usize>,
3255    // Byte offset within `data` at which the in-progress CSV record begins.
3256    // Used to verify the end-of-copy marker against the raw input bytes,
3257    // distinguishing a literal `\.` line from a quoted CSV value `"\."`
3258    // whose decoded form is also `\.`.
3259    record_start: usize,
3260    csv: Option<CsvScanState>,
3261}
3262
3263#[derive(Debug)]
3264struct CsvScanState {
3265    reader: csv_core::Reader,
3266    output: Vec<u8>,
3267    ends: Vec<usize>,
3268    skip_first_record: bool,
3269}
3270
3271impl CopyRowScanner {
3272    fn new(params: &CopyFormatParams<'_>) -> Self {
3273        let csv = match params {
3274            CopyFormatParams::Csv(CopyCsvFormatParams {
3275                delimiter,
3276                quote,
3277                escape,
3278                header,
3279                ..
3280            }) => Some(CsvScanState::new(*delimiter, *quote, *escape, *header)),
3281            _ => None,
3282        };
3283
3284        CopyRowScanner {
3285            scan_pos: 0,
3286            last_row_end: None,
3287            end_marker_end: None,
3288            record_start: 0,
3289            csv,
3290        }
3291    }
3292
3293    fn scan_new_bytes(&mut self, data: &[u8]) {
3294        if self.scan_pos >= data.len() {
3295            return;
3296        }
3297
3298        if let Some(csv) = self.csv.as_mut() {
3299            let mut input = &data[self.scan_pos..];
3300            let mut consumed = 0usize;
3301            while !input.is_empty() {
3302                let (result, n_input, _n_output, _n_ends) =
3303                    csv.reader
3304                        .read_record(input, &mut csv.output, &mut csv.ends);
3305                consumed += n_input;
3306                input = &input[n_input..];
3307
3308                match result {
3309                    ReadRecordResult::InputEmpty => break,
3310                    ReadRecordResult::OutputFull => {
3311                        if n_input == 0 {
3312                            csv.output
3313                                .resize(csv.output.len().saturating_mul(2).max(1), 0);
3314                        }
3315                    }
3316                    ReadRecordResult::OutputEndsFull => {
3317                        if n_input == 0 {
3318                            csv.ends.resize(csv.ends.len().saturating_mul(2).max(1), 0);
3319                        }
3320                    }
3321                    ReadRecordResult::Record | ReadRecordResult::End => {
3322                        let row_end = self.scan_pos + consumed;
3323                        self.last_row_end = Some(row_end);
3324                        if self.end_marker_end.is_none() {
3325                            let is_marker = if csv.skip_first_record {
3326                                csv.skip_first_record = false;
3327                                false
3328                            } else {
3329                                // Detect the marker against the raw input
3330                                // bytes, not the CSV-decoded record. A quoted
3331                                // data row `"\."` decodes to `\.` but must be
3332                                // imported as data; only a bare `\.` line
3333                                // terminates the COPY.
3334                                let raw = &data[self.record_start..row_end];
3335                                // csv-core ends a CRLF record after the `\r`,
3336                                // leaving the trailing `\n` as the leading byte
3337                                // of the next record's span; a CR-only record
3338                                // ends in a lone `\r`. So a `\.` marker record's
3339                                // raw span can be `\.\n` (LF), `\n\.\r` (CRLF)
3340                                // or `\.\r` (CR). Trim CR/LF from both ends
3341                                // before comparing — a trailing-only strip would
3342                                // miss the CRLF/CR forms. Quoted `"\."` data
3343                                // keeps its surrounding quotes after trimming and
3344                                // is therefore correctly rejected.
3345                                let start = raw
3346                                    .iter()
3347                                    .take_while(|&&b| b == b'\r' || b == b'\n')
3348                                    .count();
3349                                let trailing = raw[start..]
3350                                    .iter()
3351                                    .rev()
3352                                    .take_while(|&&b| b == b'\r' || b == b'\n')
3353                                    .count();
3354                                let trimmed = &raw[start..raw.len() - trailing];
3355                                trimmed == b"\\."
3356                            };
3357                            if is_marker {
3358                                self.end_marker_end = Some(row_end);
3359                                self.record_start = row_end;
3360                                break;
3361                            }
3362                        }
3363                        self.record_start = row_end;
3364                    }
3365                }
3366            }
3367        } else {
3368            let mut row_start = self.last_row_end.unwrap_or(0);
3369            for (offset, b) in data[self.scan_pos..].iter().enumerate() {
3370                if *b == b'\n' {
3371                    let row_end = self.scan_pos + offset + 1;
3372                    self.last_row_end = Some(row_end);
3373                    if self.end_marker_end.is_none() {
3374                        let row = &data[row_start..row_end];
3375                        if row.get(0..2) == Some(b"\\.") {
3376                            self.end_marker_end = Some(row_end);
3377                            break;
3378                        }
3379                    }
3380                    row_start = row_end;
3381                }
3382            }
3383        }
3384
3385        self.scan_pos = data.len();
3386    }
3387
3388    fn last_row_end(&self) -> Option<usize> {
3389        self.last_row_end
3390    }
3391
3392    fn end_marker_end(&self) -> Option<usize> {
3393        self.end_marker_end
3394    }
3395
3396    fn current_row_size(&self, data_len: usize) -> usize {
3397        data_len.saturating_sub(self.last_row_end.unwrap_or(0))
3398    }
3399
3400    fn on_split(&mut self, split_pos: usize) {
3401        self.scan_pos = self.scan_pos.saturating_sub(split_pos);
3402        self.last_row_end = None;
3403        self.end_marker_end = self
3404            .end_marker_end
3405            .and_then(|end| end.checked_sub(split_pos));
3406        // `record_start` is only maintained for the CSV path; the text and
3407        // binary paths leave it at 0. For CSV, splits always occur at a
3408        // completed-row boundary, so the in-progress record (if any) starts at
3409        // the new beginning of the buffer. Assert that invariant so the
3410        // `saturating_sub` below doesn't silently paper over a bug that
3411        // bisected an in-progress record — but only when CSV is in use, since
3412        // otherwise `record_start` is meaninglessly 0.
3413        soft_assert_or_log!(
3414            self.csv.is_none() || self.record_start >= split_pos,
3415            "split bisected an in-progress CSV record: record_start={} < split_pos={}",
3416            self.record_start,
3417            split_pos,
3418        );
3419        self.record_start = self.record_start.saturating_sub(split_pos);
3420    }
3421
3422    fn on_truncate(&mut self, new_len: usize) {
3423        self.scan_pos = self.scan_pos.min(new_len);
3424        self.last_row_end = self.last_row_end.filter(|&end| end <= new_len);
3425        self.end_marker_end = self.end_marker_end.filter(|&end| end <= new_len);
3426        self.record_start = self.record_start.min(new_len);
3427    }
3428}
3429
3430impl CsvScanState {
3431    fn new(delimiter: u8, quote: u8, escape: u8, header: bool) -> Self {
3432        let (double_quote, escape) = if quote == escape {
3433            (true, None)
3434        } else {
3435            (false, Some(escape))
3436        };
3437        CsvScanState {
3438            reader: csv_core::ReaderBuilder::new()
3439                .delimiter(delimiter)
3440                .quote(quote)
3441                .double_quote(double_quote)
3442                .escape(escape)
3443                .build(),
3444            output: vec![0; 1],
3445            ends: vec![0; 1],
3446            skip_first_record: header,
3447        }
3448    }
3449}
3450
3451#[cfg(test)]
3452mod test {
3453    use super::*;
3454
3455    #[mz_ore::test]
3456    fn test_copy_row_scanner_end_marker_line_endings() {
3457        // The pgwire COPY row scanner must detect a bare `\.` end-of-copy
3458        // marker for every line ending, and must never mistake a quoted
3459        // `"\."` data row for it. csv-core ends a CRLF record after the `\r`
3460        // (leaving the `\n` as the next record's leading byte), so the raw
3461        // record span of a `\.` marker is `\.\n` (LF), `\n\.\r` (CRLF) or
3462        // `\.\r` (CR); a trailing-only strip would miss the CRLF/CR forms and
3463        // silently import post-marker rows.
3464        let params = CopyFormatParams::Csv(CopyCsvFormatParams::default());
3465
3466        let marker_end = |data: &[u8]| -> Option<usize> {
3467            let mut scanner = CopyRowScanner::new(&params);
3468            scanner.scan_new_bytes(data);
3469            scanner.end_marker_end()
3470        };
3471
3472        for eol in [&b"\n"[..], b"\r\n", b"\r"] {
3473            let join = |lines: &[&str]| -> Vec<u8> {
3474                let mut out = Vec::new();
3475                for line in lines {
3476                    out.extend_from_slice(line.as_bytes());
3477                    out.extend_from_slice(eol);
3478                }
3479                out
3480            };
3481
3482            // Bare `\.` (the marker is the second record, so record_start has
3483            // already advanced past the orphaned terminator of `first`).
3484            // csv-core reports the record after a single terminator byte, so
3485            // the marker boundary sits just past `first<eol>\.` + one byte.
3486            let data = join(&["first", "\\.", "after"]);
3487            let mut prefix = Vec::new();
3488            prefix.extend_from_slice(b"first");
3489            prefix.extend_from_slice(eol);
3490            prefix.extend_from_slice(b"\\.");
3491            assert_eq!(
3492                marker_end(&data),
3493                Some(prefix.len() + 1),
3494                "bare marker, eol={eol:?}"
3495            );
3496
3497            // Quoted "\." is data, not the marker.
3498            let data = join(&["before", "\"\\.\"", "after"]);
3499            assert_eq!(marker_end(&data), None, "quoted marker, eol={eol:?}");
3500        }
3501    }
3502
3503    #[mz_ore::test]
3504    fn test_copy_row_scanner_non_csv_split() {
3505        // Regression: `record_start` is only maintained for the CSV path; the
3506        // text and binary paths leave it at 0. `on_split` must therefore not
3507        // assert `record_start >= split_pos` for those formats — that fires on
3508        // every split of a large text/binary COPY stream (soft-assertions
3509        // panic under test). Mirrors `COPY ... FROM STDIN` (default text
3510        // format) splitting at a row boundary once the buffer fills.
3511        for params in [
3512            CopyFormatParams::Text(CopyTextFormatParams::default()),
3513            CopyFormatParams::Binary,
3514        ] {
3515            let mut scanner = CopyRowScanner::new(&params);
3516            let data = b"1\thello world\t2\tsome text value here\n\
3517                         3\thello world\t6\tsome text value here\n";
3518            scanner.scan_new_bytes(data);
3519            let split_pos = scanner.last_row_end().expect("a complete row");
3520            assert!(split_pos > 0, "params={params:?}");
3521            // Must not panic via the CSV-only `on_split` soft-assert.
3522            scanner.on_split(split_pos);
3523            assert_eq!(scanner.record_start, 0, "params={params:?}");
3524        }
3525    }
3526
3527    #[mz_ore::test]
3528    fn test_parse_options() {
3529        struct TestCase {
3530            input: &'static str,
3531            expect: Result<Vec<(&'static str, &'static str)>, ()>,
3532        }
3533        let tests = vec![
3534            TestCase {
3535                input: "",
3536                expect: Ok(vec![]),
3537            },
3538            TestCase {
3539                input: "--key",
3540                expect: Err(()),
3541            },
3542            TestCase {
3543                input: "--key=val",
3544                expect: Ok(vec![("key", "val")]),
3545            },
3546            TestCase {
3547                input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
3548                expect: Ok(vec![
3549                    ("key", "val"),
3550                    ("key2", "val2"),
3551                    ("key3", "val3"),
3552                    ("key4", "val4"),
3553                    ("key5", "val5"),
3554                ]),
3555            },
3556            TestCase {
3557                input: r#"-c\ key=val"#,
3558                expect: Ok(vec![(" key", "val")]),
3559            },
3560            TestCase {
3561                input: "--key=val -ckey2 val2",
3562                expect: Err(()),
3563            },
3564            // Unclear what this should do.
3565            TestCase {
3566                input: "--key=",
3567                expect: Ok(vec![("key", "")]),
3568            },
3569        ];
3570        for test in tests {
3571            let got = parse_options(test.input);
3572            let expect = test.expect.map(|r| {
3573                r.into_iter()
3574                    .map(|(k, v)| (k.to_owned(), v.to_owned()))
3575                    .collect()
3576            });
3577            assert_eq!(got, expect, "input: {}", test.input);
3578        }
3579    }
3580
3581    #[mz_ore::test]
3582    fn test_parse_option() {
3583        struct TestCase {
3584            input: &'static str,
3585            expect: Result<(&'static str, &'static str), ()>,
3586        }
3587        let tests = vec![
3588            TestCase {
3589                input: "",
3590                expect: Err(()),
3591            },
3592            TestCase {
3593                input: "--",
3594                expect: Err(()),
3595            },
3596            TestCase {
3597                input: "--c",
3598                expect: Err(()),
3599            },
3600            TestCase {
3601                input: "a=b",
3602                expect: Err(()),
3603            },
3604            TestCase {
3605                input: "--a=b",
3606                expect: Ok(("a", "b")),
3607            },
3608            TestCase {
3609                input: "--ca=b",
3610                expect: Ok(("ca", "b")),
3611            },
3612            TestCase {
3613                input: "-ca=b",
3614                expect: Ok(("a", "b")),
3615            },
3616            // Unclear what this should error, but at least test it.
3617            TestCase {
3618                input: "--=",
3619                expect: Ok(("", "")),
3620            },
3621        ];
3622        for test in tests {
3623            let got = parse_option(test.input);
3624            assert_eq!(got, test.expect, "input: {}", test.input);
3625        }
3626    }
3627
3628    #[mz_ore::test]
3629    fn test_split_options() {
3630        struct TestCase {
3631            input: &'static str,
3632            expect: Vec<&'static str>,
3633        }
3634        let tests = vec![
3635            TestCase {
3636                input: "",
3637                expect: vec![],
3638            },
3639            TestCase {
3640                input: "  ",
3641                expect: vec![],
3642            },
3643            TestCase {
3644                input: " a ",
3645                expect: vec!["a"],
3646            },
3647            TestCase {
3648                input: "  ab     cd   ",
3649                expect: vec!["ab", "cd"],
3650            },
3651            TestCase {
3652                input: r#"  ab\     cd   "#,
3653                expect: vec!["ab ", "cd"],
3654            },
3655            TestCase {
3656                input: r#"  ab\\     cd   "#,
3657                expect: vec![r#"ab\"#, "cd"],
3658            },
3659            TestCase {
3660                input: r#"  ab\\\     cd   "#,
3661                expect: vec![r#"ab\ "#, "cd"],
3662            },
3663            TestCase {
3664                input: r#"  ab\\\ cd   "#,
3665                expect: vec![r#"ab\ cd"#],
3666            },
3667            TestCase {
3668                input: r#"  ab\\\cd   "#,
3669                expect: vec![r#"ab\cd"#],
3670            },
3671            TestCase {
3672                input: r#"a\"#,
3673                expect: vec!["a"],
3674            },
3675            TestCase {
3676                input: r#"a\ "#,
3677                expect: vec!["a "],
3678            },
3679            TestCase {
3680                input: r#"\"#,
3681                expect: vec![],
3682            },
3683            TestCase {
3684                input: r#"\ "#,
3685                expect: vec![r#" "#],
3686            },
3687            TestCase {
3688                input: r#" \ "#,
3689                expect: vec![r#" "#],
3690            },
3691            TestCase {
3692                input: r#"\  "#,
3693                expect: vec![r#" "#],
3694            },
3695        ];
3696        for test in tests {
3697            let got = split_options(test.input);
3698            assert_eq!(got, test.expect, "input: {}", test.input);
3699        }
3700    }
3701
3702    #[mz_ore::test]
3703    fn test_is_jwt() {
3704        // A real JWT header decodes successfully.
3705        assert!(is_jwt("eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0.signature"));
3706        // Not JWTs: plain strings, wrong segment count, non-JSON headers.
3707        for s in [
3708            "",
3709            "secure_password",
3710            "p4ss.w0rd",
3711            "aaa.bbb.ccc",
3712            "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0",
3713            "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0.sig.extra",
3714        ] {
3715            assert!(!is_jwt(s), "is_jwt({s:?})");
3716        }
3717    }
3718}