Skip to main content

mz_pgwire/
protocol.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use csv_core::ReadRecordResult;
21use futures::future::{BoxFuture, FutureExt, pending};
22use itertools::Itertools;
23use mz_adapter::client::RecordFirstRowStream;
24use mz_adapter::session::{
25    EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState, Session,
26    SessionConfig, TransactionStatus,
27};
28use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
29use mz_adapter::{
30    AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics,
31    verify_datum_desc,
32};
33use mz_auth::Authenticated;
34use mz_auth::password::Password;
35use mz_authenticator::{Authenticator, GenericOidcAuthenticator};
36use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
37use mz_ore::cast::CastFrom;
38use mz_ore::netio::AsyncReady;
39use mz_ore::now::{EpochMillis, SYSTEM_TIME};
40use mz_ore::str::StrExt;
41use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log, soft_assert_or_log};
42use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
43use mz_pgwire_common::{
44    ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
45    VERSIONS,
46};
47use mz_repr::{
48    CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
49    SqlRelationType, SqlScalarType,
50};
51use mz_server_core::TlsMode;
52use mz_server_core::listeners;
53use mz_server_core::listeners::AllowedRoles;
54use mz_sql::ast::display::AstDisplay;
55use mz_sql::ast::{
56    CopyDirection, CopyStatement, CopyTarget, FetchDirection, Ident, Raw, Statement,
57};
58use mz_sql::parse::StatementParseResult;
59use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
60use mz_sql::session::metadata::SessionMetadata;
61use mz_sql::session::user::INTERNAL_USER_NAMES;
62use mz_sql::session::vars::VarInput;
63use postgres::error::SqlState;
64use tokio::io::{self, AsyncRead, AsyncWrite};
65use tokio::select;
66use tokio::time::{self};
67use tokio_metrics::TaskMetrics;
68use tokio_stream::wrappers::UnboundedReceiverStream;
69use tracing::{Instrument, debug, debug_span, warn};
70use uuid::Uuid;
71
72use crate::codec::{
73    FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
74};
75use crate::message::{
76    self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
77    SASLServerFirstMessage,
78};
79
80/// Reports whether the given stream begins with a pgwire handshake.
81///
82/// To avoid false negatives, there must be at least eight bytes in `buf`.
83pub fn match_handshake(buf: &[u8]) -> bool {
84    // The pgwire StartupMessage looks like this:
85    //
86    //     i32 - Length of entire message.
87    //     i32 - Protocol version number.
88    //     [String] - Arbitrary key-value parameters of any length.
89    //
90    // Since arbitrary parameters can be included in the StartupMessage, the
91    // first Int32 is worthless, since the message could have any length.
92    // Instead, we sniff the protocol version number.
93    if buf.len() < 8 {
94        return false;
95    }
96    let version = NetworkEndian::read_i32(&buf[4..8]);
97    VERSIONS.contains(&version)
98}
99
100/// Parameters for the [`run`] function.
101pub struct RunParams<'a, A, I>
102where
103    I: Iterator<Item = TaskMetrics> + Send,
104{
105    /// The TLS mode of the pgwire server.
106    pub tls_mode: Option<TlsMode>,
107    /// A client for the adapter.
108    pub adapter_client: mz_adapter::Client,
109    /// The connection to the client.
110    pub conn: &'a mut FramedConn<A>,
111    /// The universally unique identifier for the connection.
112    pub conn_uuid: Uuid,
113    /// The protocol version that the client provided in the startup message.
114    pub version: i32,
115    /// The parameters that the client provided in the startup message.
116    pub params: BTreeMap<String, String>,
117    /// Frontegg JWT authenticator.
118    pub frontegg: Option<FronteggAuthenticator>,
119    /// OIDC authenticator.
120    pub oidc: GenericOidcAuthenticator,
121    /// The authentication method defined by the server's listener
122    /// configuration.
123    pub authenticator_kind: listeners::AuthenticatorKind,
124    /// Global connection limit and count
125    pub active_connection_counter: ConnectionCounter,
126    /// Helm chart version
127    pub helm_chart_version: Option<String>,
128    /// Whether to allow reserved users (ie: mz_system).
129    pub allowed_roles: AllowedRoles,
130    /// Tokio metrics
131    pub tokio_metrics_intervals: I,
132}
133
134/// Runs a pgwire connection to completion.
135///
136/// This involves responding to `FrontendMessage::StartupMessage` and all future
137/// requests until the client terminates the connection or a fatal error occurs.
138///
139/// Note that this function returns successfully even upon delivering a fatal
140/// error to the client. It only returns `Err` if an unexpected I/O error occurs
141/// while communicating with the client, e.g., if the connection is severed in
142/// the middle of a request.
143#[mz_ore::instrument(level = "debug")]
144pub async fn run<'a, A, I>(
145    RunParams {
146        tls_mode,
147        adapter_client,
148        conn,
149        conn_uuid,
150        version,
151        mut params,
152        frontegg,
153        oidc,
154        authenticator_kind,
155        active_connection_counter,
156        helm_chart_version,
157        allowed_roles,
158        tokio_metrics_intervals,
159    }: RunParams<'a, A, I>,
160) -> Result<(), io::Error>
161where
162    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
163    I: Iterator<Item = TaskMetrics> + Send,
164{
165    if version != VERSION_3 {
166        return conn
167            .send(ErrorResponse::fatal(
168                SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
169                "server does not support the client's requested protocol version",
170            ))
171            .await;
172    }
173
174    let user = params.remove("user").unwrap_or_else(String::new);
175    let options = parse_options(params.get("options").unwrap_or(&String::new()));
176    let authenticator =
177        get_authenticator(authenticator_kind, frontegg, oidc, adapter_client.clone());
178    // TODO move this somewhere it can be shared with HTTP
179    let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
180    // this is a superset of internal users
181    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
182    let role_allowed = match allowed_roles {
183        AllowedRoles::Normal => !is_reserved_user,
184        AllowedRoles::Internal => is_internal_user,
185        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
186    };
187    if !role_allowed {
188        let msg = format!("unauthorized login to user '{user}'");
189        return conn
190            .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
191            .await;
192    }
193
194    if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
195        return conn.send(err).await;
196    }
197
198    let authenticator_kind = authenticator.kind();
199
200    let (mut session, expired) = match authenticator {
201        Authenticator::Frontegg(frontegg) => {
202            let password = match request_cleartext_password(conn).await {
203                Ok(password) => password,
204                Err(PasswordRequestError::IoError(e)) => return Err(e),
205                Err(PasswordRequestError::InvalidPasswordError(e)) => {
206                    return conn.send(e).await;
207                }
208            };
209
210            let auth_response = frontegg.authenticate(&user, &password).await;
211            match auth_response {
212                // Create a session based on the auth session.
213                //
214                // In particular, it's important that the username come from the
215                // auth session, as Frontegg may return an email address with
216                // different casing than the user supplied via the pgwire
217                // username fN
218                Ok((mut auth_session, authenticated)) => {
219                    let session = adapter_client.new_session(
220                        SessionConfig {
221                            conn_id: conn.conn_id().clone(),
222                            uuid: conn_uuid,
223                            user: auth_session.user().into(),
224                            client_ip: conn.peer_addr().clone(),
225                            external_metadata_rx: Some(auth_session.external_metadata_rx()),
226                            helm_chart_version,
227                            authenticator_kind,
228                            groups: None,
229                        },
230                        authenticated,
231                    );
232                    let expired = async move { auth_session.expired().await };
233                    (session, expired.left_future())
234                }
235                Err(err) => {
236                    warn!(?err, "pgwire connection failed authentication");
237                    return conn
238                        .send(ErrorResponse::fatal(
239                            SqlState::INVALID_PASSWORD,
240                            "invalid password",
241                        ))
242                        .await;
243                }
244            }
245        }
246        Authenticator::Oidc(oidc) => {
247            // OIDC listener: accepts either a JWT (uses OIDC authentication) or a
248            // plain SQL password (uses SQL password authentication).
249            let password = match request_cleartext_password(conn).await {
250                Ok(password) => password,
251                Err(PasswordRequestError::IoError(e)) => return Err(e),
252                Err(PasswordRequestError::InvalidPasswordError(e)) => {
253                    return conn.send(e).await;
254                }
255            };
256            if is_jwt(&password) {
257                let auth_response = oidc.authenticate(&password, Some(&user)).await;
258                match auth_response {
259                    Ok((mut claims, authenticated)) => {
260                        let groups = claims.groups.take();
261                        let session = adapter_client.new_session(
262                            SessionConfig {
263                                conn_id: conn.conn_id().clone(),
264                                uuid: conn_uuid,
265                                user: std::mem::take(&mut claims.user),
266                                client_ip: conn.peer_addr().clone(),
267                                external_metadata_rx: None,
268                                helm_chart_version,
269                                authenticator_kind,
270                                groups,
271                            },
272                            authenticated,
273                        );
274                        // No invalidation of the auth session once authenticated,
275                        // so auth session lasts indefinitely.
276                        (session, pending().right_future())
277                    }
278                    Err(err) => {
279                        warn!(?err, "pgwire connection failed authentication");
280                        return conn.send(err.into_response()).await;
281                    }
282                }
283            } else {
284                let session = match authenticate_with_password(
285                    conn,
286                    &adapter_client,
287                    user,
288                    Password(password),
289                    conn_uuid,
290                    helm_chart_version,
291                )
292                .await
293                {
294                    Ok(session) => session,
295                    Err(PasswordRequestError::IoError(e)) => return Err(e),
296                    Err(PasswordRequestError::InvalidPasswordError(e)) => {
297                        return conn.send(e).await;
298                    }
299                };
300                (session, pending().right_future())
301            }
302        }
303        Authenticator::Password(adapter_client) => {
304            let password = match request_cleartext_password(conn).await {
305                Ok(password) => password,
306                Err(PasswordRequestError::IoError(e)) => return Err(e),
307                Err(PasswordRequestError::InvalidPasswordError(e)) => {
308                    return conn.send(e).await;
309                }
310            };
311            let session = match authenticate_with_password(
312                conn,
313                &adapter_client,
314                user,
315                Password(password),
316                conn_uuid,
317                helm_chart_version,
318            )
319            .await
320            {
321                Ok(session) => session,
322                Err(PasswordRequestError::IoError(e)) => return Err(e),
323                Err(PasswordRequestError::InvalidPasswordError(e)) => {
324                    return conn.send(e).await;
325                }
326            };
327            // No frontegg check, so auth session lasts indefinitely.
328            (session, pending().right_future())
329        }
330        Authenticator::Sasl(adapter_client) => {
331            // Start the handshake
332            conn.send(BackendMessage::AuthenticationSASL).await?;
333            conn.flush().await?;
334            // Get the initial response indicating chosen mechanism
335            let (mechanism, initial_response) = match conn.recv().await? {
336                Some(FrontendMessage::RawAuthentication(data)) => {
337                    match decode_sasl_initial_response(Cursor::new(&data)).ok() {
338                        Some(FrontendMessage::SASLInitialResponse {
339                            gs2_header,
340                            mechanism,
341                            initial_response,
342                        }) => {
343                            // We do not support channel binding
344                            if gs2_header.channel_binding_enabled() {
345                                return conn
346                                    .send(ErrorResponse::fatal(
347                                        SqlState::PROTOCOL_VIOLATION,
348                                        "channel binding not supported",
349                                    ))
350                                    .await;
351                            }
352                            (mechanism, initial_response)
353                        }
354                        _ => {
355                            return conn
356                                .send(ErrorResponse::fatal(
357                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
358                                    "expected SASLInitialResponse message",
359                                ))
360                                .await;
361                        }
362                    }
363                }
364                _ => {
365                    return conn
366                        .send(ErrorResponse::fatal(
367                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
368                            "expected SASLInitialResponse message",
369                        ))
370                        .await;
371                }
372            };
373
374            if mechanism != "SCRAM-SHA-256" {
375                return conn
376                    .send(ErrorResponse::fatal(
377                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
378                        "unsupported SASL mechanism",
379                    ))
380                    .await;
381            }
382
383            if initial_response.nonce.len() > 256 {
384                return conn
385                    .send(ErrorResponse::fatal(
386                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
387                        "nonce too long",
388                    ))
389                    .await;
390            }
391
392            let (server_first_message_raw, mock_hash) = match adapter_client
393                .generate_sasl_challenge(&user, &initial_response.nonce)
394                .await
395            {
396                Ok(response) => {
397                    let server_first_message_raw = format!(
398                        "r={},s={},i={}",
399                        response.nonce, response.salt, response.iteration_count
400                    );
401
402                    let client_key = [0u8; 32];
403                    let server_key = [1u8; 32];
404                    let mock_hash = format!(
405                        "SCRAM-SHA-256${}:{}${}:{}",
406                        response.iteration_count,
407                        response.salt,
408                        BASE64_STANDARD.encode(client_key),
409                        BASE64_STANDARD.encode(server_key)
410                    );
411
412                    conn.send(BackendMessage::AuthenticationSASLContinue(
413                        SASLServerFirstMessage {
414                            iteration_count: response.iteration_count,
415                            nonce: response.nonce,
416                            salt: response.salt,
417                        },
418                    ))
419                    .await?;
420                    conn.flush().await?;
421                    (server_first_message_raw, mock_hash)
422                }
423                Err(e) => {
424                    return conn.send(e.into_response(Severity::Fatal)).await;
425                }
426            };
427
428            let authenticated = match conn.recv().await? {
429                Some(FrontendMessage::RawAuthentication(data)) => {
430                    match decode_sasl_response(Cursor::new(&data)).ok() {
431                        Some(FrontendMessage::SASLResponse(response)) => {
432                            let auth_message = format!(
433                                "{},{},{}",
434                                initial_response.client_first_message_bare_raw,
435                                server_first_message_raw,
436                                response.client_final_message_bare_raw
437                            );
438                            if response.proof.len() > 1024 {
439                                return conn
440                                    .send(ErrorResponse::fatal(
441                                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
442                                        "proof too long",
443                                    ))
444                                    .await;
445                            }
446                            match adapter_client
447                                .verify_sasl_proof(
448                                    &user,
449                                    &response.proof,
450                                    &auth_message,
451                                    &mock_hash,
452                                )
453                                .await
454                            {
455                                Ok((proof_response, authenticated)) => {
456                                    conn.send(BackendMessage::AuthenticationSASLFinal(
457                                        SASLServerFinalMessage {
458                                            kind: SASLServerFinalMessageKinds::Verifier(
459                                                proof_response.verifier,
460                                            ),
461                                            extensions: vec![],
462                                        },
463                                    ))
464                                    .await?;
465                                    conn.flush().await?;
466                                    authenticated
467                                }
468                                Err(_) => {
469                                    return conn
470                                        .send(ErrorResponse::fatal(
471                                            SqlState::INVALID_PASSWORD,
472                                            "invalid password",
473                                        ))
474                                        .await;
475                                }
476                            }
477                        }
478                        _ => {
479                            return conn
480                                .send(ErrorResponse::fatal(
481                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
482                                    "expected SASLResponse message",
483                                ))
484                                .await;
485                        }
486                    }
487                }
488                _ => {
489                    return conn
490                        .send(ErrorResponse::fatal(
491                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
492                            "expected SASLResponse message",
493                        ))
494                        .await;
495                }
496            };
497
498            let session = adapter_client.new_session(
499                SessionConfig {
500                    conn_id: conn.conn_id().clone(),
501                    uuid: conn_uuid,
502                    user,
503                    client_ip: conn.peer_addr().clone(),
504                    external_metadata_rx: None,
505                    helm_chart_version,
506                    authenticator_kind,
507                    groups: None,
508                },
509                authenticated,
510            );
511            // No frontegg check, so auth session lasts indefinitely.
512            let auth_session = pending().right_future();
513            (session, auth_session)
514        }
515
516        Authenticator::None => {
517            let session = adapter_client.new_session(
518                SessionConfig {
519                    conn_id: conn.conn_id().clone(),
520                    uuid: conn_uuid,
521                    user,
522                    client_ip: conn.peer_addr().clone(),
523                    external_metadata_rx: None,
524                    helm_chart_version,
525                    authenticator_kind,
526                    groups: None,
527                },
528                Authenticated,
529            );
530            // No frontegg check, so auth session lasts indefinitely.
531            let auth_session = pending().right_future();
532            (session, auth_session)
533        }
534    };
535
536    let system_vars = adapter_client.get_system_vars().await;
537    for (name, value) in params {
538        let settings = match name.as_str() {
539            "options" => match &options {
540                Ok(opts) => opts,
541                Err(()) => {
542                    session.add_notice(AdapterNotice::BadStartupSetting {
543                        name,
544                        reason: "could not parse".into(),
545                    });
546                    continue;
547                }
548            },
549            _ => &vec![(name, value)],
550        };
551        for (key, val) in settings {
552            const LOCAL: bool = false;
553            // TODO: Issuing an error here is better than what we did before
554            // (silently ignore errors on set), but erroring the connection
555            // might be the better behavior. We maybe need to support more
556            // options sent by psql and drivers before we can safely do this.
557            if let Err(err) = session
558                .vars_mut()
559                .set(&system_vars, key, VarInput::Flat(val), LOCAL)
560            {
561                session.add_notice(AdapterNotice::BadStartupSetting {
562                    name: key.clone(),
563                    reason: err.to_string(),
564                });
565            }
566        }
567    }
568    session
569        .vars_mut()
570        .end_transaction(EndTransactionAction::Commit);
571
572    let _guard = match active_connection_counter.allocate_connection(session.user()) {
573        Ok(drop_connection) => drop_connection,
574        Err(e) => {
575            let e: AdapterError = e.into();
576            return conn.send(e.into_response(Severity::Fatal)).await;
577        }
578    };
579
580    // Register session with adapter.
581    let mut adapter_client = match adapter_client.startup(session).await {
582        Ok(adapter_client) => adapter_client,
583        Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
584    };
585
586    let mut buf = vec![BackendMessage::AuthenticationOk];
587    for var in adapter_client.session().vars().notify_set() {
588        buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
589    }
590    buf.push(BackendMessage::BackendKeyData {
591        conn_id: adapter_client.session().conn_id().unhandled(),
592        secret_key: adapter_client.session().secret_key(),
593    });
594    buf.extend(
595        adapter_client
596            .session()
597            .drain_notices()
598            .into_iter()
599            .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
600    );
601    buf.push(BackendMessage::ReadyForQuery(
602        adapter_client.session().transaction().into(),
603    ));
604    conn.send_all(buf).await?;
605    conn.flush().await?;
606
607    let machine = StateMachine {
608        conn,
609        adapter_client,
610        txn_needs_commit: false,
611        tokio_metrics_intervals,
612    };
613
614    select! {
615        r = machine.run() => {
616            // Errors produced internally (like MAX_REQUEST_SIZE being exceeded) should send an
617            // error to the client informing them why the connection was closed. We still want to
618            // return the original error up the stack, though, so we skip error checking during conn
619            // operations.
620            if let Err(err) = &r {
621                let _ = conn
622                    .send(ErrorResponse::fatal(
623                        SqlState::CONNECTION_FAILURE,
624                        err.to_string(),
625                    ))
626                    .await;
627                let _ = conn.flush().await;
628            }
629            r
630        },
631        _ = expired => {
632            conn
633                .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
634                .await?;
635            conn.flush().await
636        }
637    }
638}
639
640/// Decides if a given password is a JWT by checking
641/// if we can decode its header.
642fn is_jwt(password: &str) -> bool {
643    jsonwebtoken::decode_header(password).is_ok()
644}
645
646/// Returns (name, value) session settings pairs from an options value.
647///
648/// From Postgres, see pg_split_opts in postinit.c and process_postgres_switches
649/// in postgres.c.
650fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
651    let opts = split_options(value);
652    let mut pairs = Vec::with_capacity(opts.len());
653    let mut seen_prefix = false;
654    for opt in opts {
655        if !seen_prefix {
656            if opt == "-c" {
657                seen_prefix = true;
658            } else {
659                let (key, val) = parse_option(&opt)?;
660                pairs.push((key.to_owned(), val.to_owned()));
661            }
662        } else {
663            let (key, val) = opt.split_once('=').ok_or(())?;
664            pairs.push((key.to_owned(), val.to_owned()));
665            seen_prefix = false;
666        }
667    }
668    Ok(pairs)
669}
670
671/// Returns the parsed key and value from option of the form `--key=value`, `-c
672/// key=value`, or `-ckey=value`. Keys replace `-` with `_`. Returns an error if
673/// there was some other prefix.
674fn parse_option(option: &str) -> Result<(&str, &str), ()> {
675    let (key, value) = option.split_once('=').ok_or(())?;
676    for prefix in &["-c", "--"] {
677        if let Some(key) = key.strip_prefix(prefix) {
678            return Ok((key, value));
679        }
680    }
681    Err(())
682}
683
684/// Splits value by any number of spaces except those preceded by `\`.
685fn split_options(value: &str) -> Vec<String> {
686    let mut strs = Vec::new();
687    // Need to build a string because of the escaping, so we can't simply
688    // subslice into value, and this isn't called enough to need to make it
689    // smart so it only builds a string if needed.
690    let mut current = String::new();
691    let mut was_slash = false;
692    for c in value.chars() {
693        was_slash = match c {
694            ' ' => {
695                if was_slash {
696                    current.push(' ');
697                } else if !current.is_empty() {
698                    // To ignore multiple spaces in a row, only push if current
699                    // is not empty.
700                    strs.push(std::mem::take(&mut current));
701                }
702                false
703            }
704            '\\' => {
705                if was_slash {
706                    // Two slashes in a row will add a slash and not escape the
707                    // next char.
708                    current.push('\\');
709                    false
710                } else {
711                    true
712                }
713            }
714            _ => {
715                current.push(c);
716                false
717            }
718        };
719    }
720    // A `\` at the end will be ignored.
721    if !current.is_empty() {
722        strs.push(current);
723    }
724    strs
725}
726
727enum PasswordRequestError {
728    InvalidPasswordError(ErrorResponse),
729    IoError(io::Error),
730}
731
732impl From<io::Error> for PasswordRequestError {
733    fn from(e: io::Error) -> Self {
734        PasswordRequestError::IoError(e)
735    }
736}
737
738/// Requests a cleartext password from a connection and returns it if it is valid.
739/// Sends an error response in the connection if the password
740/// is not valid.
741async fn request_cleartext_password<A>(
742    conn: &mut FramedConn<A>,
743) -> Result<String, PasswordRequestError>
744where
745    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
746{
747    conn.send(BackendMessage::AuthenticationCleartextPassword)
748        .await?;
749    conn.flush().await?;
750
751    if let Some(message) = conn.recv().await? {
752        if let FrontendMessage::RawAuthentication(data) = message {
753            if let Some(FrontendMessage::Password { password }) =
754                decode_password(Cursor::new(&data)).ok()
755            {
756                return Ok(password);
757            }
758        }
759    }
760
761    Err(PasswordRequestError::InvalidPasswordError(
762        ErrorResponse::fatal(
763            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
764            "expected Password message",
765        ),
766    ))
767}
768
769/// Helper for password-based authentication using AdapterClient
770/// and returns an authenticated session.
771async fn authenticate_with_password<A>(
772    conn: &FramedConn<A>,
773    adapter_client: &mz_adapter::Client,
774    user: String,
775    password: Password,
776    conn_uuid: Uuid,
777    helm_chart_version: Option<String>,
778) -> Result<Session, PasswordRequestError>
779where
780    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
781{
782    let authenticated = match adapter_client.authenticate(&user, &password).await {
783        Ok(authenticated) => authenticated,
784        Err(err) => {
785            warn!(?err, "pgwire connection failed authentication");
786            return Err(PasswordRequestError::InvalidPasswordError(
787                ErrorResponse::fatal(SqlState::INVALID_PASSWORD, "invalid password"),
788            ));
789        }
790    };
791
792    let session = adapter_client.new_session(
793        SessionConfig {
794            conn_id: conn.conn_id().clone(),
795            uuid: conn_uuid,
796            user,
797            client_ip: conn.peer_addr().clone(),
798            external_metadata_rx: None,
799            helm_chart_version,
800            authenticator_kind: mz_auth::AuthenticatorKind::Password,
801            groups: None,
802        },
803        authenticated,
804    );
805
806    Ok(session)
807}
808
809#[derive(Debug)]
810enum State {
811    Ready,
812    Drain,
813    Done,
814}
815
816struct StateMachine<'a, A, I>
817where
818    I: Iterator<Item = TaskMetrics> + Send + 'a,
819{
820    conn: &'a mut FramedConn<A>,
821    adapter_client: mz_adapter::SessionClient,
822    txn_needs_commit: bool,
823    tokio_metrics_intervals: I,
824}
825
826enum SendRowsEndedReason {
827    Success {
828        result_size: u64,
829        rows_returned: u64,
830    },
831    Errored {
832        error: String,
833    },
834    Canceled,
835}
836
837const ABORTED_TXN_MSG: &str =
838    "current transaction is aborted, commands ignored until end of transaction block";
839
840impl<'a, A, I> StateMachine<'a, A, I>
841where
842    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
843    I: Iterator<Item = TaskMetrics> + Send + 'a,
844{
845    // Manually desugar this (don't use `async fn run`) here because a much better
846    // error message is produced if there are problems with Send or other traits
847    // somewhere within the Future.
848    #[allow(clippy::manual_async_fn)]
849    #[mz_ore::instrument(level = "debug")]
850    fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
851        async move {
852            let mut state = State::Ready;
853            loop {
854                self.send_pending_notices().await?;
855                state = match state {
856                    State::Ready => self.advance_ready().await?,
857                    State::Drain => self.advance_drain().await?,
858                    State::Done => return Ok(()),
859                };
860                self.adapter_client
861                    .add_idle_in_transaction_session_timeout();
862            }
863        }
864    }
865
866    #[instrument(level = "debug")]
867    async fn advance_ready(&mut self) -> Result<State, io::Error> {
868        // Start a new metrics interval before the `recv()` call.
869        self.tokio_metrics_intervals
870            .next()
871            .expect("infinite iterator");
872
873        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
874        let message = select! {
875            biased;
876
877            // `recv_timeout()` is cancel-safe as per it's docs.
878            Some(timeout) = self.adapter_client.recv_timeout() => {
879                let err: AdapterError = timeout.into();
880                let conn_id = self.adapter_client.session().conn_id();
881                tracing::warn!("session timed out, conn_id {}", conn_id);
882
883                // Process the error, doing any state cleanup.
884                let error_response = err.into_response(Severity::Fatal);
885                let error_state = self.send_error_and_get_state(error_response).await;
886
887                // Terminate __after__ we do any cleanup.
888                self.adapter_client.terminate().await;
889
890                // We must wait for the client to send a request before we can send the error response.
891                // Due to the PG wire protocol, we can't send an ErrorResponse unless it is in response
892                // to a client message.
893                let _ = self.conn.recv().await?;
894                return error_state;
895            },
896            // `recv()` is cancel-safe as per it's docs.
897            message = self.conn.recv() => message?,
898        };
899
900        // Take the metrics since just before the `recv`.
901        let interval = self
902            .tokio_metrics_intervals
903            .next()
904            .expect("infinite iterator");
905        let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
906
907        // TODO(ggevay): Consider subtracting the scheduling delay from `received`. It's not obvious
908        // whether we should do this, because the result wouldn't exactly correspond to either first
909        // byte received or last byte received (for msgs that arrive in more than one network packet).
910        let received = SYSTEM_TIME();
911
912        self.adapter_client
913            .remove_idle_in_transaction_session_timeout();
914
915        // NOTE(guswynn): we could consider adding spans to all message types. Currently
916        // only a few message types seem useful.
917        let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
918
919        let start = message.as_ref().map(|_| Instant::now());
920        let next_state = match message {
921            Some(FrontendMessage::Query { sql }) => {
922                let query_root_span =
923                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
924                query_root_span.follows_from(tracing::Span::current());
925                self.query(sql, received)
926                    .instrument(query_root_span)
927                    .await?
928            }
929            Some(FrontendMessage::Parse {
930                name,
931                sql,
932                param_types,
933            }) => self.parse(name, sql, param_types).await?,
934            Some(FrontendMessage::Bind {
935                portal_name,
936                statement_name,
937                param_formats,
938                raw_params,
939                result_formats,
940            }) => {
941                self.bind(
942                    portal_name,
943                    statement_name,
944                    param_formats,
945                    raw_params,
946                    result_formats,
947                )
948                .await?
949            }
950            Some(FrontendMessage::Execute {
951                portal_name,
952                max_rows,
953            }) => {
954                let max_rows = match usize::try_from(max_rows) {
955                    Ok(0) | Err(_) => ExecuteCount::All, // If `max_rows < 0`, no limit.
956                    Ok(n) => ExecuteCount::Count(n),
957                };
958                let execute_root_span =
959                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
960                execute_root_span.follows_from(tracing::Span::current());
961                let state = self
962                    .execute(
963                        portal_name,
964                        max_rows,
965                        portal_exec_message,
966                        None,
967                        ExecuteTimeout::None,
968                        None,
969                        Some(received),
970                    )
971                    .instrument(execute_root_span)
972                    .await?;
973                // In PostgreSQL, when using the extended query protocol, some statements may
974                // trigger an eager commit of the current implicit transaction,
975                // see: <https://git.postgresql.org/gitweb/?p=postgresql.git&a=commitdiff&h=f92944137>.
976                //
977                // In Materialize, however, we eagerly commit every statement outside of an explicit
978                // transaction when using the extended query protocol. This allows us to eliminate
979                // the possibility of a multiple statement implicit transaction, which in turn
980                // allows us to apply single-statement optimizations to queries issued in implicit
981                // transactions in the extended query protocol.
982                //
983                // We don't immediately commit here to allow users to page through the portal if
984                // necessary. Committing the transaction would destroy the portal before the next
985                // Execute command has a chance to resume it. So we instead mark the transaction
986                // for commit the next time that `ensure_transaction` is called.
987                if self.adapter_client.session().transaction().is_implicit() {
988                    self.txn_needs_commit = true;
989                }
990                state
991            }
992            Some(FrontendMessage::DescribeStatement { name }) => {
993                self.describe_statement(&name).await?
994            }
995            Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
996            Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
997            Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
998            Some(FrontendMessage::Flush) => self.flush().await?,
999            Some(FrontendMessage::Sync) => self.sync().await?,
1000            Some(FrontendMessage::Terminate) => State::Done,
1001
1002            Some(FrontendMessage::CopyData(_))
1003            | Some(FrontendMessage::CopyDone)
1004            | Some(FrontendMessage::CopyFail(_))
1005            | Some(FrontendMessage::Password { .. })
1006            | Some(FrontendMessage::RawAuthentication(_))
1007            | Some(FrontendMessage::SASLInitialResponse { .. })
1008            | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
1009            None => State::Done,
1010        };
1011
1012        if let Some(start) = start {
1013            self.adapter_client
1014                .inner()
1015                .metrics()
1016                .pgwire_message_processing_seconds
1017                .with_label_values(&[message_name])
1018                .observe(start.elapsed().as_secs_f64());
1019        }
1020        self.adapter_client
1021            .inner()
1022            .metrics()
1023            .pgwire_recv_scheduling_delay_ms
1024            .with_label_values(&[message_name])
1025            .observe(recv_scheduling_delay_ms);
1026
1027        Ok(next_state)
1028    }
1029
1030    async fn advance_drain(&mut self) -> Result<State, io::Error> {
1031        let message = self.conn.recv().await?;
1032        if message.is_some() {
1033            self.adapter_client
1034                .remove_idle_in_transaction_session_timeout();
1035        }
1036        match message {
1037            Some(FrontendMessage::Sync) => self.sync().await,
1038            None => Ok(State::Done),
1039            _ => Ok(State::Drain),
1040        }
1041    }
1042
1043    /// Note that `lifecycle_timestamps` belongs to the whole "Simple Query", because the whole
1044    /// Simple Query is received and parsed together. This means that if there are multiple
1045    /// statements in a Simple Query, then all of them have the same `lifecycle_timestamps`.
1046    #[instrument(level = "debug")]
1047    async fn one_query(
1048        &mut self,
1049        stmt: Statement<Raw>,
1050        sql: String,
1051        lifecycle_timestamps: LifecycleTimestamps,
1052    ) -> Result<State, io::Error> {
1053        // Bind the portal. Note that this does not set the empty string prepared
1054        // statement.
1055        const EMPTY_PORTAL: &str = "";
1056        if let Err(e) = self
1057            .adapter_client
1058            .declare(EMPTY_PORTAL.to_string(), stmt, sql)
1059            .await
1060        {
1061            return self
1062                .send_error_and_get_state(e.into_response(Severity::Error))
1063                .await;
1064        }
1065        let portal = self
1066            .adapter_client
1067            .session()
1068            .get_portal_unverified_mut(EMPTY_PORTAL)
1069            .expect("unnamed portal should be present");
1070
1071        *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
1072
1073        let stmt_desc = portal.desc.clone();
1074        if !stmt_desc.param_types.is_empty() {
1075            return self
1076                .send_error_and_get_state(ErrorResponse::error(
1077                    SqlState::UNDEFINED_PARAMETER,
1078                    "there is no parameter $1",
1079                ))
1080                .await;
1081        }
1082
1083        // Maybe send row description.
1084        if let Some(relation_desc) = &stmt_desc.relation_desc {
1085            if !stmt_desc.is_copy {
1086                let formats = vec![Format::Text; stmt_desc.arity()];
1087                self.send(BackendMessage::RowDescription(
1088                    message::encode_row_description(relation_desc, &formats),
1089                ))
1090                .await?;
1091            }
1092        }
1093
1094        let result = match self
1095            .adapter_client
1096            .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
1097            .await
1098        {
1099            Ok((response, execute_started)) => {
1100                self.send_pending_notices().await?;
1101                self.send_execute_response(
1102                    response,
1103                    stmt_desc.relation_desc,
1104                    EMPTY_PORTAL.to_string(),
1105                    ExecuteCount::All,
1106                    portal_exec_message,
1107                    None,
1108                    ExecuteTimeout::None,
1109                    execute_started,
1110                )
1111                .await
1112            }
1113            Err(e) => {
1114                self.send_pending_notices().await?;
1115                self.send_error_and_get_state(e.into_response(Severity::Error))
1116                    .await
1117            }
1118        };
1119
1120        // Destroy the portal.
1121        self.adapter_client.session().remove_portal(EMPTY_PORTAL);
1122
1123        result
1124    }
1125
1126    async fn ensure_transaction(
1127        &mut self,
1128        num_stmts: usize,
1129        message_type: &str,
1130    ) -> Result<(), io::Error> {
1131        let start = Instant::now();
1132        if self.txn_needs_commit {
1133            self.commit_transaction().await?;
1134        }
1135        // start_transaction can't error (but assert that just in case it changes in
1136        // the future.
1137        let res = self.adapter_client.start_transaction(Some(num_stmts));
1138        assert_ok!(res);
1139        self.adapter_client
1140            .inner()
1141            .metrics()
1142            .pgwire_ensure_transaction_seconds
1143            .with_label_values(&[message_type])
1144            .observe(start.elapsed().as_secs_f64());
1145        Ok(())
1146    }
1147
1148    fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1149        let parse_start = Instant::now();
1150        let result = match self.adapter_client.parse(sql) {
1151            Ok(result) => result.map_err(|e| {
1152                // Convert our 0-based byte position to pgwire's 1-based character
1153                // position.
1154                let pos = sql[..e.error.pos].chars().count() + 1;
1155                ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1156            }),
1157            Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1158        };
1159        self.adapter_client
1160            .inner()
1161            .metrics()
1162            .parse_seconds
1163            .observe(parse_start.elapsed().as_secs_f64());
1164        result
1165    }
1166
1167    /// Executes a "Simple Query", see
1168    /// <https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SIMPLE-QUERY>
1169    ///
1170    /// For implicit transaction handling, see "Multiple Statements in a Simple Query" in the above.
1171    #[instrument(level = "debug")]
1172    async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1173        // Parse first before doing any transaction checking.
1174        let stmts = match self.parse_sql(&sql) {
1175            Ok(stmts) => stmts,
1176            Err(err) => {
1177                self.send_error_and_get_state(err).await?;
1178                return self.ready().await;
1179            }
1180        };
1181
1182        let num_stmts = stmts.len();
1183
1184        // Compare with postgres' backend/tcop/postgres.c exec_simple_query.
1185        for StatementParseResult { ast: stmt, sql } in stmts {
1186            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1187            if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1188                self.aborted_txn_error().await?;
1189                break;
1190            }
1191
1192            // Start an implicit transaction if we aren't in any transaction and there's
1193            // more than one statement. This mirrors the `use_implicit_block` variable in
1194            // postgres.
1195            //
1196            // This needs to be done in the loop instead of once at the top because
1197            // a COMMIT/ROLLBACK statement needs to start a new transaction on next
1198            // statement.
1199            self.ensure_transaction(num_stmts, "query").await?;
1200
1201            match self
1202                .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1203                .await?
1204            {
1205                State::Ready => (),
1206                State::Drain => break,
1207                State::Done => return Ok(State::Done),
1208            }
1209        }
1210
1211        // Implicit transactions are closed at the end of a Query message.
1212        {
1213            if self.adapter_client.session().transaction().is_implicit() {
1214                self.commit_transaction().await?;
1215            }
1216        }
1217
1218        if num_stmts == 0 {
1219            self.send(BackendMessage::EmptyQueryResponse).await?;
1220        }
1221
1222        self.ready().await
1223    }
1224
1225    #[instrument(level = "debug")]
1226    async fn parse(
1227        &mut self,
1228        name: String,
1229        sql: String,
1230        param_oids: Vec<u32>,
1231    ) -> Result<State, io::Error> {
1232        // Start a transaction if we aren't in one.
1233        self.ensure_transaction(1, "parse").await?;
1234
1235        let mut param_types = vec![];
1236        for oid in param_oids {
1237            match mz_pgrepr::Type::from_oid(oid) {
1238                Ok(ty) => match SqlScalarType::try_from(&ty) {
1239                    Ok(ty) => param_types.push(Some(ty)),
1240                    Err(err) => {
1241                        return self
1242                            .send_error_and_get_state(ErrorResponse::error(
1243                                SqlState::INVALID_PARAMETER_VALUE,
1244                                err.to_string(),
1245                            ))
1246                            .await;
1247                    }
1248                },
1249                Err(_) if oid == 0 => param_types.push(None),
1250                Err(e) => {
1251                    return self
1252                        .send_error_and_get_state(ErrorResponse::error(
1253                            SqlState::PROTOCOL_VIOLATION,
1254                            e.to_string(),
1255                        ))
1256                        .await;
1257                }
1258            }
1259        }
1260
1261        let stmts = match self.parse_sql(&sql) {
1262            Ok(stmts) => stmts,
1263            Err(err) => {
1264                return self.send_error_and_get_state(err).await;
1265            }
1266        };
1267        if stmts.len() > 1 {
1268            return self
1269                .send_error_and_get_state(ErrorResponse::error(
1270                    SqlState::INTERNAL_ERROR,
1271                    "cannot insert multiple commands into a prepared statement",
1272                ))
1273                .await;
1274        }
1275        let (maybe_stmt, sql) = match stmts.into_iter().next() {
1276            None => (None, ""),
1277            Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1278        };
1279        if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1280            return self.aborted_txn_error().await;
1281        }
1282        match self
1283            .adapter_client
1284            .prepare(name, maybe_stmt, sql.to_string(), param_types)
1285            .await
1286        {
1287            Ok(()) => {
1288                self.send(BackendMessage::ParseComplete).await?;
1289                Ok(State::Ready)
1290            }
1291            Err(e) => {
1292                self.send_error_and_get_state(e.into_response(Severity::Error))
1293                    .await
1294            }
1295        }
1296    }
1297
1298    /// Commits and clears the current transaction.
1299    #[instrument(level = "debug")]
1300    async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1301        self.end_transaction(EndTransactionAction::Commit).await
1302    }
1303
1304    /// Rollback and clears the current transaction.
1305    #[instrument(level = "debug")]
1306    async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1307        self.end_transaction(EndTransactionAction::Rollback).await
1308    }
1309
1310    /// End a transaction and report to the user if an error occurred.
1311    #[instrument(level = "debug")]
1312    async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1313        self.txn_needs_commit = false;
1314        let resp = self.adapter_client.end_transaction(action).await;
1315        if let Err(err) = resp {
1316            self.send(BackendMessage::ErrorResponse(
1317                err.into_response(Severity::Error),
1318            ))
1319            .await?;
1320        }
1321        Ok(())
1322    }
1323
1324    #[instrument(level = "debug")]
1325    async fn bind(
1326        &mut self,
1327        portal_name: String,
1328        statement_name: String,
1329        param_formats: Vec<Format>,
1330        raw_params: Vec<Option<Vec<u8>>>,
1331        result_formats: Vec<Format>,
1332    ) -> Result<State, io::Error> {
1333        // Start a transaction if we aren't in one.
1334        self.ensure_transaction(1, "bind").await?;
1335
1336        let aborted_txn = self.is_aborted_txn();
1337        let stmt = match self
1338            .adapter_client
1339            .get_prepared_statement(&statement_name)
1340            .await
1341        {
1342            Ok(stmt) => stmt,
1343            Err(err) => {
1344                return self
1345                    .send_error_and_get_state(err.into_response(Severity::Error))
1346                    .await;
1347            }
1348        };
1349
1350        let param_types = &stmt.desc().param_types;
1351        if param_types.len() != raw_params.len() {
1352            let message = format!(
1353                "bind message supplies {actual} parameters, \
1354                 but prepared statement \"{name}\" requires {expected}",
1355                name = statement_name,
1356                actual = raw_params.len(),
1357                expected = param_types.len()
1358            );
1359            return self
1360                .send_error_and_get_state(ErrorResponse::error(
1361                    SqlState::PROTOCOL_VIOLATION,
1362                    message,
1363                ))
1364                .await;
1365        }
1366        let param_formats = match pad_formats(param_formats, raw_params.len()) {
1367            Ok(param_formats) => param_formats,
1368            Err(msg) => {
1369                return self
1370                    .send_error_and_get_state(ErrorResponse::error(
1371                        SqlState::PROTOCOL_VIOLATION,
1372                        msg,
1373                    ))
1374                    .await;
1375            }
1376        };
1377        if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1378            return self.aborted_txn_error().await;
1379        }
1380        let buf = RowArena::new();
1381        let mut params = vec![];
1382        for ((raw_param, mz_typ), format) in raw_params
1383            .into_iter()
1384            .zip_eq(param_types)
1385            .zip_eq(param_formats)
1386        {
1387            let pg_typ = mz_pgrepr::Type::from(mz_typ);
1388            let datum = match raw_param {
1389                None => Datum::Null,
1390                Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1391                    Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1392                        Ok(datum) => datum,
1393                        Err(msg) => {
1394                            return self
1395                                .send_error_and_get_state(ErrorResponse::error(
1396                                    SqlState::INVALID_PARAMETER_VALUE,
1397                                    msg,
1398                                ))
1399                                .await;
1400                        }
1401                    },
1402                    Err(err) => {
1403                        let msg = format!("unable to decode parameter: {}", err);
1404                        return self
1405                            .send_error_and_get_state(ErrorResponse::error(
1406                                SqlState::INVALID_PARAMETER_VALUE,
1407                                msg,
1408                            ))
1409                            .await;
1410                    }
1411                },
1412            };
1413            params.push((datum, mz_typ.clone()))
1414        }
1415
1416        let result_formats = match pad_formats(
1417            result_formats,
1418            stmt.desc()
1419                .relation_desc
1420                .clone()
1421                .map(|desc| desc.typ().column_types.len())
1422                .unwrap_or(0),
1423        ) {
1424            Ok(result_formats) => result_formats,
1425            Err(msg) => {
1426                return self
1427                    .send_error_and_get_state(ErrorResponse::error(
1428                        SqlState::PROTOCOL_VIOLATION,
1429                        msg,
1430                    ))
1431                    .await;
1432            }
1433        };
1434
1435        // Binary encodings are disabled for list, map, and aclitem types, but this doesn't
1436        // apply to COPY TO statements.
1437        if !stmt.stmt().map_or(false, |stmt| match stmt {
1438            Statement::Copy(CopyStatement {
1439                direction: CopyDirection::To,
1440                ..
1441            }) => true,
1442            Statement::Copy(CopyStatement {
1443                direction: CopyDirection::From,
1444                // To be conservative, we are restricting COPY FROM to only allow list/map/aclitem types if it is not
1445                // copying from STDIN. It is likely that this works in theory, but is risky and likely to OOM anyways
1446                // as all the data will be held in a buffer in memory before being processed.
1447                target: CopyTarget::Expr(_),
1448                ..
1449            }) => true,
1450            _ => false,
1451        }) {
1452            if let Some(desc) = stmt.desc().relation_desc.clone() {
1453                for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1454                    match (format, &ty.scalar_type) {
1455                        (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1456                            return self
1457                                .send_error_and_get_state(ErrorResponse::error(
1458                                    SqlState::PROTOCOL_VIOLATION,
1459                                    "binary encoding of list types is not implemented",
1460                                ))
1461                                .await;
1462                        }
1463                        (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1464                            return self
1465                                .send_error_and_get_state(ErrorResponse::error(
1466                                    SqlState::PROTOCOL_VIOLATION,
1467                                    "binary encoding of map types is not implemented",
1468                                ))
1469                                .await;
1470                        }
1471                        (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1472                            return self
1473                                .send_error_and_get_state(ErrorResponse::error(
1474                                    SqlState::PROTOCOL_VIOLATION,
1475                                    "binary encoding of aclitem types does not exist",
1476                                ))
1477                                .await;
1478                        }
1479                        _ => (),
1480                    }
1481                }
1482            }
1483        }
1484
1485        let desc = stmt.desc().clone();
1486        let logging = Arc::clone(stmt.logging());
1487        let stmt_ast = stmt.stmt().cloned();
1488        let state_revision = stmt.state_revision;
1489        if let Err(err) = self.adapter_client.session().set_portal(
1490            portal_name,
1491            desc,
1492            stmt_ast,
1493            logging,
1494            params,
1495            result_formats,
1496            state_revision,
1497        ) {
1498            return self
1499                .send_error_and_get_state(err.into_response(Severity::Error))
1500                .await;
1501        }
1502
1503        self.send(BackendMessage::BindComplete).await?;
1504        Ok(State::Ready)
1505    }
1506
1507    /// `outer_ctx_extra` is Some when we are executing as part of an outer statement, e.g., a FETCH
1508    /// triggering the execution of the underlying query.
1509    fn execute(
1510        &mut self,
1511        portal_name: String,
1512        max_rows: ExecuteCount,
1513        get_response: GetResponse,
1514        fetch_portal_name: Option<String>,
1515        timeout: ExecuteTimeout,
1516        outer_ctx_extra: Option<ExecuteContextGuard>,
1517        received: Option<EpochMillis>,
1518    ) -> BoxFuture<'_, Result<State, io::Error>> {
1519        async move {
1520            let aborted_txn = self.is_aborted_txn();
1521
1522            // Check if the portal has been started and can be continued.
1523            let portal = match self
1524                .adapter_client
1525                .session()
1526                .get_portal_unverified_mut(&portal_name)
1527            {
1528                Some(portal) => portal,
1529                None => {
1530                    let msg = format!("portal {} does not exist", portal_name.quoted());
1531                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1532                        self.adapter_client.retire_execute(
1533                            outer_ctx_extra,
1534                            StatementEndedExecutionReason::Errored { error: msg.clone() },
1535                        );
1536                    }
1537                    return self
1538                        .send_error_and_get_state(ErrorResponse::error(
1539                            SqlState::INVALID_CURSOR_NAME,
1540                            msg,
1541                        ))
1542                        .await;
1543                }
1544            };
1545
1546            *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1547
1548            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1549            let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1550            if aborted_txn && !txn_exit_stmt {
1551                if let Some(outer_ctx_extra) = outer_ctx_extra {
1552                    self.adapter_client.retire_execute(
1553                        outer_ctx_extra,
1554                        StatementEndedExecutionReason::Errored {
1555                            error: ABORTED_TXN_MSG.to_string(),
1556                        },
1557                    );
1558                }
1559                return self.aborted_txn_error().await;
1560            }
1561
1562            let row_desc = portal.desc.relation_desc.clone();
1563            match portal.state {
1564                PortalState::NotStarted => {
1565                    // Start a transaction if we aren't in one.
1566                    self.ensure_transaction(1, "execute").await?;
1567                    match self
1568                        .adapter_client
1569                        .execute(
1570                            portal_name.clone(),
1571                            self.conn.wait_closed(),
1572                            outer_ctx_extra,
1573                        )
1574                        .await
1575                    {
1576                        Ok((response, execute_started)) => {
1577                            self.send_pending_notices().await?;
1578                            self.send_execute_response(
1579                                response,
1580                                row_desc,
1581                                portal_name,
1582                                max_rows,
1583                                get_response,
1584                                fetch_portal_name,
1585                                timeout,
1586                                execute_started,
1587                            )
1588                            .await
1589                        }
1590                        Err(e) => {
1591                            self.send_pending_notices().await?;
1592                            self.send_error_and_get_state(e.into_response(Severity::Error))
1593                                .await
1594                        }
1595                    }
1596                }
1597                PortalState::InProgress(rows) => {
1598                    let rows = rows.take().expect("InProgress rows must be populated");
1599                    let (result, statement_ended_execution_reason) = match self
1600                        .send_rows(
1601                            row_desc.expect("portal missing row desc on resumption"),
1602                            portal_name,
1603                            rows,
1604                            max_rows,
1605                            get_response,
1606                            fetch_portal_name,
1607                            timeout,
1608                        )
1609                        .await
1610                    {
1611                        Err(e) => {
1612                            // This is an error communicating with the connection.
1613                            // We consider that to be a cancelation, rather than a query error.
1614                            (Err(e), StatementEndedExecutionReason::Canceled)
1615                        }
1616                        Ok((ok, SendRowsEndedReason::Canceled)) => {
1617                            (Ok(ok), StatementEndedExecutionReason::Canceled)
1618                        }
1619                        // NOTE: For now the values for `result_size` and
1620                        // `rows_returned` in fetches are a bit confusing.
1621                        // We record `Some(n)` for the first fetch, where `n` is
1622                        // the number of bytes/rows returned by the inner
1623                        // execute (regardless of how many rows the
1624                        // fetch fetched), and `None` for subsequent fetches.
1625                        //
1626                        // This arguably makes sense since the size/rows
1627                        // returned measures how much work the compute
1628                        // layer had to do to satisfy the query, but
1629                        // we should revisit it if/when we start
1630                        // logging the inner execute separately.
1631                        Ok((
1632                            ok,
1633                            SendRowsEndedReason::Success {
1634                                result_size: _,
1635                                rows_returned: _,
1636                            },
1637                        )) => (
1638                            Ok(ok),
1639                            StatementEndedExecutionReason::Success {
1640                                result_size: None,
1641                                rows_returned: None,
1642                                execution_strategy: None,
1643                            },
1644                        ),
1645                        Ok((ok, SendRowsEndedReason::Errored { error })) => {
1646                            (Ok(ok), StatementEndedExecutionReason::Errored { error })
1647                        }
1648                    };
1649                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1650                        self.adapter_client
1651                            .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1652                    }
1653                    result
1654                }
1655                // FETCH is an awkward command for our current architecture. In Postgres it
1656                // will extract <count> rows from the target portal, cache them, and return
1657                // them to the user as requested. Its command tag is always FETCH <num rows
1658                // extracted>. In Materialize, since we have chosen to not fully support FETCH,
1659                // we must remember the number of rows that were returned. Use this tag to
1660                // remember that information and return it.
1661                PortalState::Completed(Some(tag)) => {
1662                    let tag = tag.to_string();
1663                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1664                        self.adapter_client.retire_execute(
1665                            outer_ctx_extra,
1666                            StatementEndedExecutionReason::Success {
1667                                result_size: None,
1668                                rows_returned: None,
1669                                execution_strategy: None,
1670                            },
1671                        );
1672                    }
1673                    self.send(BackendMessage::CommandComplete { tag }).await?;
1674                    Ok(State::Ready)
1675                }
1676                PortalState::Completed(None) => {
1677                    let error = format!(
1678                        "portal {} cannot be run",
1679                        Ident::new_unchecked(portal_name).to_ast_string_stable()
1680                    );
1681                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1682                        self.adapter_client.retire_execute(
1683                            outer_ctx_extra,
1684                            StatementEndedExecutionReason::Errored {
1685                                error: error.clone(),
1686                            },
1687                        );
1688                    }
1689                    self.send_error_and_get_state(ErrorResponse::error(
1690                        SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1691                        error,
1692                    ))
1693                    .await
1694                }
1695            }
1696        }
1697        .instrument(debug_span!("execute"))
1698        .boxed()
1699    }
1700
1701    #[instrument(level = "debug")]
1702    async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1703        // Start a transaction if we aren't in one.
1704        self.ensure_transaction(1, "describe_statement").await?;
1705
1706        let stmt = match self.adapter_client.get_prepared_statement(name).await {
1707            Ok(stmt) => stmt,
1708            Err(err) => {
1709                return self
1710                    .send_error_and_get_state(err.into_response(Severity::Error))
1711                    .await;
1712            }
1713        };
1714        // Cloning to avoid a mutable borrow issue because `send` also uses `adapter_client`
1715        let parameter_desc = BackendMessage::ParameterDescription(
1716            stmt.desc()
1717                .param_types
1718                .iter()
1719                .map(mz_pgrepr::Type::from)
1720                .collect(),
1721        );
1722        // Claim that all results will be output in text format, even
1723        // though the true result formats are not yet known. A bit
1724        // weird, but this is the behavior that PostgreSQL specifies.
1725        let formats = vec![Format::Text; stmt.desc().arity()];
1726        let row_desc = describe_rows(stmt.desc(), &formats);
1727        self.send_all([parameter_desc, row_desc]).await?;
1728        Ok(State::Ready)
1729    }
1730
1731    #[instrument(level = "debug")]
1732    async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1733        // Start a transaction if we aren't in one.
1734        self.ensure_transaction(1, "describe_portal").await?;
1735
1736        let session = self.adapter_client.session();
1737        let row_desc = session
1738            .get_portal_unverified(name)
1739            .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1740        match row_desc {
1741            Some(row_desc) => {
1742                self.send(row_desc).await?;
1743                Ok(State::Ready)
1744            }
1745            None => {
1746                self.send_error_and_get_state(ErrorResponse::error(
1747                    SqlState::INVALID_CURSOR_NAME,
1748                    format!("portal {} does not exist", name.quoted()),
1749                ))
1750                .await
1751            }
1752        }
1753    }
1754
1755    #[instrument(level = "debug")]
1756    async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1757        self.adapter_client
1758            .session()
1759            .remove_prepared_statement(&name);
1760        self.send(BackendMessage::CloseComplete).await?;
1761        Ok(State::Ready)
1762    }
1763
1764    #[instrument(level = "debug")]
1765    async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1766        self.adapter_client.session().remove_portal(&name);
1767        self.send(BackendMessage::CloseComplete).await?;
1768        Ok(State::Ready)
1769    }
1770
1771    fn complete_portal(&mut self, name: &str) {
1772        let portal = self
1773            .adapter_client
1774            .session()
1775            .get_portal_unverified_mut(name)
1776            .expect("portal should exist");
1777        *portal.state = PortalState::Completed(None);
1778    }
1779
1780    async fn fetch(
1781        &mut self,
1782        name: String,
1783        count: Option<FetchDirection>,
1784        max_rows: ExecuteCount,
1785        fetch_portal_name: Option<String>,
1786        timeout: ExecuteTimeout,
1787        ctx_extra: ExecuteContextGuard,
1788    ) -> Result<State, io::Error> {
1789        // Unlike Execute, no count specified in FETCH returns 1 row, and 0 means 0
1790        // instead of All.
1791        let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1792
1793        // Figure out how many rows we should send back by looking at the various
1794        // combinations of the execute and fetch.
1795        //
1796        // In Postgres, Fetch will cache <count> rows from the target portal and
1797        // return those as requested (if, say, an Execute message was sent with a
1798        // max_rows < the Fetch's count). We expect that case to be incredibly rare and
1799        // so have chosen to not support it until users request it. This eases
1800        // implementation difficulty since we don't have to be able to "send" rows to
1801        // a buffer.
1802        //
1803        // TODO(mjibson): Test this somehow? Need to divide up the pgtest files in
1804        // order to have some that are not Postgres compatible.
1805        let count = match (max_rows, count) {
1806            (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1807                let count = usize::cast_from(count);
1808                if max_rows < count {
1809                    let msg = "Execute with max_rows < a FETCH's count is not supported";
1810                    self.adapter_client.retire_execute(
1811                        ctx_extra,
1812                        StatementEndedExecutionReason::Errored {
1813                            error: msg.to_string(),
1814                        },
1815                    );
1816                    return self
1817                        .send_error_and_get_state(ErrorResponse::error(
1818                            SqlState::FEATURE_NOT_SUPPORTED,
1819                            msg,
1820                        ))
1821                        .await;
1822                }
1823                ExecuteCount::Count(count)
1824            }
1825            (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1826                let msg = "Execute with max_rows of a FETCH ALL is not supported";
1827                self.adapter_client.retire_execute(
1828                    ctx_extra,
1829                    StatementEndedExecutionReason::Errored {
1830                        error: msg.to_string(),
1831                    },
1832                );
1833                return self
1834                    .send_error_and_get_state(ErrorResponse::error(
1835                        SqlState::FEATURE_NOT_SUPPORTED,
1836                        msg,
1837                    ))
1838                    .await;
1839            }
1840            (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1841            (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1842                ExecuteCount::Count(usize::cast_from(count))
1843            }
1844        };
1845        let cursor_name = name.to_string();
1846        self.execute(
1847            cursor_name,
1848            count,
1849            fetch_message,
1850            fetch_portal_name,
1851            timeout,
1852            Some(ctx_extra),
1853            None,
1854        )
1855        .await
1856    }
1857
1858    async fn flush(&mut self) -> Result<State, io::Error> {
1859        self.conn.flush().await?;
1860        Ok(State::Ready)
1861    }
1862
1863    /// Sends a backend message to the client, after applying a severity filter.
1864    ///
1865    /// The message is only sent if its severity is above the severity set
1866    /// in the session, with the default value being NOTICE.
1867    #[instrument(level = "debug")]
1868    async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1869    where
1870        M: Into<BackendMessage>,
1871    {
1872        let message: BackendMessage = message.into();
1873        let is_error =
1874            matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1875
1876        self.conn.send(message).await?;
1877
1878        // Flush immediately after sending an error response, as some clients
1879        // expect to be able to read the error response before sending a Sync
1880        // message. This is arguably in violation of the protocol specification,
1881        // but the specification is somewhat ambiguous, and easier to match
1882        // PostgreSQL here than to fix all the clients that have this
1883        // expectation.
1884        if is_error {
1885            self.conn.flush().await?;
1886        }
1887
1888        Ok(())
1889    }
1890
1891    #[instrument(level = "debug")]
1892    pub async fn send_all(
1893        &mut self,
1894        messages: impl IntoIterator<Item = BackendMessage>,
1895    ) -> Result<(), io::Error> {
1896        for m in messages {
1897            self.send(m).await?;
1898        }
1899        Ok(())
1900    }
1901
1902    #[instrument(level = "debug")]
1903    async fn sync(&mut self) -> Result<State, io::Error> {
1904        // Close the current transaction if we are in an implicit transaction.
1905        if self.adapter_client.session().transaction().is_implicit() {
1906            self.commit_transaction().await?;
1907        }
1908        self.ready().await
1909    }
1910
1911    #[instrument(level = "debug")]
1912    async fn ready(&mut self) -> Result<State, io::Error> {
1913        let txn_state = self.adapter_client.session().transaction().into();
1914        self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1915        self.flush().await
1916    }
1917
1918    #[allow(clippy::too_many_arguments)]
1919    #[instrument(level = "debug")]
1920    async fn send_execute_response(
1921        &mut self,
1922        response: ExecuteResponse,
1923        row_desc: Option<RelationDesc>,
1924        portal_name: String,
1925        max_rows: ExecuteCount,
1926        get_response: GetResponse,
1927        fetch_portal_name: Option<String>,
1928        timeout: ExecuteTimeout,
1929        execute_started: Instant,
1930    ) -> Result<State, io::Error> {
1931        let mut tag = response.tag();
1932
1933        macro_rules! command_complete {
1934            () => {{
1935                self.send(BackendMessage::CommandComplete {
1936                    tag: tag
1937                        .take()
1938                        .expect("command_complete only called on tag-generating results"),
1939                })
1940                .await?;
1941                Ok(State::Ready)
1942            }};
1943        }
1944
1945        let r = match response {
1946            ExecuteResponse::ClosedCursor => {
1947                self.complete_portal(&portal_name);
1948                command_complete!()
1949            }
1950            ExecuteResponse::DeclaredCursor => {
1951                self.complete_portal(&portal_name);
1952                command_complete!()
1953            }
1954            ExecuteResponse::EmptyQuery => {
1955                self.send(BackendMessage::EmptyQueryResponse).await?;
1956                Ok(State::Ready)
1957            }
1958            ExecuteResponse::Fetch {
1959                name,
1960                count,
1961                timeout,
1962                ctx_extra,
1963            } => {
1964                self.fetch(
1965                    name,
1966                    count,
1967                    max_rows,
1968                    Some(portal_name.to_string()),
1969                    timeout,
1970                    ctx_extra,
1971                )
1972                .await
1973            }
1974            ExecuteResponse::SendingRowsStreaming {
1975                rows,
1976                instance_id,
1977                strategy,
1978            } => {
1979                let row_desc = row_desc
1980                    .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1981
1982                let span = tracing::debug_span!("sending_rows_streaming");
1983
1984                self.send_rows(
1985                    row_desc,
1986                    portal_name,
1987                    InProgressRows::new(RecordFirstRowStream::new(
1988                        Box::new(rows),
1989                        execute_started,
1990                        &self.adapter_client,
1991                        Some(instance_id),
1992                        Some(strategy),
1993                    )),
1994                    max_rows,
1995                    get_response,
1996                    fetch_portal_name,
1997                    timeout,
1998                )
1999                .instrument(span)
2000                .await
2001                .map(|(state, _)| state)
2002            }
2003            ExecuteResponse::SendingRowsImmediate { rows } => {
2004                let row_desc = row_desc
2005                    .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
2006
2007                let span = tracing::debug_span!("sending_rows_immediate");
2008
2009                let stream =
2010                    futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
2011                self.send_rows(
2012                    row_desc,
2013                    portal_name,
2014                    InProgressRows::new(RecordFirstRowStream::new(
2015                        Box::new(stream),
2016                        execute_started,
2017                        &self.adapter_client,
2018                        None,
2019                        Some(StatementExecutionStrategy::Constant),
2020                    )),
2021                    max_rows,
2022                    get_response,
2023                    fetch_portal_name,
2024                    timeout,
2025                )
2026                .instrument(span)
2027                .await
2028                .map(|(state, _)| state)
2029            }
2030            ExecuteResponse::SetVariable { name, .. } => {
2031                // This code is somewhat awkwardly structured because we
2032                // can't hold `var` across an await point.
2033                let qn = name.to_string();
2034                let msg = if let Some(var) = self
2035                    .adapter_client
2036                    .session()
2037                    .vars_mut()
2038                    .notify_set()
2039                    .find(|v| v.name() == qn)
2040                {
2041                    Some(BackendMessage::ParameterStatus(var.name(), var.value()))
2042                } else {
2043                    None
2044                };
2045                if let Some(msg) = msg {
2046                    self.send(msg).await?;
2047                }
2048                command_complete!()
2049            }
2050            ExecuteResponse::Subscribing {
2051                rx,
2052                ctx_extra,
2053                instance_id,
2054            } => {
2055                if fetch_portal_name.is_none() {
2056                    let mut msg = ErrorResponse::notice(
2057                        SqlState::WARNING,
2058                        "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
2059                    );
2060                    if self.adapter_client.session().vars().application_name() == "psql" {
2061                        msg.hint = Some(
2062                            "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
2063                                .into(),
2064                        )
2065                    }
2066                    self.send(msg).await?;
2067                    self.conn.flush().await?;
2068                }
2069                let row_desc =
2070                    row_desc.expect("missing row description for ExecuteResponse::Subscribing");
2071                let (result, statement_ended_execution_reason) = match self
2072                    .send_rows(
2073                        row_desc,
2074                        portal_name,
2075                        InProgressRows::new(RecordFirstRowStream::new(
2076                            Box::new(UnboundedReceiverStream::new(rx)),
2077                            execute_started,
2078                            &self.adapter_client,
2079                            Some(instance_id),
2080                            None,
2081                        )),
2082                        max_rows,
2083                        get_response,
2084                        fetch_portal_name,
2085                        timeout,
2086                    )
2087                    .await
2088                {
2089                    Err(e) => {
2090                        // This is an error communicating with the connection.
2091                        // We consider that to be a cancelation, rather than a query error.
2092                        (Err(e), StatementEndedExecutionReason::Canceled)
2093                    }
2094                    Ok((ok, SendRowsEndedReason::Canceled)) => {
2095                        (Ok(ok), StatementEndedExecutionReason::Canceled)
2096                    }
2097                    Ok((
2098                        ok,
2099                        SendRowsEndedReason::Success {
2100                            result_size,
2101                            rows_returned,
2102                        },
2103                    )) => (
2104                        Ok(ok),
2105                        StatementEndedExecutionReason::Success {
2106                            result_size: Some(result_size),
2107                            rows_returned: Some(rows_returned),
2108                            execution_strategy: None,
2109                        },
2110                    ),
2111                    Ok((ok, SendRowsEndedReason::Errored { error })) => {
2112                        (Ok(ok), StatementEndedExecutionReason::Errored { error })
2113                    }
2114                };
2115                self.adapter_client
2116                    .retire_execute(ctx_extra, statement_ended_execution_reason);
2117                return result;
2118            }
2119            ExecuteResponse::CopyTo { format, resp } => {
2120                let row_desc =
2121                    row_desc.expect("missing row description for ExecuteResponse::CopyTo");
2122                match *resp {
2123                    ExecuteResponse::Subscribing {
2124                        rx,
2125                        ctx_extra,
2126                        instance_id,
2127                    } => {
2128                        let (result, statement_ended_execution_reason) = match self
2129                            .copy_rows(
2130                                format,
2131                                row_desc,
2132                                RecordFirstRowStream::new(
2133                                    Box::new(UnboundedReceiverStream::new(rx)),
2134                                    execute_started,
2135                                    &self.adapter_client,
2136                                    Some(instance_id),
2137                                    None,
2138                                ),
2139                            )
2140                            .await
2141                        {
2142                            Err(e) => {
2143                                // This is an error communicating with the connection.
2144                                // We consider that to be a cancelation, rather than a query error.
2145                                (Err(e), StatementEndedExecutionReason::Canceled)
2146                            }
2147                            Ok((
2148                                state,
2149                                SendRowsEndedReason::Success {
2150                                    result_size,
2151                                    rows_returned,
2152                                },
2153                            )) => (
2154                                Ok(state),
2155                                StatementEndedExecutionReason::Success {
2156                                    result_size: Some(result_size),
2157                                    rows_returned: Some(rows_returned),
2158                                    execution_strategy: None,
2159                                },
2160                            ),
2161                            Ok((state, SendRowsEndedReason::Errored { error })) => {
2162                                (Ok(state), StatementEndedExecutionReason::Errored { error })
2163                            }
2164                            Ok((state, SendRowsEndedReason::Canceled)) => {
2165                                (Ok(state), StatementEndedExecutionReason::Canceled)
2166                            }
2167                        };
2168                        self.adapter_client
2169                            .retire_execute(ctx_extra, statement_ended_execution_reason);
2170                        return result;
2171                    }
2172                    ExecuteResponse::SendingRowsStreaming {
2173                        rows,
2174                        instance_id,
2175                        strategy,
2176                    } => {
2177                        // We don't need to finalize execution here;
2178                        // it was already done in the
2179                        // coordinator. Just extract the state and
2180                        // return that.
2181                        return self
2182                            .copy_rows(
2183                                format,
2184                                row_desc,
2185                                RecordFirstRowStream::new(
2186                                    Box::new(rows),
2187                                    execute_started,
2188                                    &self.adapter_client,
2189                                    Some(instance_id),
2190                                    Some(strategy),
2191                                ),
2192                            )
2193                            .await
2194                            .map(|(state, _)| state);
2195                    }
2196                    ExecuteResponse::SendingRowsImmediate { rows } => {
2197                        let span = tracing::debug_span!("sending_rows_immediate");
2198
2199                        let rows = futures::stream::once(futures::future::ready(
2200                            PeekResponseUnary::Rows(rows),
2201                        ));
2202                        // We don't need to finalize execution here;
2203                        // it was already done in the
2204                        // coordinator. Just extract the state and
2205                        // return that.
2206                        return self
2207                            .copy_rows(
2208                                format,
2209                                row_desc,
2210                                RecordFirstRowStream::new(
2211                                    Box::new(rows),
2212                                    execute_started,
2213                                    &self.adapter_client,
2214                                    None,
2215                                    Some(StatementExecutionStrategy::Constant),
2216                                ),
2217                            )
2218                            .instrument(span)
2219                            .await
2220                            .map(|(state, _)| state);
2221                    }
2222                    _ => {
2223                        return self
2224                            .send_error_and_get_state(ErrorResponse::error(
2225                                SqlState::INTERNAL_ERROR,
2226                                "unsupported COPY response type".to_string(),
2227                            ))
2228                            .await;
2229                    }
2230                };
2231            }
2232            ExecuteResponse::CopyFrom {
2233                target_id,
2234                target_name,
2235                columns,
2236                params,
2237                ctx_extra,
2238            } => {
2239                let row_desc =
2240                    row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2241                self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2242                    .await
2243            }
2244            ExecuteResponse::TransactionCommitted { params }
2245            | ExecuteResponse::TransactionRolledBack { params } => {
2246                let notify_set: mz_ore::collections::HashSet<String> = self
2247                    .adapter_client
2248                    .session()
2249                    .vars()
2250                    .notify_set()
2251                    .map(|v| v.name().to_string())
2252                    .collect();
2253
2254                // Only report on parameters that are in the notify set.
2255                for (name, value) in params
2256                    .into_iter()
2257                    .filter(|(name, _v)| notify_set.contains(*name))
2258                {
2259                    let msg = BackendMessage::ParameterStatus(name, value);
2260                    self.send(msg).await?;
2261                }
2262                command_complete!()
2263            }
2264
2265            ExecuteResponse::AlteredDefaultPrivileges
2266            | ExecuteResponse::AlteredObject(..)
2267            | ExecuteResponse::AlteredRole
2268            | ExecuteResponse::AlteredSystemConfiguration
2269            | ExecuteResponse::CreatedCluster { .. }
2270            | ExecuteResponse::CreatedClusterReplica { .. }
2271            | ExecuteResponse::CreatedConnection { .. }
2272            | ExecuteResponse::CreatedDatabase { .. }
2273            | ExecuteResponse::CreatedIndex { .. }
2274            | ExecuteResponse::CreatedIntrospectionSubscribe
2275            | ExecuteResponse::CreatedMaterializedView { .. }
2276            | ExecuteResponse::CreatedRole
2277            | ExecuteResponse::CreatedSchema { .. }
2278            | ExecuteResponse::CreatedSecret { .. }
2279            | ExecuteResponse::CreatedSink { .. }
2280            | ExecuteResponse::CreatedSource { .. }
2281            | ExecuteResponse::CreatedTable { .. }
2282            | ExecuteResponse::CreatedType
2283            | ExecuteResponse::CreatedView { .. }
2284            | ExecuteResponse::CreatedViews { .. }
2285            | ExecuteResponse::CreatedNetworkPolicy
2286            | ExecuteResponse::Comment
2287            | ExecuteResponse::Deallocate { .. }
2288            | ExecuteResponse::Deleted(..)
2289            | ExecuteResponse::DiscardedAll
2290            | ExecuteResponse::DiscardedTemp
2291            | ExecuteResponse::DroppedObject(_)
2292            | ExecuteResponse::DroppedOwned
2293            | ExecuteResponse::GrantedPrivilege
2294            | ExecuteResponse::GrantedRole
2295            | ExecuteResponse::Inserted(..)
2296            | ExecuteResponse::Copied(..)
2297            | ExecuteResponse::Prepare
2298            | ExecuteResponse::Raised
2299            | ExecuteResponse::ReassignOwned
2300            | ExecuteResponse::RevokedPrivilege
2301            | ExecuteResponse::RevokedRole
2302            | ExecuteResponse::StartedTransaction { .. }
2303            | ExecuteResponse::Updated(..)
2304            | ExecuteResponse::ValidatedConnection => {
2305                command_complete!()
2306            }
2307        };
2308
2309        assert_none!(tag, "tag created but not consumed: {:?}", tag);
2310        r
2311    }
2312
2313    #[allow(clippy::too_many_arguments)]
2314    // TODO(guswynn): figure out how to get it to compile without skip_all
2315    #[mz_ore::instrument(level = "debug")]
2316    async fn send_rows(
2317        &mut self,
2318        row_desc: RelationDesc,
2319        portal_name: String,
2320        mut rows: InProgressRows,
2321        max_rows: ExecuteCount,
2322        get_response: GetResponse,
2323        fetch_portal_name: Option<String>,
2324        timeout: ExecuteTimeout,
2325    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2326        // If this portal is being executed from a FETCH then we need to use the result
2327        // format type of the outer portal.
2328        let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2329            name
2330        } else {
2331            &portal_name
2332        };
2333        let result_formats = self
2334            .adapter_client
2335            .session()
2336            .get_portal_unverified(result_format_portal_name)
2337            .expect("valid fetch portal name for send rows")
2338            .result_formats
2339            .clone();
2340
2341        let (mut wait_once, mut deadline) = match timeout {
2342            ExecuteTimeout::None => (false, None),
2343            ExecuteTimeout::Seconds(t) => (
2344                false,
2345                Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2346            ),
2347            ExecuteTimeout::WaitOnce => (true, None),
2348        };
2349
2350        // Sanity check that the various `RelationDesc`s match up.
2351        {
2352            let portal_name_desc = &self
2353                .adapter_client
2354                .session()
2355                .get_portal_unverified(portal_name.as_str())
2356                .expect("portal should exist")
2357                .desc
2358                .relation_desc;
2359            if let Some(portal_name_desc) = portal_name_desc {
2360                soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2361            }
2362            if let Some(fetch_portal_name) = &fetch_portal_name {
2363                let fetch_portal_desc = &self
2364                    .adapter_client
2365                    .session()
2366                    .get_portal_unverified(fetch_portal_name)
2367                    .expect("portal should exist")
2368                    .desc
2369                    .relation_desc;
2370                if let Some(fetch_portal_desc) = fetch_portal_desc {
2371                    soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2372                }
2373            }
2374        }
2375
2376        self.conn.set_encode_state(
2377            row_desc
2378                .typ()
2379                .column_types
2380                .iter()
2381                .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2382                .zip_eq(result_formats)
2383                .collect(),
2384        );
2385
2386        let mut total_sent_rows = 0;
2387        let mut total_sent_bytes = 0;
2388        // want_rows is the maximum number of rows the client wants.
2389        let mut want_rows = match max_rows {
2390            ExecuteCount::All => usize::MAX,
2391            ExecuteCount::Count(count) => count,
2392        };
2393
2394        // Send rows while the client still wants them and there are still rows to send.
2395        loop {
2396            // Fetch next batch of rows, waiting for a possible requested
2397            // timeout or notice.
2398            let batch = if rows.current.is_some() {
2399                FetchResult::Rows(rows.current.take())
2400            } else if want_rows == 0 {
2401                FetchResult::Rows(None)
2402            } else {
2403                let notice_fut = self.adapter_client.session().recv_notice();
2404                // Biased: drain available data before checking the deadline.
2405                // This is critical for the WaitOnce case, where the deadline
2406                // is set to `Instant::now()` right after the first batch:
2407                // without `biased`, `recv()` and the already-expired deadline
2408                // race nondeterministically, so we might break the loop
2409                // before `no_more_rows` is set (or even before ready rows
2410                // are consumed). With an explicit `TIMEOUT`, missing a batch
2411                // right at the boundary is acceptable, but WaitOnce fires
2412                // immediately and the race is not.
2413                //
2414                // Trade-off: if `recv()` keeps returning Ready (unlikely in
2415                // practice—row processing + flush is slower than upstream
2416                // tick granularity), a `TIMEOUT` deadline could be delayed.
2417                // See database-issues#9470.
2418                tokio::select! {
2419                    biased;
2420                    err = self.conn.wait_closed() => return Err(err),
2421                    batch = rows.remaining.recv() => match batch {
2422                        None => FetchResult::Rows(None),
2423                        Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2424                        Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2425                        Some(PeekResponseUnary::DependencyDropped(dep)) => {
2426                            FetchResult::Error(dep.query_terminated_error())
2427                        }
2428                        Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2429                    },
2430                    notice = notice_fut => {
2431                        FetchResult::Notice(notice)
2432                    }
2433                    _ = time::sleep_until(
2434                        deadline.unwrap_or_else(tokio::time::Instant::now),
2435                    ), if deadline.is_some() => FetchResult::Rows(None),
2436                }
2437            };
2438
2439            match batch {
2440                FetchResult::Rows(None) => break,
2441                FetchResult::Rows(Some(mut batch_rows)) => {
2442                    if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2443                        let msg = err.to_string();
2444                        return self
2445                            .send_error_and_get_state(err.into_response(Severity::Error))
2446                            .await
2447                            .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2448                    }
2449
2450                    // If wait_once is true: the first time this fn is called it blocks (same as
2451                    // deadline == None). The second time this fn is called it should behave the
2452                    // same a 0s timeout.
2453                    if wait_once && batch_rows.peek().is_some() {
2454                        deadline = Some(tokio::time::Instant::now());
2455                        wait_once = false;
2456                    }
2457
2458                    // Send a portion of the rows.
2459                    let mut sent_rows = 0;
2460                    let mut sent_bytes = 0;
2461                    let messages = (&mut batch_rows)
2462                        // TODO(parkmycar): This is a fair bit of juggling between iterator types
2463                        // to count the total number of bytes. Alternatively we could track the
2464                        // total sent bytes in this .map(...) call, but having side effects in map
2465                        // is a code smell.
2466                        .map(|row| {
2467                            let row_len = row.byte_len();
2468                            let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2469                            (row_len, BackendMessage::DataRow(values))
2470                        })
2471                        .inspect(|(row_len, _)| {
2472                            sent_bytes += row_len;
2473                            sent_rows += 1
2474                        })
2475                        .map(|(_row_len, row)| row)
2476                        .take(want_rows);
2477                    self.send_all(messages).await?;
2478
2479                    total_sent_rows += sent_rows;
2480                    total_sent_bytes += sent_bytes;
2481                    want_rows -= sent_rows;
2482
2483                    // If we have sent the number of requested rows, put the remainder of the batch
2484                    // (if any) back and stop sending.
2485                    if want_rows == 0 {
2486                        if batch_rows.peek().is_some() {
2487                            rows.current = Some(batch_rows);
2488                        }
2489                        break;
2490                    }
2491
2492                    self.conn.flush().await?;
2493                }
2494                FetchResult::Notice(notice) => {
2495                    self.send(notice.into_response()).await?;
2496                    self.conn.flush().await?;
2497                }
2498                FetchResult::Error(text) => {
2499                    return self
2500                        .send_error_and_get_state(ErrorResponse::error(
2501                            SqlState::INTERNAL_ERROR,
2502                            text.clone(),
2503                        ))
2504                        .await
2505                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2506                }
2507                FetchResult::Canceled => {
2508                    return self
2509                        .send_error_and_get_state(ErrorResponse::error(
2510                            SqlState::QUERY_CANCELED,
2511                            "canceling statement due to user request",
2512                        ))
2513                        .await
2514                        .map(|state| (state, SendRowsEndedReason::Canceled));
2515                }
2516            }
2517        }
2518
2519        let portal = self
2520            .adapter_client
2521            .session()
2522            .get_portal_unverified_mut(&portal_name)
2523            .expect("valid portal name for send rows");
2524
2525        let saw_rows = rows.remaining.saw_rows;
2526        let no_more_rows = rows.no_more_rows();
2527        let metric_recorded = rows.remaining.metric_recorded;
2528        let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2529
2530        if no_more_rows && !metric_recorded {
2531            rows.remaining.metric_recorded = true;
2532        }
2533
2534        // Always return rows back, even if it's empty. This prevents an unclosed
2535        // portal from re-executing after it has been emptied.
2536        *portal.state = PortalState::InProgress(Some(rows));
2537
2538        let fetch_portal = fetch_portal_name.map(|name| {
2539            self.adapter_client
2540                .session()
2541                .get_portal_unverified_mut(&name)
2542                .expect("valid fetch portal")
2543        });
2544        let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2545        self.send(response_message).await?;
2546
2547        // Attend to metrics if there are no more rows. Only record once per stream
2548        // to avoid polluting the histogram when an exhausted cursor is FETCHed again.
2549        if no_more_rows && !metric_recorded {
2550            let statement_type = if let Some(stmt) = &self
2551                .adapter_client
2552                .session()
2553                .get_portal_unverified(&portal_name)
2554                .expect("valid portal name for send_rows")
2555                .stmt
2556            {
2557                metrics::statement_type_label_value(stmt.deref())
2558            } else {
2559                "no-statement"
2560            };
2561            let duration = if saw_rows {
2562                recorded_first_row_instant
2563                    .expect("recorded_first_row_instant because saw_rows")
2564                    .elapsed()
2565            } else {
2566                // If the result is empty, then we define time from first to last row as 0.
2567                // (Note that, currently, an empty result involves a PeekResponse with 0 rows, which
2568                // does flip `saw_rows`, so this code path is currently not exercised.)
2569                Duration::ZERO
2570            };
2571            self.adapter_client
2572                .inner()
2573                .metrics()
2574                .result_rows_first_to_last_byte_seconds
2575                .with_label_values(&[statement_type])
2576                .observe(duration.as_secs_f64());
2577        }
2578
2579        Ok((
2580            State::Ready,
2581            SendRowsEndedReason::Success {
2582                result_size: u64::cast_from(total_sent_bytes),
2583                rows_returned: u64::cast_from(total_sent_rows),
2584            },
2585        ))
2586    }
2587
2588    #[mz_ore::instrument(level = "debug")]
2589    async fn copy_rows(
2590        &mut self,
2591        format: CopyFormat,
2592        row_desc: RelationDesc,
2593        mut stream: RecordFirstRowStream,
2594    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2595        let (row_format, encode_format) = match format {
2596            CopyFormat::Text => (
2597                CopyFormatParams::Text(CopyTextFormatParams::default()),
2598                Format::Text,
2599            ),
2600            CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2601            CopyFormat::Csv => (
2602                CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2603                Format::Text,
2604            ),
2605            CopyFormat::Parquet => {
2606                let text = "Parquet format is not supported".to_string();
2607                return self
2608                    .send_error_and_get_state(ErrorResponse::error(
2609                        SqlState::INTERNAL_ERROR,
2610                        text.clone(),
2611                    ))
2612                    .await
2613                    .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2614            }
2615        };
2616
2617        let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2618            mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2619        };
2620
2621        let typ = row_desc.typ();
2622        let column_formats = iter::repeat(encode_format)
2623            .take(typ.column_types.len())
2624            .collect();
2625        self.send(BackendMessage::CopyOutResponse {
2626            overall_format: encode_format,
2627            column_formats,
2628        })
2629        .await?;
2630
2631        // In Postgres, binary copy has a header that is followed (in the same
2632        // CopyData) by the first row. In order to replicate their behavior, use a
2633        // common vec that we can extend one time now and then fill up with the encode
2634        // functions.
2635        let mut out = Vec::new();
2636
2637        if let CopyFormat::Binary = format {
2638            // 11-byte signature.
2639            out.extend(b"PGCOPY\n\xFF\r\n\0");
2640            // 32-bit flags field.
2641            out.extend([0, 0, 0, 0]);
2642            // 32-bit header extension length field.
2643            out.extend([0, 0, 0, 0]);
2644        }
2645
2646        let mut count = 0;
2647        let mut total_sent_bytes = 0;
2648        loop {
2649            tokio::select! {
2650                e = self.conn.wait_closed() => return Err(e),
2651                batch = stream.recv() => match batch {
2652                    None => break,
2653                    Some(PeekResponseUnary::Error(text)) => {
2654                        let err =
2655                            ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone());
2656                        return self
2657                            .send_error_and_get_state(err)
2658                            .await
2659                            .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2660                    }
2661                    Some(PeekResponseUnary::DependencyDropped(dep)) => {
2662                        let err = dep.to_concurrent_dependency_drop();
2663                        let text = err.to_string();
2664                        let resp = err.into_response(Severity::Error);
2665                        return self
2666                            .send_error_and_get_state(resp)
2667                            .await
2668                            .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2669                    }
2670                    Some(PeekResponseUnary::Canceled) => {
2671                        return self.send_error_and_get_state(ErrorResponse::error(
2672                                SqlState::QUERY_CANCELED,
2673                                "canceling statement due to user request",
2674                            ))
2675                            .await.map(|state| (state, SendRowsEndedReason::Canceled));
2676                    }
2677                    Some(PeekResponseUnary::Rows(mut rows)) => {
2678                        count += rows.count();
2679                        while let Some(row) = rows.next() {
2680                            total_sent_bytes += row.byte_len();
2681                            encode_fn(row, typ, &mut out)?;
2682                            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2683                                .await?;
2684                        }
2685                    }
2686                },
2687                notice = self.adapter_client.session().recv_notice() => {
2688                    self.send(notice.into_response())
2689                        .await?;
2690                    self.conn.flush().await?;
2691                }
2692            }
2693
2694            self.conn.flush().await?;
2695        }
2696        // Send required trailers.
2697        if let CopyFormat::Binary = format {
2698            let trailer: i16 = -1;
2699            out.extend(trailer.to_be_bytes());
2700            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2701                .await?;
2702        }
2703
2704        let tag = format!("COPY {}", count);
2705        self.send(BackendMessage::CopyDone).await?;
2706        self.send(BackendMessage::CommandComplete { tag }).await?;
2707        Ok((
2708            State::Ready,
2709            SendRowsEndedReason::Success {
2710                result_size: u64::cast_from(total_sent_bytes),
2711                rows_returned: u64::cast_from(count),
2712            },
2713        ))
2714    }
2715
2716    /// Handles the copy-in mode of the postgres protocol from transferring
2717    /// data to the server.
2718    #[instrument(level = "debug")]
2719    async fn copy_from(
2720        &mut self,
2721        target_id: CatalogItemId,
2722        target_name: String,
2723        columns: Vec<ColumnIndex>,
2724        params: CopyFormatParams<'static>,
2725        row_desc: RelationDesc,
2726        mut ctx_extra: ExecuteContextGuard,
2727    ) -> Result<State, io::Error> {
2728        let res = self
2729            .copy_from_inner(
2730                target_id,
2731                target_name,
2732                columns,
2733                params,
2734                row_desc,
2735                &mut ctx_extra,
2736            )
2737            .await;
2738        match &res {
2739            Ok(State::Ready) => {
2740                self.adapter_client.retire_execute(
2741                    ctx_extra,
2742                    StatementEndedExecutionReason::Success {
2743                        result_size: None,
2744                        rows_returned: None,
2745                        execution_strategy: None,
2746                    },
2747                );
2748            }
2749            Ok(State::Done) => {
2750                // The connection closed gracefully without sending us a `CopyDone`,
2751                // causing us to just drop the copy request.
2752                // For the purposes of statement logging, we count this as a cancellation.
2753                self.adapter_client
2754                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2755            }
2756            Err(e) => {
2757                self.adapter_client.retire_execute(
2758                    ctx_extra,
2759                    StatementEndedExecutionReason::Errored {
2760                        error: format!("{e}"),
2761                    },
2762                );
2763            }
2764            Ok(State::Drain) => {}
2765        }
2766        res
2767    }
2768
2769    async fn copy_from_inner(
2770        &mut self,
2771        target_id: CatalogItemId,
2772        target_name: String,
2773        columns: Vec<ColumnIndex>,
2774        params: CopyFormatParams<'static>,
2775        row_desc: RelationDesc,
2776        ctx_extra: &mut ExecuteContextGuard,
2777    ) -> Result<State, io::Error> {
2778        let typ = row_desc.typ();
2779        let column_formats = vec![Format::Text; typ.column_types.len()];
2780        self.send(BackendMessage::CopyInResponse {
2781            overall_format: Format::Text,
2782            column_formats,
2783        })
2784        .await?;
2785        self.conn.flush().await?;
2786
2787        // Set up the parallel streaming batch builders in the coordinator.
2788        let writer = match self
2789            .adapter_client
2790            .start_copy_from_stdin(
2791                target_id,
2792                target_name.clone(),
2793                columns.clone(),
2794                row_desc.clone(),
2795                params.clone(),
2796            )
2797            .await
2798        {
2799            Ok(writer) => writer,
2800            Err(e) => {
2801                // Drain remaining CopyData/CopyDone/CopyFail messages from the
2802                // socket. Since CopyInResponse was already sent, the client may
2803                // have pipelined copy data that we must consume before returning
2804                // the error, otherwise they'd be misinterpreted as top-level
2805                // protocol messages and cause a deadlock.
2806                loop {
2807                    match self.conn.recv().await? {
2808                        Some(FrontendMessage::CopyData(_)) => {}
2809                        Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2810                            break;
2811                        }
2812                        Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2813                        Some(_) => break,
2814                        None => return Ok(State::Done),
2815                    }
2816                }
2817                self.adapter_client.retire_execute(
2818                    std::mem::take(ctx_extra),
2819                    StatementEndedExecutionReason::Errored {
2820                        error: e.to_string(),
2821                    },
2822                );
2823                return self
2824                    .send_error_and_get_state(e.into_response(Severity::Error))
2825                    .await;
2826            }
2827        };
2828
2829        // Enable copy mode on the codec to skip aggregate buffer size checks.
2830        self.conn.set_copy_mode(true);
2831
2832        // Batch size for splitting raw data across parallel workers (~32MB).
2833        const BATCH_SIZE: usize = 32 * 1024 * 1024;
2834        let max_copy_from_row_size = self
2835            .adapter_client
2836            .get_system_vars()
2837            .await
2838            .max_copy_from_row_size()
2839            .try_into()
2840            .unwrap_or(usize::MAX);
2841
2842        let mut data = Vec::new();
2843        let mut row_scanner = CopyRowScanner::new(&params);
2844        let num_workers = writer.batch_txs.len();
2845        let mut next_worker: usize = 0;
2846        let mut saw_copy_done = false;
2847        let mut saw_end_marker = false;
2848        let mut copy_from_error: Option<(SqlState, String)> = None;
2849
2850        // Receive loop: accumulate CopyData, split at row boundaries,
2851        // round-robin raw chunks to parallel batch builder workers.
2852        loop {
2853            let message = self.conn.recv().await?;
2854            match message {
2855                Some(FrontendMessage::CopyData(buf)) => {
2856                    if saw_end_marker {
2857                        // Per PostgreSQL COPY behavior, ignore all bytes after
2858                        // the end-of-copy marker until CopyDone.
2859                        continue;
2860                    }
2861                    data.extend(buf);
2862                    row_scanner.scan_new_bytes(&data);
2863
2864                    if let Some(end_pos) = row_scanner.end_marker_end() {
2865                        data.truncate(end_pos);
2866                        row_scanner.on_truncate(end_pos);
2867                        saw_end_marker = true;
2868                    }
2869
2870                    // Guard against pathological single rows that never terminate.
2871                    if row_scanner.current_row_size(data.len()) > max_copy_from_row_size {
2872                        copy_from_error = Some((
2873                            SqlState::INSUFFICIENT_RESOURCES,
2874                            format!(
2875                                "COPY FROM STDIN row exceeded max_copy_from_row_size \
2876                                 ({max_copy_from_row_size} bytes)"
2877                            ),
2878                        ));
2879                        break;
2880                    }
2881
2882                    // When buffer exceeds batch size, split at the last complete row
2883                    // and send the complete rows chunk to the next worker.
2884                    let mut send_failed = false;
2885                    while data.len() >= BATCH_SIZE {
2886                        let split_pos = match row_scanner.last_row_end() {
2887                            Some(pos) => pos,
2888                            None => break, // no complete row yet
2889                        };
2890                        let remainder = data.split_off(split_pos);
2891                        let chunk = std::mem::replace(&mut data, remainder);
2892                        row_scanner.on_split(split_pos);
2893                        if writer.batch_txs[next_worker].send(chunk).await.is_err() {
2894                            send_failed = true;
2895                            break;
2896                        }
2897                        next_worker = (next_worker + 1) % num_workers;
2898                    }
2899                    // Worker dropped (likely errored) — stop sending,
2900                    // fall through to completion_rx for the real error.
2901                    if send_failed {
2902                        break;
2903                    }
2904                }
2905                Some(FrontendMessage::CopyDone) => {
2906                    // Send any remaining data to the next worker.
2907                    if !data.is_empty() {
2908                        let chunk = std::mem::take(&mut data);
2909                        // Ignore send failure — completion_rx will have the error.
2910                        let _ = writer.batch_txs[next_worker].send(chunk).await;
2911                    }
2912                    saw_copy_done = true;
2913                    break;
2914                }
2915                Some(FrontendMessage::CopyFail(err)) => {
2916                    self.adapter_client.retire_execute(
2917                        std::mem::take(ctx_extra),
2918                        StatementEndedExecutionReason::Canceled,
2919                    );
2920                    // Drop the writer to signal cancellation to the background tasks.
2921                    drop(writer);
2922                    self.conn.set_copy_mode(false);
2923                    return self
2924                        .send_error_and_get_state(ErrorResponse::error(
2925                            SqlState::QUERY_CANCELED,
2926                            format!("COPY from stdin failed: {}", err),
2927                        ))
2928                        .await;
2929                }
2930                Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2931                Some(_) => {
2932                    let msg = "unexpected message type during COPY from stdin";
2933                    self.adapter_client.retire_execute(
2934                        std::mem::take(ctx_extra),
2935                        StatementEndedExecutionReason::Errored {
2936                            error: msg.to_string(),
2937                        },
2938                    );
2939                    drop(writer);
2940                    self.conn.set_copy_mode(false);
2941                    return self
2942                        .send_error_and_get_state(ErrorResponse::error(
2943                            SqlState::PROTOCOL_VIOLATION,
2944                            msg,
2945                        ))
2946                        .await;
2947                }
2948                None => {
2949                    drop(writer);
2950                    self.conn.set_copy_mode(false);
2951                    return Ok(State::Done);
2952                }
2953            }
2954        }
2955
2956        // If we exited the receive loop before seeing `CopyDone` (e.g. because
2957        // a worker failed and dropped its channel), keep draining COPY input to
2958        // avoid desynchronizing the protocol state machine.
2959        if !saw_copy_done {
2960            loop {
2961                match self.conn.recv().await? {
2962                    Some(FrontendMessage::CopyData(_)) => {}
2963                    Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) => {
2964                        break;
2965                    }
2966                    Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2967                    Some(_) => {
2968                        let msg = "unexpected message type during COPY from stdin";
2969                        self.adapter_client.retire_execute(
2970                            std::mem::take(ctx_extra),
2971                            StatementEndedExecutionReason::Errored {
2972                                error: msg.to_string(),
2973                            },
2974                        );
2975                        drop(writer);
2976                        self.conn.set_copy_mode(false);
2977                        return self
2978                            .send_error_and_get_state(ErrorResponse::error(
2979                                SqlState::PROTOCOL_VIOLATION,
2980                                msg,
2981                            ))
2982                            .await;
2983                    }
2984                    None => {
2985                        drop(writer);
2986                        self.conn.set_copy_mode(false);
2987                        return Ok(State::Done);
2988                    }
2989                }
2990            }
2991        }
2992
2993        if let Some((code, msg)) = copy_from_error {
2994            self.adapter_client.retire_execute(
2995                std::mem::take(ctx_extra),
2996                StatementEndedExecutionReason::Errored { error: msg.clone() },
2997            );
2998            drop(writer);
2999            self.conn.set_copy_mode(false);
3000            return self
3001                .send_error_and_get_state(ErrorResponse::error(code, msg))
3002                .await;
3003        }
3004
3005        self.conn.set_copy_mode(false);
3006
3007        // Drop all senders to signal EOF to the background batch builders.
3008        // If copy_err is set, a worker already failed — dropping the senders
3009        // will cause remaining workers to stop, and we'll get the real error
3010        // from completion_rx below.
3011        drop(writer.batch_txs);
3012
3013        // Wait for all parallel workers to finish building batches.
3014        let (proto_batches, row_count) = match writer.completion_rx.await {
3015            Ok(Ok(result)) => result,
3016            Ok(Err(e)) => {
3017                self.adapter_client.retire_execute(
3018                    std::mem::take(ctx_extra),
3019                    StatementEndedExecutionReason::Errored {
3020                        error: e.to_string(),
3021                    },
3022                );
3023                return self
3024                    .send_error_and_get_state(e.into_response(Severity::Error))
3025                    .await;
3026            }
3027            Err(_) => {
3028                let msg = "COPY FROM STDIN: background batch builder tasks dropped";
3029                self.adapter_client.retire_execute(
3030                    std::mem::take(ctx_extra),
3031                    StatementEndedExecutionReason::Errored {
3032                        error: msg.to_string(),
3033                    },
3034                );
3035                return self
3036                    .send_error_and_get_state(ErrorResponse::error(SqlState::INTERNAL_ERROR, msg))
3037                    .await;
3038            }
3039        };
3040
3041        // Stage all batches in the session's transaction for atomic commit.
3042        if let Err(e) = self
3043            .adapter_client
3044            .stage_copy_from_stdin_batches(target_id, proto_batches)
3045        {
3046            self.adapter_client.retire_execute(
3047                std::mem::take(ctx_extra),
3048                StatementEndedExecutionReason::Errored {
3049                    error: e.to_string(),
3050                },
3051            );
3052            return self
3053                .send_error_and_get_state(e.into_response(Severity::Error))
3054                .await;
3055        }
3056
3057        let tag = format!("COPY {}", row_count);
3058        self.send(BackendMessage::CommandComplete { tag }).await?;
3059
3060        Ok(State::Ready)
3061    }
3062
3063    #[instrument(level = "debug")]
3064    async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
3065        let notices = self
3066            .adapter_client
3067            .session()
3068            .drain_notices()
3069            .into_iter()
3070            .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
3071        self.send_all(notices).await?;
3072        Ok(())
3073    }
3074
3075    #[instrument(level = "debug")]
3076    async fn send_error_and_get_state(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
3077        assert!(err.severity.is_error());
3078        debug!(
3079            "cid={} error code={}",
3080            self.adapter_client.session().conn_id(),
3081            err.code.code()
3082        );
3083        let is_fatal = err.severity.is_fatal();
3084        self.send(BackendMessage::ErrorResponse(err)).await?;
3085
3086        let txn = self.adapter_client.session().transaction();
3087        match txn {
3088            // Error can be called from describe and parse and so might not be in an active
3089            // transaction.
3090            TransactionStatus::Default | TransactionStatus::Failed(_) => {}
3091            // In Started (i.e., a single statement), cleanup ourselves.
3092            TransactionStatus::Started(_) => {
3093                self.rollback_transaction().await?;
3094            }
3095            // Implicit transactions also clear themselves.
3096            TransactionStatus::InTransactionImplicit(_) => {
3097                self.rollback_transaction().await?;
3098            }
3099            // Explicit transactions move to failed.
3100            TransactionStatus::InTransaction(_) => {
3101                self.adapter_client.fail_transaction();
3102            }
3103        };
3104        if is_fatal {
3105            Ok(State::Done)
3106        } else {
3107            Ok(State::Drain)
3108        }
3109    }
3110
3111    #[instrument(level = "debug")]
3112    async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
3113        self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
3114            SqlState::IN_FAILED_SQL_TRANSACTION,
3115            ABORTED_TXN_MSG,
3116        )))
3117        .await?;
3118        Ok(State::Drain)
3119    }
3120
3121    fn is_aborted_txn(&mut self) -> bool {
3122        matches!(
3123            self.adapter_client.session().transaction(),
3124            TransactionStatus::Failed(_)
3125        )
3126    }
3127}
3128
3129fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
3130    match (formats.len(), n) {
3131        (0, e) => Ok(vec![Format::Text; e]),
3132        (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
3133        (a, e) if a == e => Ok(formats),
3134        (a, e) => Err(format!(
3135            "expected {} field format specifiers, but got {}",
3136            e, a
3137        )),
3138    }
3139}
3140
3141fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
3142    match &stmt_desc.relation_desc {
3143        Some(desc) if !stmt_desc.is_copy => {
3144            BackendMessage::RowDescription(message::encode_row_description(desc, formats))
3145        }
3146        _ => BackendMessage::NoData,
3147    }
3148}
3149
3150type GetResponse = fn(
3151    max_rows: ExecuteCount,
3152    total_sent_rows: usize,
3153    fetch_portal: Option<PortalRefMut>,
3154) -> BackendMessage;
3155
3156// A GetResponse used by send_rows during execute messages on portals or for
3157// simple query messages.
3158fn portal_exec_message(
3159    max_rows: ExecuteCount,
3160    total_sent_rows: usize,
3161    _fetch_portal: Option<PortalRefMut>,
3162) -> BackendMessage {
3163    // If max_rows is not specified, we will always send back a CommandComplete. If
3164    // max_rows is specified, we only send CommandComplete if there were more rows
3165    // requested than were remaining. That is, if max_rows == number of rows that
3166    // were remaining before sending (not that are remaining after sending), then
3167    // we still send a PortalSuspended. The number of remaining rows after the rows
3168    // have been sent doesn't matter. This matches postgres.
3169    match max_rows {
3170        ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
3171            BackendMessage::PortalSuspended
3172        }
3173        _ => BackendMessage::CommandComplete {
3174            tag: format!("SELECT {}", total_sent_rows),
3175        },
3176    }
3177}
3178
3179// A GetResponse used by send_rows during FETCH queries.
3180fn fetch_message(
3181    _max_rows: ExecuteCount,
3182    total_sent_rows: usize,
3183    fetch_portal: Option<PortalRefMut>,
3184) -> BackendMessage {
3185    let tag = format!("FETCH {}", total_sent_rows);
3186    if let Some(portal) = fetch_portal {
3187        *portal.state = PortalState::Completed(Some(tag.clone()));
3188    }
3189    BackendMessage::CommandComplete { tag }
3190}
3191
3192fn get_authenticator(
3193    authenticator_kind: listeners::AuthenticatorKind,
3194    frontegg: Option<FronteggAuthenticator>,
3195    oidc: GenericOidcAuthenticator,
3196    adapter_client: mz_adapter::Client,
3197) -> Authenticator {
3198    match authenticator_kind {
3199        listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
3200            "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
3201        )),
3202        listeners::AuthenticatorKind::Password => Authenticator::Password(adapter_client),
3203        listeners::AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client),
3204        listeners::AuthenticatorKind::Oidc => Authenticator::Oidc(oidc),
3205        listeners::AuthenticatorKind::None => Authenticator::None,
3206    }
3207}
3208
3209#[derive(Debug, Copy, Clone)]
3210enum ExecuteCount {
3211    All,
3212    Count(usize),
3213}
3214
3215// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
3216fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
3217    match stmt {
3218        // Add PREPARE to this if we ever support it.
3219        Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
3220        None => false,
3221    }
3222}
3223
3224#[derive(Debug)]
3225enum FetchResult {
3226    Rows(Option<Box<dyn RowIterator + Send + Sync>>),
3227    Canceled,
3228    Error(String),
3229    Notice(AdapterNotice),
3230}
3231
3232#[derive(Debug)]
3233struct CopyRowScanner {
3234    scan_pos: usize,
3235    last_row_end: Option<usize>,
3236    end_marker_end: Option<usize>,
3237    // Byte offset within `data` at which the in-progress CSV record begins.
3238    // Used to verify the end-of-copy marker against the raw input bytes,
3239    // distinguishing a literal `\.` line from a quoted CSV value `"\."`
3240    // whose decoded form is also `\.`.
3241    record_start: usize,
3242    csv: Option<CsvScanState>,
3243}
3244
3245#[derive(Debug)]
3246struct CsvScanState {
3247    reader: csv_core::Reader,
3248    output: Vec<u8>,
3249    ends: Vec<usize>,
3250    skip_first_record: bool,
3251}
3252
3253impl CopyRowScanner {
3254    fn new(params: &CopyFormatParams<'_>) -> Self {
3255        let csv = match params {
3256            CopyFormatParams::Csv(CopyCsvFormatParams {
3257                delimiter,
3258                quote,
3259                escape,
3260                header,
3261                ..
3262            }) => Some(CsvScanState::new(*delimiter, *quote, *escape, *header)),
3263            _ => None,
3264        };
3265
3266        CopyRowScanner {
3267            scan_pos: 0,
3268            last_row_end: None,
3269            end_marker_end: None,
3270            record_start: 0,
3271            csv,
3272        }
3273    }
3274
3275    fn scan_new_bytes(&mut self, data: &[u8]) {
3276        if self.scan_pos >= data.len() {
3277            return;
3278        }
3279
3280        if let Some(csv) = self.csv.as_mut() {
3281            let mut input = &data[self.scan_pos..];
3282            let mut consumed = 0usize;
3283            while !input.is_empty() {
3284                let (result, n_input, _n_output, _n_ends) =
3285                    csv.reader
3286                        .read_record(input, &mut csv.output, &mut csv.ends);
3287                consumed += n_input;
3288                input = &input[n_input..];
3289
3290                match result {
3291                    ReadRecordResult::InputEmpty => break,
3292                    ReadRecordResult::OutputFull => {
3293                        if n_input == 0 {
3294                            csv.output
3295                                .resize(csv.output.len().saturating_mul(2).max(1), 0);
3296                        }
3297                    }
3298                    ReadRecordResult::OutputEndsFull => {
3299                        if n_input == 0 {
3300                            csv.ends.resize(csv.ends.len().saturating_mul(2).max(1), 0);
3301                        }
3302                    }
3303                    ReadRecordResult::Record | ReadRecordResult::End => {
3304                        let row_end = self.scan_pos + consumed;
3305                        self.last_row_end = Some(row_end);
3306                        if self.end_marker_end.is_none() {
3307                            let is_marker = if csv.skip_first_record {
3308                                csv.skip_first_record = false;
3309                                false
3310                            } else {
3311                                // Detect the marker against the raw input
3312                                // bytes, not the CSV-decoded record. A quoted
3313                                // data row `"\."` decodes to `\.` but must be
3314                                // imported as data; only a bare `\.` line
3315                                // terminates the COPY.
3316                                let raw = &data[self.record_start..row_end];
3317                                // csv-core ends a CRLF record after the `\r`,
3318                                // leaving the trailing `\n` as the leading byte
3319                                // of the next record's span; a CR-only record
3320                                // ends in a lone `\r`. So a `\.` marker record's
3321                                // raw span can be `\.\n` (LF), `\n\.\r` (CRLF)
3322                                // or `\.\r` (CR). Trim CR/LF from both ends
3323                                // before comparing — a trailing-only strip would
3324                                // miss the CRLF/CR forms. Quoted `"\."` data
3325                                // keeps its surrounding quotes after trimming and
3326                                // is therefore correctly rejected.
3327                                let start = raw
3328                                    .iter()
3329                                    .take_while(|&&b| b == b'\r' || b == b'\n')
3330                                    .count();
3331                                let trailing = raw[start..]
3332                                    .iter()
3333                                    .rev()
3334                                    .take_while(|&&b| b == b'\r' || b == b'\n')
3335                                    .count();
3336                                let trimmed = &raw[start..raw.len() - trailing];
3337                                trimmed == b"\\."
3338                            };
3339                            if is_marker {
3340                                self.end_marker_end = Some(row_end);
3341                                self.record_start = row_end;
3342                                break;
3343                            }
3344                        }
3345                        self.record_start = row_end;
3346                    }
3347                }
3348            }
3349        } else {
3350            let mut row_start = self.last_row_end.unwrap_or(0);
3351            for (offset, b) in data[self.scan_pos..].iter().enumerate() {
3352                if *b == b'\n' {
3353                    let row_end = self.scan_pos + offset + 1;
3354                    self.last_row_end = Some(row_end);
3355                    if self.end_marker_end.is_none() {
3356                        let row = &data[row_start..row_end];
3357                        if row.get(0..2) == Some(b"\\.") {
3358                            self.end_marker_end = Some(row_end);
3359                            break;
3360                        }
3361                    }
3362                    row_start = row_end;
3363                }
3364            }
3365        }
3366
3367        self.scan_pos = data.len();
3368    }
3369
3370    fn last_row_end(&self) -> Option<usize> {
3371        self.last_row_end
3372    }
3373
3374    fn end_marker_end(&self) -> Option<usize> {
3375        self.end_marker_end
3376    }
3377
3378    fn current_row_size(&self, data_len: usize) -> usize {
3379        data_len.saturating_sub(self.last_row_end.unwrap_or(0))
3380    }
3381
3382    fn on_split(&mut self, split_pos: usize) {
3383        self.scan_pos = self.scan_pos.saturating_sub(split_pos);
3384        self.last_row_end = None;
3385        self.end_marker_end = self
3386            .end_marker_end
3387            .and_then(|end| end.checked_sub(split_pos));
3388        // `record_start` is only maintained for the CSV path; the text and
3389        // binary paths leave it at 0. For CSV, splits always occur at a
3390        // completed-row boundary, so the in-progress record (if any) starts at
3391        // the new beginning of the buffer. Assert that invariant so the
3392        // `saturating_sub` below doesn't silently paper over a bug that
3393        // bisected an in-progress record — but only when CSV is in use, since
3394        // otherwise `record_start` is meaninglessly 0.
3395        soft_assert_or_log!(
3396            self.csv.is_none() || self.record_start >= split_pos,
3397            "split bisected an in-progress CSV record: record_start={} < split_pos={}",
3398            self.record_start,
3399            split_pos,
3400        );
3401        self.record_start = self.record_start.saturating_sub(split_pos);
3402    }
3403
3404    fn on_truncate(&mut self, new_len: usize) {
3405        self.scan_pos = self.scan_pos.min(new_len);
3406        self.last_row_end = self.last_row_end.filter(|&end| end <= new_len);
3407        self.end_marker_end = self.end_marker_end.filter(|&end| end <= new_len);
3408        self.record_start = self.record_start.min(new_len);
3409    }
3410}
3411
3412impl CsvScanState {
3413    fn new(delimiter: u8, quote: u8, escape: u8, header: bool) -> Self {
3414        let (double_quote, escape) = if quote == escape {
3415            (true, None)
3416        } else {
3417            (false, Some(escape))
3418        };
3419        CsvScanState {
3420            reader: csv_core::ReaderBuilder::new()
3421                .delimiter(delimiter)
3422                .quote(quote)
3423                .double_quote(double_quote)
3424                .escape(escape)
3425                .build(),
3426            output: vec![0; 1],
3427            ends: vec![0; 1],
3428            skip_first_record: header,
3429        }
3430    }
3431}
3432
3433#[cfg(test)]
3434mod test {
3435    use super::*;
3436
3437    #[mz_ore::test]
3438    fn test_copy_row_scanner_end_marker_line_endings() {
3439        // The pgwire COPY row scanner must detect a bare `\.` end-of-copy
3440        // marker for every line ending, and must never mistake a quoted
3441        // `"\."` data row for it. csv-core ends a CRLF record after the `\r`
3442        // (leaving the `\n` as the next record's leading byte), so the raw
3443        // record span of a `\.` marker is `\.\n` (LF), `\n\.\r` (CRLF) or
3444        // `\.\r` (CR); a trailing-only strip would miss the CRLF/CR forms and
3445        // silently import post-marker rows.
3446        let params = CopyFormatParams::Csv(CopyCsvFormatParams::default());
3447
3448        let marker_end = |data: &[u8]| -> Option<usize> {
3449            let mut scanner = CopyRowScanner::new(&params);
3450            scanner.scan_new_bytes(data);
3451            scanner.end_marker_end()
3452        };
3453
3454        for eol in [&b"\n"[..], b"\r\n", b"\r"] {
3455            let join = |lines: &[&str]| -> Vec<u8> {
3456                let mut out = Vec::new();
3457                for line in lines {
3458                    out.extend_from_slice(line.as_bytes());
3459                    out.extend_from_slice(eol);
3460                }
3461                out
3462            };
3463
3464            // Bare `\.` (the marker is the second record, so record_start has
3465            // already advanced past the orphaned terminator of `first`).
3466            // csv-core reports the record after a single terminator byte, so
3467            // the marker boundary sits just past `first<eol>\.` + one byte.
3468            let data = join(&["first", "\\.", "after"]);
3469            let mut prefix = Vec::new();
3470            prefix.extend_from_slice(b"first");
3471            prefix.extend_from_slice(eol);
3472            prefix.extend_from_slice(b"\\.");
3473            assert_eq!(
3474                marker_end(&data),
3475                Some(prefix.len() + 1),
3476                "bare marker, eol={eol:?}"
3477            );
3478
3479            // Quoted "\." is data, not the marker.
3480            let data = join(&["before", "\"\\.\"", "after"]);
3481            assert_eq!(marker_end(&data), None, "quoted marker, eol={eol:?}");
3482        }
3483    }
3484
3485    #[mz_ore::test]
3486    fn test_copy_row_scanner_non_csv_split() {
3487        // Regression: `record_start` is only maintained for the CSV path; the
3488        // text and binary paths leave it at 0. `on_split` must therefore not
3489        // assert `record_start >= split_pos` for those formats — that fires on
3490        // every split of a large text/binary COPY stream (soft-assertions
3491        // panic under test). Mirrors `COPY ... FROM STDIN` (default text
3492        // format) splitting at a row boundary once the buffer fills.
3493        for params in [
3494            CopyFormatParams::Text(CopyTextFormatParams::default()),
3495            CopyFormatParams::Binary,
3496        ] {
3497            let mut scanner = CopyRowScanner::new(&params);
3498            let data = b"1\thello world\t2\tsome text value here\n\
3499                         3\thello world\t6\tsome text value here\n";
3500            scanner.scan_new_bytes(data);
3501            let split_pos = scanner.last_row_end().expect("a complete row");
3502            assert!(split_pos > 0, "params={params:?}");
3503            // Must not panic via the CSV-only `on_split` soft-assert.
3504            scanner.on_split(split_pos);
3505            assert_eq!(scanner.record_start, 0, "params={params:?}");
3506        }
3507    }
3508
3509    #[mz_ore::test]
3510    fn test_parse_options() {
3511        struct TestCase {
3512            input: &'static str,
3513            expect: Result<Vec<(&'static str, &'static str)>, ()>,
3514        }
3515        let tests = vec![
3516            TestCase {
3517                input: "",
3518                expect: Ok(vec![]),
3519            },
3520            TestCase {
3521                input: "--key",
3522                expect: Err(()),
3523            },
3524            TestCase {
3525                input: "--key=val",
3526                expect: Ok(vec![("key", "val")]),
3527            },
3528            TestCase {
3529                input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
3530                expect: Ok(vec![
3531                    ("key", "val"),
3532                    ("key2", "val2"),
3533                    ("key3", "val3"),
3534                    ("key4", "val4"),
3535                    ("key5", "val5"),
3536                ]),
3537            },
3538            TestCase {
3539                input: r#"-c\ key=val"#,
3540                expect: Ok(vec![(" key", "val")]),
3541            },
3542            TestCase {
3543                input: "--key=val -ckey2 val2",
3544                expect: Err(()),
3545            },
3546            // Unclear what this should do.
3547            TestCase {
3548                input: "--key=",
3549                expect: Ok(vec![("key", "")]),
3550            },
3551        ];
3552        for test in tests {
3553            let got = parse_options(test.input);
3554            let expect = test.expect.map(|r| {
3555                r.into_iter()
3556                    .map(|(k, v)| (k.to_owned(), v.to_owned()))
3557                    .collect()
3558            });
3559            assert_eq!(got, expect, "input: {}", test.input);
3560        }
3561    }
3562
3563    #[mz_ore::test]
3564    fn test_parse_option() {
3565        struct TestCase {
3566            input: &'static str,
3567            expect: Result<(&'static str, &'static str), ()>,
3568        }
3569        let tests = vec![
3570            TestCase {
3571                input: "",
3572                expect: Err(()),
3573            },
3574            TestCase {
3575                input: "--",
3576                expect: Err(()),
3577            },
3578            TestCase {
3579                input: "--c",
3580                expect: Err(()),
3581            },
3582            TestCase {
3583                input: "a=b",
3584                expect: Err(()),
3585            },
3586            TestCase {
3587                input: "--a=b",
3588                expect: Ok(("a", "b")),
3589            },
3590            TestCase {
3591                input: "--ca=b",
3592                expect: Ok(("ca", "b")),
3593            },
3594            TestCase {
3595                input: "-ca=b",
3596                expect: Ok(("a", "b")),
3597            },
3598            // Unclear what this should error, but at least test it.
3599            TestCase {
3600                input: "--=",
3601                expect: Ok(("", "")),
3602            },
3603        ];
3604        for test in tests {
3605            let got = parse_option(test.input);
3606            assert_eq!(got, test.expect, "input: {}", test.input);
3607        }
3608    }
3609
3610    #[mz_ore::test]
3611    fn test_split_options() {
3612        struct TestCase {
3613            input: &'static str,
3614            expect: Vec<&'static str>,
3615        }
3616        let tests = vec![
3617            TestCase {
3618                input: "",
3619                expect: vec![],
3620            },
3621            TestCase {
3622                input: "  ",
3623                expect: vec![],
3624            },
3625            TestCase {
3626                input: " a ",
3627                expect: vec!["a"],
3628            },
3629            TestCase {
3630                input: "  ab     cd   ",
3631                expect: vec!["ab", "cd"],
3632            },
3633            TestCase {
3634                input: r#"  ab\     cd   "#,
3635                expect: vec!["ab ", "cd"],
3636            },
3637            TestCase {
3638                input: r#"  ab\\     cd   "#,
3639                expect: vec![r#"ab\"#, "cd"],
3640            },
3641            TestCase {
3642                input: r#"  ab\\\     cd   "#,
3643                expect: vec![r#"ab\ "#, "cd"],
3644            },
3645            TestCase {
3646                input: r#"  ab\\\ cd   "#,
3647                expect: vec![r#"ab\ cd"#],
3648            },
3649            TestCase {
3650                input: r#"  ab\\\cd   "#,
3651                expect: vec![r#"ab\cd"#],
3652            },
3653            TestCase {
3654                input: r#"a\"#,
3655                expect: vec!["a"],
3656            },
3657            TestCase {
3658                input: r#"a\ "#,
3659                expect: vec!["a "],
3660            },
3661            TestCase {
3662                input: r#"\"#,
3663                expect: vec![],
3664            },
3665            TestCase {
3666                input: r#"\ "#,
3667                expect: vec![r#" "#],
3668            },
3669            TestCase {
3670                input: r#" \ "#,
3671                expect: vec![r#" "#],
3672            },
3673            TestCase {
3674                input: r#"\  "#,
3675                expect: vec![r#" "#],
3676            },
3677        ];
3678        for test in tests {
3679            let got = split_options(test.input);
3680            assert_eq!(got, test.expect, "input: {}", test.input);
3681        }
3682    }
3683
3684    #[mz_ore::test]
3685    fn test_is_jwt() {
3686        // A real JWT header decodes successfully.
3687        assert!(is_jwt("eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0.signature"));
3688        // Not JWTs: plain strings, wrong segment count, non-JSON headers.
3689        for s in [
3690            "",
3691            "secure_password",
3692            "p4ss.w0rd",
3693            "aaa.bbb.ccc",
3694            "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0",
3695            "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0.sig.extra",
3696        ] {
3697            assert!(!is_jwt(s), "is_jwt({s:?})");
3698        }
3699    }
3700}