mz_pgwire/
protocol.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use byteorder::{ByteOrder, NetworkEndian};
19use futures::future::{BoxFuture, FutureExt, pending};
20use itertools::Itertools;
21use mz_adapter::client::RecordFirstRowStream;
22use mz_adapter::session::{
23    EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState,
24    SessionConfig, TransactionStatus,
25};
26use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
27use mz_adapter::{
28    AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, PeekResponseUnary, metrics,
29    verify_datum_desc,
30};
31use mz_auth::password::Password;
32use mz_authenticator::Authenticator;
33use mz_ore::cast::CastFrom;
34use mz_ore::netio::AsyncReady;
35use mz_ore::now::{EpochMillis, SYSTEM_TIME};
36use mz_ore::str::StrExt;
37use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
38use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
39use mz_pgwire_common::{
40    ConnectionCounter, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3, VERSIONS,
41};
42use mz_repr::user::InternalUserMetadata;
43use mz_repr::{
44    CatalogItemId, ColumnIndex, Datum, RelationDesc, RelationType, RowArena, RowIterator, RowRef,
45    ScalarType,
46};
47use mz_server_core::TlsMode;
48use mz_server_core::listeners::AllowedRoles;
49use mz_sql::ast::display::AstDisplay;
50use mz_sql::ast::{CopyDirection, CopyStatement, FetchDirection, Ident, Raw, Statement};
51use mz_sql::parse::StatementParseResult;
52use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
53use mz_sql::session::metadata::SessionMetadata;
54use mz_sql::session::user::INTERNAL_USER_NAMES;
55use mz_sql::session::vars::{MAX_COPY_FROM_SIZE, Var, VarInput};
56use postgres::error::SqlState;
57use tokio::io::{self, AsyncRead, AsyncWrite};
58use tokio::select;
59use tokio::time::{self};
60use tokio_metrics::TaskMetrics;
61use tokio_stream::wrappers::UnboundedReceiverStream;
62use tracing::{Instrument, debug, debug_span, warn};
63use uuid::Uuid;
64
65use crate::codec::FramedConn;
66use crate::message::{self, BackendMessage};
67
68/// Reports whether the given stream begins with a pgwire handshake.
69///
70/// To avoid false negatives, there must be at least eight bytes in `buf`.
71pub fn match_handshake(buf: &[u8]) -> bool {
72    // The pgwire StartupMessage looks like this:
73    //
74    //     i32 - Length of entire message.
75    //     i32 - Protocol version number.
76    //     [String] - Arbitrary key-value parameters of any length.
77    //
78    // Since arbitrary parameters can be included in the StartupMessage, the
79    // first Int32 is worthless, since the message could have any length.
80    // Instead, we sniff the protocol version number.
81    if buf.len() < 8 {
82        return false;
83    }
84    let version = NetworkEndian::read_i32(&buf[4..8]);
85    VERSIONS.contains(&version)
86}
87
88/// Parameters for the [`run`] function.
89pub struct RunParams<'a, A, I>
90where
91    I: Iterator<Item = TaskMetrics> + Send,
92{
93    /// The TLS mode of the pgwire server.
94    pub tls_mode: Option<TlsMode>,
95    /// A client for the adapter.
96    pub adapter_client: mz_adapter::Client,
97    /// The connection to the client.
98    pub conn: &'a mut FramedConn<A>,
99    /// The universally unique identifier for the connection.
100    pub conn_uuid: Uuid,
101    /// The protocol version that the client provided in the startup message.
102    pub version: i32,
103    /// The parameters that the client provided in the startup message.
104    pub params: BTreeMap<String, String>,
105    /// Authentication method to use. Frontegg, Password, or None.
106    pub authenticator: Authenticator,
107    /// Global connection limit and count
108    pub active_connection_counter: ConnectionCounter,
109    /// Helm chart version
110    pub helm_chart_version: Option<String>,
111    /// Whether to allow reserved users (ie: mz_system).
112    pub allowed_roles: AllowedRoles,
113    /// Tokio metrics
114    pub tokio_metrics_intervals: I,
115}
116
117/// Runs a pgwire connection to completion.
118///
119/// This involves responding to `FrontendMessage::StartupMessage` and all future
120/// requests until the client terminates the connection or a fatal error occurs.
121///
122/// Note that this function returns successfully even upon delivering a fatal
123/// error to the client. It only returns `Err` if an unexpected I/O error occurs
124/// while communicating with the client, e.g., if the connection is severed in
125/// the middle of a request.
126#[mz_ore::instrument(level = "debug")]
127pub async fn run<'a, A, I>(
128    RunParams {
129        tls_mode,
130        adapter_client,
131        conn,
132        conn_uuid,
133        version,
134        mut params,
135        authenticator,
136        active_connection_counter,
137        helm_chart_version,
138        allowed_roles,
139        tokio_metrics_intervals,
140    }: RunParams<'a, A, I>,
141) -> Result<(), io::Error>
142where
143    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
144    I: Iterator<Item = TaskMetrics> + Send,
145{
146    if version != VERSION_3 {
147        return conn
148            .send(ErrorResponse::fatal(
149                SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
150                "server does not support the client's requested protocol version",
151            ))
152            .await;
153    }
154
155    let user = params.remove("user").unwrap_or_else(String::new);
156
157    // TODO move this somewhere it can be shared with HTTP
158    let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
159    // this is a superset of internal users
160    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
161    let role_allowed = match allowed_roles {
162        AllowedRoles::Normal => !is_reserved_user,
163        AllowedRoles::Internal => is_internal_user,
164        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
165    };
166    if !role_allowed {
167        let msg = format!("unauthorized login to user '{user}'");
168        return conn
169            .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
170            .await;
171    }
172
173    if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
174        return conn.send(err).await;
175    }
176
177    let (mut session, expired) = match authenticator {
178        Authenticator::Frontegg(frontegg) => {
179            conn.send(BackendMessage::AuthenticationCleartextPassword)
180                .await?;
181            conn.flush().await?;
182            let password = match conn.recv().await? {
183                Some(FrontendMessage::Password { password }) => password,
184                _ => {
185                    return conn
186                        .send(ErrorResponse::fatal(
187                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
188                            "expected Password message",
189                        ))
190                        .await;
191                }
192            };
193
194            let auth_response = frontegg.authenticate(&user, &password).await;
195            match auth_response {
196                Ok(mut auth_session) => {
197                    // Create a session based on the auth session.
198                    //
199                    // In particular, it's important that the username come from the
200                    // auth session, as Frontegg may return an email address with
201                    // different casing than the user supplied via the pgwire
202                    // username field. We want to use the Frontegg casing as
203                    // canonical.
204                    let session = adapter_client.new_session(SessionConfig {
205                        conn_id: conn.conn_id().clone(),
206                        uuid: conn_uuid,
207                        user: auth_session.user().into(),
208                        client_ip: conn.peer_addr().clone(),
209                        external_metadata_rx: Some(auth_session.external_metadata_rx()),
210                        internal_user_metadata: None,
211                        helm_chart_version,
212                    });
213                    let expired = async move { auth_session.expired().await };
214                    (session, expired.left_future())
215                }
216                Err(err) => {
217                    warn!(?err, "pgwire connection failed authentication");
218                    return conn
219                        .send(ErrorResponse::fatal(
220                            SqlState::INVALID_PASSWORD,
221                            "invalid password",
222                        ))
223                        .await;
224                }
225            }
226        }
227        Authenticator::Password(adapter_client) => {
228            conn.send(BackendMessage::AuthenticationCleartextPassword)
229                .await?;
230            conn.flush().await?;
231            let password = match conn.recv().await? {
232                Some(FrontendMessage::Password { password }) => Password(password),
233                _ => {
234                    return conn
235                        .send(ErrorResponse::fatal(
236                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
237                            "expected Password message",
238                        ))
239                        .await;
240                }
241            };
242            let auth_response = match adapter_client.authenticate(&user, &password).await {
243                Ok(resp) => resp,
244                Err(err) => {
245                    warn!(?err, "pgwire connection failed authentication");
246                    return conn
247                        .send(ErrorResponse::fatal(
248                            SqlState::INVALID_PASSWORD,
249                            "invalid password",
250                        ))
251                        .await;
252                }
253            };
254            let session = adapter_client.new_session(SessionConfig {
255                conn_id: conn.conn_id().clone(),
256                uuid: conn_uuid,
257                user,
258                client_ip: conn.peer_addr().clone(),
259                external_metadata_rx: None,
260                internal_user_metadata: Some(InternalUserMetadata {
261                    superuser: auth_response.superuser,
262                }),
263                helm_chart_version,
264            });
265            // No frontegg check, so auth session lasts indefinitely.
266            let auth_session = pending().right_future();
267            (session, auth_session)
268        }
269        Authenticator::None => {
270            let session = adapter_client.new_session(SessionConfig {
271                conn_id: conn.conn_id().clone(),
272                uuid: conn_uuid,
273                user,
274                client_ip: conn.peer_addr().clone(),
275                external_metadata_rx: None,
276                internal_user_metadata: None,
277                helm_chart_version,
278            });
279            // No frontegg check, so auth session lasts indefinitely.
280            let auth_session = pending().right_future();
281            (session, auth_session)
282        }
283    };
284
285    let system_vars = adapter_client.get_system_vars().await;
286    for (name, value) in params {
287        let settings = match name.as_str() {
288            "options" => match parse_options(&value) {
289                Ok(opts) => opts,
290                Err(()) => {
291                    session.add_notice(AdapterNotice::BadStartupSetting {
292                        name,
293                        reason: "could not parse".into(),
294                    });
295                    continue;
296                }
297            },
298            _ => vec![(name, value)],
299        };
300        for (key, val) in settings {
301            const LOCAL: bool = false;
302            // TODO: Issuing an error here is better than what we did before
303            // (silently ignore errors on set), but erroring the connection
304            // might be the better behavior. We maybe need to support more
305            // options sent by psql and drivers before we can safely do this.
306            if let Err(err) =
307                session
308                    .vars_mut()
309                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
310            {
311                session.add_notice(AdapterNotice::BadStartupSetting {
312                    name: key,
313                    reason: err.to_string(),
314                });
315            }
316        }
317    }
318    session
319        .vars_mut()
320        .end_transaction(EndTransactionAction::Commit);
321
322    let _guard = match active_connection_counter.allocate_connection(session.user()) {
323        Ok(drop_connection) => drop_connection,
324        Err(e) => {
325            let e: AdapterError = e.into();
326            return conn.send(e.into_response(Severity::Fatal)).await;
327        }
328    };
329
330    // Register session with adapter.
331    let mut adapter_client = match adapter_client.startup(session).await {
332        Ok(adapter_client) => adapter_client,
333        Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
334    };
335
336    let mut buf = vec![BackendMessage::AuthenticationOk];
337    for var in adapter_client.session().vars().notify_set() {
338        buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
339    }
340    buf.push(BackendMessage::BackendKeyData {
341        conn_id: adapter_client.session().conn_id().unhandled(),
342        secret_key: adapter_client.session().secret_key(),
343    });
344    buf.extend(
345        adapter_client
346            .session()
347            .drain_notices()
348            .into_iter()
349            .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
350    );
351    buf.push(BackendMessage::ReadyForQuery(
352        adapter_client.session().transaction().into(),
353    ));
354    conn.send_all(buf).await?;
355    conn.flush().await?;
356
357    let machine = StateMachine {
358        conn,
359        adapter_client,
360        txn_needs_commit: false,
361        tokio_metrics_intervals,
362    };
363
364    select! {
365        r = machine.run() => {
366            // Errors produced internally (like MAX_REQUEST_SIZE being exceeded) should send an
367            // error to the client informing them why the connection was closed. We still want to
368            // return the original error up the stack, though, so we skip error checking during conn
369            // operations.
370            if let Err(err) = &r {
371                let _ = conn
372                    .send(ErrorResponse::fatal(
373                        SqlState::CONNECTION_FAILURE,
374                        err.to_string(),
375                    ))
376                    .await;
377                let _ = conn.flush().await;
378            }
379            r
380        },
381        _ = expired => {
382            conn
383                .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
384                .await?;
385            conn.flush().await
386        }
387    }
388}
389
390/// Returns (name, value) session settings pairs from an options value.
391///
392/// From Postgres, see pg_split_opts in postinit.c and process_postgres_switches
393/// in postgres.c.
394fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
395    let opts = split_options(value);
396    let mut pairs = Vec::with_capacity(opts.len());
397    let mut seen_prefix = false;
398    for opt in opts {
399        if !seen_prefix {
400            if opt == "-c" {
401                seen_prefix = true;
402            } else {
403                let (key, val) = parse_option(&opt)?;
404                pairs.push((key.to_owned(), val.to_owned()));
405            }
406        } else {
407            let (key, val) = opt.split_once('=').ok_or(())?;
408            pairs.push((key.to_owned(), val.to_owned()));
409            seen_prefix = false;
410        }
411    }
412    Ok(pairs)
413}
414
415/// Returns the parsed key and value from option of the form `--key=value`, `-c
416/// key=value`, or `-ckey=value`. Keys replace `-` with `_`. Returns an error if
417/// there was some other prefix.
418fn parse_option(option: &str) -> Result<(&str, &str), ()> {
419    let (key, value) = option.split_once('=').ok_or(())?;
420    for prefix in &["-c", "--"] {
421        if let Some(key) = key.strip_prefix(prefix) {
422            return Ok((key, value));
423        }
424    }
425    Err(())
426}
427
428/// Splits value by any number of spaces except those preceded by `\`.
429fn split_options(value: &str) -> Vec<String> {
430    let mut strs = Vec::new();
431    // Need to build a string because of the escaping, so we can't simply
432    // subslice into value, and this isn't called enough to need to make it
433    // smart so it only builds a string if needed.
434    let mut current = String::new();
435    let mut was_slash = false;
436    for c in value.chars() {
437        was_slash = match c {
438            ' ' => {
439                if was_slash {
440                    current.push(' ');
441                } else if !current.is_empty() {
442                    // To ignore multiple spaces in a row, only push if current
443                    // is not empty.
444                    strs.push(std::mem::take(&mut current));
445                }
446                false
447            }
448            '\\' => {
449                if was_slash {
450                    // Two slashes in a row will add a slash and not escape the
451                    // next char.
452                    current.push('\\');
453                    false
454                } else {
455                    true
456                }
457            }
458            _ => {
459                current.push(c);
460                false
461            }
462        };
463    }
464    // A `\` at the end will be ignored.
465    if !current.is_empty() {
466        strs.push(current);
467    }
468    strs
469}
470
471#[derive(Debug)]
472enum State {
473    Ready,
474    Drain,
475    Done,
476}
477
478struct StateMachine<'a, A, I>
479where
480    I: Iterator<Item = TaskMetrics> + Send + 'a,
481{
482    conn: &'a mut FramedConn<A>,
483    adapter_client: mz_adapter::SessionClient,
484    txn_needs_commit: bool,
485    tokio_metrics_intervals: I,
486}
487
488enum SendRowsEndedReason {
489    Success {
490        result_size: u64,
491        rows_returned: u64,
492    },
493    Errored {
494        error: String,
495    },
496    Canceled,
497}
498
499const ABORTED_TXN_MSG: &str =
500    "current transaction is aborted, commands ignored until end of transaction block";
501
502impl<'a, A, I> StateMachine<'a, A, I>
503where
504    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
505    I: Iterator<Item = TaskMetrics> + Send + 'a,
506{
507    // Manually desugar this (don't use `async fn run`) here because a much better
508    // error message is produced if there are problems with Send or other traits
509    // somewhere within the Future.
510    #[allow(clippy::manual_async_fn)]
511    #[mz_ore::instrument(level = "debug")]
512    fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
513        async move {
514            let mut state = State::Ready;
515            loop {
516                self.send_pending_notices().await?;
517                state = match state {
518                    State::Ready => self.advance_ready().await?,
519                    State::Drain => self.advance_drain().await?,
520                    State::Done => return Ok(()),
521                };
522                self.adapter_client
523                    .add_idle_in_transaction_session_timeout();
524            }
525        }
526    }
527
528    #[instrument(level = "debug")]
529    async fn advance_ready(&mut self) -> Result<State, io::Error> {
530        // Start a new metrics interval before the `recv()` call.
531        self.tokio_metrics_intervals
532            .next()
533            .expect("infinite iterator");
534
535        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
536        let message = select! {
537            biased;
538
539            // `recv_timeout()` is cancel-safe as per it's docs.
540            Some(timeout) = self.adapter_client.recv_timeout() => {
541                let err: AdapterError = timeout.into();
542                let conn_id = self.adapter_client.session().conn_id();
543                tracing::warn!("session timed out, conn_id {}", conn_id);
544
545                // Process the error, doing any state cleanup.
546                let error_response = err.into_response(Severity::Fatal);
547                let error_state = self.error(error_response).await;
548
549                // Terminate __after__ we do any cleanup.
550                self.adapter_client.terminate().await;
551
552                // We must wait for the client to send a request before we can send the error response.
553                // Due to the PG wire protocol, we can't send an ErrorResponse unless it is in response
554                // to a client message.
555                let _ = self.conn.recv().await?;
556                return error_state;
557            },
558            // `recv()` is cancel-safe as per it's docs.
559            message = self.conn.recv() => message?,
560        };
561
562        // Take the metrics since just before the `recv`.
563        let interval = self
564            .tokio_metrics_intervals
565            .next()
566            .expect("infinite iterator");
567        let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
568
569        // TODO(ggevay): Consider subtracting the scheduling delay from `received`. It's not obvious
570        // whether we should do this, because the result wouldn't exactly correspond to either first
571        // byte received or last byte received (for msgs that arrive in more than one network packet).
572        let received = SYSTEM_TIME();
573
574        self.adapter_client
575            .remove_idle_in_transaction_session_timeout();
576
577        // NOTE(guswynn): we could consider adding spans to all message types. Currently
578        // only a few message types seem useful.
579        let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
580
581        let start = message.as_ref().map(|_| Instant::now());
582        let next_state = match message {
583            Some(FrontendMessage::Query { sql }) => {
584                let query_root_span =
585                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
586                query_root_span.follows_from(tracing::Span::current());
587                self.query(sql, received)
588                    .instrument(query_root_span)
589                    .await?
590            }
591            Some(FrontendMessage::Parse {
592                name,
593                sql,
594                param_types,
595            }) => self.parse(name, sql, param_types).await?,
596            Some(FrontendMessage::Bind {
597                portal_name,
598                statement_name,
599                param_formats,
600                raw_params,
601                result_formats,
602            }) => {
603                self.bind(
604                    portal_name,
605                    statement_name,
606                    param_formats,
607                    raw_params,
608                    result_formats,
609                )
610                .await?
611            }
612            Some(FrontendMessage::Execute {
613                portal_name,
614                max_rows,
615            }) => {
616                let max_rows = match usize::try_from(max_rows) {
617                    Ok(0) | Err(_) => ExecuteCount::All, // If `max_rows < 0`, no limit.
618                    Ok(n) => ExecuteCount::Count(n),
619                };
620                let execute_root_span =
621                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
622                execute_root_span.follows_from(tracing::Span::current());
623                let state = self
624                    .execute(
625                        portal_name,
626                        max_rows,
627                        portal_exec_message,
628                        None,
629                        ExecuteTimeout::None,
630                        None,
631                        Some(received),
632                    )
633                    .instrument(execute_root_span)
634                    .await?;
635                // In PostgreSQL, when using the extended query protocol, some statements may
636                // trigger an eager commit of the current implicit transaction,
637                // see: <https://git.postgresql.org/gitweb/?p=postgresql.git&a=commitdiff&h=f92944137>.
638                //
639                // In Materialize, however, we eagerly commit every statement outside of an explicit
640                // transaction when using the extended query protocol. This allows us to eliminate
641                // the possibility of a multiple statement implicit transaction, which in turn
642                // allows us to apply single-statement optimizations to queries issued in implicit
643                // transactions in the extended query protocol.
644                //
645                // We don't immediately commit here to allow users to page through the portal if
646                // necessary. Committing the transaction would destroy the portal before the next
647                // Execute command has a chance to resume it. So we instead mark the transaction
648                // for commit the next time that `ensure_transaction` is called.
649                if self.adapter_client.session().transaction().is_implicit() {
650                    self.txn_needs_commit = true;
651                }
652                state
653            }
654            Some(FrontendMessage::DescribeStatement { name }) => {
655                self.describe_statement(&name).await?
656            }
657            Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
658            Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
659            Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
660            Some(FrontendMessage::Flush) => self.flush().await?,
661            Some(FrontendMessage::Sync) => self.sync().await?,
662            Some(FrontendMessage::Terminate) => State::Done,
663
664            Some(FrontendMessage::CopyData(_))
665            | Some(FrontendMessage::CopyDone)
666            | Some(FrontendMessage::CopyFail(_))
667            | Some(FrontendMessage::Password { .. }) => State::Drain,
668            None => State::Done,
669        };
670
671        if let Some(start) = start {
672            self.adapter_client
673                .inner()
674                .metrics()
675                .pgwire_message_processing_seconds
676                .with_label_values(&[message_name])
677                .observe(start.elapsed().as_secs_f64());
678        }
679        self.adapter_client
680            .inner()
681            .metrics()
682            .pgwire_recv_scheduling_delay_ms
683            .with_label_values(&[message_name])
684            .observe(recv_scheduling_delay_ms);
685
686        Ok(next_state)
687    }
688
689    async fn advance_drain(&mut self) -> Result<State, io::Error> {
690        let message = self.conn.recv().await?;
691        if message.is_some() {
692            self.adapter_client
693                .remove_idle_in_transaction_session_timeout();
694        }
695        match message {
696            Some(FrontendMessage::Sync) => self.sync().await,
697            None => Ok(State::Done),
698            _ => Ok(State::Drain),
699        }
700    }
701
702    /// Note that `lifecycle_timestamps` belongs to the whole "Simple Query", because the whole
703    /// Simple Query is received and parsed together. This means that if there are multiple
704    /// statements in a Simple Query, then all of them have the same `lifecycle_timestamps`.
705    #[instrument(level = "debug")]
706    async fn one_query(
707        &mut self,
708        stmt: Statement<Raw>,
709        sql: String,
710        lifecycle_timestamps: LifecycleTimestamps,
711    ) -> Result<State, io::Error> {
712        // Bind the portal. Note that this does not set the empty string prepared
713        // statement.
714        const EMPTY_PORTAL: &str = "";
715        if let Err(e) = self
716            .adapter_client
717            .declare(EMPTY_PORTAL.to_string(), stmt, sql)
718            .await
719        {
720            return self.error(e.into_response(Severity::Error)).await;
721        }
722        let portal = self
723            .adapter_client
724            .session()
725            .get_portal_unverified_mut(EMPTY_PORTAL)
726            .expect("unnamed portal should be present");
727
728        *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
729
730        let stmt_desc = portal.desc.clone();
731        if !stmt_desc.param_types.is_empty() {
732            return self
733                .error(ErrorResponse::error(
734                    SqlState::UNDEFINED_PARAMETER,
735                    "there is no parameter $1",
736                ))
737                .await;
738        }
739
740        // Maybe send row description.
741        if let Some(relation_desc) = &stmt_desc.relation_desc {
742            if !stmt_desc.is_copy {
743                let formats = vec![Format::Text; stmt_desc.arity()];
744                self.send(BackendMessage::RowDescription(
745                    message::encode_row_description(relation_desc, &formats),
746                ))
747                .await?;
748            }
749        }
750
751        let result = match self
752            .adapter_client
753            .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
754            .await
755        {
756            Ok((response, execute_started)) => {
757                self.send_pending_notices().await?;
758                self.send_execute_response(
759                    response,
760                    stmt_desc.relation_desc,
761                    EMPTY_PORTAL.to_string(),
762                    ExecuteCount::All,
763                    portal_exec_message,
764                    None,
765                    ExecuteTimeout::None,
766                    execute_started,
767                )
768                .await
769            }
770            Err(e) => {
771                self.send_pending_notices().await?;
772                self.error(e.into_response(Severity::Error)).await
773            }
774        };
775
776        // Destroy the portal.
777        self.adapter_client.session().remove_portal(EMPTY_PORTAL);
778
779        result
780    }
781
782    async fn ensure_transaction(
783        &mut self,
784        num_stmts: usize,
785        message_type: &str,
786    ) -> Result<(), io::Error> {
787        let start = Instant::now();
788        if self.txn_needs_commit {
789            self.commit_transaction().await?;
790        }
791        // start_transaction can't error (but assert that just in case it changes in
792        // the future.
793        let res = self.adapter_client.start_transaction(Some(num_stmts));
794        assert_ok!(res);
795        self.adapter_client
796            .inner()
797            .metrics()
798            .pgwire_ensure_transaction_seconds
799            .with_label_values(&[message_type])
800            .observe(start.elapsed().as_secs_f64());
801        Ok(())
802    }
803
804    fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
805        let parse_start = Instant::now();
806        let result = match self.adapter_client.parse(sql) {
807            Ok(result) => result.map_err(|e| {
808                // Convert our 0-based byte position to pgwire's 1-based character
809                // position.
810                let pos = sql[..e.error.pos].chars().count() + 1;
811                ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
812            }),
813            Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
814        };
815        self.adapter_client
816            .inner()
817            .metrics()
818            .parse_seconds
819            .with_label_values(&[])
820            .observe(parse_start.elapsed().as_secs_f64());
821        result
822    }
823
824    /// Executes a "Simple Query", see
825    /// <https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SIMPLE-QUERY>
826    ///
827    /// For implicit transaction handling, see "Multiple Statements in a Simple Query" in the above.
828    #[instrument(level = "debug")]
829    async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
830        // Parse first before doing any transaction checking.
831        let stmts = match self.parse_sql(&sql) {
832            Ok(stmts) => stmts,
833            Err(err) => {
834                self.error(err).await?;
835                return self.ready().await;
836            }
837        };
838
839        let num_stmts = stmts.len();
840
841        // Compare with postgres' backend/tcop/postgres.c exec_simple_query.
842        for StatementParseResult { ast: stmt, sql } in stmts {
843            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
844            if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
845                self.aborted_txn_error().await?;
846                break;
847            }
848
849            // Start an implicit transaction if we aren't in any transaction and there's
850            // more than one statement. This mirrors the `use_implicit_block` variable in
851            // postgres.
852            //
853            // This needs to be done in the loop instead of once at the top because
854            // a COMMIT/ROLLBACK statement needs to start a new transaction on next
855            // statement.
856            self.ensure_transaction(num_stmts, "query").await?;
857
858            match self
859                .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
860                .await?
861            {
862                State::Ready => (),
863                State::Drain => break,
864                State::Done => return Ok(State::Done),
865            }
866        }
867
868        // Implicit transactions are closed at the end of a Query message.
869        {
870            if self.adapter_client.session().transaction().is_implicit() {
871                self.commit_transaction().await?;
872            }
873        }
874
875        if num_stmts == 0 {
876            self.send(BackendMessage::EmptyQueryResponse).await?;
877        }
878
879        self.ready().await
880    }
881
882    #[instrument(level = "debug")]
883    async fn parse(
884        &mut self,
885        name: String,
886        sql: String,
887        param_oids: Vec<u32>,
888    ) -> Result<State, io::Error> {
889        // Start a transaction if we aren't in one.
890        self.ensure_transaction(1, "parse").await?;
891
892        let mut param_types = vec![];
893        for oid in param_oids {
894            match mz_pgrepr::Type::from_oid(oid) {
895                Ok(ty) => match ScalarType::try_from(&ty) {
896                    Ok(ty) => param_types.push(Some(ty)),
897                    Err(err) => {
898                        return self
899                            .error(ErrorResponse::error(
900                                SqlState::INVALID_PARAMETER_VALUE,
901                                err.to_string(),
902                            ))
903                            .await;
904                    }
905                },
906                Err(_) if oid == 0 => param_types.push(None),
907                Err(e) => {
908                    return self
909                        .error(ErrorResponse::error(
910                            SqlState::PROTOCOL_VIOLATION,
911                            e.to_string(),
912                        ))
913                        .await;
914                }
915            }
916        }
917
918        let stmts = match self.parse_sql(&sql) {
919            Ok(stmts) => stmts,
920            Err(err) => {
921                return self.error(err).await;
922            }
923        };
924        if stmts.len() > 1 {
925            return self
926                .error(ErrorResponse::error(
927                    SqlState::INTERNAL_ERROR,
928                    "cannot insert multiple commands into a prepared statement",
929                ))
930                .await;
931        }
932        let (maybe_stmt, sql) = match stmts.into_iter().next() {
933            None => (None, ""),
934            Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
935        };
936        if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
937            return self.aborted_txn_error().await;
938        }
939        match self
940            .adapter_client
941            .prepare(name, maybe_stmt, sql.to_string(), param_types)
942            .await
943        {
944            Ok(()) => {
945                self.send(BackendMessage::ParseComplete).await?;
946                Ok(State::Ready)
947            }
948            Err(e) => self.error(e.into_response(Severity::Error)).await,
949        }
950    }
951
952    /// Commits and clears the current transaction.
953    #[instrument(level = "debug")]
954    async fn commit_transaction(&mut self) -> Result<(), io::Error> {
955        self.end_transaction(EndTransactionAction::Commit).await
956    }
957
958    /// Rollback and clears the current transaction.
959    #[instrument(level = "debug")]
960    async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
961        self.end_transaction(EndTransactionAction::Rollback).await
962    }
963
964    /// End a transaction and report to the user if an error occurred.
965    #[instrument(level = "debug")]
966    async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
967        self.txn_needs_commit = false;
968        let resp = self.adapter_client.end_transaction(action).await;
969        if let Err(err) = resp {
970            self.send(BackendMessage::ErrorResponse(
971                err.into_response(Severity::Error),
972            ))
973            .await?;
974        }
975        Ok(())
976    }
977
978    #[instrument(level = "debug")]
979    async fn bind(
980        &mut self,
981        portal_name: String,
982        statement_name: String,
983        param_formats: Vec<Format>,
984        raw_params: Vec<Option<Vec<u8>>>,
985        result_formats: Vec<Format>,
986    ) -> Result<State, io::Error> {
987        // Start a transaction if we aren't in one.
988        self.ensure_transaction(1, "bind").await?;
989
990        let aborted_txn = self.is_aborted_txn();
991        let stmt = match self
992            .adapter_client
993            .get_prepared_statement(&statement_name)
994            .await
995        {
996            Ok(stmt) => stmt,
997            Err(err) => return self.error(err.into_response(Severity::Error)).await,
998        };
999
1000        let param_types = &stmt.desc().param_types;
1001        if param_types.len() != raw_params.len() {
1002            let message = format!(
1003                "bind message supplies {actual} parameters, \
1004                 but prepared statement \"{name}\" requires {expected}",
1005                name = statement_name,
1006                actual = raw_params.len(),
1007                expected = param_types.len()
1008            );
1009            return self
1010                .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, message))
1011                .await;
1012        }
1013        let param_formats = match pad_formats(param_formats, raw_params.len()) {
1014            Ok(param_formats) => param_formats,
1015            Err(msg) => {
1016                return self
1017                    .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1018                    .await;
1019            }
1020        };
1021        if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1022            return self.aborted_txn_error().await;
1023        }
1024        let buf = RowArena::new();
1025        let mut params = vec![];
1026        for ((raw_param, mz_typ), format) in raw_params
1027            .into_iter()
1028            .zip_eq(param_types)
1029            .zip_eq(param_formats)
1030        {
1031            let pg_typ = mz_pgrepr::Type::from(mz_typ);
1032            let datum = match raw_param {
1033                None => Datum::Null,
1034                Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1035                    Ok(param) => param.into_datum(&buf, &pg_typ),
1036                    Err(err) => {
1037                        let msg = format!("unable to decode parameter: {}", err);
1038                        return self
1039                            .error(ErrorResponse::error(SqlState::INVALID_PARAMETER_VALUE, msg))
1040                            .await;
1041                    }
1042                },
1043            };
1044            params.push((datum, mz_typ.clone()))
1045        }
1046
1047        let result_formats = match pad_formats(
1048            result_formats,
1049            stmt.desc()
1050                .relation_desc
1051                .clone()
1052                .map(|desc| desc.typ().column_types.len())
1053                .unwrap_or(0),
1054        ) {
1055            Ok(result_formats) => result_formats,
1056            Err(msg) => {
1057                return self
1058                    .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1059                    .await;
1060            }
1061        };
1062
1063        // Binary encodings are disabled for list, map, and aclitem types, but this doesn't
1064        // apply to COPY TO statements.
1065        if !stmt.stmt().map_or(false, |stmt| {
1066            matches!(
1067                stmt,
1068                Statement::Copy(CopyStatement {
1069                    direction: CopyDirection::To,
1070                    ..
1071                })
1072            )
1073        }) {
1074            if let Some(desc) = stmt.desc().relation_desc.clone() {
1075                for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1076                    match (format, &ty.scalar_type) {
1077                        (Format::Binary, mz_repr::ScalarType::List { .. }) => {
1078                            return self
1079                                .error(ErrorResponse::error(
1080                                    SqlState::PROTOCOL_VIOLATION,
1081                                    "binary encoding of list types is not implemented",
1082                                ))
1083                                .await;
1084                        }
1085                        (Format::Binary, mz_repr::ScalarType::Map { .. }) => {
1086                            return self
1087                                .error(ErrorResponse::error(
1088                                    SqlState::PROTOCOL_VIOLATION,
1089                                    "binary encoding of map types is not implemented",
1090                                ))
1091                                .await;
1092                        }
1093                        (Format::Binary, mz_repr::ScalarType::AclItem) => {
1094                            return self
1095                                .error(ErrorResponse::error(
1096                                    SqlState::PROTOCOL_VIOLATION,
1097                                    "binary encoding of aclitem types does not exist",
1098                                ))
1099                                .await;
1100                        }
1101                        _ => (),
1102                    }
1103                }
1104            }
1105        }
1106
1107        let desc = stmt.desc().clone();
1108        let logging = Arc::clone(stmt.logging());
1109        let stmt_ast = stmt.stmt().cloned();
1110        let state_revision = stmt.state_revision;
1111        if let Err(err) = self.adapter_client.session().set_portal(
1112            portal_name,
1113            desc,
1114            stmt_ast,
1115            logging,
1116            params,
1117            result_formats,
1118            state_revision,
1119        ) {
1120            return self.error(err.into_response(Severity::Error)).await;
1121        }
1122
1123        self.send(BackendMessage::BindComplete).await?;
1124        Ok(State::Ready)
1125    }
1126
1127    fn execute(
1128        &mut self,
1129        portal_name: String,
1130        max_rows: ExecuteCount,
1131        get_response: GetResponse,
1132        fetch_portal_name: Option<String>,
1133        timeout: ExecuteTimeout,
1134        outer_ctx_extra: Option<ExecuteContextExtra>,
1135        received: Option<EpochMillis>,
1136    ) -> BoxFuture<'_, Result<State, io::Error>> {
1137        async move {
1138            let aborted_txn = self.is_aborted_txn();
1139
1140            // Check if the portal has been started and can be continued.
1141            let portal = match self
1142                .adapter_client
1143                .session()
1144                .get_portal_unverified_mut(&portal_name)
1145            {
1146                Some(portal) => portal,
1147                None => {
1148                    let msg = format!("portal {} does not exist", portal_name.quoted());
1149                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1150                        self.adapter_client.retire_execute(
1151                            outer_ctx_extra,
1152                            StatementEndedExecutionReason::Errored { error: msg.clone() },
1153                        );
1154                    }
1155                    return self
1156                        .error(ErrorResponse::error(SqlState::INVALID_CURSOR_NAME, msg))
1157                        .await;
1158                }
1159            };
1160
1161            *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1162
1163            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1164            let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1165            if aborted_txn && !txn_exit_stmt {
1166                if let Some(outer_ctx_extra) = outer_ctx_extra {
1167                    self.adapter_client.retire_execute(
1168                        outer_ctx_extra,
1169                        StatementEndedExecutionReason::Errored {
1170                            error: ABORTED_TXN_MSG.to_string(),
1171                        },
1172                    );
1173                }
1174                return self.aborted_txn_error().await;
1175            }
1176
1177            let row_desc = portal.desc.relation_desc.clone();
1178            match portal.state {
1179                PortalState::NotStarted => {
1180                    // Start a transaction if we aren't in one.
1181                    self.ensure_transaction(1, "execute").await?;
1182                    match self
1183                        .adapter_client
1184                        .execute(
1185                            portal_name.clone(),
1186                            self.conn.wait_closed(),
1187                            outer_ctx_extra,
1188                        )
1189                        .await
1190                    {
1191                        Ok((response, execute_started)) => {
1192                            self.send_pending_notices().await?;
1193                            self.send_execute_response(
1194                                response,
1195                                row_desc,
1196                                portal_name,
1197                                max_rows,
1198                                get_response,
1199                                fetch_portal_name,
1200                                timeout,
1201                                execute_started,
1202                            )
1203                            .await
1204                        }
1205                        Err(e) => {
1206                            self.send_pending_notices().await?;
1207                            self.error(e.into_response(Severity::Error)).await
1208                        }
1209                    }
1210                }
1211                PortalState::InProgress(rows) => {
1212                    let rows = rows.take().expect("InProgress rows must be populated");
1213                    let (result, statement_ended_execution_reason) = match self
1214                        .send_rows(
1215                            row_desc.expect("portal missing row desc on resumption"),
1216                            portal_name,
1217                            rows,
1218                            max_rows,
1219                            get_response,
1220                            fetch_portal_name,
1221                            timeout,
1222                        )
1223                        .await
1224                    {
1225                        Err(e) => {
1226                            // This is an error communicating with the connection.
1227                            // We consider that to be a cancelation, rather than a query error.
1228                            (Err(e), StatementEndedExecutionReason::Canceled)
1229                        }
1230                        Ok((ok, SendRowsEndedReason::Canceled)) => {
1231                            (Ok(ok), StatementEndedExecutionReason::Canceled)
1232                        }
1233                        // NOTE: For now the values for `result_size` and
1234                        // `rows_returned` in fetches are a bit confusing.
1235                        // We record `Some(n)` for the first fetch, where `n` is
1236                        // the number of bytes/rows returned by the inner
1237                        // execute (regardless of how many rows the
1238                        // fetch fetched), and `None` for subsequent fetches.
1239                        //
1240                        // This arguably makes sense since the size/rows
1241                        // returned measures how much work the compute
1242                        // layer had to do to satisfy the query, but
1243                        // we should revisit it if/when we start
1244                        // logging the inner execute separately.
1245                        Ok((
1246                            ok,
1247                            SendRowsEndedReason::Success {
1248                                result_size: _,
1249                                rows_returned: _,
1250                            },
1251                        )) => (
1252                            Ok(ok),
1253                            StatementEndedExecutionReason::Success {
1254                                result_size: None,
1255                                rows_returned: None,
1256                                execution_strategy: None,
1257                            },
1258                        ),
1259                        Ok((ok, SendRowsEndedReason::Errored { error })) => {
1260                            (Ok(ok), StatementEndedExecutionReason::Errored { error })
1261                        }
1262                    };
1263                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1264                        self.adapter_client
1265                            .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1266                    }
1267                    result
1268                }
1269                // FETCH is an awkward command for our current architecture. In Postgres it
1270                // will extract <count> rows from the target portal, cache them, and return
1271                // them to the user as requested. Its command tag is always FETCH <num rows
1272                // extracted>. In Materialize, since we have chosen to not fully support FETCH,
1273                // we must remember the number of rows that were returned. Use this tag to
1274                // remember that information and return it.
1275                PortalState::Completed(Some(tag)) => {
1276                    let tag = tag.to_string();
1277                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1278                        self.adapter_client.retire_execute(
1279                            outer_ctx_extra,
1280                            StatementEndedExecutionReason::Success {
1281                                result_size: None,
1282                                rows_returned: None,
1283                                execution_strategy: None,
1284                            },
1285                        );
1286                    }
1287                    self.send(BackendMessage::CommandComplete { tag }).await?;
1288                    Ok(State::Ready)
1289                }
1290                PortalState::Completed(None) => {
1291                    let error = format!(
1292                        "portal {} cannot be run",
1293                        Ident::new_unchecked(portal_name).to_ast_string_stable()
1294                    );
1295                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1296                        self.adapter_client.retire_execute(
1297                            outer_ctx_extra,
1298                            StatementEndedExecutionReason::Errored {
1299                                error: error.clone(),
1300                            },
1301                        );
1302                    }
1303                    self.error(ErrorResponse::error(
1304                        SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1305                        error,
1306                    ))
1307                    .await
1308                }
1309            }
1310        }
1311        .instrument(debug_span!("execute"))
1312        .boxed()
1313    }
1314
1315    #[instrument(level = "debug")]
1316    async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1317        // Start a transaction if we aren't in one.
1318        self.ensure_transaction(1, "describe_statement").await?;
1319
1320        let stmt = match self.adapter_client.get_prepared_statement(name).await {
1321            Ok(stmt) => stmt,
1322            Err(err) => return self.error(err.into_response(Severity::Error)).await,
1323        };
1324        // Cloning to avoid a mutable borrow issue because `send` also uses `adapter_client`
1325        let parameter_desc = BackendMessage::ParameterDescription(
1326            stmt.desc()
1327                .param_types
1328                .iter()
1329                .map(mz_pgrepr::Type::from)
1330                .collect(),
1331        );
1332        // Claim that all results will be output in text format, even
1333        // though the true result formats are not yet known. A bit
1334        // weird, but this is the behavior that PostgreSQL specifies.
1335        let formats = vec![Format::Text; stmt.desc().arity()];
1336        let row_desc = describe_rows(stmt.desc(), &formats);
1337        self.send_all([parameter_desc, row_desc]).await?;
1338        Ok(State::Ready)
1339    }
1340
1341    #[instrument(level = "debug")]
1342    async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1343        // Start a transaction if we aren't in one.
1344        self.ensure_transaction(1, "describe_portal").await?;
1345
1346        let session = self.adapter_client.session();
1347        let row_desc = session
1348            .get_portal_unverified(name)
1349            .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1350        match row_desc {
1351            Some(row_desc) => {
1352                self.send(row_desc).await?;
1353                Ok(State::Ready)
1354            }
1355            None => {
1356                self.error(ErrorResponse::error(
1357                    SqlState::INVALID_CURSOR_NAME,
1358                    format!("portal {} does not exist", name.quoted()),
1359                ))
1360                .await
1361            }
1362        }
1363    }
1364
1365    #[instrument(level = "debug")]
1366    async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1367        self.adapter_client
1368            .session()
1369            .remove_prepared_statement(&name);
1370        self.send(BackendMessage::CloseComplete).await?;
1371        Ok(State::Ready)
1372    }
1373
1374    #[instrument(level = "debug")]
1375    async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1376        self.adapter_client.session().remove_portal(&name);
1377        self.send(BackendMessage::CloseComplete).await?;
1378        Ok(State::Ready)
1379    }
1380
1381    fn complete_portal(&mut self, name: &str) {
1382        let portal = self
1383            .adapter_client
1384            .session()
1385            .get_portal_unverified_mut(name)
1386            .expect("portal should exist");
1387        *portal.state = PortalState::Completed(None);
1388    }
1389
1390    async fn fetch(
1391        &mut self,
1392        name: String,
1393        count: Option<FetchDirection>,
1394        max_rows: ExecuteCount,
1395        fetch_portal_name: Option<String>,
1396        timeout: ExecuteTimeout,
1397        ctx_extra: ExecuteContextExtra,
1398    ) -> Result<State, io::Error> {
1399        // Unlike Execute, no count specified in FETCH returns 1 row, and 0 means 0
1400        // instead of All.
1401        let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1402
1403        // Figure out how many rows we should send back by looking at the various
1404        // combinations of the execute and fetch.
1405        //
1406        // In Postgres, Fetch will cache <count> rows from the target portal and
1407        // return those as requested (if, say, an Execute message was sent with a
1408        // max_rows < the Fetch's count). We expect that case to be incredibly rare and
1409        // so have chosen to not support it until users request it. This eases
1410        // implementation difficulty since we don't have to be able to "send" rows to
1411        // a buffer.
1412        //
1413        // TODO(mjibson): Test this somehow? Need to divide up the pgtest files in
1414        // order to have some that are not Postgres compatible.
1415        let count = match (max_rows, count) {
1416            (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1417                let count = usize::cast_from(count);
1418                if max_rows < count {
1419                    let msg = "Execute with max_rows < a FETCH's count is not supported";
1420                    self.adapter_client.retire_execute(
1421                        ctx_extra,
1422                        StatementEndedExecutionReason::Errored {
1423                            error: msg.to_string(),
1424                        },
1425                    );
1426                    return self
1427                        .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1428                        .await;
1429                }
1430                ExecuteCount::Count(count)
1431            }
1432            (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1433                let msg = "Execute with max_rows of a FETCH ALL is not supported";
1434                self.adapter_client.retire_execute(
1435                    ctx_extra,
1436                    StatementEndedExecutionReason::Errored {
1437                        error: msg.to_string(),
1438                    },
1439                );
1440                return self
1441                    .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1442                    .await;
1443            }
1444            (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1445            (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1446                ExecuteCount::Count(usize::cast_from(count))
1447            }
1448        };
1449        let cursor_name = name.to_string();
1450        self.execute(
1451            cursor_name,
1452            count,
1453            fetch_message,
1454            fetch_portal_name,
1455            timeout,
1456            Some(ctx_extra),
1457            None,
1458        )
1459        .await
1460    }
1461
1462    async fn flush(&mut self) -> Result<State, io::Error> {
1463        self.conn.flush().await?;
1464        Ok(State::Ready)
1465    }
1466
1467    /// Sends a backend message to the client, after applying a severity filter.
1468    ///
1469    /// The message is only sent if its severity is above the severity set
1470    /// in the session, with the default value being NOTICE.
1471    #[instrument(level = "debug")]
1472    async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1473    where
1474        M: Into<BackendMessage>,
1475    {
1476        let message: BackendMessage = message.into();
1477        let is_error =
1478            matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1479
1480        self.conn.send(message).await?;
1481
1482        // Flush immediately after sending an error response, as some clients
1483        // expect to be able to read the error response before sending a Sync
1484        // message. This is arguably in violation of the protocol specification,
1485        // but the specification is somewhat ambiguous, and easier to match
1486        // PostgreSQL here than to fix all the clients that have this
1487        // expectation.
1488        if is_error {
1489            self.conn.flush().await?;
1490        }
1491
1492        Ok(())
1493    }
1494
1495    #[instrument(level = "debug")]
1496    pub async fn send_all(
1497        &mut self,
1498        messages: impl IntoIterator<Item = BackendMessage>,
1499    ) -> Result<(), io::Error> {
1500        for m in messages {
1501            self.send(m).await?;
1502        }
1503        Ok(())
1504    }
1505
1506    #[instrument(level = "debug")]
1507    async fn sync(&mut self) -> Result<State, io::Error> {
1508        // Close the current transaction if we are in an implicit transaction.
1509        if self.adapter_client.session().transaction().is_implicit() {
1510            self.commit_transaction().await?;
1511        }
1512        self.ready().await
1513    }
1514
1515    #[instrument(level = "debug")]
1516    async fn ready(&mut self) -> Result<State, io::Error> {
1517        let txn_state = self.adapter_client.session().transaction().into();
1518        self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1519        self.flush().await
1520    }
1521
1522    #[allow(clippy::too_many_arguments)]
1523    #[instrument(level = "debug")]
1524    async fn send_execute_response(
1525        &mut self,
1526        response: ExecuteResponse,
1527        row_desc: Option<RelationDesc>,
1528        portal_name: String,
1529        max_rows: ExecuteCount,
1530        get_response: GetResponse,
1531        fetch_portal_name: Option<String>,
1532        timeout: ExecuteTimeout,
1533        execute_started: Instant,
1534    ) -> Result<State, io::Error> {
1535        let mut tag = response.tag();
1536
1537        macro_rules! command_complete {
1538            () => {{
1539                self.send(BackendMessage::CommandComplete {
1540                    tag: tag
1541                        .take()
1542                        .expect("command_complete only called on tag-generating results"),
1543                })
1544                .await?;
1545                Ok(State::Ready)
1546            }};
1547        }
1548
1549        let r = match response {
1550            ExecuteResponse::ClosedCursor => {
1551                self.complete_portal(&portal_name);
1552                command_complete!()
1553            }
1554            ExecuteResponse::DeclaredCursor => {
1555                self.complete_portal(&portal_name);
1556                command_complete!()
1557            }
1558            ExecuteResponse::EmptyQuery => {
1559                self.send(BackendMessage::EmptyQueryResponse).await?;
1560                Ok(State::Ready)
1561            }
1562            ExecuteResponse::Fetch {
1563                name,
1564                count,
1565                timeout,
1566                ctx_extra,
1567            } => {
1568                self.fetch(
1569                    name,
1570                    count,
1571                    max_rows,
1572                    Some(portal_name.to_string()),
1573                    timeout,
1574                    ctx_extra,
1575                )
1576                .await
1577            }
1578            ExecuteResponse::SendingRowsStreaming {
1579                rows,
1580                instance_id,
1581                strategy,
1582            } => {
1583                let row_desc = row_desc
1584                    .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1585
1586                let span = tracing::debug_span!("sending_rows_streaming");
1587
1588                self.send_rows(
1589                    row_desc,
1590                    portal_name,
1591                    InProgressRows::new(RecordFirstRowStream::new(
1592                        Box::new(rows),
1593                        execute_started,
1594                        &self.adapter_client,
1595                        Some(instance_id),
1596                        Some(strategy),
1597                    )),
1598                    max_rows,
1599                    get_response,
1600                    fetch_portal_name,
1601                    timeout,
1602                )
1603                .instrument(span)
1604                .await
1605                .map(|(state, _)| state)
1606            }
1607            ExecuteResponse::SendingRowsImmediate { rows } => {
1608                let row_desc = row_desc
1609                    .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1610
1611                let span = tracing::debug_span!("sending_rows_immediate");
1612
1613                let stream =
1614                    futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1615                self.send_rows(
1616                    row_desc,
1617                    portal_name,
1618                    InProgressRows::new(RecordFirstRowStream::new(
1619                        Box::new(stream),
1620                        execute_started,
1621                        &self.adapter_client,
1622                        None,
1623                        Some(StatementExecutionStrategy::Constant),
1624                    )),
1625                    max_rows,
1626                    get_response,
1627                    fetch_portal_name,
1628                    timeout,
1629                )
1630                .instrument(span)
1631                .await
1632                .map(|(state, _)| state)
1633            }
1634            ExecuteResponse::SetVariable { name, .. } => {
1635                // This code is somewhat awkwardly structured because we
1636                // can't hold `var` across an await point.
1637                let qn = name.to_string();
1638                let msg = if let Some(var) = self
1639                    .adapter_client
1640                    .session()
1641                    .vars_mut()
1642                    .notify_set()
1643                    .find(|v| v.name() == qn)
1644                {
1645                    Some(BackendMessage::ParameterStatus(var.name(), var.value()))
1646                } else {
1647                    None
1648                };
1649                if let Some(msg) = msg {
1650                    self.send(msg).await?;
1651                }
1652                command_complete!()
1653            }
1654            ExecuteResponse::Subscribing {
1655                rx,
1656                ctx_extra,
1657                instance_id,
1658            } => {
1659                if fetch_portal_name.is_none() {
1660                    let mut msg = ErrorResponse::notice(
1661                        SqlState::WARNING,
1662                        "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
1663                    );
1664                    if self.adapter_client.session().vars().application_name() == "psql" {
1665                        msg.hint = Some(
1666                            "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
1667                                .into(),
1668                        )
1669                    }
1670                    self.send(msg).await?;
1671                    self.conn.flush().await?;
1672                }
1673                let row_desc =
1674                    row_desc.expect("missing row description for ExecuteResponse::Subscribing");
1675                let (result, statement_ended_execution_reason) = match self
1676                    .send_rows(
1677                        row_desc,
1678                        portal_name,
1679                        InProgressRows::new(RecordFirstRowStream::new(
1680                            Box::new(UnboundedReceiverStream::new(rx)),
1681                            execute_started,
1682                            &self.adapter_client,
1683                            Some(instance_id),
1684                            None,
1685                        )),
1686                        max_rows,
1687                        get_response,
1688                        fetch_portal_name,
1689                        timeout,
1690                    )
1691                    .await
1692                {
1693                    Err(e) => {
1694                        // This is an error communicating with the connection.
1695                        // We consider that to be a cancelation, rather than a query error.
1696                        (Err(e), StatementEndedExecutionReason::Canceled)
1697                    }
1698                    Ok((ok, SendRowsEndedReason::Canceled)) => {
1699                        (Ok(ok), StatementEndedExecutionReason::Canceled)
1700                    }
1701                    Ok((
1702                        ok,
1703                        SendRowsEndedReason::Success {
1704                            result_size,
1705                            rows_returned,
1706                        },
1707                    )) => (
1708                        Ok(ok),
1709                        StatementEndedExecutionReason::Success {
1710                            result_size: Some(result_size),
1711                            rows_returned: Some(rows_returned),
1712                            execution_strategy: None,
1713                        },
1714                    ),
1715                    Ok((ok, SendRowsEndedReason::Errored { error })) => {
1716                        (Ok(ok), StatementEndedExecutionReason::Errored { error })
1717                    }
1718                };
1719                self.adapter_client
1720                    .retire_execute(ctx_extra, statement_ended_execution_reason);
1721                return result;
1722            }
1723            ExecuteResponse::CopyTo { format, resp } => {
1724                let row_desc =
1725                    row_desc.expect("missing row description for ExecuteResponse::CopyTo");
1726                match *resp {
1727                    ExecuteResponse::Subscribing {
1728                        rx,
1729                        ctx_extra,
1730                        instance_id,
1731                    } => {
1732                        let (result, statement_ended_execution_reason) = match self
1733                            .copy_rows(
1734                                format,
1735                                row_desc,
1736                                RecordFirstRowStream::new(
1737                                    Box::new(UnboundedReceiverStream::new(rx)),
1738                                    execute_started,
1739                                    &self.adapter_client,
1740                                    Some(instance_id),
1741                                    None,
1742                                ),
1743                            )
1744                            .await
1745                        {
1746                            Err(e) => {
1747                                // This is an error communicating with the connection.
1748                                // We consider that to be a cancelation, rather than a query error.
1749                                (Err(e), StatementEndedExecutionReason::Canceled)
1750                            }
1751                            Ok((
1752                                state,
1753                                SendRowsEndedReason::Success {
1754                                    result_size,
1755                                    rows_returned,
1756                                },
1757                            )) => (
1758                                Ok(state),
1759                                StatementEndedExecutionReason::Success {
1760                                    result_size: Some(result_size),
1761                                    rows_returned: Some(rows_returned),
1762                                    execution_strategy: None,
1763                                },
1764                            ),
1765                            Ok((state, SendRowsEndedReason::Errored { error })) => {
1766                                (Ok(state), StatementEndedExecutionReason::Errored { error })
1767                            }
1768                            Ok((state, SendRowsEndedReason::Canceled)) => {
1769                                (Ok(state), StatementEndedExecutionReason::Canceled)
1770                            }
1771                        };
1772                        self.adapter_client
1773                            .retire_execute(ctx_extra, statement_ended_execution_reason);
1774                        return result;
1775                    }
1776                    ExecuteResponse::SendingRowsStreaming {
1777                        rows,
1778                        instance_id,
1779                        strategy,
1780                    } => {
1781                        // We don't need to finalize execution here;
1782                        // it was already done in the
1783                        // coordinator. Just extract the state and
1784                        // return that.
1785                        return self
1786                            .copy_rows(
1787                                format,
1788                                row_desc,
1789                                RecordFirstRowStream::new(
1790                                    Box::new(rows),
1791                                    execute_started,
1792                                    &self.adapter_client,
1793                                    Some(instance_id),
1794                                    Some(strategy),
1795                                ),
1796                            )
1797                            .await
1798                            .map(|(state, _)| state);
1799                    }
1800                    ExecuteResponse::SendingRowsImmediate { rows } => {
1801                        let span = tracing::debug_span!("sending_rows_immediate");
1802
1803                        let rows = futures::stream::once(futures::future::ready(
1804                            PeekResponseUnary::Rows(rows),
1805                        ));
1806                        // We don't need to finalize execution here;
1807                        // it was already done in the
1808                        // coordinator. Just extract the state and
1809                        // return that.
1810                        return self
1811                            .copy_rows(
1812                                format,
1813                                row_desc,
1814                                RecordFirstRowStream::new(
1815                                    Box::new(rows),
1816                                    execute_started,
1817                                    &self.adapter_client,
1818                                    None,
1819                                    Some(StatementExecutionStrategy::Constant),
1820                                ),
1821                            )
1822                            .instrument(span)
1823                            .await
1824                            .map(|(state, _)| state);
1825                    }
1826                    _ => {
1827                        return self
1828                            .error(ErrorResponse::error(
1829                                SqlState::INTERNAL_ERROR,
1830                                "unsupported COPY response type".to_string(),
1831                            ))
1832                            .await;
1833                    }
1834                };
1835            }
1836            ExecuteResponse::CopyFrom {
1837                id,
1838                columns,
1839                params,
1840                ctx_extra,
1841            } => {
1842                let row_desc =
1843                    row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
1844                self.copy_from(id, columns, params, row_desc, ctx_extra)
1845                    .await
1846            }
1847            ExecuteResponse::TransactionCommitted { params }
1848            | ExecuteResponse::TransactionRolledBack { params } => {
1849                let notify_set: mz_ore::collections::HashSet<String> = self
1850                    .adapter_client
1851                    .session()
1852                    .vars()
1853                    .notify_set()
1854                    .map(|v| v.name().to_string())
1855                    .collect();
1856
1857                // Only report on parameters that are in the notify set.
1858                for (name, value) in params
1859                    .into_iter()
1860                    .filter(|(name, _v)| notify_set.contains(*name))
1861                {
1862                    let msg = BackendMessage::ParameterStatus(name, value);
1863                    self.send(msg).await?;
1864                }
1865                command_complete!()
1866            }
1867
1868            ExecuteResponse::AlteredDefaultPrivileges
1869            | ExecuteResponse::AlteredObject(..)
1870            | ExecuteResponse::AlteredRole
1871            | ExecuteResponse::AlteredSystemConfiguration
1872            | ExecuteResponse::CreatedCluster { .. }
1873            | ExecuteResponse::CreatedClusterReplica { .. }
1874            | ExecuteResponse::CreatedConnection { .. }
1875            | ExecuteResponse::CreatedDatabase { .. }
1876            | ExecuteResponse::CreatedIndex { .. }
1877            | ExecuteResponse::CreatedIntrospectionSubscribe
1878            | ExecuteResponse::CreatedMaterializedView { .. }
1879            | ExecuteResponse::CreatedContinualTask { .. }
1880            | ExecuteResponse::CreatedRole
1881            | ExecuteResponse::CreatedSchema { .. }
1882            | ExecuteResponse::CreatedSecret { .. }
1883            | ExecuteResponse::CreatedSink { .. }
1884            | ExecuteResponse::CreatedSource { .. }
1885            | ExecuteResponse::CreatedTable { .. }
1886            | ExecuteResponse::CreatedType
1887            | ExecuteResponse::CreatedView { .. }
1888            | ExecuteResponse::CreatedViews { .. }
1889            | ExecuteResponse::CreatedNetworkPolicy
1890            | ExecuteResponse::Comment
1891            | ExecuteResponse::Deallocate { .. }
1892            | ExecuteResponse::Deleted(..)
1893            | ExecuteResponse::DiscardedAll
1894            | ExecuteResponse::DiscardedTemp
1895            | ExecuteResponse::DroppedObject(_)
1896            | ExecuteResponse::DroppedOwned
1897            | ExecuteResponse::GrantedPrivilege
1898            | ExecuteResponse::GrantedRole
1899            | ExecuteResponse::Inserted(..)
1900            | ExecuteResponse::Copied(..)
1901            | ExecuteResponse::Prepare
1902            | ExecuteResponse::Raised
1903            | ExecuteResponse::ReassignOwned
1904            | ExecuteResponse::RevokedPrivilege
1905            | ExecuteResponse::RevokedRole
1906            | ExecuteResponse::StartedTransaction { .. }
1907            | ExecuteResponse::Updated(..)
1908            | ExecuteResponse::ValidatedConnection => {
1909                command_complete!()
1910            }
1911        };
1912
1913        assert_none!(tag, "tag created but not consumed: {:?}", tag);
1914        r
1915    }
1916
1917    #[allow(clippy::too_many_arguments)]
1918    // TODO(guswynn): figure out how to get it to compile without skip_all
1919    #[mz_ore::instrument(level = "debug")]
1920    async fn send_rows(
1921        &mut self,
1922        row_desc: RelationDesc,
1923        portal_name: String,
1924        mut rows: InProgressRows,
1925        max_rows: ExecuteCount,
1926        get_response: GetResponse,
1927        fetch_portal_name: Option<String>,
1928        timeout: ExecuteTimeout,
1929    ) -> Result<(State, SendRowsEndedReason), io::Error> {
1930        // If this portal is being executed from a FETCH then we need to use the result
1931        // format type of the outer portal.
1932        let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
1933            name
1934        } else {
1935            &portal_name
1936        };
1937        let result_formats = self
1938            .adapter_client
1939            .session()
1940            .get_portal_unverified(result_format_portal_name)
1941            .expect("valid fetch portal name for send rows")
1942            .result_formats
1943            .clone();
1944
1945        let (mut wait_once, mut deadline) = match timeout {
1946            ExecuteTimeout::None => (false, None),
1947            ExecuteTimeout::Seconds(t) => (
1948                false,
1949                Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
1950            ),
1951            ExecuteTimeout::WaitOnce => (true, None),
1952        };
1953
1954        // Sanity check that the various `RelationDesc`s match up.
1955        {
1956            let portal_name_desc = &self
1957                .adapter_client
1958                .session()
1959                .get_portal_unverified(portal_name.as_str())
1960                .expect("portal should exist")
1961                .desc
1962                .relation_desc;
1963            if let Some(portal_name_desc) = portal_name_desc {
1964                soft_assert_eq_or_log!(portal_name_desc, &row_desc);
1965            }
1966            if let Some(fetch_portal_name) = &fetch_portal_name {
1967                let fetch_portal_desc = &self
1968                    .adapter_client
1969                    .session()
1970                    .get_portal_unverified(fetch_portal_name)
1971                    .expect("portal should exist")
1972                    .desc
1973                    .relation_desc;
1974                if let Some(fetch_portal_desc) = fetch_portal_desc {
1975                    soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
1976                }
1977            }
1978        }
1979
1980        self.conn.set_encode_state(
1981            row_desc
1982                .typ()
1983                .column_types
1984                .iter()
1985                .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
1986                .zip_eq(result_formats)
1987                .collect(),
1988        );
1989
1990        let mut total_sent_rows = 0;
1991        let mut total_sent_bytes = 0;
1992        // want_rows is the maximum number of rows the client wants.
1993        let mut want_rows = match max_rows {
1994            ExecuteCount::All => usize::MAX,
1995            ExecuteCount::Count(count) => count,
1996        };
1997
1998        // Send rows while the client still wants them and there are still rows to send.
1999        loop {
2000            // Fetch next batch of rows, waiting for a possible requested
2001            // timeout or notice.
2002            let batch = if rows.current.is_some() {
2003                FetchResult::Rows(rows.current.take())
2004            } else if want_rows == 0 {
2005                FetchResult::Rows(None)
2006            } else {
2007                let notice_fut = self.adapter_client.session().recv_notice();
2008                tokio::select! {
2009                    err = self.conn.wait_closed() => return Err(err),
2010                    _ = time::sleep_until(deadline.unwrap_or_else(tokio::time::Instant::now)), if deadline.is_some() => FetchResult::Rows(None),
2011                    notice = notice_fut => {
2012                        FetchResult::Notice(notice)
2013                    }
2014                    batch = rows.remaining.recv() => match batch {
2015                        None => FetchResult::Rows(None),
2016                        Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2017                        Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2018                        Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2019                    },
2020                }
2021            };
2022
2023            match batch {
2024                FetchResult::Rows(None) => break,
2025                FetchResult::Rows(Some(mut batch_rows)) => {
2026                    if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2027                        let msg = err.to_string();
2028                        return self
2029                            .error(err.into_response(Severity::Error))
2030                            .await
2031                            .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2032                    }
2033
2034                    // If wait_once is true: the first time this fn is called it blocks (same as
2035                    // deadline == None). The second time this fn is called it should behave the
2036                    // same a 0s timeout.
2037                    if wait_once && batch_rows.peek().is_some() {
2038                        deadline = Some(tokio::time::Instant::now());
2039                        wait_once = false;
2040                    }
2041
2042                    // Send a portion of the rows.
2043                    let mut sent_rows = 0;
2044                    let mut sent_bytes = 0;
2045                    let messages = (&mut batch_rows)
2046                        // TODO(parkmycar): This is a fair bit of juggling between iterator types
2047                        // to count the total number of bytes. Alternatively we could track the
2048                        // total sent bytes in this .map(...) call, but having side effects in map
2049                        // is a code smell.
2050                        .map(|row| {
2051                            let row_len = row.byte_len();
2052                            let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2053                            (row_len, BackendMessage::DataRow(values))
2054                        })
2055                        .inspect(|(row_len, _)| {
2056                            sent_bytes += row_len;
2057                            sent_rows += 1
2058                        })
2059                        .map(|(_row_len, row)| row)
2060                        .take(want_rows);
2061                    self.send_all(messages).await?;
2062
2063                    total_sent_rows += sent_rows;
2064                    total_sent_bytes += sent_bytes;
2065                    want_rows -= sent_rows;
2066
2067                    // If we have sent the number of requested rows, put the remainder of the batch
2068                    // (if any) back and stop sending.
2069                    if want_rows == 0 {
2070                        if batch_rows.peek().is_some() {
2071                            rows.current = Some(batch_rows);
2072                        }
2073                        break;
2074                    }
2075
2076                    self.conn.flush().await?;
2077                }
2078                FetchResult::Notice(notice) => {
2079                    self.send(notice.into_response()).await?;
2080                    self.conn.flush().await?;
2081                }
2082                FetchResult::Error(text) => {
2083                    return self
2084                        .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2085                        .await
2086                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2087                }
2088                FetchResult::Canceled => {
2089                    return self
2090                        .error(ErrorResponse::error(
2091                            SqlState::QUERY_CANCELED,
2092                            "canceling statement due to user request",
2093                        ))
2094                        .await
2095                        .map(|state| (state, SendRowsEndedReason::Canceled));
2096                }
2097            }
2098        }
2099
2100        let portal = self
2101            .adapter_client
2102            .session()
2103            .get_portal_unverified_mut(&portal_name)
2104            .expect("valid portal name for send rows");
2105
2106        let saw_rows = rows.remaining.saw_rows;
2107        let no_more_rows = rows.no_more_rows();
2108        let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2109
2110        // Always return rows back, even if it's empty. This prevents an unclosed
2111        // portal from re-executing after it has been emptied.
2112        *portal.state = PortalState::InProgress(Some(rows));
2113
2114        let fetch_portal = fetch_portal_name.map(|name| {
2115            self.adapter_client
2116                .session()
2117                .get_portal_unverified_mut(&name)
2118                .expect("valid fetch portal")
2119        });
2120        let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2121        self.send(response_message).await?;
2122
2123        // Attend to metrics if there are no more rows.
2124        if no_more_rows {
2125            let statement_type = if let Some(stmt) = &self
2126                .adapter_client
2127                .session()
2128                .get_portal_unverified(&portal_name)
2129                .expect("valid portal name for send_rows")
2130                .stmt
2131            {
2132                metrics::statement_type_label_value(stmt.deref())
2133            } else {
2134                "no-statement"
2135            };
2136            let duration = if saw_rows {
2137                recorded_first_row_instant
2138                    .expect("recorded_first_row_instant because saw_rows")
2139                    .elapsed()
2140            } else {
2141                // If the result is empty, then we define time from first to last row as 0.
2142                // (Note that, currently, an empty result involves a PeekResponse with 0 rows, which
2143                // does flip `saw_rows`, so this code path is currently not exercised.)
2144                Duration::ZERO
2145            };
2146            self.adapter_client
2147                .inner()
2148                .metrics()
2149                .result_rows_first_to_last_byte_seconds
2150                .with_label_values(&[statement_type])
2151                .observe(duration.as_secs_f64());
2152        }
2153
2154        Ok((
2155            State::Ready,
2156            SendRowsEndedReason::Success {
2157                result_size: u64::cast_from(total_sent_bytes),
2158                rows_returned: u64::cast_from(total_sent_rows),
2159            },
2160        ))
2161    }
2162
2163    #[mz_ore::instrument(level = "debug")]
2164    async fn copy_rows(
2165        &mut self,
2166        format: CopyFormat,
2167        row_desc: RelationDesc,
2168        mut stream: RecordFirstRowStream,
2169    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2170        let (row_format, encode_format) = match format {
2171            CopyFormat::Text => (
2172                CopyFormatParams::Text(CopyTextFormatParams::default()),
2173                Format::Text,
2174            ),
2175            CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2176            CopyFormat::Csv => (
2177                CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2178                Format::Text,
2179            ),
2180            CopyFormat::Parquet => {
2181                let text = "Parquet format is not supported".to_string();
2182                return self
2183                    .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2184                    .await
2185                    .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2186            }
2187        };
2188
2189        let encode_fn = |row: &RowRef, typ: &RelationType, out: &mut Vec<u8>| {
2190            mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2191        };
2192
2193        let typ = row_desc.typ();
2194        let column_formats = iter::repeat(encode_format)
2195            .take(typ.column_types.len())
2196            .collect();
2197        self.send(BackendMessage::CopyOutResponse {
2198            overall_format: encode_format,
2199            column_formats,
2200        })
2201        .await?;
2202
2203        // In Postgres, binary copy has a header that is followed (in the same
2204        // CopyData) by the first row. In order to replicate their behavior, use a
2205        // common vec that we can extend one time now and then fill up with the encode
2206        // functions.
2207        let mut out = Vec::new();
2208
2209        if let CopyFormat::Binary = format {
2210            // 11-byte signature.
2211            out.extend(b"PGCOPY\n\xFF\r\n\0");
2212            // 32-bit flags field.
2213            out.extend([0, 0, 0, 0]);
2214            // 32-bit header extension length field.
2215            out.extend([0, 0, 0, 0]);
2216        }
2217
2218        let mut count = 0;
2219        let mut total_sent_bytes = 0;
2220        loop {
2221            tokio::select! {
2222                e = self.conn.wait_closed() => return Err(e),
2223                batch = stream.recv() => match batch {
2224                    None => break,
2225                    Some(PeekResponseUnary::Error(text)) => {
2226                        return self
2227                            .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2228                        .await
2229                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2230                    }
2231                    Some(PeekResponseUnary::Canceled) => {
2232                        return self.error(ErrorResponse::error(
2233                                SqlState::QUERY_CANCELED,
2234                                "canceling statement due to user request",
2235                            ))
2236                            .await.map(|state| (state, SendRowsEndedReason::Canceled));
2237                    }
2238                    Some(PeekResponseUnary::Rows(mut rows)) => {
2239                        count += rows.count();
2240                        while let Some(row) = rows.next() {
2241                            total_sent_bytes += row.byte_len();
2242                            encode_fn(row, typ, &mut out)?;
2243                            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2244                                .await?;
2245                        }
2246                    }
2247                },
2248                notice = self.adapter_client.session().recv_notice() => {
2249                    self.send(notice.into_response())
2250                        .await?;
2251                    self.conn.flush().await?;
2252                }
2253            }
2254
2255            self.conn.flush().await?;
2256        }
2257        // Send required trailers.
2258        if let CopyFormat::Binary = format {
2259            let trailer: i16 = -1;
2260            out.extend(trailer.to_be_bytes());
2261            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2262                .await?;
2263        }
2264
2265        let tag = format!("COPY {}", count);
2266        self.send(BackendMessage::CopyDone).await?;
2267        self.send(BackendMessage::CommandComplete { tag }).await?;
2268        Ok((
2269            State::Ready,
2270            SendRowsEndedReason::Success {
2271                result_size: u64::cast_from(total_sent_bytes),
2272                rows_returned: u64::cast_from(count),
2273            },
2274        ))
2275    }
2276
2277    /// Handles the copy-in mode of the postgres protocol from transferring
2278    /// data to the server.
2279    #[instrument(level = "debug")]
2280    async fn copy_from(
2281        &mut self,
2282        id: CatalogItemId,
2283        columns: Vec<ColumnIndex>,
2284        params: CopyFormatParams<'_>,
2285        row_desc: RelationDesc,
2286        mut ctx_extra: ExecuteContextExtra,
2287    ) -> Result<State, io::Error> {
2288        let res = self
2289            .copy_from_inner(id, columns, params, row_desc, &mut ctx_extra)
2290            .await;
2291        match &res {
2292            Ok(State::Done) => {
2293                // The connection closed gracefully without sending us a `CopyDone`,
2294                // causing us to just drop the copy request.
2295                // For the purposes of statement logging, we count this as a cancellation.
2296                self.adapter_client
2297                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2298            }
2299            Err(e) => {
2300                self.adapter_client.retire_execute(
2301                    ctx_extra,
2302                    StatementEndedExecutionReason::Errored {
2303                        error: format!("{e}"),
2304                    },
2305                );
2306            }
2307            other => {
2308                tracing::warn!(?other, "aborting COPY FROM");
2309                self.adapter_client
2310                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Aborted);
2311            }
2312        }
2313        res
2314    }
2315
2316    async fn copy_from_inner(
2317        &mut self,
2318        id: CatalogItemId,
2319        columns: Vec<ColumnIndex>,
2320        params: CopyFormatParams<'_>,
2321        row_desc: RelationDesc,
2322        ctx_extra: &mut ExecuteContextExtra,
2323    ) -> Result<State, io::Error> {
2324        let typ = row_desc.typ();
2325        let column_formats = vec![Format::Text; typ.column_types.len()];
2326        self.send(BackendMessage::CopyInResponse {
2327            overall_format: Format::Text,
2328            column_formats,
2329        })
2330        .await?;
2331        self.conn.flush().await?;
2332
2333        let system_vars = self.adapter_client.get_system_vars().await;
2334        let max_size = system_vars
2335            .get(MAX_COPY_FROM_SIZE.name())
2336            .ok()
2337            .and_then(|max_size| max_size.value().parse().ok())
2338            .unwrap_or(usize::MAX);
2339        tracing::debug!("COPY FROM max buffer size: {max_size} bytes");
2340
2341        let mut data = Vec::new();
2342        loop {
2343            let message = self.conn.recv().await?;
2344            match message {
2345                Some(FrontendMessage::CopyData(buf)) => {
2346                    // Bail before we OOM.
2347                    if (data.len() + buf.len()) > max_size {
2348                        return self
2349                            .error(ErrorResponse::error(
2350                                SqlState::INSUFFICIENT_RESOURCES,
2351                                "COPY FROM STDIN too large",
2352                            ))
2353                            .await;
2354                    }
2355                    data.extend(buf)
2356                }
2357                Some(FrontendMessage::CopyDone) => break,
2358                Some(FrontendMessage::CopyFail(err)) => {
2359                    self.adapter_client.retire_execute(
2360                        std::mem::take(ctx_extra),
2361                        StatementEndedExecutionReason::Canceled,
2362                    );
2363                    return self
2364                        .error(ErrorResponse::error(
2365                            SqlState::QUERY_CANCELED,
2366                            format!("COPY from stdin failed: {}", err),
2367                        ))
2368                        .await;
2369                }
2370                Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2371                Some(_) => {
2372                    let msg = "unexpected message type during COPY from stdin";
2373                    self.adapter_client.retire_execute(
2374                        std::mem::take(ctx_extra),
2375                        StatementEndedExecutionReason::Errored {
2376                            error: msg.to_string(),
2377                        },
2378                    );
2379                    return self
2380                        .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
2381                        .await;
2382                }
2383                None => {
2384                    return Ok(State::Done);
2385                }
2386            }
2387        }
2388
2389        let column_types = typ
2390            .column_types
2391            .iter()
2392            .map(|x| &x.scalar_type)
2393            .map(mz_pgrepr::Type::from)
2394            .collect::<Vec<mz_pgrepr::Type>>();
2395
2396        let rows = match mz_pgcopy::decode_copy_format(&data, &column_types, params) {
2397            Ok(rows) => rows,
2398            Err(e) => {
2399                self.adapter_client.retire_execute(
2400                    std::mem::take(ctx_extra),
2401                    StatementEndedExecutionReason::Errored {
2402                        error: e.to_string(),
2403                    },
2404                );
2405                return self
2406                    .error(ErrorResponse::error(
2407                        SqlState::BAD_COPY_FILE_FORMAT,
2408                        format!("{}", e),
2409                    ))
2410                    .await;
2411            }
2412        };
2413
2414        let count = rows.len();
2415
2416        if let Err(e) = self
2417            .adapter_client
2418            .insert_rows(id, columns, rows, std::mem::take(ctx_extra))
2419            .await
2420        {
2421            self.adapter_client.retire_execute(
2422                std::mem::take(ctx_extra),
2423                StatementEndedExecutionReason::Errored {
2424                    error: e.to_string(),
2425                },
2426            );
2427            return self.error(e.into_response(Severity::Error)).await;
2428        }
2429
2430        let tag = format!("COPY {}", count);
2431        self.send(BackendMessage::CommandComplete { tag }).await?;
2432
2433        Ok(State::Ready)
2434    }
2435
2436    #[instrument(level = "debug")]
2437    async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
2438        let notices = self
2439            .adapter_client
2440            .session()
2441            .drain_notices()
2442            .into_iter()
2443            .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
2444        self.send_all(notices).await?;
2445        Ok(())
2446    }
2447
2448    #[instrument(level = "debug")]
2449    async fn error(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
2450        assert!(err.severity.is_error());
2451        debug!(
2452            "cid={} error code={}",
2453            self.adapter_client.session().conn_id(),
2454            err.code.code()
2455        );
2456        let is_fatal = err.severity.is_fatal();
2457        self.send(BackendMessage::ErrorResponse(err)).await?;
2458
2459        let txn = self.adapter_client.session().transaction();
2460        match txn {
2461            // Error can be called from describe and parse and so might not be in an active
2462            // transaction.
2463            TransactionStatus::Default | TransactionStatus::Failed(_) => {}
2464            // In Started (i.e., a single statement), cleanup ourselves.
2465            TransactionStatus::Started(_) => {
2466                self.rollback_transaction().await?;
2467            }
2468            // Implicit transactions also clear themselves.
2469            TransactionStatus::InTransactionImplicit(_) => {
2470                self.rollback_transaction().await?;
2471            }
2472            // Explicit transactions move to failed.
2473            TransactionStatus::InTransaction(_) => {
2474                self.adapter_client.fail_transaction();
2475            }
2476        };
2477        if is_fatal {
2478            Ok(State::Done)
2479        } else {
2480            Ok(State::Drain)
2481        }
2482    }
2483
2484    #[instrument(level = "debug")]
2485    async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
2486        self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
2487            SqlState::IN_FAILED_SQL_TRANSACTION,
2488            ABORTED_TXN_MSG,
2489        )))
2490        .await?;
2491        Ok(State::Drain)
2492    }
2493
2494    fn is_aborted_txn(&mut self) -> bool {
2495        matches!(
2496            self.adapter_client.session().transaction(),
2497            TransactionStatus::Failed(_)
2498        )
2499    }
2500}
2501
2502fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
2503    match (formats.len(), n) {
2504        (0, e) => Ok(vec![Format::Text; e]),
2505        (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
2506        (a, e) if a == e => Ok(formats),
2507        (a, e) => Err(format!(
2508            "expected {} field format specifiers, but got {}",
2509            e, a
2510        )),
2511    }
2512}
2513
2514fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
2515    match &stmt_desc.relation_desc {
2516        Some(desc) if !stmt_desc.is_copy => {
2517            BackendMessage::RowDescription(message::encode_row_description(desc, formats))
2518        }
2519        _ => BackendMessage::NoData,
2520    }
2521}
2522
2523type GetResponse = fn(
2524    max_rows: ExecuteCount,
2525    total_sent_rows: usize,
2526    fetch_portal: Option<PortalRefMut>,
2527) -> BackendMessage;
2528
2529// A GetResponse used by send_rows during execute messages on portals or for
2530// simple query messages.
2531fn portal_exec_message(
2532    max_rows: ExecuteCount,
2533    total_sent_rows: usize,
2534    _fetch_portal: Option<PortalRefMut>,
2535) -> BackendMessage {
2536    // If max_rows is not specified, we will always send back a CommandComplete. If
2537    // max_rows is specified, we only send CommandComplete if there were more rows
2538    // requested than were remaining. That is, if max_rows == number of rows that
2539    // were remaining before sending (not that are remaining after sending), then
2540    // we still send a PortalSuspended. The number of remaining rows after the rows
2541    // have been sent doesn't matter. This matches postgres.
2542    match max_rows {
2543        ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
2544            BackendMessage::PortalSuspended
2545        }
2546        _ => BackendMessage::CommandComplete {
2547            tag: format!("SELECT {}", total_sent_rows),
2548        },
2549    }
2550}
2551
2552// A GetResponse used by send_rows during FETCH queries.
2553fn fetch_message(
2554    _max_rows: ExecuteCount,
2555    total_sent_rows: usize,
2556    fetch_portal: Option<PortalRefMut>,
2557) -> BackendMessage {
2558    let tag = format!("FETCH {}", total_sent_rows);
2559    if let Some(portal) = fetch_portal {
2560        *portal.state = PortalState::Completed(Some(tag.clone()));
2561    }
2562    BackendMessage::CommandComplete { tag }
2563}
2564
2565#[derive(Debug, Copy, Clone)]
2566enum ExecuteCount {
2567    All,
2568    Count(usize),
2569}
2570
2571// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
2572fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
2573    match stmt {
2574        // Add PREPARE to this if we ever support it.
2575        Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
2576        None => false,
2577    }
2578}
2579
2580#[derive(Debug)]
2581enum FetchResult {
2582    Rows(Option<Box<dyn RowIterator + Send + Sync>>),
2583    Canceled,
2584    Error(String),
2585    Notice(AdapterNotice),
2586}
2587
2588#[cfg(test)]
2589mod test {
2590    use super::*;
2591
2592    #[mz_ore::test]
2593    fn test_parse_options() {
2594        struct TestCase {
2595            input: &'static str,
2596            expect: Result<Vec<(&'static str, &'static str)>, ()>,
2597        }
2598        let tests = vec![
2599            TestCase {
2600                input: "",
2601                expect: Ok(vec![]),
2602            },
2603            TestCase {
2604                input: "--key",
2605                expect: Err(()),
2606            },
2607            TestCase {
2608                input: "--key=val",
2609                expect: Ok(vec![("key", "val")]),
2610            },
2611            TestCase {
2612                input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
2613                expect: Ok(vec![
2614                    ("key", "val"),
2615                    ("key2", "val2"),
2616                    ("key3", "val3"),
2617                    ("key4", "val4"),
2618                    ("key5", "val5"),
2619                ]),
2620            },
2621            TestCase {
2622                input: r#"-c\ key=val"#,
2623                expect: Ok(vec![(" key", "val")]),
2624            },
2625            TestCase {
2626                input: "--key=val -ckey2 val2",
2627                expect: Err(()),
2628            },
2629            // Unclear what this should do.
2630            TestCase {
2631                input: "--key=",
2632                expect: Ok(vec![("key", "")]),
2633            },
2634        ];
2635        for test in tests {
2636            let got = parse_options(test.input);
2637            let expect = test.expect.map(|r| {
2638                r.into_iter()
2639                    .map(|(k, v)| (k.to_owned(), v.to_owned()))
2640                    .collect()
2641            });
2642            assert_eq!(got, expect, "input: {}", test.input);
2643        }
2644    }
2645
2646    #[mz_ore::test]
2647    fn test_parse_option() {
2648        struct TestCase {
2649            input: &'static str,
2650            expect: Result<(&'static str, &'static str), ()>,
2651        }
2652        let tests = vec![
2653            TestCase {
2654                input: "",
2655                expect: Err(()),
2656            },
2657            TestCase {
2658                input: "--",
2659                expect: Err(()),
2660            },
2661            TestCase {
2662                input: "--c",
2663                expect: Err(()),
2664            },
2665            TestCase {
2666                input: "a=b",
2667                expect: Err(()),
2668            },
2669            TestCase {
2670                input: "--a=b",
2671                expect: Ok(("a", "b")),
2672            },
2673            TestCase {
2674                input: "--ca=b",
2675                expect: Ok(("ca", "b")),
2676            },
2677            TestCase {
2678                input: "-ca=b",
2679                expect: Ok(("a", "b")),
2680            },
2681            // Unclear what this should error, but at least test it.
2682            TestCase {
2683                input: "--=",
2684                expect: Ok(("", "")),
2685            },
2686        ];
2687        for test in tests {
2688            let got = parse_option(test.input);
2689            assert_eq!(got, test.expect, "input: {}", test.input);
2690        }
2691    }
2692
2693    #[mz_ore::test]
2694    fn test_split_options() {
2695        struct TestCase {
2696            input: &'static str,
2697            expect: Vec<&'static str>,
2698        }
2699        let tests = vec![
2700            TestCase {
2701                input: "",
2702                expect: vec![],
2703            },
2704            TestCase {
2705                input: "  ",
2706                expect: vec![],
2707            },
2708            TestCase {
2709                input: " a ",
2710                expect: vec!["a"],
2711            },
2712            TestCase {
2713                input: "  ab     cd   ",
2714                expect: vec!["ab", "cd"],
2715            },
2716            TestCase {
2717                input: r#"  ab\     cd   "#,
2718                expect: vec!["ab ", "cd"],
2719            },
2720            TestCase {
2721                input: r#"  ab\\     cd   "#,
2722                expect: vec![r#"ab\"#, "cd"],
2723            },
2724            TestCase {
2725                input: r#"  ab\\\     cd   "#,
2726                expect: vec![r#"ab\ "#, "cd"],
2727            },
2728            TestCase {
2729                input: r#"  ab\\\ cd   "#,
2730                expect: vec![r#"ab\ cd"#],
2731            },
2732            TestCase {
2733                input: r#"  ab\\\cd   "#,
2734                expect: vec![r#"ab\cd"#],
2735            },
2736            TestCase {
2737                input: r#"a\"#,
2738                expect: vec!["a"],
2739            },
2740            TestCase {
2741                input: r#"a\ "#,
2742                expect: vec!["a "],
2743            },
2744            TestCase {
2745                input: r#"\"#,
2746                expect: vec![],
2747            },
2748            TestCase {
2749                input: r#"\ "#,
2750                expect: vec![r#" "#],
2751            },
2752            TestCase {
2753                input: r#" \ "#,
2754                expect: vec![r#" "#],
2755            },
2756            TestCase {
2757                input: r#"\  "#,
2758                expect: vec![r#" "#],
2759            },
2760        ];
2761        for test in tests {
2762            let got = split_options(test.input);
2763            assert_eq!(got, test.expect, "input: {}", test.input);
2764        }
2765    }
2766}