Skip to main content

mz_pgwire/
protocol.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use futures::future::{BoxFuture, FutureExt, pending};
21use itertools::Itertools;
22use mz_adapter::client::RecordFirstRowStream;
23use mz_adapter::session::{
24    EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState, Session,
25    SessionConfig, TransactionStatus,
26};
27use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
28use mz_adapter::{
29    AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics,
30    verify_datum_desc,
31};
32use mz_auth::Authenticated;
33use mz_auth::password::Password;
34use mz_authenticator::{Authenticator, GenericOidcAuthenticator};
35use mz_frontegg_auth::Authenticator as FronteggAuthenticator;
36use mz_ore::cast::CastFrom;
37use mz_ore::netio::AsyncReady;
38use mz_ore::now::{EpochMillis, SYSTEM_TIME};
39use mz_ore::str::StrExt;
40use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
41use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
42use mz_pgwire_common::{
43    ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
44    VERSIONS,
45};
46use mz_repr::{
47    CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
48    SqlRelationType, SqlScalarType,
49};
50use mz_server_core::TlsMode;
51use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind};
52use mz_sql::ast::display::AstDisplay;
53use mz_sql::ast::{CopyDirection, CopyStatement, FetchDirection, Ident, Raw, Statement};
54use mz_sql::parse::StatementParseResult;
55use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
56use mz_sql::session::metadata::SessionMetadata;
57use mz_sql::session::user::INTERNAL_USER_NAMES;
58use mz_sql::session::vars::{MAX_COPY_FROM_SIZE, Var, VarInput};
59use postgres::error::SqlState;
60use tokio::io::{self, AsyncRead, AsyncWrite};
61use tokio::select;
62use tokio::time::{self};
63use tokio_metrics::TaskMetrics;
64use tokio_stream::wrappers::UnboundedReceiverStream;
65use tracing::{Instrument, debug, debug_span, warn};
66use uuid::Uuid;
67
68use crate::codec::{
69    FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
70};
71use crate::message::{
72    self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
73    SASLServerFirstMessage,
74};
75
76/// Reports whether the given stream begins with a pgwire handshake.
77///
78/// To avoid false negatives, there must be at least eight bytes in `buf`.
79pub fn match_handshake(buf: &[u8]) -> bool {
80    // The pgwire StartupMessage looks like this:
81    //
82    //     i32 - Length of entire message.
83    //     i32 - Protocol version number.
84    //     [String] - Arbitrary key-value parameters of any length.
85    //
86    // Since arbitrary parameters can be included in the StartupMessage, the
87    // first Int32 is worthless, since the message could have any length.
88    // Instead, we sniff the protocol version number.
89    if buf.len() < 8 {
90        return false;
91    }
92    let version = NetworkEndian::read_i32(&buf[4..8]);
93    VERSIONS.contains(&version)
94}
95
96/// Parameters for the [`run`] function.
97pub struct RunParams<'a, A, I>
98where
99    I: Iterator<Item = TaskMetrics> + Send,
100{
101    /// The TLS mode of the pgwire server.
102    pub tls_mode: Option<TlsMode>,
103    /// A client for the adapter.
104    pub adapter_client: mz_adapter::Client,
105    /// The connection to the client.
106    pub conn: &'a mut FramedConn<A>,
107    /// The universally unique identifier for the connection.
108    pub conn_uuid: Uuid,
109    /// The protocol version that the client provided in the startup message.
110    pub version: i32,
111    /// The parameters that the client provided in the startup message.
112    pub params: BTreeMap<String, String>,
113    /// Frontegg JWT authenticator.
114    pub frontegg: Option<FronteggAuthenticator>,
115    /// OIDC authenticator.
116    pub oidc: GenericOidcAuthenticator,
117    /// The authentication method defined by the server's listener
118    /// configuration.
119    pub authenticator_kind: AuthenticatorKind,
120    /// Global connection limit and count
121    pub active_connection_counter: ConnectionCounter,
122    /// Helm chart version
123    pub helm_chart_version: Option<String>,
124    /// Whether to allow reserved users (ie: mz_system).
125    pub allowed_roles: AllowedRoles,
126    /// Tokio metrics
127    pub tokio_metrics_intervals: I,
128}
129
130/// Runs a pgwire connection to completion.
131///
132/// This involves responding to `FrontendMessage::StartupMessage` and all future
133/// requests until the client terminates the connection or a fatal error occurs.
134///
135/// Note that this function returns successfully even upon delivering a fatal
136/// error to the client. It only returns `Err` if an unexpected I/O error occurs
137/// while communicating with the client, e.g., if the connection is severed in
138/// the middle of a request.
139#[mz_ore::instrument(level = "debug")]
140pub async fn run<'a, A, I>(
141    RunParams {
142        tls_mode,
143        adapter_client,
144        conn,
145        conn_uuid,
146        version,
147        mut params,
148        frontegg,
149        oidc,
150        authenticator_kind,
151        active_connection_counter,
152        helm_chart_version,
153        allowed_roles,
154        tokio_metrics_intervals,
155    }: RunParams<'a, A, I>,
156) -> Result<(), io::Error>
157where
158    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
159    I: Iterator<Item = TaskMetrics> + Send,
160{
161    if version != VERSION_3 {
162        return conn
163            .send(ErrorResponse::fatal(
164                SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
165                "server does not support the client's requested protocol version",
166            ))
167            .await;
168    }
169
170    let user = params.remove("user").unwrap_or_else(String::new);
171    let options = parse_options(params.get("options").unwrap_or(&String::new()));
172
173    // If oidc_auth_enabled exists as an option, return its value and filter it from
174    // the remaining options.
175    let (oidc_auth_enabled, options) = extract_oidc_auth_enabled_from_options(options);
176    let authenticator = get_authenticator(
177        authenticator_kind,
178        frontegg,
179        oidc,
180        adapter_client.clone(),
181        oidc_auth_enabled,
182    );
183    // TODO move this somewhere it can be shared with HTTP
184    let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
185    // this is a superset of internal users
186    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
187    let role_allowed = match allowed_roles {
188        AllowedRoles::Normal => !is_reserved_user,
189        AllowedRoles::Internal => is_internal_user,
190        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
191    };
192    if !role_allowed {
193        let msg = format!("unauthorized login to user '{user}'");
194        return conn
195            .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
196            .await;
197    }
198
199    if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
200        return conn.send(err).await;
201    }
202
203    let (mut session, expired) = match authenticator {
204        Authenticator::Frontegg(frontegg) => {
205            let password = match request_cleartext_password(conn).await {
206                Ok(password) => password,
207                Err(PasswordRequestError::IoError(e)) => return Err(e),
208                Err(PasswordRequestError::InvalidPasswordError(e)) => {
209                    return conn.send(e).await;
210                }
211            };
212
213            let auth_response = frontegg.authenticate(&user, &password).await;
214            match auth_response {
215                // Create a session based on the auth session.
216                //
217                // In particular, it's important that the username come from the
218                // auth session, as Frontegg may return an email address with
219                // different casing than the user supplied via the pgwire
220                // username fN
221                Ok((mut auth_session, authenticated)) => {
222                    let session = adapter_client.new_session(
223                        SessionConfig {
224                            conn_id: conn.conn_id().clone(),
225                            uuid: conn_uuid,
226                            user: auth_session.user().into(),
227                            client_ip: conn.peer_addr().clone(),
228                            external_metadata_rx: Some(auth_session.external_metadata_rx()),
229                            helm_chart_version,
230                        },
231                        authenticated,
232                    );
233                    let expired = async move { auth_session.expired().await };
234                    (session, expired.left_future())
235                }
236                Err(err) => {
237                    warn!(?err, "pgwire connection failed authentication");
238                    return conn
239                        .send(ErrorResponse::fatal(
240                            SqlState::INVALID_PASSWORD,
241                            "invalid password",
242                        ))
243                        .await;
244                }
245            }
246        }
247        Authenticator::Oidc(oidc) => {
248            // OIDC authentication: JWT sent as password in cleartext flow
249            let jwt = match request_cleartext_password(conn).await {
250                Ok(password) => password,
251                Err(PasswordRequestError::IoError(e)) => return Err(e),
252                Err(PasswordRequestError::InvalidPasswordError(e)) => {
253                    return conn.send(e).await;
254                }
255            };
256            let auth_response = oidc.authenticate(&jwt, Some(&user)).await;
257            match auth_response {
258                Ok((claims, authenticated)) => {
259                    let session = adapter_client.new_session(
260                        SessionConfig {
261                            conn_id: conn.conn_id().clone(),
262                            uuid: conn_uuid,
263                            user: claims.username().into(),
264                            client_ip: conn.peer_addr().clone(),
265                            external_metadata_rx: None,
266                            helm_chart_version,
267                        },
268                        authenticated,
269                    );
270                    // No invalidation of the auth session once authenticated,
271                    // so auth session lasts indefinitely.
272                    (session, pending().right_future())
273                }
274                Err(err) => {
275                    warn!(?err, "pgwire connection failed authentication");
276                    return conn
277                        .send(ErrorResponse::fatal(
278                            SqlState::INVALID_PASSWORD,
279                            "invalid password",
280                        ))
281                        .await;
282                }
283            }
284        }
285        Authenticator::Password(adapter_client) => {
286            let session = match authenticate_with_password(
287                conn,
288                &adapter_client,
289                user,
290                conn_uuid,
291                helm_chart_version,
292            )
293            .await
294            {
295                Ok(session) => session,
296                Err(PasswordRequestError::IoError(e)) => return Err(e),
297                Err(PasswordRequestError::InvalidPasswordError(e)) => {
298                    return conn.send(e).await;
299                }
300            };
301            // No frontegg check, so auth session lasts indefinitely.
302            (session, pending().right_future())
303        }
304        Authenticator::Sasl(adapter_client) => {
305            // Start the handshake
306            conn.send(BackendMessage::AuthenticationSASL).await?;
307            conn.flush().await?;
308            // Get the initial response indicating chosen mechanism
309            let (mechanism, initial_response) = match conn.recv().await? {
310                Some(FrontendMessage::RawAuthentication(data)) => {
311                    match decode_sasl_initial_response(Cursor::new(&data)).ok() {
312                        Some(FrontendMessage::SASLInitialResponse {
313                            gs2_header,
314                            mechanism,
315                            initial_response,
316                        }) => {
317                            // We do not support channel binding
318                            if gs2_header.channel_binding_enabled() {
319                                return conn
320                                    .send(ErrorResponse::fatal(
321                                        SqlState::PROTOCOL_VIOLATION,
322                                        "channel binding not supported",
323                                    ))
324                                    .await;
325                            }
326                            (mechanism, initial_response)
327                        }
328                        _ => {
329                            return conn
330                                .send(ErrorResponse::fatal(
331                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
332                                    "expected SASLInitialResponse message",
333                                ))
334                                .await;
335                        }
336                    }
337                }
338                _ => {
339                    return conn
340                        .send(ErrorResponse::fatal(
341                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
342                            "expected SASLInitialResponse message",
343                        ))
344                        .await;
345                }
346            };
347
348            if mechanism != "SCRAM-SHA-256" {
349                return conn
350                    .send(ErrorResponse::fatal(
351                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
352                        "unsupported SASL mechanism",
353                    ))
354                    .await;
355            }
356
357            if initial_response.nonce.len() > 256 {
358                return conn
359                    .send(ErrorResponse::fatal(
360                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
361                        "nonce too long",
362                    ))
363                    .await;
364            }
365
366            let (server_first_message_raw, mock_hash) = match adapter_client
367                .generate_sasl_challenge(&user, &initial_response.nonce)
368                .await
369            {
370                Ok(response) => {
371                    let server_first_message_raw = format!(
372                        "r={},s={},i={}",
373                        response.nonce, response.salt, response.iteration_count
374                    );
375
376                    let client_key = [0u8; 32];
377                    let server_key = [1u8; 32];
378                    let mock_hash = format!(
379                        "SCRAM-SHA-256${}:{}${}:{}",
380                        response.iteration_count,
381                        response.salt,
382                        BASE64_STANDARD.encode(client_key),
383                        BASE64_STANDARD.encode(server_key)
384                    );
385
386                    conn.send(BackendMessage::AuthenticationSASLContinue(
387                        SASLServerFirstMessage {
388                            iteration_count: response.iteration_count,
389                            nonce: response.nonce,
390                            salt: response.salt,
391                        },
392                    ))
393                    .await?;
394                    conn.flush().await?;
395                    (server_first_message_raw, mock_hash)
396                }
397                Err(e) => {
398                    return conn.send(e.into_response(Severity::Fatal)).await;
399                }
400            };
401
402            let authenticated = match conn.recv().await? {
403                Some(FrontendMessage::RawAuthentication(data)) => {
404                    match decode_sasl_response(Cursor::new(&data)).ok() {
405                        Some(FrontendMessage::SASLResponse(response)) => {
406                            let auth_message = format!(
407                                "{},{},{}",
408                                initial_response.client_first_message_bare_raw,
409                                server_first_message_raw,
410                                response.client_final_message_bare_raw
411                            );
412                            if response.proof.len() > 1024 {
413                                return conn
414                                    .send(ErrorResponse::fatal(
415                                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
416                                        "proof too long",
417                                    ))
418                                    .await;
419                            }
420                            match adapter_client
421                                .verify_sasl_proof(
422                                    &user,
423                                    &response.proof,
424                                    &auth_message,
425                                    &mock_hash,
426                                )
427                                .await
428                            {
429                                Ok((proof_response, authenticated)) => {
430                                    conn.send(BackendMessage::AuthenticationSASLFinal(
431                                        SASLServerFinalMessage {
432                                            kind: SASLServerFinalMessageKinds::Verifier(
433                                                proof_response.verifier,
434                                            ),
435                                            extensions: vec![],
436                                        },
437                                    ))
438                                    .await?;
439                                    conn.flush().await?;
440                                    authenticated
441                                }
442                                Err(_) => {
443                                    return conn
444                                        .send(ErrorResponse::fatal(
445                                            SqlState::INVALID_PASSWORD,
446                                            "invalid password",
447                                        ))
448                                        .await;
449                                }
450                            }
451                        }
452                        _ => {
453                            return conn
454                                .send(ErrorResponse::fatal(
455                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
456                                    "expected SASLResponse message",
457                                ))
458                                .await;
459                        }
460                    }
461                }
462                _ => {
463                    return conn
464                        .send(ErrorResponse::fatal(
465                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
466                            "expected SASLResponse message",
467                        ))
468                        .await;
469                }
470            };
471
472            let session = adapter_client.new_session(
473                SessionConfig {
474                    conn_id: conn.conn_id().clone(),
475                    uuid: conn_uuid,
476                    user,
477                    client_ip: conn.peer_addr().clone(),
478                    external_metadata_rx: None,
479                    helm_chart_version,
480                },
481                authenticated,
482            );
483            // No frontegg check, so auth session lasts indefinitely.
484            let auth_session = pending().right_future();
485            (session, auth_session)
486        }
487
488        Authenticator::None => {
489            let session = adapter_client.new_session(
490                SessionConfig {
491                    conn_id: conn.conn_id().clone(),
492                    uuid: conn_uuid,
493                    user,
494                    client_ip: conn.peer_addr().clone(),
495                    external_metadata_rx: None,
496                    helm_chart_version,
497                },
498                Authenticated,
499            );
500            // No frontegg check, so auth session lasts indefinitely.
501            let auth_session = pending().right_future();
502            (session, auth_session)
503        }
504    };
505
506    let system_vars = adapter_client.get_system_vars().await;
507    for (name, value) in params {
508        let settings = match name.as_str() {
509            "options" => match &options {
510                Ok(opts) => opts,
511                Err(()) => {
512                    session.add_notice(AdapterNotice::BadStartupSetting {
513                        name,
514                        reason: "could not parse".into(),
515                    });
516                    continue;
517                }
518            },
519            _ => &vec![(name, value)],
520        };
521        for (key, val) in settings {
522            const LOCAL: bool = false;
523            // TODO: Issuing an error here is better than what we did before
524            // (silently ignore errors on set), but erroring the connection
525            // might be the better behavior. We maybe need to support more
526            // options sent by psql and drivers before we can safely do this.
527            if let Err(err) = session
528                .vars_mut()
529                .set(&system_vars, key, VarInput::Flat(val), LOCAL)
530            {
531                session.add_notice(AdapterNotice::BadStartupSetting {
532                    name: key.clone(),
533                    reason: err.to_string(),
534                });
535            }
536        }
537    }
538    session
539        .vars_mut()
540        .end_transaction(EndTransactionAction::Commit);
541
542    let _guard = match active_connection_counter.allocate_connection(session.user()) {
543        Ok(drop_connection) => drop_connection,
544        Err(e) => {
545            let e: AdapterError = e.into();
546            return conn.send(e.into_response(Severity::Fatal)).await;
547        }
548    };
549
550    // Register session with adapter.
551    let mut adapter_client = match adapter_client.startup(session).await {
552        Ok(adapter_client) => adapter_client,
553        Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
554    };
555
556    let mut buf = vec![BackendMessage::AuthenticationOk];
557    for var in adapter_client.session().vars().notify_set() {
558        buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
559    }
560    buf.push(BackendMessage::BackendKeyData {
561        conn_id: adapter_client.session().conn_id().unhandled(),
562        secret_key: adapter_client.session().secret_key(),
563    });
564    buf.extend(
565        adapter_client
566            .session()
567            .drain_notices()
568            .into_iter()
569            .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
570    );
571    buf.push(BackendMessage::ReadyForQuery(
572        adapter_client.session().transaction().into(),
573    ));
574    conn.send_all(buf).await?;
575    conn.flush().await?;
576
577    let machine = StateMachine {
578        conn,
579        adapter_client,
580        txn_needs_commit: false,
581        tokio_metrics_intervals,
582    };
583
584    select! {
585        r = machine.run() => {
586            // Errors produced internally (like MAX_REQUEST_SIZE being exceeded) should send an
587            // error to the client informing them why the connection was closed. We still want to
588            // return the original error up the stack, though, so we skip error checking during conn
589            // operations.
590            if let Err(err) = &r {
591                let _ = conn
592                    .send(ErrorResponse::fatal(
593                        SqlState::CONNECTION_FAILURE,
594                        err.to_string(),
595                    ))
596                    .await;
597                let _ = conn.flush().await;
598            }
599            r
600        },
601        _ = expired => {
602            conn
603                .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
604                .await?;
605            conn.flush().await
606        }
607    }
608}
609
610/// Gets `oidc_auth_enabled` from options if it exists.
611/// Returns options with oidc_auth_enabled extracted
612/// and the oidc_auth_enabled value.
613fn extract_oidc_auth_enabled_from_options(
614    options: Result<Vec<(String, String)>, ()>,
615) -> (bool, Result<Vec<(String, String)>, ()>) {
616    let options = match options {
617        Ok(opts) => opts,
618        Err(_) => return (false, options),
619    };
620
621    let mut new_options = Vec::new();
622    let mut oidc_auth_enabled = false;
623
624    for (k, v) in options {
625        if k == "oidc_auth_enabled" {
626            oidc_auth_enabled = v.parse::<bool>().unwrap_or(false);
627        } else {
628            new_options.push((k, v));
629        }
630    }
631
632    (oidc_auth_enabled, Ok(new_options))
633}
634
635/// Returns (name, value) session settings pairs from an options value.
636///
637/// From Postgres, see pg_split_opts in postinit.c and process_postgres_switches
638/// in postgres.c.
639fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
640    let opts = split_options(value);
641    let mut pairs = Vec::with_capacity(opts.len());
642    let mut seen_prefix = false;
643    for opt in opts {
644        if !seen_prefix {
645            if opt == "-c" {
646                seen_prefix = true;
647            } else {
648                let (key, val) = parse_option(&opt)?;
649                pairs.push((key.to_owned(), val.to_owned()));
650            }
651        } else {
652            let (key, val) = opt.split_once('=').ok_or(())?;
653            pairs.push((key.to_owned(), val.to_owned()));
654            seen_prefix = false;
655        }
656    }
657    Ok(pairs)
658}
659
660/// Returns the parsed key and value from option of the form `--key=value`, `-c
661/// key=value`, or `-ckey=value`. Keys replace `-` with `_`. Returns an error if
662/// there was some other prefix.
663fn parse_option(option: &str) -> Result<(&str, &str), ()> {
664    let (key, value) = option.split_once('=').ok_or(())?;
665    for prefix in &["-c", "--"] {
666        if let Some(key) = key.strip_prefix(prefix) {
667            return Ok((key, value));
668        }
669    }
670    Err(())
671}
672
673/// Splits value by any number of spaces except those preceded by `\`.
674fn split_options(value: &str) -> Vec<String> {
675    let mut strs = Vec::new();
676    // Need to build a string because of the escaping, so we can't simply
677    // subslice into value, and this isn't called enough to need to make it
678    // smart so it only builds a string if needed.
679    let mut current = String::new();
680    let mut was_slash = false;
681    for c in value.chars() {
682        was_slash = match c {
683            ' ' => {
684                if was_slash {
685                    current.push(' ');
686                } else if !current.is_empty() {
687                    // To ignore multiple spaces in a row, only push if current
688                    // is not empty.
689                    strs.push(std::mem::take(&mut current));
690                }
691                false
692            }
693            '\\' => {
694                if was_slash {
695                    // Two slashes in a row will add a slash and not escape the
696                    // next char.
697                    current.push('\\');
698                    false
699                } else {
700                    true
701                }
702            }
703            _ => {
704                current.push(c);
705                false
706            }
707        };
708    }
709    // A `\` at the end will be ignored.
710    if !current.is_empty() {
711        strs.push(current);
712    }
713    strs
714}
715
716enum PasswordRequestError {
717    InvalidPasswordError(ErrorResponse),
718    IoError(io::Error),
719}
720
721impl From<io::Error> for PasswordRequestError {
722    fn from(e: io::Error) -> Self {
723        PasswordRequestError::IoError(e)
724    }
725}
726
727/// Requests a cleartext password from a connection and returns it if it is valid.
728/// Sends an error response in the connection if the password
729/// is not valid.
730async fn request_cleartext_password<A>(
731    conn: &mut FramedConn<A>,
732) -> Result<String, PasswordRequestError>
733where
734    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
735{
736    conn.send(BackendMessage::AuthenticationCleartextPassword)
737        .await?;
738    conn.flush().await?;
739
740    if let Some(message) = conn.recv().await? {
741        if let FrontendMessage::RawAuthentication(data) = message {
742            if let Some(FrontendMessage::Password { password }) =
743                decode_password(Cursor::new(&data)).ok()
744            {
745                return Ok(password);
746            }
747        }
748    }
749
750    Err(PasswordRequestError::InvalidPasswordError(
751        ErrorResponse::fatal(
752            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
753            "expected Password message",
754        ),
755    ))
756}
757
758/// Helper for password-based authentication using AdapterClient
759/// and returns an authenticated session.
760async fn authenticate_with_password<A>(
761    conn: &mut FramedConn<A>,
762    adapter_client: &mz_adapter::Client,
763    user: String,
764    conn_uuid: Uuid,
765    helm_chart_version: Option<String>,
766) -> Result<Session, PasswordRequestError>
767where
768    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
769{
770    let password = match request_cleartext_password(conn).await {
771        Ok(password) => Password(password),
772        Err(e) => return Err(e),
773    };
774
775    let authenticated = match adapter_client.authenticate(&user, &password).await {
776        Ok(authenticated) => authenticated,
777        Err(err) => {
778            warn!(?err, "pgwire connection failed authentication");
779            return Err(PasswordRequestError::InvalidPasswordError(
780                ErrorResponse::fatal(SqlState::INVALID_PASSWORD, "invalid password"),
781            ));
782        }
783    };
784
785    let session = adapter_client.new_session(
786        SessionConfig {
787            conn_id: conn.conn_id().clone(),
788            uuid: conn_uuid,
789            user,
790            client_ip: conn.peer_addr().clone(),
791            external_metadata_rx: None,
792            helm_chart_version,
793        },
794        authenticated,
795    );
796
797    Ok(session)
798}
799
800#[derive(Debug)]
801enum State {
802    Ready,
803    Drain,
804    Done,
805}
806
807struct StateMachine<'a, A, I>
808where
809    I: Iterator<Item = TaskMetrics> + Send + 'a,
810{
811    conn: &'a mut FramedConn<A>,
812    adapter_client: mz_adapter::SessionClient,
813    txn_needs_commit: bool,
814    tokio_metrics_intervals: I,
815}
816
817enum SendRowsEndedReason {
818    Success {
819        result_size: u64,
820        rows_returned: u64,
821    },
822    Errored {
823        error: String,
824    },
825    Canceled,
826}
827
828const ABORTED_TXN_MSG: &str =
829    "current transaction is aborted, commands ignored until end of transaction block";
830
831impl<'a, A, I> StateMachine<'a, A, I>
832where
833    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
834    I: Iterator<Item = TaskMetrics> + Send + 'a,
835{
836    // Manually desugar this (don't use `async fn run`) here because a much better
837    // error message is produced if there are problems with Send or other traits
838    // somewhere within the Future.
839    #[allow(clippy::manual_async_fn)]
840    #[mz_ore::instrument(level = "debug")]
841    fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
842        async move {
843            let mut state = State::Ready;
844            loop {
845                self.send_pending_notices().await?;
846                state = match state {
847                    State::Ready => self.advance_ready().await?,
848                    State::Drain => self.advance_drain().await?,
849                    State::Done => return Ok(()),
850                };
851                self.adapter_client
852                    .add_idle_in_transaction_session_timeout();
853            }
854        }
855    }
856
857    #[instrument(level = "debug")]
858    async fn advance_ready(&mut self) -> Result<State, io::Error> {
859        // Start a new metrics interval before the `recv()` call.
860        self.tokio_metrics_intervals
861            .next()
862            .expect("infinite iterator");
863
864        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
865        let message = select! {
866            biased;
867
868            // `recv_timeout()` is cancel-safe as per it's docs.
869            Some(timeout) = self.adapter_client.recv_timeout() => {
870                let err: AdapterError = timeout.into();
871                let conn_id = self.adapter_client.session().conn_id();
872                tracing::warn!("session timed out, conn_id {}", conn_id);
873
874                // Process the error, doing any state cleanup.
875                let error_response = err.into_response(Severity::Fatal);
876                let error_state = self.error(error_response).await;
877
878                // Terminate __after__ we do any cleanup.
879                self.adapter_client.terminate().await;
880
881                // We must wait for the client to send a request before we can send the error response.
882                // Due to the PG wire protocol, we can't send an ErrorResponse unless it is in response
883                // to a client message.
884                let _ = self.conn.recv().await?;
885                return error_state;
886            },
887            // `recv()` is cancel-safe as per it's docs.
888            message = self.conn.recv() => message?,
889        };
890
891        // Take the metrics since just before the `recv`.
892        let interval = self
893            .tokio_metrics_intervals
894            .next()
895            .expect("infinite iterator");
896        let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
897
898        // TODO(ggevay): Consider subtracting the scheduling delay from `received`. It's not obvious
899        // whether we should do this, because the result wouldn't exactly correspond to either first
900        // byte received or last byte received (for msgs that arrive in more than one network packet).
901        let received = SYSTEM_TIME();
902
903        self.adapter_client
904            .remove_idle_in_transaction_session_timeout();
905
906        // NOTE(guswynn): we could consider adding spans to all message types. Currently
907        // only a few message types seem useful.
908        let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
909
910        let start = message.as_ref().map(|_| Instant::now());
911        let next_state = match message {
912            Some(FrontendMessage::Query { sql }) => {
913                let query_root_span =
914                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
915                query_root_span.follows_from(tracing::Span::current());
916                self.query(sql, received)
917                    .instrument(query_root_span)
918                    .await?
919            }
920            Some(FrontendMessage::Parse {
921                name,
922                sql,
923                param_types,
924            }) => self.parse(name, sql, param_types).await?,
925            Some(FrontendMessage::Bind {
926                portal_name,
927                statement_name,
928                param_formats,
929                raw_params,
930                result_formats,
931            }) => {
932                self.bind(
933                    portal_name,
934                    statement_name,
935                    param_formats,
936                    raw_params,
937                    result_formats,
938                )
939                .await?
940            }
941            Some(FrontendMessage::Execute {
942                portal_name,
943                max_rows,
944            }) => {
945                let max_rows = match usize::try_from(max_rows) {
946                    Ok(0) | Err(_) => ExecuteCount::All, // If `max_rows < 0`, no limit.
947                    Ok(n) => ExecuteCount::Count(n),
948                };
949                let execute_root_span =
950                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
951                execute_root_span.follows_from(tracing::Span::current());
952                let state = self
953                    .execute(
954                        portal_name,
955                        max_rows,
956                        portal_exec_message,
957                        None,
958                        ExecuteTimeout::None,
959                        None,
960                        Some(received),
961                    )
962                    .instrument(execute_root_span)
963                    .await?;
964                // In PostgreSQL, when using the extended query protocol, some statements may
965                // trigger an eager commit of the current implicit transaction,
966                // see: <https://git.postgresql.org/gitweb/?p=postgresql.git&a=commitdiff&h=f92944137>.
967                //
968                // In Materialize, however, we eagerly commit every statement outside of an explicit
969                // transaction when using the extended query protocol. This allows us to eliminate
970                // the possibility of a multiple statement implicit transaction, which in turn
971                // allows us to apply single-statement optimizations to queries issued in implicit
972                // transactions in the extended query protocol.
973                //
974                // We don't immediately commit here to allow users to page through the portal if
975                // necessary. Committing the transaction would destroy the portal before the next
976                // Execute command has a chance to resume it. So we instead mark the transaction
977                // for commit the next time that `ensure_transaction` is called.
978                if self.adapter_client.session().transaction().is_implicit() {
979                    self.txn_needs_commit = true;
980                }
981                state
982            }
983            Some(FrontendMessage::DescribeStatement { name }) => {
984                self.describe_statement(&name).await?
985            }
986            Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
987            Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
988            Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
989            Some(FrontendMessage::Flush) => self.flush().await?,
990            Some(FrontendMessage::Sync) => self.sync().await?,
991            Some(FrontendMessage::Terminate) => State::Done,
992
993            Some(FrontendMessage::CopyData(_))
994            | Some(FrontendMessage::CopyDone)
995            | Some(FrontendMessage::CopyFail(_))
996            | Some(FrontendMessage::Password { .. })
997            | Some(FrontendMessage::RawAuthentication(_))
998            | Some(FrontendMessage::SASLInitialResponse { .. })
999            | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
1000            None => State::Done,
1001        };
1002
1003        if let Some(start) = start {
1004            self.adapter_client
1005                .inner()
1006                .metrics()
1007                .pgwire_message_processing_seconds
1008                .with_label_values(&[message_name])
1009                .observe(start.elapsed().as_secs_f64());
1010        }
1011        self.adapter_client
1012            .inner()
1013            .metrics()
1014            .pgwire_recv_scheduling_delay_ms
1015            .with_label_values(&[message_name])
1016            .observe(recv_scheduling_delay_ms);
1017
1018        Ok(next_state)
1019    }
1020
1021    async fn advance_drain(&mut self) -> Result<State, io::Error> {
1022        let message = self.conn.recv().await?;
1023        if message.is_some() {
1024            self.adapter_client
1025                .remove_idle_in_transaction_session_timeout();
1026        }
1027        match message {
1028            Some(FrontendMessage::Sync) => self.sync().await,
1029            None => Ok(State::Done),
1030            _ => Ok(State::Drain),
1031        }
1032    }
1033
1034    /// Note that `lifecycle_timestamps` belongs to the whole "Simple Query", because the whole
1035    /// Simple Query is received and parsed together. This means that if there are multiple
1036    /// statements in a Simple Query, then all of them have the same `lifecycle_timestamps`.
1037    #[instrument(level = "debug")]
1038    async fn one_query(
1039        &mut self,
1040        stmt: Statement<Raw>,
1041        sql: String,
1042        lifecycle_timestamps: LifecycleTimestamps,
1043    ) -> Result<State, io::Error> {
1044        // Bind the portal. Note that this does not set the empty string prepared
1045        // statement.
1046        const EMPTY_PORTAL: &str = "";
1047        if let Err(e) = self
1048            .adapter_client
1049            .declare(EMPTY_PORTAL.to_string(), stmt, sql)
1050            .await
1051        {
1052            return self.error(e.into_response(Severity::Error)).await;
1053        }
1054        let portal = self
1055            .adapter_client
1056            .session()
1057            .get_portal_unverified_mut(EMPTY_PORTAL)
1058            .expect("unnamed portal should be present");
1059
1060        *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
1061
1062        let stmt_desc = portal.desc.clone();
1063        if !stmt_desc.param_types.is_empty() {
1064            return self
1065                .error(ErrorResponse::error(
1066                    SqlState::UNDEFINED_PARAMETER,
1067                    "there is no parameter $1",
1068                ))
1069                .await;
1070        }
1071
1072        // Maybe send row description.
1073        if let Some(relation_desc) = &stmt_desc.relation_desc {
1074            if !stmt_desc.is_copy {
1075                let formats = vec![Format::Text; stmt_desc.arity()];
1076                self.send(BackendMessage::RowDescription(
1077                    message::encode_row_description(relation_desc, &formats),
1078                ))
1079                .await?;
1080            }
1081        }
1082
1083        let result = match self
1084            .adapter_client
1085            .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
1086            .await
1087        {
1088            Ok((response, execute_started)) => {
1089                self.send_pending_notices().await?;
1090                self.send_execute_response(
1091                    response,
1092                    stmt_desc.relation_desc,
1093                    EMPTY_PORTAL.to_string(),
1094                    ExecuteCount::All,
1095                    portal_exec_message,
1096                    None,
1097                    ExecuteTimeout::None,
1098                    execute_started,
1099                )
1100                .await
1101            }
1102            Err(e) => {
1103                self.send_pending_notices().await?;
1104                self.error(e.into_response(Severity::Error)).await
1105            }
1106        };
1107
1108        // Destroy the portal.
1109        self.adapter_client.session().remove_portal(EMPTY_PORTAL);
1110
1111        result
1112    }
1113
1114    async fn ensure_transaction(
1115        &mut self,
1116        num_stmts: usize,
1117        message_type: &str,
1118    ) -> Result<(), io::Error> {
1119        let start = Instant::now();
1120        if self.txn_needs_commit {
1121            self.commit_transaction().await?;
1122        }
1123        // start_transaction can't error (but assert that just in case it changes in
1124        // the future.
1125        let res = self.adapter_client.start_transaction(Some(num_stmts));
1126        assert_ok!(res);
1127        self.adapter_client
1128            .inner()
1129            .metrics()
1130            .pgwire_ensure_transaction_seconds
1131            .with_label_values(&[message_type])
1132            .observe(start.elapsed().as_secs_f64());
1133        Ok(())
1134    }
1135
1136    fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1137        let parse_start = Instant::now();
1138        let result = match self.adapter_client.parse(sql) {
1139            Ok(result) => result.map_err(|e| {
1140                // Convert our 0-based byte position to pgwire's 1-based character
1141                // position.
1142                let pos = sql[..e.error.pos].chars().count() + 1;
1143                ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1144            }),
1145            Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1146        };
1147        self.adapter_client
1148            .inner()
1149            .metrics()
1150            .parse_seconds
1151            .observe(parse_start.elapsed().as_secs_f64());
1152        result
1153    }
1154
1155    /// Executes a "Simple Query", see
1156    /// <https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SIMPLE-QUERY>
1157    ///
1158    /// For implicit transaction handling, see "Multiple Statements in a Simple Query" in the above.
1159    #[instrument(level = "debug")]
1160    async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1161        // Parse first before doing any transaction checking.
1162        let stmts = match self.parse_sql(&sql) {
1163            Ok(stmts) => stmts,
1164            Err(err) => {
1165                self.error(err).await?;
1166                return self.ready().await;
1167            }
1168        };
1169
1170        let num_stmts = stmts.len();
1171
1172        // Compare with postgres' backend/tcop/postgres.c exec_simple_query.
1173        for StatementParseResult { ast: stmt, sql } in stmts {
1174            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1175            if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1176                self.aborted_txn_error().await?;
1177                break;
1178            }
1179
1180            // Start an implicit transaction if we aren't in any transaction and there's
1181            // more than one statement. This mirrors the `use_implicit_block` variable in
1182            // postgres.
1183            //
1184            // This needs to be done in the loop instead of once at the top because
1185            // a COMMIT/ROLLBACK statement needs to start a new transaction on next
1186            // statement.
1187            self.ensure_transaction(num_stmts, "query").await?;
1188
1189            match self
1190                .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1191                .await?
1192            {
1193                State::Ready => (),
1194                State::Drain => break,
1195                State::Done => return Ok(State::Done),
1196            }
1197        }
1198
1199        // Implicit transactions are closed at the end of a Query message.
1200        {
1201            if self.adapter_client.session().transaction().is_implicit() {
1202                self.commit_transaction().await?;
1203            }
1204        }
1205
1206        if num_stmts == 0 {
1207            self.send(BackendMessage::EmptyQueryResponse).await?;
1208        }
1209
1210        self.ready().await
1211    }
1212
1213    #[instrument(level = "debug")]
1214    async fn parse(
1215        &mut self,
1216        name: String,
1217        sql: String,
1218        param_oids: Vec<u32>,
1219    ) -> Result<State, io::Error> {
1220        // Start a transaction if we aren't in one.
1221        self.ensure_transaction(1, "parse").await?;
1222
1223        let mut param_types = vec![];
1224        for oid in param_oids {
1225            match mz_pgrepr::Type::from_oid(oid) {
1226                Ok(ty) => match SqlScalarType::try_from(&ty) {
1227                    Ok(ty) => param_types.push(Some(ty)),
1228                    Err(err) => {
1229                        return self
1230                            .error(ErrorResponse::error(
1231                                SqlState::INVALID_PARAMETER_VALUE,
1232                                err.to_string(),
1233                            ))
1234                            .await;
1235                    }
1236                },
1237                Err(_) if oid == 0 => param_types.push(None),
1238                Err(e) => {
1239                    return self
1240                        .error(ErrorResponse::error(
1241                            SqlState::PROTOCOL_VIOLATION,
1242                            e.to_string(),
1243                        ))
1244                        .await;
1245                }
1246            }
1247        }
1248
1249        let stmts = match self.parse_sql(&sql) {
1250            Ok(stmts) => stmts,
1251            Err(err) => {
1252                return self.error(err).await;
1253            }
1254        };
1255        if stmts.len() > 1 {
1256            return self
1257                .error(ErrorResponse::error(
1258                    SqlState::INTERNAL_ERROR,
1259                    "cannot insert multiple commands into a prepared statement",
1260                ))
1261                .await;
1262        }
1263        let (maybe_stmt, sql) = match stmts.into_iter().next() {
1264            None => (None, ""),
1265            Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1266        };
1267        if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1268            return self.aborted_txn_error().await;
1269        }
1270        match self
1271            .adapter_client
1272            .prepare(name, maybe_stmt, sql.to_string(), param_types)
1273            .await
1274        {
1275            Ok(()) => {
1276                self.send(BackendMessage::ParseComplete).await?;
1277                Ok(State::Ready)
1278            }
1279            Err(e) => self.error(e.into_response(Severity::Error)).await,
1280        }
1281    }
1282
1283    /// Commits and clears the current transaction.
1284    #[instrument(level = "debug")]
1285    async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1286        self.end_transaction(EndTransactionAction::Commit).await
1287    }
1288
1289    /// Rollback and clears the current transaction.
1290    #[instrument(level = "debug")]
1291    async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1292        self.end_transaction(EndTransactionAction::Rollback).await
1293    }
1294
1295    /// End a transaction and report to the user if an error occurred.
1296    #[instrument(level = "debug")]
1297    async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1298        self.txn_needs_commit = false;
1299        let resp = self.adapter_client.end_transaction(action).await;
1300        if let Err(err) = resp {
1301            self.send(BackendMessage::ErrorResponse(
1302                err.into_response(Severity::Error),
1303            ))
1304            .await?;
1305        }
1306        Ok(())
1307    }
1308
1309    #[instrument(level = "debug")]
1310    async fn bind(
1311        &mut self,
1312        portal_name: String,
1313        statement_name: String,
1314        param_formats: Vec<Format>,
1315        raw_params: Vec<Option<Vec<u8>>>,
1316        result_formats: Vec<Format>,
1317    ) -> Result<State, io::Error> {
1318        // Start a transaction if we aren't in one.
1319        self.ensure_transaction(1, "bind").await?;
1320
1321        let aborted_txn = self.is_aborted_txn();
1322        let stmt = match self
1323            .adapter_client
1324            .get_prepared_statement(&statement_name)
1325            .await
1326        {
1327            Ok(stmt) => stmt,
1328            Err(err) => return self.error(err.into_response(Severity::Error)).await,
1329        };
1330
1331        let param_types = &stmt.desc().param_types;
1332        if param_types.len() != raw_params.len() {
1333            let message = format!(
1334                "bind message supplies {actual} parameters, \
1335                 but prepared statement \"{name}\" requires {expected}",
1336                name = statement_name,
1337                actual = raw_params.len(),
1338                expected = param_types.len()
1339            );
1340            return self
1341                .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, message))
1342                .await;
1343        }
1344        let param_formats = match pad_formats(param_formats, raw_params.len()) {
1345            Ok(param_formats) => param_formats,
1346            Err(msg) => {
1347                return self
1348                    .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1349                    .await;
1350            }
1351        };
1352        if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1353            return self.aborted_txn_error().await;
1354        }
1355        let buf = RowArena::new();
1356        let mut params = vec![];
1357        for ((raw_param, mz_typ), format) in raw_params
1358            .into_iter()
1359            .zip_eq(param_types)
1360            .zip_eq(param_formats)
1361        {
1362            let pg_typ = mz_pgrepr::Type::from(mz_typ);
1363            let datum = match raw_param {
1364                None => Datum::Null,
1365                Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1366                    Ok(param) => param.into_datum(&buf, &pg_typ),
1367                    Err(err) => {
1368                        let msg = format!("unable to decode parameter: {}", err);
1369                        return self
1370                            .error(ErrorResponse::error(SqlState::INVALID_PARAMETER_VALUE, msg))
1371                            .await;
1372                    }
1373                },
1374            };
1375            params.push((datum, mz_typ.clone()))
1376        }
1377
1378        let result_formats = match pad_formats(
1379            result_formats,
1380            stmt.desc()
1381                .relation_desc
1382                .clone()
1383                .map(|desc| desc.typ().column_types.len())
1384                .unwrap_or(0),
1385        ) {
1386            Ok(result_formats) => result_formats,
1387            Err(msg) => {
1388                return self
1389                    .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1390                    .await;
1391            }
1392        };
1393
1394        // Binary encodings are disabled for list, map, and aclitem types, but this doesn't
1395        // apply to COPY TO statements.
1396        if !stmt.stmt().map_or(false, |stmt| {
1397            matches!(
1398                stmt,
1399                Statement::Copy(CopyStatement {
1400                    direction: CopyDirection::To,
1401                    ..
1402                })
1403            )
1404        }) {
1405            if let Some(desc) = stmt.desc().relation_desc.clone() {
1406                for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1407                    match (format, &ty.scalar_type) {
1408                        (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1409                            return self
1410                                .error(ErrorResponse::error(
1411                                    SqlState::PROTOCOL_VIOLATION,
1412                                    "binary encoding of list types is not implemented",
1413                                ))
1414                                .await;
1415                        }
1416                        (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1417                            return self
1418                                .error(ErrorResponse::error(
1419                                    SqlState::PROTOCOL_VIOLATION,
1420                                    "binary encoding of map types is not implemented",
1421                                ))
1422                                .await;
1423                        }
1424                        (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1425                            return self
1426                                .error(ErrorResponse::error(
1427                                    SqlState::PROTOCOL_VIOLATION,
1428                                    "binary encoding of aclitem types does not exist",
1429                                ))
1430                                .await;
1431                        }
1432                        _ => (),
1433                    }
1434                }
1435            }
1436        }
1437
1438        let desc = stmt.desc().clone();
1439        let logging = Arc::clone(stmt.logging());
1440        let stmt_ast = stmt.stmt().cloned();
1441        let state_revision = stmt.state_revision;
1442        if let Err(err) = self.adapter_client.session().set_portal(
1443            portal_name,
1444            desc,
1445            stmt_ast,
1446            logging,
1447            params,
1448            result_formats,
1449            state_revision,
1450        ) {
1451            return self.error(err.into_response(Severity::Error)).await;
1452        }
1453
1454        self.send(BackendMessage::BindComplete).await?;
1455        Ok(State::Ready)
1456    }
1457
1458    /// `outer_ctx_extra` is Some when we are executing as part of an outer statement, e.g., a FETCH
1459    /// triggering the execution of the underlying query.
1460    fn execute(
1461        &mut self,
1462        portal_name: String,
1463        max_rows: ExecuteCount,
1464        get_response: GetResponse,
1465        fetch_portal_name: Option<String>,
1466        timeout: ExecuteTimeout,
1467        outer_ctx_extra: Option<ExecuteContextGuard>,
1468        received: Option<EpochMillis>,
1469    ) -> BoxFuture<'_, Result<State, io::Error>> {
1470        async move {
1471            let aborted_txn = self.is_aborted_txn();
1472
1473            // Check if the portal has been started and can be continued.
1474            let portal = match self
1475                .adapter_client
1476                .session()
1477                .get_portal_unverified_mut(&portal_name)
1478            {
1479                Some(portal) => portal,
1480                None => {
1481                    let msg = format!("portal {} does not exist", portal_name.quoted());
1482                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1483                        self.adapter_client.retire_execute(
1484                            outer_ctx_extra,
1485                            StatementEndedExecutionReason::Errored { error: msg.clone() },
1486                        );
1487                    }
1488                    return self
1489                        .error(ErrorResponse::error(SqlState::INVALID_CURSOR_NAME, msg))
1490                        .await;
1491                }
1492            };
1493
1494            *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1495
1496            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1497            let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1498            if aborted_txn && !txn_exit_stmt {
1499                if let Some(outer_ctx_extra) = outer_ctx_extra {
1500                    self.adapter_client.retire_execute(
1501                        outer_ctx_extra,
1502                        StatementEndedExecutionReason::Errored {
1503                            error: ABORTED_TXN_MSG.to_string(),
1504                        },
1505                    );
1506                }
1507                return self.aborted_txn_error().await;
1508            }
1509
1510            let row_desc = portal.desc.relation_desc.clone();
1511            match portal.state {
1512                PortalState::NotStarted => {
1513                    // Start a transaction if we aren't in one.
1514                    self.ensure_transaction(1, "execute").await?;
1515                    match self
1516                        .adapter_client
1517                        .execute(
1518                            portal_name.clone(),
1519                            self.conn.wait_closed(),
1520                            outer_ctx_extra,
1521                        )
1522                        .await
1523                    {
1524                        Ok((response, execute_started)) => {
1525                            self.send_pending_notices().await?;
1526                            self.send_execute_response(
1527                                response,
1528                                row_desc,
1529                                portal_name,
1530                                max_rows,
1531                                get_response,
1532                                fetch_portal_name,
1533                                timeout,
1534                                execute_started,
1535                            )
1536                            .await
1537                        }
1538                        Err(e) => {
1539                            self.send_pending_notices().await?;
1540                            self.error(e.into_response(Severity::Error)).await
1541                        }
1542                    }
1543                }
1544                PortalState::InProgress(rows) => {
1545                    let rows = rows.take().expect("InProgress rows must be populated");
1546                    let (result, statement_ended_execution_reason) = match self
1547                        .send_rows(
1548                            row_desc.expect("portal missing row desc on resumption"),
1549                            portal_name,
1550                            rows,
1551                            max_rows,
1552                            get_response,
1553                            fetch_portal_name,
1554                            timeout,
1555                        )
1556                        .await
1557                    {
1558                        Err(e) => {
1559                            // This is an error communicating with the connection.
1560                            // We consider that to be a cancelation, rather than a query error.
1561                            (Err(e), StatementEndedExecutionReason::Canceled)
1562                        }
1563                        Ok((ok, SendRowsEndedReason::Canceled)) => {
1564                            (Ok(ok), StatementEndedExecutionReason::Canceled)
1565                        }
1566                        // NOTE: For now the values for `result_size` and
1567                        // `rows_returned` in fetches are a bit confusing.
1568                        // We record `Some(n)` for the first fetch, where `n` is
1569                        // the number of bytes/rows returned by the inner
1570                        // execute (regardless of how many rows the
1571                        // fetch fetched), and `None` for subsequent fetches.
1572                        //
1573                        // This arguably makes sense since the size/rows
1574                        // returned measures how much work the compute
1575                        // layer had to do to satisfy the query, but
1576                        // we should revisit it if/when we start
1577                        // logging the inner execute separately.
1578                        Ok((
1579                            ok,
1580                            SendRowsEndedReason::Success {
1581                                result_size: _,
1582                                rows_returned: _,
1583                            },
1584                        )) => (
1585                            Ok(ok),
1586                            StatementEndedExecutionReason::Success {
1587                                result_size: None,
1588                                rows_returned: None,
1589                                execution_strategy: None,
1590                            },
1591                        ),
1592                        Ok((ok, SendRowsEndedReason::Errored { error })) => {
1593                            (Ok(ok), StatementEndedExecutionReason::Errored { error })
1594                        }
1595                    };
1596                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1597                        self.adapter_client
1598                            .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1599                    }
1600                    result
1601                }
1602                // FETCH is an awkward command for our current architecture. In Postgres it
1603                // will extract <count> rows from the target portal, cache them, and return
1604                // them to the user as requested. Its command tag is always FETCH <num rows
1605                // extracted>. In Materialize, since we have chosen to not fully support FETCH,
1606                // we must remember the number of rows that were returned. Use this tag to
1607                // remember that information and return it.
1608                PortalState::Completed(Some(tag)) => {
1609                    let tag = tag.to_string();
1610                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1611                        self.adapter_client.retire_execute(
1612                            outer_ctx_extra,
1613                            StatementEndedExecutionReason::Success {
1614                                result_size: None,
1615                                rows_returned: None,
1616                                execution_strategy: None,
1617                            },
1618                        );
1619                    }
1620                    self.send(BackendMessage::CommandComplete { tag }).await?;
1621                    Ok(State::Ready)
1622                }
1623                PortalState::Completed(None) => {
1624                    let error = format!(
1625                        "portal {} cannot be run",
1626                        Ident::new_unchecked(portal_name).to_ast_string_stable()
1627                    );
1628                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1629                        self.adapter_client.retire_execute(
1630                            outer_ctx_extra,
1631                            StatementEndedExecutionReason::Errored {
1632                                error: error.clone(),
1633                            },
1634                        );
1635                    }
1636                    self.error(ErrorResponse::error(
1637                        SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1638                        error,
1639                    ))
1640                    .await
1641                }
1642            }
1643        }
1644        .instrument(debug_span!("execute"))
1645        .boxed()
1646    }
1647
1648    #[instrument(level = "debug")]
1649    async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1650        // Start a transaction if we aren't in one.
1651        self.ensure_transaction(1, "describe_statement").await?;
1652
1653        let stmt = match self.adapter_client.get_prepared_statement(name).await {
1654            Ok(stmt) => stmt,
1655            Err(err) => return self.error(err.into_response(Severity::Error)).await,
1656        };
1657        // Cloning to avoid a mutable borrow issue because `send` also uses `adapter_client`
1658        let parameter_desc = BackendMessage::ParameterDescription(
1659            stmt.desc()
1660                .param_types
1661                .iter()
1662                .map(mz_pgrepr::Type::from)
1663                .collect(),
1664        );
1665        // Claim that all results will be output in text format, even
1666        // though the true result formats are not yet known. A bit
1667        // weird, but this is the behavior that PostgreSQL specifies.
1668        let formats = vec![Format::Text; stmt.desc().arity()];
1669        let row_desc = describe_rows(stmt.desc(), &formats);
1670        self.send_all([parameter_desc, row_desc]).await?;
1671        Ok(State::Ready)
1672    }
1673
1674    #[instrument(level = "debug")]
1675    async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1676        // Start a transaction if we aren't in one.
1677        self.ensure_transaction(1, "describe_portal").await?;
1678
1679        let session = self.adapter_client.session();
1680        let row_desc = session
1681            .get_portal_unverified(name)
1682            .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1683        match row_desc {
1684            Some(row_desc) => {
1685                self.send(row_desc).await?;
1686                Ok(State::Ready)
1687            }
1688            None => {
1689                self.error(ErrorResponse::error(
1690                    SqlState::INVALID_CURSOR_NAME,
1691                    format!("portal {} does not exist", name.quoted()),
1692                ))
1693                .await
1694            }
1695        }
1696    }
1697
1698    #[instrument(level = "debug")]
1699    async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1700        self.adapter_client
1701            .session()
1702            .remove_prepared_statement(&name);
1703        self.send(BackendMessage::CloseComplete).await?;
1704        Ok(State::Ready)
1705    }
1706
1707    #[instrument(level = "debug")]
1708    async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1709        self.adapter_client.session().remove_portal(&name);
1710        self.send(BackendMessage::CloseComplete).await?;
1711        Ok(State::Ready)
1712    }
1713
1714    fn complete_portal(&mut self, name: &str) {
1715        let portal = self
1716            .adapter_client
1717            .session()
1718            .get_portal_unverified_mut(name)
1719            .expect("portal should exist");
1720        *portal.state = PortalState::Completed(None);
1721    }
1722
1723    async fn fetch(
1724        &mut self,
1725        name: String,
1726        count: Option<FetchDirection>,
1727        max_rows: ExecuteCount,
1728        fetch_portal_name: Option<String>,
1729        timeout: ExecuteTimeout,
1730        ctx_extra: ExecuteContextGuard,
1731    ) -> Result<State, io::Error> {
1732        // Unlike Execute, no count specified in FETCH returns 1 row, and 0 means 0
1733        // instead of All.
1734        let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1735
1736        // Figure out how many rows we should send back by looking at the various
1737        // combinations of the execute and fetch.
1738        //
1739        // In Postgres, Fetch will cache <count> rows from the target portal and
1740        // return those as requested (if, say, an Execute message was sent with a
1741        // max_rows < the Fetch's count). We expect that case to be incredibly rare and
1742        // so have chosen to not support it until users request it. This eases
1743        // implementation difficulty since we don't have to be able to "send" rows to
1744        // a buffer.
1745        //
1746        // TODO(mjibson): Test this somehow? Need to divide up the pgtest files in
1747        // order to have some that are not Postgres compatible.
1748        let count = match (max_rows, count) {
1749            (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1750                let count = usize::cast_from(count);
1751                if max_rows < count {
1752                    let msg = "Execute with max_rows < a FETCH's count is not supported";
1753                    self.adapter_client.retire_execute(
1754                        ctx_extra,
1755                        StatementEndedExecutionReason::Errored {
1756                            error: msg.to_string(),
1757                        },
1758                    );
1759                    return self
1760                        .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1761                        .await;
1762                }
1763                ExecuteCount::Count(count)
1764            }
1765            (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1766                let msg = "Execute with max_rows of a FETCH ALL is not supported";
1767                self.adapter_client.retire_execute(
1768                    ctx_extra,
1769                    StatementEndedExecutionReason::Errored {
1770                        error: msg.to_string(),
1771                    },
1772                );
1773                return self
1774                    .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1775                    .await;
1776            }
1777            (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1778            (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1779                ExecuteCount::Count(usize::cast_from(count))
1780            }
1781        };
1782        let cursor_name = name.to_string();
1783        self.execute(
1784            cursor_name,
1785            count,
1786            fetch_message,
1787            fetch_portal_name,
1788            timeout,
1789            Some(ctx_extra),
1790            None,
1791        )
1792        .await
1793    }
1794
1795    async fn flush(&mut self) -> Result<State, io::Error> {
1796        self.conn.flush().await?;
1797        Ok(State::Ready)
1798    }
1799
1800    /// Sends a backend message to the client, after applying a severity filter.
1801    ///
1802    /// The message is only sent if its severity is above the severity set
1803    /// in the session, with the default value being NOTICE.
1804    #[instrument(level = "debug")]
1805    async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1806    where
1807        M: Into<BackendMessage>,
1808    {
1809        let message: BackendMessage = message.into();
1810        let is_error =
1811            matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1812
1813        self.conn.send(message).await?;
1814
1815        // Flush immediately after sending an error response, as some clients
1816        // expect to be able to read the error response before sending a Sync
1817        // message. This is arguably in violation of the protocol specification,
1818        // but the specification is somewhat ambiguous, and easier to match
1819        // PostgreSQL here than to fix all the clients that have this
1820        // expectation.
1821        if is_error {
1822            self.conn.flush().await?;
1823        }
1824
1825        Ok(())
1826    }
1827
1828    #[instrument(level = "debug")]
1829    pub async fn send_all(
1830        &mut self,
1831        messages: impl IntoIterator<Item = BackendMessage>,
1832    ) -> Result<(), io::Error> {
1833        for m in messages {
1834            self.send(m).await?;
1835        }
1836        Ok(())
1837    }
1838
1839    #[instrument(level = "debug")]
1840    async fn sync(&mut self) -> Result<State, io::Error> {
1841        // Close the current transaction if we are in an implicit transaction.
1842        if self.adapter_client.session().transaction().is_implicit() {
1843            self.commit_transaction().await?;
1844        }
1845        self.ready().await
1846    }
1847
1848    #[instrument(level = "debug")]
1849    async fn ready(&mut self) -> Result<State, io::Error> {
1850        let txn_state = self.adapter_client.session().transaction().into();
1851        self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1852        self.flush().await
1853    }
1854
1855    #[allow(clippy::too_many_arguments)]
1856    #[instrument(level = "debug")]
1857    async fn send_execute_response(
1858        &mut self,
1859        response: ExecuteResponse,
1860        row_desc: Option<RelationDesc>,
1861        portal_name: String,
1862        max_rows: ExecuteCount,
1863        get_response: GetResponse,
1864        fetch_portal_name: Option<String>,
1865        timeout: ExecuteTimeout,
1866        execute_started: Instant,
1867    ) -> Result<State, io::Error> {
1868        let mut tag = response.tag();
1869
1870        macro_rules! command_complete {
1871            () => {{
1872                self.send(BackendMessage::CommandComplete {
1873                    tag: tag
1874                        .take()
1875                        .expect("command_complete only called on tag-generating results"),
1876                })
1877                .await?;
1878                Ok(State::Ready)
1879            }};
1880        }
1881
1882        let r = match response {
1883            ExecuteResponse::ClosedCursor => {
1884                self.complete_portal(&portal_name);
1885                command_complete!()
1886            }
1887            ExecuteResponse::DeclaredCursor => {
1888                self.complete_portal(&portal_name);
1889                command_complete!()
1890            }
1891            ExecuteResponse::EmptyQuery => {
1892                self.send(BackendMessage::EmptyQueryResponse).await?;
1893                Ok(State::Ready)
1894            }
1895            ExecuteResponse::Fetch {
1896                name,
1897                count,
1898                timeout,
1899                ctx_extra,
1900            } => {
1901                self.fetch(
1902                    name,
1903                    count,
1904                    max_rows,
1905                    Some(portal_name.to_string()),
1906                    timeout,
1907                    ctx_extra,
1908                )
1909                .await
1910            }
1911            ExecuteResponse::SendingRowsStreaming {
1912                rows,
1913                instance_id,
1914                strategy,
1915            } => {
1916                let row_desc = row_desc
1917                    .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1918
1919                let span = tracing::debug_span!("sending_rows_streaming");
1920
1921                self.send_rows(
1922                    row_desc,
1923                    portal_name,
1924                    InProgressRows::new(RecordFirstRowStream::new(
1925                        Box::new(rows),
1926                        execute_started,
1927                        &self.adapter_client,
1928                        Some(instance_id),
1929                        Some(strategy),
1930                    )),
1931                    max_rows,
1932                    get_response,
1933                    fetch_portal_name,
1934                    timeout,
1935                )
1936                .instrument(span)
1937                .await
1938                .map(|(state, _)| state)
1939            }
1940            ExecuteResponse::SendingRowsImmediate { rows } => {
1941                let row_desc = row_desc
1942                    .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1943
1944                let span = tracing::debug_span!("sending_rows_immediate");
1945
1946                let stream =
1947                    futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1948                self.send_rows(
1949                    row_desc,
1950                    portal_name,
1951                    InProgressRows::new(RecordFirstRowStream::new(
1952                        Box::new(stream),
1953                        execute_started,
1954                        &self.adapter_client,
1955                        None,
1956                        Some(StatementExecutionStrategy::Constant),
1957                    )),
1958                    max_rows,
1959                    get_response,
1960                    fetch_portal_name,
1961                    timeout,
1962                )
1963                .instrument(span)
1964                .await
1965                .map(|(state, _)| state)
1966            }
1967            ExecuteResponse::SetVariable { name, .. } => {
1968                // This code is somewhat awkwardly structured because we
1969                // can't hold `var` across an await point.
1970                let qn = name.to_string();
1971                let msg = if let Some(var) = self
1972                    .adapter_client
1973                    .session()
1974                    .vars_mut()
1975                    .notify_set()
1976                    .find(|v| v.name() == qn)
1977                {
1978                    Some(BackendMessage::ParameterStatus(var.name(), var.value()))
1979                } else {
1980                    None
1981                };
1982                if let Some(msg) = msg {
1983                    self.send(msg).await?;
1984                }
1985                command_complete!()
1986            }
1987            ExecuteResponse::Subscribing {
1988                rx,
1989                ctx_extra,
1990                instance_id,
1991            } => {
1992                if fetch_portal_name.is_none() {
1993                    let mut msg = ErrorResponse::notice(
1994                        SqlState::WARNING,
1995                        "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
1996                    );
1997                    if self.adapter_client.session().vars().application_name() == "psql" {
1998                        msg.hint = Some(
1999                            "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
2000                                .into(),
2001                        )
2002                    }
2003                    self.send(msg).await?;
2004                    self.conn.flush().await?;
2005                }
2006                let row_desc =
2007                    row_desc.expect("missing row description for ExecuteResponse::Subscribing");
2008                let (result, statement_ended_execution_reason) = match self
2009                    .send_rows(
2010                        row_desc,
2011                        portal_name,
2012                        InProgressRows::new(RecordFirstRowStream::new(
2013                            Box::new(UnboundedReceiverStream::new(rx)),
2014                            execute_started,
2015                            &self.adapter_client,
2016                            Some(instance_id),
2017                            None,
2018                        )),
2019                        max_rows,
2020                        get_response,
2021                        fetch_portal_name,
2022                        timeout,
2023                    )
2024                    .await
2025                {
2026                    Err(e) => {
2027                        // This is an error communicating with the connection.
2028                        // We consider that to be a cancelation, rather than a query error.
2029                        (Err(e), StatementEndedExecutionReason::Canceled)
2030                    }
2031                    Ok((ok, SendRowsEndedReason::Canceled)) => {
2032                        (Ok(ok), StatementEndedExecutionReason::Canceled)
2033                    }
2034                    Ok((
2035                        ok,
2036                        SendRowsEndedReason::Success {
2037                            result_size,
2038                            rows_returned,
2039                        },
2040                    )) => (
2041                        Ok(ok),
2042                        StatementEndedExecutionReason::Success {
2043                            result_size: Some(result_size),
2044                            rows_returned: Some(rows_returned),
2045                            execution_strategy: None,
2046                        },
2047                    ),
2048                    Ok((ok, SendRowsEndedReason::Errored { error })) => {
2049                        (Ok(ok), StatementEndedExecutionReason::Errored { error })
2050                    }
2051                };
2052                self.adapter_client
2053                    .retire_execute(ctx_extra, statement_ended_execution_reason);
2054                return result;
2055            }
2056            ExecuteResponse::CopyTo { format, resp } => {
2057                let row_desc =
2058                    row_desc.expect("missing row description for ExecuteResponse::CopyTo");
2059                match *resp {
2060                    ExecuteResponse::Subscribing {
2061                        rx,
2062                        ctx_extra,
2063                        instance_id,
2064                    } => {
2065                        let (result, statement_ended_execution_reason) = match self
2066                            .copy_rows(
2067                                format,
2068                                row_desc,
2069                                RecordFirstRowStream::new(
2070                                    Box::new(UnboundedReceiverStream::new(rx)),
2071                                    execute_started,
2072                                    &self.adapter_client,
2073                                    Some(instance_id),
2074                                    None,
2075                                ),
2076                            )
2077                            .await
2078                        {
2079                            Err(e) => {
2080                                // This is an error communicating with the connection.
2081                                // We consider that to be a cancelation, rather than a query error.
2082                                (Err(e), StatementEndedExecutionReason::Canceled)
2083                            }
2084                            Ok((
2085                                state,
2086                                SendRowsEndedReason::Success {
2087                                    result_size,
2088                                    rows_returned,
2089                                },
2090                            )) => (
2091                                Ok(state),
2092                                StatementEndedExecutionReason::Success {
2093                                    result_size: Some(result_size),
2094                                    rows_returned: Some(rows_returned),
2095                                    execution_strategy: None,
2096                                },
2097                            ),
2098                            Ok((state, SendRowsEndedReason::Errored { error })) => {
2099                                (Ok(state), StatementEndedExecutionReason::Errored { error })
2100                            }
2101                            Ok((state, SendRowsEndedReason::Canceled)) => {
2102                                (Ok(state), StatementEndedExecutionReason::Canceled)
2103                            }
2104                        };
2105                        self.adapter_client
2106                            .retire_execute(ctx_extra, statement_ended_execution_reason);
2107                        return result;
2108                    }
2109                    ExecuteResponse::SendingRowsStreaming {
2110                        rows,
2111                        instance_id,
2112                        strategy,
2113                    } => {
2114                        // We don't need to finalize execution here;
2115                        // it was already done in the
2116                        // coordinator. Just extract the state and
2117                        // return that.
2118                        return self
2119                            .copy_rows(
2120                                format,
2121                                row_desc,
2122                                RecordFirstRowStream::new(
2123                                    Box::new(rows),
2124                                    execute_started,
2125                                    &self.adapter_client,
2126                                    Some(instance_id),
2127                                    Some(strategy),
2128                                ),
2129                            )
2130                            .await
2131                            .map(|(state, _)| state);
2132                    }
2133                    ExecuteResponse::SendingRowsImmediate { rows } => {
2134                        let span = tracing::debug_span!("sending_rows_immediate");
2135
2136                        let rows = futures::stream::once(futures::future::ready(
2137                            PeekResponseUnary::Rows(rows),
2138                        ));
2139                        // We don't need to finalize execution here;
2140                        // it was already done in the
2141                        // coordinator. Just extract the state and
2142                        // return that.
2143                        return self
2144                            .copy_rows(
2145                                format,
2146                                row_desc,
2147                                RecordFirstRowStream::new(
2148                                    Box::new(rows),
2149                                    execute_started,
2150                                    &self.adapter_client,
2151                                    None,
2152                                    Some(StatementExecutionStrategy::Constant),
2153                                ),
2154                            )
2155                            .instrument(span)
2156                            .await
2157                            .map(|(state, _)| state);
2158                    }
2159                    _ => {
2160                        return self
2161                            .error(ErrorResponse::error(
2162                                SqlState::INTERNAL_ERROR,
2163                                "unsupported COPY response type".to_string(),
2164                            ))
2165                            .await;
2166                    }
2167                };
2168            }
2169            ExecuteResponse::CopyFrom {
2170                target_id,
2171                target_name,
2172                columns,
2173                params,
2174                ctx_extra,
2175            } => {
2176                let row_desc =
2177                    row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2178                self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2179                    .await
2180            }
2181            ExecuteResponse::TransactionCommitted { params }
2182            | ExecuteResponse::TransactionRolledBack { params } => {
2183                let notify_set: mz_ore::collections::HashSet<String> = self
2184                    .adapter_client
2185                    .session()
2186                    .vars()
2187                    .notify_set()
2188                    .map(|v| v.name().to_string())
2189                    .collect();
2190
2191                // Only report on parameters that are in the notify set.
2192                for (name, value) in params
2193                    .into_iter()
2194                    .filter(|(name, _v)| notify_set.contains(*name))
2195                {
2196                    let msg = BackendMessage::ParameterStatus(name, value);
2197                    self.send(msg).await?;
2198                }
2199                command_complete!()
2200            }
2201
2202            ExecuteResponse::AlteredDefaultPrivileges
2203            | ExecuteResponse::AlteredObject(..)
2204            | ExecuteResponse::AlteredRole
2205            | ExecuteResponse::AlteredSystemConfiguration
2206            | ExecuteResponse::CreatedCluster { .. }
2207            | ExecuteResponse::CreatedClusterReplica { .. }
2208            | ExecuteResponse::CreatedConnection { .. }
2209            | ExecuteResponse::CreatedDatabase { .. }
2210            | ExecuteResponse::CreatedIndex { .. }
2211            | ExecuteResponse::CreatedIntrospectionSubscribe
2212            | ExecuteResponse::CreatedMaterializedView { .. }
2213            | ExecuteResponse::CreatedContinualTask { .. }
2214            | ExecuteResponse::CreatedRole
2215            | ExecuteResponse::CreatedSchema { .. }
2216            | ExecuteResponse::CreatedSecret { .. }
2217            | ExecuteResponse::CreatedSink { .. }
2218            | ExecuteResponse::CreatedSource { .. }
2219            | ExecuteResponse::CreatedTable { .. }
2220            | ExecuteResponse::CreatedType
2221            | ExecuteResponse::CreatedView { .. }
2222            | ExecuteResponse::CreatedViews { .. }
2223            | ExecuteResponse::CreatedNetworkPolicy
2224            | ExecuteResponse::Comment
2225            | ExecuteResponse::Deallocate { .. }
2226            | ExecuteResponse::Deleted(..)
2227            | ExecuteResponse::DiscardedAll
2228            | ExecuteResponse::DiscardedTemp
2229            | ExecuteResponse::DroppedObject(_)
2230            | ExecuteResponse::DroppedOwned
2231            | ExecuteResponse::GrantedPrivilege
2232            | ExecuteResponse::GrantedRole
2233            | ExecuteResponse::Inserted(..)
2234            | ExecuteResponse::Copied(..)
2235            | ExecuteResponse::Prepare
2236            | ExecuteResponse::Raised
2237            | ExecuteResponse::ReassignOwned
2238            | ExecuteResponse::RevokedPrivilege
2239            | ExecuteResponse::RevokedRole
2240            | ExecuteResponse::StartedTransaction { .. }
2241            | ExecuteResponse::Updated(..)
2242            | ExecuteResponse::ValidatedConnection => {
2243                command_complete!()
2244            }
2245        };
2246
2247        assert_none!(tag, "tag created but not consumed: {:?}", tag);
2248        r
2249    }
2250
2251    #[allow(clippy::too_many_arguments)]
2252    // TODO(guswynn): figure out how to get it to compile without skip_all
2253    #[mz_ore::instrument(level = "debug")]
2254    async fn send_rows(
2255        &mut self,
2256        row_desc: RelationDesc,
2257        portal_name: String,
2258        mut rows: InProgressRows,
2259        max_rows: ExecuteCount,
2260        get_response: GetResponse,
2261        fetch_portal_name: Option<String>,
2262        timeout: ExecuteTimeout,
2263    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2264        // If this portal is being executed from a FETCH then we need to use the result
2265        // format type of the outer portal.
2266        let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2267            name
2268        } else {
2269            &portal_name
2270        };
2271        let result_formats = self
2272            .adapter_client
2273            .session()
2274            .get_portal_unverified(result_format_portal_name)
2275            .expect("valid fetch portal name for send rows")
2276            .result_formats
2277            .clone();
2278
2279        let (mut wait_once, mut deadline) = match timeout {
2280            ExecuteTimeout::None => (false, None),
2281            ExecuteTimeout::Seconds(t) => (
2282                false,
2283                Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2284            ),
2285            ExecuteTimeout::WaitOnce => (true, None),
2286        };
2287
2288        // Sanity check that the various `RelationDesc`s match up.
2289        {
2290            let portal_name_desc = &self
2291                .adapter_client
2292                .session()
2293                .get_portal_unverified(portal_name.as_str())
2294                .expect("portal should exist")
2295                .desc
2296                .relation_desc;
2297            if let Some(portal_name_desc) = portal_name_desc {
2298                soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2299            }
2300            if let Some(fetch_portal_name) = &fetch_portal_name {
2301                let fetch_portal_desc = &self
2302                    .adapter_client
2303                    .session()
2304                    .get_portal_unverified(fetch_portal_name)
2305                    .expect("portal should exist")
2306                    .desc
2307                    .relation_desc;
2308                if let Some(fetch_portal_desc) = fetch_portal_desc {
2309                    soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2310                }
2311            }
2312        }
2313
2314        self.conn.set_encode_state(
2315            row_desc
2316                .typ()
2317                .column_types
2318                .iter()
2319                .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2320                .zip_eq(result_formats)
2321                .collect(),
2322        );
2323
2324        let mut total_sent_rows = 0;
2325        let mut total_sent_bytes = 0;
2326        // want_rows is the maximum number of rows the client wants.
2327        let mut want_rows = match max_rows {
2328            ExecuteCount::All => usize::MAX,
2329            ExecuteCount::Count(count) => count,
2330        };
2331
2332        // Send rows while the client still wants them and there are still rows to send.
2333        loop {
2334            // Fetch next batch of rows, waiting for a possible requested
2335            // timeout or notice.
2336            let batch = if rows.current.is_some() {
2337                FetchResult::Rows(rows.current.take())
2338            } else if want_rows == 0 {
2339                FetchResult::Rows(None)
2340            } else {
2341                let notice_fut = self.adapter_client.session().recv_notice();
2342                tokio::select! {
2343                    err = self.conn.wait_closed() => return Err(err),
2344                    _ = time::sleep_until(
2345                        deadline.unwrap_or_else(tokio::time::Instant::now),
2346                    ), if deadline.is_some() => FetchResult::Rows(None),
2347                    notice = notice_fut => {
2348                        FetchResult::Notice(notice)
2349                    }
2350                    batch = rows.remaining.recv() => match batch {
2351                        None => FetchResult::Rows(None),
2352                        Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2353                        Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2354                        Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2355                    },
2356                }
2357            };
2358
2359            match batch {
2360                FetchResult::Rows(None) => break,
2361                FetchResult::Rows(Some(mut batch_rows)) => {
2362                    if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2363                        let msg = err.to_string();
2364                        return self
2365                            .error(err.into_response(Severity::Error))
2366                            .await
2367                            .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2368                    }
2369
2370                    // If wait_once is true: the first time this fn is called it blocks (same as
2371                    // deadline == None). The second time this fn is called it should behave the
2372                    // same a 0s timeout.
2373                    if wait_once && batch_rows.peek().is_some() {
2374                        deadline = Some(tokio::time::Instant::now());
2375                        wait_once = false;
2376                    }
2377
2378                    // Send a portion of the rows.
2379                    let mut sent_rows = 0;
2380                    let mut sent_bytes = 0;
2381                    let messages = (&mut batch_rows)
2382                        // TODO(parkmycar): This is a fair bit of juggling between iterator types
2383                        // to count the total number of bytes. Alternatively we could track the
2384                        // total sent bytes in this .map(...) call, but having side effects in map
2385                        // is a code smell.
2386                        .map(|row| {
2387                            let row_len = row.byte_len();
2388                            let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2389                            (row_len, BackendMessage::DataRow(values))
2390                        })
2391                        .inspect(|(row_len, _)| {
2392                            sent_bytes += row_len;
2393                            sent_rows += 1
2394                        })
2395                        .map(|(_row_len, row)| row)
2396                        .take(want_rows);
2397                    self.send_all(messages).await?;
2398
2399                    total_sent_rows += sent_rows;
2400                    total_sent_bytes += sent_bytes;
2401                    want_rows -= sent_rows;
2402
2403                    // If we have sent the number of requested rows, put the remainder of the batch
2404                    // (if any) back and stop sending.
2405                    if want_rows == 0 {
2406                        if batch_rows.peek().is_some() {
2407                            rows.current = Some(batch_rows);
2408                        }
2409                        break;
2410                    }
2411
2412                    self.conn.flush().await?;
2413                }
2414                FetchResult::Notice(notice) => {
2415                    self.send(notice.into_response()).await?;
2416                    self.conn.flush().await?;
2417                }
2418                FetchResult::Error(text) => {
2419                    return self
2420                        .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2421                        .await
2422                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2423                }
2424                FetchResult::Canceled => {
2425                    return self
2426                        .error(ErrorResponse::error(
2427                            SqlState::QUERY_CANCELED,
2428                            "canceling statement due to user request",
2429                        ))
2430                        .await
2431                        .map(|state| (state, SendRowsEndedReason::Canceled));
2432                }
2433            }
2434        }
2435
2436        let portal = self
2437            .adapter_client
2438            .session()
2439            .get_portal_unverified_mut(&portal_name)
2440            .expect("valid portal name for send rows");
2441
2442        let saw_rows = rows.remaining.saw_rows;
2443        let no_more_rows = rows.no_more_rows();
2444        let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2445
2446        // Always return rows back, even if it's empty. This prevents an unclosed
2447        // portal from re-executing after it has been emptied.
2448        *portal.state = PortalState::InProgress(Some(rows));
2449
2450        let fetch_portal = fetch_portal_name.map(|name| {
2451            self.adapter_client
2452                .session()
2453                .get_portal_unverified_mut(&name)
2454                .expect("valid fetch portal")
2455        });
2456        let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2457        self.send(response_message).await?;
2458
2459        // Attend to metrics if there are no more rows.
2460        if no_more_rows {
2461            let statement_type = if let Some(stmt) = &self
2462                .adapter_client
2463                .session()
2464                .get_portal_unverified(&portal_name)
2465                .expect("valid portal name for send_rows")
2466                .stmt
2467            {
2468                metrics::statement_type_label_value(stmt.deref())
2469            } else {
2470                "no-statement"
2471            };
2472            let duration = if saw_rows {
2473                recorded_first_row_instant
2474                    .expect("recorded_first_row_instant because saw_rows")
2475                    .elapsed()
2476            } else {
2477                // If the result is empty, then we define time from first to last row as 0.
2478                // (Note that, currently, an empty result involves a PeekResponse with 0 rows, which
2479                // does flip `saw_rows`, so this code path is currently not exercised.)
2480                Duration::ZERO
2481            };
2482            self.adapter_client
2483                .inner()
2484                .metrics()
2485                .result_rows_first_to_last_byte_seconds
2486                .with_label_values(&[statement_type])
2487                .observe(duration.as_secs_f64());
2488        }
2489
2490        Ok((
2491            State::Ready,
2492            SendRowsEndedReason::Success {
2493                result_size: u64::cast_from(total_sent_bytes),
2494                rows_returned: u64::cast_from(total_sent_rows),
2495            },
2496        ))
2497    }
2498
2499    #[mz_ore::instrument(level = "debug")]
2500    async fn copy_rows(
2501        &mut self,
2502        format: CopyFormat,
2503        row_desc: RelationDesc,
2504        mut stream: RecordFirstRowStream,
2505    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2506        let (row_format, encode_format) = match format {
2507            CopyFormat::Text => (
2508                CopyFormatParams::Text(CopyTextFormatParams::default()),
2509                Format::Text,
2510            ),
2511            CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2512            CopyFormat::Csv => (
2513                CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2514                Format::Text,
2515            ),
2516            CopyFormat::Parquet => {
2517                let text = "Parquet format is not supported".to_string();
2518                return self
2519                    .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2520                    .await
2521                    .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2522            }
2523        };
2524
2525        let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2526            mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2527        };
2528
2529        let typ = row_desc.typ();
2530        let column_formats = iter::repeat(encode_format)
2531            .take(typ.column_types.len())
2532            .collect();
2533        self.send(BackendMessage::CopyOutResponse {
2534            overall_format: encode_format,
2535            column_formats,
2536        })
2537        .await?;
2538
2539        // In Postgres, binary copy has a header that is followed (in the same
2540        // CopyData) by the first row. In order to replicate their behavior, use a
2541        // common vec that we can extend one time now and then fill up with the encode
2542        // functions.
2543        let mut out = Vec::new();
2544
2545        if let CopyFormat::Binary = format {
2546            // 11-byte signature.
2547            out.extend(b"PGCOPY\n\xFF\r\n\0");
2548            // 32-bit flags field.
2549            out.extend([0, 0, 0, 0]);
2550            // 32-bit header extension length field.
2551            out.extend([0, 0, 0, 0]);
2552        }
2553
2554        let mut count = 0;
2555        let mut total_sent_bytes = 0;
2556        loop {
2557            tokio::select! {
2558                e = self.conn.wait_closed() => return Err(e),
2559                batch = stream.recv() => match batch {
2560                    None => break,
2561                    Some(PeekResponseUnary::Error(text)) => {
2562                        return self
2563                            .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2564                        .await
2565                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2566                    }
2567                    Some(PeekResponseUnary::Canceled) => {
2568                        return self.error(ErrorResponse::error(
2569                                SqlState::QUERY_CANCELED,
2570                                "canceling statement due to user request",
2571                            ))
2572                            .await.map(|state| (state, SendRowsEndedReason::Canceled));
2573                    }
2574                    Some(PeekResponseUnary::Rows(mut rows)) => {
2575                        count += rows.count();
2576                        while let Some(row) = rows.next() {
2577                            total_sent_bytes += row.byte_len();
2578                            encode_fn(row, typ, &mut out)?;
2579                            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2580                                .await?;
2581                        }
2582                    }
2583                },
2584                notice = self.adapter_client.session().recv_notice() => {
2585                    self.send(notice.into_response())
2586                        .await?;
2587                    self.conn.flush().await?;
2588                }
2589            }
2590
2591            self.conn.flush().await?;
2592        }
2593        // Send required trailers.
2594        if let CopyFormat::Binary = format {
2595            let trailer: i16 = -1;
2596            out.extend(trailer.to_be_bytes());
2597            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2598                .await?;
2599        }
2600
2601        let tag = format!("COPY {}", count);
2602        self.send(BackendMessage::CopyDone).await?;
2603        self.send(BackendMessage::CommandComplete { tag }).await?;
2604        Ok((
2605            State::Ready,
2606            SendRowsEndedReason::Success {
2607                result_size: u64::cast_from(total_sent_bytes),
2608                rows_returned: u64::cast_from(count),
2609            },
2610        ))
2611    }
2612
2613    /// Handles the copy-in mode of the postgres protocol from transferring
2614    /// data to the server.
2615    #[instrument(level = "debug")]
2616    async fn copy_from(
2617        &mut self,
2618        target_id: CatalogItemId,
2619        target_name: String,
2620        columns: Vec<ColumnIndex>,
2621        params: CopyFormatParams<'_>,
2622        row_desc: RelationDesc,
2623        mut ctx_extra: ExecuteContextGuard,
2624    ) -> Result<State, io::Error> {
2625        let res = self
2626            .copy_from_inner(
2627                target_id,
2628                target_name,
2629                columns,
2630                params,
2631                row_desc,
2632                &mut ctx_extra,
2633            )
2634            .await;
2635        match &res {
2636            Ok(State::Done) => {
2637                // The connection closed gracefully without sending us a `CopyDone`,
2638                // causing us to just drop the copy request.
2639                // For the purposes of statement logging, we count this as a cancellation.
2640                self.adapter_client
2641                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2642            }
2643            Err(e) => {
2644                self.adapter_client.retire_execute(
2645                    ctx_extra,
2646                    StatementEndedExecutionReason::Errored {
2647                        error: format!("{e}"),
2648                    },
2649                );
2650            }
2651            other => {
2652                tracing::warn!(?other, "aborting COPY FROM");
2653                self.adapter_client
2654                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Aborted);
2655            }
2656        }
2657        res
2658    }
2659
2660    async fn copy_from_inner(
2661        &mut self,
2662        target_id: CatalogItemId,
2663        target_name: String,
2664        columns: Vec<ColumnIndex>,
2665        params: CopyFormatParams<'_>,
2666        row_desc: RelationDesc,
2667        ctx_extra: &mut ExecuteContextGuard,
2668    ) -> Result<State, io::Error> {
2669        let typ = row_desc.typ();
2670        let column_formats = vec![Format::Text; typ.column_types.len()];
2671        self.send(BackendMessage::CopyInResponse {
2672            overall_format: Format::Text,
2673            column_formats,
2674        })
2675        .await?;
2676        self.conn.flush().await?;
2677
2678        let system_vars = self.adapter_client.get_system_vars().await;
2679        let max_size = system_vars
2680            .get(MAX_COPY_FROM_SIZE.name())
2681            .ok()
2682            .and_then(|max_size| max_size.value().parse().ok())
2683            .unwrap_or(usize::MAX);
2684        tracing::debug!("COPY FROM max buffer size: {max_size} bytes");
2685
2686        let mut data = Vec::new();
2687        loop {
2688            let message = self.conn.recv().await?;
2689            match message {
2690                Some(FrontendMessage::CopyData(buf)) => {
2691                    // Bail before we OOM.
2692                    if (data.len() + buf.len()) > max_size {
2693                        return self
2694                            .error(ErrorResponse::error(
2695                                SqlState::INSUFFICIENT_RESOURCES,
2696                                "COPY FROM STDIN too large",
2697                            ))
2698                            .await;
2699                    }
2700                    data.extend(buf)
2701                }
2702                Some(FrontendMessage::CopyDone) => break,
2703                Some(FrontendMessage::CopyFail(err)) => {
2704                    self.adapter_client.retire_execute(
2705                        std::mem::take(ctx_extra),
2706                        StatementEndedExecutionReason::Canceled,
2707                    );
2708                    return self
2709                        .error(ErrorResponse::error(
2710                            SqlState::QUERY_CANCELED,
2711                            format!("COPY from stdin failed: {}", err),
2712                        ))
2713                        .await;
2714                }
2715                Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2716                Some(_) => {
2717                    let msg = "unexpected message type during COPY from stdin";
2718                    self.adapter_client.retire_execute(
2719                        std::mem::take(ctx_extra),
2720                        StatementEndedExecutionReason::Errored {
2721                            error: msg.to_string(),
2722                        },
2723                    );
2724                    return self
2725                        .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
2726                        .await;
2727                }
2728                None => {
2729                    return Ok(State::Done);
2730                }
2731            }
2732        }
2733
2734        let column_types = typ
2735            .column_types
2736            .iter()
2737            .map(|x| &x.scalar_type)
2738            .map(mz_pgrepr::Type::from)
2739            .collect::<Vec<mz_pgrepr::Type>>();
2740
2741        let rows = match mz_pgcopy::decode_copy_format(&data, &column_types, params) {
2742            Ok(rows) => rows,
2743            Err(e) => {
2744                self.adapter_client.retire_execute(
2745                    std::mem::take(ctx_extra),
2746                    StatementEndedExecutionReason::Errored {
2747                        error: e.to_string(),
2748                    },
2749                );
2750                return self
2751                    .error(ErrorResponse::error(
2752                        SqlState::BAD_COPY_FILE_FORMAT,
2753                        format!("{}", e),
2754                    ))
2755                    .await;
2756            }
2757        };
2758
2759        let count = rows.len();
2760
2761        if let Err(e) = self
2762            .adapter_client
2763            .insert_rows(
2764                target_id,
2765                target_name,
2766                columns,
2767                rows,
2768                std::mem::take(ctx_extra),
2769            )
2770            .await
2771        {
2772            self.adapter_client.retire_execute(
2773                std::mem::take(ctx_extra),
2774                StatementEndedExecutionReason::Errored {
2775                    error: e.to_string(),
2776                },
2777            );
2778            return self.error(e.into_response(Severity::Error)).await;
2779        }
2780
2781        let tag = format!("COPY {}", count);
2782        self.send(BackendMessage::CommandComplete { tag }).await?;
2783
2784        Ok(State::Ready)
2785    }
2786
2787    #[instrument(level = "debug")]
2788    async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
2789        let notices = self
2790            .adapter_client
2791            .session()
2792            .drain_notices()
2793            .into_iter()
2794            .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
2795        self.send_all(notices).await?;
2796        Ok(())
2797    }
2798
2799    #[instrument(level = "debug")]
2800    async fn error(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
2801        assert!(err.severity.is_error());
2802        debug!(
2803            "cid={} error code={}",
2804            self.adapter_client.session().conn_id(),
2805            err.code.code()
2806        );
2807        let is_fatal = err.severity.is_fatal();
2808        self.send(BackendMessage::ErrorResponse(err)).await?;
2809
2810        let txn = self.adapter_client.session().transaction();
2811        match txn {
2812            // Error can be called from describe and parse and so might not be in an active
2813            // transaction.
2814            TransactionStatus::Default | TransactionStatus::Failed(_) => {}
2815            // In Started (i.e., a single statement), cleanup ourselves.
2816            TransactionStatus::Started(_) => {
2817                self.rollback_transaction().await?;
2818            }
2819            // Implicit transactions also clear themselves.
2820            TransactionStatus::InTransactionImplicit(_) => {
2821                self.rollback_transaction().await?;
2822            }
2823            // Explicit transactions move to failed.
2824            TransactionStatus::InTransaction(_) => {
2825                self.adapter_client.fail_transaction();
2826            }
2827        };
2828        if is_fatal {
2829            Ok(State::Done)
2830        } else {
2831            Ok(State::Drain)
2832        }
2833    }
2834
2835    #[instrument(level = "debug")]
2836    async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
2837        self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
2838            SqlState::IN_FAILED_SQL_TRANSACTION,
2839            ABORTED_TXN_MSG,
2840        )))
2841        .await?;
2842        Ok(State::Drain)
2843    }
2844
2845    fn is_aborted_txn(&mut self) -> bool {
2846        matches!(
2847            self.adapter_client.session().transaction(),
2848            TransactionStatus::Failed(_)
2849        )
2850    }
2851}
2852
2853fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
2854    match (formats.len(), n) {
2855        (0, e) => Ok(vec![Format::Text; e]),
2856        (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
2857        (a, e) if a == e => Ok(formats),
2858        (a, e) => Err(format!(
2859            "expected {} field format specifiers, but got {}",
2860            e, a
2861        )),
2862    }
2863}
2864
2865fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
2866    match &stmt_desc.relation_desc {
2867        Some(desc) if !stmt_desc.is_copy => {
2868            BackendMessage::RowDescription(message::encode_row_description(desc, formats))
2869        }
2870        _ => BackendMessage::NoData,
2871    }
2872}
2873
2874type GetResponse = fn(
2875    max_rows: ExecuteCount,
2876    total_sent_rows: usize,
2877    fetch_portal: Option<PortalRefMut>,
2878) -> BackendMessage;
2879
2880// A GetResponse used by send_rows during execute messages on portals or for
2881// simple query messages.
2882fn portal_exec_message(
2883    max_rows: ExecuteCount,
2884    total_sent_rows: usize,
2885    _fetch_portal: Option<PortalRefMut>,
2886) -> BackendMessage {
2887    // If max_rows is not specified, we will always send back a CommandComplete. If
2888    // max_rows is specified, we only send CommandComplete if there were more rows
2889    // requested than were remaining. That is, if max_rows == number of rows that
2890    // were remaining before sending (not that are remaining after sending), then
2891    // we still send a PortalSuspended. The number of remaining rows after the rows
2892    // have been sent doesn't matter. This matches postgres.
2893    match max_rows {
2894        ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
2895            BackendMessage::PortalSuspended
2896        }
2897        _ => BackendMessage::CommandComplete {
2898            tag: format!("SELECT {}", total_sent_rows),
2899        },
2900    }
2901}
2902
2903// A GetResponse used by send_rows during FETCH queries.
2904fn fetch_message(
2905    _max_rows: ExecuteCount,
2906    total_sent_rows: usize,
2907    fetch_portal: Option<PortalRefMut>,
2908) -> BackendMessage {
2909    let tag = format!("FETCH {}", total_sent_rows);
2910    if let Some(portal) = fetch_portal {
2911        *portal.state = PortalState::Completed(Some(tag.clone()));
2912    }
2913    BackendMessage::CommandComplete { tag }
2914}
2915
2916fn get_authenticator(
2917    authenticator_kind: AuthenticatorKind,
2918    frontegg: Option<FronteggAuthenticator>,
2919    oidc: GenericOidcAuthenticator,
2920    adapter_client: mz_adapter::Client,
2921    // If oidc_auth_enabled exists as an option in the pgwire connection's
2922    // `option` parameter
2923    oidc_auth_option_enabled: bool,
2924) -> Authenticator {
2925    match authenticator_kind {
2926        AuthenticatorKind::Frontegg => Authenticator::Frontegg(
2927            frontegg.expect("Frontegg authenticator should exist with AuthenticatorKind::Frontegg"),
2928        ),
2929        AuthenticatorKind::Password => Authenticator::Password(adapter_client),
2930        AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client),
2931        AuthenticatorKind::Oidc => {
2932            if oidc_auth_option_enabled {
2933                Authenticator::Oidc(oidc)
2934            } else {
2935                // Fallback to password authentication if oidc auth is not enabled
2936                // through options.
2937                Authenticator::Password(adapter_client)
2938            }
2939        }
2940        AuthenticatorKind::None => Authenticator::None,
2941    }
2942}
2943
2944#[derive(Debug, Copy, Clone)]
2945enum ExecuteCount {
2946    All,
2947    Count(usize),
2948}
2949
2950// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
2951fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
2952    match stmt {
2953        // Add PREPARE to this if we ever support it.
2954        Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
2955        None => false,
2956    }
2957}
2958
2959#[derive(Debug)]
2960enum FetchResult {
2961    Rows(Option<Box<dyn RowIterator + Send + Sync>>),
2962    Canceled,
2963    Error(String),
2964    Notice(AdapterNotice),
2965}
2966
2967#[cfg(test)]
2968mod test {
2969    use super::*;
2970
2971    #[mz_ore::test]
2972    fn test_parse_options() {
2973        struct TestCase {
2974            input: &'static str,
2975            expect: Result<Vec<(&'static str, &'static str)>, ()>,
2976        }
2977        let tests = vec![
2978            TestCase {
2979                input: "",
2980                expect: Ok(vec![]),
2981            },
2982            TestCase {
2983                input: "--key",
2984                expect: Err(()),
2985            },
2986            TestCase {
2987                input: "--key=val",
2988                expect: Ok(vec![("key", "val")]),
2989            },
2990            TestCase {
2991                input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
2992                expect: Ok(vec![
2993                    ("key", "val"),
2994                    ("key2", "val2"),
2995                    ("key3", "val3"),
2996                    ("key4", "val4"),
2997                    ("key5", "val5"),
2998                ]),
2999            },
3000            TestCase {
3001                input: r#"-c\ key=val"#,
3002                expect: Ok(vec![(" key", "val")]),
3003            },
3004            TestCase {
3005                input: "--key=val -ckey2 val2",
3006                expect: Err(()),
3007            },
3008            // Unclear what this should do.
3009            TestCase {
3010                input: "--key=",
3011                expect: Ok(vec![("key", "")]),
3012            },
3013        ];
3014        for test in tests {
3015            let got = parse_options(test.input);
3016            let expect = test.expect.map(|r| {
3017                r.into_iter()
3018                    .map(|(k, v)| (k.to_owned(), v.to_owned()))
3019                    .collect()
3020            });
3021            assert_eq!(got, expect, "input: {}", test.input);
3022        }
3023    }
3024
3025    #[mz_ore::test]
3026    fn test_parse_option() {
3027        struct TestCase {
3028            input: &'static str,
3029            expect: Result<(&'static str, &'static str), ()>,
3030        }
3031        let tests = vec![
3032            TestCase {
3033                input: "",
3034                expect: Err(()),
3035            },
3036            TestCase {
3037                input: "--",
3038                expect: Err(()),
3039            },
3040            TestCase {
3041                input: "--c",
3042                expect: Err(()),
3043            },
3044            TestCase {
3045                input: "a=b",
3046                expect: Err(()),
3047            },
3048            TestCase {
3049                input: "--a=b",
3050                expect: Ok(("a", "b")),
3051            },
3052            TestCase {
3053                input: "--ca=b",
3054                expect: Ok(("ca", "b")),
3055            },
3056            TestCase {
3057                input: "-ca=b",
3058                expect: Ok(("a", "b")),
3059            },
3060            // Unclear what this should error, but at least test it.
3061            TestCase {
3062                input: "--=",
3063                expect: Ok(("", "")),
3064            },
3065        ];
3066        for test in tests {
3067            let got = parse_option(test.input);
3068            assert_eq!(got, test.expect, "input: {}", test.input);
3069        }
3070    }
3071
3072    #[mz_ore::test]
3073    fn test_split_options() {
3074        struct TestCase {
3075            input: &'static str,
3076            expect: Vec<&'static str>,
3077        }
3078        let tests = vec![
3079            TestCase {
3080                input: "",
3081                expect: vec![],
3082            },
3083            TestCase {
3084                input: "  ",
3085                expect: vec![],
3086            },
3087            TestCase {
3088                input: " a ",
3089                expect: vec!["a"],
3090            },
3091            TestCase {
3092                input: "  ab     cd   ",
3093                expect: vec!["ab", "cd"],
3094            },
3095            TestCase {
3096                input: r#"  ab\     cd   "#,
3097                expect: vec!["ab ", "cd"],
3098            },
3099            TestCase {
3100                input: r#"  ab\\     cd   "#,
3101                expect: vec![r#"ab\"#, "cd"],
3102            },
3103            TestCase {
3104                input: r#"  ab\\\     cd   "#,
3105                expect: vec![r#"ab\ "#, "cd"],
3106            },
3107            TestCase {
3108                input: r#"  ab\\\ cd   "#,
3109                expect: vec![r#"ab\ cd"#],
3110            },
3111            TestCase {
3112                input: r#"  ab\\\cd   "#,
3113                expect: vec![r#"ab\cd"#],
3114            },
3115            TestCase {
3116                input: r#"a\"#,
3117                expect: vec!["a"],
3118            },
3119            TestCase {
3120                input: r#"a\ "#,
3121                expect: vec!["a "],
3122            },
3123            TestCase {
3124                input: r#"\"#,
3125                expect: vec![],
3126            },
3127            TestCase {
3128                input: r#"\ "#,
3129                expect: vec![r#" "#],
3130            },
3131            TestCase {
3132                input: r#" \ "#,
3133                expect: vec![r#" "#],
3134            },
3135            TestCase {
3136                input: r#"\  "#,
3137                expect: vec![r#" "#],
3138            },
3139        ];
3140        for test in tests {
3141            let got = split_options(test.input);
3142            assert_eq!(got, test.expect, "input: {}", test.input);
3143        }
3144    }
3145}