mz_pgwire/
protocol.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use std::{iter, mem};
17
18use base64::prelude::*;
19use byteorder::{ByteOrder, NetworkEndian};
20use futures::future::{BoxFuture, FutureExt, pending};
21use itertools::Itertools;
22use mz_adapter::client::RecordFirstRowStream;
23use mz_adapter::session::{
24    EndTransactionAction, InProgressRows, LifecycleTimestamps, PortalRefMut, PortalState,
25    SessionConfig, TransactionStatus,
26};
27use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
28use mz_adapter::{
29    AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, PeekResponseUnary, metrics,
30    verify_datum_desc,
31};
32use mz_auth::password::Password;
33use mz_authenticator::Authenticator;
34use mz_ore::cast::CastFrom;
35use mz_ore::netio::AsyncReady;
36use mz_ore::now::{EpochMillis, SYSTEM_TIME};
37use mz_ore::str::StrExt;
38use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log};
39use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
40use mz_pgwire_common::{
41    ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3,
42    VERSIONS,
43};
44use mz_repr::user::InternalUserMetadata;
45use mz_repr::{
46    CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef,
47    SqlRelationType, SqlScalarType,
48};
49use mz_server_core::TlsMode;
50use mz_server_core::listeners::AllowedRoles;
51use mz_sql::ast::display::AstDisplay;
52use mz_sql::ast::{CopyDirection, CopyStatement, FetchDirection, Ident, Raw, Statement};
53use mz_sql::parse::StatementParseResult;
54use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
55use mz_sql::session::metadata::SessionMetadata;
56use mz_sql::session::user::INTERNAL_USER_NAMES;
57use mz_sql::session::vars::{MAX_COPY_FROM_SIZE, Var, VarInput};
58use postgres::error::SqlState;
59use tokio::io::{self, AsyncRead, AsyncWrite};
60use tokio::select;
61use tokio::time::{self};
62use tokio_metrics::TaskMetrics;
63use tokio_stream::wrappers::UnboundedReceiverStream;
64use tracing::{Instrument, debug, debug_span, warn};
65use uuid::Uuid;
66
67use crate::codec::{
68    FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response,
69};
70use crate::message::{
71    self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds,
72    SASLServerFirstMessage,
73};
74
75/// Reports whether the given stream begins with a pgwire handshake.
76///
77/// To avoid false negatives, there must be at least eight bytes in `buf`.
78pub fn match_handshake(buf: &[u8]) -> bool {
79    // The pgwire StartupMessage looks like this:
80    //
81    //     i32 - Length of entire message.
82    //     i32 - Protocol version number.
83    //     [String] - Arbitrary key-value parameters of any length.
84    //
85    // Since arbitrary parameters can be included in the StartupMessage, the
86    // first Int32 is worthless, since the message could have any length.
87    // Instead, we sniff the protocol version number.
88    if buf.len() < 8 {
89        return false;
90    }
91    let version = NetworkEndian::read_i32(&buf[4..8]);
92    VERSIONS.contains(&version)
93}
94
95/// Parameters for the [`run`] function.
96pub struct RunParams<'a, A, I>
97where
98    I: Iterator<Item = TaskMetrics> + Send,
99{
100    /// The TLS mode of the pgwire server.
101    pub tls_mode: Option<TlsMode>,
102    /// A client for the adapter.
103    pub adapter_client: mz_adapter::Client,
104    /// The connection to the client.
105    pub conn: &'a mut FramedConn<A>,
106    /// The universally unique identifier for the connection.
107    pub conn_uuid: Uuid,
108    /// The protocol version that the client provided in the startup message.
109    pub version: i32,
110    /// The parameters that the client provided in the startup message.
111    pub params: BTreeMap<String, String>,
112    /// Authentication method to use. Frontegg, Password, or None.
113    pub authenticator: Authenticator,
114    /// Global connection limit and count
115    pub active_connection_counter: ConnectionCounter,
116    /// Helm chart version
117    pub helm_chart_version: Option<String>,
118    /// Whether to allow reserved users (ie: mz_system).
119    pub allowed_roles: AllowedRoles,
120    /// Tokio metrics
121    pub tokio_metrics_intervals: I,
122}
123
124/// Runs a pgwire connection to completion.
125///
126/// This involves responding to `FrontendMessage::StartupMessage` and all future
127/// requests until the client terminates the connection or a fatal error occurs.
128///
129/// Note that this function returns successfully even upon delivering a fatal
130/// error to the client. It only returns `Err` if an unexpected I/O error occurs
131/// while communicating with the client, e.g., if the connection is severed in
132/// the middle of a request.
133#[mz_ore::instrument(level = "debug")]
134pub async fn run<'a, A, I>(
135    RunParams {
136        tls_mode,
137        adapter_client,
138        conn,
139        conn_uuid,
140        version,
141        mut params,
142        authenticator,
143        active_connection_counter,
144        helm_chart_version,
145        allowed_roles,
146        tokio_metrics_intervals,
147    }: RunParams<'a, A, I>,
148) -> Result<(), io::Error>
149where
150    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
151    I: Iterator<Item = TaskMetrics> + Send,
152{
153    if version != VERSION_3 {
154        return conn
155            .send(ErrorResponse::fatal(
156                SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
157                "server does not support the client's requested protocol version",
158            ))
159            .await;
160    }
161
162    let user = params.remove("user").unwrap_or_else(String::new);
163
164    // TODO move this somewhere it can be shared with HTTP
165    let is_internal_user = INTERNAL_USER_NAMES.contains(&user);
166    // this is a superset of internal users
167    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(user.as_str());
168    let role_allowed = match allowed_roles {
169        AllowedRoles::Normal => !is_reserved_user,
170        AllowedRoles::Internal => is_internal_user,
171        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
172    };
173    if !role_allowed {
174        let msg = format!("unauthorized login to user '{user}'");
175        return conn
176            .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
177            .await;
178    }
179
180    if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
181        return conn.send(err).await;
182    }
183
184    let (mut session, expired) = match authenticator {
185        Authenticator::Frontegg(frontegg) => {
186            conn.send(BackendMessage::AuthenticationCleartextPassword)
187                .await?;
188            conn.flush().await?;
189            let password = match conn.recv().await? {
190                Some(FrontendMessage::RawAuthentication(data)) => {
191                    match decode_password(Cursor::new(&data)).ok() {
192                        Some(FrontendMessage::Password { password }) => password,
193                        _ => {
194                            return conn
195                                .send(ErrorResponse::fatal(
196                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
197                                    "expected Password message",
198                                ))
199                                .await;
200                        }
201                    }
202                }
203                _ => {
204                    return conn
205                        .send(ErrorResponse::fatal(
206                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
207                            "expected Password message",
208                        ))
209                        .await;
210                }
211            };
212
213            let auth_response = frontegg.authenticate(&user, &password).await;
214            match auth_response {
215                Ok(mut auth_session) => {
216                    // Create a session based on the auth session.
217                    //
218                    // In particular, it's important that the username come from the
219                    // auth session, as Frontegg may return an email address with
220                    // different casing than the user supplied via the pgwire
221                    // username field. We want to use the Frontegg casing as
222                    // canonical.
223                    let session = adapter_client.new_session(SessionConfig {
224                        conn_id: conn.conn_id().clone(),
225                        uuid: conn_uuid,
226                        user: auth_session.user().into(),
227                        client_ip: conn.peer_addr().clone(),
228                        external_metadata_rx: Some(auth_session.external_metadata_rx()),
229                        internal_user_metadata: None,
230                        helm_chart_version,
231                    });
232                    let expired = async move { auth_session.expired().await };
233                    (session, expired.left_future())
234                }
235                Err(err) => {
236                    warn!(?err, "pgwire connection failed authentication");
237                    return conn
238                        .send(ErrorResponse::fatal(
239                            SqlState::INVALID_PASSWORD,
240                            "invalid password",
241                        ))
242                        .await;
243                }
244            }
245        }
246        Authenticator::Password(adapter_client) => {
247            conn.send(BackendMessage::AuthenticationCleartextPassword)
248                .await?;
249            conn.flush().await?;
250            let password = match conn.recv().await? {
251                Some(FrontendMessage::RawAuthentication(data)) => {
252                    match decode_password(Cursor::new(&data)).ok() {
253                        Some(FrontendMessage::Password { password }) => Password(password),
254                        _ => {
255                            return conn
256                                .send(ErrorResponse::fatal(
257                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
258                                    "expected Password message",
259                                ))
260                                .await;
261                        }
262                    }
263                }
264                _ => {
265                    return conn
266                        .send(ErrorResponse::fatal(
267                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
268                            "expected Password message",
269                        ))
270                        .await;
271                }
272            };
273            let auth_response = match adapter_client.authenticate(&user, &password).await {
274                Ok(resp) => resp,
275                Err(err) => {
276                    warn!(?err, "pgwire connection failed authentication");
277                    return conn
278                        .send(ErrorResponse::fatal(
279                            SqlState::INVALID_PASSWORD,
280                            "invalid password",
281                        ))
282                        .await;
283                }
284            };
285            let session = adapter_client.new_session(SessionConfig {
286                conn_id: conn.conn_id().clone(),
287                uuid: conn_uuid,
288                user,
289                client_ip: conn.peer_addr().clone(),
290                external_metadata_rx: None,
291                internal_user_metadata: Some(InternalUserMetadata {
292                    superuser: auth_response.superuser,
293                }),
294                helm_chart_version,
295            });
296            // No frontegg check, so auth session lasts indefinitely.
297            let auth_session = pending().right_future();
298            (session, auth_session)
299        }
300        Authenticator::Sasl(adapter_client) => {
301            // Start the handshake
302            conn.send(BackendMessage::AuthenticationSASL).await?;
303            conn.flush().await?;
304            // Get the initial response indicating chosen mechanism
305            let (mechanism, initial_response) = match conn.recv().await? {
306                Some(FrontendMessage::RawAuthentication(data)) => {
307                    match decode_sasl_initial_response(Cursor::new(&data)).ok() {
308                        Some(FrontendMessage::SASLInitialResponse {
309                            gs2_header,
310                            mechanism,
311                            initial_response,
312                        }) => {
313                            // We do not support channel binding
314                            if gs2_header.channel_binding_enabled() {
315                                return conn
316                                    .send(ErrorResponse::fatal(
317                                        SqlState::PROTOCOL_VIOLATION,
318                                        "channel binding not supported",
319                                    ))
320                                    .await;
321                            }
322                            (mechanism, initial_response)
323                        }
324                        _ => {
325                            return conn
326                                .send(ErrorResponse::fatal(
327                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
328                                    "expected SASLInitialResponse message",
329                                ))
330                                .await;
331                        }
332                    }
333                }
334                _ => {
335                    return conn
336                        .send(ErrorResponse::fatal(
337                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
338                            "expected SASLInitialResponse message",
339                        ))
340                        .await;
341                }
342            };
343
344            if mechanism != "SCRAM-SHA-256" {
345                return conn
346                    .send(ErrorResponse::fatal(
347                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
348                        "unsupported SASL mechanism",
349                    ))
350                    .await;
351            }
352
353            if initial_response.nonce.len() > 256 {
354                return conn
355                    .send(ErrorResponse::fatal(
356                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
357                        "nonce too long",
358                    ))
359                    .await;
360            }
361
362            let (server_first_message_raw, mock_hash) = match adapter_client
363                .generate_sasl_challenge(&user, &initial_response.nonce)
364                .await
365            {
366                Ok(response) => {
367                    let server_first_message_raw = format!(
368                        "r={},s={},i={}",
369                        response.nonce, response.salt, response.iteration_count
370                    );
371
372                    let client_key = [0u8; 32];
373                    let server_key = [1u8; 32];
374                    let mock_hash = format!(
375                        "SCRAM-SHA-256${}:{}${}:{}",
376                        response.iteration_count,
377                        response.salt,
378                        BASE64_STANDARD.encode(client_key),
379                        BASE64_STANDARD.encode(server_key)
380                    );
381
382                    conn.send(BackendMessage::AuthenticationSASLContinue(
383                        SASLServerFirstMessage {
384                            iteration_count: response.iteration_count,
385                            nonce: response.nonce,
386                            salt: response.salt,
387                        },
388                    ))
389                    .await?;
390                    conn.flush().await?;
391                    (server_first_message_raw, mock_hash)
392                }
393                Err(e) => {
394                    return conn.send(e.into_response(Severity::Fatal)).await;
395                }
396            };
397
398            let auth_resp = match conn.recv().await? {
399                Some(FrontendMessage::RawAuthentication(data)) => {
400                    match decode_sasl_response(Cursor::new(&data)).ok() {
401                        Some(FrontendMessage::SASLResponse(response)) => {
402                            let auth_message = format!(
403                                "{},{},{}",
404                                initial_response.client_first_message_bare_raw,
405                                server_first_message_raw,
406                                response.client_final_message_bare_raw
407                            );
408                            if response.proof.len() > 1024 {
409                                return conn
410                                    .send(ErrorResponse::fatal(
411                                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
412                                        "proof too long",
413                                    ))
414                                    .await;
415                            }
416                            match adapter_client
417                                .verify_sasl_proof(
418                                    &user,
419                                    &response.proof,
420                                    &auth_message,
421                                    &mock_hash,
422                                )
423                                .await
424                            {
425                                Ok(resp) => {
426                                    conn.send(BackendMessage::AuthenticationSASLFinal(
427                                        SASLServerFinalMessage {
428                                            kind: SASLServerFinalMessageKinds::Verifier(
429                                                resp.verifier,
430                                            ),
431                                            extensions: vec![],
432                                        },
433                                    ))
434                                    .await?;
435                                    conn.flush().await?;
436                                    resp.auth_resp
437                                }
438                                Err(_) => {
439                                    return conn
440                                        .send(ErrorResponse::fatal(
441                                            SqlState::INVALID_PASSWORD,
442                                            "invalid password",
443                                        ))
444                                        .await;
445                                }
446                            }
447                        }
448                        _ => {
449                            return conn
450                                .send(ErrorResponse::fatal(
451                                    SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
452                                    "expected SASLResponse message",
453                                ))
454                                .await;
455                        }
456                    }
457                }
458                _ => {
459                    return conn
460                        .send(ErrorResponse::fatal(
461                            SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
462                            "expected SASLResponse message",
463                        ))
464                        .await;
465                }
466            };
467
468            let session = adapter_client.new_session(SessionConfig {
469                conn_id: conn.conn_id().clone(),
470                uuid: conn_uuid,
471                user,
472                client_ip: conn.peer_addr().clone(),
473                external_metadata_rx: None,
474                internal_user_metadata: Some(InternalUserMetadata {
475                    superuser: auth_resp.superuser,
476                }),
477                helm_chart_version,
478            });
479            // No frontegg check, so auth session lasts indefinitely.
480            let auth_session = pending().right_future();
481            (session, auth_session)
482        }
483        Authenticator::None => {
484            let session = adapter_client.new_session(SessionConfig {
485                conn_id: conn.conn_id().clone(),
486                uuid: conn_uuid,
487                user,
488                client_ip: conn.peer_addr().clone(),
489                external_metadata_rx: None,
490                internal_user_metadata: None,
491                helm_chart_version,
492            });
493            // No frontegg check, so auth session lasts indefinitely.
494            let auth_session = pending().right_future();
495            (session, auth_session)
496        }
497    };
498
499    let system_vars = adapter_client.get_system_vars().await;
500    for (name, value) in params {
501        let settings = match name.as_str() {
502            "options" => match parse_options(&value) {
503                Ok(opts) => opts,
504                Err(()) => {
505                    session.add_notice(AdapterNotice::BadStartupSetting {
506                        name,
507                        reason: "could not parse".into(),
508                    });
509                    continue;
510                }
511            },
512            _ => vec![(name, value)],
513        };
514        for (key, val) in settings {
515            const LOCAL: bool = false;
516            // TODO: Issuing an error here is better than what we did before
517            // (silently ignore errors on set), but erroring the connection
518            // might be the better behavior. We maybe need to support more
519            // options sent by psql and drivers before we can safely do this.
520            if let Err(err) =
521                session
522                    .vars_mut()
523                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
524            {
525                session.add_notice(AdapterNotice::BadStartupSetting {
526                    name: key,
527                    reason: err.to_string(),
528                });
529            }
530        }
531    }
532    session
533        .vars_mut()
534        .end_transaction(EndTransactionAction::Commit);
535
536    let _guard = match active_connection_counter.allocate_connection(session.user()) {
537        Ok(drop_connection) => drop_connection,
538        Err(e) => {
539            let e: AdapterError = e.into();
540            return conn.send(e.into_response(Severity::Fatal)).await;
541        }
542    };
543
544    // Register session with adapter.
545    let mut adapter_client = match adapter_client.startup(session).await {
546        Ok(adapter_client) => adapter_client,
547        Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
548    };
549
550    let mut buf = vec![BackendMessage::AuthenticationOk];
551    for var in adapter_client.session().vars().notify_set() {
552        buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
553    }
554    buf.push(BackendMessage::BackendKeyData {
555        conn_id: adapter_client.session().conn_id().unhandled(),
556        secret_key: adapter_client.session().secret_key(),
557    });
558    buf.extend(
559        adapter_client
560            .session()
561            .drain_notices()
562            .into_iter()
563            .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
564    );
565    buf.push(BackendMessage::ReadyForQuery(
566        adapter_client.session().transaction().into(),
567    ));
568    conn.send_all(buf).await?;
569    conn.flush().await?;
570
571    let machine = StateMachine {
572        conn,
573        adapter_client,
574        txn_needs_commit: false,
575        tokio_metrics_intervals,
576    };
577
578    select! {
579        r = machine.run() => {
580            // Errors produced internally (like MAX_REQUEST_SIZE being exceeded) should send an
581            // error to the client informing them why the connection was closed. We still want to
582            // return the original error up the stack, though, so we skip error checking during conn
583            // operations.
584            if let Err(err) = &r {
585                let _ = conn
586                    .send(ErrorResponse::fatal(
587                        SqlState::CONNECTION_FAILURE,
588                        err.to_string(),
589                    ))
590                    .await;
591                let _ = conn.flush().await;
592            }
593            r
594        },
595        _ = expired => {
596            conn
597                .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
598                .await?;
599            conn.flush().await
600        }
601    }
602}
603
604/// Returns (name, value) session settings pairs from an options value.
605///
606/// From Postgres, see pg_split_opts in postinit.c and process_postgres_switches
607/// in postgres.c.
608fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
609    let opts = split_options(value);
610    let mut pairs = Vec::with_capacity(opts.len());
611    let mut seen_prefix = false;
612    for opt in opts {
613        if !seen_prefix {
614            if opt == "-c" {
615                seen_prefix = true;
616            } else {
617                let (key, val) = parse_option(&opt)?;
618                pairs.push((key.to_owned(), val.to_owned()));
619            }
620        } else {
621            let (key, val) = opt.split_once('=').ok_or(())?;
622            pairs.push((key.to_owned(), val.to_owned()));
623            seen_prefix = false;
624        }
625    }
626    Ok(pairs)
627}
628
629/// Returns the parsed key and value from option of the form `--key=value`, `-c
630/// key=value`, or `-ckey=value`. Keys replace `-` with `_`. Returns an error if
631/// there was some other prefix.
632fn parse_option(option: &str) -> Result<(&str, &str), ()> {
633    let (key, value) = option.split_once('=').ok_or(())?;
634    for prefix in &["-c", "--"] {
635        if let Some(key) = key.strip_prefix(prefix) {
636            return Ok((key, value));
637        }
638    }
639    Err(())
640}
641
642/// Splits value by any number of spaces except those preceded by `\`.
643fn split_options(value: &str) -> Vec<String> {
644    let mut strs = Vec::new();
645    // Need to build a string because of the escaping, so we can't simply
646    // subslice into value, and this isn't called enough to need to make it
647    // smart so it only builds a string if needed.
648    let mut current = String::new();
649    let mut was_slash = false;
650    for c in value.chars() {
651        was_slash = match c {
652            ' ' => {
653                if was_slash {
654                    current.push(' ');
655                } else if !current.is_empty() {
656                    // To ignore multiple spaces in a row, only push if current
657                    // is not empty.
658                    strs.push(std::mem::take(&mut current));
659                }
660                false
661            }
662            '\\' => {
663                if was_slash {
664                    // Two slashes in a row will add a slash and not escape the
665                    // next char.
666                    current.push('\\');
667                    false
668                } else {
669                    true
670                }
671            }
672            _ => {
673                current.push(c);
674                false
675            }
676        };
677    }
678    // A `\` at the end will be ignored.
679    if !current.is_empty() {
680        strs.push(current);
681    }
682    strs
683}
684
685#[derive(Debug)]
686enum State {
687    Ready,
688    Drain,
689    Done,
690}
691
692struct StateMachine<'a, A, I>
693where
694    I: Iterator<Item = TaskMetrics> + Send + 'a,
695{
696    conn: &'a mut FramedConn<A>,
697    adapter_client: mz_adapter::SessionClient,
698    txn_needs_commit: bool,
699    tokio_metrics_intervals: I,
700}
701
702enum SendRowsEndedReason {
703    Success {
704        result_size: u64,
705        rows_returned: u64,
706    },
707    Errored {
708        error: String,
709    },
710    Canceled,
711}
712
713const ABORTED_TXN_MSG: &str =
714    "current transaction is aborted, commands ignored until end of transaction block";
715
716impl<'a, A, I> StateMachine<'a, A, I>
717where
718    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
719    I: Iterator<Item = TaskMetrics> + Send + 'a,
720{
721    // Manually desugar this (don't use `async fn run`) here because a much better
722    // error message is produced if there are problems with Send or other traits
723    // somewhere within the Future.
724    #[allow(clippy::manual_async_fn)]
725    #[mz_ore::instrument(level = "debug")]
726    fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
727        async move {
728            let mut state = State::Ready;
729            loop {
730                self.send_pending_notices().await?;
731                state = match state {
732                    State::Ready => self.advance_ready().await?,
733                    State::Drain => self.advance_drain().await?,
734                    State::Done => return Ok(()),
735                };
736                self.adapter_client
737                    .add_idle_in_transaction_session_timeout();
738            }
739        }
740    }
741
742    #[instrument(level = "debug")]
743    async fn advance_ready(&mut self) -> Result<State, io::Error> {
744        // Start a new metrics interval before the `recv()` call.
745        self.tokio_metrics_intervals
746            .next()
747            .expect("infinite iterator");
748
749        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
750        let message = select! {
751            biased;
752
753            // `recv_timeout()` is cancel-safe as per it's docs.
754            Some(timeout) = self.adapter_client.recv_timeout() => {
755                let err: AdapterError = timeout.into();
756                let conn_id = self.adapter_client.session().conn_id();
757                tracing::warn!("session timed out, conn_id {}", conn_id);
758
759                // Process the error, doing any state cleanup.
760                let error_response = err.into_response(Severity::Fatal);
761                let error_state = self.error(error_response).await;
762
763                // Terminate __after__ we do any cleanup.
764                self.adapter_client.terminate().await;
765
766                // We must wait for the client to send a request before we can send the error response.
767                // Due to the PG wire protocol, we can't send an ErrorResponse unless it is in response
768                // to a client message.
769                let _ = self.conn.recv().await?;
770                return error_state;
771            },
772            // `recv()` is cancel-safe as per it's docs.
773            message = self.conn.recv() => message?,
774        };
775
776        // Take the metrics since just before the `recv`.
777        let interval = self
778            .tokio_metrics_intervals
779            .next()
780            .expect("infinite iterator");
781        let recv_scheduling_delay_ms = interval.total_scheduled_duration.as_secs_f64() * 1000.0;
782
783        // TODO(ggevay): Consider subtracting the scheduling delay from `received`. It's not obvious
784        // whether we should do this, because the result wouldn't exactly correspond to either first
785        // byte received or last byte received (for msgs that arrive in more than one network packet).
786        let received = SYSTEM_TIME();
787
788        self.adapter_client
789            .remove_idle_in_transaction_session_timeout();
790
791        // NOTE(guswynn): we could consider adding spans to all message types. Currently
792        // only a few message types seem useful.
793        let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
794
795        let start = message.as_ref().map(|_| Instant::now());
796        let next_state = match message {
797            Some(FrontendMessage::Query { sql }) => {
798                let query_root_span =
799                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
800                query_root_span.follows_from(tracing::Span::current());
801                self.query(sql, received)
802                    .instrument(query_root_span)
803                    .await?
804            }
805            Some(FrontendMessage::Parse {
806                name,
807                sql,
808                param_types,
809            }) => self.parse(name, sql, param_types).await?,
810            Some(FrontendMessage::Bind {
811                portal_name,
812                statement_name,
813                param_formats,
814                raw_params,
815                result_formats,
816            }) => {
817                self.bind(
818                    portal_name,
819                    statement_name,
820                    param_formats,
821                    raw_params,
822                    result_formats,
823                )
824                .await?
825            }
826            Some(FrontendMessage::Execute {
827                portal_name,
828                max_rows,
829            }) => {
830                let max_rows = match usize::try_from(max_rows) {
831                    Ok(0) | Err(_) => ExecuteCount::All, // If `max_rows < 0`, no limit.
832                    Ok(n) => ExecuteCount::Count(n),
833                };
834                let execute_root_span =
835                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
836                execute_root_span.follows_from(tracing::Span::current());
837                let state = self
838                    .execute(
839                        portal_name,
840                        max_rows,
841                        portal_exec_message,
842                        None,
843                        ExecuteTimeout::None,
844                        None,
845                        Some(received),
846                    )
847                    .instrument(execute_root_span)
848                    .await?;
849                // In PostgreSQL, when using the extended query protocol, some statements may
850                // trigger an eager commit of the current implicit transaction,
851                // see: <https://git.postgresql.org/gitweb/?p=postgresql.git&a=commitdiff&h=f92944137>.
852                //
853                // In Materialize, however, we eagerly commit every statement outside of an explicit
854                // transaction when using the extended query protocol. This allows us to eliminate
855                // the possibility of a multiple statement implicit transaction, which in turn
856                // allows us to apply single-statement optimizations to queries issued in implicit
857                // transactions in the extended query protocol.
858                //
859                // We don't immediately commit here to allow users to page through the portal if
860                // necessary. Committing the transaction would destroy the portal before the next
861                // Execute command has a chance to resume it. So we instead mark the transaction
862                // for commit the next time that `ensure_transaction` is called.
863                if self.adapter_client.session().transaction().is_implicit() {
864                    self.txn_needs_commit = true;
865                }
866                state
867            }
868            Some(FrontendMessage::DescribeStatement { name }) => {
869                self.describe_statement(&name).await?
870            }
871            Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
872            Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
873            Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
874            Some(FrontendMessage::Flush) => self.flush().await?,
875            Some(FrontendMessage::Sync) => self.sync().await?,
876            Some(FrontendMessage::Terminate) => State::Done,
877
878            Some(FrontendMessage::CopyData(_))
879            | Some(FrontendMessage::CopyDone)
880            | Some(FrontendMessage::CopyFail(_))
881            | Some(FrontendMessage::Password { .. })
882            | Some(FrontendMessage::RawAuthentication(_))
883            | Some(FrontendMessage::SASLInitialResponse { .. })
884            | Some(FrontendMessage::SASLResponse(_)) => State::Drain,
885            None => State::Done,
886        };
887
888        if let Some(start) = start {
889            self.adapter_client
890                .inner()
891                .metrics()
892                .pgwire_message_processing_seconds
893                .with_label_values(&[message_name])
894                .observe(start.elapsed().as_secs_f64());
895        }
896        self.adapter_client
897            .inner()
898            .metrics()
899            .pgwire_recv_scheduling_delay_ms
900            .with_label_values(&[message_name])
901            .observe(recv_scheduling_delay_ms);
902
903        Ok(next_state)
904    }
905
906    async fn advance_drain(&mut self) -> Result<State, io::Error> {
907        let message = self.conn.recv().await?;
908        if message.is_some() {
909            self.adapter_client
910                .remove_idle_in_transaction_session_timeout();
911        }
912        match message {
913            Some(FrontendMessage::Sync) => self.sync().await,
914            None => Ok(State::Done),
915            _ => Ok(State::Drain),
916        }
917    }
918
919    /// Note that `lifecycle_timestamps` belongs to the whole "Simple Query", because the whole
920    /// Simple Query is received and parsed together. This means that if there are multiple
921    /// statements in a Simple Query, then all of them have the same `lifecycle_timestamps`.
922    #[instrument(level = "debug")]
923    async fn one_query(
924        &mut self,
925        stmt: Statement<Raw>,
926        sql: String,
927        lifecycle_timestamps: LifecycleTimestamps,
928    ) -> Result<State, io::Error> {
929        // Bind the portal. Note that this does not set the empty string prepared
930        // statement.
931        const EMPTY_PORTAL: &str = "";
932        if let Err(e) = self
933            .adapter_client
934            .declare(EMPTY_PORTAL.to_string(), stmt, sql)
935            .await
936        {
937            return self.error(e.into_response(Severity::Error)).await;
938        }
939        let portal = self
940            .adapter_client
941            .session()
942            .get_portal_unverified_mut(EMPTY_PORTAL)
943            .expect("unnamed portal should be present");
944
945        *portal.lifecycle_timestamps = Some(lifecycle_timestamps);
946
947        let stmt_desc = portal.desc.clone();
948        if !stmt_desc.param_types.is_empty() {
949            return self
950                .error(ErrorResponse::error(
951                    SqlState::UNDEFINED_PARAMETER,
952                    "there is no parameter $1",
953                ))
954                .await;
955        }
956
957        // Maybe send row description.
958        if let Some(relation_desc) = &stmt_desc.relation_desc {
959            if !stmt_desc.is_copy {
960                let formats = vec![Format::Text; stmt_desc.arity()];
961                self.send(BackendMessage::RowDescription(
962                    message::encode_row_description(relation_desc, &formats),
963                ))
964                .await?;
965            }
966        }
967
968        let result = match self
969            .adapter_client
970            .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
971            .await
972        {
973            Ok((response, execute_started)) => {
974                self.send_pending_notices().await?;
975                self.send_execute_response(
976                    response,
977                    stmt_desc.relation_desc,
978                    EMPTY_PORTAL.to_string(),
979                    ExecuteCount::All,
980                    portal_exec_message,
981                    None,
982                    ExecuteTimeout::None,
983                    execute_started,
984                )
985                .await
986            }
987            Err(e) => {
988                self.send_pending_notices().await?;
989                self.error(e.into_response(Severity::Error)).await
990            }
991        };
992
993        // Destroy the portal.
994        self.adapter_client.session().remove_portal(EMPTY_PORTAL);
995
996        result
997    }
998
999    async fn ensure_transaction(
1000        &mut self,
1001        num_stmts: usize,
1002        message_type: &str,
1003    ) -> Result<(), io::Error> {
1004        let start = Instant::now();
1005        if self.txn_needs_commit {
1006            self.commit_transaction().await?;
1007        }
1008        // start_transaction can't error (but assert that just in case it changes in
1009        // the future.
1010        let res = self.adapter_client.start_transaction(Some(num_stmts));
1011        assert_ok!(res);
1012        self.adapter_client
1013            .inner()
1014            .metrics()
1015            .pgwire_ensure_transaction_seconds
1016            .with_label_values(&[message_type])
1017            .observe(start.elapsed().as_secs_f64());
1018        Ok(())
1019    }
1020
1021    fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
1022        let parse_start = Instant::now();
1023        let result = match self.adapter_client.parse(sql) {
1024            Ok(result) => result.map_err(|e| {
1025                // Convert our 0-based byte position to pgwire's 1-based character
1026                // position.
1027                let pos = sql[..e.error.pos].chars().count() + 1;
1028                ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
1029            }),
1030            Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
1031        };
1032        self.adapter_client
1033            .inner()
1034            .metrics()
1035            .parse_seconds
1036            .with_label_values(&[])
1037            .observe(parse_start.elapsed().as_secs_f64());
1038        result
1039    }
1040
1041    /// Executes a "Simple Query", see
1042    /// <https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-SIMPLE-QUERY>
1043    ///
1044    /// For implicit transaction handling, see "Multiple Statements in a Simple Query" in the above.
1045    #[instrument(level = "debug")]
1046    async fn query(&mut self, sql: String, received: EpochMillis) -> Result<State, io::Error> {
1047        // Parse first before doing any transaction checking.
1048        let stmts = match self.parse_sql(&sql) {
1049            Ok(stmts) => stmts,
1050            Err(err) => {
1051                self.error(err).await?;
1052                return self.ready().await;
1053            }
1054        };
1055
1056        let num_stmts = stmts.len();
1057
1058        // Compare with postgres' backend/tcop/postgres.c exec_simple_query.
1059        for StatementParseResult { ast: stmt, sql } in stmts {
1060            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1061            if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
1062                self.aborted_txn_error().await?;
1063                break;
1064            }
1065
1066            // Start an implicit transaction if we aren't in any transaction and there's
1067            // more than one statement. This mirrors the `use_implicit_block` variable in
1068            // postgres.
1069            //
1070            // This needs to be done in the loop instead of once at the top because
1071            // a COMMIT/ROLLBACK statement needs to start a new transaction on next
1072            // statement.
1073            self.ensure_transaction(num_stmts, "query").await?;
1074
1075            match self
1076                .one_query(stmt, sql.to_string(), LifecycleTimestamps { received })
1077                .await?
1078            {
1079                State::Ready => (),
1080                State::Drain => break,
1081                State::Done => return Ok(State::Done),
1082            }
1083        }
1084
1085        // Implicit transactions are closed at the end of a Query message.
1086        {
1087            if self.adapter_client.session().transaction().is_implicit() {
1088                self.commit_transaction().await?;
1089            }
1090        }
1091
1092        if num_stmts == 0 {
1093            self.send(BackendMessage::EmptyQueryResponse).await?;
1094        }
1095
1096        self.ready().await
1097    }
1098
1099    #[instrument(level = "debug")]
1100    async fn parse(
1101        &mut self,
1102        name: String,
1103        sql: String,
1104        param_oids: Vec<u32>,
1105    ) -> Result<State, io::Error> {
1106        // Start a transaction if we aren't in one.
1107        self.ensure_transaction(1, "parse").await?;
1108
1109        let mut param_types = vec![];
1110        for oid in param_oids {
1111            match mz_pgrepr::Type::from_oid(oid) {
1112                Ok(ty) => match SqlScalarType::try_from(&ty) {
1113                    Ok(ty) => param_types.push(Some(ty)),
1114                    Err(err) => {
1115                        return self
1116                            .error(ErrorResponse::error(
1117                                SqlState::INVALID_PARAMETER_VALUE,
1118                                err.to_string(),
1119                            ))
1120                            .await;
1121                    }
1122                },
1123                Err(_) if oid == 0 => param_types.push(None),
1124                Err(e) => {
1125                    return self
1126                        .error(ErrorResponse::error(
1127                            SqlState::PROTOCOL_VIOLATION,
1128                            e.to_string(),
1129                        ))
1130                        .await;
1131                }
1132            }
1133        }
1134
1135        let stmts = match self.parse_sql(&sql) {
1136            Ok(stmts) => stmts,
1137            Err(err) => {
1138                return self.error(err).await;
1139            }
1140        };
1141        if stmts.len() > 1 {
1142            return self
1143                .error(ErrorResponse::error(
1144                    SqlState::INTERNAL_ERROR,
1145                    "cannot insert multiple commands into a prepared statement",
1146                ))
1147                .await;
1148        }
1149        let (maybe_stmt, sql) = match stmts.into_iter().next() {
1150            None => (None, ""),
1151            Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
1152        };
1153        if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
1154            return self.aborted_txn_error().await;
1155        }
1156        match self
1157            .adapter_client
1158            .prepare(name, maybe_stmt, sql.to_string(), param_types)
1159            .await
1160        {
1161            Ok(()) => {
1162                self.send(BackendMessage::ParseComplete).await?;
1163                Ok(State::Ready)
1164            }
1165            Err(e) => self.error(e.into_response(Severity::Error)).await,
1166        }
1167    }
1168
1169    /// Commits and clears the current transaction.
1170    #[instrument(level = "debug")]
1171    async fn commit_transaction(&mut self) -> Result<(), io::Error> {
1172        self.end_transaction(EndTransactionAction::Commit).await
1173    }
1174
1175    /// Rollback and clears the current transaction.
1176    #[instrument(level = "debug")]
1177    async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
1178        self.end_transaction(EndTransactionAction::Rollback).await
1179    }
1180
1181    /// End a transaction and report to the user if an error occurred.
1182    #[instrument(level = "debug")]
1183    async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
1184        self.txn_needs_commit = false;
1185        let resp = self.adapter_client.end_transaction(action).await;
1186        if let Err(err) = resp {
1187            self.send(BackendMessage::ErrorResponse(
1188                err.into_response(Severity::Error),
1189            ))
1190            .await?;
1191        }
1192        Ok(())
1193    }
1194
1195    #[instrument(level = "debug")]
1196    async fn bind(
1197        &mut self,
1198        portal_name: String,
1199        statement_name: String,
1200        param_formats: Vec<Format>,
1201        raw_params: Vec<Option<Vec<u8>>>,
1202        result_formats: Vec<Format>,
1203    ) -> Result<State, io::Error> {
1204        // Start a transaction if we aren't in one.
1205        self.ensure_transaction(1, "bind").await?;
1206
1207        let aborted_txn = self.is_aborted_txn();
1208        let stmt = match self
1209            .adapter_client
1210            .get_prepared_statement(&statement_name)
1211            .await
1212        {
1213            Ok(stmt) => stmt,
1214            Err(err) => return self.error(err.into_response(Severity::Error)).await,
1215        };
1216
1217        let param_types = &stmt.desc().param_types;
1218        if param_types.len() != raw_params.len() {
1219            let message = format!(
1220                "bind message supplies {actual} parameters, \
1221                 but prepared statement \"{name}\" requires {expected}",
1222                name = statement_name,
1223                actual = raw_params.len(),
1224                expected = param_types.len()
1225            );
1226            return self
1227                .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, message))
1228                .await;
1229        }
1230        let param_formats = match pad_formats(param_formats, raw_params.len()) {
1231            Ok(param_formats) => param_formats,
1232            Err(msg) => {
1233                return self
1234                    .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1235                    .await;
1236            }
1237        };
1238        if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
1239            return self.aborted_txn_error().await;
1240        }
1241        let buf = RowArena::new();
1242        let mut params = vec![];
1243        for ((raw_param, mz_typ), format) in raw_params
1244            .into_iter()
1245            .zip_eq(param_types)
1246            .zip_eq(param_formats)
1247        {
1248            let pg_typ = mz_pgrepr::Type::from(mz_typ);
1249            let datum = match raw_param {
1250                None => Datum::Null,
1251                Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
1252                    Ok(param) => param.into_datum(&buf, &pg_typ),
1253                    Err(err) => {
1254                        let msg = format!("unable to decode parameter: {}", err);
1255                        return self
1256                            .error(ErrorResponse::error(SqlState::INVALID_PARAMETER_VALUE, msg))
1257                            .await;
1258                    }
1259                },
1260            };
1261            params.push((datum, mz_typ.clone()))
1262        }
1263
1264        let result_formats = match pad_formats(
1265            result_formats,
1266            stmt.desc()
1267                .relation_desc
1268                .clone()
1269                .map(|desc| desc.typ().column_types.len())
1270                .unwrap_or(0),
1271        ) {
1272            Ok(result_formats) => result_formats,
1273            Err(msg) => {
1274                return self
1275                    .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
1276                    .await;
1277            }
1278        };
1279
1280        // Binary encodings are disabled for list, map, and aclitem types, but this doesn't
1281        // apply to COPY TO statements.
1282        if !stmt.stmt().map_or(false, |stmt| {
1283            matches!(
1284                stmt,
1285                Statement::Copy(CopyStatement {
1286                    direction: CopyDirection::To,
1287                    ..
1288                })
1289            )
1290        }) {
1291            if let Some(desc) = stmt.desc().relation_desc.clone() {
1292                for (format, ty) in result_formats.iter().zip_eq(desc.iter_types()) {
1293                    match (format, &ty.scalar_type) {
1294                        (Format::Binary, mz_repr::SqlScalarType::List { .. }) => {
1295                            return self
1296                                .error(ErrorResponse::error(
1297                                    SqlState::PROTOCOL_VIOLATION,
1298                                    "binary encoding of list types is not implemented",
1299                                ))
1300                                .await;
1301                        }
1302                        (Format::Binary, mz_repr::SqlScalarType::Map { .. }) => {
1303                            return self
1304                                .error(ErrorResponse::error(
1305                                    SqlState::PROTOCOL_VIOLATION,
1306                                    "binary encoding of map types is not implemented",
1307                                ))
1308                                .await;
1309                        }
1310                        (Format::Binary, mz_repr::SqlScalarType::AclItem) => {
1311                            return self
1312                                .error(ErrorResponse::error(
1313                                    SqlState::PROTOCOL_VIOLATION,
1314                                    "binary encoding of aclitem types does not exist",
1315                                ))
1316                                .await;
1317                        }
1318                        _ => (),
1319                    }
1320                }
1321            }
1322        }
1323
1324        let desc = stmt.desc().clone();
1325        let logging = Arc::clone(stmt.logging());
1326        let stmt_ast = stmt.stmt().cloned();
1327        let state_revision = stmt.state_revision;
1328        if let Err(err) = self.adapter_client.session().set_portal(
1329            portal_name,
1330            desc,
1331            stmt_ast,
1332            logging,
1333            params,
1334            result_formats,
1335            state_revision,
1336        ) {
1337            return self.error(err.into_response(Severity::Error)).await;
1338        }
1339
1340        self.send(BackendMessage::BindComplete).await?;
1341        Ok(State::Ready)
1342    }
1343
1344    fn execute(
1345        &mut self,
1346        portal_name: String,
1347        max_rows: ExecuteCount,
1348        get_response: GetResponse,
1349        fetch_portal_name: Option<String>,
1350        timeout: ExecuteTimeout,
1351        outer_ctx_extra: Option<ExecuteContextExtra>,
1352        received: Option<EpochMillis>,
1353    ) -> BoxFuture<'_, Result<State, io::Error>> {
1354        async move {
1355            let aborted_txn = self.is_aborted_txn();
1356
1357            // Check if the portal has been started and can be continued.
1358            let portal = match self
1359                .adapter_client
1360                .session()
1361                .get_portal_unverified_mut(&portal_name)
1362            {
1363                Some(portal) => portal,
1364                None => {
1365                    let msg = format!("portal {} does not exist", portal_name.quoted());
1366                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1367                        self.adapter_client.retire_execute(
1368                            outer_ctx_extra,
1369                            StatementEndedExecutionReason::Errored { error: msg.clone() },
1370                        );
1371                    }
1372                    return self
1373                        .error(ErrorResponse::error(SqlState::INVALID_CURSOR_NAME, msg))
1374                        .await;
1375                }
1376            };
1377
1378            *portal.lifecycle_timestamps = received.map(LifecycleTimestamps::new);
1379
1380            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1381            let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1382            if aborted_txn && !txn_exit_stmt {
1383                if let Some(outer_ctx_extra) = outer_ctx_extra {
1384                    self.adapter_client.retire_execute(
1385                        outer_ctx_extra,
1386                        StatementEndedExecutionReason::Errored {
1387                            error: ABORTED_TXN_MSG.to_string(),
1388                        },
1389                    );
1390                }
1391                return self.aborted_txn_error().await;
1392            }
1393
1394            let row_desc = portal.desc.relation_desc.clone();
1395            match portal.state {
1396                PortalState::NotStarted => {
1397                    // Start a transaction if we aren't in one.
1398                    self.ensure_transaction(1, "execute").await?;
1399                    match self
1400                        .adapter_client
1401                        .execute(
1402                            portal_name.clone(),
1403                            self.conn.wait_closed(),
1404                            outer_ctx_extra,
1405                        )
1406                        .await
1407                    {
1408                        Ok((response, execute_started)) => {
1409                            self.send_pending_notices().await?;
1410                            self.send_execute_response(
1411                                response,
1412                                row_desc,
1413                                portal_name,
1414                                max_rows,
1415                                get_response,
1416                                fetch_portal_name,
1417                                timeout,
1418                                execute_started,
1419                            )
1420                            .await
1421                        }
1422                        Err(e) => {
1423                            self.send_pending_notices().await?;
1424                            self.error(e.into_response(Severity::Error)).await
1425                        }
1426                    }
1427                }
1428                PortalState::InProgress(rows) => {
1429                    let rows = rows.take().expect("InProgress rows must be populated");
1430                    let (result, statement_ended_execution_reason) = match self
1431                        .send_rows(
1432                            row_desc.expect("portal missing row desc on resumption"),
1433                            portal_name,
1434                            rows,
1435                            max_rows,
1436                            get_response,
1437                            fetch_portal_name,
1438                            timeout,
1439                        )
1440                        .await
1441                    {
1442                        Err(e) => {
1443                            // This is an error communicating with the connection.
1444                            // We consider that to be a cancelation, rather than a query error.
1445                            (Err(e), StatementEndedExecutionReason::Canceled)
1446                        }
1447                        Ok((ok, SendRowsEndedReason::Canceled)) => {
1448                            (Ok(ok), StatementEndedExecutionReason::Canceled)
1449                        }
1450                        // NOTE: For now the values for `result_size` and
1451                        // `rows_returned` in fetches are a bit confusing.
1452                        // We record `Some(n)` for the first fetch, where `n` is
1453                        // the number of bytes/rows returned by the inner
1454                        // execute (regardless of how many rows the
1455                        // fetch fetched), and `None` for subsequent fetches.
1456                        //
1457                        // This arguably makes sense since the size/rows
1458                        // returned measures how much work the compute
1459                        // layer had to do to satisfy the query, but
1460                        // we should revisit it if/when we start
1461                        // logging the inner execute separately.
1462                        Ok((
1463                            ok,
1464                            SendRowsEndedReason::Success {
1465                                result_size: _,
1466                                rows_returned: _,
1467                            },
1468                        )) => (
1469                            Ok(ok),
1470                            StatementEndedExecutionReason::Success {
1471                                result_size: None,
1472                                rows_returned: None,
1473                                execution_strategy: None,
1474                            },
1475                        ),
1476                        Ok((ok, SendRowsEndedReason::Errored { error })) => {
1477                            (Ok(ok), StatementEndedExecutionReason::Errored { error })
1478                        }
1479                    };
1480                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1481                        self.adapter_client
1482                            .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1483                    }
1484                    result
1485                }
1486                // FETCH is an awkward command for our current architecture. In Postgres it
1487                // will extract <count> rows from the target portal, cache them, and return
1488                // them to the user as requested. Its command tag is always FETCH <num rows
1489                // extracted>. In Materialize, since we have chosen to not fully support FETCH,
1490                // we must remember the number of rows that were returned. Use this tag to
1491                // remember that information and return it.
1492                PortalState::Completed(Some(tag)) => {
1493                    let tag = tag.to_string();
1494                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1495                        self.adapter_client.retire_execute(
1496                            outer_ctx_extra,
1497                            StatementEndedExecutionReason::Success {
1498                                result_size: None,
1499                                rows_returned: None,
1500                                execution_strategy: None,
1501                            },
1502                        );
1503                    }
1504                    self.send(BackendMessage::CommandComplete { tag }).await?;
1505                    Ok(State::Ready)
1506                }
1507                PortalState::Completed(None) => {
1508                    let error = format!(
1509                        "portal {} cannot be run",
1510                        Ident::new_unchecked(portal_name).to_ast_string_stable()
1511                    );
1512                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1513                        self.adapter_client.retire_execute(
1514                            outer_ctx_extra,
1515                            StatementEndedExecutionReason::Errored {
1516                                error: error.clone(),
1517                            },
1518                        );
1519                    }
1520                    self.error(ErrorResponse::error(
1521                        SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1522                        error,
1523                    ))
1524                    .await
1525                }
1526            }
1527        }
1528        .instrument(debug_span!("execute"))
1529        .boxed()
1530    }
1531
1532    #[instrument(level = "debug")]
1533    async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1534        // Start a transaction if we aren't in one.
1535        self.ensure_transaction(1, "describe_statement").await?;
1536
1537        let stmt = match self.adapter_client.get_prepared_statement(name).await {
1538            Ok(stmt) => stmt,
1539            Err(err) => return self.error(err.into_response(Severity::Error)).await,
1540        };
1541        // Cloning to avoid a mutable borrow issue because `send` also uses `adapter_client`
1542        let parameter_desc = BackendMessage::ParameterDescription(
1543            stmt.desc()
1544                .param_types
1545                .iter()
1546                .map(mz_pgrepr::Type::from)
1547                .collect(),
1548        );
1549        // Claim that all results will be output in text format, even
1550        // though the true result formats are not yet known. A bit
1551        // weird, but this is the behavior that PostgreSQL specifies.
1552        let formats = vec![Format::Text; stmt.desc().arity()];
1553        let row_desc = describe_rows(stmt.desc(), &formats);
1554        self.send_all([parameter_desc, row_desc]).await?;
1555        Ok(State::Ready)
1556    }
1557
1558    #[instrument(level = "debug")]
1559    async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1560        // Start a transaction if we aren't in one.
1561        self.ensure_transaction(1, "describe_portal").await?;
1562
1563        let session = self.adapter_client.session();
1564        let row_desc = session
1565            .get_portal_unverified(name)
1566            .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1567        match row_desc {
1568            Some(row_desc) => {
1569                self.send(row_desc).await?;
1570                Ok(State::Ready)
1571            }
1572            None => {
1573                self.error(ErrorResponse::error(
1574                    SqlState::INVALID_CURSOR_NAME,
1575                    format!("portal {} does not exist", name.quoted()),
1576                ))
1577                .await
1578            }
1579        }
1580    }
1581
1582    #[instrument(level = "debug")]
1583    async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1584        self.adapter_client
1585            .session()
1586            .remove_prepared_statement(&name);
1587        self.send(BackendMessage::CloseComplete).await?;
1588        Ok(State::Ready)
1589    }
1590
1591    #[instrument(level = "debug")]
1592    async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1593        self.adapter_client.session().remove_portal(&name);
1594        self.send(BackendMessage::CloseComplete).await?;
1595        Ok(State::Ready)
1596    }
1597
1598    fn complete_portal(&mut self, name: &str) {
1599        let portal = self
1600            .adapter_client
1601            .session()
1602            .get_portal_unverified_mut(name)
1603            .expect("portal should exist");
1604        *portal.state = PortalState::Completed(None);
1605    }
1606
1607    async fn fetch(
1608        &mut self,
1609        name: String,
1610        count: Option<FetchDirection>,
1611        max_rows: ExecuteCount,
1612        fetch_portal_name: Option<String>,
1613        timeout: ExecuteTimeout,
1614        ctx_extra: ExecuteContextExtra,
1615    ) -> Result<State, io::Error> {
1616        // Unlike Execute, no count specified in FETCH returns 1 row, and 0 means 0
1617        // instead of All.
1618        let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1619
1620        // Figure out how many rows we should send back by looking at the various
1621        // combinations of the execute and fetch.
1622        //
1623        // In Postgres, Fetch will cache <count> rows from the target portal and
1624        // return those as requested (if, say, an Execute message was sent with a
1625        // max_rows < the Fetch's count). We expect that case to be incredibly rare and
1626        // so have chosen to not support it until users request it. This eases
1627        // implementation difficulty since we don't have to be able to "send" rows to
1628        // a buffer.
1629        //
1630        // TODO(mjibson): Test this somehow? Need to divide up the pgtest files in
1631        // order to have some that are not Postgres compatible.
1632        let count = match (max_rows, count) {
1633            (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1634                let count = usize::cast_from(count);
1635                if max_rows < count {
1636                    let msg = "Execute with max_rows < a FETCH's count is not supported";
1637                    self.adapter_client.retire_execute(
1638                        ctx_extra,
1639                        StatementEndedExecutionReason::Errored {
1640                            error: msg.to_string(),
1641                        },
1642                    );
1643                    return self
1644                        .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1645                        .await;
1646                }
1647                ExecuteCount::Count(count)
1648            }
1649            (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1650                let msg = "Execute with max_rows of a FETCH ALL is not supported";
1651                self.adapter_client.retire_execute(
1652                    ctx_extra,
1653                    StatementEndedExecutionReason::Errored {
1654                        error: msg.to_string(),
1655                    },
1656                );
1657                return self
1658                    .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1659                    .await;
1660            }
1661            (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1662            (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1663                ExecuteCount::Count(usize::cast_from(count))
1664            }
1665        };
1666        let cursor_name = name.to_string();
1667        self.execute(
1668            cursor_name,
1669            count,
1670            fetch_message,
1671            fetch_portal_name,
1672            timeout,
1673            Some(ctx_extra),
1674            None,
1675        )
1676        .await
1677    }
1678
1679    async fn flush(&mut self) -> Result<State, io::Error> {
1680        self.conn.flush().await?;
1681        Ok(State::Ready)
1682    }
1683
1684    /// Sends a backend message to the client, after applying a severity filter.
1685    ///
1686    /// The message is only sent if its severity is above the severity set
1687    /// in the session, with the default value being NOTICE.
1688    #[instrument(level = "debug")]
1689    async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1690    where
1691        M: Into<BackendMessage>,
1692    {
1693        let message: BackendMessage = message.into();
1694        let is_error =
1695            matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1696
1697        self.conn.send(message).await?;
1698
1699        // Flush immediately after sending an error response, as some clients
1700        // expect to be able to read the error response before sending a Sync
1701        // message. This is arguably in violation of the protocol specification,
1702        // but the specification is somewhat ambiguous, and easier to match
1703        // PostgreSQL here than to fix all the clients that have this
1704        // expectation.
1705        if is_error {
1706            self.conn.flush().await?;
1707        }
1708
1709        Ok(())
1710    }
1711
1712    #[instrument(level = "debug")]
1713    pub async fn send_all(
1714        &mut self,
1715        messages: impl IntoIterator<Item = BackendMessage>,
1716    ) -> Result<(), io::Error> {
1717        for m in messages {
1718            self.send(m).await?;
1719        }
1720        Ok(())
1721    }
1722
1723    #[instrument(level = "debug")]
1724    async fn sync(&mut self) -> Result<State, io::Error> {
1725        // Close the current transaction if we are in an implicit transaction.
1726        if self.adapter_client.session().transaction().is_implicit() {
1727            self.commit_transaction().await?;
1728        }
1729        self.ready().await
1730    }
1731
1732    #[instrument(level = "debug")]
1733    async fn ready(&mut self) -> Result<State, io::Error> {
1734        let txn_state = self.adapter_client.session().transaction().into();
1735        self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1736        self.flush().await
1737    }
1738
1739    #[allow(clippy::too_many_arguments)]
1740    #[instrument(level = "debug")]
1741    async fn send_execute_response(
1742        &mut self,
1743        response: ExecuteResponse,
1744        row_desc: Option<RelationDesc>,
1745        portal_name: String,
1746        max_rows: ExecuteCount,
1747        get_response: GetResponse,
1748        fetch_portal_name: Option<String>,
1749        timeout: ExecuteTimeout,
1750        execute_started: Instant,
1751    ) -> Result<State, io::Error> {
1752        let mut tag = response.tag();
1753
1754        macro_rules! command_complete {
1755            () => {{
1756                self.send(BackendMessage::CommandComplete {
1757                    tag: tag
1758                        .take()
1759                        .expect("command_complete only called on tag-generating results"),
1760                })
1761                .await?;
1762                Ok(State::Ready)
1763            }};
1764        }
1765
1766        let r = match response {
1767            ExecuteResponse::ClosedCursor => {
1768                self.complete_portal(&portal_name);
1769                command_complete!()
1770            }
1771            ExecuteResponse::DeclaredCursor => {
1772                self.complete_portal(&portal_name);
1773                command_complete!()
1774            }
1775            ExecuteResponse::EmptyQuery => {
1776                self.send(BackendMessage::EmptyQueryResponse).await?;
1777                Ok(State::Ready)
1778            }
1779            ExecuteResponse::Fetch {
1780                name,
1781                count,
1782                timeout,
1783                ctx_extra,
1784            } => {
1785                self.fetch(
1786                    name,
1787                    count,
1788                    max_rows,
1789                    Some(portal_name.to_string()),
1790                    timeout,
1791                    ctx_extra,
1792                )
1793                .await
1794            }
1795            ExecuteResponse::SendingRowsStreaming {
1796                rows,
1797                instance_id,
1798                strategy,
1799            } => {
1800                let row_desc = row_desc
1801                    .expect("missing row description for ExecuteResponse::SendingRowsStreaming");
1802
1803                let span = tracing::debug_span!("sending_rows_streaming");
1804
1805                self.send_rows(
1806                    row_desc,
1807                    portal_name,
1808                    InProgressRows::new(RecordFirstRowStream::new(
1809                        Box::new(rows),
1810                        execute_started,
1811                        &self.adapter_client,
1812                        Some(instance_id),
1813                        Some(strategy),
1814                    )),
1815                    max_rows,
1816                    get_response,
1817                    fetch_portal_name,
1818                    timeout,
1819                )
1820                .instrument(span)
1821                .await
1822                .map(|(state, _)| state)
1823            }
1824            ExecuteResponse::SendingRowsImmediate { rows } => {
1825                let row_desc = row_desc
1826                    .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1827
1828                let span = tracing::debug_span!("sending_rows_immediate");
1829
1830                let stream =
1831                    futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1832                self.send_rows(
1833                    row_desc,
1834                    portal_name,
1835                    InProgressRows::new(RecordFirstRowStream::new(
1836                        Box::new(stream),
1837                        execute_started,
1838                        &self.adapter_client,
1839                        None,
1840                        Some(StatementExecutionStrategy::Constant),
1841                    )),
1842                    max_rows,
1843                    get_response,
1844                    fetch_portal_name,
1845                    timeout,
1846                )
1847                .instrument(span)
1848                .await
1849                .map(|(state, _)| state)
1850            }
1851            ExecuteResponse::SetVariable { name, .. } => {
1852                // This code is somewhat awkwardly structured because we
1853                // can't hold `var` across an await point.
1854                let qn = name.to_string();
1855                let msg = if let Some(var) = self
1856                    .adapter_client
1857                    .session()
1858                    .vars_mut()
1859                    .notify_set()
1860                    .find(|v| v.name() == qn)
1861                {
1862                    Some(BackendMessage::ParameterStatus(var.name(), var.value()))
1863                } else {
1864                    None
1865                };
1866                if let Some(msg) = msg {
1867                    self.send(msg).await?;
1868                }
1869                command_complete!()
1870            }
1871            ExecuteResponse::Subscribing {
1872                rx,
1873                ctx_extra,
1874                instance_id,
1875            } => {
1876                if fetch_portal_name.is_none() {
1877                    let mut msg = ErrorResponse::notice(
1878                        SqlState::WARNING,
1879                        "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
1880                    );
1881                    if self.adapter_client.session().vars().application_name() == "psql" {
1882                        msg.hint = Some(
1883                            "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
1884                                .into(),
1885                        )
1886                    }
1887                    self.send(msg).await?;
1888                    self.conn.flush().await?;
1889                }
1890                let row_desc =
1891                    row_desc.expect("missing row description for ExecuteResponse::Subscribing");
1892                let (result, statement_ended_execution_reason) = match self
1893                    .send_rows(
1894                        row_desc,
1895                        portal_name,
1896                        InProgressRows::new(RecordFirstRowStream::new(
1897                            Box::new(UnboundedReceiverStream::new(rx)),
1898                            execute_started,
1899                            &self.adapter_client,
1900                            Some(instance_id),
1901                            None,
1902                        )),
1903                        max_rows,
1904                        get_response,
1905                        fetch_portal_name,
1906                        timeout,
1907                    )
1908                    .await
1909                {
1910                    Err(e) => {
1911                        // This is an error communicating with the connection.
1912                        // We consider that to be a cancelation, rather than a query error.
1913                        (Err(e), StatementEndedExecutionReason::Canceled)
1914                    }
1915                    Ok((ok, SendRowsEndedReason::Canceled)) => {
1916                        (Ok(ok), StatementEndedExecutionReason::Canceled)
1917                    }
1918                    Ok((
1919                        ok,
1920                        SendRowsEndedReason::Success {
1921                            result_size,
1922                            rows_returned,
1923                        },
1924                    )) => (
1925                        Ok(ok),
1926                        StatementEndedExecutionReason::Success {
1927                            result_size: Some(result_size),
1928                            rows_returned: Some(rows_returned),
1929                            execution_strategy: None,
1930                        },
1931                    ),
1932                    Ok((ok, SendRowsEndedReason::Errored { error })) => {
1933                        (Ok(ok), StatementEndedExecutionReason::Errored { error })
1934                    }
1935                };
1936                self.adapter_client
1937                    .retire_execute(ctx_extra, statement_ended_execution_reason);
1938                return result;
1939            }
1940            ExecuteResponse::CopyTo { format, resp } => {
1941                let row_desc =
1942                    row_desc.expect("missing row description for ExecuteResponse::CopyTo");
1943                match *resp {
1944                    ExecuteResponse::Subscribing {
1945                        rx,
1946                        ctx_extra,
1947                        instance_id,
1948                    } => {
1949                        let (result, statement_ended_execution_reason) = match self
1950                            .copy_rows(
1951                                format,
1952                                row_desc,
1953                                RecordFirstRowStream::new(
1954                                    Box::new(UnboundedReceiverStream::new(rx)),
1955                                    execute_started,
1956                                    &self.adapter_client,
1957                                    Some(instance_id),
1958                                    None,
1959                                ),
1960                            )
1961                            .await
1962                        {
1963                            Err(e) => {
1964                                // This is an error communicating with the connection.
1965                                // We consider that to be a cancelation, rather than a query error.
1966                                (Err(e), StatementEndedExecutionReason::Canceled)
1967                            }
1968                            Ok((
1969                                state,
1970                                SendRowsEndedReason::Success {
1971                                    result_size,
1972                                    rows_returned,
1973                                },
1974                            )) => (
1975                                Ok(state),
1976                                StatementEndedExecutionReason::Success {
1977                                    result_size: Some(result_size),
1978                                    rows_returned: Some(rows_returned),
1979                                    execution_strategy: None,
1980                                },
1981                            ),
1982                            Ok((state, SendRowsEndedReason::Errored { error })) => {
1983                                (Ok(state), StatementEndedExecutionReason::Errored { error })
1984                            }
1985                            Ok((state, SendRowsEndedReason::Canceled)) => {
1986                                (Ok(state), StatementEndedExecutionReason::Canceled)
1987                            }
1988                        };
1989                        self.adapter_client
1990                            .retire_execute(ctx_extra, statement_ended_execution_reason);
1991                        return result;
1992                    }
1993                    ExecuteResponse::SendingRowsStreaming {
1994                        rows,
1995                        instance_id,
1996                        strategy,
1997                    } => {
1998                        // We don't need to finalize execution here;
1999                        // it was already done in the
2000                        // coordinator. Just extract the state and
2001                        // return that.
2002                        return self
2003                            .copy_rows(
2004                                format,
2005                                row_desc,
2006                                RecordFirstRowStream::new(
2007                                    Box::new(rows),
2008                                    execute_started,
2009                                    &self.adapter_client,
2010                                    Some(instance_id),
2011                                    Some(strategy),
2012                                ),
2013                            )
2014                            .await
2015                            .map(|(state, _)| state);
2016                    }
2017                    ExecuteResponse::SendingRowsImmediate { rows } => {
2018                        let span = tracing::debug_span!("sending_rows_immediate");
2019
2020                        let rows = futures::stream::once(futures::future::ready(
2021                            PeekResponseUnary::Rows(rows),
2022                        ));
2023                        // We don't need to finalize execution here;
2024                        // it was already done in the
2025                        // coordinator. Just extract the state and
2026                        // return that.
2027                        return self
2028                            .copy_rows(
2029                                format,
2030                                row_desc,
2031                                RecordFirstRowStream::new(
2032                                    Box::new(rows),
2033                                    execute_started,
2034                                    &self.adapter_client,
2035                                    None,
2036                                    Some(StatementExecutionStrategy::Constant),
2037                                ),
2038                            )
2039                            .instrument(span)
2040                            .await
2041                            .map(|(state, _)| state);
2042                    }
2043                    _ => {
2044                        return self
2045                            .error(ErrorResponse::error(
2046                                SqlState::INTERNAL_ERROR,
2047                                "unsupported COPY response type".to_string(),
2048                            ))
2049                            .await;
2050                    }
2051                };
2052            }
2053            ExecuteResponse::CopyFrom {
2054                target_id,
2055                target_name,
2056                columns,
2057                params,
2058                ctx_extra,
2059            } => {
2060                let row_desc =
2061                    row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
2062                self.copy_from(target_id, target_name, columns, params, row_desc, ctx_extra)
2063                    .await
2064            }
2065            ExecuteResponse::TransactionCommitted { params }
2066            | ExecuteResponse::TransactionRolledBack { params } => {
2067                let notify_set: mz_ore::collections::HashSet<String> = self
2068                    .adapter_client
2069                    .session()
2070                    .vars()
2071                    .notify_set()
2072                    .map(|v| v.name().to_string())
2073                    .collect();
2074
2075                // Only report on parameters that are in the notify set.
2076                for (name, value) in params
2077                    .into_iter()
2078                    .filter(|(name, _v)| notify_set.contains(*name))
2079                {
2080                    let msg = BackendMessage::ParameterStatus(name, value);
2081                    self.send(msg).await?;
2082                }
2083                command_complete!()
2084            }
2085
2086            ExecuteResponse::AlteredDefaultPrivileges
2087            | ExecuteResponse::AlteredObject(..)
2088            | ExecuteResponse::AlteredRole
2089            | ExecuteResponse::AlteredSystemConfiguration
2090            | ExecuteResponse::CreatedCluster { .. }
2091            | ExecuteResponse::CreatedClusterReplica { .. }
2092            | ExecuteResponse::CreatedConnection { .. }
2093            | ExecuteResponse::CreatedDatabase { .. }
2094            | ExecuteResponse::CreatedIndex { .. }
2095            | ExecuteResponse::CreatedIntrospectionSubscribe
2096            | ExecuteResponse::CreatedMaterializedView { .. }
2097            | ExecuteResponse::CreatedContinualTask { .. }
2098            | ExecuteResponse::CreatedRole
2099            | ExecuteResponse::CreatedSchema { .. }
2100            | ExecuteResponse::CreatedSecret { .. }
2101            | ExecuteResponse::CreatedSink { .. }
2102            | ExecuteResponse::CreatedSource { .. }
2103            | ExecuteResponse::CreatedTable { .. }
2104            | ExecuteResponse::CreatedType
2105            | ExecuteResponse::CreatedView { .. }
2106            | ExecuteResponse::CreatedViews { .. }
2107            | ExecuteResponse::CreatedNetworkPolicy
2108            | ExecuteResponse::Comment
2109            | ExecuteResponse::Deallocate { .. }
2110            | ExecuteResponse::Deleted(..)
2111            | ExecuteResponse::DiscardedAll
2112            | ExecuteResponse::DiscardedTemp
2113            | ExecuteResponse::DroppedObject(_)
2114            | ExecuteResponse::DroppedOwned
2115            | ExecuteResponse::GrantedPrivilege
2116            | ExecuteResponse::GrantedRole
2117            | ExecuteResponse::Inserted(..)
2118            | ExecuteResponse::Copied(..)
2119            | ExecuteResponse::Prepare
2120            | ExecuteResponse::Raised
2121            | ExecuteResponse::ReassignOwned
2122            | ExecuteResponse::RevokedPrivilege
2123            | ExecuteResponse::RevokedRole
2124            | ExecuteResponse::StartedTransaction { .. }
2125            | ExecuteResponse::Updated(..)
2126            | ExecuteResponse::ValidatedConnection => {
2127                command_complete!()
2128            }
2129        };
2130
2131        assert_none!(tag, "tag created but not consumed: {:?}", tag);
2132        r
2133    }
2134
2135    #[allow(clippy::too_many_arguments)]
2136    // TODO(guswynn): figure out how to get it to compile without skip_all
2137    #[mz_ore::instrument(level = "debug")]
2138    async fn send_rows(
2139        &mut self,
2140        row_desc: RelationDesc,
2141        portal_name: String,
2142        mut rows: InProgressRows,
2143        max_rows: ExecuteCount,
2144        get_response: GetResponse,
2145        fetch_portal_name: Option<String>,
2146        timeout: ExecuteTimeout,
2147    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2148        // If this portal is being executed from a FETCH then we need to use the result
2149        // format type of the outer portal.
2150        let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
2151            name
2152        } else {
2153            &portal_name
2154        };
2155        let result_formats = self
2156            .adapter_client
2157            .session()
2158            .get_portal_unverified(result_format_portal_name)
2159            .expect("valid fetch portal name for send rows")
2160            .result_formats
2161            .clone();
2162
2163        let (mut wait_once, mut deadline) = match timeout {
2164            ExecuteTimeout::None => (false, None),
2165            ExecuteTimeout::Seconds(t) => (
2166                false,
2167                Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
2168            ),
2169            ExecuteTimeout::WaitOnce => (true, None),
2170        };
2171
2172        // Sanity check that the various `RelationDesc`s match up.
2173        {
2174            let portal_name_desc = &self
2175                .adapter_client
2176                .session()
2177                .get_portal_unverified(portal_name.as_str())
2178                .expect("portal should exist")
2179                .desc
2180                .relation_desc;
2181            if let Some(portal_name_desc) = portal_name_desc {
2182                soft_assert_eq_or_log!(portal_name_desc, &row_desc);
2183            }
2184            if let Some(fetch_portal_name) = &fetch_portal_name {
2185                let fetch_portal_desc = &self
2186                    .adapter_client
2187                    .session()
2188                    .get_portal_unverified(fetch_portal_name)
2189                    .expect("portal should exist")
2190                    .desc
2191                    .relation_desc;
2192                if let Some(fetch_portal_desc) = fetch_portal_desc {
2193                    soft_assert_eq_or_log!(fetch_portal_desc, &row_desc);
2194                }
2195            }
2196        }
2197
2198        self.conn.set_encode_state(
2199            row_desc
2200                .typ()
2201                .column_types
2202                .iter()
2203                .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
2204                .zip_eq(result_formats)
2205                .collect(),
2206        );
2207
2208        let mut total_sent_rows = 0;
2209        let mut total_sent_bytes = 0;
2210        // want_rows is the maximum number of rows the client wants.
2211        let mut want_rows = match max_rows {
2212            ExecuteCount::All => usize::MAX,
2213            ExecuteCount::Count(count) => count,
2214        };
2215
2216        // Send rows while the client still wants them and there are still rows to send.
2217        loop {
2218            // Fetch next batch of rows, waiting for a possible requested
2219            // timeout or notice.
2220            let batch = if rows.current.is_some() {
2221                FetchResult::Rows(rows.current.take())
2222            } else if want_rows == 0 {
2223                FetchResult::Rows(None)
2224            } else {
2225                let notice_fut = self.adapter_client.session().recv_notice();
2226                tokio::select! {
2227                    err = self.conn.wait_closed() => return Err(err),
2228                    _ = time::sleep_until(deadline.unwrap_or_else(tokio::time::Instant::now)), if deadline.is_some() => FetchResult::Rows(None),
2229                    notice = notice_fut => {
2230                        FetchResult::Notice(notice)
2231                    }
2232                    batch = rows.remaining.recv() => match batch {
2233                        None => FetchResult::Rows(None),
2234                        Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
2235                        Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
2236                        Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
2237                    },
2238                }
2239            };
2240
2241            match batch {
2242                FetchResult::Rows(None) => break,
2243                FetchResult::Rows(Some(mut batch_rows)) => {
2244                    if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
2245                        let msg = err.to_string();
2246                        return self
2247                            .error(err.into_response(Severity::Error))
2248                            .await
2249                            .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
2250                    }
2251
2252                    // If wait_once is true: the first time this fn is called it blocks (same as
2253                    // deadline == None). The second time this fn is called it should behave the
2254                    // same a 0s timeout.
2255                    if wait_once && batch_rows.peek().is_some() {
2256                        deadline = Some(tokio::time::Instant::now());
2257                        wait_once = false;
2258                    }
2259
2260                    // Send a portion of the rows.
2261                    let mut sent_rows = 0;
2262                    let mut sent_bytes = 0;
2263                    let messages = (&mut batch_rows)
2264                        // TODO(parkmycar): This is a fair bit of juggling between iterator types
2265                        // to count the total number of bytes. Alternatively we could track the
2266                        // total sent bytes in this .map(...) call, but having side effects in map
2267                        // is a code smell.
2268                        .map(|row| {
2269                            let row_len = row.byte_len();
2270                            let values = mz_pgrepr::values_from_row(row, row_desc.typ());
2271                            (row_len, BackendMessage::DataRow(values))
2272                        })
2273                        .inspect(|(row_len, _)| {
2274                            sent_bytes += row_len;
2275                            sent_rows += 1
2276                        })
2277                        .map(|(_row_len, row)| row)
2278                        .take(want_rows);
2279                    self.send_all(messages).await?;
2280
2281                    total_sent_rows += sent_rows;
2282                    total_sent_bytes += sent_bytes;
2283                    want_rows -= sent_rows;
2284
2285                    // If we have sent the number of requested rows, put the remainder of the batch
2286                    // (if any) back and stop sending.
2287                    if want_rows == 0 {
2288                        if batch_rows.peek().is_some() {
2289                            rows.current = Some(batch_rows);
2290                        }
2291                        break;
2292                    }
2293
2294                    self.conn.flush().await?;
2295                }
2296                FetchResult::Notice(notice) => {
2297                    self.send(notice.into_response()).await?;
2298                    self.conn.flush().await?;
2299                }
2300                FetchResult::Error(text) => {
2301                    return self
2302                        .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2303                        .await
2304                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2305                }
2306                FetchResult::Canceled => {
2307                    return self
2308                        .error(ErrorResponse::error(
2309                            SqlState::QUERY_CANCELED,
2310                            "canceling statement due to user request",
2311                        ))
2312                        .await
2313                        .map(|state| (state, SendRowsEndedReason::Canceled));
2314                }
2315            }
2316        }
2317
2318        let portal = self
2319            .adapter_client
2320            .session()
2321            .get_portal_unverified_mut(&portal_name)
2322            .expect("valid portal name for send rows");
2323
2324        let saw_rows = rows.remaining.saw_rows;
2325        let no_more_rows = rows.no_more_rows();
2326        let recorded_first_row_instant = rows.remaining.recorded_first_row_instant;
2327
2328        // Always return rows back, even if it's empty. This prevents an unclosed
2329        // portal from re-executing after it has been emptied.
2330        *portal.state = PortalState::InProgress(Some(rows));
2331
2332        let fetch_portal = fetch_portal_name.map(|name| {
2333            self.adapter_client
2334                .session()
2335                .get_portal_unverified_mut(&name)
2336                .expect("valid fetch portal")
2337        });
2338        let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2339        self.send(response_message).await?;
2340
2341        // Attend to metrics if there are no more rows.
2342        if no_more_rows {
2343            let statement_type = if let Some(stmt) = &self
2344                .adapter_client
2345                .session()
2346                .get_portal_unverified(&portal_name)
2347                .expect("valid portal name for send_rows")
2348                .stmt
2349            {
2350                metrics::statement_type_label_value(stmt.deref())
2351            } else {
2352                "no-statement"
2353            };
2354            let duration = if saw_rows {
2355                recorded_first_row_instant
2356                    .expect("recorded_first_row_instant because saw_rows")
2357                    .elapsed()
2358            } else {
2359                // If the result is empty, then we define time from first to last row as 0.
2360                // (Note that, currently, an empty result involves a PeekResponse with 0 rows, which
2361                // does flip `saw_rows`, so this code path is currently not exercised.)
2362                Duration::ZERO
2363            };
2364            self.adapter_client
2365                .inner()
2366                .metrics()
2367                .result_rows_first_to_last_byte_seconds
2368                .with_label_values(&[statement_type])
2369                .observe(duration.as_secs_f64());
2370        }
2371
2372        Ok((
2373            State::Ready,
2374            SendRowsEndedReason::Success {
2375                result_size: u64::cast_from(total_sent_bytes),
2376                rows_returned: u64::cast_from(total_sent_rows),
2377            },
2378        ))
2379    }
2380
2381    #[mz_ore::instrument(level = "debug")]
2382    async fn copy_rows(
2383        &mut self,
2384        format: CopyFormat,
2385        row_desc: RelationDesc,
2386        mut stream: RecordFirstRowStream,
2387    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2388        let (row_format, encode_format) = match format {
2389            CopyFormat::Text => (
2390                CopyFormatParams::Text(CopyTextFormatParams::default()),
2391                Format::Text,
2392            ),
2393            CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2394            CopyFormat::Csv => (
2395                CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2396                Format::Text,
2397            ),
2398            CopyFormat::Parquet => {
2399                let text = "Parquet format is not supported".to_string();
2400                return self
2401                    .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2402                    .await
2403                    .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2404            }
2405        };
2406
2407        let encode_fn = |row: &RowRef, typ: &SqlRelationType, out: &mut Vec<u8>| {
2408            mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2409        };
2410
2411        let typ = row_desc.typ();
2412        let column_formats = iter::repeat(encode_format)
2413            .take(typ.column_types.len())
2414            .collect();
2415        self.send(BackendMessage::CopyOutResponse {
2416            overall_format: encode_format,
2417            column_formats,
2418        })
2419        .await?;
2420
2421        // In Postgres, binary copy has a header that is followed (in the same
2422        // CopyData) by the first row. In order to replicate their behavior, use a
2423        // common vec that we can extend one time now and then fill up with the encode
2424        // functions.
2425        let mut out = Vec::new();
2426
2427        if let CopyFormat::Binary = format {
2428            // 11-byte signature.
2429            out.extend(b"PGCOPY\n\xFF\r\n\0");
2430            // 32-bit flags field.
2431            out.extend([0, 0, 0, 0]);
2432            // 32-bit header extension length field.
2433            out.extend([0, 0, 0, 0]);
2434        }
2435
2436        let mut count = 0;
2437        let mut total_sent_bytes = 0;
2438        loop {
2439            tokio::select! {
2440                e = self.conn.wait_closed() => return Err(e),
2441                batch = stream.recv() => match batch {
2442                    None => break,
2443                    Some(PeekResponseUnary::Error(text)) => {
2444                        return self
2445                            .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2446                        .await
2447                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2448                    }
2449                    Some(PeekResponseUnary::Canceled) => {
2450                        return self.error(ErrorResponse::error(
2451                                SqlState::QUERY_CANCELED,
2452                                "canceling statement due to user request",
2453                            ))
2454                            .await.map(|state| (state, SendRowsEndedReason::Canceled));
2455                    }
2456                    Some(PeekResponseUnary::Rows(mut rows)) => {
2457                        count += rows.count();
2458                        while let Some(row) = rows.next() {
2459                            total_sent_bytes += row.byte_len();
2460                            encode_fn(row, typ, &mut out)?;
2461                            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2462                                .await?;
2463                        }
2464                    }
2465                },
2466                notice = self.adapter_client.session().recv_notice() => {
2467                    self.send(notice.into_response())
2468                        .await?;
2469                    self.conn.flush().await?;
2470                }
2471            }
2472
2473            self.conn.flush().await?;
2474        }
2475        // Send required trailers.
2476        if let CopyFormat::Binary = format {
2477            let trailer: i16 = -1;
2478            out.extend(trailer.to_be_bytes());
2479            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2480                .await?;
2481        }
2482
2483        let tag = format!("COPY {}", count);
2484        self.send(BackendMessage::CopyDone).await?;
2485        self.send(BackendMessage::CommandComplete { tag }).await?;
2486        Ok((
2487            State::Ready,
2488            SendRowsEndedReason::Success {
2489                result_size: u64::cast_from(total_sent_bytes),
2490                rows_returned: u64::cast_from(count),
2491            },
2492        ))
2493    }
2494
2495    /// Handles the copy-in mode of the postgres protocol from transferring
2496    /// data to the server.
2497    #[instrument(level = "debug")]
2498    async fn copy_from(
2499        &mut self,
2500        target_id: CatalogItemId,
2501        target_name: String,
2502        columns: Vec<ColumnIndex>,
2503        params: CopyFormatParams<'_>,
2504        row_desc: RelationDesc,
2505        mut ctx_extra: ExecuteContextExtra,
2506    ) -> Result<State, io::Error> {
2507        let res = self
2508            .copy_from_inner(
2509                target_id,
2510                target_name,
2511                columns,
2512                params,
2513                row_desc,
2514                &mut ctx_extra,
2515            )
2516            .await;
2517        match &res {
2518            Ok(State::Done) => {
2519                // The connection closed gracefully without sending us a `CopyDone`,
2520                // causing us to just drop the copy request.
2521                // For the purposes of statement logging, we count this as a cancellation.
2522                self.adapter_client
2523                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2524            }
2525            Err(e) => {
2526                self.adapter_client.retire_execute(
2527                    ctx_extra,
2528                    StatementEndedExecutionReason::Errored {
2529                        error: format!("{e}"),
2530                    },
2531                );
2532            }
2533            other => {
2534                tracing::warn!(?other, "aborting COPY FROM");
2535                self.adapter_client
2536                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Aborted);
2537            }
2538        }
2539        res
2540    }
2541
2542    async fn copy_from_inner(
2543        &mut self,
2544        target_id: CatalogItemId,
2545        target_name: String,
2546        columns: Vec<ColumnIndex>,
2547        params: CopyFormatParams<'_>,
2548        row_desc: RelationDesc,
2549        ctx_extra: &mut ExecuteContextExtra,
2550    ) -> Result<State, io::Error> {
2551        let typ = row_desc.typ();
2552        let column_formats = vec![Format::Text; typ.column_types.len()];
2553        self.send(BackendMessage::CopyInResponse {
2554            overall_format: Format::Text,
2555            column_formats,
2556        })
2557        .await?;
2558        self.conn.flush().await?;
2559
2560        let system_vars = self.adapter_client.get_system_vars().await;
2561        let max_size = system_vars
2562            .get(MAX_COPY_FROM_SIZE.name())
2563            .ok()
2564            .and_then(|max_size| max_size.value().parse().ok())
2565            .unwrap_or(usize::MAX);
2566        tracing::debug!("COPY FROM max buffer size: {max_size} bytes");
2567
2568        let mut data = Vec::new();
2569        loop {
2570            let message = self.conn.recv().await?;
2571            match message {
2572                Some(FrontendMessage::CopyData(buf)) => {
2573                    // Bail before we OOM.
2574                    if (data.len() + buf.len()) > max_size {
2575                        return self
2576                            .error(ErrorResponse::error(
2577                                SqlState::INSUFFICIENT_RESOURCES,
2578                                "COPY FROM STDIN too large",
2579                            ))
2580                            .await;
2581                    }
2582                    data.extend(buf)
2583                }
2584                Some(FrontendMessage::CopyDone) => break,
2585                Some(FrontendMessage::CopyFail(err)) => {
2586                    self.adapter_client.retire_execute(
2587                        std::mem::take(ctx_extra),
2588                        StatementEndedExecutionReason::Canceled,
2589                    );
2590                    return self
2591                        .error(ErrorResponse::error(
2592                            SqlState::QUERY_CANCELED,
2593                            format!("COPY from stdin failed: {}", err),
2594                        ))
2595                        .await;
2596                }
2597                Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2598                Some(_) => {
2599                    let msg = "unexpected message type during COPY from stdin";
2600                    self.adapter_client.retire_execute(
2601                        std::mem::take(ctx_extra),
2602                        StatementEndedExecutionReason::Errored {
2603                            error: msg.to_string(),
2604                        },
2605                    );
2606                    return self
2607                        .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
2608                        .await;
2609                }
2610                None => {
2611                    return Ok(State::Done);
2612                }
2613            }
2614        }
2615
2616        let column_types = typ
2617            .column_types
2618            .iter()
2619            .map(|x| &x.scalar_type)
2620            .map(mz_pgrepr::Type::from)
2621            .collect::<Vec<mz_pgrepr::Type>>();
2622
2623        let rows = match mz_pgcopy::decode_copy_format(&data, &column_types, params) {
2624            Ok(rows) => rows,
2625            Err(e) => {
2626                self.adapter_client.retire_execute(
2627                    std::mem::take(ctx_extra),
2628                    StatementEndedExecutionReason::Errored {
2629                        error: e.to_string(),
2630                    },
2631                );
2632                return self
2633                    .error(ErrorResponse::error(
2634                        SqlState::BAD_COPY_FILE_FORMAT,
2635                        format!("{}", e),
2636                    ))
2637                    .await;
2638            }
2639        };
2640
2641        let count = rows.len();
2642
2643        if let Err(e) = self
2644            .adapter_client
2645            .insert_rows(
2646                target_id,
2647                target_name,
2648                columns,
2649                rows,
2650                std::mem::take(ctx_extra),
2651            )
2652            .await
2653        {
2654            self.adapter_client.retire_execute(
2655                std::mem::take(ctx_extra),
2656                StatementEndedExecutionReason::Errored {
2657                    error: e.to_string(),
2658                },
2659            );
2660            return self.error(e.into_response(Severity::Error)).await;
2661        }
2662
2663        let tag = format!("COPY {}", count);
2664        self.send(BackendMessage::CommandComplete { tag }).await?;
2665
2666        Ok(State::Ready)
2667    }
2668
2669    #[instrument(level = "debug")]
2670    async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
2671        let notices = self
2672            .adapter_client
2673            .session()
2674            .drain_notices()
2675            .into_iter()
2676            .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
2677        self.send_all(notices).await?;
2678        Ok(())
2679    }
2680
2681    #[instrument(level = "debug")]
2682    async fn error(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
2683        assert!(err.severity.is_error());
2684        debug!(
2685            "cid={} error code={}",
2686            self.adapter_client.session().conn_id(),
2687            err.code.code()
2688        );
2689        let is_fatal = err.severity.is_fatal();
2690        self.send(BackendMessage::ErrorResponse(err)).await?;
2691
2692        let txn = self.adapter_client.session().transaction();
2693        match txn {
2694            // Error can be called from describe and parse and so might not be in an active
2695            // transaction.
2696            TransactionStatus::Default | TransactionStatus::Failed(_) => {}
2697            // In Started (i.e., a single statement), cleanup ourselves.
2698            TransactionStatus::Started(_) => {
2699                self.rollback_transaction().await?;
2700            }
2701            // Implicit transactions also clear themselves.
2702            TransactionStatus::InTransactionImplicit(_) => {
2703                self.rollback_transaction().await?;
2704            }
2705            // Explicit transactions move to failed.
2706            TransactionStatus::InTransaction(_) => {
2707                self.adapter_client.fail_transaction();
2708            }
2709        };
2710        if is_fatal {
2711            Ok(State::Done)
2712        } else {
2713            Ok(State::Drain)
2714        }
2715    }
2716
2717    #[instrument(level = "debug")]
2718    async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
2719        self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
2720            SqlState::IN_FAILED_SQL_TRANSACTION,
2721            ABORTED_TXN_MSG,
2722        )))
2723        .await?;
2724        Ok(State::Drain)
2725    }
2726
2727    fn is_aborted_txn(&mut self) -> bool {
2728        matches!(
2729            self.adapter_client.session().transaction(),
2730            TransactionStatus::Failed(_)
2731        )
2732    }
2733}
2734
2735fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
2736    match (formats.len(), n) {
2737        (0, e) => Ok(vec![Format::Text; e]),
2738        (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
2739        (a, e) if a == e => Ok(formats),
2740        (a, e) => Err(format!(
2741            "expected {} field format specifiers, but got {}",
2742            e, a
2743        )),
2744    }
2745}
2746
2747fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
2748    match &stmt_desc.relation_desc {
2749        Some(desc) if !stmt_desc.is_copy => {
2750            BackendMessage::RowDescription(message::encode_row_description(desc, formats))
2751        }
2752        _ => BackendMessage::NoData,
2753    }
2754}
2755
2756type GetResponse = fn(
2757    max_rows: ExecuteCount,
2758    total_sent_rows: usize,
2759    fetch_portal: Option<PortalRefMut>,
2760) -> BackendMessage;
2761
2762// A GetResponse used by send_rows during execute messages on portals or for
2763// simple query messages.
2764fn portal_exec_message(
2765    max_rows: ExecuteCount,
2766    total_sent_rows: usize,
2767    _fetch_portal: Option<PortalRefMut>,
2768) -> BackendMessage {
2769    // If max_rows is not specified, we will always send back a CommandComplete. If
2770    // max_rows is specified, we only send CommandComplete if there were more rows
2771    // requested than were remaining. That is, if max_rows == number of rows that
2772    // were remaining before sending (not that are remaining after sending), then
2773    // we still send a PortalSuspended. The number of remaining rows after the rows
2774    // have been sent doesn't matter. This matches postgres.
2775    match max_rows {
2776        ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
2777            BackendMessage::PortalSuspended
2778        }
2779        _ => BackendMessage::CommandComplete {
2780            tag: format!("SELECT {}", total_sent_rows),
2781        },
2782    }
2783}
2784
2785// A GetResponse used by send_rows during FETCH queries.
2786fn fetch_message(
2787    _max_rows: ExecuteCount,
2788    total_sent_rows: usize,
2789    fetch_portal: Option<PortalRefMut>,
2790) -> BackendMessage {
2791    let tag = format!("FETCH {}", total_sent_rows);
2792    if let Some(portal) = fetch_portal {
2793        *portal.state = PortalState::Completed(Some(tag.clone()));
2794    }
2795    BackendMessage::CommandComplete { tag }
2796}
2797
2798#[derive(Debug, Copy, Clone)]
2799enum ExecuteCount {
2800    All,
2801    Count(usize),
2802}
2803
2804// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
2805fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
2806    match stmt {
2807        // Add PREPARE to this if we ever support it.
2808        Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
2809        None => false,
2810    }
2811}
2812
2813#[derive(Debug)]
2814enum FetchResult {
2815    Rows(Option<Box<dyn RowIterator + Send + Sync>>),
2816    Canceled,
2817    Error(String),
2818    Notice(AdapterNotice),
2819}
2820
2821#[cfg(test)]
2822mod test {
2823    use super::*;
2824
2825    #[mz_ore::test]
2826    fn test_parse_options() {
2827        struct TestCase {
2828            input: &'static str,
2829            expect: Result<Vec<(&'static str, &'static str)>, ()>,
2830        }
2831        let tests = vec![
2832            TestCase {
2833                input: "",
2834                expect: Ok(vec![]),
2835            },
2836            TestCase {
2837                input: "--key",
2838                expect: Err(()),
2839            },
2840            TestCase {
2841                input: "--key=val",
2842                expect: Ok(vec![("key", "val")]),
2843            },
2844            TestCase {
2845                input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
2846                expect: Ok(vec![
2847                    ("key", "val"),
2848                    ("key2", "val2"),
2849                    ("key3", "val3"),
2850                    ("key4", "val4"),
2851                    ("key5", "val5"),
2852                ]),
2853            },
2854            TestCase {
2855                input: r#"-c\ key=val"#,
2856                expect: Ok(vec![(" key", "val")]),
2857            },
2858            TestCase {
2859                input: "--key=val -ckey2 val2",
2860                expect: Err(()),
2861            },
2862            // Unclear what this should do.
2863            TestCase {
2864                input: "--key=",
2865                expect: Ok(vec![("key", "")]),
2866            },
2867        ];
2868        for test in tests {
2869            let got = parse_options(test.input);
2870            let expect = test.expect.map(|r| {
2871                r.into_iter()
2872                    .map(|(k, v)| (k.to_owned(), v.to_owned()))
2873                    .collect()
2874            });
2875            assert_eq!(got, expect, "input: {}", test.input);
2876        }
2877    }
2878
2879    #[mz_ore::test]
2880    fn test_parse_option() {
2881        struct TestCase {
2882            input: &'static str,
2883            expect: Result<(&'static str, &'static str), ()>,
2884        }
2885        let tests = vec![
2886            TestCase {
2887                input: "",
2888                expect: Err(()),
2889            },
2890            TestCase {
2891                input: "--",
2892                expect: Err(()),
2893            },
2894            TestCase {
2895                input: "--c",
2896                expect: Err(()),
2897            },
2898            TestCase {
2899                input: "a=b",
2900                expect: Err(()),
2901            },
2902            TestCase {
2903                input: "--a=b",
2904                expect: Ok(("a", "b")),
2905            },
2906            TestCase {
2907                input: "--ca=b",
2908                expect: Ok(("ca", "b")),
2909            },
2910            TestCase {
2911                input: "-ca=b",
2912                expect: Ok(("a", "b")),
2913            },
2914            // Unclear what this should error, but at least test it.
2915            TestCase {
2916                input: "--=",
2917                expect: Ok(("", "")),
2918            },
2919        ];
2920        for test in tests {
2921            let got = parse_option(test.input);
2922            assert_eq!(got, test.expect, "input: {}", test.input);
2923        }
2924    }
2925
2926    #[mz_ore::test]
2927    fn test_split_options() {
2928        struct TestCase {
2929            input: &'static str,
2930            expect: Vec<&'static str>,
2931        }
2932        let tests = vec![
2933            TestCase {
2934                input: "",
2935                expect: vec![],
2936            },
2937            TestCase {
2938                input: "  ",
2939                expect: vec![],
2940            },
2941            TestCase {
2942                input: " a ",
2943                expect: vec!["a"],
2944            },
2945            TestCase {
2946                input: "  ab     cd   ",
2947                expect: vec!["ab", "cd"],
2948            },
2949            TestCase {
2950                input: r#"  ab\     cd   "#,
2951                expect: vec!["ab ", "cd"],
2952            },
2953            TestCase {
2954                input: r#"  ab\\     cd   "#,
2955                expect: vec![r#"ab\"#, "cd"],
2956            },
2957            TestCase {
2958                input: r#"  ab\\\     cd   "#,
2959                expect: vec![r#"ab\ "#, "cd"],
2960            },
2961            TestCase {
2962                input: r#"  ab\\\ cd   "#,
2963                expect: vec![r#"ab\ cd"#],
2964            },
2965            TestCase {
2966                input: r#"  ab\\\cd   "#,
2967                expect: vec![r#"ab\cd"#],
2968            },
2969            TestCase {
2970                input: r#"a\"#,
2971                expect: vec!["a"],
2972            },
2973            TestCase {
2974                input: r#"a\ "#,
2975                expect: vec!["a "],
2976            },
2977            TestCase {
2978                input: r#"\"#,
2979                expect: vec![],
2980            },
2981            TestCase {
2982                input: r#"\ "#,
2983                expect: vec![r#" "#],
2984            },
2985            TestCase {
2986                input: r#" \ "#,
2987                expect: vec![r#" "#],
2988            },
2989            TestCase {
2990                input: r#"\  "#,
2991                expect: vec![r#" "#],
2992            },
2993        ];
2994        for test in tests {
2995            let got = split_options(test.input);
2996            assert_eq!(got, test.expect, "input: {}", test.input);
2997        }
2998    }
2999}