mz_pgwire/
protocol.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use futures::future::{BoxFuture, FutureExt, pending};
21use itertools::Itertools;
22use mz_adapter::client::RecordFirstRowStream;
23use mz_adapter::session::{
24    EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState,
25    SessionConfig, TransactionStatus,
26};
27use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
28use mz_adapter::{
29    AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, PeekResponseUnary, metrics,
30    verify_datum_desc,
31};
32use mz_auth::password::Password;
33use mz_authenticator::Authenticator;
34use mz_ore::cast::CastFrom;
35use mz_ore::netio::AsyncReady;
36use mz_ore::now::{EpochMillis, SYSTEM_TIME};
37use mz_ore::str::StrExt;
38use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
39use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
40use mz_pgwire_common::{
41    ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
42    VERSIONS,
43};
44use mz_repr::user::InternalUserMetadata;
45use mz_repr::{
46    CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
47    SqlRelationType, SqlScalarType,
48};
49use mz_server_core::TlsMode;
50use mz_server_core::listeners::AllowedRoles;
51use mz_sql::ast::display::AstDisplay;
52use mz_sql::ast::{CopyDirection, CopyStatement, FetchDirection, Ident, Raw, Statement};
53use mz_sql::parse::StatementParseResult;
54use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
55use mz_sql::session::metadata::SessionMetadata;
56use mz_sql::session::user::INTERNAL_USER_NAMES;
57use mz_sql::session::vars::{MAX_COPY_FROM_SIZE, Var, VarInput};
58use postgres::error::SqlState;
59use tokio::io::{self, AsyncRead, AsyncWrite};
60use tokio::select;
61use tokio::time::{self};
62use tokio_metrics::TaskMetrics;
63use tokio_stream::wrappers::UnboundedReceiverStream;
64use tracing::{Instrument, debug, debug_span, warn};
65use uuid::Uuid;
66
67use crate::codec::{
68    FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
69};
70use crate::message::{
71    self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
72    SASLServerFirstMessage,
73};
74
75/// Reports whether the given stream begins with a pgwire handshake.
76///
77/// To avoid false negatives, there must be at least eight bytes in `buf`.
78pub fn match_handshake(buf: &[u8]) -> bool {
79    // The pgwire StartupMessage looks like this:
80    //
81    //     i32 - Length of entire message.
82    //     i32 - Protocol version number.
83    //     [String] - Arbitrary key-value parameters of any length.
84    //
85    // Since arbitrary parameters can be included in the StartupMessage, the
86    // first Int32 is worthless, since the message could have any length.
87    // Instead, we sniff the protocol version number.
88    if buf.len() < 8 {
89        return false;
90    }
91    let version = NetworkEndian::read_i32(&buf[4..8]);
92    VERSIONS.contains(&version)
93}
94
95/// Parameters for the [`run`] function.
96pub struct RunParams<'a, A, I>
97where
98    I: Iterator<Item = TaskMetrics> + Send,
99{
100    /// The TLS mode of the pgwire server.
101    pub tls_mode: Option<TlsMode>,
102    /// A client for the adapter.
103    pub adapter_client: mz_adapter::Client,
104    /// The connection to the client.
105    pub conn: &'a mut FramedConn<A>,
106    /// The universally unique identifier for the connection.
107    pub conn_uuid: Uuid,
108    /// The protocol version that the client provided in the startup message.
109    pub version: i32,
110    /// The parameters that the client provided in the startup message.
111    pub params: BTreeMap<String, String>,
112    /// Authentication method to use. Frontegg, Password, or None.
113    pub authenticator: Authenticator,
114    /// Global connection limit and count
115    pub active_connection_counter: ConnectionCounter,
116    /// Helm chart version
117    pub helm_chart_version: Option<String>,
118    /// Whether to allow reserved users (ie: mz_system).
119    pub allowed_roles: AllowedRoles,
120    /// Tokio metrics
121    pub tokio_metrics_intervals: I,
122}
123
124/// Runs a pgwire connection to completion.
125///
126/// This involves responding to `FrontendMessage::StartupMessage` and all future
127/// requests until the client terminates the connection or a fatal error occurs.
128///
129/// Note that this function returns successfully even upon delivering a fatal
130/// error to the client. It only returns `Err` if an unexpected I/O error occurs
131/// while communicating with the client, e.g., if the connection is severed in
132/// the middle of a request.
133#[mz_ore::instrument(level = "debug")]
134pub async fn run<'a, A, I>(
135    RunParams {
136        tls_mode,
137        adapter_client,
138        conn,
139        conn_uuid,
140        version,
141        mut params,
142        authenticator,
143        active_connection_counter,
144        helm_chart_version,
145        allowed_roles,
146        tokio_metrics_intervals,
147    }: RunParams<'a, A, I>,
148) -> Result<(), io::Error>
149where
150    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
151    I: Iterator<Item = TaskMetrics> + Send,
152{
153    if version != VERSION_3 {
154        return conn
155            .send(ErrorResponse::fatal(
156                SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
157                "server does not support the client's requested protocol version",
158            ))
159            .await;
160    }
161
162    let user = params.remove("user").unwrap_or_else(String::new);
163
164    // TODO move this somewhere it can be shared with HTTP
165    let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
166    // this is a superset of internal users
167    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
168    let role_allowed = match allowed_roles {
169        AllowedRoles::Normal => !is_reserved_user,
170        AllowedRoles::Internal => is_internal_user,
171        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
172    };
173    if !role_allowed {
174        let msg = format!("unauthorized login to user '{user}'");
175        return conn
176            .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
177            .await;
178    }
179
180    if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
181        return conn.send(err).await;
182    }
183
184    let (mut session, expired) = match authenticator {
185        Authenticator::Frontegg(frontegg) => {
186            conn.send(BackendMessage::AuthenticationCleartextPassword)
187                .await?;
188            conn.flush().await?;
189            let password = match conn.recv().await? {
190                Some(FrontendMessage::RawAuthentication(data)) => {
191                    match decode_password(Cursor::new(&data)).ok() {
192                        Some(FrontendMessage::Password { password }) => password,
193                        _ => {
194                            return conn
195                                .send(ErrorResponse::fatal(
196                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
197                                    "expected Password message",
198                                ))
199                                .await;
200                        }
201                    }
202                }
203                _ => {
204                    return conn
205                        .send(ErrorResponse::fatal(
206                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
207                            "expected Password message",
208                        ))
209                        .await;
210                }
211            };
212
213            let auth_response = frontegg.authenticate(&user, &password).await;
214            match auth_response {
215                Ok(mut auth_session) => {
216                    // Create a session based on the auth session.
217                    //
218                    // In particular, it's important that the username come from the
219                    // auth session, as Frontegg may return an email address with
220                    // different casing than the user supplied via the pgwire
221                    // username field. We want to use the Frontegg casing as
222                    // canonical.
223                    let session = adapter_client.new_session(SessionConfig {
224                        conn_id: conn.conn_id().clone(),
225                        uuid: conn_uuid,
226                        user: auth_session.user().into(),
227                        client_ip: conn.peer_addr().clone(),
228                        external_metadata_rx: Some(auth_session.external_metadata_rx()),
229                        internal_user_metadata: None,
230                        helm_chart_version,
231                    });
232                    let expired = async move { auth_session.expired().await };
233                    (session, expired.left_future())
234                }
235                Err(err) => {
236                    warn!(?err, "pgwire connection failed authentication");
237                    return conn
238                        .send(ErrorResponse::fatal(
239                            SqlState::INVALID_PASSWORD,
240                            "invalid password",
241                        ))
242                        .await;
243                }
244            }
245        }
246        Authenticator::Password(adapter_client) => {
247            conn.send(BackendMessage::AuthenticationCleartextPassword)
248                .await?;
249            conn.flush().await?;
250            let password = match conn.recv().await? {
251                Some(FrontendMessage::RawAuthentication(data)) => {
252                    match decode_password(Cursor::new(&data)).ok() {
253                        Some(FrontendMessage::Password { password }) => Password(password),
254                        _ => {
255                            return conn
256                                .send(ErrorResponse::fatal(
257                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
258                                    "expected Password message",
259                                ))
260                                .await;
261                        }
262                    }
263                }
264                _ => {
265                    return conn
266                        .send(ErrorResponse::fatal(
267                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
268                            "expected Password message",
269                        ))
270                        .await;
271                }
272            };
273            let auth_response = match adapter_client.authenticate(&user, &password).await {
274                Ok(resp) => resp,
275                Err(err) => {
276                    warn!(?err, "pgwire connection failed authentication");
277                    return conn
278                        .send(ErrorResponse::fatal(
279                            SqlState::INVALID_PASSWORD,
280                            "invalid password",
281                        ))
282                        .await;
283                }
284            };
285            let session = adapter_client.new_session(SessionConfig {
286                conn_id: conn.conn_id().clone(),
287                uuid: conn_uuid,
288                user,
289                client_ip: conn.peer_addr().clone(),
290                external_metadata_rx: None,
291                internal_user_metadata: Some(InternalUserMetadata {
292                    superuser: auth_response.superuser,
293                }),
294                helm_chart_version,
295            });
296            // No frontegg check, so auth session lasts indefinitely.
297            let auth_session = pending().right_future();
298            (session, auth_session)
299        }
300        Authenticator::Sasl(adapter_client) => {
301            // Start the handshake
302            conn.send(BackendMessage::AuthenticationSASL).await?;
303            conn.flush().await?;
304            // Get the initial response indicating chosen mechanism
305            let (mechanism, initial_response) = match conn.recv().await? {
306                Some(FrontendMessage::RawAuthentication(data)) => {
307                    match decode_sasl_initial_response(Cursor::new(&data)).ok() {
308                        Some(FrontendMessage::SASLInitialResponse {
309                            gs2_header,
310                            mechanism,
311                            initial_response,
312                        }) => {
313                            // We do not support channel binding
314                            if gs2_header.channel_binding_enabled() {
315                                return conn
316                                    .send(ErrorResponse::fatal(
317                                        SqlState::PROTOCOL_VIOLATION,
318                                        "channel binding not supported",
319                                    ))
320                                    .await;
321                            }
322                            (mechanism, initial_response)
323                        }
324                        _ => {
325                            return conn
326                                .send(ErrorResponse::fatal(
327                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
328                                    "expected SASLInitialResponse message",
329                                ))
330                                .await;
331                        }
332                    }
333                }
334                _ => {
335                    return conn
336                        .send(ErrorResponse::fatal(
337                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
338                            "expected SASLInitialResponse message",
339                        ))
340                        .await;
341                }
342            };
343
344            if mechanism != "SCRAM-SHA-256" {
345                return conn
346                    .send(ErrorResponse::fatal(
347                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
348                        "unsupported SASL mechanism",
349                    ))
350                    .await;
351            }
352
353            if initial_response.nonce.len() > 256 {
354                return conn
355                    .send(ErrorResponse::fatal(
356                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
357                        "nonce too long",
358                    ))
359                    .await;
360            }
361
362            let (server_first_message_raw, mock_hash) = match adapter_client
363                .generate_sasl_challenge(&user, &initial_response.nonce)
364                .await
365            {
366                Ok(response) => {
367                    let server_first_message_raw = format!(
368                        "r={},s={},i={}",
369                        response.nonce, response.salt, response.iteration_count
370                    );
371
372                    let client_key = [0u8; 32];
373                    let server_key = [1u8; 32];
374                    let mock_hash = format!(
375                        "SCRAM-SHA-256${}:{}${}:{}",
376                        response.iteration_count,
377                        response.salt,
378                        BASE64_STANDARD.encode(client_key),
379                        BASE64_STANDARD.encode(server_key)
380                    );
381
382                    conn.send(BackendMessage::AuthenticationSASLContinue(
383                        SASLServerFirstMessage {
384                            iteration_count: response.iteration_count,
385                            nonce: response.nonce,
386                            salt: response.salt,
387                        },
388                    ))
389                    .await?;
390                    conn.flush().await?;
391                    (server_first_message_raw, mock_hash)
392                }
393                Err(e) => {
394                    return conn.send(e.into_response(Severity::Fatal)).await;
395                }
396            };
397
398            let auth_resp = match conn.recv().await? {
399                Some(FrontendMessage::RawAuthentication(data)) => {
400                    match decode_sasl_response(Cursor::new(&data)).ok() {
401                        Some(FrontendMessage::SASLResponse(response)) => {
402                            let auth_message = format!(
403                                "{},{},{}",
404                                initial_response.client_first_message_bare_raw,
405                                server_first_message_raw,
406                                response.client_final_message_bare_raw
407                            );
408                            if response.proof.len() > 1024 {
409                                return conn
410                                    .send(ErrorResponse::fatal(
411                                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
412                                        "proof too long",
413                                    ))
414                                    .await;
415                            }
416                            match adapter_client
417                                .verify_sasl_proof(
418                                    &user,
419                                    &response.proof,
420                                    &auth_message,
421                                    &mock_hash,
422                                )
423                                .await
424                            {
425                                Ok(resp) => {
426                                    conn.send(BackendMessage::AuthenticationSASLFinal(
427                                        SASLServerFinalMessage {
428                                            kind: SASLServerFinalMessageKinds::Verifier(
429                                                resp.verifier,
430                                            ),
431                                            extensions: vec![],
432                                        },
433                                    ))
434                                    .await?;
435                                    conn.flush().await?;
436                                    resp.auth_resp
437                                }
438                                Err(_) => {
439                                    return conn
440                                        .send(ErrorResponse::fatal(
441                                            SqlState::INVALID_PASSWORD,
442                                            "invalid password",
443                                        ))
444                                        .await;
445                                }
446                            }
447                        }
448                        _ => {
449                            return conn
450                                .send(ErrorResponse::fatal(
451                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
452                                    "expected SASLResponse message",
453                                ))
454                                .await;
455                        }
456                    }
457                }
458                _ => {
459                    return conn
460                        .send(ErrorResponse::fatal(
461                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
462                            "expected SASLResponse message",
463                        ))
464                        .await;
465                }
466            };
467
468            let session = adapter_client.new_session(SessionConfig {
469                conn_id: conn.conn_id().clone(),
470                uuid: conn_uuid,
471                user,
472                client_ip: conn.peer_addr().clone(),
473                external_metadata_rx: None,
474                internal_user_metadata: Some(InternalUserMetadata {
475                    superuser: auth_resp.superuser,
476                }),
477                helm_chart_version,
478            });
479            // No frontegg check, so auth session lasts indefinitely.
480            let auth_session = pending().right_future();
481            (session, auth_session)
482        }
483        Authenticator::None => {
484            let session = adapter_client.new_session(SessionConfig {
485                conn_id: conn.conn_id().clone(),
486                uuid: conn_uuid,
487                user,
488                client_ip: conn.peer_addr().clone(),
489                external_metadata_rx: None,
490                internal_user_metadata: None,
491                helm_chart_version,
492            });
493            // No frontegg check, so auth session lasts indefinitely.
494            let auth_session = pending().right_future();
495            (session, auth_session)
496        }
497    };
498
499    let system_vars = adapter_client.get_system_vars().await;
500    for (name, value) in params {
501        let settings = match name.as_str() {
502            "options" => match parse_options(&value) {
503                Ok(opts) => opts,
504                Err(()) => {
505                    session.add_notice(AdapterNotice::BadStartupSetting {
506                        name,
507                        reason: "could not parse".into(),
508                    });
509                    continue;
510                }
511            },
512            _ => vec![(name, value)],
513        };
514        for (key, val) in settings {
515            const LOCAL: bool = false;
516            // TODO: Issuing an error here is better than what we did before
517            // (silently ignore errors on set), but erroring the connection
518            // might be the better behavior. We maybe need to support more
519            // options sent by psql and drivers before we can safely do this.
520            if let Err(err) =
521                session
522                    .vars_mut()
523                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
524            {
525                session.add_notice(AdapterNotice::BadStartupSetting {
526                    name: key,
527                    reason: err.to_string(),
528                });
529            }
530        }
531    }
532    session
533        .vars_mut()
534        .end_transaction(EndTransactionAction::Commit);
535
536    let _guard = match active_connection_counter.allocate_connection(session.user()) {
537        Ok(drop_connection) => drop_connection,
538        Err(e) => {
539            let e: AdapterError = e.into();
540            return conn.send(e.into_response(Severity::Fatal)).await;
541        }
542    };
543
544    // Register session with adapter.
545    let mut adapter_client = match adapter_client.startup(session).await {
546        Ok(adapter_client) => adapter_client,
547        Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
548    };
549
550    let mut buf = vec![BackendMessage::AuthenticationOk];
551    for var in adapter_client.session().vars().notify_set() {
552        buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
553    }
554    buf.push(BackendMessage::BackendKeyData {
555        conn_id: adapter_client.session().conn_id().unhandled(),
556        secret_key: adapter_client.session().secret_key(),
557    });
558    buf.extend(
559        adapter_client
560            .session()
561            .drain_notices()
562            .into_iter()
563            .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
564    );
565    buf.push(BackendMessage::ReadyForQuery(
566        adapter_client.session().transaction().into(),
567    ));
568    conn.send_all(buf).await?;
569    conn.flush().await?;
570
571    let machine = StateMachine {
572        conn,
573        adapter_client,
574        txn_needs_commit: false,
575        tokio_metrics_intervals,
576    };
577
578    select! {
579        r = machine.run() => {
580            // Errors produced internally (like MAX_REQUEST_SIZE being exceeded) should send an
581            // error to the client informing them why the connection was closed. We still want to
582            // return the original error up the stack, though, so we skip error checking during conn
583            // operations.
584            if let Err(err) = &r {
585                let _ = conn
586                    .send(ErrorResponse::fatal(
587                        SqlState::CONNECTION_FAILURE,
588                        err.to_string(),
589                    ))
590                    .await;
591                let _ = conn.flush().await;
592            }
593            r
594        },
595        _ = expired => {
596            conn
597                .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
598                .await?;
599            conn.flush().await
600        }
601    }
602}
603
604/// Returns (name, value) session settings pairs from an options value.
605///
606/// From Postgres, see pg_split_opts in postinit.c and process_postgres_switches
607/// in postgres.c.
608fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
609    let opts = split_options(value);
610    let mut pairs = Vec::with_capacity(opts.len());
611    let mut seen_prefix = false;
612    for opt in opts {
613        if !seen_prefix {
614            if opt == "-c" {
615                seen_prefix = true;
616            } else {
617                let (key, val) = parse_option(&opt)?;
618                pairs.push((key.to_owned(), val.to_owned()));
619            }
620        } else {
621            let (key, val) = opt.split_once('=').ok_or(())?;
622            pairs.push((key.to_owned(), val.to_owned()));
623            seen_prefix = false;
624        }
625    }
626    Ok(pairs)
627}
628
629/// Returns the parsed key and value from option of the form `--key=value`, `-c
630/// key=value`, or `-ckey=value`. Keys replace `-` with `_`. Returns an error if
631/// there was some other prefix.
632fn parse_option(option: &str) -> Result<(&str, &str), ()> {
633    let (key, value) = option.split_once('=').ok_or(())?;
634    for prefix in &["-c", "--"] {
635        if let Some(key) = key.strip_prefix(prefix) {
636            return Ok((key, value));
637        }
638    }
639    Err(())
640}
641
642/// Splits value by any number of spaces except those preceded by `\`.
643fn split_options(value: &str) -> Vec<String> {
644    let mut strs = Vec::new();
645    // Need to build a string because of the escaping, so we can't simply
646    // subslice into value, and this isn't called enough to need to make it
647    // smart so it only builds a string if needed.
648    let mut current = String::new();
649    let mut was_slash = false;
650    for c in value.chars() {
651        was_slash = match c {
652            ' ' => {
653                if was_slash {
654                    current.push(' ');
655                } else if !current.is_empty() {
656                    // To ignore multiple spaces in a row, only push if current
657                    // is not empty.
658                    strs.push(std::mem::take(&mut current));
659                }
660                false
661            }
662            '\\' => {
663                if was_slash {
664                    // Two slashes in a row will add a slash and not escape the
665                    // next char.
666                    current.push('\\');
667                    false
668                } else {
669                    true
670                }
671            }
672            _ => {
673                current.push(c);
674                false
675            }
676        };
677    }
678    // A `\` at the end will be ignored.
679    if !current.is_empty() {
680        strs.push(current);
681    }
682    strs
683}
684
685#[derive(Debug)]
686enum State {
687    Ready,
688    Drain,
689    Done,
690}
691
692struct StateMachine<'a, A, I>
693where
694    I: Iterator<Item = TaskMetrics> + Send + 'a,
695{
696    conn: &'a mut FramedConn<A>,
697    adapter_client: mz_adapter::SessionClient,
698    txn_needs_commit: bool,
699    tokio_metrics_intervals: I,
700}
701
702enum SendRowsEndedReason {
703    Success {
704        result_size: u64,
705        rows_returned: u64,
706    },
707    Errored {
708        error: String,
709    },
710    Canceled,
711}
712
713const ABORTED_TXN_MSG: &str =
714    "current transaction is aborted, commands ignored until end of transaction block";
715
716impl<'a, A, I> StateMachine<'a, A, I>
717where
718    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
719    I: Iterator<Item = TaskMetrics> + Send + 'a,
720{
721    // Manually desugar this (don't use `async fn run`) here because a much better
722    // error message is produced if there are problems with Send or other traits
723    // somewhere within the Future.
724    #[allow(clippy::manual_async_fn)]
725    #[mz_ore::instrument(level = "debug")]
726    fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
727        async move {
728            let mut state = State::Ready;
729            loop {
730                self.send_pending_notices().await?;
731                state = match state {
732                    State::Ready => self.advance_ready().await?,
733                    State::Drain => self.advance_drain().await?,
734                    State::Done => return Ok(()),
735                };
736                self.adapter_client
737                    .add_idle_in_transaction_session_timeout();
738            }
739        }
740    }
741
742    #[instrument(level = "debug")]
743    async fn advance_ready(&mut self) -> Result<State, io::Error> {
744        // Start a new metrics interval before the `recv()` call.
745        self.tokio_metrics_intervals
746            .next()
747            .expect("infinite iterator");
748
749        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
750        let message = select! {
751            biased;
752
753            // `recv_timeout()` is cancel-safe as per it's docs.
754            Some(timeout) = self.adapter_client.recv_timeout() => {
755                let err: AdapterError = timeout.into();
756                let conn_id = self.adapter_client.session().conn_id();
757                tracing::warn!("session timed out, conn_id {}", conn_id);
758
759                // Process the error, doing any state cleanup.
760                let error_response = err.into_response(Severity::Fatal);
761                let error_state = self.error(error_response).await;
762
763                // Terminate __after__ we do any cleanup.
764                self.adapter_client.terminate().await;
765
766                // We must wait for the client to send a request before we can send the error response.
767                // Due to the PG wire protocol, we can't send an ErrorResponse unless it is in response
768                // to a client message.
769                let _ = self.conn.recv().await?;
770                return error_state;
771            },
772            // `recv()` is cancel-safe as per it's docs.
773            message = self.conn.recv() => message?,
774        };
775
776        // Take the metrics since just before the `recv`.
777        let interval = self
778            .tokio_metrics_intervals
779            .next()
780            .expect("infinite iterator");
781        let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
782
783        // TODO(ggevay): Consider subtracting the scheduling delay from `received`. It's not obvious
784        // whether we should do this, because the result wouldn't exactly correspond to either first
785        // byte received or last byte received (for msgs that arrive in more than one network packet).
786        let received = SYSTEM_TIME();
787
788        self.adapter_client
789            .remove_idle_in_transaction_session_timeout();
790
791        // NOTE(guswynn): we could consider adding spans to all message types. Currently
792        // only a few message types seem useful.
793        let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
794
795        let start = message.as_ref().map(|_| Instant::now());
796        let next_state = match message {
797            Some(FrontendMessage::Query { sql }) => {
798                let query_root_span =
799                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
800                query_root_span.follows_from(tracing::Span::current());
801                self.query(sql, received)
802                    .instrument(query_root_span)
803                    .await?
804            }
805            Some(FrontendMessage::Parse {
806                name,
807                sql,
808                param_types,
809            }) => self.parse(name, sql, param_types).await?,
810            Some(FrontendMessage::Bind {
811                portal_name,
812                statement_name,
813                param_formats,
814                raw_params,
815                result_formats,
816            }) => {
817                self.bind(
818                    portal_name,
819                    statement_name,
820                    param_formats,
821                    raw_params,
822                    result_formats,
823                )
824                .await?
825            }
826            Some(FrontendMessage::Execute {
827                portal_name,
828                max_rows,
829            }) => {
830                let max_rows = match usize::try_from(max_rows) {
831                    Ok(0) | Err(_) => ExecuteCount::All, // If `max_rows < 0`, no limit.
832                    Ok(n) => ExecuteCount::Count(n),
833                };
834                let execute_root_span =
835                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
836                execute_root_span.follows_from(tracing::Span::current());
837                let state = self
838                    .execute(
839                        portal_name,
840                        max_rows,
841                        portal_exec_message,
842                        None,
843                        ExecuteTimeout::None,
844                        None,
845                        Some(received),
846                    )
847                    .instrument(execute_root_span)
848                    .await?;
849                // In PostgreSQL, when using the extended query protocol, some statements may
850                // trigger an eager commit of the current implicit transaction,
851                // see: <https://git.postgresql.org/gitweb/?p=postgresql.git&a=commitdiff&h=f92944137>.
852                //
853                // In Materialize, however, we eagerly commit every statement outside of an explicit
854                // transaction when using the extended query protocol. This allows us to eliminate
855                // the possibility of a multiple statement implicit transaction, which in turn
856                // allows us to apply single-statement optimizations to queries issued in implicit
857                // transactions in the extended query protocol.
858                //
859                // We don't immediately commit here to allow users to page through the portal if
860                // necessary. Committing the transaction would destroy the portal before the next
861                // Execute command has a chance to resume it. So we instead mark the transaction
862                // for commit the next time that `ensure_transaction` is called.
863                if self.adapter_client.session().transaction().is_implicit() {
864                    self.txn_needs_commit = true;
865                }
866                state
867            }
868            Some(FrontendMessage::DescribeStatement { name }) => {
869                self.describe_statement(&name).await?
870            }
871            Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
872            Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
873            Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
874            Some(FrontendMessage::Flush) => self.flush().await?,
875            Some(FrontendMessage::Sync) => self.sync().await?,
876            Some(FrontendMessage::Terminate) => State::Done,
877
878            Some(FrontendMessage::CopyData(_))
879            | Some(FrontendMessage::CopyDone)
880            | Some(FrontendMessage::CopyFail(_))
881            | Some(FrontendMessage::Password { .. })
882            | Some(FrontendMessage::RawAuthentication(_))
883            | Some(FrontendMessage::SASLInitialResponse { .. })
884            | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
885            None => State::Done,
886        };
887
888        if let Some(start) = start {
889            self.adapter_client
890                .inner()
891                .metrics()
892                .pgwire_message_processing_seconds
893                .with_label_values(&[message_name])
894                .observe(start.elapsed().as_secs_f64());
895        }
896        self.adapter_client
897            .inner()
898            .metrics()
899            .pgwire_recv_scheduling_delay_ms
900            .with_label_values(&[message_name])
901            .observe(recv_scheduling_delay_ms);
902
903        Ok(next_state)
904    }
905
906    async fn advance_drain(&mut self) -> Result<State, io::Error> {
907        let message = self.conn.recv().await?;
908        if message.is_some() {
909            self.adapter_client
910                .remove_idle_in_transaction_session_timeout();
911        }
912        match message {
913            Some(FrontendMessage::Sync) => self.sync().await,
914            None => Ok(State::Done),
915            _ => Ok(State::Drain),
916        }
917    }
918
919    /// Note that `lifecycle_timestamps` belongs to the whole "Simple Query", because the whole
920    /// Simple Query is received and parsed together. This means that if there are multiple
921    /// statements in a Simple Query, then all of them have the same `lifecycle_timestamps`.
922    #[instrument(level = "debug")]
923    async fn one_query(
924        &mut self,
925        stmt: Statement<Raw>,
926        sql: String,
927        lifecycle_timestamps: LifecycleTimestamps,
928    ) -> Result<State, io::Error> {
929        // Bind the portal. Note that this does not set the empty string prepared
930        // statement.
931        const EMPTY_PORTAL: &str = "";
932        if let Err(e) = self
933            .adapter_client
934            .declare(EMPTY_PORTAL.to_string(), stmt, sql)
935            .await
936        {
937            return self.error(e.into_response(Severity::Error)).await;
938        }
939        let portal = self
940            .adapter_client
941            .session()
942            .get_portal_unverified_mut(EMPTY_PORTAL)
943            .expect("unnamed portal should be present");
944
945        *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
946
947        let stmt_desc = portal.desc.clone();
948        if !stmt_desc.param_types.is_empty() {
949            return self
950                .error(ErrorResponse::error(
951                    SqlState::UNDEFINED_PARAMETER,
952                    "there is no parameter $1",
953                ))
954                .await;
955        }
956
957        // Maybe send row description.
958        if let Some(relation_desc) = &stmt_desc.relation_desc {
959            if !stmt_desc.is_copy {
960                let formats = vec![Format::Text; stmt_desc.arity()];
961                self.send(BackendMessage::RowDescription(
962                    message::encode_row_description(relation_desc, &formats),
963                ))
964                .await?;
965            }
966        }
967
968        let result = match self
969            .adapter_client
970            .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
971            .await
972        {
973            Ok((response, execute_started)) => {
974                self.send_pending_notices().await?;
975                self.send_execute_response(
976                    response,
977                    stmt_desc.relation_desc,
978                    EMPTY_PORTAL.to_string(),
979                    ExecuteCount::All,
980                    portal_exec_message,
981                    None,
982                    ExecuteTimeout::None,
983                    execute_started,
984                )
985                .await
986            }
987            Err(e) => {
988                self.send_pending_notices().await?;
989                self.error(e.into_response(Severity::Error)).await
990            }
991        };
992
993        // Destroy the portal.
994        self.adapter_client.session().remove_portal(EMPTY_PORTAL);
995
996        result
997    }
998
999    async fn ensure_transaction(
1000        &mut self,
1001        num_stmts: usize,
1002        message_type: &str,
1003    ) -> Result<(), io::Error> {
1004        let start = Instant::now();
1005        if self.txn_needs_commit {
1006            self.commit_transaction().await?;
1007        }
1008        // start_transaction can't error (but assert that just in case it changes in
1009        // the future.
1010        let res = self.adapter_client.start_transaction(Some(num_stmts));
1011        assert_ok!(res);
1012        self.adapter_client
1013            .inner()
1014            .metrics()
1015            .pgwire_ensure_transaction_seconds
1016            .with_label_values(&[message_type])
1017            .observe(start.elapsed().as_secs_f64());
1018        Ok(())
1019    }
1020
1021    fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1022        let parse_start = Instant::now();
1023        let result = match self.adapter_client.parse(sql) {
1024            Ok(result) => result.map_err(|e| {
1025                // Convert our 0-based byte position to pgwire's 1-based character
1026                // position.
1027                let pos = sql[..e.error.pos].chars().count() + 1;
1028                ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1029            }),
1030            Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1031        };
1032        self.adapter_client
1033            .inner()
1034            .metrics()
1035            .parse_seconds
1036            .observe(parse_start.elapsed().as_secs_f64());
1037        result
1038    }
1039
1040    /// Executes a "Simple Query", see
1041    /// <https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SIMPLE-QUERY>
1042    ///
1043    /// For implicit transaction handling, see "Multiple Statements in a Simple Query" in the above.
1044    #[instrument(level = "debug")]
1045    async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1046        // Parse first before doing any transaction checking.
1047        let stmts = match self.parse_sql(&sql) {
1048            Ok(stmts) => stmts,
1049            Err(err) => {
1050                self.error(err).await?;
1051                return self.ready().await;
1052            }
1053        };
1054
1055        let num_stmts = stmts.len();
1056
1057        // Compare with postgres' backend/tcop/postgres.c exec_simple_query.
1058        for StatementParseResult { ast: stmt, sql } in stmts {
1059            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1060            if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1061                self.aborted_txn_error().await?;
1062                break;
1063            }
1064
1065            // Start an implicit transaction if we aren't in any transaction and there's
1066            // more than one statement. This mirrors the `use_implicit_block` variable in
1067            // postgres.
1068            //
1069            // This needs to be done in the loop instead of once at the top because
1070            // a COMMIT/ROLLBACK statement needs to start a new transaction on next
1071            // statement.
1072            self.ensure_transaction(num_stmts, "query").await?;
1073
1074            match self
1075                .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1076                .await?
1077            {
1078                State::Ready => (),
1079                State::Drain => break,
1080                State::Done => return Ok(State::Done),
1081            }
1082        }
1083
1084        // Implicit transactions are closed at the end of a Query message.
1085        {
1086            if self.adapter_client.session().transaction().is_implicit() {
1087                self.commit_transaction().await?;
1088            }
1089        }
1090
1091        if num_stmts == 0 {
1092            self.send(BackendMessage::EmptyQueryResponse).await?;
1093        }
1094
1095        self.ready().await
1096    }
1097
1098    #[instrument(level = "debug")]
1099    async fn parse(
1100        &mut self,
1101        name: String,
1102        sql: String,
1103        param_oids: Vec<u32>,
1104    ) -> Result<State, io::Error> {
1105        // Start a transaction if we aren't in one.
1106        self.ensure_transaction(1, "parse").await?;
1107
1108        let mut param_types = vec![];
1109        for oid in param_oids {
1110            match mz_pgrepr::Type::from_oid(oid) {
1111                Ok(ty) => match SqlScalarType::try_from(&ty) {
1112                    Ok(ty) => param_types.push(Some(ty)),
1113                    Err(err) => {
1114                        return self
1115                            .error(ErrorResponse::error(
1116                                SqlState::INVALID_PARAMETER_VALUE,
1117                                err.to_string(),
1118                            ))
1119                            .await;
1120                    }
1121                },
1122                Err(_) if oid == 0 => param_types.push(None),
1123                Err(e) => {
1124                    return self
1125                        .error(ErrorResponse::error(
1126                            SqlState::PROTOCOL_VIOLATION,
1127                            e.to_string(),
1128                        ))
1129                        .await;
1130                }
1131            }
1132        }
1133
1134        let stmts = match self.parse_sql(&sql) {
1135            Ok(stmts) => stmts,
1136            Err(err) => {
1137                return self.error(err).await;
1138            }
1139        };
1140        if stmts.len() > 1 {
1141            return self
1142                .error(ErrorResponse::error(
1143                    SqlState::INTERNAL_ERROR,
1144                    "cannot insert multiple commands into a prepared statement",
1145                ))
1146                .await;
1147        }
1148        let (maybe_stmt, sql) = match stmts.into_iter().next() {
1149            None => (None, ""),
1150            Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1151        };
1152        if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1153            return self.aborted_txn_error().await;
1154        }
1155        match self
1156            .adapter_client
1157            .prepare(name, maybe_stmt, sql.to_string(), param_types)
1158            .await
1159        {
1160            Ok(()) => {
1161                self.send(BackendMessage::ParseComplete).await?;
1162                Ok(State::Ready)
1163            }
1164            Err(e) => self.error(e.into_response(Severity::Error)).await,
1165        }
1166    }
1167
1168    /// Commits and clears the current transaction.
1169    #[instrument(level = "debug")]
1170    async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1171        self.end_transaction(EndTransactionAction::Commit).await
1172    }
1173
1174    /// Rollback and clears the current transaction.
1175    #[instrument(level = "debug")]
1176    async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1177        self.end_transaction(EndTransactionAction::Rollback).await
1178    }
1179
1180    /// End a transaction and report to the user if an error occurred.
1181    #[instrument(level = "debug")]
1182    async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1183        self.txn_needs_commit = false;
1184        let resp = self.adapter_client.end_transaction(action).await;
1185        if let Err(err) = resp {
1186            self.send(BackendMessage::ErrorResponse(
1187                err.into_response(Severity::Error),
1188            ))
1189            .await?;
1190        }
1191        Ok(())
1192    }
1193
1194    #[instrument(level = "debug")]
1195    async fn bind(
1196        &mut self,
1197        portal_name: String,
1198        statement_name: String,
1199        param_formats: Vec<Format>,
1200        raw_params: Vec<Option<Vec<u8>>>,
1201        result_formats: Vec<Format>,
1202    ) -> Result<State, io::Error> {
1203        // Start a transaction if we aren't in one.
1204        self.ensure_transaction(1, "bind").await?;
1205
1206        let aborted_txn = self.is_aborted_txn();
1207        let stmt = match self
1208            .adapter_client
1209            .get_prepared_statement(&statement_name)
1210            .await
1211        {
1212            Ok(stmt) => stmt,
1213            Err(err) => return self.error(err.into_response(Severity::Error)).await,
1214        };
1215
1216        let param_types = &stmt.desc().param_types;
1217        if param_types.len() != raw_params.len() {
1218            let message = format!(
1219                "bind message supplies {actual} parameters, \
1220                 but prepared statement \"{name}\" requires {expected}",
1221                name = statement_name,
1222                actual = raw_params.len(),
1223                expected = param_types.len()
1224            );
1225            return self
1226                .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, message))
1227                .await;
1228        }
1229        let param_formats = match pad_formats(param_formats, raw_params.len()) {
1230            Ok(param_formats) => param_formats,
1231            Err(msg) => {
1232                return self
1233                    .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1234                    .await;
1235            }
1236        };
1237        if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1238            return self.aborted_txn_error().await;
1239        }
1240        let buf = RowArena::new();
1241        let mut params = vec![];
1242        for ((raw_param, mz_typ), format) in raw_params
1243            .into_iter()
1244            .zip_eq(param_types)
1245            .zip_eq(param_formats)
1246        {
1247            let pg_typ = mz_pgrepr::Type::from(mz_typ);
1248            let datum = match raw_param {
1249                None => Datum::Null,
1250                Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1251                    Ok(param) => param.into_datum(&buf, &pg_typ),
1252                    Err(err) => {
1253                        let msg = format!("unable to decode parameter: {}", err);
1254                        return self
1255                            .error(ErrorResponse::error(SqlState::INVALID_PARAMETER_VALUE, msg))
1256                            .await;
1257                    }
1258                },
1259            };
1260            params.push((datum, mz_typ.clone()))
1261        }
1262
1263        let result_formats = match pad_formats(
1264            result_formats,
1265            stmt.desc()
1266                .relation_desc
1267                .clone()
1268                .map(|desc| desc.typ().column_types.len())
1269                .unwrap_or(0),
1270        ) {
1271            Ok(result_formats) => result_formats,
1272            Err(msg) => {
1273                return self
1274                    .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1275                    .await;
1276            }
1277        };
1278
1279        // Binary encodings are disabled for list, map, and aclitem types, but this doesn't
1280        // apply to COPY TO statements.
1281        if !stmt.stmt().map_or(false, |stmt| {
1282            matches!(
1283                stmt,
1284                Statement::Copy(CopyStatement {
1285                    direction: CopyDirection::To,
1286                    ..
1287                })
1288            )
1289        }) {
1290            if let Some(desc) = stmt.desc().relation_desc.clone() {
1291                for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1292                    match (format, &ty.scalar_type) {
1293                        (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1294                            return self
1295                                .error(ErrorResponse::error(
1296                                    SqlState::PROTOCOL_VIOLATION,
1297                                    "binary encoding of list types is not implemented",
1298                                ))
1299                                .await;
1300                        }
1301                        (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1302                            return self
1303                                .error(ErrorResponse::error(
1304                                    SqlState::PROTOCOL_VIOLATION,
1305                                    "binary encoding of map types is not implemented",
1306                                ))
1307                                .await;
1308                        }
1309                        (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1310                            return self
1311                                .error(ErrorResponse::error(
1312                                    SqlState::PROTOCOL_VIOLATION,
1313                                    "binary encoding of aclitem types does not exist",
1314                                ))
1315                                .await;
1316                        }
1317                        _ => (),
1318                    }
1319                }
1320            }
1321        }
1322
1323        let desc = stmt.desc().clone();
1324        let logging = Arc::clone(stmt.logging());
1325        let stmt_ast = stmt.stmt().cloned();
1326        let state_revision = stmt.state_revision;
1327        if let Err(err) = self.adapter_client.session().set_portal(
1328            portal_name,
1329            desc,
1330            stmt_ast,
1331            logging,
1332            params,
1333            result_formats,
1334            state_revision,
1335        ) {
1336            return self.error(err.into_response(Severity::Error)).await;
1337        }
1338
1339        self.send(BackendMessage::BindComplete).await?;
1340        Ok(State::Ready)
1341    }
1342
1343    fn execute(
1344        &mut self,
1345        portal_name: String,
1346        max_rows: ExecuteCount,
1347        get_response: GetResponse,
1348        fetch_portal_name: Option<String>,
1349        timeout: ExecuteTimeout,
1350        outer_ctx_extra: Option<ExecuteContextExtra>,
1351        received: Option<EpochMillis>,
1352    ) -> BoxFuture<'_, Result<State, io::Error>> {
1353        async move {
1354            let aborted_txn = self.is_aborted_txn();
1355
1356            // Check if the portal has been started and can be continued.
1357            let portal = match self
1358                .adapter_client
1359                .session()
1360                .get_portal_unverified_mut(&portal_name)
1361            {
1362                Some(portal) => portal,
1363                None => {
1364                    let msg = format!("portal {} does not exist", portal_name.quoted());
1365                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1366                        self.adapter_client.retire_execute(
1367                            outer_ctx_extra,
1368                            StatementEndedExecutionReason::Errored { error: msg.clone() },
1369                        );
1370                    }
1371                    return self
1372                        .error(ErrorResponse::error(SqlState::INVALID_CURSOR_NAME, msg))
1373                        .await;
1374                }
1375            };
1376
1377            *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1378
1379            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1380            let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1381            if aborted_txn && !txn_exit_stmt {
1382                if let Some(outer_ctx_extra) = outer_ctx_extra {
1383                    self.adapter_client.retire_execute(
1384                        outer_ctx_extra,
1385                        StatementEndedExecutionReason::Errored {
1386                            error: ABORTED_TXN_MSG.to_string(),
1387                        },
1388                    );
1389                }
1390                return self.aborted_txn_error().await;
1391            }
1392
1393            let row_desc = portal.desc.relation_desc.clone();
1394            match portal.state {
1395                PortalState::NotStarted => {
1396                    // Start a transaction if we aren't in one.
1397                    self.ensure_transaction(1, "execute").await?;
1398                    match self
1399                        .adapter_client
1400                        .execute(
1401                            portal_name.clone(),
1402                            self.conn.wait_closed(),
1403                            outer_ctx_extra,
1404                        )
1405                        .await
1406                    {
1407                        Ok((response, execute_started)) => {
1408                            self.send_pending_notices().await?;
1409                            self.send_execute_response(
1410                                response,
1411                                row_desc,
1412                                portal_name,
1413                                max_rows,
1414                                get_response,
1415                                fetch_portal_name,
1416                                timeout,
1417                                execute_started,
1418                            )
1419                            .await
1420                        }
1421                        Err(e) => {
1422                            self.send_pending_notices().await?;
1423                            self.error(e.into_response(Severity::Error)).await
1424                        }
1425                    }
1426                }
1427                PortalState::InProgress(rows) => {
1428                    let rows = rows.take().expect("InProgress rows must be populated");
1429                    let (result, statement_ended_execution_reason) = match self
1430                        .send_rows(
1431                            row_desc.expect("portal missing row desc on resumption"),
1432                            portal_name,
1433                            rows,
1434                            max_rows,
1435                            get_response,
1436                            fetch_portal_name,
1437                            timeout,
1438                        )
1439                        .await
1440                    {
1441                        Err(e) => {
1442                            // This is an error communicating with the connection.
1443                            // We consider that to be a cancelation, rather than a query error.
1444                            (Err(e), StatementEndedExecutionReason::Canceled)
1445                        }
1446                        Ok((ok, SendRowsEndedReason::Canceled)) => {
1447                            (Ok(ok), StatementEndedExecutionReason::Canceled)
1448                        }
1449                        // NOTE: For now the values for `result_size` and
1450                        // `rows_returned` in fetches are a bit confusing.
1451                        // We record `Some(n)` for the first fetch, where `n` is
1452                        // the number of bytes/rows returned by the inner
1453                        // execute (regardless of how many rows the
1454                        // fetch fetched), and `None` for subsequent fetches.
1455                        //
1456                        // This arguably makes sense since the size/rows
1457                        // returned measures how much work the compute
1458                        // layer had to do to satisfy the query, but
1459                        // we should revisit it if/when we start
1460                        // logging the inner execute separately.
1461                        Ok((
1462                            ok,
1463                            SendRowsEndedReason::Success {
1464                                result_size: _,
1465                                rows_returned: _,
1466                            },
1467                        )) => (
1468                            Ok(ok),
1469                            StatementEndedExecutionReason::Success {
1470                                result_size: None,
1471                                rows_returned: None,
1472                                execution_strategy: None,
1473                            },
1474                        ),
1475                        Ok((ok, SendRowsEndedReason::Errored { error })) => {
1476                            (Ok(ok), StatementEndedExecutionReason::Errored { error })
1477                        }
1478                    };
1479                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1480                        self.adapter_client
1481                            .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1482                    }
1483                    result
1484                }
1485                // FETCH is an awkward command for our current architecture. In Postgres it
1486                // will extract <count> rows from the target portal, cache them, and return
1487                // them to the user as requested. Its command tag is always FETCH <num rows
1488                // extracted>. In Materialize, since we have chosen to not fully support FETCH,
1489                // we must remember the number of rows that were returned. Use this tag to
1490                // remember that information and return it.
1491                PortalState::Completed(Some(tag)) => {
1492                    let tag = tag.to_string();
1493                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1494                        self.adapter_client.retire_execute(
1495                            outer_ctx_extra,
1496                            StatementEndedExecutionReason::Success {
1497                                result_size: None,
1498                                rows_returned: None,
1499                                execution_strategy: None,
1500                            },
1501                        );
1502                    }
1503                    self.send(BackendMessage::CommandComplete { tag }).await?;
1504                    Ok(State::Ready)
1505                }
1506                PortalState::Completed(None) => {
1507                    let error = format!(
1508                        "portal {} cannot be run",
1509                        Ident::new_unchecked(portal_name).to_ast_string_stable()
1510                    );
1511                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1512                        self.adapter_client.retire_execute(
1513                            outer_ctx_extra,
1514                            StatementEndedExecutionReason::Errored {
1515                                error: error.clone(),
1516                            },
1517                        );
1518                    }
1519                    self.error(ErrorResponse::error(
1520                        SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1521                        error,
1522                    ))
1523                    .await
1524                }
1525            }
1526        }
1527        .instrument(debug_span!("execute"))
1528        .boxed()
1529    }
1530
1531    #[instrument(level = "debug")]
1532    async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1533        // Start a transaction if we aren't in one.
1534        self.ensure_transaction(1, "describe_statement").await?;
1535
1536        let stmt = match self.adapter_client.get_prepared_statement(name).await {
1537            Ok(stmt) => stmt,
1538            Err(err) => return self.error(err.into_response(Severity::Error)).await,
1539        };
1540        // Cloning to avoid a mutable borrow issue because `send` also uses `adapter_client`
1541        let parameter_desc = BackendMessage::ParameterDescription(
1542            stmt.desc()
1543                .param_types
1544                .iter()
1545                .map(mz_pgrepr::Type::from)
1546                .collect(),
1547        );
1548        // Claim that all results will be output in text format, even
1549        // though the true result formats are not yet known. A bit
1550        // weird, but this is the behavior that PostgreSQL specifies.
1551        let formats = vec![Format::Text; stmt.desc().arity()];
1552        let row_desc = describe_rows(stmt.desc(), &formats);
1553        self.send_all([parameter_desc, row_desc]).await?;
1554        Ok(State::Ready)
1555    }
1556
1557    #[instrument(level = "debug")]
1558    async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1559        // Start a transaction if we aren't in one.
1560        self.ensure_transaction(1, "describe_portal").await?;
1561
1562        let session = self.adapter_client.session();
1563        let row_desc = session
1564            .get_portal_unverified(name)
1565            .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1566        match row_desc {
1567            Some(row_desc) => {
1568                self.send(row_desc).await?;
1569                Ok(State::Ready)
1570            }
1571            None => {
1572                self.error(ErrorResponse::error(
1573                    SqlState::INVALID_CURSOR_NAME,
1574                    format!("portal {} does not exist", name.quoted()),
1575                ))
1576                .await
1577            }
1578        }
1579    }
1580
1581    #[instrument(level = "debug")]
1582    async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1583        self.adapter_client
1584            .session()
1585            .remove_prepared_statement(&name);
1586        self.send(BackendMessage::CloseComplete).await?;
1587        Ok(State::Ready)
1588    }
1589
1590    #[instrument(level = "debug")]
1591    async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1592        self.adapter_client.session().remove_portal(&name);
1593        self.send(BackendMessage::CloseComplete).await?;
1594        Ok(State::Ready)
1595    }
1596
1597    fn complete_portal(&mut self, name: &str) {
1598        let portal = self
1599            .adapter_client
1600            .session()
1601            .get_portal_unverified_mut(name)
1602            .expect("portal should exist");
1603        *portal.state = PortalState::Completed(None);
1604    }
1605
1606    async fn fetch(
1607        &mut self,
1608        name: String,
1609        count: Option<FetchDirection>,
1610        max_rows: ExecuteCount,
1611        fetch_portal_name: Option<String>,
1612        timeout: ExecuteTimeout,
1613        ctx_extra: ExecuteContextExtra,
1614    ) -> Result<State, io::Error> {
1615        // Unlike Execute, no count specified in FETCH returns 1 row, and 0 means 0
1616        // instead of All.
1617        let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1618
1619        // Figure out how many rows we should send back by looking at the various
1620        // combinations of the execute and fetch.
1621        //
1622        // In Postgres, Fetch will cache <count> rows from the target portal and
1623        // return those as requested (if, say, an Execute message was sent with a
1624        // max_rows < the Fetch's count). We expect that case to be incredibly rare and
1625        // so have chosen to not support it until users request it. This eases
1626        // implementation difficulty since we don't have to be able to "send" rows to
1627        // a buffer.
1628        //
1629        // TODO(mjibson): Test this somehow? Need to divide up the pgtest files in
1630        // order to have some that are not Postgres compatible.
1631        let count = match (max_rows, count) {
1632            (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1633                let count = usize::cast_from(count);
1634                if max_rows < count {
1635                    let msg = "Execute with max_rows < a FETCH's count is not supported";
1636                    self.adapter_client.retire_execute(
1637                        ctx_extra,
1638                        StatementEndedExecutionReason::Errored {
1639                            error: msg.to_string(),
1640                        },
1641                    );
1642                    return self
1643                        .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1644                        .await;
1645                }
1646                ExecuteCount::Count(count)
1647            }
1648            (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1649                let msg = "Execute with max_rows of a FETCH ALL is not supported";
1650                self.adapter_client.retire_execute(
1651                    ctx_extra,
1652                    StatementEndedExecutionReason::Errored {
1653                        error: msg.to_string(),
1654                    },
1655                );
1656                return self
1657                    .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1658                    .await;
1659            }
1660            (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1661            (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1662                ExecuteCount::Count(usize::cast_from(count))
1663            }
1664        };
1665        let cursor_name = name.to_string();
1666        self.execute(
1667            cursor_name,
1668            count,
1669            fetch_message,
1670            fetch_portal_name,
1671            timeout,
1672            Some(ctx_extra),
1673            None,
1674        )
1675        .await
1676    }
1677
1678    async fn flush(&mut self) -> Result<State, io::Error> {
1679        self.conn.flush().await?;
1680        Ok(State::Ready)
1681    }
1682
1683    /// Sends a backend message to the client, after applying a severity filter.
1684    ///
1685    /// The message is only sent if its severity is above the severity set
1686    /// in the session, with the default value being NOTICE.
1687    #[instrument(level = "debug")]
1688    async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1689    where
1690        M: Into<BackendMessage>,
1691    {
1692        let message: BackendMessage = message.into();
1693        let is_error =
1694            matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1695
1696        self.conn.send(message).await?;
1697
1698        // Flush immediately after sending an error response, as some clients
1699        // expect to be able to read the error response before sending a Sync
1700        // message. This is arguably in violation of the protocol specification,
1701        // but the specification is somewhat ambiguous, and easier to match
1702        // PostgreSQL here than to fix all the clients that have this
1703        // expectation.
1704        if is_error {
1705            self.conn.flush().await?;
1706        }
1707
1708        Ok(())
1709    }
1710
1711    #[instrument(level = "debug")]
1712    pub async fn send_all(
1713        &mut self,
1714        messages: impl IntoIterator<Item = BackendMessage>,
1715    ) -> Result<(), io::Error> {
1716        for m in messages {
1717            self.send(m).await?;
1718        }
1719        Ok(())
1720    }
1721
1722    #[instrument(level = "debug")]
1723    async fn sync(&mut self) -> Result<State, io::Error> {
1724        // Close the current transaction if we are in an implicit transaction.
1725        if self.adapter_client.session().transaction().is_implicit() {
1726            self.commit_transaction().await?;
1727        }
1728        self.ready().await
1729    }
1730
1731    #[instrument(level = "debug")]
1732    async fn ready(&mut self) -> Result<State, io::Error> {
1733        let txn_state = self.adapter_client.session().transaction().into();
1734        self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1735        self.flush().await
1736    }
1737
1738    #[allow(clippy::too_many_arguments)]
1739    #[instrument(level = "debug")]
1740    async fn send_execute_response(
1741        &mut self,
1742        response: ExecuteResponse,
1743        row_desc: Option<RelationDesc>,
1744        portal_name: String,
1745        max_rows: ExecuteCount,
1746        get_response: GetResponse,
1747        fetch_portal_name: Option<String>,
1748        timeout: ExecuteTimeout,
1749        execute_started: Instant,
1750    ) -> Result<State, io::Error> {
1751        let mut tag = response.tag();
1752
1753        macro_rules! command_complete {
1754            () => {{
1755                self.send(BackendMessage::CommandComplete {
1756                    tag: tag
1757                        .take()
1758                        .expect("command_complete only called on tag-generating results"),
1759                })
1760                .await?;
1761                Ok(State::Ready)
1762            }};
1763        }
1764
1765        let r = match response {
1766            ExecuteResponse::ClosedCursor => {
1767                self.complete_portal(&portal_name);
1768                command_complete!()
1769            }
1770            ExecuteResponse::DeclaredCursor => {
1771                self.complete_portal(&portal_name);
1772                command_complete!()
1773            }
1774            ExecuteResponse::EmptyQuery => {
1775                self.send(BackendMessage::EmptyQueryResponse).await?;
1776                Ok(State::Ready)
1777            }
1778            ExecuteResponse::Fetch {
1779                name,
1780                count,
1781                timeout,
1782                ctx_extra,
1783            } => {
1784                self.fetch(
1785                    name,
1786                    count,
1787                    max_rows,
1788                    Some(portal_name.to_string()),
1789                    timeout,
1790                    ctx_extra,
1791                )
1792                .await
1793            }
1794            ExecuteResponse::SendingRowsStreaming {
1795                rows,
1796                instance_id,
1797                strategy,
1798            } => {
1799                let row_desc = row_desc
1800                    .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1801
1802                let span = tracing::debug_span!("sending_rows_streaming");
1803
1804                self.send_rows(
1805                    row_desc,
1806                    portal_name,
1807                    InProgressRows::new(RecordFirstRowStream::new(
1808                        Box::new(rows),
1809                        execute_started,
1810                        &self.adapter_client,
1811                        Some(instance_id),
1812                        Some(strategy),
1813                    )),
1814                    max_rows,
1815                    get_response,
1816                    fetch_portal_name,
1817                    timeout,
1818                )
1819                .instrument(span)
1820                .await
1821                .map(|(state, _)| state)
1822            }
1823            ExecuteResponse::SendingRowsImmediate { rows } => {
1824                let row_desc = row_desc
1825                    .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1826
1827                let span = tracing::debug_span!("sending_rows_immediate");
1828
1829                let stream =
1830                    futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1831                self.send_rows(
1832                    row_desc,
1833                    portal_name,
1834                    InProgressRows::new(RecordFirstRowStream::new(
1835                        Box::new(stream),
1836                        execute_started,
1837                        &self.adapter_client,
1838                        None,
1839                        Some(StatementExecutionStrategy::Constant),
1840                    )),
1841                    max_rows,
1842                    get_response,
1843                    fetch_portal_name,
1844                    timeout,
1845                )
1846                .instrument(span)
1847                .await
1848                .map(|(state, _)| state)
1849            }
1850            ExecuteResponse::SetVariable { name, .. } => {
1851                // This code is somewhat awkwardly structured because we
1852                // can't hold `var` across an await point.
1853                let qn = name.to_string();
1854                let msg = if let Some(var) = self
1855                    .adapter_client
1856                    .session()
1857                    .vars_mut()
1858                    .notify_set()
1859                    .find(|v| v.name() == qn)
1860                {
1861                    Some(BackendMessage::ParameterStatus(var.name(), var.value()))
1862                } else {
1863                    None
1864                };
1865                if let Some(msg) = msg {
1866                    self.send(msg).await?;
1867                }
1868                command_complete!()
1869            }
1870            ExecuteResponse::Subscribing {
1871                rx,
1872                ctx_extra,
1873                instance_id,
1874            } => {
1875                if fetch_portal_name.is_none() {
1876                    let mut msg = ErrorResponse::notice(
1877                        SqlState::WARNING,
1878                        "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
1879                    );
1880                    if self.adapter_client.session().vars().application_name() == "psql" {
1881                        msg.hint = Some(
1882                            "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
1883                                .into(),
1884                        )
1885                    }
1886                    self.send(msg).await?;
1887                    self.conn.flush().await?;
1888                }
1889                let row_desc =
1890                    row_desc.expect("missing row description for ExecuteResponse::Subscribing");
1891                let (result, statement_ended_execution_reason) = match self
1892                    .send_rows(
1893                        row_desc,
1894                        portal_name,
1895                        InProgressRows::new(RecordFirstRowStream::new(
1896                            Box::new(UnboundedReceiverStream::new(rx)),
1897                            execute_started,
1898                            &self.adapter_client,
1899                            Some(instance_id),
1900                            None,
1901                        )),
1902                        max_rows,
1903                        get_response,
1904                        fetch_portal_name,
1905                        timeout,
1906                    )
1907                    .await
1908                {
1909                    Err(e) => {
1910                        // This is an error communicating with the connection.
1911                        // We consider that to be a cancelation, rather than a query error.
1912                        (Err(e), StatementEndedExecutionReason::Canceled)
1913                    }
1914                    Ok((ok, SendRowsEndedReason::Canceled)) => {
1915                        (Ok(ok), StatementEndedExecutionReason::Canceled)
1916                    }
1917                    Ok((
1918                        ok,
1919                        SendRowsEndedReason::Success {
1920                            result_size,
1921                            rows_returned,
1922                        },
1923                    )) => (
1924                        Ok(ok),
1925                        StatementEndedExecutionReason::Success {
1926                            result_size: Some(result_size),
1927                            rows_returned: Some(rows_returned),
1928                            execution_strategy: None,
1929                        },
1930                    ),
1931                    Ok((ok, SendRowsEndedReason::Errored { error })) => {
1932                        (Ok(ok), StatementEndedExecutionReason::Errored { error })
1933                    }
1934                };
1935                self.adapter_client
1936                    .retire_execute(ctx_extra, statement_ended_execution_reason);
1937                return result;
1938            }
1939            ExecuteResponse::CopyTo { format, resp } => {
1940                let row_desc =
1941                    row_desc.expect("missing row description for ExecuteResponse::CopyTo");
1942                match *resp {
1943                    ExecuteResponse::Subscribing {
1944                        rx,
1945                        ctx_extra,
1946                        instance_id,
1947                    } => {
1948                        let (result, statement_ended_execution_reason) = match self
1949                            .copy_rows(
1950                                format,
1951                                row_desc,
1952                                RecordFirstRowStream::new(
1953                                    Box::new(UnboundedReceiverStream::new(rx)),
1954                                    execute_started,
1955                                    &self.adapter_client,
1956                                    Some(instance_id),
1957                                    None,
1958                                ),
1959                            )
1960                            .await
1961                        {
1962                            Err(e) => {
1963                                // This is an error communicating with the connection.
1964                                // We consider that to be a cancelation, rather than a query error.
1965                                (Err(e), StatementEndedExecutionReason::Canceled)
1966                            }
1967                            Ok((
1968                                state,
1969                                SendRowsEndedReason::Success {
1970                                    result_size,
1971                                    rows_returned,
1972                                },
1973                            )) => (
1974                                Ok(state),
1975                                StatementEndedExecutionReason::Success {
1976                                    result_size: Some(result_size),
1977                                    rows_returned: Some(rows_returned),
1978                                    execution_strategy: None,
1979                                },
1980                            ),
1981                            Ok((state, SendRowsEndedReason::Errored { error })) => {
1982                                (Ok(state), StatementEndedExecutionReason::Errored { error })
1983                            }
1984                            Ok((state, SendRowsEndedReason::Canceled)) => {
1985                                (Ok(state), StatementEndedExecutionReason::Canceled)
1986                            }
1987                        };
1988                        self.adapter_client
1989                            .retire_execute(ctx_extra, statement_ended_execution_reason);
1990                        return result;
1991                    }
1992                    ExecuteResponse::SendingRowsStreaming {
1993                        rows,
1994                        instance_id,
1995                        strategy,
1996                    } => {
1997                        // We don't need to finalize execution here;
1998                        // it was already done in the
1999                        // coordinator. Just extract the state and
2000                        // return that.
2001                        return self
2002                            .copy_rows(
2003                                format,
2004                                row_desc,
2005                                RecordFirstRowStream::new(
2006                                    Box::new(rows),
2007                                    execute_started,
2008                                    &self.adapter_client,
2009                                    Some(instance_id),
2010                                    Some(strategy),
2011                                ),
2012                            )
2013                            .await
2014                            .map(|(state, _)| state);
2015                    }
2016                    ExecuteResponse::SendingRowsImmediate { rows } => {
2017                        let span = tracing::debug_span!("sending_rows_immediate");
2018
2019                        let rows = futures::stream::once(futures::future::ready(
2020                            PeekResponseUnary::Rows(rows),
2021                        ));
2022                        // We don't need to finalize execution here;
2023                        // it was already done in the
2024                        // coordinator. Just extract the state and
2025                        // return that.
2026                        return self
2027                            .copy_rows(
2028                                format,
2029                                row_desc,
2030                                RecordFirstRowStream::new(
2031                                    Box::new(rows),
2032                                    execute_started,
2033                                    &self.adapter_client,
2034                                    None,
2035                                    Some(StatementExecutionStrategy::Constant),
2036                                ),
2037                            )
2038                            .instrument(span)
2039                            .await
2040                            .map(|(state, _)| state);
2041                    }
2042                    _ => {
2043                        return self
2044                            .error(ErrorResponse::error(
2045                                SqlState::INTERNAL_ERROR,
2046                                "unsupported COPY response type".to_string(),
2047                            ))
2048                            .await;
2049                    }
2050                };
2051            }
2052            ExecuteResponse::CopyFrom {
2053                target_id,
2054                target_name,
2055                columns,
2056                params,
2057                ctx_extra,
2058            } => {
2059                let row_desc =
2060                    row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2061                self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2062                    .await
2063            }
2064            ExecuteResponse::TransactionCommitted { params }
2065            | ExecuteResponse::TransactionRolledBack { params } => {
2066                let notify_set: mz_ore::collections::HashSet<String> = self
2067                    .adapter_client
2068                    .session()
2069                    .vars()
2070                    .notify_set()
2071                    .map(|v| v.name().to_string())
2072                    .collect();
2073
2074                // Only report on parameters that are in the notify set.
2075                for (name, value) in params
2076                    .into_iter()
2077                    .filter(|(name, _v)| notify_set.contains(*name))
2078                {
2079                    let msg = BackendMessage::ParameterStatus(name, value);
2080                    self.send(msg).await?;
2081                }
2082                command_complete!()
2083            }
2084
2085            ExecuteResponse::AlteredDefaultPrivileges
2086            | ExecuteResponse::AlteredObject(..)
2087            | ExecuteResponse::AlteredRole
2088            | ExecuteResponse::AlteredSystemConfiguration
2089            | ExecuteResponse::CreatedCluster { .. }
2090            | ExecuteResponse::CreatedClusterReplica { .. }
2091            | ExecuteResponse::CreatedConnection { .. }
2092            | ExecuteResponse::CreatedDatabase { .. }
2093            | ExecuteResponse::CreatedIndex { .. }
2094            | ExecuteResponse::CreatedIntrospectionSubscribe
2095            | ExecuteResponse::CreatedMaterializedView { .. }
2096            | ExecuteResponse::CreatedContinualTask { .. }
2097            | ExecuteResponse::CreatedRole
2098            | ExecuteResponse::CreatedSchema { .. }
2099            | ExecuteResponse::CreatedSecret { .. }
2100            | ExecuteResponse::CreatedSink { .. }
2101            | ExecuteResponse::CreatedSource { .. }
2102            | ExecuteResponse::CreatedTable { .. }
2103            | ExecuteResponse::CreatedType
2104            | ExecuteResponse::CreatedView { .. }
2105            | ExecuteResponse::CreatedViews { .. }
2106            | ExecuteResponse::CreatedNetworkPolicy
2107            | ExecuteResponse::Comment
2108            | ExecuteResponse::Deallocate { .. }
2109            | ExecuteResponse::Deleted(..)
2110            | ExecuteResponse::DiscardedAll
2111            | ExecuteResponse::DiscardedTemp
2112            | ExecuteResponse::DroppedObject(_)
2113            | ExecuteResponse::DroppedOwned
2114            | ExecuteResponse::GrantedPrivilege
2115            | ExecuteResponse::GrantedRole
2116            | ExecuteResponse::Inserted(..)
2117            | ExecuteResponse::Copied(..)
2118            | ExecuteResponse::Prepare
2119            | ExecuteResponse::Raised
2120            | ExecuteResponse::ReassignOwned
2121            | ExecuteResponse::RevokedPrivilege
2122            | ExecuteResponse::RevokedRole
2123            | ExecuteResponse::StartedTransaction { .. }
2124            | ExecuteResponse::Updated(..)
2125            | ExecuteResponse::ValidatedConnection => {
2126                command_complete!()
2127            }
2128        };
2129
2130        assert_none!(tag, "tag created but not consumed: {:?}", tag);
2131        r
2132    }
2133
2134    #[allow(clippy::too_many_arguments)]
2135    // TODO(guswynn): figure out how to get it to compile without skip_all
2136    #[mz_ore::instrument(level = "debug")]
2137    async fn send_rows(
2138        &mut self,
2139        row_desc: RelationDesc,
2140        portal_name: String,
2141        mut rows: InProgressRows,
2142        max_rows: ExecuteCount,
2143        get_response: GetResponse,
2144        fetch_portal_name: Option<String>,
2145        timeout: ExecuteTimeout,
2146    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2147        // If this portal is being executed from a FETCH then we need to use the result
2148        // format type of the outer portal.
2149        let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2150            name
2151        } else {
2152            &portal_name
2153        };
2154        let result_formats = self
2155            .adapter_client
2156            .session()
2157            .get_portal_unverified(result_format_portal_name)
2158            .expect("valid fetch portal name for send rows")
2159            .result_formats
2160            .clone();
2161
2162        let (mut wait_once, mut deadline) = match timeout {
2163            ExecuteTimeout::None => (false, None),
2164            ExecuteTimeout::Seconds(t) => (
2165                false,
2166                Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2167            ),
2168            ExecuteTimeout::WaitOnce => (true, None),
2169        };
2170
2171        // Sanity check that the various `RelationDesc`s match up.
2172        {
2173            let portal_name_desc = &self
2174                .adapter_client
2175                .session()
2176                .get_portal_unverified(portal_name.as_str())
2177                .expect("portal should exist")
2178                .desc
2179                .relation_desc;
2180            if let Some(portal_name_desc) = portal_name_desc {
2181                soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2182            }
2183            if let Some(fetch_portal_name) = &fetch_portal_name {
2184                let fetch_portal_desc = &self
2185                    .adapter_client
2186                    .session()
2187                    .get_portal_unverified(fetch_portal_name)
2188                    .expect("portal should exist")
2189                    .desc
2190                    .relation_desc;
2191                if let Some(fetch_portal_desc) = fetch_portal_desc {
2192                    soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2193                }
2194            }
2195        }
2196
2197        self.conn.set_encode_state(
2198            row_desc
2199                .typ()
2200                .column_types
2201                .iter()
2202                .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2203                .zip_eq(result_formats)
2204                .collect(),
2205        );
2206
2207        let mut total_sent_rows = 0;
2208        let mut total_sent_bytes = 0;
2209        // want_rows is the maximum number of rows the client wants.
2210        let mut want_rows = match max_rows {
2211            ExecuteCount::All => usize::MAX,
2212            ExecuteCount::Count(count) => count,
2213        };
2214
2215        // Send rows while the client still wants them and there are still rows to send.
2216        loop {
2217            // Fetch next batch of rows, waiting for a possible requested
2218            // timeout or notice.
2219            let batch = if rows.current.is_some() {
2220                FetchResult::Rows(rows.current.take())
2221            } else if want_rows == 0 {
2222                FetchResult::Rows(None)
2223            } else {
2224                let notice_fut = self.adapter_client.session().recv_notice();
2225                tokio::select! {
2226                    err = self.conn.wait_closed() => return Err(err),
2227                    _ = time::sleep_until(deadline.unwrap_or_else(tokio::time::Instant::now)), if deadline.is_some() => FetchResult::Rows(None),
2228                    notice = notice_fut => {
2229                        FetchResult::Notice(notice)
2230                    }
2231                    batch = rows.remaining.recv() => match batch {
2232                        None => FetchResult::Rows(None),
2233                        Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2234                        Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2235                        Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2236                    },
2237                }
2238            };
2239
2240            match batch {
2241                FetchResult::Rows(None) => break,
2242                FetchResult::Rows(Some(mut batch_rows)) => {
2243                    if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2244                        let msg = err.to_string();
2245                        return self
2246                            .error(err.into_response(Severity::Error))
2247                            .await
2248                            .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2249                    }
2250
2251                    // If wait_once is true: the first time this fn is called it blocks (same as
2252                    // deadline == None). The second time this fn is called it should behave the
2253                    // same a 0s timeout.
2254                    if wait_once && batch_rows.peek().is_some() {
2255                        deadline = Some(tokio::time::Instant::now());
2256                        wait_once = false;
2257                    }
2258
2259                    // Send a portion of the rows.
2260                    let mut sent_rows = 0;
2261                    let mut sent_bytes = 0;
2262                    let messages = (&mut batch_rows)
2263                        // TODO(parkmycar): This is a fair bit of juggling between iterator types
2264                        // to count the total number of bytes. Alternatively we could track the
2265                        // total sent bytes in this .map(...) call, but having side effects in map
2266                        // is a code smell.
2267                        .map(|row| {
2268                            let row_len = row.byte_len();
2269                            let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2270                            (row_len, BackendMessage::DataRow(values))
2271                        })
2272                        .inspect(|(row_len, _)| {
2273                            sent_bytes += row_len;
2274                            sent_rows += 1
2275                        })
2276                        .map(|(_row_len, row)| row)
2277                        .take(want_rows);
2278                    self.send_all(messages).await?;
2279
2280                    total_sent_rows += sent_rows;
2281                    total_sent_bytes += sent_bytes;
2282                    want_rows -= sent_rows;
2283
2284                    // If we have sent the number of requested rows, put the remainder of the batch
2285                    // (if any) back and stop sending.
2286                    if want_rows == 0 {
2287                        if batch_rows.peek().is_some() {
2288                            rows.current = Some(batch_rows);
2289                        }
2290                        break;
2291                    }
2292
2293                    self.conn.flush().await?;
2294                }
2295                FetchResult::Notice(notice) => {
2296                    self.send(notice.into_response()).await?;
2297                    self.conn.flush().await?;
2298                }
2299                FetchResult::Error(text) => {
2300                    return self
2301                        .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2302                        .await
2303                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2304                }
2305                FetchResult::Canceled => {
2306                    return self
2307                        .error(ErrorResponse::error(
2308                            SqlState::QUERY_CANCELED,
2309                            "canceling statement due to user request",
2310                        ))
2311                        .await
2312                        .map(|state| (state, SendRowsEndedReason::Canceled));
2313                }
2314            }
2315        }
2316
2317        let portal = self
2318            .adapter_client
2319            .session()
2320            .get_portal_unverified_mut(&portal_name)
2321            .expect("valid portal name for send rows");
2322
2323        let saw_rows = rows.remaining.saw_rows;
2324        let no_more_rows = rows.no_more_rows();
2325        let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2326
2327        // Always return rows back, even if it's empty. This prevents an unclosed
2328        // portal from re-executing after it has been emptied.
2329        *portal.state = PortalState::InProgress(Some(rows));
2330
2331        let fetch_portal = fetch_portal_name.map(|name| {
2332            self.adapter_client
2333                .session()
2334                .get_portal_unverified_mut(&name)
2335                .expect("valid fetch portal")
2336        });
2337        let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2338        self.send(response_message).await?;
2339
2340        // Attend to metrics if there are no more rows.
2341        if no_more_rows {
2342            let statement_type = if let Some(stmt) = &self
2343                .adapter_client
2344                .session()
2345                .get_portal_unverified(&portal_name)
2346                .expect("valid portal name for send_rows")
2347                .stmt
2348            {
2349                metrics::statement_type_label_value(stmt.deref())
2350            } else {
2351                "no-statement"
2352            };
2353            let duration = if saw_rows {
2354                recorded_first_row_instant
2355                    .expect("recorded_first_row_instant because saw_rows")
2356                    .elapsed()
2357            } else {
2358                // If the result is empty, then we define time from first to last row as 0.
2359                // (Note that, currently, an empty result involves a PeekResponse with 0 rows, which
2360                // does flip `saw_rows`, so this code path is currently not exercised.)
2361                Duration::ZERO
2362            };
2363            self.adapter_client
2364                .inner()
2365                .metrics()
2366                .result_rows_first_to_last_byte_seconds
2367                .with_label_values(&[statement_type])
2368                .observe(duration.as_secs_f64());
2369        }
2370
2371        Ok((
2372            State::Ready,
2373            SendRowsEndedReason::Success {
2374                result_size: u64::cast_from(total_sent_bytes),
2375                rows_returned: u64::cast_from(total_sent_rows),
2376            },
2377        ))
2378    }
2379
2380    #[mz_ore::instrument(level = "debug")]
2381    async fn copy_rows(
2382        &mut self,
2383        format: CopyFormat,
2384        row_desc: RelationDesc,
2385        mut stream: RecordFirstRowStream,
2386    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2387        let (row_format, encode_format) = match format {
2388            CopyFormat::Text => (
2389                CopyFormatParams::Text(CopyTextFormatParams::default()),
2390                Format::Text,
2391            ),
2392            CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2393            CopyFormat::Csv => (
2394                CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2395                Format::Text,
2396            ),
2397            CopyFormat::Parquet => {
2398                let text = "Parquet format is not supported".to_string();
2399                return self
2400                    .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2401                    .await
2402                    .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2403            }
2404        };
2405
2406        let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2407            mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2408        };
2409
2410        let typ = row_desc.typ();
2411        let column_formats = iter::repeat(encode_format)
2412            .take(typ.column_types.len())
2413            .collect();
2414        self.send(BackendMessage::CopyOutResponse {
2415            overall_format: encode_format,
2416            column_formats,
2417        })
2418        .await?;
2419
2420        // In Postgres, binary copy has a header that is followed (in the same
2421        // CopyData) by the first row. In order to replicate their behavior, use a
2422        // common vec that we can extend one time now and then fill up with the encode
2423        // functions.
2424        let mut out = Vec::new();
2425
2426        if let CopyFormat::Binary = format {
2427            // 11-byte signature.
2428            out.extend(b"PGCOPY\n\xFF\r\n\0");
2429            // 32-bit flags field.
2430            out.extend([0, 0, 0, 0]);
2431            // 32-bit header extension length field.
2432            out.extend([0, 0, 0, 0]);
2433        }
2434
2435        let mut count = 0;
2436        let mut total_sent_bytes = 0;
2437        loop {
2438            tokio::select! {
2439                e = self.conn.wait_closed() => return Err(e),
2440                batch = stream.recv() => match batch {
2441                    None => break,
2442                    Some(PeekResponseUnary::Error(text)) => {
2443                        return self
2444                            .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2445                        .await
2446                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2447                    }
2448                    Some(PeekResponseUnary::Canceled) => {
2449                        return self.error(ErrorResponse::error(
2450                                SqlState::QUERY_CANCELED,
2451                                "canceling statement due to user request",
2452                            ))
2453                            .await.map(|state| (state, SendRowsEndedReason::Canceled));
2454                    }
2455                    Some(PeekResponseUnary::Rows(mut rows)) => {
2456                        count += rows.count();
2457                        while let Some(row) = rows.next() {
2458                            total_sent_bytes += row.byte_len();
2459                            encode_fn(row, typ, &mut out)?;
2460                            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2461                                .await?;
2462                        }
2463                    }
2464                },
2465                notice = self.adapter_client.session().recv_notice() => {
2466                    self.send(notice.into_response())
2467                        .await?;
2468                    self.conn.flush().await?;
2469                }
2470            }
2471
2472            self.conn.flush().await?;
2473        }
2474        // Send required trailers.
2475        if let CopyFormat::Binary = format {
2476            let trailer: i16 = -1;
2477            out.extend(trailer.to_be_bytes());
2478            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2479                .await?;
2480        }
2481
2482        let tag = format!("COPY {}", count);
2483        self.send(BackendMessage::CopyDone).await?;
2484        self.send(BackendMessage::CommandComplete { tag }).await?;
2485        Ok((
2486            State::Ready,
2487            SendRowsEndedReason::Success {
2488                result_size: u64::cast_from(total_sent_bytes),
2489                rows_returned: u64::cast_from(count),
2490            },
2491        ))
2492    }
2493
2494    /// Handles the copy-in mode of the postgres protocol from transferring
2495    /// data to the server.
2496    #[instrument(level = "debug")]
2497    async fn copy_from(
2498        &mut self,
2499        target_id: CatalogItemId,
2500        target_name: String,
2501        columns: Vec<ColumnIndex>,
2502        params: CopyFormatParams<'_>,
2503        row_desc: RelationDesc,
2504        mut ctx_extra: ExecuteContextExtra,
2505    ) -> Result<State, io::Error> {
2506        let res = self
2507            .copy_from_inner(
2508                target_id,
2509                target_name,
2510                columns,
2511                params,
2512                row_desc,
2513                &mut ctx_extra,
2514            )
2515            .await;
2516        match &res {
2517            Ok(State::Done) => {
2518                // The connection closed gracefully without sending us a `CopyDone`,
2519                // causing us to just drop the copy request.
2520                // For the purposes of statement logging, we count this as a cancellation.
2521                self.adapter_client
2522                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2523            }
2524            Err(e) => {
2525                self.adapter_client.retire_execute(
2526                    ctx_extra,
2527                    StatementEndedExecutionReason::Errored {
2528                        error: format!("{e}"),
2529                    },
2530                );
2531            }
2532            other => {
2533                tracing::warn!(?other, "aborting COPY FROM");
2534                self.adapter_client
2535                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Aborted);
2536            }
2537        }
2538        res
2539    }
2540
2541    async fn copy_from_inner(
2542        &mut self,
2543        target_id: CatalogItemId,
2544        target_name: String,
2545        columns: Vec<ColumnIndex>,
2546        params: CopyFormatParams<'_>,
2547        row_desc: RelationDesc,
2548        ctx_extra: &mut ExecuteContextExtra,
2549    ) -> Result<State, io::Error> {
2550        let typ = row_desc.typ();
2551        let column_formats = vec![Format::Text; typ.column_types.len()];
2552        self.send(BackendMessage::CopyInResponse {
2553            overall_format: Format::Text,
2554            column_formats,
2555        })
2556        .await?;
2557        self.conn.flush().await?;
2558
2559        let system_vars = self.adapter_client.get_system_vars().await;
2560        let max_size = system_vars
2561            .get(MAX_COPY_FROM_SIZE.name())
2562            .ok()
2563            .and_then(|max_size| max_size.value().parse().ok())
2564            .unwrap_or(usize::MAX);
2565        tracing::debug!("COPY FROM max buffer size: {max_size} bytes");
2566
2567        let mut data = Vec::new();
2568        loop {
2569            let message = self.conn.recv().await?;
2570            match message {
2571                Some(FrontendMessage::CopyData(buf)) => {
2572                    // Bail before we OOM.
2573                    if (data.len() + buf.len()) > max_size {
2574                        return self
2575                            .error(ErrorResponse::error(
2576                                SqlState::INSUFFICIENT_RESOURCES,
2577                                "COPY FROM STDIN too large",
2578                            ))
2579                            .await;
2580                    }
2581                    data.extend(buf)
2582                }
2583                Some(FrontendMessage::CopyDone) => break,
2584                Some(FrontendMessage::CopyFail(err)) => {
2585                    self.adapter_client.retire_execute(
2586                        std::mem::take(ctx_extra),
2587                        StatementEndedExecutionReason::Canceled,
2588                    );
2589                    return self
2590                        .error(ErrorResponse::error(
2591                            SqlState::QUERY_CANCELED,
2592                            format!("COPY from stdin failed: {}", err),
2593                        ))
2594                        .await;
2595                }
2596                Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2597                Some(_) => {
2598                    let msg = "unexpected message type during COPY from stdin";
2599                    self.adapter_client.retire_execute(
2600                        std::mem::take(ctx_extra),
2601                        StatementEndedExecutionReason::Errored {
2602                            error: msg.to_string(),
2603                        },
2604                    );
2605                    return self
2606                        .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
2607                        .await;
2608                }
2609                None => {
2610                    return Ok(State::Done);
2611                }
2612            }
2613        }
2614
2615        let column_types = typ
2616            .column_types
2617            .iter()
2618            .map(|x| &x.scalar_type)
2619            .map(mz_pgrepr::Type::from)
2620            .collect::<Vec<mz_pgrepr::Type>>();
2621
2622        let rows = match mz_pgcopy::decode_copy_format(&data, &column_types, params) {
2623            Ok(rows) => rows,
2624            Err(e) => {
2625                self.adapter_client.retire_execute(
2626                    std::mem::take(ctx_extra),
2627                    StatementEndedExecutionReason::Errored {
2628                        error: e.to_string(),
2629                    },
2630                );
2631                return self
2632                    .error(ErrorResponse::error(
2633                        SqlState::BAD_COPY_FILE_FORMAT,
2634                        format!("{}", e),
2635                    ))
2636                    .await;
2637            }
2638        };
2639
2640        let count = rows.len();
2641
2642        if let Err(e) = self
2643            .adapter_client
2644            .insert_rows(
2645                target_id,
2646                target_name,
2647                columns,
2648                rows,
2649                std::mem::take(ctx_extra),
2650            )
2651            .await
2652        {
2653            self.adapter_client.retire_execute(
2654                std::mem::take(ctx_extra),
2655                StatementEndedExecutionReason::Errored {
2656                    error: e.to_string(),
2657                },
2658            );
2659            return self.error(e.into_response(Severity::Error)).await;
2660        }
2661
2662        let tag = format!("COPY {}", count);
2663        self.send(BackendMessage::CommandComplete { tag }).await?;
2664
2665        Ok(State::Ready)
2666    }
2667
2668    #[instrument(level = "debug")]
2669    async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
2670        let notices = self
2671            .adapter_client
2672            .session()
2673            .drain_notices()
2674            .into_iter()
2675            .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
2676        self.send_all(notices).await?;
2677        Ok(())
2678    }
2679
2680    #[instrument(level = "debug")]
2681    async fn error(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
2682        assert!(err.severity.is_error());
2683        debug!(
2684            "cid={} error code={}",
2685            self.adapter_client.session().conn_id(),
2686            err.code.code()
2687        );
2688        let is_fatal = err.severity.is_fatal();
2689        self.send(BackendMessage::ErrorResponse(err)).await?;
2690
2691        let txn = self.adapter_client.session().transaction();
2692        match txn {
2693            // Error can be called from describe and parse and so might not be in an active
2694            // transaction.
2695            TransactionStatus::Default | TransactionStatus::Failed(_) => {}
2696            // In Started (i.e., a single statement), cleanup ourselves.
2697            TransactionStatus::Started(_) => {
2698                self.rollback_transaction().await?;
2699            }
2700            // Implicit transactions also clear themselves.
2701            TransactionStatus::InTransactionImplicit(_) => {
2702                self.rollback_transaction().await?;
2703            }
2704            // Explicit transactions move to failed.
2705            TransactionStatus::InTransaction(_) => {
2706                self.adapter_client.fail_transaction();
2707            }
2708        };
2709        if is_fatal {
2710            Ok(State::Done)
2711        } else {
2712            Ok(State::Drain)
2713        }
2714    }
2715
2716    #[instrument(level = "debug")]
2717    async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
2718        self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
2719            SqlState::IN_FAILED_SQL_TRANSACTION,
2720            ABORTED_TXN_MSG,
2721        )))
2722        .await?;
2723        Ok(State::Drain)
2724    }
2725
2726    fn is_aborted_txn(&mut self) -> bool {
2727        matches!(
2728            self.adapter_client.session().transaction(),
2729            TransactionStatus::Failed(_)
2730        )
2731    }
2732}
2733
2734fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
2735    match (formats.len(), n) {
2736        (0, e) => Ok(vec![Format::Text; e]),
2737        (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
2738        (a, e) if a == e => Ok(formats),
2739        (a, e) => Err(format!(
2740            "expected {} field format specifiers, but got {}",
2741            e, a
2742        )),
2743    }
2744}
2745
2746fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
2747    match &stmt_desc.relation_desc {
2748        Some(desc) if !stmt_desc.is_copy => {
2749            BackendMessage::RowDescription(message::encode_row_description(desc, formats))
2750        }
2751        _ => BackendMessage::NoData,
2752    }
2753}
2754
2755type GetResponse = fn(
2756    max_rows: ExecuteCount,
2757    total_sent_rows: usize,
2758    fetch_portal: Option<PortalRefMut>,
2759) -> BackendMessage;
2760
2761// A GetResponse used by send_rows during execute messages on portals or for
2762// simple query messages.
2763fn portal_exec_message(
2764    max_rows: ExecuteCount,
2765    total_sent_rows: usize,
2766    _fetch_portal: Option<PortalRefMut>,
2767) -> BackendMessage {
2768    // If max_rows is not specified, we will always send back a CommandComplete. If
2769    // max_rows is specified, we only send CommandComplete if there were more rows
2770    // requested than were remaining. That is, if max_rows == number of rows that
2771    // were remaining before sending (not that are remaining after sending), then
2772    // we still send a PortalSuspended. The number of remaining rows after the rows
2773    // have been sent doesn't matter. This matches postgres.
2774    match max_rows {
2775        ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
2776            BackendMessage::PortalSuspended
2777        }
2778        _ => BackendMessage::CommandComplete {
2779            tag: format!("SELECT {}", total_sent_rows),
2780        },
2781    }
2782}
2783
2784// A GetResponse used by send_rows during FETCH queries.
2785fn fetch_message(
2786    _max_rows: ExecuteCount,
2787    total_sent_rows: usize,
2788    fetch_portal: Option<PortalRefMut>,
2789) -> BackendMessage {
2790    let tag = format!("FETCH {}", total_sent_rows);
2791    if let Some(portal) = fetch_portal {
2792        *portal.state = PortalState::Completed(Some(tag.clone()));
2793    }
2794    BackendMessage::CommandComplete { tag }
2795}
2796
2797#[derive(Debug, Copy, Clone)]
2798enum ExecuteCount {
2799    All,
2800    Count(usize),
2801}
2802
2803// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
2804fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
2805    match stmt {
2806        // Add PREPARE to this if we ever support it.
2807        Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
2808        None => false,
2809    }
2810}
2811
2812#[derive(Debug)]
2813enum FetchResult {
2814    Rows(Option<Box<dyn RowIterator + Send + Sync>>),
2815    Canceled,
2816    Error(String),
2817    Notice(AdapterNotice),
2818}
2819
2820#[cfg(test)]
2821mod test {
2822    use super::*;
2823
2824    #[mz_ore::test]
2825    fn test_parse_options() {
2826        struct TestCase {
2827            input: &'static str,
2828            expect: Result<Vec<(&'static str, &'static str)>, ()>,
2829        }
2830        let tests = vec![
2831            TestCase {
2832                input: "",
2833                expect: Ok(vec![]),
2834            },
2835            TestCase {
2836                input: "--key",
2837                expect: Err(()),
2838            },
2839            TestCase {
2840                input: "--key=val",
2841                expect: Ok(vec![("key", "val")]),
2842            },
2843            TestCase {
2844                input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
2845                expect: Ok(vec![
2846                    ("key", "val"),
2847                    ("key2", "val2"),
2848                    ("key3", "val3"),
2849                    ("key4", "val4"),
2850                    ("key5", "val5"),
2851                ]),
2852            },
2853            TestCase {
2854                input: r#"-c\ key=val"#,
2855                expect: Ok(vec![(" key", "val")]),
2856            },
2857            TestCase {
2858                input: "--key=val -ckey2 val2",
2859                expect: Err(()),
2860            },
2861            // Unclear what this should do.
2862            TestCase {
2863                input: "--key=",
2864                expect: Ok(vec![("key", "")]),
2865            },
2866        ];
2867        for test in tests {
2868            let got = parse_options(test.input);
2869            let expect = test.expect.map(|r| {
2870                r.into_iter()
2871                    .map(|(k, v)| (k.to_owned(), v.to_owned()))
2872                    .collect()
2873            });
2874            assert_eq!(got, expect, "input: {}", test.input);
2875        }
2876    }
2877
2878    #[mz_ore::test]
2879    fn test_parse_option() {
2880        struct TestCase {
2881            input: &'static str,
2882            expect: Result<(&'static str, &'static str), ()>,
2883        }
2884        let tests = vec![
2885            TestCase {
2886                input: "",
2887                expect: Err(()),
2888            },
2889            TestCase {
2890                input: "--",
2891                expect: Err(()),
2892            },
2893            TestCase {
2894                input: "--c",
2895                expect: Err(()),
2896            },
2897            TestCase {
2898                input: "a=b",
2899                expect: Err(()),
2900            },
2901            TestCase {
2902                input: "--a=b",
2903                expect: Ok(("a", "b")),
2904            },
2905            TestCase {
2906                input: "--ca=b",
2907                expect: Ok(("ca", "b")),
2908            },
2909            TestCase {
2910                input: "-ca=b",
2911                expect: Ok(("a", "b")),
2912            },
2913            // Unclear what this should error, but at least test it.
2914            TestCase {
2915                input: "--=",
2916                expect: Ok(("", "")),
2917            },
2918        ];
2919        for test in tests {
2920            let got = parse_option(test.input);
2921            assert_eq!(got, test.expect, "input: {}", test.input);
2922        }
2923    }
2924
2925    #[mz_ore::test]
2926    fn test_split_options() {
2927        struct TestCase {
2928            input: &'static str,
2929            expect: Vec<&'static str>,
2930        }
2931        let tests = vec![
2932            TestCase {
2933                input: "",
2934                expect: vec![],
2935            },
2936            TestCase {
2937                input: "  ",
2938                expect: vec![],
2939            },
2940            TestCase {
2941                input: " a ",
2942                expect: vec!["a"],
2943            },
2944            TestCase {
2945                input: "  ab     cd   ",
2946                expect: vec!["ab", "cd"],
2947            },
2948            TestCase {
2949                input: r#"  ab\     cd   "#,
2950                expect: vec!["ab ", "cd"],
2951            },
2952            TestCase {
2953                input: r#"  ab\\     cd   "#,
2954                expect: vec![r#"ab\"#, "cd"],
2955            },
2956            TestCase {
2957                input: r#"  ab\\\     cd   "#,
2958                expect: vec![r#"ab\ "#, "cd"],
2959            },
2960            TestCase {
2961                input: r#"  ab\\\ cd   "#,
2962                expect: vec![r#"ab\ cd"#],
2963            },
2964            TestCase {
2965                input: r#"  ab\\\cd   "#,
2966                expect: vec![r#"ab\cd"#],
2967            },
2968            TestCase {
2969                input: r#"a\"#,
2970                expect: vec!["a"],
2971            },
2972            TestCase {
2973                input: r#"a\ "#,
2974                expect: vec!["a "],
2975            },
2976            TestCase {
2977                input: r#"\"#,
2978                expect: vec![],
2979            },
2980            TestCase {
2981                input: r#"\ "#,
2982                expect: vec![r#" "#],
2983            },
2984            TestCase {
2985                input: r#" \ "#,
2986                expect: vec![r#" "#],
2987            },
2988            TestCase {
2989                input: r#"\  "#,
2990                expect: vec![r#" "#],
2991            },
2992        ];
2993        for test in tests {
2994            let got = split_options(test.input);
2995            assert_eq!(got, test.expect, "input: {}", test.input);
2996        }
2997    }
2998}