mz_pgwire/
protocol.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use futures::future::{BoxFuture, FutureExt, pending};
21use itertools::Itertools;
22use mz_adapter::client::RecordFirstRowStream;
23use mz_adapter::session::{
24    EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState,
25    SessionConfig, TransactionStatus,
26};
27use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
28use mz_adapter::{
29    AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, PeekResponseUnary, metrics,
30    verify_datum_desc,
31};
32use mz_auth::password::Password;
33use mz_authenticator::Authenticator;
34use mz_ore::cast::CastFrom;
35use mz_ore::netio::AsyncReady;
36use mz_ore::now::{EpochMillis, SYSTEM_TIME};
37use mz_ore::str::StrExt;
38use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
39use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
40use mz_pgwire_common::{
41    ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
42    VERSIONS,
43};
44use mz_repr::user::InternalUserMetadata;
45use mz_repr::{
46    CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
47    SqlRelationType, SqlScalarType,
48};
49use mz_server_core::TlsMode;
50use mz_server_core::listeners::AllowedRoles;
51use mz_sql::ast::display::AstDisplay;
52use mz_sql::ast::{CopyDirection, CopyStatement, FetchDirection, Ident, Raw, Statement};
53use mz_sql::parse::StatementParseResult;
54use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
55use mz_sql::session::metadata::SessionMetadata;
56use mz_sql::session::user::INTERNAL_USER_NAMES;
57use mz_sql::session::vars::{MAX_COPY_FROM_SIZE, Var, VarInput};
58use postgres::error::SqlState;
59use tokio::io::{self, AsyncRead, AsyncWrite};
60use tokio::select;
61use tokio::time::{self};
62use tokio_metrics::TaskMetrics;
63use tokio_stream::wrappers::UnboundedReceiverStream;
64use tracing::{Instrument, debug, debug_span, warn};
65use uuid::Uuid;
66
67use crate::codec::{
68    FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
69};
70use crate::message::{
71    self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
72    SASLServerFirstMessage,
73};
74
75/// Reports whether the given stream begins with a pgwire handshake.
76///
77/// To avoid false negatives, there must be at least eight bytes in `buf`.
78pub fn match_handshake(buf: &[u8]) -> bool {
79    // The pgwire StartupMessage looks like this:
80    //
81    //     i32 - Length of entire message.
82    //     i32 - Protocol version number.
83    //     [String] - Arbitrary key-value parameters of any length.
84    //
85    // Since arbitrary parameters can be included in the StartupMessage, the
86    // first Int32 is worthless, since the message could have any length.
87    // Instead, we sniff the protocol version number.
88    if buf.len() < 8 {
89        return false;
90    }
91    let version = NetworkEndian::read_i32(&buf[4..8]);
92    VERSIONS.contains(&version)
93}
94
95/// Parameters for the [`run`] function.
96pub struct RunParams<'a, A, I>
97where
98    I: Iterator<Item = TaskMetrics> + Send,
99{
100    /// The TLS mode of the pgwire server.
101    pub tls_mode: Option<TlsMode>,
102    /// A client for the adapter.
103    pub adapter_client: mz_adapter::Client,
104    /// The connection to the client.
105    pub conn: &'a mut FramedConn<A>,
106    /// The universally unique identifier for the connection.
107    pub conn_uuid: Uuid,
108    /// The protocol version that the client provided in the startup message.
109    pub version: i32,
110    /// The parameters that the client provided in the startup message.
111    pub params: BTreeMap<String, String>,
112    /// Authentication method to use. Frontegg, Password, or None.
113    pub authenticator: Authenticator,
114    /// Global connection limit and count
115    pub active_connection_counter: ConnectionCounter,
116    /// Helm chart version
117    pub helm_chart_version: Option<String>,
118    /// Whether to allow reserved users (ie: mz_system).
119    pub allowed_roles: AllowedRoles,
120    /// Tokio metrics
121    pub tokio_metrics_intervals: I,
122}
123
124/// Runs a pgwire connection to completion.
125///
126/// This involves responding to `FrontendMessage::StartupMessage` and all future
127/// requests until the client terminates the connection or a fatal error occurs.
128///
129/// Note that this function returns successfully even upon delivering a fatal
130/// error to the client. It only returns `Err` if an unexpected I/O error occurs
131/// while communicating with the client, e.g., if the connection is severed in
132/// the middle of a request.
133#[mz_ore::instrument(level = "debug")]
134pub async fn run<'a, A, I>(
135    RunParams {
136        tls_mode,
137        adapter_client,
138        conn,
139        conn_uuid,
140        version,
141        mut params,
142        authenticator,
143        active_connection_counter,
144        helm_chart_version,
145        allowed_roles,
146        tokio_metrics_intervals,
147    }: RunParams<'a, A, I>,
148) -> Result<(), io::Error>
149where
150    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
151    I: Iterator<Item = TaskMetrics> + Send,
152{
153    if version != VERSION_3 {
154        return conn
155            .send(ErrorResponse::fatal(
156                SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
157                "server does not support the client's requested protocol version",
158            ))
159            .await;
160    }
161
162    let user = params.remove("user").unwrap_or_else(String::new);
163
164    // TODO move this somewhere it can be shared with HTTP
165    let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
166    // this is a superset of internal users
167    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
168    let role_allowed = match allowed_roles {
169        AllowedRoles::Normal => !is_reserved_user,
170        AllowedRoles::Internal => is_internal_user,
171        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
172    };
173    if !role_allowed {
174        let msg = format!("unauthorized login to user '{user}'");
175        return conn
176            .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
177            .await;
178    }
179
180    if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
181        return conn.send(err).await;
182    }
183
184    let (mut session, expired) = match authenticator {
185        Authenticator::Frontegg(frontegg) => {
186            conn.send(BackendMessage::AuthenticationCleartextPassword)
187                .await?;
188            conn.flush().await?;
189            let password = match conn.recv().await? {
190                Some(FrontendMessage::RawAuthentication(data)) => {
191                    match decode_password(Cursor::new(&data)).ok() {
192                        Some(FrontendMessage::Password { password }) => password,
193                        _ => {
194                            return conn
195                                .send(ErrorResponse::fatal(
196                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
197                                    "expected Password message",
198                                ))
199                                .await;
200                        }
201                    }
202                }
203                _ => {
204                    return conn
205                        .send(ErrorResponse::fatal(
206                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
207                            "expected Password message",
208                        ))
209                        .await;
210                }
211            };
212
213            let auth_response = frontegg.authenticate(&user, &password).await;
214            match auth_response {
215                Ok(mut auth_session) => {
216                    // Create a session based on the auth session.
217                    //
218                    // In particular, it's important that the username come from the
219                    // auth session, as Frontegg may return an email address with
220                    // different casing than the user supplied via the pgwire
221                    // username field. We want to use the Frontegg casing as
222                    // canonical.
223                    let session = adapter_client.new_session(SessionConfig {
224                        conn_id: conn.conn_id().clone(),
225                        uuid: conn_uuid,
226                        user: auth_session.user().into(),
227                        client_ip: conn.peer_addr().clone(),
228                        external_metadata_rx: Some(auth_session.external_metadata_rx()),
229                        internal_user_metadata: None,
230                        helm_chart_version,
231                    });
232                    let expired = async move { auth_session.expired().await };
233                    (session, expired.left_future())
234                }
235                Err(err) => {
236                    warn!(?err, "pgwire connection failed authentication");
237                    return conn
238                        .send(ErrorResponse::fatal(
239                            SqlState::INVALID_PASSWORD,
240                            "invalid password",
241                        ))
242                        .await;
243                }
244            }
245        }
246        Authenticator::Password(adapter_client) => {
247            conn.send(BackendMessage::AuthenticationCleartextPassword)
248                .await?;
249            conn.flush().await?;
250            let password = match conn.recv().await? {
251                Some(FrontendMessage::RawAuthentication(data)) => {
252                    match decode_password(Cursor::new(&data)).ok() {
253                        Some(FrontendMessage::Password { password }) => Password(password),
254                        _ => {
255                            return conn
256                                .send(ErrorResponse::fatal(
257                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
258                                    "expected Password message",
259                                ))
260                                .await;
261                        }
262                    }
263                }
264                _ => {
265                    return conn
266                        .send(ErrorResponse::fatal(
267                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
268                            "expected Password message",
269                        ))
270                        .await;
271                }
272            };
273            let auth_response = match adapter_client.authenticate(&user, &password).await {
274                Ok(resp) => resp,
275                Err(err) => {
276                    warn!(?err, "pgwire connection failed authentication");
277                    return conn
278                        .send(ErrorResponse::fatal(
279                            SqlState::INVALID_PASSWORD,
280                            "invalid password",
281                        ))
282                        .await;
283                }
284            };
285            let session = adapter_client.new_session(SessionConfig {
286                conn_id: conn.conn_id().clone(),
287                uuid: conn_uuid,
288                user,
289                client_ip: conn.peer_addr().clone(),
290                external_metadata_rx: None,
291                internal_user_metadata: Some(InternalUserMetadata {
292                    superuser: auth_response.superuser,
293                }),
294                helm_chart_version,
295            });
296            // No frontegg check, so auth session lasts indefinitely.
297            let auth_session = pending().right_future();
298            (session, auth_session)
299        }
300        Authenticator::Sasl(adapter_client) => {
301            // Start the handshake
302            conn.send(BackendMessage::AuthenticationSASL).await?;
303            conn.flush().await?;
304            // Get the initial response indicating chosen mechanism
305            let (mechanism, initial_response) = match conn.recv().await? {
306                Some(FrontendMessage::RawAuthentication(data)) => {
307                    match decode_sasl_initial_response(Cursor::new(&data)).ok() {
308                        Some(FrontendMessage::SASLInitialResponse {
309                            gs2_header,
310                            mechanism,
311                            initial_response,
312                        }) => {
313                            // We do not support channel binding
314                            if gs2_header.channel_binding_enabled() {
315                                return conn
316                                    .send(ErrorResponse::fatal(
317                                        SqlState::PROTOCOL_VIOLATION,
318                                        "channel binding not supported",
319                                    ))
320                                    .await;
321                            }
322                            (mechanism, initial_response)
323                        }
324                        _ => {
325                            return conn
326                                .send(ErrorResponse::fatal(
327                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
328                                    "expected SASLInitialResponse message",
329                                ))
330                                .await;
331                        }
332                    }
333                }
334                _ => {
335                    return conn
336                        .send(ErrorResponse::fatal(
337                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
338                            "expected SASLInitialResponse message",
339                        ))
340                        .await;
341                }
342            };
343
344            if mechanism != "SCRAM-SHA-256" {
345                return conn
346                    .send(ErrorResponse::fatal(
347                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
348                        "unsupported SASL mechanism",
349                    ))
350                    .await;
351            }
352
353            if initial_response.nonce.len() > 256 {
354                return conn
355                    .send(ErrorResponse::fatal(
356                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
357                        "nonce too long",
358                    ))
359                    .await;
360            }
361
362            let (server_first_message_raw, mock_hash) = match adapter_client
363                .generate_sasl_challenge(&user, &initial_response.nonce)
364                .await
365            {
366                Ok(response) => {
367                    let server_first_message_raw = format!(
368                        "r={},s={},i={}",
369                        response.nonce, response.salt, response.iteration_count
370                    );
371
372                    let client_key = [0u8; 32];
373                    let server_key = [1u8; 32];
374                    let mock_hash = format!(
375                        "SCRAM-SHA-256${}:{}${}:{}",
376                        response.iteration_count,
377                        response.salt,
378                        BASE64_STANDARD.encode(client_key),
379                        BASE64_STANDARD.encode(server_key)
380                    );
381
382                    conn.send(BackendMessage::AuthenticationSASLContinue(
383                        SASLServerFirstMessage {
384                            iteration_count: response.iteration_count,
385                            nonce: response.nonce,
386                            salt: response.salt,
387                        },
388                    ))
389                    .await?;
390                    conn.flush().await?;
391                    (server_first_message_raw, mock_hash)
392                }
393                Err(e) => {
394                    return conn.send(e.into_response(Severity::Fatal)).await;
395                }
396            };
397
398            let auth_resp = match conn.recv().await? {
399                Some(FrontendMessage::RawAuthentication(data)) => {
400                    match decode_sasl_response(Cursor::new(&data)).ok() {
401                        Some(FrontendMessage::SASLResponse(response)) => {
402                            let auth_message = format!(
403                                "{},{},{}",
404                                initial_response.client_first_message_bare_raw,
405                                server_first_message_raw,
406                                response.client_final_message_bare_raw
407                            );
408                            if response.proof.len() > 1024 {
409                                return conn
410                                    .send(ErrorResponse::fatal(
411                                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
412                                        "proof too long",
413                                    ))
414                                    .await;
415                            }
416                            match adapter_client
417                                .verify_sasl_proof(
418                                    &user,
419                                    &response.proof,
420                                    &auth_message,
421                                    &mock_hash,
422                                )
423                                .await
424                            {
425                                Ok(resp) => {
426                                    conn.send(BackendMessage::AuthenticationSASLFinal(
427                                        SASLServerFinalMessage {
428                                            kind: SASLServerFinalMessageKinds::Verifier(
429                                                resp.verifier,
430                                            ),
431                                            extensions: vec![],
432                                        },
433                                    ))
434                                    .await?;
435                                    conn.flush().await?;
436                                    resp.auth_resp
437                                }
438                                Err(_) => {
439                                    return conn
440                                        .send(ErrorResponse::fatal(
441                                            SqlState::INVALID_PASSWORD,
442                                            "invalid password",
443                                        ))
444                                        .await;
445                                }
446                            }
447                        }
448                        _ => {
449                            return conn
450                                .send(ErrorResponse::fatal(
451                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
452                                    "expected SASLResponse message",
453                                ))
454                                .await;
455                        }
456                    }
457                }
458                _ => {
459                    return conn
460                        .send(ErrorResponse::fatal(
461                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
462                            "expected SASLResponse message",
463                        ))
464                        .await;
465                }
466            };
467
468            let session = adapter_client.new_session(SessionConfig {
469                conn_id: conn.conn_id().clone(),
470                uuid: conn_uuid,
471                user,
472                client_ip: conn.peer_addr().clone(),
473                external_metadata_rx: None,
474                internal_user_metadata: Some(InternalUserMetadata {
475                    superuser: auth_resp.superuser,
476                }),
477                helm_chart_version,
478            });
479            // No frontegg check, so auth session lasts indefinitely.
480            let auth_session = pending().right_future();
481            (session, auth_session)
482        }
483        Authenticator::None => {
484            let session = adapter_client.new_session(SessionConfig {
485                conn_id: conn.conn_id().clone(),
486                uuid: conn_uuid,
487                user,
488                client_ip: conn.peer_addr().clone(),
489                external_metadata_rx: None,
490                internal_user_metadata: None,
491                helm_chart_version,
492            });
493            // No frontegg check, so auth session lasts indefinitely.
494            let auth_session = pending().right_future();
495            (session, auth_session)
496        }
497    };
498
499    let system_vars = adapter_client.get_system_vars().await;
500    for (name, value) in params {
501        let settings = match name.as_str() {
502            "options" => match parse_options(&value) {
503                Ok(opts) => opts,
504                Err(()) => {
505                    session.add_notice(AdapterNotice::BadStartupSetting {
506                        name,
507                        reason: "could not parse".into(),
508                    });
509                    continue;
510                }
511            },
512            _ => vec![(name, value)],
513        };
514        for (key, val) in settings {
515            const LOCAL: bool = false;
516            // TODO: Issuing an error here is better than what we did before
517            // (silently ignore errors on set), but erroring the connection
518            // might be the better behavior. We maybe need to support more
519            // options sent by psql and drivers before we can safely do this.
520            if let Err(err) =
521                session
522                    .vars_mut()
523                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
524            {
525                session.add_notice(AdapterNotice::BadStartupSetting {
526                    name: key,
527                    reason: err.to_string(),
528                });
529            }
530        }
531    }
532    session
533        .vars_mut()
534        .end_transaction(EndTransactionAction::Commit);
535
536    let _guard = match active_connection_counter.allocate_connection(session.user()) {
537        Ok(drop_connection) => drop_connection,
538        Err(e) => {
539            let e: AdapterError = e.into();
540            return conn.send(e.into_response(Severity::Fatal)).await;
541        }
542    };
543
544    // Register session with adapter.
545    let mut adapter_client = match adapter_client.startup(session).await {
546        Ok(adapter_client) => adapter_client,
547        Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
548    };
549
550    let mut buf = vec![BackendMessage::AuthenticationOk];
551    for var in adapter_client.session().vars().notify_set() {
552        buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
553    }
554    buf.push(BackendMessage::BackendKeyData {
555        conn_id: adapter_client.session().conn_id().unhandled(),
556        secret_key: adapter_client.session().secret_key(),
557    });
558    buf.extend(
559        adapter_client
560            .session()
561            .drain_notices()
562            .into_iter()
563            .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
564    );
565    buf.push(BackendMessage::ReadyForQuery(
566        adapter_client.session().transaction().into(),
567    ));
568    conn.send_all(buf).await?;
569    conn.flush().await?;
570
571    let machine = StateMachine {
572        conn,
573        adapter_client,
574        txn_needs_commit: false,
575        tokio_metrics_intervals,
576    };
577
578    select! {
579        r = machine.run() => {
580            // Errors produced internally (like MAX_REQUEST_SIZE being exceeded) should send an
581            // error to the client informing them why the connection was closed. We still want to
582            // return the original error up the stack, though, so we skip error checking during conn
583            // operations.
584            if let Err(err) = &r {
585                let _ = conn
586                    .send(ErrorResponse::fatal(
587                        SqlState::CONNECTION_FAILURE,
588                        err.to_string(),
589                    ))
590                    .await;
591                let _ = conn.flush().await;
592            }
593            r
594        },
595        _ = expired => {
596            conn
597                .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
598                .await?;
599            conn.flush().await
600        }
601    }
602}
603
604/// Returns (name, value) session settings pairs from an options value.
605///
606/// From Postgres, see pg_split_opts in postinit.c and process_postgres_switches
607/// in postgres.c.
608fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
609    let opts = split_options(value);
610    let mut pairs = Vec::with_capacity(opts.len());
611    let mut seen_prefix = false;
612    for opt in opts {
613        if !seen_prefix {
614            if opt == "-c" {
615                seen_prefix = true;
616            } else {
617                let (key, val) = parse_option(&opt)?;
618                pairs.push((key.to_owned(), val.to_owned()));
619            }
620        } else {
621            let (key, val) = opt.split_once('=').ok_or(())?;
622            pairs.push((key.to_owned(), val.to_owned()));
623            seen_prefix = false;
624        }
625    }
626    Ok(pairs)
627}
628
629/// Returns the parsed key and value from option of the form `--key=value`, `-c
630/// key=value`, or `-ckey=value`. Keys replace `-` with `_`. Returns an error if
631/// there was some other prefix.
632fn parse_option(option: &str) -> Result<(&str, &str), ()> {
633    let (key, value) = option.split_once('=').ok_or(())?;
634    for prefix in &["-c", "--"] {
635        if let Some(key) = key.strip_prefix(prefix) {
636            return Ok((key, value));
637        }
638    }
639    Err(())
640}
641
642/// Splits value by any number of spaces except those preceded by `\`.
643fn split_options(value: &str) -> Vec<String> {
644    let mut strs = Vec::new();
645    // Need to build a string because of the escaping, so we can't simply
646    // subslice into value, and this isn't called enough to need to make it
647    // smart so it only builds a string if needed.
648    let mut current = String::new();
649    let mut was_slash = false;
650    for c in value.chars() {
651        was_slash = match c {
652            ' ' => {
653                if was_slash {
654                    current.push(' ');
655                } else if !current.is_empty() {
656                    // To ignore multiple spaces in a row, only push if current
657                    // is not empty.
658                    strs.push(std::mem::take(&mut current));
659                }
660                false
661            }
662            '\\' => {
663                if was_slash {
664                    // Two slashes in a row will add a slash and not escape the
665                    // next char.
666                    current.push('\\');
667                    false
668                } else {
669                    true
670                }
671            }
672            _ => {
673                current.push(c);
674                false
675            }
676        };
677    }
678    // A `\` at the end will be ignored.
679    if !current.is_empty() {
680        strs.push(current);
681    }
682    strs
683}
684
685#[derive(Debug)]
686enum State {
687    Ready,
688    Drain,
689    Done,
690}
691
692struct StateMachine<'a, A, I>
693where
694    I: Iterator<Item = TaskMetrics> + Send + 'a,
695{
696    conn: &'a mut FramedConn<A>,
697    adapter_client: mz_adapter::SessionClient,
698    txn_needs_commit: bool,
699    tokio_metrics_intervals: I,
700}
701
702enum SendRowsEndedReason {
703    Success {
704        result_size: u64,
705        rows_returned: u64,
706    },
707    Errored {
708        error: String,
709    },
710    Canceled,
711}
712
713const ABORTED_TXN_MSG: &str =
714    "current transaction is aborted, commands ignored until end of transaction block";
715
716impl<'a, A, I> StateMachine<'a, A, I>
717where
718    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
719    I: Iterator<Item = TaskMetrics> + Send + 'a,
720{
721    // Manually desugar this (don't use `async fn run`) here because a much better
722    // error message is produced if there are problems with Send or other traits
723    // somewhere within the Future.
724    #[allow(clippy::manual_async_fn)]
725    #[mz_ore::instrument(level = "debug")]
726    fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
727        async move {
728            let mut state = State::Ready;
729            loop {
730                self.send_pending_notices().await?;
731                state = match state {
732                    State::Ready => self.advance_ready().await?,
733                    State::Drain => self.advance_drain().await?,
734                    State::Done => return Ok(()),
735                };
736                self.adapter_client
737                    .add_idle_in_transaction_session_timeout();
738            }
739        }
740    }
741
742    #[instrument(level = "debug")]
743    async fn advance_ready(&mut self) -> Result<State, io::Error> {
744        // Start a new metrics interval before the `recv()` call.
745        self.tokio_metrics_intervals
746            .next()
747            .expect("infinite iterator");
748
749        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
750        let message = select! {
751            biased;
752
753            // `recv_timeout()` is cancel-safe as per it's docs.
754            Some(timeout) = self.adapter_client.recv_timeout() => {
755                let err: AdapterError = timeout.into();
756                let conn_id = self.adapter_client.session().conn_id();
757                tracing::warn!("session timed out, conn_id {}", conn_id);
758
759                // Process the error, doing any state cleanup.
760                let error_response = err.into_response(Severity::Fatal);
761                let error_state = self.error(error_response).await;
762
763                // Terminate __after__ we do any cleanup.
764                self.adapter_client.terminate().await;
765
766                // We must wait for the client to send a request before we can send the error response.
767                // Due to the PG wire protocol, we can't send an ErrorResponse unless it is in response
768                // to a client message.
769                let _ = self.conn.recv().await?;
770                return error_state;
771            },
772            // `recv()` is cancel-safe as per it's docs.
773            message = self.conn.recv() => message?,
774        };
775
776        // Take the metrics since just before the `recv`.
777        let interval = self
778            .tokio_metrics_intervals
779            .next()
780            .expect("infinite iterator");
781        let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
782
783        // TODO(ggevay): Consider subtracting the scheduling delay from `received`. It's not obvious
784        // whether we should do this, because the result wouldn't exactly correspond to either first
785        // byte received or last byte received (for msgs that arrive in more than one network packet).
786        let received = SYSTEM_TIME();
787
788        self.adapter_client
789            .remove_idle_in_transaction_session_timeout();
790
791        // NOTE(guswynn): we could consider adding spans to all message types. Currently
792        // only a few message types seem useful.
793        let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
794
795        let start = message.as_ref().map(|_| Instant::now());
796        let next_state = match message {
797            Some(FrontendMessage::Query { sql }) => {
798                let query_root_span =
799                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
800                query_root_span.follows_from(tracing::Span::current());
801                self.query(sql, received)
802                    .instrument(query_root_span)
803                    .await?
804            }
805            Some(FrontendMessage::Parse {
806                name,
807                sql,
808                param_types,
809            }) => self.parse(name, sql, param_types).await?,
810            Some(FrontendMessage::Bind {
811                portal_name,
812                statement_name,
813                param_formats,
814                raw_params,
815                result_formats,
816            }) => {
817                self.bind(
818                    portal_name,
819                    statement_name,
820                    param_formats,
821                    raw_params,
822                    result_formats,
823                )
824                .await?
825            }
826            Some(FrontendMessage::Execute {
827                portal_name,
828                max_rows,
829            }) => {
830                let max_rows = match usize::try_from(max_rows) {
831                    Ok(0) | Err(_) => ExecuteCount::All, // If `max_rows < 0`, no limit.
832                    Ok(n) => ExecuteCount::Count(n),
833                };
834                let execute_root_span =
835                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
836                execute_root_span.follows_from(tracing::Span::current());
837                let state = self
838                    .execute(
839                        portal_name,
840                        max_rows,
841                        portal_exec_message,
842                        None,
843                        ExecuteTimeout::None,
844                        None,
845                        Some(received),
846                    )
847                    .instrument(execute_root_span)
848                    .await?;
849                // In PostgreSQL, when using the extended query protocol, some statements may
850                // trigger an eager commit of the current implicit transaction,
851                // see: <https://git.postgresql.org/gitweb/?p=postgresql.git&a=commitdiff&h=f92944137>.
852                //
853                // In Materialize, however, we eagerly commit every statement outside of an explicit
854                // transaction when using the extended query protocol. This allows us to eliminate
855                // the possibility of a multiple statement implicit transaction, which in turn
856                // allows us to apply single-statement optimizations to queries issued in implicit
857                // transactions in the extended query protocol.
858                //
859                // We don't immediately commit here to allow users to page through the portal if
860                // necessary. Committing the transaction would destroy the portal before the next
861                // Execute command has a chance to resume it. So we instead mark the transaction
862                // for commit the next time that `ensure_transaction` is called.
863                if self.adapter_client.session().transaction().is_implicit() {
864                    self.txn_needs_commit = true;
865                }
866                state
867            }
868            Some(FrontendMessage::DescribeStatement { name }) => {
869                self.describe_statement(&name).await?
870            }
871            Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
872            Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
873            Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
874            Some(FrontendMessage::Flush) => self.flush().await?,
875            Some(FrontendMessage::Sync) => self.sync().await?,
876            Some(FrontendMessage::Terminate) => State::Done,
877
878            Some(FrontendMessage::CopyData(_))
879            | Some(FrontendMessage::CopyDone)
880            | Some(FrontendMessage::CopyFail(_))
881            | Some(FrontendMessage::Password { .. })
882            | Some(FrontendMessage::RawAuthentication(_))
883            | Some(FrontendMessage::SASLInitialResponse { .. })
884            | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
885            None => State::Done,
886        };
887
888        if let Some(start) = start {
889            self.adapter_client
890                .inner()
891                .metrics()
892                .pgwire_message_processing_seconds
893                .with_label_values(&[message_name])
894                .observe(start.elapsed().as_secs_f64());
895        }
896        self.adapter_client
897            .inner()
898            .metrics()
899            .pgwire_recv_scheduling_delay_ms
900            .with_label_values(&[message_name])
901            .observe(recv_scheduling_delay_ms);
902
903        Ok(next_state)
904    }
905
906    async fn advance_drain(&mut self) -> Result<State, io::Error> {
907        let message = self.conn.recv().await?;
908        if message.is_some() {
909            self.adapter_client
910                .remove_idle_in_transaction_session_timeout();
911        }
912        match message {
913            Some(FrontendMessage::Sync) => self.sync().await,
914            None => Ok(State::Done),
915            _ => Ok(State::Drain),
916        }
917    }
918
919    /// Note that `lifecycle_timestamps` belongs to the whole "Simple Query", because the whole
920    /// Simple Query is received and parsed together. This means that if there are multiple
921    /// statements in a Simple Query, then all of them have the same `lifecycle_timestamps`.
922    #[instrument(level = "debug")]
923    async fn one_query(
924        &mut self,
925        stmt: Statement<Raw>,
926        sql: String,
927        lifecycle_timestamps: LifecycleTimestamps,
928    ) -> Result<State, io::Error> {
929        // Bind the portal. Note that this does not set the empty string prepared
930        // statement.
931        const EMPTY_PORTAL: &str = "";
932        if let Err(e) = self
933            .adapter_client
934            .declare(EMPTY_PORTAL.to_string(), stmt, sql)
935            .await
936        {
937            return self.error(e.into_response(Severity::Error)).await;
938        }
939        let portal = self
940            .adapter_client
941            .session()
942            .get_portal_unverified_mut(EMPTY_PORTAL)
943            .expect("unnamed portal should be present");
944
945        *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
946
947        let stmt_desc = portal.desc.clone();
948        if !stmt_desc.param_types.is_empty() {
949            return self
950                .error(ErrorResponse::error(
951                    SqlState::UNDEFINED_PARAMETER,
952                    "there is no parameter $1",
953                ))
954                .await;
955        }
956
957        // Maybe send row description.
958        if let Some(relation_desc) = &stmt_desc.relation_desc {
959            if !stmt_desc.is_copy {
960                let formats = vec![Format::Text; stmt_desc.arity()];
961                self.send(BackendMessage::RowDescription(
962                    message::encode_row_description(relation_desc, &formats),
963                ))
964                .await?;
965            }
966        }
967
968        let result = match self
969            .adapter_client
970            .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
971            .await
972        {
973            Ok((response, execute_started)) => {
974                self.send_pending_notices().await?;
975                self.send_execute_response(
976                    response,
977                    stmt_desc.relation_desc,
978                    EMPTY_PORTAL.to_string(),
979                    ExecuteCount::All,
980                    portal_exec_message,
981                    None,
982                    ExecuteTimeout::None,
983                    execute_started,
984                )
985                .await
986            }
987            Err(e) => {
988                self.send_pending_notices().await?;
989                self.error(e.into_response(Severity::Error)).await
990            }
991        };
992
993        // Destroy the portal.
994        self.adapter_client.session().remove_portal(EMPTY_PORTAL);
995
996        result
997    }
998
999    async fn ensure_transaction(
1000        &mut self,
1001        num_stmts: usize,
1002        message_type: &str,
1003    ) -> Result<(), io::Error> {
1004        let start = Instant::now();
1005        if self.txn_needs_commit {
1006            self.commit_transaction().await?;
1007        }
1008        // start_transaction can't error (but assert that just in case it changes in
1009        // the future.
1010        let res = self.adapter_client.start_transaction(Some(num_stmts));
1011        assert_ok!(res);
1012        self.adapter_client
1013            .inner()
1014            .metrics()
1015            .pgwire_ensure_transaction_seconds
1016            .with_label_values(&[message_type])
1017            .observe(start.elapsed().as_secs_f64());
1018        Ok(())
1019    }
1020
1021    fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1022        let parse_start = Instant::now();
1023        let result = match self.adapter_client.parse(sql) {
1024            Ok(result) => result.map_err(|e| {
1025                // Convert our 0-based byte position to pgwire's 1-based character
1026                // position.
1027                let pos = sql[..e.error.pos].chars().count() + 1;
1028                ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1029            }),
1030            Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1031        };
1032        self.adapter_client
1033            .inner()
1034            .metrics()
1035            .parse_seconds
1036            .observe(parse_start.elapsed().as_secs_f64());
1037        result
1038    }
1039
1040    /// Executes a "Simple Query", see
1041    /// <https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SIMPLE-QUERY>
1042    ///
1043    /// For implicit transaction handling, see "Multiple Statements in a Simple Query" in the above.
1044    #[instrument(level = "debug")]
1045    async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1046        // Parse first before doing any transaction checking.
1047        let stmts = match self.parse_sql(&sql) {
1048            Ok(stmts) => stmts,
1049            Err(err) => {
1050                self.error(err).await?;
1051                return self.ready().await;
1052            }
1053        };
1054
1055        let num_stmts = stmts.len();
1056
1057        // Compare with postgres' backend/tcop/postgres.c exec_simple_query.
1058        for StatementParseResult { ast: stmt, sql } in stmts {
1059            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1060            if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1061                self.aborted_txn_error().await?;
1062                break;
1063            }
1064
1065            // Start an implicit transaction if we aren't in any transaction and there's
1066            // more than one statement. This mirrors the `use_implicit_block` variable in
1067            // postgres.
1068            //
1069            // This needs to be done in the loop instead of once at the top because
1070            // a COMMIT/ROLLBACK statement needs to start a new transaction on next
1071            // statement.
1072            self.ensure_transaction(num_stmts, "query").await?;
1073
1074            match self
1075                .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1076                .await?
1077            {
1078                State::Ready => (),
1079                State::Drain => break,
1080                State::Done => return Ok(State::Done),
1081            }
1082        }
1083
1084        // Implicit transactions are closed at the end of a Query message.
1085        {
1086            if self.adapter_client.session().transaction().is_implicit() {
1087                self.commit_transaction().await?;
1088            }
1089        }
1090
1091        if num_stmts == 0 {
1092            self.send(BackendMessage::EmptyQueryResponse).await?;
1093        }
1094
1095        self.ready().await
1096    }
1097
1098    #[instrument(level = "debug")]
1099    async fn parse(
1100        &mut self,
1101        name: String,
1102        sql: String,
1103        param_oids: Vec<u32>,
1104    ) -> Result<State, io::Error> {
1105        // Start a transaction if we aren't in one.
1106        self.ensure_transaction(1, "parse").await?;
1107
1108        let mut param_types = vec![];
1109        for oid in param_oids {
1110            match mz_pgrepr::Type::from_oid(oid) {
1111                Ok(ty) => match SqlScalarType::try_from(&ty) {
1112                    Ok(ty) => param_types.push(Some(ty)),
1113                    Err(err) => {
1114                        return self
1115                            .error(ErrorResponse::error(
1116                                SqlState::INVALID_PARAMETER_VALUE,
1117                                err.to_string(),
1118                            ))
1119                            .await;
1120                    }
1121                },
1122                Err(_) if oid == 0 => param_types.push(None),
1123                Err(e) => {
1124                    return self
1125                        .error(ErrorResponse::error(
1126                            SqlState::PROTOCOL_VIOLATION,
1127                            e.to_string(),
1128                        ))
1129                        .await;
1130                }
1131            }
1132        }
1133
1134        let stmts = match self.parse_sql(&sql) {
1135            Ok(stmts) => stmts,
1136            Err(err) => {
1137                return self.error(err).await;
1138            }
1139        };
1140        if stmts.len() > 1 {
1141            return self
1142                .error(ErrorResponse::error(
1143                    SqlState::INTERNAL_ERROR,
1144                    "cannot insert multiple commands into a prepared statement",
1145                ))
1146                .await;
1147        }
1148        let (maybe_stmt, sql) = match stmts.into_iter().next() {
1149            None => (None, ""),
1150            Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1151        };
1152        if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1153            return self.aborted_txn_error().await;
1154        }
1155        match self
1156            .adapter_client
1157            .prepare(name, maybe_stmt, sql.to_string(), param_types)
1158            .await
1159        {
1160            Ok(()) => {
1161                self.send(BackendMessage::ParseComplete).await?;
1162                Ok(State::Ready)
1163            }
1164            Err(e) => self.error(e.into_response(Severity::Error)).await,
1165        }
1166    }
1167
1168    /// Commits and clears the current transaction.
1169    #[instrument(level = "debug")]
1170    async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1171        self.end_transaction(EndTransactionAction::Commit).await
1172    }
1173
1174    /// Rollback and clears the current transaction.
1175    #[instrument(level = "debug")]
1176    async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1177        self.end_transaction(EndTransactionAction::Rollback).await
1178    }
1179
1180    /// End a transaction and report to the user if an error occurred.
1181    #[instrument(level = "debug")]
1182    async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1183        self.txn_needs_commit = false;
1184        let resp = self.adapter_client.end_transaction(action).await;
1185        if let Err(err) = resp {
1186            self.send(BackendMessage::ErrorResponse(
1187                err.into_response(Severity::Error),
1188            ))
1189            .await?;
1190        }
1191        Ok(())
1192    }
1193
1194    #[instrument(level = "debug")]
1195    async fn bind(
1196        &mut self,
1197        portal_name: String,
1198        statement_name: String,
1199        param_formats: Vec<Format>,
1200        raw_params: Vec<Option<Vec<u8>>>,
1201        result_formats: Vec<Format>,
1202    ) -> Result<State, io::Error> {
1203        // Start a transaction if we aren't in one.
1204        self.ensure_transaction(1, "bind").await?;
1205
1206        let aborted_txn = self.is_aborted_txn();
1207        let stmt = match self
1208            .adapter_client
1209            .get_prepared_statement(&statement_name)
1210            .await
1211        {
1212            Ok(stmt) => stmt,
1213            Err(err) => return self.error(err.into_response(Severity::Error)).await,
1214        };
1215
1216        let param_types = &stmt.desc().param_types;
1217        if param_types.len() != raw_params.len() {
1218            let message = format!(
1219                "bind message supplies {actual} parameters, \
1220                 but prepared statement \"{name}\" requires {expected}",
1221                name = statement_name,
1222                actual = raw_params.len(),
1223                expected = param_types.len()
1224            );
1225            return self
1226                .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, message))
1227                .await;
1228        }
1229        let param_formats = match pad_formats(param_formats, raw_params.len()) {
1230            Ok(param_formats) => param_formats,
1231            Err(msg) => {
1232                return self
1233                    .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1234                    .await;
1235            }
1236        };
1237        if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1238            return self.aborted_txn_error().await;
1239        }
1240        let buf = RowArena::new();
1241        let mut params = vec![];
1242        for ((raw_param, mz_typ), format) in raw_params
1243            .into_iter()
1244            .zip_eq(param_types)
1245            .zip_eq(param_formats)
1246        {
1247            let pg_typ = mz_pgrepr::Type::from(mz_typ);
1248            let datum = match raw_param {
1249                None => Datum::Null,
1250                Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1251                    Ok(param) => param.into_datum(&buf, &pg_typ),
1252                    Err(err) => {
1253                        let msg = format!("unable to decode parameter: {}", err);
1254                        return self
1255                            .error(ErrorResponse::error(SqlState::INVALID_PARAMETER_VALUE, msg))
1256                            .await;
1257                    }
1258                },
1259            };
1260            params.push((datum, mz_typ.clone()))
1261        }
1262
1263        let result_formats = match pad_formats(
1264            result_formats,
1265            stmt.desc()
1266                .relation_desc
1267                .clone()
1268                .map(|desc| desc.typ().column_types.len())
1269                .unwrap_or(0),
1270        ) {
1271            Ok(result_formats) => result_formats,
1272            Err(msg) => {
1273                return self
1274                    .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1275                    .await;
1276            }
1277        };
1278
1279        // Binary encodings are disabled for list, map, and aclitem types, but this doesn't
1280        // apply to COPY TO statements.
1281        if !stmt.stmt().map_or(false, |stmt| {
1282            matches!(
1283                stmt,
1284                Statement::Copy(CopyStatement {
1285                    direction: CopyDirection::To,
1286                    ..
1287                })
1288            )
1289        }) {
1290            if let Some(desc) = stmt.desc().relation_desc.clone() {
1291                for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1292                    match (format, &ty.scalar_type) {
1293                        (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1294                            return self
1295                                .error(ErrorResponse::error(
1296                                    SqlState::PROTOCOL_VIOLATION,
1297                                    "binary encoding of list types is not implemented",
1298                                ))
1299                                .await;
1300                        }
1301                        (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1302                            return self
1303                                .error(ErrorResponse::error(
1304                                    SqlState::PROTOCOL_VIOLATION,
1305                                    "binary encoding of map types is not implemented",
1306                                ))
1307                                .await;
1308                        }
1309                        (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1310                            return self
1311                                .error(ErrorResponse::error(
1312                                    SqlState::PROTOCOL_VIOLATION,
1313                                    "binary encoding of aclitem types does not exist",
1314                                ))
1315                                .await;
1316                        }
1317                        _ => (),
1318                    }
1319                }
1320            }
1321        }
1322
1323        let desc = stmt.desc().clone();
1324        let logging = Arc::clone(stmt.logging());
1325        let stmt_ast = stmt.stmt().cloned();
1326        let state_revision = stmt.state_revision;
1327        if let Err(err) = self.adapter_client.session().set_portal(
1328            portal_name,
1329            desc,
1330            stmt_ast,
1331            logging,
1332            params,
1333            result_formats,
1334            state_revision,
1335        ) {
1336            return self.error(err.into_response(Severity::Error)).await;
1337        }
1338
1339        self.send(BackendMessage::BindComplete).await?;
1340        Ok(State::Ready)
1341    }
1342
1343    /// `outer_ctx_extra` is Some when we are executing as part of an outer statement, e.g., a FETCH
1344    /// triggering the execution of the underlying query.
1345    fn execute(
1346        &mut self,
1347        portal_name: String,
1348        max_rows: ExecuteCount,
1349        get_response: GetResponse,
1350        fetch_portal_name: Option<String>,
1351        timeout: ExecuteTimeout,
1352        outer_ctx_extra: Option<ExecuteContextExtra>,
1353        received: Option<EpochMillis>,
1354    ) -> BoxFuture<'_, Result<State, io::Error>> {
1355        async move {
1356            let aborted_txn = self.is_aborted_txn();
1357
1358            // Check if the portal has been started and can be continued.
1359            let portal = match self
1360                .adapter_client
1361                .session()
1362                .get_portal_unverified_mut(&portal_name)
1363            {
1364                Some(portal) => portal,
1365                None => {
1366                    let msg = format!("portal {} does not exist", portal_name.quoted());
1367                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1368                        self.adapter_client.retire_execute(
1369                            outer_ctx_extra,
1370                            StatementEndedExecutionReason::Errored { error: msg.clone() },
1371                        );
1372                    }
1373                    return self
1374                        .error(ErrorResponse::error(SqlState::INVALID_CURSOR_NAME, msg))
1375                        .await;
1376                }
1377            };
1378
1379            *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1380
1381            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1382            let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1383            if aborted_txn && !txn_exit_stmt {
1384                if let Some(outer_ctx_extra) = outer_ctx_extra {
1385                    self.adapter_client.retire_execute(
1386                        outer_ctx_extra,
1387                        StatementEndedExecutionReason::Errored {
1388                            error: ABORTED_TXN_MSG.to_string(),
1389                        },
1390                    );
1391                }
1392                return self.aborted_txn_error().await;
1393            }
1394
1395            let row_desc = portal.desc.relation_desc.clone();
1396            match portal.state {
1397                PortalState::NotStarted => {
1398                    // Start a transaction if we aren't in one.
1399                    self.ensure_transaction(1, "execute").await?;
1400                    match self
1401                        .adapter_client
1402                        .execute(
1403                            portal_name.clone(),
1404                            self.conn.wait_closed(),
1405                            outer_ctx_extra,
1406                        )
1407                        .await
1408                    {
1409                        Ok((response, execute_started)) => {
1410                            self.send_pending_notices().await?;
1411                            self.send_execute_response(
1412                                response,
1413                                row_desc,
1414                                portal_name,
1415                                max_rows,
1416                                get_response,
1417                                fetch_portal_name,
1418                                timeout,
1419                                execute_started,
1420                            )
1421                            .await
1422                        }
1423                        Err(e) => {
1424                            self.send_pending_notices().await?;
1425                            self.error(e.into_response(Severity::Error)).await
1426                        }
1427                    }
1428                }
1429                PortalState::InProgress(rows) => {
1430                    let rows = rows.take().expect("InProgress rows must be populated");
1431                    let (result, statement_ended_execution_reason) = match self
1432                        .send_rows(
1433                            row_desc.expect("portal missing row desc on resumption"),
1434                            portal_name,
1435                            rows,
1436                            max_rows,
1437                            get_response,
1438                            fetch_portal_name,
1439                            timeout,
1440                        )
1441                        .await
1442                    {
1443                        Err(e) => {
1444                            // This is an error communicating with the connection.
1445                            // We consider that to be a cancelation, rather than a query error.
1446                            (Err(e), StatementEndedExecutionReason::Canceled)
1447                        }
1448                        Ok((ok, SendRowsEndedReason::Canceled)) => {
1449                            (Ok(ok), StatementEndedExecutionReason::Canceled)
1450                        }
1451                        // NOTE: For now the values for `result_size` and
1452                        // `rows_returned` in fetches are a bit confusing.
1453                        // We record `Some(n)` for the first fetch, where `n` is
1454                        // the number of bytes/rows returned by the inner
1455                        // execute (regardless of how many rows the
1456                        // fetch fetched), and `None` for subsequent fetches.
1457                        //
1458                        // This arguably makes sense since the size/rows
1459                        // returned measures how much work the compute
1460                        // layer had to do to satisfy the query, but
1461                        // we should revisit it if/when we start
1462                        // logging the inner execute separately.
1463                        Ok((
1464                            ok,
1465                            SendRowsEndedReason::Success {
1466                                result_size: _,
1467                                rows_returned: _,
1468                            },
1469                        )) => (
1470                            Ok(ok),
1471                            StatementEndedExecutionReason::Success {
1472                                result_size: None,
1473                                rows_returned: None,
1474                                execution_strategy: None,
1475                            },
1476                        ),
1477                        Ok((ok, SendRowsEndedReason::Errored { error })) => {
1478                            (Ok(ok), StatementEndedExecutionReason::Errored { error })
1479                        }
1480                    };
1481                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1482                        self.adapter_client
1483                            .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1484                    }
1485                    result
1486                }
1487                // FETCH is an awkward command for our current architecture. In Postgres it
1488                // will extract <count> rows from the target portal, cache them, and return
1489                // them to the user as requested. Its command tag is always FETCH <num rows
1490                // extracted>. In Materialize, since we have chosen to not fully support FETCH,
1491                // we must remember the number of rows that were returned. Use this tag to
1492                // remember that information and return it.
1493                PortalState::Completed(Some(tag)) => {
1494                    let tag = tag.to_string();
1495                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1496                        self.adapter_client.retire_execute(
1497                            outer_ctx_extra,
1498                            StatementEndedExecutionReason::Success {
1499                                result_size: None,
1500                                rows_returned: None,
1501                                execution_strategy: None,
1502                            },
1503                        );
1504                    }
1505                    self.send(BackendMessage::CommandComplete { tag }).await?;
1506                    Ok(State::Ready)
1507                }
1508                PortalState::Completed(None) => {
1509                    let error = format!(
1510                        "portal {} cannot be run",
1511                        Ident::new_unchecked(portal_name).to_ast_string_stable()
1512                    );
1513                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1514                        self.adapter_client.retire_execute(
1515                            outer_ctx_extra,
1516                            StatementEndedExecutionReason::Errored {
1517                                error: error.clone(),
1518                            },
1519                        );
1520                    }
1521                    self.error(ErrorResponse::error(
1522                        SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1523                        error,
1524                    ))
1525                    .await
1526                }
1527            }
1528        }
1529        .instrument(debug_span!("execute"))
1530        .boxed()
1531    }
1532
1533    #[instrument(level = "debug")]
1534    async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1535        // Start a transaction if we aren't in one.
1536        self.ensure_transaction(1, "describe_statement").await?;
1537
1538        let stmt = match self.adapter_client.get_prepared_statement(name).await {
1539            Ok(stmt) => stmt,
1540            Err(err) => return self.error(err.into_response(Severity::Error)).await,
1541        };
1542        // Cloning to avoid a mutable borrow issue because `send` also uses `adapter_client`
1543        let parameter_desc = BackendMessage::ParameterDescription(
1544            stmt.desc()
1545                .param_types
1546                .iter()
1547                .map(mz_pgrepr::Type::from)
1548                .collect(),
1549        );
1550        // Claim that all results will be output in text format, even
1551        // though the true result formats are not yet known. A bit
1552        // weird, but this is the behavior that PostgreSQL specifies.
1553        let formats = vec![Format::Text; stmt.desc().arity()];
1554        let row_desc = describe_rows(stmt.desc(), &formats);
1555        self.send_all([parameter_desc, row_desc]).await?;
1556        Ok(State::Ready)
1557    }
1558
1559    #[instrument(level = "debug")]
1560    async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1561        // Start a transaction if we aren't in one.
1562        self.ensure_transaction(1, "describe_portal").await?;
1563
1564        let session = self.adapter_client.session();
1565        let row_desc = session
1566            .get_portal_unverified(name)
1567            .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1568        match row_desc {
1569            Some(row_desc) => {
1570                self.send(row_desc).await?;
1571                Ok(State::Ready)
1572            }
1573            None => {
1574                self.error(ErrorResponse::error(
1575                    SqlState::INVALID_CURSOR_NAME,
1576                    format!("portal {} does not exist", name.quoted()),
1577                ))
1578                .await
1579            }
1580        }
1581    }
1582
1583    #[instrument(level = "debug")]
1584    async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1585        self.adapter_client
1586            .session()
1587            .remove_prepared_statement(&name);
1588        self.send(BackendMessage::CloseComplete).await?;
1589        Ok(State::Ready)
1590    }
1591
1592    #[instrument(level = "debug")]
1593    async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1594        self.adapter_client.session().remove_portal(&name);
1595        self.send(BackendMessage::CloseComplete).await?;
1596        Ok(State::Ready)
1597    }
1598
1599    fn complete_portal(&mut self, name: &str) {
1600        let portal = self
1601            .adapter_client
1602            .session()
1603            .get_portal_unverified_mut(name)
1604            .expect("portal should exist");
1605        *portal.state = PortalState::Completed(None);
1606    }
1607
1608    async fn fetch(
1609        &mut self,
1610        name: String,
1611        count: Option<FetchDirection>,
1612        max_rows: ExecuteCount,
1613        fetch_portal_name: Option<String>,
1614        timeout: ExecuteTimeout,
1615        ctx_extra: ExecuteContextExtra,
1616    ) -> Result<State, io::Error> {
1617        // Unlike Execute, no count specified in FETCH returns 1 row, and 0 means 0
1618        // instead of All.
1619        let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1620
1621        // Figure out how many rows we should send back by looking at the various
1622        // combinations of the execute and fetch.
1623        //
1624        // In Postgres, Fetch will cache <count> rows from the target portal and
1625        // return those as requested (if, say, an Execute message was sent with a
1626        // max_rows < the Fetch's count). We expect that case to be incredibly rare and
1627        // so have chosen to not support it until users request it. This eases
1628        // implementation difficulty since we don't have to be able to "send" rows to
1629        // a buffer.
1630        //
1631        // TODO(mjibson): Test this somehow? Need to divide up the pgtest files in
1632        // order to have some that are not Postgres compatible.
1633        let count = match (max_rows, count) {
1634            (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1635                let count = usize::cast_from(count);
1636                if max_rows < count {
1637                    let msg = "Execute with max_rows < a FETCH's count is not supported";
1638                    self.adapter_client.retire_execute(
1639                        ctx_extra,
1640                        StatementEndedExecutionReason::Errored {
1641                            error: msg.to_string(),
1642                        },
1643                    );
1644                    return self
1645                        .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1646                        .await;
1647                }
1648                ExecuteCount::Count(count)
1649            }
1650            (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1651                let msg = "Execute with max_rows of a FETCH ALL is not supported";
1652                self.adapter_client.retire_execute(
1653                    ctx_extra,
1654                    StatementEndedExecutionReason::Errored {
1655                        error: msg.to_string(),
1656                    },
1657                );
1658                return self
1659                    .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1660                    .await;
1661            }
1662            (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1663            (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1664                ExecuteCount::Count(usize::cast_from(count))
1665            }
1666        };
1667        let cursor_name = name.to_string();
1668        self.execute(
1669            cursor_name,
1670            count,
1671            fetch_message,
1672            fetch_portal_name,
1673            timeout,
1674            Some(ctx_extra),
1675            None,
1676        )
1677        .await
1678    }
1679
1680    async fn flush(&mut self) -> Result<State, io::Error> {
1681        self.conn.flush().await?;
1682        Ok(State::Ready)
1683    }
1684
1685    /// Sends a backend message to the client, after applying a severity filter.
1686    ///
1687    /// The message is only sent if its severity is above the severity set
1688    /// in the session, with the default value being NOTICE.
1689    #[instrument(level = "debug")]
1690    async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1691    where
1692        M: Into<BackendMessage>,
1693    {
1694        let message: BackendMessage = message.into();
1695        let is_error =
1696            matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1697
1698        self.conn.send(message).await?;
1699
1700        // Flush immediately after sending an error response, as some clients
1701        // expect to be able to read the error response before sending a Sync
1702        // message. This is arguably in violation of the protocol specification,
1703        // but the specification is somewhat ambiguous, and easier to match
1704        // PostgreSQL here than to fix all the clients that have this
1705        // expectation.
1706        if is_error {
1707            self.conn.flush().await?;
1708        }
1709
1710        Ok(())
1711    }
1712
1713    #[instrument(level = "debug")]
1714    pub async fn send_all(
1715        &mut self,
1716        messages: impl IntoIterator<Item = BackendMessage>,
1717    ) -> Result<(), io::Error> {
1718        for m in messages {
1719            self.send(m).await?;
1720        }
1721        Ok(())
1722    }
1723
1724    #[instrument(level = "debug")]
1725    async fn sync(&mut self) -> Result<State, io::Error> {
1726        // Close the current transaction if we are in an implicit transaction.
1727        if self.adapter_client.session().transaction().is_implicit() {
1728            self.commit_transaction().await?;
1729        }
1730        self.ready().await
1731    }
1732
1733    #[instrument(level = "debug")]
1734    async fn ready(&mut self) -> Result<State, io::Error> {
1735        let txn_state = self.adapter_client.session().transaction().into();
1736        self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1737        self.flush().await
1738    }
1739
1740    #[allow(clippy::too_many_arguments)]
1741    #[instrument(level = "debug")]
1742    async fn send_execute_response(
1743        &mut self,
1744        response: ExecuteResponse,
1745        row_desc: Option<RelationDesc>,
1746        portal_name: String,
1747        max_rows: ExecuteCount,
1748        get_response: GetResponse,
1749        fetch_portal_name: Option<String>,
1750        timeout: ExecuteTimeout,
1751        execute_started: Instant,
1752    ) -> Result<State, io::Error> {
1753        let mut tag = response.tag();
1754
1755        macro_rules! command_complete {
1756            () => {{
1757                self.send(BackendMessage::CommandComplete {
1758                    tag: tag
1759                        .take()
1760                        .expect("command_complete only called on tag-generating results"),
1761                })
1762                .await?;
1763                Ok(State::Ready)
1764            }};
1765        }
1766
1767        let r = match response {
1768            ExecuteResponse::ClosedCursor => {
1769                self.complete_portal(&portal_name);
1770                command_complete!()
1771            }
1772            ExecuteResponse::DeclaredCursor => {
1773                self.complete_portal(&portal_name);
1774                command_complete!()
1775            }
1776            ExecuteResponse::EmptyQuery => {
1777                self.send(BackendMessage::EmptyQueryResponse).await?;
1778                Ok(State::Ready)
1779            }
1780            ExecuteResponse::Fetch {
1781                name,
1782                count,
1783                timeout,
1784                ctx_extra,
1785            } => {
1786                self.fetch(
1787                    name,
1788                    count,
1789                    max_rows,
1790                    Some(portal_name.to_string()),
1791                    timeout,
1792                    ctx_extra,
1793                )
1794                .await
1795            }
1796            ExecuteResponse::SendingRowsStreaming {
1797                rows,
1798                instance_id,
1799                strategy,
1800            } => {
1801                let row_desc = row_desc
1802                    .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1803
1804                let span = tracing::debug_span!("sending_rows_streaming");
1805
1806                self.send_rows(
1807                    row_desc,
1808                    portal_name,
1809                    InProgressRows::new(RecordFirstRowStream::new(
1810                        Box::new(rows),
1811                        execute_started,
1812                        &self.adapter_client,
1813                        Some(instance_id),
1814                        Some(strategy),
1815                    )),
1816                    max_rows,
1817                    get_response,
1818                    fetch_portal_name,
1819                    timeout,
1820                )
1821                .instrument(span)
1822                .await
1823                .map(|(state, _)| state)
1824            }
1825            ExecuteResponse::SendingRowsImmediate { rows } => {
1826                let row_desc = row_desc
1827                    .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1828
1829                let span = tracing::debug_span!("sending_rows_immediate");
1830
1831                let stream =
1832                    futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1833                self.send_rows(
1834                    row_desc,
1835                    portal_name,
1836                    InProgressRows::new(RecordFirstRowStream::new(
1837                        Box::new(stream),
1838                        execute_started,
1839                        &self.adapter_client,
1840                        None,
1841                        Some(StatementExecutionStrategy::Constant),
1842                    )),
1843                    max_rows,
1844                    get_response,
1845                    fetch_portal_name,
1846                    timeout,
1847                )
1848                .instrument(span)
1849                .await
1850                .map(|(state, _)| state)
1851            }
1852            ExecuteResponse::SetVariable { name, .. } => {
1853                // This code is somewhat awkwardly structured because we
1854                // can't hold `var` across an await point.
1855                let qn = name.to_string();
1856                let msg = if let Some(var) = self
1857                    .adapter_client
1858                    .session()
1859                    .vars_mut()
1860                    .notify_set()
1861                    .find(|v| v.name() == qn)
1862                {
1863                    Some(BackendMessage::ParameterStatus(var.name(), var.value()))
1864                } else {
1865                    None
1866                };
1867                if let Some(msg) = msg {
1868                    self.send(msg).await?;
1869                }
1870                command_complete!()
1871            }
1872            ExecuteResponse::Subscribing {
1873                rx,
1874                ctx_extra,
1875                instance_id,
1876            } => {
1877                if fetch_portal_name.is_none() {
1878                    let mut msg = ErrorResponse::notice(
1879                        SqlState::WARNING,
1880                        "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
1881                    );
1882                    if self.adapter_client.session().vars().application_name() == "psql" {
1883                        msg.hint = Some(
1884                            "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
1885                                .into(),
1886                        )
1887                    }
1888                    self.send(msg).await?;
1889                    self.conn.flush().await?;
1890                }
1891                let row_desc =
1892                    row_desc.expect("missing row description for ExecuteResponse::Subscribing");
1893                let (result, statement_ended_execution_reason) = match self
1894                    .send_rows(
1895                        row_desc,
1896                        portal_name,
1897                        InProgressRows::new(RecordFirstRowStream::new(
1898                            Box::new(UnboundedReceiverStream::new(rx)),
1899                            execute_started,
1900                            &self.adapter_client,
1901                            Some(instance_id),
1902                            None,
1903                        )),
1904                        max_rows,
1905                        get_response,
1906                        fetch_portal_name,
1907                        timeout,
1908                    )
1909                    .await
1910                {
1911                    Err(e) => {
1912                        // This is an error communicating with the connection.
1913                        // We consider that to be a cancelation, rather than a query error.
1914                        (Err(e), StatementEndedExecutionReason::Canceled)
1915                    }
1916                    Ok((ok, SendRowsEndedReason::Canceled)) => {
1917                        (Ok(ok), StatementEndedExecutionReason::Canceled)
1918                    }
1919                    Ok((
1920                        ok,
1921                        SendRowsEndedReason::Success {
1922                            result_size,
1923                            rows_returned,
1924                        },
1925                    )) => (
1926                        Ok(ok),
1927                        StatementEndedExecutionReason::Success {
1928                            result_size: Some(result_size),
1929                            rows_returned: Some(rows_returned),
1930                            execution_strategy: None,
1931                        },
1932                    ),
1933                    Ok((ok, SendRowsEndedReason::Errored { error })) => {
1934                        (Ok(ok), StatementEndedExecutionReason::Errored { error })
1935                    }
1936                };
1937                self.adapter_client
1938                    .retire_execute(ctx_extra, statement_ended_execution_reason);
1939                return result;
1940            }
1941            ExecuteResponse::CopyTo { format, resp } => {
1942                let row_desc =
1943                    row_desc.expect("missing row description for ExecuteResponse::CopyTo");
1944                match *resp {
1945                    ExecuteResponse::Subscribing {
1946                        rx,
1947                        ctx_extra,
1948                        instance_id,
1949                    } => {
1950                        let (result, statement_ended_execution_reason) = match self
1951                            .copy_rows(
1952                                format,
1953                                row_desc,
1954                                RecordFirstRowStream::new(
1955                                    Box::new(UnboundedReceiverStream::new(rx)),
1956                                    execute_started,
1957                                    &self.adapter_client,
1958                                    Some(instance_id),
1959                                    None,
1960                                ),
1961                            )
1962                            .await
1963                        {
1964                            Err(e) => {
1965                                // This is an error communicating with the connection.
1966                                // We consider that to be a cancelation, rather than a query error.
1967                                (Err(e), StatementEndedExecutionReason::Canceled)
1968                            }
1969                            Ok((
1970                                state,
1971                                SendRowsEndedReason::Success {
1972                                    result_size,
1973                                    rows_returned,
1974                                },
1975                            )) => (
1976                                Ok(state),
1977                                StatementEndedExecutionReason::Success {
1978                                    result_size: Some(result_size),
1979                                    rows_returned: Some(rows_returned),
1980                                    execution_strategy: None,
1981                                },
1982                            ),
1983                            Ok((state, SendRowsEndedReason::Errored { error })) => {
1984                                (Ok(state), StatementEndedExecutionReason::Errored { error })
1985                            }
1986                            Ok((state, SendRowsEndedReason::Canceled)) => {
1987                                (Ok(state), StatementEndedExecutionReason::Canceled)
1988                            }
1989                        };
1990                        self.adapter_client
1991                            .retire_execute(ctx_extra, statement_ended_execution_reason);
1992                        return result;
1993                    }
1994                    ExecuteResponse::SendingRowsStreaming {
1995                        rows,
1996                        instance_id,
1997                        strategy,
1998                    } => {
1999                        // We don't need to finalize execution here;
2000                        // it was already done in the
2001                        // coordinator. Just extract the state and
2002                        // return that.
2003                        return self
2004                            .copy_rows(
2005                                format,
2006                                row_desc,
2007                                RecordFirstRowStream::new(
2008                                    Box::new(rows),
2009                                    execute_started,
2010                                    &self.adapter_client,
2011                                    Some(instance_id),
2012                                    Some(strategy),
2013                                ),
2014                            )
2015                            .await
2016                            .map(|(state, _)| state);
2017                    }
2018                    ExecuteResponse::SendingRowsImmediate { rows } => {
2019                        let span = tracing::debug_span!("sending_rows_immediate");
2020
2021                        let rows = futures::stream::once(futures::future::ready(
2022                            PeekResponseUnary::Rows(rows),
2023                        ));
2024                        // We don't need to finalize execution here;
2025                        // it was already done in the
2026                        // coordinator. Just extract the state and
2027                        // return that.
2028                        return self
2029                            .copy_rows(
2030                                format,
2031                                row_desc,
2032                                RecordFirstRowStream::new(
2033                                    Box::new(rows),
2034                                    execute_started,
2035                                    &self.adapter_client,
2036                                    None,
2037                                    Some(StatementExecutionStrategy::Constant),
2038                                ),
2039                            )
2040                            .instrument(span)
2041                            .await
2042                            .map(|(state, _)| state);
2043                    }
2044                    _ => {
2045                        return self
2046                            .error(ErrorResponse::error(
2047                                SqlState::INTERNAL_ERROR,
2048                                "unsupported COPY response type".to_string(),
2049                            ))
2050                            .await;
2051                    }
2052                };
2053            }
2054            ExecuteResponse::CopyFrom {
2055                target_id,
2056                target_name,
2057                columns,
2058                params,
2059                ctx_extra,
2060            } => {
2061                let row_desc =
2062                    row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2063                self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2064                    .await
2065            }
2066            ExecuteResponse::TransactionCommitted { params }
2067            | ExecuteResponse::TransactionRolledBack { params } => {
2068                let notify_set: mz_ore::collections::HashSet<String> = self
2069                    .adapter_client
2070                    .session()
2071                    .vars()
2072                    .notify_set()
2073                    .map(|v| v.name().to_string())
2074                    .collect();
2075
2076                // Only report on parameters that are in the notify set.
2077                for (name, value) in params
2078                    .into_iter()
2079                    .filter(|(name, _v)| notify_set.contains(*name))
2080                {
2081                    let msg = BackendMessage::ParameterStatus(name, value);
2082                    self.send(msg).await?;
2083                }
2084                command_complete!()
2085            }
2086
2087            ExecuteResponse::AlteredDefaultPrivileges
2088            | ExecuteResponse::AlteredObject(..)
2089            | ExecuteResponse::AlteredRole
2090            | ExecuteResponse::AlteredSystemConfiguration
2091            | ExecuteResponse::CreatedCluster { .. }
2092            | ExecuteResponse::CreatedClusterReplica { .. }
2093            | ExecuteResponse::CreatedConnection { .. }
2094            | ExecuteResponse::CreatedDatabase { .. }
2095            | ExecuteResponse::CreatedIndex { .. }
2096            | ExecuteResponse::CreatedIntrospectionSubscribe
2097            | ExecuteResponse::CreatedMaterializedView { .. }
2098            | ExecuteResponse::CreatedContinualTask { .. }
2099            | ExecuteResponse::CreatedRole
2100            | ExecuteResponse::CreatedSchema { .. }
2101            | ExecuteResponse::CreatedSecret { .. }
2102            | ExecuteResponse::CreatedSink { .. }
2103            | ExecuteResponse::CreatedSource { .. }
2104            | ExecuteResponse::CreatedTable { .. }
2105            | ExecuteResponse::CreatedType
2106            | ExecuteResponse::CreatedView { .. }
2107            | ExecuteResponse::CreatedViews { .. }
2108            | ExecuteResponse::CreatedNetworkPolicy
2109            | ExecuteResponse::Comment
2110            | ExecuteResponse::Deallocate { .. }
2111            | ExecuteResponse::Deleted(..)
2112            | ExecuteResponse::DiscardedAll
2113            | ExecuteResponse::DiscardedTemp
2114            | ExecuteResponse::DroppedObject(_)
2115            | ExecuteResponse::DroppedOwned
2116            | ExecuteResponse::GrantedPrivilege
2117            | ExecuteResponse::GrantedRole
2118            | ExecuteResponse::Inserted(..)
2119            | ExecuteResponse::Copied(..)
2120            | ExecuteResponse::Prepare
2121            | ExecuteResponse::Raised
2122            | ExecuteResponse::ReassignOwned
2123            | ExecuteResponse::RevokedPrivilege
2124            | ExecuteResponse::RevokedRole
2125            | ExecuteResponse::StartedTransaction { .. }
2126            | ExecuteResponse::Updated(..)
2127            | ExecuteResponse::ValidatedConnection => {
2128                command_complete!()
2129            }
2130        };
2131
2132        assert_none!(tag, "tag created but not consumed: {:?}", tag);
2133        r
2134    }
2135
2136    #[allow(clippy::too_many_arguments)]
2137    // TODO(guswynn): figure out how to get it to compile without skip_all
2138    #[mz_ore::instrument(level = "debug")]
2139    async fn send_rows(
2140        &mut self,
2141        row_desc: RelationDesc,
2142        portal_name: String,
2143        mut rows: InProgressRows,
2144        max_rows: ExecuteCount,
2145        get_response: GetResponse,
2146        fetch_portal_name: Option<String>,
2147        timeout: ExecuteTimeout,
2148    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2149        // If this portal is being executed from a FETCH then we need to use the result
2150        // format type of the outer portal.
2151        let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2152            name
2153        } else {
2154            &portal_name
2155        };
2156        let result_formats = self
2157            .adapter_client
2158            .session()
2159            .get_portal_unverified(result_format_portal_name)
2160            .expect("valid fetch portal name for send rows")
2161            .result_formats
2162            .clone();
2163
2164        let (mut wait_once, mut deadline) = match timeout {
2165            ExecuteTimeout::None => (false, None),
2166            ExecuteTimeout::Seconds(t) => (
2167                false,
2168                Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2169            ),
2170            ExecuteTimeout::WaitOnce => (true, None),
2171        };
2172
2173        // Sanity check that the various `RelationDesc`s match up.
2174        {
2175            let portal_name_desc = &self
2176                .adapter_client
2177                .session()
2178                .get_portal_unverified(portal_name.as_str())
2179                .expect("portal should exist")
2180                .desc
2181                .relation_desc;
2182            if let Some(portal_name_desc) = portal_name_desc {
2183                soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2184            }
2185            if let Some(fetch_portal_name) = &fetch_portal_name {
2186                let fetch_portal_desc = &self
2187                    .adapter_client
2188                    .session()
2189                    .get_portal_unverified(fetch_portal_name)
2190                    .expect("portal should exist")
2191                    .desc
2192                    .relation_desc;
2193                if let Some(fetch_portal_desc) = fetch_portal_desc {
2194                    soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2195                }
2196            }
2197        }
2198
2199        self.conn.set_encode_state(
2200            row_desc
2201                .typ()
2202                .column_types
2203                .iter()
2204                .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2205                .zip_eq(result_formats)
2206                .collect(),
2207        );
2208
2209        let mut total_sent_rows = 0;
2210        let mut total_sent_bytes = 0;
2211        // want_rows is the maximum number of rows the client wants.
2212        let mut want_rows = match max_rows {
2213            ExecuteCount::All => usize::MAX,
2214            ExecuteCount::Count(count) => count,
2215        };
2216
2217        // Send rows while the client still wants them and there are still rows to send.
2218        loop {
2219            // Fetch next batch of rows, waiting for a possible requested
2220            // timeout or notice.
2221            let batch = if rows.current.is_some() {
2222                FetchResult::Rows(rows.current.take())
2223            } else if want_rows == 0 {
2224                FetchResult::Rows(None)
2225            } else {
2226                let notice_fut = self.adapter_client.session().recv_notice();
2227                tokio::select! {
2228                    err = self.conn.wait_closed() => return Err(err),
2229                    _ = time::sleep_until(deadline.unwrap_or_else(tokio::time::Instant::now)), if deadline.is_some() => FetchResult::Rows(None),
2230                    notice = notice_fut => {
2231                        FetchResult::Notice(notice)
2232                    }
2233                    batch = rows.remaining.recv() => match batch {
2234                        None => FetchResult::Rows(None),
2235                        Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2236                        Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2237                        Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2238                    },
2239                }
2240            };
2241
2242            match batch {
2243                FetchResult::Rows(None) => break,
2244                FetchResult::Rows(Some(mut batch_rows)) => {
2245                    if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2246                        let msg = err.to_string();
2247                        return self
2248                            .error(err.into_response(Severity::Error))
2249                            .await
2250                            .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2251                    }
2252
2253                    // If wait_once is true: the first time this fn is called it blocks (same as
2254                    // deadline == None). The second time this fn is called it should behave the
2255                    // same a 0s timeout.
2256                    if wait_once && batch_rows.peek().is_some() {
2257                        deadline = Some(tokio::time::Instant::now());
2258                        wait_once = false;
2259                    }
2260
2261                    // Send a portion of the rows.
2262                    let mut sent_rows = 0;
2263                    let mut sent_bytes = 0;
2264                    let messages = (&mut batch_rows)
2265                        // TODO(parkmycar): This is a fair bit of juggling between iterator types
2266                        // to count the total number of bytes. Alternatively we could track the
2267                        // total sent bytes in this .map(...) call, but having side effects in map
2268                        // is a code smell.
2269                        .map(|row| {
2270                            let row_len = row.byte_len();
2271                            let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2272                            (row_len, BackendMessage::DataRow(values))
2273                        })
2274                        .inspect(|(row_len, _)| {
2275                            sent_bytes += row_len;
2276                            sent_rows += 1
2277                        })
2278                        .map(|(_row_len, row)| row)
2279                        .take(want_rows);
2280                    self.send_all(messages).await?;
2281
2282                    total_sent_rows += sent_rows;
2283                    total_sent_bytes += sent_bytes;
2284                    want_rows -= sent_rows;
2285
2286                    // If we have sent the number of requested rows, put the remainder of the batch
2287                    // (if any) back and stop sending.
2288                    if want_rows == 0 {
2289                        if batch_rows.peek().is_some() {
2290                            rows.current = Some(batch_rows);
2291                        }
2292                        break;
2293                    }
2294
2295                    self.conn.flush().await?;
2296                }
2297                FetchResult::Notice(notice) => {
2298                    self.send(notice.into_response()).await?;
2299                    self.conn.flush().await?;
2300                }
2301                FetchResult::Error(text) => {
2302                    return self
2303                        .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2304                        .await
2305                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2306                }
2307                FetchResult::Canceled => {
2308                    return self
2309                        .error(ErrorResponse::error(
2310                            SqlState::QUERY_CANCELED,
2311                            "canceling statement due to user request",
2312                        ))
2313                        .await
2314                        .map(|state| (state, SendRowsEndedReason::Canceled));
2315                }
2316            }
2317        }
2318
2319        let portal = self
2320            .adapter_client
2321            .session()
2322            .get_portal_unverified_mut(&portal_name)
2323            .expect("valid portal name for send rows");
2324
2325        let saw_rows = rows.remaining.saw_rows;
2326        let no_more_rows = rows.no_more_rows();
2327        let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2328
2329        // Always return rows back, even if it's empty. This prevents an unclosed
2330        // portal from re-executing after it has been emptied.
2331        *portal.state = PortalState::InProgress(Some(rows));
2332
2333        let fetch_portal = fetch_portal_name.map(|name| {
2334            self.adapter_client
2335                .session()
2336                .get_portal_unverified_mut(&name)
2337                .expect("valid fetch portal")
2338        });
2339        let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2340        self.send(response_message).await?;
2341
2342        // Attend to metrics if there are no more rows.
2343        if no_more_rows {
2344            let statement_type = if let Some(stmt) = &self
2345                .adapter_client
2346                .session()
2347                .get_portal_unverified(&portal_name)
2348                .expect("valid portal name for send_rows")
2349                .stmt
2350            {
2351                metrics::statement_type_label_value(stmt.deref())
2352            } else {
2353                "no-statement"
2354            };
2355            let duration = if saw_rows {
2356                recorded_first_row_instant
2357                    .expect("recorded_first_row_instant because saw_rows")
2358                    .elapsed()
2359            } else {
2360                // If the result is empty, then we define time from first to last row as 0.
2361                // (Note that, currently, an empty result involves a PeekResponse with 0 rows, which
2362                // does flip `saw_rows`, so this code path is currently not exercised.)
2363                Duration::ZERO
2364            };
2365            self.adapter_client
2366                .inner()
2367                .metrics()
2368                .result_rows_first_to_last_byte_seconds
2369                .with_label_values(&[statement_type])
2370                .observe(duration.as_secs_f64());
2371        }
2372
2373        Ok((
2374            State::Ready,
2375            SendRowsEndedReason::Success {
2376                result_size: u64::cast_from(total_sent_bytes),
2377                rows_returned: u64::cast_from(total_sent_rows),
2378            },
2379        ))
2380    }
2381
2382    #[mz_ore::instrument(level = "debug")]
2383    async fn copy_rows(
2384        &mut self,
2385        format: CopyFormat,
2386        row_desc: RelationDesc,
2387        mut stream: RecordFirstRowStream,
2388    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2389        let (row_format, encode_format) = match format {
2390            CopyFormat::Text => (
2391                CopyFormatParams::Text(CopyTextFormatParams::default()),
2392                Format::Text,
2393            ),
2394            CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2395            CopyFormat::Csv => (
2396                CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2397                Format::Text,
2398            ),
2399            CopyFormat::Parquet => {
2400                let text = "Parquet format is not supported".to_string();
2401                return self
2402                    .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2403                    .await
2404                    .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2405            }
2406        };
2407
2408        let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2409            mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2410        };
2411
2412        let typ = row_desc.typ();
2413        let column_formats = iter::repeat(encode_format)
2414            .take(typ.column_types.len())
2415            .collect();
2416        self.send(BackendMessage::CopyOutResponse {
2417            overall_format: encode_format,
2418            column_formats,
2419        })
2420        .await?;
2421
2422        // In Postgres, binary copy has a header that is followed (in the same
2423        // CopyData) by the first row. In order to replicate their behavior, use a
2424        // common vec that we can extend one time now and then fill up with the encode
2425        // functions.
2426        let mut out = Vec::new();
2427
2428        if let CopyFormat::Binary = format {
2429            // 11-byte signature.
2430            out.extend(b"PGCOPY\n\xFF\r\n\0");
2431            // 32-bit flags field.
2432            out.extend([0, 0, 0, 0]);
2433            // 32-bit header extension length field.
2434            out.extend([0, 0, 0, 0]);
2435        }
2436
2437        let mut count = 0;
2438        let mut total_sent_bytes = 0;
2439        loop {
2440            tokio::select! {
2441                e = self.conn.wait_closed() => return Err(e),
2442                batch = stream.recv() => match batch {
2443                    None => break,
2444                    Some(PeekResponseUnary::Error(text)) => {
2445                        return self
2446                            .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2447                        .await
2448                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2449                    }
2450                    Some(PeekResponseUnary::Canceled) => {
2451                        return self.error(ErrorResponse::error(
2452                                SqlState::QUERY_CANCELED,
2453                                "canceling statement due to user request",
2454                            ))
2455                            .await.map(|state| (state, SendRowsEndedReason::Canceled));
2456                    }
2457                    Some(PeekResponseUnary::Rows(mut rows)) => {
2458                        count += rows.count();
2459                        while let Some(row) = rows.next() {
2460                            total_sent_bytes += row.byte_len();
2461                            encode_fn(row, typ, &mut out)?;
2462                            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2463                                .await?;
2464                        }
2465                    }
2466                },
2467                notice = self.adapter_client.session().recv_notice() => {
2468                    self.send(notice.into_response())
2469                        .await?;
2470                    self.conn.flush().await?;
2471                }
2472            }
2473
2474            self.conn.flush().await?;
2475        }
2476        // Send required trailers.
2477        if let CopyFormat::Binary = format {
2478            let trailer: i16 = -1;
2479            out.extend(trailer.to_be_bytes());
2480            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2481                .await?;
2482        }
2483
2484        let tag = format!("COPY {}", count);
2485        self.send(BackendMessage::CopyDone).await?;
2486        self.send(BackendMessage::CommandComplete { tag }).await?;
2487        Ok((
2488            State::Ready,
2489            SendRowsEndedReason::Success {
2490                result_size: u64::cast_from(total_sent_bytes),
2491                rows_returned: u64::cast_from(count),
2492            },
2493        ))
2494    }
2495
2496    /// Handles the copy-in mode of the postgres protocol from transferring
2497    /// data to the server.
2498    #[instrument(level = "debug")]
2499    async fn copy_from(
2500        &mut self,
2501        target_id: CatalogItemId,
2502        target_name: String,
2503        columns: Vec<ColumnIndex>,
2504        params: CopyFormatParams<'_>,
2505        row_desc: RelationDesc,
2506        mut ctx_extra: ExecuteContextExtra,
2507    ) -> Result<State, io::Error> {
2508        let res = self
2509            .copy_from_inner(
2510                target_id,
2511                target_name,
2512                columns,
2513                params,
2514                row_desc,
2515                &mut ctx_extra,
2516            )
2517            .await;
2518        match &res {
2519            Ok(State::Done) => {
2520                // The connection closed gracefully without sending us a `CopyDone`,
2521                // causing us to just drop the copy request.
2522                // For the purposes of statement logging, we count this as a cancellation.
2523                self.adapter_client
2524                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2525            }
2526            Err(e) => {
2527                self.adapter_client.retire_execute(
2528                    ctx_extra,
2529                    StatementEndedExecutionReason::Errored {
2530                        error: format!("{e}"),
2531                    },
2532                );
2533            }
2534            other => {
2535                tracing::warn!(?other, "aborting COPY FROM");
2536                self.adapter_client
2537                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Aborted);
2538            }
2539        }
2540        res
2541    }
2542
2543    async fn copy_from_inner(
2544        &mut self,
2545        target_id: CatalogItemId,
2546        target_name: String,
2547        columns: Vec<ColumnIndex>,
2548        params: CopyFormatParams<'_>,
2549        row_desc: RelationDesc,
2550        ctx_extra: &mut ExecuteContextExtra,
2551    ) -> Result<State, io::Error> {
2552        let typ = row_desc.typ();
2553        let column_formats = vec![Format::Text; typ.column_types.len()];
2554        self.send(BackendMessage::CopyInResponse {
2555            overall_format: Format::Text,
2556            column_formats,
2557        })
2558        .await?;
2559        self.conn.flush().await?;
2560
2561        let system_vars = self.adapter_client.get_system_vars().await;
2562        let max_size = system_vars
2563            .get(MAX_COPY_FROM_SIZE.name())
2564            .ok()
2565            .and_then(|max_size| max_size.value().parse().ok())
2566            .unwrap_or(usize::MAX);
2567        tracing::debug!("COPY FROM max buffer size: {max_size} bytes");
2568
2569        let mut data = Vec::new();
2570        loop {
2571            let message = self.conn.recv().await?;
2572            match message {
2573                Some(FrontendMessage::CopyData(buf)) => {
2574                    // Bail before we OOM.
2575                    if (data.len() + buf.len()) > max_size {
2576                        return self
2577                            .error(ErrorResponse::error(
2578                                SqlState::INSUFFICIENT_RESOURCES,
2579                                "COPY FROM STDIN too large",
2580                            ))
2581                            .await;
2582                    }
2583                    data.extend(buf)
2584                }
2585                Some(FrontendMessage::CopyDone) => break,
2586                Some(FrontendMessage::CopyFail(err)) => {
2587                    self.adapter_client.retire_execute(
2588                        std::mem::take(ctx_extra),
2589                        StatementEndedExecutionReason::Canceled,
2590                    );
2591                    return self
2592                        .error(ErrorResponse::error(
2593                            SqlState::QUERY_CANCELED,
2594                            format!("COPY from stdin failed: {}", err),
2595                        ))
2596                        .await;
2597                }
2598                Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2599                Some(_) => {
2600                    let msg = "unexpected message type during COPY from stdin";
2601                    self.adapter_client.retire_execute(
2602                        std::mem::take(ctx_extra),
2603                        StatementEndedExecutionReason::Errored {
2604                            error: msg.to_string(),
2605                        },
2606                    );
2607                    return self
2608                        .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
2609                        .await;
2610                }
2611                None => {
2612                    return Ok(State::Done);
2613                }
2614            }
2615        }
2616
2617        let column_types = typ
2618            .column_types
2619            .iter()
2620            .map(|x| &x.scalar_type)
2621            .map(mz_pgrepr::Type::from)
2622            .collect::<Vec<mz_pgrepr::Type>>();
2623
2624        let rows = match mz_pgcopy::decode_copy_format(&data, &column_types, params) {
2625            Ok(rows) => rows,
2626            Err(e) => {
2627                self.adapter_client.retire_execute(
2628                    std::mem::take(ctx_extra),
2629                    StatementEndedExecutionReason::Errored {
2630                        error: e.to_string(),
2631                    },
2632                );
2633                return self
2634                    .error(ErrorResponse::error(
2635                        SqlState::BAD_COPY_FILE_FORMAT,
2636                        format!("{}", e),
2637                    ))
2638                    .await;
2639            }
2640        };
2641
2642        let count = rows.len();
2643
2644        if let Err(e) = self
2645            .adapter_client
2646            .insert_rows(
2647                target_id,
2648                target_name,
2649                columns,
2650                rows,
2651                std::mem::take(ctx_extra),
2652            )
2653            .await
2654        {
2655            self.adapter_client.retire_execute(
2656                std::mem::take(ctx_extra),
2657                StatementEndedExecutionReason::Errored {
2658                    error: e.to_string(),
2659                },
2660            );
2661            return self.error(e.into_response(Severity::Error)).await;
2662        }
2663
2664        let tag = format!("COPY {}", count);
2665        self.send(BackendMessage::CommandComplete { tag }).await?;
2666
2667        Ok(State::Ready)
2668    }
2669
2670    #[instrument(level = "debug")]
2671    async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
2672        let notices = self
2673            .adapter_client
2674            .session()
2675            .drain_notices()
2676            .into_iter()
2677            .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
2678        self.send_all(notices).await?;
2679        Ok(())
2680    }
2681
2682    #[instrument(level = "debug")]
2683    async fn error(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
2684        assert!(err.severity.is_error());
2685        debug!(
2686            "cid={} error code={}",
2687            self.adapter_client.session().conn_id(),
2688            err.code.code()
2689        );
2690        let is_fatal = err.severity.is_fatal();
2691        self.send(BackendMessage::ErrorResponse(err)).await?;
2692
2693        let txn = self.adapter_client.session().transaction();
2694        match txn {
2695            // Error can be called from describe and parse and so might not be in an active
2696            // transaction.
2697            TransactionStatus::Default | TransactionStatus::Failed(_) => {}
2698            // In Started (i.e., a single statement), cleanup ourselves.
2699            TransactionStatus::Started(_) => {
2700                self.rollback_transaction().await?;
2701            }
2702            // Implicit transactions also clear themselves.
2703            TransactionStatus::InTransactionImplicit(_) => {
2704                self.rollback_transaction().await?;
2705            }
2706            // Explicit transactions move to failed.
2707            TransactionStatus::InTransaction(_) => {
2708                self.adapter_client.fail_transaction();
2709            }
2710        };
2711        if is_fatal {
2712            Ok(State::Done)
2713        } else {
2714            Ok(State::Drain)
2715        }
2716    }
2717
2718    #[instrument(level = "debug")]
2719    async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
2720        self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
2721            SqlState::IN_FAILED_SQL_TRANSACTION,
2722            ABORTED_TXN_MSG,
2723        )))
2724        .await?;
2725        Ok(State::Drain)
2726    }
2727
2728    fn is_aborted_txn(&mut self) -> bool {
2729        matches!(
2730            self.adapter_client.session().transaction(),
2731            TransactionStatus::Failed(_)
2732        )
2733    }
2734}
2735
2736fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
2737    match (formats.len(), n) {
2738        (0, e) => Ok(vec![Format::Text; e]),
2739        (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
2740        (a, e) if a == e => Ok(formats),
2741        (a, e) => Err(format!(
2742            "expected {} field format specifiers, but got {}",
2743            e, a
2744        )),
2745    }
2746}
2747
2748fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
2749    match &stmt_desc.relation_desc {
2750        Some(desc) if !stmt_desc.is_copy => {
2751            BackendMessage::RowDescription(message::encode_row_description(desc, formats))
2752        }
2753        _ => BackendMessage::NoData,
2754    }
2755}
2756
2757type GetResponse = fn(
2758    max_rows: ExecuteCount,
2759    total_sent_rows: usize,
2760    fetch_portal: Option<PortalRefMut>,
2761) -> BackendMessage;
2762
2763// A GetResponse used by send_rows during execute messages on portals or for
2764// simple query messages.
2765fn portal_exec_message(
2766    max_rows: ExecuteCount,
2767    total_sent_rows: usize,
2768    _fetch_portal: Option<PortalRefMut>,
2769) -> BackendMessage {
2770    // If max_rows is not specified, we will always send back a CommandComplete. If
2771    // max_rows is specified, we only send CommandComplete if there were more rows
2772    // requested than were remaining. That is, if max_rows == number of rows that
2773    // were remaining before sending (not that are remaining after sending), then
2774    // we still send a PortalSuspended. The number of remaining rows after the rows
2775    // have been sent doesn't matter. This matches postgres.
2776    match max_rows {
2777        ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
2778            BackendMessage::PortalSuspended
2779        }
2780        _ => BackendMessage::CommandComplete {
2781            tag: format!("SELECT {}", total_sent_rows),
2782        },
2783    }
2784}
2785
2786// A GetResponse used by send_rows during FETCH queries.
2787fn fetch_message(
2788    _max_rows: ExecuteCount,
2789    total_sent_rows: usize,
2790    fetch_portal: Option<PortalRefMut>,
2791) -> BackendMessage {
2792    let tag = format!("FETCH {}", total_sent_rows);
2793    if let Some(portal) = fetch_portal {
2794        *portal.state = PortalState::Completed(Some(tag.clone()));
2795    }
2796    BackendMessage::CommandComplete { tag }
2797}
2798
2799#[derive(Debug, Copy, Clone)]
2800enum ExecuteCount {
2801    All,
2802    Count(usize),
2803}
2804
2805// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
2806fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
2807    match stmt {
2808        // Add PREPARE to this if we ever support it.
2809        Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
2810        None => false,
2811    }
2812}
2813
2814#[derive(Debug)]
2815enum FetchResult {
2816    Rows(Option<Box<dyn RowIterator + Send + Sync>>),
2817    Canceled,
2818    Error(String),
2819    Notice(AdapterNotice),
2820}
2821
2822#[cfg(test)]
2823mod test {
2824    use super::*;
2825
2826    #[mz_ore::test]
2827    fn test_parse_options() {
2828        struct TestCase {
2829            input: &'static str,
2830            expect: Result<Vec<(&'static str, &'static str)>, ()>,
2831        }
2832        let tests = vec![
2833            TestCase {
2834                input: "",
2835                expect: Ok(vec![]),
2836            },
2837            TestCase {
2838                input: "--key",
2839                expect: Err(()),
2840            },
2841            TestCase {
2842                input: "--key=val",
2843                expect: Ok(vec![("key", "val")]),
2844            },
2845            TestCase {
2846                input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
2847                expect: Ok(vec![
2848                    ("key", "val"),
2849                    ("key2", "val2"),
2850                    ("key3", "val3"),
2851                    ("key4", "val4"),
2852                    ("key5", "val5"),
2853                ]),
2854            },
2855            TestCase {
2856                input: r#"-c\ key=val"#,
2857                expect: Ok(vec![(" key", "val")]),
2858            },
2859            TestCase {
2860                input: "--key=val -ckey2 val2",
2861                expect: Err(()),
2862            },
2863            // Unclear what this should do.
2864            TestCase {
2865                input: "--key=",
2866                expect: Ok(vec![("key", "")]),
2867            },
2868        ];
2869        for test in tests {
2870            let got = parse_options(test.input);
2871            let expect = test.expect.map(|r| {
2872                r.into_iter()
2873                    .map(|(k, v)| (k.to_owned(), v.to_owned()))
2874                    .collect()
2875            });
2876            assert_eq!(got, expect, "input: {}", test.input);
2877        }
2878    }
2879
2880    #[mz_ore::test]
2881    fn test_parse_option() {
2882        struct TestCase {
2883            input: &'static str,
2884            expect: Result<(&'static str, &'static str), ()>,
2885        }
2886        let tests = vec![
2887            TestCase {
2888                input: "",
2889                expect: Err(()),
2890            },
2891            TestCase {
2892                input: "--",
2893                expect: Err(()),
2894            },
2895            TestCase {
2896                input: "--c",
2897                expect: Err(()),
2898            },
2899            TestCase {
2900                input: "a=b",
2901                expect: Err(()),
2902            },
2903            TestCase {
2904                input: "--a=b",
2905                expect: Ok(("a", "b")),
2906            },
2907            TestCase {
2908                input: "--ca=b",
2909                expect: Ok(("ca", "b")),
2910            },
2911            TestCase {
2912                input: "-ca=b",
2913                expect: Ok(("a", "b")),
2914            },
2915            // Unclear what this should error, but at least test it.
2916            TestCase {
2917                input: "--=",
2918                expect: Ok(("", "")),
2919            },
2920        ];
2921        for test in tests {
2922            let got = parse_option(test.input);
2923            assert_eq!(got, test.expect, "input: {}", test.input);
2924        }
2925    }
2926
2927    #[mz_ore::test]
2928    fn test_split_options() {
2929        struct TestCase {
2930            input: &'static str,
2931            expect: Vec<&'static str>,
2932        }
2933        let tests = vec![
2934            TestCase {
2935                input: "",
2936                expect: vec![],
2937            },
2938            TestCase {
2939                input: "  ",
2940                expect: vec![],
2941            },
2942            TestCase {
2943                input: " a ",
2944                expect: vec!["a"],
2945            },
2946            TestCase {
2947                input: "  ab     cd   ",
2948                expect: vec!["ab", "cd"],
2949            },
2950            TestCase {
2951                input: r#"  ab\     cd   "#,
2952                expect: vec!["ab ", "cd"],
2953            },
2954            TestCase {
2955                input: r#"  ab\\     cd   "#,
2956                expect: vec![r#"ab\"#, "cd"],
2957            },
2958            TestCase {
2959                input: r#"  ab\\\     cd   "#,
2960                expect: vec![r#"ab\ "#, "cd"],
2961            },
2962            TestCase {
2963                input: r#"  ab\\\ cd   "#,
2964                expect: vec![r#"ab\ cd"#],
2965            },
2966            TestCase {
2967                input: r#"  ab\\\cd   "#,
2968                expect: vec![r#"ab\cd"#],
2969            },
2970            TestCase {
2971                input: r#"a\"#,
2972                expect: vec!["a"],
2973            },
2974            TestCase {
2975                input: r#"a\ "#,
2976                expect: vec!["a "],
2977            },
2978            TestCase {
2979                input: r#"\"#,
2980                expect: vec![],
2981            },
2982            TestCase {
2983                input: r#"\ "#,
2984                expect: vec![r#" "#],
2985            },
2986            TestCase {
2987                input: r#" \ "#,
2988                expect: vec![r#" "#],
2989            },
2990            TestCase {
2991                input: r#"\  "#,
2992                expect: vec![r#" "#],
2993            },
2994        ];
2995        for test in tests {
2996            let got = split_options(test.input);
2997            assert_eq!(got, test.expect, "input: {}", test.input);
2998        }
2999    }
3000}