mz_pgwire/
protocol.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::sync::Arc;
14use std::time::Instant;
15use std::{iter, mem};
16
17use byteorder::{ByteOrder, NetworkEndian};
18use futures::future::{BoxFuture, FutureExt, pending};
19use itertools::izip;
20use mz_adapter::client::RecordFirstRowStream;
21use mz_adapter::session::{
22    EndTransactionAction, InProgressRows, Portal, PortalState, SessionConfig, TransactionStatus,
23};
24use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
25use mz_adapter::{
26    AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, PeekResponseUnary,
27    RowsFuture, verify_datum_desc,
28};
29use mz_auth::password::Password;
30use mz_frontegg_auth::Authenticator as FronteggAuthentication;
31use mz_ore::cast::CastFrom;
32use mz_ore::netio::AsyncReady;
33use mz_ore::str::StrExt;
34use mz_ore::{assert_none, assert_ok, instrument};
35use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
36use mz_pgwire_common::{
37    ConnectionCounter, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3, VERSIONS,
38};
39use mz_repr::user::InternalUserMetadata;
40use mz_repr::{
41    CatalogItemId, ColumnIndex, Datum, RelationDesc, RelationType, RowArena, RowIterator, RowRef,
42    ScalarType,
43};
44use mz_server_core::TlsMode;
45use mz_sql::ast::display::AstDisplay;
46use mz_sql::ast::{CopyDirection, CopyStatement, FetchDirection, Ident, Raw, Statement};
47use mz_sql::parse::StatementParseResult;
48use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
49use mz_sql::session::metadata::SessionMetadata;
50use mz_sql::session::user::INTERNAL_USER_NAMES;
51use mz_sql::session::vars::{MAX_COPY_FROM_SIZE, Var, VarInput};
52use postgres::error::SqlState;
53use tokio::io::{self, AsyncRead, AsyncWrite};
54use tokio::select;
55use tokio::sync::mpsc::UnboundedReceiver;
56use tokio::time::{self};
57use tokio_stream::wrappers::UnboundedReceiverStream;
58use tracing::{Instrument, debug, debug_span, warn};
59use uuid::Uuid;
60
61use crate::codec::FramedConn;
62use crate::message::{self, BackendMessage};
63
64/// Reports whether the given stream begins with a pgwire handshake.
65///
66/// To avoid false negatives, there must be at least eight bytes in `buf`.
67pub fn match_handshake(buf: &[u8]) -> bool {
68    // The pgwire StartupMessage looks like this:
69    //
70    //     i32 - Length of entire message.
71    //     i32 - Protocol version number.
72    //     [String] - Arbitrary key-value parameters of any length.
73    //
74    // Since arbitrary parameters can be included in the StartupMessage, the
75    // first Int32 is worthless, since the message could have any length.
76    // Instead, we sniff the protocol version number.
77    if buf.len() < 8 {
78        return false;
79    }
80    let version = NetworkEndian::read_i32(&buf[4..8]);
81    VERSIONS.contains(&version)
82}
83
84/// Parameters for the [`run`] function.
85pub struct RunParams<'a, A> {
86    /// The TLS mode of the pgwire server.
87    pub tls_mode: Option<TlsMode>,
88    /// A client for the adapter.
89    pub adapter_client: mz_adapter::Client,
90    /// The connection to the client.
91    pub conn: &'a mut FramedConn<A>,
92    /// The universally unique identifier for the connection.
93    pub conn_uuid: Uuid,
94    /// The protocol version that the client provided in the startup message.
95    pub version: i32,
96    /// The parameters that the client provided in the startup message.
97    pub params: BTreeMap<String, String>,
98    /// Frontegg authentication.
99    pub frontegg: Option<&'a FronteggAuthentication>,
100    /// Whether to use self hosted auth.
101    pub use_self_hosted_auth: bool,
102    /// Whether this is an internal server that permits access to restricted
103    /// system resources.
104    pub internal: bool,
105    /// Global connection limit and count
106    pub active_connection_counter: ConnectionCounter,
107    /// Helm chart version
108    pub helm_chart_version: Option<String>,
109}
110
111/// Runs a pgwire connection to completion.
112///
113/// This involves responding to `FrontendMessage::StartupMessage` and all future
114/// requests until the client terminates the connection or a fatal error occurs.
115///
116/// Note that this function returns successfully even upon delivering a fatal
117/// error to the client. It only returns `Err` if an unexpected I/O error occurs
118/// while communicating with the client, e.g., if the connection is severed in
119/// the middle of a request.
120#[mz_ore::instrument(level = "debug")]
121pub async fn run<'a, A>(
122    RunParams {
123        tls_mode,
124        adapter_client,
125        conn,
126        conn_uuid,
127        version,
128        mut params,
129        frontegg,
130        use_self_hosted_auth,
131        internal,
132        active_connection_counter,
133        helm_chart_version,
134    }: RunParams<'a, A>,
135) -> Result<(), io::Error>
136where
137    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
138{
139    if version != VERSION_3 {
140        return conn
141            .send(ErrorResponse::fatal(
142                SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
143                "server does not support the client's requested protocol version",
144            ))
145            .await;
146    }
147
148    let user = params.remove("user").unwrap_or_else(String::new);
149
150    if internal {
151        // The internal server can only be used to connect to the internal users.
152        if !INTERNAL_USER_NAMES.contains(&user) {
153            let msg = format!("unauthorized login to user '{user}'");
154            return conn
155                .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
156                .await;
157        }
158    } else {
159        // The external server cannot be used to connect to any system users.
160        if mz_adapter::catalog::is_reserved_role_name(user.as_str()) {
161            let msg = format!("unauthorized login to user '{user}'");
162            return conn
163                .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
164                .await;
165        }
166    }
167
168    if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
169        return conn.send(err).await;
170    }
171
172    let (mut session, expired) = if let Some(frontegg) = frontegg {
173        conn.send(BackendMessage::AuthenticationCleartextPassword)
174            .await?;
175        conn.flush().await?;
176        let password = match conn.recv().await? {
177            Some(FrontendMessage::Password { password }) => password,
178            _ => {
179                return conn
180                    .send(ErrorResponse::fatal(
181                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
182                        "expected Password message",
183                    ))
184                    .await;
185            }
186        };
187
188        let auth_response = frontegg.authenticate(&user, &password).await;
189        match auth_response {
190            Ok(mut auth_session) => {
191                // Create a session based on the auth session.
192                //
193                // In particular, it's important that the username come from the
194                // auth session, as Frontegg may return an email address with
195                // different casing than the user supplied via the pgwire
196                // username field. We want to use the Frontegg casing as
197                // canonical.
198                let session = adapter_client.new_session(SessionConfig {
199                    conn_id: conn.conn_id().clone(),
200                    uuid: conn_uuid,
201                    user: auth_session.user().into(),
202                    client_ip: conn.peer_addr().clone(),
203                    external_metadata_rx: Some(auth_session.external_metadata_rx()),
204                    internal_user_metadata: None,
205                    helm_chart_version,
206                });
207                let expired = async move { auth_session.expired().await };
208                (session, expired.left_future())
209            }
210            Err(err) => {
211                warn!(?err, "pgwire connection failed authentication");
212                return conn
213                    .send(ErrorResponse::fatal(
214                        SqlState::INVALID_PASSWORD,
215                        "invalid password",
216                    ))
217                    .await;
218            }
219        }
220    } else if use_self_hosted_auth {
221        conn.send(BackendMessage::AuthenticationCleartextPassword)
222            .await?;
223        conn.flush().await?;
224        let password = match conn.recv().await? {
225            Some(FrontendMessage::Password { password }) => Password(password),
226            _ => {
227                return conn
228                    .send(ErrorResponse::fatal(
229                        SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
230                        "expected Password message",
231                    ))
232                    .await;
233            }
234        };
235        let auth_response = match adapter_client.authenticate(&user, &password).await {
236            Ok(resp) => resp,
237            Err(err) => {
238                warn!(?err, "pgwire connection failed authentication");
239                return conn
240                    .send(ErrorResponse::fatal(
241                        SqlState::INVALID_PASSWORD,
242                        "invalid password",
243                    ))
244                    .await;
245            }
246        };
247        let session = adapter_client.new_session(SessionConfig {
248            conn_id: conn.conn_id().clone(),
249            uuid: conn_uuid,
250            user,
251            client_ip: conn.peer_addr().clone(),
252            external_metadata_rx: None,
253            internal_user_metadata: Some(InternalUserMetadata {
254                superuser: auth_response.superuser,
255            }),
256            helm_chart_version,
257        });
258        // No frontegg check, so auth session lasts indefinitely.
259        let auth_session = pending().right_future();
260        (session, auth_session)
261    } else {
262        let session = adapter_client.new_session(SessionConfig {
263            conn_id: conn.conn_id().clone(),
264            uuid: conn_uuid,
265            user,
266            client_ip: conn.peer_addr().clone(),
267            external_metadata_rx: None,
268            internal_user_metadata: None,
269            helm_chart_version,
270        });
271        // No frontegg check, so auth session lasts indefinitely.
272        let auth_session = pending().right_future();
273        (session, auth_session)
274    };
275
276    let system_vars = adapter_client.get_system_vars().await;
277    for (name, value) in params {
278        let settings = match name.as_str() {
279            "options" => match parse_options(&value) {
280                Ok(opts) => opts,
281                Err(()) => {
282                    session.add_notice(AdapterNotice::BadStartupSetting {
283                        name,
284                        reason: "could not parse".into(),
285                    });
286                    continue;
287                }
288            },
289            _ => vec![(name, value)],
290        };
291        for (key, val) in settings {
292            const LOCAL: bool = false;
293            // TODO: Issuing an error here is better than what we did before
294            // (silently ignore errors on set), but erroring the connection
295            // might be the better behavior. We maybe need to support more
296            // options sent by psql and drivers before we can safely do this.
297            if let Err(err) =
298                session
299                    .vars_mut()
300                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
301            {
302                session.add_notice(AdapterNotice::BadStartupSetting {
303                    name: key,
304                    reason: err.to_string(),
305                });
306            }
307        }
308    }
309    session
310        .vars_mut()
311        .end_transaction(EndTransactionAction::Commit);
312
313    let _guard = match active_connection_counter.allocate_connection(session.user()) {
314        Ok(drop_connection) => drop_connection,
315        Err(e) => {
316            let e: AdapterError = e.into();
317            return conn.send(e.into_response(Severity::Fatal)).await;
318        }
319    };
320
321    // Register session with adapter.
322    let mut adapter_client = match adapter_client.startup(session).await {
323        Ok(adapter_client) => adapter_client,
324        Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
325    };
326
327    let mut buf = vec![BackendMessage::AuthenticationOk];
328    for var in adapter_client.session().vars().notify_set() {
329        buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
330    }
331    buf.push(BackendMessage::BackendKeyData {
332        conn_id: adapter_client.session().conn_id().unhandled(),
333        secret_key: adapter_client.session().secret_key(),
334    });
335    buf.extend(
336        adapter_client
337            .session()
338            .drain_notices()
339            .into_iter()
340            .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
341    );
342    buf.push(BackendMessage::ReadyForQuery(
343        adapter_client.session().transaction().into(),
344    ));
345    conn.send_all(buf).await?;
346    conn.flush().await?;
347
348    let machine = StateMachine {
349        conn,
350        adapter_client,
351        txn_needs_commit: false,
352    };
353
354    select! {
355        r = machine.run() => {
356            // Errors produced internally (like MAX_REQUEST_SIZE being exceeded) should send an
357            // error to the client informing them why the connection was closed. We still want to
358            // return the original error up the stack, though, so we skip error checking during conn
359            // operations.
360            if let Err(err) = &r {
361                let _ = conn
362                    .send(ErrorResponse::fatal(
363                        SqlState::CONNECTION_FAILURE,
364                        err.to_string(),
365                    ))
366                    .await;
367                let _ = conn.flush().await;
368            }
369            r
370        },
371        _ = expired => {
372            conn
373                .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
374                .await?;
375            conn.flush().await
376        }
377    }
378}
379
380/// Returns (name, value) session settings pairs from an options value.
381///
382/// From Postgres, see pg_split_opts in postinit.c and process_postgres_switches
383/// in postgres.c.
384fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
385    let opts = split_options(value);
386    let mut pairs = Vec::with_capacity(opts.len());
387    let mut seen_prefix = false;
388    for opt in opts {
389        if !seen_prefix {
390            if opt == "-c" {
391                seen_prefix = true;
392            } else {
393                let (key, val) = parse_option(&opt)?;
394                pairs.push((key.to_owned(), val.to_owned()));
395            }
396        } else {
397            let (key, val) = opt.split_once('=').ok_or(())?;
398            pairs.push((key.to_owned(), val.to_owned()));
399            seen_prefix = false;
400        }
401    }
402    Ok(pairs)
403}
404
405/// Returns the parsed key and value from option of the form `--key=value`, `-c
406/// key=value`, or `-ckey=value`. Keys replace `-` with `_`. Returns an error if
407/// there was some other prefix.
408fn parse_option(option: &str) -> Result<(&str, &str), ()> {
409    let (key, value) = option.split_once('=').ok_or(())?;
410    for prefix in &["-c", "--"] {
411        if let Some(key) = key.strip_prefix(prefix) {
412            return Ok((key, value));
413        }
414    }
415    Err(())
416}
417
418/// Splits value by any number of spaces except those preceded by `\`.
419fn split_options(value: &str) -> Vec<String> {
420    let mut strs = Vec::new();
421    // Need to build a string because of the escaping, so we can't simply
422    // subslice into value, and this isn't called enough to need to make it
423    // smart so it only builds a string if needed.
424    let mut current = String::new();
425    let mut was_slash = false;
426    for c in value.chars() {
427        was_slash = match c {
428            ' ' => {
429                if was_slash {
430                    current.push(' ');
431                } else if !current.is_empty() {
432                    // To ignore multiple spaces in a row, only push if current
433                    // is not empty.
434                    strs.push(std::mem::take(&mut current));
435                }
436                false
437            }
438            '\\' => {
439                if was_slash {
440                    // Two slashes in a row will add a slash and not escape the
441                    // next char.
442                    current.push('\\');
443                    false
444                } else {
445                    true
446                }
447            }
448            _ => {
449                current.push(c);
450                false
451            }
452        };
453    }
454    // A `\` at the end will be ignored.
455    if !current.is_empty() {
456        strs.push(current);
457    }
458    strs
459}
460
461#[derive(Debug)]
462enum State {
463    Ready,
464    Drain,
465    Done,
466}
467
468struct StateMachine<'a, A> {
469    conn: &'a mut FramedConn<A>,
470    adapter_client: mz_adapter::SessionClient,
471    txn_needs_commit: bool,
472}
473
474enum SendRowsEndedReason {
475    Success {
476        result_size: u64,
477        rows_returned: u64,
478    },
479    Errored {
480        error: String,
481    },
482    Canceled,
483}
484
485const ABORTED_TXN_MSG: &str =
486    "current transaction is aborted, commands ignored until end of transaction block";
487
488impl<'a, A> StateMachine<'a, A>
489where
490    A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
491{
492    // Manually desugar this (don't use `async fn run`) here because a much better
493    // error message is produced if there are problems with Send or other traits
494    // somewhere within the Future.
495    #[allow(clippy::manual_async_fn)]
496    #[mz_ore::instrument(level = "debug")]
497    fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
498        async move {
499            let mut state = State::Ready;
500            loop {
501                self.send_pending_notices().await?;
502                state = match state {
503                    State::Ready => self.advance_ready().await?,
504                    State::Drain => self.advance_drain().await?,
505                    State::Done => return Ok(()),
506                };
507                self.adapter_client
508                    .add_idle_in_transaction_session_timeout();
509            }
510        }
511    }
512
513    #[instrument(level = "debug")]
514    async fn advance_ready(&mut self) -> Result<State, io::Error> {
515        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
516        let message = select! {
517            biased;
518
519            // `recv_timeout()` is cancel-safe as per it's docs.
520            Some(timeout) = self.adapter_client.recv_timeout() => {
521                let err: AdapterError = timeout.into();
522                let conn_id = self.adapter_client.session().conn_id();
523                tracing::warn!("session timed out, conn_id {}", conn_id);
524
525                // Process the error, doing any state cleanup.
526                let error_response = err.into_response(Severity::Fatal);
527                let error_state = self.error(error_response).await;
528
529                // Terminate __after__ we do any cleanup.
530                self.adapter_client.terminate().await;
531
532                // We must wait for the client to send a request before we can send the error response.
533                // Due to the PG wire protocol, we can't send an ErrorResponse unless it is in response
534                // to a client message.
535                let _ = self.conn.recv().await?;
536                return error_state;
537            },
538            // `recv()` is cancel-safe as per it's docs.
539            message = self.conn.recv() => message?,
540        };
541
542        self.adapter_client
543            .remove_idle_in_transaction_session_timeout();
544
545        // NOTE(guswynn): we could consider adding spans to all message types. Currently
546        // only a few message types seem useful.
547        let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
548
549        let next_state = match message {
550            Some(FrontendMessage::Query { sql }) => {
551                let query_root_span =
552                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
553                query_root_span.follows_from(tracing::Span::current());
554                self.query(sql).instrument(query_root_span).await?
555            }
556            Some(FrontendMessage::Parse {
557                name,
558                sql,
559                param_types,
560            }) => self.parse(name, sql, param_types).await?,
561            Some(FrontendMessage::Bind {
562                portal_name,
563                statement_name,
564                param_formats,
565                raw_params,
566                result_formats,
567            }) => {
568                self.bind(
569                    portal_name,
570                    statement_name,
571                    param_formats,
572                    raw_params,
573                    result_formats,
574                )
575                .await?
576            }
577            Some(FrontendMessage::Execute {
578                portal_name,
579                max_rows,
580            }) => {
581                let max_rows = match usize::try_from(max_rows) {
582                    Ok(0) | Err(_) => ExecuteCount::All, // If `max_rows < 0`, no limit.
583                    Ok(n) => ExecuteCount::Count(n),
584                };
585                let execute_root_span =
586                    tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
587                execute_root_span.follows_from(tracing::Span::current());
588                let state = self
589                    .execute(
590                        portal_name,
591                        max_rows,
592                        portal_exec_message,
593                        None,
594                        ExecuteTimeout::None,
595                        None,
596                    )
597                    .instrument(execute_root_span)
598                    .await?;
599                // In PostgreSQL, when using the extended query protocol, some statements may
600                // trigger an eager commit of the current implicit transaction,
601                // see: <https://git.postgresql.org/gitweb/?p=postgresql.git&a=commitdiff&h=f92944137>.
602                //
603                // In Materialize, however, we eagerly commit every statement outside of an explicit
604                // transaction when using the extended query protocol. This allows us to eliminate
605                // the possibility of a multiple statement implicit transaction, which in turn
606                // allows us to apply single-statement optimizations to queries issued in implicit
607                // transactions in the extended query protocol.
608                //
609                // We don't immediately commit here to allow users to page through the portal if
610                // necessary. Committing the transaction would destroy the portal before the next
611                // Execute command has a chance to resume it. So we instead mark the transaction
612                // for commit the next time that `ensure_transaction` is called.
613                if self.adapter_client.session().transaction().is_implicit() {
614                    self.txn_needs_commit = true;
615                }
616                state
617            }
618            Some(FrontendMessage::DescribeStatement { name }) => {
619                self.describe_statement(&name).await?
620            }
621            Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
622            Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
623            Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
624            Some(FrontendMessage::Flush) => self.flush().await?,
625            Some(FrontendMessage::Sync) => self.sync().await?,
626            Some(FrontendMessage::Terminate) => State::Done,
627
628            Some(FrontendMessage::CopyData(_))
629            | Some(FrontendMessage::CopyDone)
630            | Some(FrontendMessage::CopyFail(_))
631            | Some(FrontendMessage::Password { .. }) => State::Drain,
632            None => State::Done,
633        };
634
635        Ok(next_state)
636    }
637
638    async fn advance_drain(&mut self) -> Result<State, io::Error> {
639        let message = self.conn.recv().await?;
640        if message.is_some() {
641            self.adapter_client
642                .remove_idle_in_transaction_session_timeout();
643        }
644        match message {
645            Some(FrontendMessage::Sync) => self.sync().await,
646            None => Ok(State::Done),
647            _ => Ok(State::Drain),
648        }
649    }
650
651    #[instrument(level = "debug")]
652    async fn one_query(&mut self, stmt: Statement<Raw>, sql: String) -> Result<State, io::Error> {
653        // Bind the portal. Note that this does not set the empty string prepared
654        // statement.
655        const EMPTY_PORTAL: &str = "";
656        if let Err(e) = self
657            .adapter_client
658            .declare(EMPTY_PORTAL.to_string(), stmt, sql)
659            .await
660        {
661            return self.error(e.into_response(Severity::Error)).await;
662        }
663
664        let stmt_desc = self
665            .adapter_client
666            .session()
667            .get_portal_unverified(EMPTY_PORTAL)
668            .map(|portal| portal.desc.clone())
669            .expect("unnamed portal should be present");
670        if !stmt_desc.param_types.is_empty() {
671            return self
672                .error(ErrorResponse::error(
673                    SqlState::UNDEFINED_PARAMETER,
674                    "there is no parameter $1",
675                ))
676                .await;
677        }
678
679        // Maybe send row description.
680        if let Some(relation_desc) = &stmt_desc.relation_desc {
681            if !stmt_desc.is_copy {
682                let formats = vec![Format::Text; stmt_desc.arity()];
683                self.send(BackendMessage::RowDescription(
684                    message::encode_row_description(relation_desc, &formats),
685                ))
686                .await?;
687            }
688        }
689
690        let result = match self
691            .adapter_client
692            .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
693            .await
694        {
695            Ok((response, execute_started)) => {
696                self.send_pending_notices().await?;
697                self.send_execute_response(
698                    response,
699                    stmt_desc.relation_desc,
700                    EMPTY_PORTAL.to_string(),
701                    ExecuteCount::All,
702                    portal_exec_message,
703                    None,
704                    ExecuteTimeout::None,
705                    execute_started,
706                )
707                .await
708            }
709            Err(e) => {
710                self.send_pending_notices().await?;
711                self.error(e.into_response(Severity::Error)).await
712            }
713        };
714
715        // Destroy the portal.
716        self.adapter_client.session().remove_portal(EMPTY_PORTAL);
717
718        result
719    }
720
721    async fn ensure_transaction(&mut self, num_stmts: usize) -> Result<(), io::Error> {
722        if self.txn_needs_commit {
723            self.commit_transaction().await?;
724        }
725        // start_transaction can't error (but assert that just in case it changes in
726        // the future.
727        let res = self.adapter_client.start_transaction(Some(num_stmts));
728        assert_ok!(res);
729        Ok(())
730    }
731
732    fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
733        match self.adapter_client.parse(sql) {
734            Ok(result) => result.map_err(|e| {
735                // Convert our 0-based byte position to pgwire's 1-based character
736                // position.
737                let pos = sql[..e.error.pos].chars().count() + 1;
738                ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
739            }),
740            Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
741        }
742    }
743
744    // See "Multiple Statements in a Simple Query" which documents how implicit
745    // transactions are handled.
746    // From https://www.postgresql.org/docs/current/protocol-flow.html
747    #[instrument(level = "debug")]
748    async fn query(&mut self, sql: String) -> Result<State, io::Error> {
749        // Parse first before doing any transaction checking.
750        let stmts = match self.parse_sql(&sql) {
751            Ok(stmts) => stmts,
752            Err(err) => {
753                self.error(err).await?;
754                return self.ready().await;
755            }
756        };
757
758        let num_stmts = stmts.len();
759
760        // Compare with postgres' backend/tcop/postgres.c exec_simple_query.
761        for StatementParseResult { ast: stmt, sql } in stmts {
762            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
763            if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
764                self.aborted_txn_error().await?;
765                break;
766            }
767
768            // Start an implicit transaction if we aren't in any transaction and there's
769            // more than one statement. This mirrors the `use_implicit_block` variable in
770            // postgres.
771            //
772            // This needs to be done in the loop instead of once at the top because
773            // a COMMIT/ROLLBACK statement needs to start a new transaction on next
774            // statement.
775            self.ensure_transaction(num_stmts).await?;
776
777            match self.one_query(stmt, sql.to_string()).await? {
778                State::Ready => (),
779                State::Drain => break,
780                State::Done => return Ok(State::Done),
781            }
782        }
783
784        // Implicit transactions are closed at the end of a Query message.
785        {
786            if self.adapter_client.session().transaction().is_implicit() {
787                self.commit_transaction().await?;
788            }
789        }
790
791        if num_stmts == 0 {
792            self.send(BackendMessage::EmptyQueryResponse).await?;
793        }
794
795        self.ready().await
796    }
797
798    #[instrument(level = "debug")]
799    async fn parse(
800        &mut self,
801        name: String,
802        sql: String,
803        param_oids: Vec<u32>,
804    ) -> Result<State, io::Error> {
805        // Start a transaction if we aren't in one.
806        self.ensure_transaction(1).await?;
807
808        let mut param_types = vec![];
809        for oid in param_oids {
810            match mz_pgrepr::Type::from_oid(oid) {
811                Ok(ty) => match ScalarType::try_from(&ty) {
812                    Ok(ty) => param_types.push(Some(ty)),
813                    Err(err) => {
814                        return self
815                            .error(ErrorResponse::error(
816                                SqlState::INVALID_PARAMETER_VALUE,
817                                err.to_string(),
818                            ))
819                            .await;
820                    }
821                },
822                Err(_) if oid == 0 => param_types.push(None),
823                Err(e) => {
824                    return self
825                        .error(ErrorResponse::error(
826                            SqlState::PROTOCOL_VIOLATION,
827                            e.to_string(),
828                        ))
829                        .await;
830                }
831            }
832        }
833
834        let stmts = match self.parse_sql(&sql) {
835            Ok(stmts) => stmts,
836            Err(err) => {
837                return self.error(err).await;
838            }
839        };
840        if stmts.len() > 1 {
841            return self
842                .error(ErrorResponse::error(
843                    SqlState::INTERNAL_ERROR,
844                    "cannot insert multiple commands into a prepared statement",
845                ))
846                .await;
847        }
848        let (maybe_stmt, sql) = match stmts.into_iter().next() {
849            None => (None, ""),
850            Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
851        };
852        if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
853            return self.aborted_txn_error().await;
854        }
855        match self
856            .adapter_client
857            .prepare(name, maybe_stmt, sql.to_string(), param_types)
858            .await
859        {
860            Ok(()) => {
861                self.send(BackendMessage::ParseComplete).await?;
862                Ok(State::Ready)
863            }
864            Err(e) => self.error(e.into_response(Severity::Error)).await,
865        }
866    }
867
868    /// Commits and clears the current transaction.
869    #[instrument(level = "debug")]
870    async fn commit_transaction(&mut self) -> Result<(), io::Error> {
871        self.end_transaction(EndTransactionAction::Commit).await
872    }
873
874    /// Rollback and clears the current transaction.
875    #[instrument(level = "debug")]
876    async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
877        self.end_transaction(EndTransactionAction::Rollback).await
878    }
879
880    /// End a transaction and report to the user if an error occurred.
881    #[instrument(level = "debug")]
882    async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
883        self.txn_needs_commit = false;
884        let resp = self.adapter_client.end_transaction(action).await;
885        if let Err(err) = resp {
886            self.send(BackendMessage::ErrorResponse(
887                err.into_response(Severity::Error),
888            ))
889            .await?;
890        }
891        Ok(())
892    }
893
894    #[instrument(level = "debug")]
895    async fn bind(
896        &mut self,
897        portal_name: String,
898        statement_name: String,
899        param_formats: Vec<Format>,
900        raw_params: Vec<Option<Vec<u8>>>,
901        result_formats: Vec<Format>,
902    ) -> Result<State, io::Error> {
903        // Start a transaction if we aren't in one.
904        self.ensure_transaction(1).await?;
905
906        let aborted_txn = self.is_aborted_txn();
907        let stmt = match self
908            .adapter_client
909            .get_prepared_statement(&statement_name)
910            .await
911        {
912            Ok(stmt) => stmt,
913            Err(err) => return self.error(err.into_response(Severity::Error)).await,
914        };
915
916        let param_types = &stmt.desc().param_types;
917        if param_types.len() != raw_params.len() {
918            let message = format!(
919                "bind message supplies {actual} parameters, \
920                 but prepared statement \"{name}\" requires {expected}",
921                name = statement_name,
922                actual = raw_params.len(),
923                expected = param_types.len()
924            );
925            return self
926                .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, message))
927                .await;
928        }
929        let param_formats = match pad_formats(param_formats, raw_params.len()) {
930            Ok(param_formats) => param_formats,
931            Err(msg) => {
932                return self
933                    .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
934                    .await;
935            }
936        };
937        if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
938            return self.aborted_txn_error().await;
939        }
940        let buf = RowArena::new();
941        let mut params = vec![];
942        for (raw_param, mz_typ, format) in izip!(raw_params, param_types, param_formats) {
943            let pg_typ = mz_pgrepr::Type::from(mz_typ);
944            let datum = match raw_param {
945                None => Datum::Null,
946                Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
947                    Ok(param) => param.into_datum(&buf, &pg_typ),
948                    Err(err) => {
949                        let msg = format!("unable to decode parameter: {}", err);
950                        return self
951                            .error(ErrorResponse::error(SqlState::INVALID_PARAMETER_VALUE, msg))
952                            .await;
953                    }
954                },
955            };
956            params.push((datum, mz_typ.clone()))
957        }
958
959        let result_formats = match pad_formats(
960            result_formats,
961            stmt.desc()
962                .relation_desc
963                .clone()
964                .map(|desc| desc.typ().column_types.len())
965                .unwrap_or(0),
966        ) {
967            Ok(result_formats) => result_formats,
968            Err(msg) => {
969                return self
970                    .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
971                    .await;
972            }
973        };
974
975        // Binary encodings are disabled for list, map, and aclitem types, but this doesn't
976        // apply to COPY TO statements.
977        if !stmt.stmt().map_or(false, |stmt| {
978            matches!(
979                stmt,
980                Statement::Copy(CopyStatement {
981                    direction: CopyDirection::To,
982                    ..
983                })
984            )
985        }) {
986            if let Some(desc) = stmt.desc().relation_desc.clone() {
987                for (format, ty) in result_formats.iter().zip(desc.iter_types()) {
988                    match (format, &ty.scalar_type) {
989                        (Format::Binary, mz_repr::ScalarType::List { .. }) => {
990                            return self
991                                .error(ErrorResponse::error(
992                                    SqlState::PROTOCOL_VIOLATION,
993                                    "binary encoding of list types is not implemented",
994                                ))
995                                .await;
996                        }
997                        (Format::Binary, mz_repr::ScalarType::Map { .. }) => {
998                            return self
999                                .error(ErrorResponse::error(
1000                                    SqlState::PROTOCOL_VIOLATION,
1001                                    "binary encoding of map types is not implemented",
1002                                ))
1003                                .await;
1004                        }
1005                        (Format::Binary, mz_repr::ScalarType::AclItem) => {
1006                            return self
1007                                .error(ErrorResponse::error(
1008                                    SqlState::PROTOCOL_VIOLATION,
1009                                    "binary encoding of aclitem types does not exist",
1010                                ))
1011                                .await;
1012                        }
1013                        _ => (),
1014                    }
1015                }
1016            }
1017        }
1018
1019        let desc = stmt.desc().clone();
1020        let revision = stmt.catalog_revision;
1021        let logging = Arc::clone(stmt.logging());
1022        let stmt = stmt.stmt().cloned();
1023        if let Err(err) = self.adapter_client.session().set_portal(
1024            portal_name,
1025            desc,
1026            stmt,
1027            logging,
1028            params,
1029            result_formats,
1030            revision,
1031        ) {
1032            return self.error(err.into_response(Severity::Error)).await;
1033        }
1034
1035        self.send(BackendMessage::BindComplete).await?;
1036        Ok(State::Ready)
1037    }
1038
1039    fn execute(
1040        &mut self,
1041        portal_name: String,
1042        max_rows: ExecuteCount,
1043        get_response: GetResponse,
1044        fetch_portal_name: Option<String>,
1045        timeout: ExecuteTimeout,
1046        outer_ctx_extra: Option<ExecuteContextExtra>,
1047    ) -> BoxFuture<'_, Result<State, io::Error>> {
1048        async move {
1049            let aborted_txn = self.is_aborted_txn();
1050
1051            // Check if the portal has been started and can be continued.
1052            let portal = match self
1053                .adapter_client
1054                .session()
1055                .get_portal_unverified_mut(&portal_name)
1056            {
1057                Some(portal) => portal,
1058                None => {
1059                    let msg = format!("portal {} does not exist", portal_name.quoted());
1060                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1061                        self.adapter_client.retire_execute(
1062                            outer_ctx_extra,
1063                            StatementEndedExecutionReason::Errored { error: msg.clone() },
1064                        );
1065                    }
1066                    return self
1067                        .error(ErrorResponse::error(SqlState::INVALID_CURSOR_NAME, msg))
1068                        .await;
1069                }
1070            };
1071
1072            // In an aborted transaction, reject all commands except COMMIT/ROLLBACK.
1073            let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1074            if aborted_txn && !txn_exit_stmt {
1075                if let Some(outer_ctx_extra) = outer_ctx_extra {
1076                    self.adapter_client.retire_execute(
1077                        outer_ctx_extra,
1078                        StatementEndedExecutionReason::Errored {
1079                            error: ABORTED_TXN_MSG.to_string(),
1080                        },
1081                    );
1082                }
1083                return self.aborted_txn_error().await;
1084            }
1085
1086            let row_desc = portal.desc.relation_desc.clone();
1087            match &mut portal.state {
1088                PortalState::NotStarted => {
1089                    // Start a transaction if we aren't in one.
1090                    self.ensure_transaction(1).await?;
1091                    match self
1092                        .adapter_client
1093                        .execute(
1094                            portal_name.clone(),
1095                            self.conn.wait_closed(),
1096                            outer_ctx_extra,
1097                        )
1098                        .await
1099                    {
1100                        Ok((response, execute_started)) => {
1101                            self.send_pending_notices().await?;
1102                            self.send_execute_response(
1103                                response,
1104                                row_desc,
1105                                portal_name,
1106                                max_rows,
1107                                get_response,
1108                                fetch_portal_name,
1109                                timeout,
1110                                execute_started,
1111                            )
1112                            .await
1113                        }
1114                        Err(e) => {
1115                            self.send_pending_notices().await?;
1116                            self.error(e.into_response(Severity::Error)).await
1117                        }
1118                    }
1119                }
1120                PortalState::InProgress(rows) => {
1121                    let rows = rows.take().expect("InProgress rows must be populated");
1122                    let (result, statement_ended_execution_reason) = match self
1123                        .send_rows(
1124                            row_desc.expect("portal missing row desc on resumption"),
1125                            portal_name,
1126                            rows,
1127                            max_rows,
1128                            get_response,
1129                            fetch_portal_name,
1130                            timeout,
1131                        )
1132                        .await
1133                    {
1134                        Err(e) => {
1135                            // This is an error communicating with the connection.
1136                            // We consider that to be a cancelation, rather than a query error.
1137                            (Err(e), StatementEndedExecutionReason::Canceled)
1138                        }
1139                        Ok((ok, SendRowsEndedReason::Canceled)) => {
1140                            (Ok(ok), StatementEndedExecutionReason::Canceled)
1141                        }
1142                        // NOTE: For now the values for `result_size` and
1143                        // `rows_returned` in fetches are a bit confusing.
1144                        // We record `Some(n)` for the first fetch, where `n` is
1145                        // the number of bytes/rows returned by the inner
1146                        // execute (regardless of how many rows the
1147                        // fetch fetched), and `None` for subsequent fetches.
1148                        //
1149                        // This arguably makes sense since the size/rows
1150                        // returned measures how much work the compute
1151                        // layer had to do to satisfy the query, but
1152                        // we should revisit it if/when we start
1153                        // logging the inner execute separately.
1154                        Ok((
1155                            ok,
1156                            SendRowsEndedReason::Success {
1157                                result_size: _,
1158                                rows_returned: _,
1159                            },
1160                        )) => (
1161                            Ok(ok),
1162                            StatementEndedExecutionReason::Success {
1163                                result_size: None,
1164                                rows_returned: None,
1165                                execution_strategy: None,
1166                            },
1167                        ),
1168                        Ok((ok, SendRowsEndedReason::Errored { error })) => {
1169                            (Ok(ok), StatementEndedExecutionReason::Errored { error })
1170                        }
1171                    };
1172                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1173                        self.adapter_client
1174                            .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1175                    }
1176                    result
1177                }
1178                // FETCH is an awkward command for our current architecture. In Postgres it
1179                // will extract <count> rows from the target portal, cache them, and return
1180                // them to the user as requested. Its command tag is always FETCH <num rows
1181                // extracted>. In Materialize, since we have chosen to not fully support FETCH,
1182                // we must remember the number of rows that were returned. Use this tag to
1183                // remember that information and return it.
1184                PortalState::Completed(Some(tag)) => {
1185                    let tag = tag.to_string();
1186                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1187                        self.adapter_client.retire_execute(
1188                            outer_ctx_extra,
1189                            StatementEndedExecutionReason::Success {
1190                                result_size: None,
1191                                rows_returned: None,
1192                                execution_strategy: None,
1193                            },
1194                        );
1195                    }
1196                    self.send(BackendMessage::CommandComplete { tag }).await?;
1197                    Ok(State::Ready)
1198                }
1199                PortalState::Completed(None) => {
1200                    let error = format!(
1201                        "portal {} cannot be run",
1202                        Ident::new_unchecked(portal_name).to_ast_string_stable()
1203                    );
1204                    if let Some(outer_ctx_extra) = outer_ctx_extra {
1205                        self.adapter_client.retire_execute(
1206                            outer_ctx_extra,
1207                            StatementEndedExecutionReason::Errored {
1208                                error: error.clone(),
1209                            },
1210                        );
1211                    }
1212                    self.error(ErrorResponse::error(
1213                        SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1214                        error,
1215                    ))
1216                    .await
1217                }
1218            }
1219        }
1220        .instrument(debug_span!("execute"))
1221        .boxed()
1222    }
1223
1224    #[instrument(level = "debug")]
1225    async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1226        // Start a transaction if we aren't in one.
1227        self.ensure_transaction(1).await?;
1228
1229        let stmt = match self.adapter_client.get_prepared_statement(name).await {
1230            Ok(stmt) => stmt,
1231            Err(err) => return self.error(err.into_response(Severity::Error)).await,
1232        };
1233        // Cloning to avoid a mutable borrow issue because `send` also uses `adapter_client`
1234        let parameter_desc = BackendMessage::ParameterDescription(
1235            stmt.desc()
1236                .param_types
1237                .iter()
1238                .map(mz_pgrepr::Type::from)
1239                .collect(),
1240        );
1241        // Claim that all results will be output in text format, even
1242        // though the true result formats are not yet known. A bit
1243        // weird, but this is the behavior that PostgreSQL specifies.
1244        let formats = vec![Format::Text; stmt.desc().arity()];
1245        let row_desc = describe_rows(stmt.desc(), &formats);
1246        self.send_all([parameter_desc, row_desc]).await?;
1247        Ok(State::Ready)
1248    }
1249
1250    #[instrument(level = "debug")]
1251    async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1252        // Start a transaction if we aren't in one.
1253        self.ensure_transaction(1).await?;
1254
1255        let session = self.adapter_client.session();
1256        let row_desc = session
1257            .get_portal_unverified(name)
1258            .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1259        match row_desc {
1260            Some(row_desc) => {
1261                self.send(row_desc).await?;
1262                Ok(State::Ready)
1263            }
1264            None => {
1265                self.error(ErrorResponse::error(
1266                    SqlState::INVALID_CURSOR_NAME,
1267                    format!("portal {} does not exist", name.quoted()),
1268                ))
1269                .await
1270            }
1271        }
1272    }
1273
1274    #[instrument(level = "debug")]
1275    async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1276        self.adapter_client
1277            .session()
1278            .remove_prepared_statement(&name);
1279        self.send(BackendMessage::CloseComplete).await?;
1280        Ok(State::Ready)
1281    }
1282
1283    #[instrument(level = "debug")]
1284    async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1285        self.adapter_client.session().remove_portal(&name);
1286        self.send(BackendMessage::CloseComplete).await?;
1287        Ok(State::Ready)
1288    }
1289
1290    fn complete_portal(&mut self, name: &str) {
1291        let portal = self
1292            .adapter_client
1293            .session()
1294            .get_portal_unverified_mut(name)
1295            .expect("portal should exist");
1296        portal.state = PortalState::Completed(None);
1297    }
1298
1299    async fn fetch(
1300        &mut self,
1301        name: String,
1302        count: Option<FetchDirection>,
1303        max_rows: ExecuteCount,
1304        fetch_portal_name: Option<String>,
1305        timeout: ExecuteTimeout,
1306        ctx_extra: ExecuteContextExtra,
1307    ) -> Result<State, io::Error> {
1308        // Unlike Execute, no count specified in FETCH returns 1 row, and 0 means 0
1309        // instead of All.
1310        let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1311
1312        // Figure out how many rows we should send back by looking at the various
1313        // combinations of the execute and fetch.
1314        //
1315        // In Postgres, Fetch will cache <count> rows from the target portal and
1316        // return those as requested (if, say, an Execute message was sent with a
1317        // max_rows < the Fetch's count). We expect that case to be incredibly rare and
1318        // so have chosen to not support it until users request it. This eases
1319        // implementation difficulty since we don't have to be able to "send" rows to
1320        // a buffer.
1321        //
1322        // TODO(mjibson): Test this somehow? Need to divide up the pgtest files in
1323        // order to have some that are not Postgres compatible.
1324        let count = match (max_rows, count) {
1325            (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1326                let count = usize::cast_from(count);
1327                if max_rows < count {
1328                    let msg = "Execute with max_rows < a FETCH's count is not supported";
1329                    self.adapter_client.retire_execute(
1330                        ctx_extra,
1331                        StatementEndedExecutionReason::Errored {
1332                            error: msg.to_string(),
1333                        },
1334                    );
1335                    return self
1336                        .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1337                        .await;
1338                }
1339                ExecuteCount::Count(count)
1340            }
1341            (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1342                let msg = "Execute with max_rows of a FETCH ALL is not supported";
1343                self.adapter_client.retire_execute(
1344                    ctx_extra,
1345                    StatementEndedExecutionReason::Errored {
1346                        error: msg.to_string(),
1347                    },
1348                );
1349                return self
1350                    .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1351                    .await;
1352            }
1353            (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1354            (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1355                ExecuteCount::Count(usize::cast_from(count))
1356            }
1357        };
1358        let cursor_name = name.to_string();
1359        self.execute(
1360            cursor_name,
1361            count,
1362            fetch_message,
1363            fetch_portal_name,
1364            timeout,
1365            Some(ctx_extra),
1366        )
1367        .await
1368    }
1369
1370    async fn flush(&mut self) -> Result<State, io::Error> {
1371        self.conn.flush().await?;
1372        Ok(State::Ready)
1373    }
1374
1375    /// Sends a backend message to the client, after applying a severity filter.
1376    ///
1377    /// The message is only sent if its severity is above the severity set
1378    /// in the session, with the default value being NOTICE.
1379    #[instrument(level = "debug")]
1380    async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1381    where
1382        M: Into<BackendMessage>,
1383    {
1384        let message: BackendMessage = message.into();
1385        let is_error =
1386            matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1387
1388        self.conn.send(message).await?;
1389
1390        // Flush immediately after sending an error response, as some clients
1391        // expect to be able to read the error response before sending a Sync
1392        // message. This is arguably in violation of the protocol specification,
1393        // but the specification is somewhat ambiguous, and easier to match
1394        // PostgreSQL here than to fix all the clients that have this
1395        // expectation.
1396        if is_error {
1397            self.conn.flush().await?;
1398        }
1399
1400        Ok(())
1401    }
1402
1403    #[instrument(level = "debug")]
1404    pub async fn send_all(
1405        &mut self,
1406        messages: impl IntoIterator<Item = BackendMessage>,
1407    ) -> Result<(), io::Error> {
1408        for m in messages {
1409            self.send(m).await?;
1410        }
1411        Ok(())
1412    }
1413
1414    #[instrument(level = "debug")]
1415    async fn sync(&mut self) -> Result<State, io::Error> {
1416        // Close the current transaction if we are in an implicit transaction.
1417        if self.adapter_client.session().transaction().is_implicit() {
1418            self.commit_transaction().await?;
1419        }
1420        self.ready().await
1421    }
1422
1423    #[instrument(level = "debug")]
1424    async fn ready(&mut self) -> Result<State, io::Error> {
1425        let txn_state = self.adapter_client.session().transaction().into();
1426        self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1427        self.flush().await
1428    }
1429
1430    // Converts a RowsFuture to a stream while also checking for connection close.
1431    #[instrument(level = "debug")]
1432    async fn row_future_to_stream<'s, 'p>(
1433        &'s mut self,
1434        parent: &'p tracing::Span,
1435        mut rows: RowsFuture,
1436    ) -> Result<UnboundedReceiver<PeekResponseUnary>, io::Error>
1437    where
1438        'p: 's,
1439    {
1440        // select is safe to use because if close finishes, rows is canceled,
1441        // which is the intended behavior.
1442        let span = tracing::debug_span!(parent: parent, "row_future_to_stream");
1443        async move {
1444            loop {
1445                tokio::select! {
1446                    err = self.conn.wait_closed() => return Err(err),
1447                    rows = &mut rows => {
1448                        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
1449                        tx.send(rows).expect("send must succeed");
1450                        return Ok(rx);
1451                    }
1452                    notice = self.adapter_client.session().recv_notice() => {
1453                        self.send(notice.into_response())
1454                            .await?;
1455                        self.conn.flush().await?;
1456                    }
1457                }
1458            }
1459        }
1460        .instrument(span)
1461        .await
1462    }
1463
1464    #[allow(clippy::too_many_arguments)]
1465    #[instrument(level = "debug")]
1466    async fn send_execute_response(
1467        &mut self,
1468        response: ExecuteResponse,
1469        row_desc: Option<RelationDesc>,
1470        portal_name: String,
1471        max_rows: ExecuteCount,
1472        get_response: GetResponse,
1473        fetch_portal_name: Option<String>,
1474        timeout: ExecuteTimeout,
1475        execute_started: Instant,
1476    ) -> Result<State, io::Error> {
1477        let mut tag = response.tag();
1478
1479        macro_rules! command_complete {
1480            () => {{
1481                self.send(BackendMessage::CommandComplete {
1482                    tag: tag
1483                        .take()
1484                        .expect("command_complete only called on tag-generating results"),
1485                })
1486                .await?;
1487                Ok(State::Ready)
1488            }};
1489        }
1490
1491        let r = match response {
1492            ExecuteResponse::ClosedCursor => {
1493                self.complete_portal(&portal_name);
1494                command_complete!()
1495            }
1496            ExecuteResponse::DeclaredCursor => {
1497                self.complete_portal(&portal_name);
1498                command_complete!()
1499            }
1500            ExecuteResponse::EmptyQuery => {
1501                self.send(BackendMessage::EmptyQueryResponse).await?;
1502                Ok(State::Ready)
1503            }
1504            ExecuteResponse::Fetch {
1505                name,
1506                count,
1507                timeout,
1508                ctx_extra,
1509            } => {
1510                self.fetch(
1511                    name,
1512                    count,
1513                    max_rows,
1514                    Some(portal_name.to_string()),
1515                    timeout,
1516                    ctx_extra,
1517                )
1518                .await
1519            }
1520            ExecuteResponse::SendingRows {
1521                future: rx,
1522                instance_id,
1523                strategy,
1524            } => {
1525                let row_desc =
1526                    row_desc.expect("missing row description for ExecuteResponse::SendingRows");
1527
1528                let span = tracing::debug_span!("sending_rows");
1529                let rows = self.row_future_to_stream(&span, rx).await?;
1530
1531                self.send_rows(
1532                    row_desc,
1533                    portal_name,
1534                    InProgressRows::new(RecordFirstRowStream::new(
1535                        Box::new(UnboundedReceiverStream::new(rows)),
1536                        execute_started,
1537                        &self.adapter_client,
1538                        Some(instance_id),
1539                        Some(strategy),
1540                    )),
1541                    max_rows,
1542                    get_response,
1543                    fetch_portal_name,
1544                    timeout,
1545                )
1546                .instrument(span)
1547                .await
1548                .map(|(state, _)| state)
1549            }
1550            ExecuteResponse::SendingRowsImmediate { rows } => {
1551                let row_desc = row_desc
1552                    .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1553
1554                let span = tracing::debug_span!("sending_rows_immediate");
1555
1556                let stream =
1557                    futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1558                self.send_rows(
1559                    row_desc,
1560                    portal_name,
1561                    InProgressRows::new(RecordFirstRowStream::new(
1562                        Box::new(stream),
1563                        execute_started,
1564                        &self.adapter_client,
1565                        None,
1566                        Some(StatementExecutionStrategy::Constant),
1567                    )),
1568                    max_rows,
1569                    get_response,
1570                    fetch_portal_name,
1571                    timeout,
1572                )
1573                .instrument(span)
1574                .await
1575                .map(|(state, _)| state)
1576            }
1577            ExecuteResponse::SetVariable { name, .. } => {
1578                // This code is somewhat awkwardly structured because we
1579                // can't hold `var` across an await point.
1580                let qn = name.to_string();
1581                let msg = if let Some(var) = self
1582                    .adapter_client
1583                    .session()
1584                    .vars_mut()
1585                    .notify_set()
1586                    .find(|v| v.name() == qn)
1587                {
1588                    Some(BackendMessage::ParameterStatus(var.name(), var.value()))
1589                } else {
1590                    None
1591                };
1592                if let Some(msg) = msg {
1593                    self.send(msg).await?;
1594                }
1595                command_complete!()
1596            }
1597            ExecuteResponse::Subscribing {
1598                rx,
1599                ctx_extra,
1600                instance_id,
1601            } => {
1602                if fetch_portal_name.is_none() {
1603                    let mut msg = ErrorResponse::notice(
1604                        SqlState::WARNING,
1605                        "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
1606                    );
1607                    if self.adapter_client.session().vars().application_name() == "psql" {
1608                        msg.hint = Some(
1609                            "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
1610                                .into(),
1611                        )
1612                    }
1613                    self.send(msg).await?;
1614                    self.conn.flush().await?;
1615                }
1616                let row_desc =
1617                    row_desc.expect("missing row description for ExecuteResponse::Subscribing");
1618                let (result, statement_ended_execution_reason) = match self
1619                    .send_rows(
1620                        row_desc,
1621                        portal_name,
1622                        InProgressRows::new(RecordFirstRowStream::new(
1623                            Box::new(UnboundedReceiverStream::new(rx)),
1624                            execute_started,
1625                            &self.adapter_client,
1626                            Some(instance_id),
1627                            None,
1628                        )),
1629                        max_rows,
1630                        get_response,
1631                        fetch_portal_name,
1632                        timeout,
1633                    )
1634                    .await
1635                {
1636                    Err(e) => {
1637                        // This is an error communicating with the connection.
1638                        // We consider that to be a cancelation, rather than a query error.
1639                        (Err(e), StatementEndedExecutionReason::Canceled)
1640                    }
1641                    Ok((ok, SendRowsEndedReason::Canceled)) => {
1642                        (Ok(ok), StatementEndedExecutionReason::Canceled)
1643                    }
1644                    Ok((
1645                        ok,
1646                        SendRowsEndedReason::Success {
1647                            result_size,
1648                            rows_returned,
1649                        },
1650                    )) => (
1651                        Ok(ok),
1652                        StatementEndedExecutionReason::Success {
1653                            result_size: Some(result_size),
1654                            rows_returned: Some(rows_returned),
1655                            execution_strategy: None,
1656                        },
1657                    ),
1658                    Ok((ok, SendRowsEndedReason::Errored { error })) => {
1659                        (Ok(ok), StatementEndedExecutionReason::Errored { error })
1660                    }
1661                };
1662                self.adapter_client
1663                    .retire_execute(ctx_extra, statement_ended_execution_reason);
1664                return result;
1665            }
1666            ExecuteResponse::CopyTo { format, resp } => {
1667                let row_desc =
1668                    row_desc.expect("missing row description for ExecuteResponse::CopyTo");
1669                match *resp {
1670                    ExecuteResponse::Subscribing {
1671                        rx,
1672                        ctx_extra,
1673                        instance_id,
1674                    } => {
1675                        let (result, statement_ended_execution_reason) = match self
1676                            .copy_rows(
1677                                format,
1678                                row_desc,
1679                                RecordFirstRowStream::new(
1680                                    Box::new(UnboundedReceiverStream::new(rx)),
1681                                    execute_started,
1682                                    &self.adapter_client,
1683                                    Some(instance_id),
1684                                    None,
1685                                ),
1686                            )
1687                            .await
1688                        {
1689                            Err(e) => {
1690                                // This is an error communicating with the connection.
1691                                // We consider that to be a cancelation, rather than a query error.
1692                                (Err(e), StatementEndedExecutionReason::Canceled)
1693                            }
1694                            Ok((
1695                                state,
1696                                SendRowsEndedReason::Success {
1697                                    result_size,
1698                                    rows_returned,
1699                                },
1700                            )) => (
1701                                Ok(state),
1702                                StatementEndedExecutionReason::Success {
1703                                    result_size: Some(result_size),
1704                                    rows_returned: Some(rows_returned),
1705                                    execution_strategy: None,
1706                                },
1707                            ),
1708                            Ok((state, SendRowsEndedReason::Errored { error })) => {
1709                                (Ok(state), StatementEndedExecutionReason::Errored { error })
1710                            }
1711                            Ok((state, SendRowsEndedReason::Canceled)) => {
1712                                (Ok(state), StatementEndedExecutionReason::Canceled)
1713                            }
1714                        };
1715                        self.adapter_client
1716                            .retire_execute(ctx_extra, statement_ended_execution_reason);
1717                        return result;
1718                    }
1719                    ExecuteResponse::SendingRows {
1720                        future: rows_rx,
1721                        instance_id,
1722                        strategy,
1723                    } => {
1724                        let span = tracing::debug_span!("sending_rows");
1725                        let rows = self.row_future_to_stream(&span, rows_rx).await?;
1726                        // We don't need to finalize execution here;
1727                        // it was already done in the
1728                        // coordinator. Just extract the state and
1729                        // return that.
1730                        return self
1731                            .copy_rows(
1732                                format,
1733                                row_desc,
1734                                RecordFirstRowStream::new(
1735                                    Box::new(UnboundedReceiverStream::new(rows)),
1736                                    execute_started,
1737                                    &self.adapter_client,
1738                                    Some(instance_id),
1739                                    Some(strategy),
1740                                ),
1741                            )
1742                            .await
1743                            .map(|(state, _)| state);
1744                    }
1745                    ExecuteResponse::SendingRowsImmediate { rows } => {
1746                        let span = tracing::debug_span!("sending_rows_immediate");
1747
1748                        let rows = futures::stream::once(futures::future::ready(
1749                            PeekResponseUnary::Rows(rows),
1750                        ));
1751                        // We don't need to finalize execution here;
1752                        // it was already done in the
1753                        // coordinator. Just extract the state and
1754                        // return that.
1755                        return self
1756                            .copy_rows(
1757                                format,
1758                                row_desc,
1759                                RecordFirstRowStream::new(
1760                                    Box::new(rows),
1761                                    execute_started,
1762                                    &self.adapter_client,
1763                                    None,
1764                                    Some(StatementExecutionStrategy::Constant),
1765                                ),
1766                            )
1767                            .instrument(span)
1768                            .await
1769                            .map(|(state, _)| state);
1770                    }
1771                    _ => {
1772                        return self
1773                            .error(ErrorResponse::error(
1774                                SqlState::INTERNAL_ERROR,
1775                                "unsupported COPY response type".to_string(),
1776                            ))
1777                            .await;
1778                    }
1779                };
1780            }
1781            ExecuteResponse::CopyFrom {
1782                id,
1783                columns,
1784                params,
1785                ctx_extra,
1786            } => {
1787                let row_desc =
1788                    row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
1789                self.copy_from(id, columns, params, row_desc, ctx_extra)
1790                    .await
1791            }
1792            ExecuteResponse::TransactionCommitted { params }
1793            | ExecuteResponse::TransactionRolledBack { params } => {
1794                let notify_set: mz_ore::collections::HashSet<String> = self
1795                    .adapter_client
1796                    .session()
1797                    .vars()
1798                    .notify_set()
1799                    .map(|v| v.name().to_string())
1800                    .collect();
1801
1802                // Only report on parameters that are in the notify set.
1803                for (name, value) in params
1804                    .into_iter()
1805                    .filter(|(name, _v)| notify_set.contains(*name))
1806                {
1807                    let msg = BackendMessage::ParameterStatus(name, value);
1808                    self.send(msg).await?;
1809                }
1810                command_complete!()
1811            }
1812
1813            ExecuteResponse::AlteredDefaultPrivileges
1814            | ExecuteResponse::AlteredObject(..)
1815            | ExecuteResponse::AlteredRole
1816            | ExecuteResponse::AlteredSystemConfiguration
1817            | ExecuteResponse::CreatedCluster { .. }
1818            | ExecuteResponse::CreatedClusterReplica { .. }
1819            | ExecuteResponse::CreatedConnection { .. }
1820            | ExecuteResponse::CreatedDatabase { .. }
1821            | ExecuteResponse::CreatedIndex { .. }
1822            | ExecuteResponse::CreatedIntrospectionSubscribe
1823            | ExecuteResponse::CreatedMaterializedView { .. }
1824            | ExecuteResponse::CreatedContinualTask { .. }
1825            | ExecuteResponse::CreatedRole
1826            | ExecuteResponse::CreatedSchema { .. }
1827            | ExecuteResponse::CreatedSecret { .. }
1828            | ExecuteResponse::CreatedSink { .. }
1829            | ExecuteResponse::CreatedSource { .. }
1830            | ExecuteResponse::CreatedTable { .. }
1831            | ExecuteResponse::CreatedType
1832            | ExecuteResponse::CreatedView { .. }
1833            | ExecuteResponse::CreatedViews { .. }
1834            | ExecuteResponse::CreatedNetworkPolicy
1835            | ExecuteResponse::Comment
1836            | ExecuteResponse::Deallocate { .. }
1837            | ExecuteResponse::Deleted(..)
1838            | ExecuteResponse::DiscardedAll
1839            | ExecuteResponse::DiscardedTemp
1840            | ExecuteResponse::DroppedObject(_)
1841            | ExecuteResponse::DroppedOwned
1842            | ExecuteResponse::GrantedPrivilege
1843            | ExecuteResponse::GrantedRole
1844            | ExecuteResponse::Inserted(..)
1845            | ExecuteResponse::Copied(..)
1846            | ExecuteResponse::Prepare
1847            | ExecuteResponse::Raised
1848            | ExecuteResponse::ReassignOwned
1849            | ExecuteResponse::RevokedPrivilege
1850            | ExecuteResponse::RevokedRole
1851            | ExecuteResponse::StartedTransaction { .. }
1852            | ExecuteResponse::Updated(..)
1853            | ExecuteResponse::ValidatedConnection => {
1854                command_complete!()
1855            }
1856        };
1857
1858        assert_none!(tag, "tag created but not consumed: {:?}", tag);
1859        r
1860    }
1861
1862    #[allow(clippy::too_many_arguments)]
1863    // TODO(guswynn): figure out how to get it to compile without skip_all
1864    #[mz_ore::instrument(level = "debug")]
1865    async fn send_rows(
1866        &mut self,
1867        row_desc: RelationDesc,
1868        portal_name: String,
1869        mut rows: InProgressRows,
1870        max_rows: ExecuteCount,
1871        get_response: GetResponse,
1872        fetch_portal_name: Option<String>,
1873        timeout: ExecuteTimeout,
1874    ) -> Result<(State, SendRowsEndedReason), io::Error> {
1875        // If this portal is being executed from a FETCH then we need to use the result
1876        // format type of the outer portal.
1877        let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
1878            name
1879        } else {
1880            &portal_name
1881        };
1882        let result_formats = self
1883            .adapter_client
1884            .session()
1885            .get_portal_unverified(result_format_portal_name)
1886            .expect("valid fetch portal name for send rows")
1887            .result_formats
1888            .clone();
1889
1890        let (mut wait_once, mut deadline) = match timeout {
1891            ExecuteTimeout::None => (false, None),
1892            ExecuteTimeout::Seconds(t) => (
1893                false,
1894                Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
1895            ),
1896            ExecuteTimeout::WaitOnce => (true, None),
1897        };
1898
1899        self.conn.set_encode_state(
1900            row_desc
1901                .typ()
1902                .column_types
1903                .iter()
1904                .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
1905                .zip(result_formats)
1906                .collect(),
1907        );
1908
1909        let mut total_sent_rows = 0;
1910        let mut total_sent_bytes = 0;
1911        // want_rows is the maximum number of rows the client wants.
1912        let mut want_rows = match max_rows {
1913            ExecuteCount::All => usize::MAX,
1914            ExecuteCount::Count(count) => count,
1915        };
1916
1917        // Send rows while the client still wants them and there are still rows to send.
1918        loop {
1919            // Fetch next batch of rows, waiting for a possible requested
1920            // timeout or notice.
1921            let batch = if rows.current.is_some() {
1922                FetchResult::Rows(rows.current.take())
1923            } else if want_rows == 0 {
1924                FetchResult::Rows(None)
1925            } else {
1926                let notice_fut = self.adapter_client.session().recv_notice();
1927                tokio::select! {
1928                    err = self.conn.wait_closed() => return Err(err),
1929                    _ = time::sleep_until(deadline.unwrap_or_else(tokio::time::Instant::now)), if deadline.is_some() => FetchResult::Rows(None),
1930                    notice = notice_fut => {
1931                        FetchResult::Notice(notice)
1932                    }
1933                    batch = rows.remaining.recv() => match batch {
1934                        None => FetchResult::Rows(None),
1935                        Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
1936                        Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
1937                        Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
1938                    },
1939                }
1940            };
1941
1942            match batch {
1943                FetchResult::Rows(None) => break,
1944                FetchResult::Rows(Some(mut batch_rows)) => {
1945                    if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
1946                        let msg = err.to_string();
1947                        return self
1948                            .error(err.into_response(Severity::Error))
1949                            .await
1950                            .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
1951                    }
1952
1953                    // If wait_once is true: the first time this fn is called it blocks (same as
1954                    // deadline == None). The second time this fn is called it should behave the
1955                    // same a 0s timeout.
1956                    if wait_once && batch_rows.peek().is_some() {
1957                        deadline = Some(tokio::time::Instant::now());
1958                        wait_once = false;
1959                    }
1960
1961                    // Send a portion of the rows.
1962                    let mut sent_rows = 0;
1963                    let mut sent_bytes = 0;
1964                    let messages = (&mut batch_rows)
1965                        // TODO(parkmycar): This is a fair bit of juggling between iterator types
1966                        // to count the total number of bytes. Alternatively we could track the
1967                        // total sent bytes in this .map(...) call, but having side effects in map
1968                        // is a code smell.
1969                        .map(|row| {
1970                            let row_len = row.byte_len();
1971                            let values = mz_pgrepr::values_from_row(row, row_desc.typ());
1972                            (row_len, BackendMessage::DataRow(values))
1973                        })
1974                        .inspect(|(row_len, _)| {
1975                            sent_bytes += row_len;
1976                            sent_rows += 1
1977                        })
1978                        .map(|(_row_len, row)| row)
1979                        .take(want_rows);
1980                    self.send_all(messages).await?;
1981
1982                    total_sent_rows += sent_rows;
1983                    total_sent_bytes += sent_bytes;
1984                    want_rows -= sent_rows;
1985
1986                    // If we have sent the number of requested rows, put the remainder of the batch
1987                    // (if any) back and stop sending.
1988                    if want_rows == 0 {
1989                        if batch_rows.peek().is_some() {
1990                            rows.current = Some(batch_rows);
1991                        }
1992                        break;
1993                    }
1994
1995                    self.conn.flush().await?;
1996                }
1997                FetchResult::Notice(notice) => {
1998                    self.send(notice.into_response()).await?;
1999                    self.conn.flush().await?;
2000                }
2001                FetchResult::Error(text) => {
2002                    return self
2003                        .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2004                        .await
2005                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2006                }
2007                FetchResult::Canceled => {
2008                    return self
2009                        .error(ErrorResponse::error(
2010                            SqlState::QUERY_CANCELED,
2011                            "canceling statement due to user request",
2012                        ))
2013                        .await
2014                        .map(|state| (state, SendRowsEndedReason::Canceled));
2015                }
2016            }
2017        }
2018
2019        let portal = self
2020            .adapter_client
2021            .session()
2022            .get_portal_unverified_mut(&portal_name)
2023            .expect("valid portal name for send rows");
2024
2025        // Always return rows back, even if it's empty. This prevents an unclosed
2026        // portal from re-executing after it has been emptied.
2027        portal.state = PortalState::InProgress(Some(rows));
2028
2029        let fetch_portal = fetch_portal_name.map(|name| {
2030            self.adapter_client
2031                .session()
2032                .get_portal_unverified_mut(&name)
2033                .expect("valid fetch portal")
2034        });
2035        let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2036        self.send(response_message).await?;
2037        Ok((
2038            State::Ready,
2039            SendRowsEndedReason::Success {
2040                result_size: u64::cast_from(total_sent_bytes),
2041                rows_returned: u64::cast_from(total_sent_rows),
2042            },
2043        ))
2044    }
2045
2046    #[mz_ore::instrument(level = "debug")]
2047    async fn copy_rows(
2048        &mut self,
2049        format: CopyFormat,
2050        row_desc: RelationDesc,
2051        mut stream: RecordFirstRowStream,
2052    ) -> Result<(State, SendRowsEndedReason), io::Error> {
2053        let (row_format, encode_format) = match format {
2054            CopyFormat::Text => (
2055                CopyFormatParams::Text(CopyTextFormatParams::default()),
2056                Format::Text,
2057            ),
2058            CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2059            CopyFormat::Csv => (
2060                CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2061                Format::Text,
2062            ),
2063            CopyFormat::Parquet => {
2064                let text = "Parquet format is not supported".to_string();
2065                return self
2066                    .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2067                    .await
2068                    .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2069            }
2070        };
2071
2072        let encode_fn = |row: &RowRef, typ: &RelationType, out: &mut Vec<u8>| {
2073            mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2074        };
2075
2076        let typ = row_desc.typ();
2077        let column_formats = iter::repeat(encode_format)
2078            .take(typ.column_types.len())
2079            .collect();
2080        self.send(BackendMessage::CopyOutResponse {
2081            overall_format: encode_format,
2082            column_formats,
2083        })
2084        .await?;
2085
2086        // In Postgres, binary copy has a header that is followed (in the same
2087        // CopyData) by the first row. In order to replicate their behavior, use a
2088        // common vec that we can extend one time now and then fill up with the encode
2089        // functions.
2090        let mut out = Vec::new();
2091
2092        if let CopyFormat::Binary = format {
2093            // 11-byte signature.
2094            out.extend(b"PGCOPY\n\xFF\r\n\0");
2095            // 32-bit flags field.
2096            out.extend([0, 0, 0, 0]);
2097            // 32-bit header extension length field.
2098            out.extend([0, 0, 0, 0]);
2099        }
2100
2101        let mut count = 0;
2102        let mut total_sent_bytes = 0;
2103        loop {
2104            tokio::select! {
2105                e = self.conn.wait_closed() => return Err(e),
2106                batch = stream.recv() => match batch {
2107                    None => break,
2108                    Some(PeekResponseUnary::Error(text)) => {
2109                        return self
2110                            .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2111                        .await
2112                        .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2113                    }
2114                    Some(PeekResponseUnary::Canceled) => {
2115                        return self.error(ErrorResponse::error(
2116                                SqlState::QUERY_CANCELED,
2117                                "canceling statement due to user request",
2118                            ))
2119                            .await.map(|state| (state, SendRowsEndedReason::Canceled));
2120                    }
2121                    Some(PeekResponseUnary::Rows(mut rows)) => {
2122                        count += rows.count();
2123                        while let Some(row) = rows.next() {
2124                            total_sent_bytes += row.byte_len();
2125                            encode_fn(row, typ, &mut out)?;
2126                            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2127                                .await?;
2128                        }
2129                    }
2130                },
2131                notice = self.adapter_client.session().recv_notice() => {
2132                    self.send(notice.into_response())
2133                        .await?;
2134                    self.conn.flush().await?;
2135                }
2136            }
2137
2138            self.conn.flush().await?;
2139        }
2140        // Send required trailers.
2141        if let CopyFormat::Binary = format {
2142            let trailer: i16 = -1;
2143            out.extend(trailer.to_be_bytes());
2144            self.send(BackendMessage::CopyData(mem::take(&mut out)))
2145                .await?;
2146        }
2147
2148        let tag = format!("COPY {}", count);
2149        self.send(BackendMessage::CopyDone).await?;
2150        self.send(BackendMessage::CommandComplete { tag }).await?;
2151        Ok((
2152            State::Ready,
2153            SendRowsEndedReason::Success {
2154                result_size: u64::cast_from(total_sent_bytes),
2155                rows_returned: u64::cast_from(count),
2156            },
2157        ))
2158    }
2159
2160    /// Handles the copy-in mode of the postgres protocol from transferring
2161    /// data to the server.
2162    #[instrument(level = "debug")]
2163    async fn copy_from(
2164        &mut self,
2165        id: CatalogItemId,
2166        columns: Vec<ColumnIndex>,
2167        params: CopyFormatParams<'_>,
2168        row_desc: RelationDesc,
2169        mut ctx_extra: ExecuteContextExtra,
2170    ) -> Result<State, io::Error> {
2171        let res = self
2172            .copy_from_inner(id, columns, params, row_desc, &mut ctx_extra)
2173            .await;
2174        match &res {
2175            Ok(State::Done) => {
2176                // The connection closed gracefully without sending us a `CopyDone`,
2177                // causing us to just drop the copy request.
2178                // For the purposes of statement logging, we count this as a cancellation.
2179                self.adapter_client
2180                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2181            }
2182            Err(e) => {
2183                self.adapter_client.retire_execute(
2184                    ctx_extra,
2185                    StatementEndedExecutionReason::Errored {
2186                        error: format!("{e}"),
2187                    },
2188                );
2189            }
2190            other => {
2191                tracing::warn!(?other, "aborting COPY FROM");
2192                self.adapter_client
2193                    .retire_execute(ctx_extra, StatementEndedExecutionReason::Aborted);
2194            }
2195        }
2196        res
2197    }
2198
2199    async fn copy_from_inner(
2200        &mut self,
2201        id: CatalogItemId,
2202        columns: Vec<ColumnIndex>,
2203        params: CopyFormatParams<'_>,
2204        row_desc: RelationDesc,
2205        ctx_extra: &mut ExecuteContextExtra,
2206    ) -> Result<State, io::Error> {
2207        let typ = row_desc.typ();
2208        let column_formats = vec![Format::Text; typ.column_types.len()];
2209        self.send(BackendMessage::CopyInResponse {
2210            overall_format: Format::Text,
2211            column_formats,
2212        })
2213        .await?;
2214        self.conn.flush().await?;
2215
2216        let system_vars = self.adapter_client.get_system_vars().await;
2217        let max_size = system_vars
2218            .get(MAX_COPY_FROM_SIZE.name())
2219            .ok()
2220            .and_then(|max_size| max_size.value().parse().ok())
2221            .unwrap_or(usize::MAX);
2222        tracing::debug!("COPY FROM max buffer size: {max_size} bytes");
2223
2224        let mut data = Vec::new();
2225        loop {
2226            let message = self.conn.recv().await?;
2227            match message {
2228                Some(FrontendMessage::CopyData(buf)) => {
2229                    // Bail before we OOM.
2230                    if (data.len() + buf.len()) > max_size {
2231                        return self
2232                            .error(ErrorResponse::error(
2233                                SqlState::INSUFFICIENT_RESOURCES,
2234                                "COPY FROM STDIN too large",
2235                            ))
2236                            .await;
2237                    }
2238                    data.extend(buf)
2239                }
2240                Some(FrontendMessage::CopyDone) => break,
2241                Some(FrontendMessage::CopyFail(err)) => {
2242                    self.adapter_client.retire_execute(
2243                        std::mem::take(ctx_extra),
2244                        StatementEndedExecutionReason::Canceled,
2245                    );
2246                    return self
2247                        .error(ErrorResponse::error(
2248                            SqlState::QUERY_CANCELED,
2249                            format!("COPY from stdin failed: {}", err),
2250                        ))
2251                        .await;
2252                }
2253                Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2254                Some(_) => {
2255                    let msg = "unexpected message type during COPY from stdin";
2256                    self.adapter_client.retire_execute(
2257                        std::mem::take(ctx_extra),
2258                        StatementEndedExecutionReason::Errored {
2259                            error: msg.to_string(),
2260                        },
2261                    );
2262                    return self
2263                        .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
2264                        .await;
2265                }
2266                None => {
2267                    return Ok(State::Done);
2268                }
2269            }
2270        }
2271
2272        let column_types = typ
2273            .column_types
2274            .iter()
2275            .map(|x| &x.scalar_type)
2276            .map(mz_pgrepr::Type::from)
2277            .collect::<Vec<mz_pgrepr::Type>>();
2278
2279        let rows = match mz_pgcopy::decode_copy_format(&data, &column_types, params) {
2280            Ok(rows) => rows,
2281            Err(e) => {
2282                self.adapter_client.retire_execute(
2283                    std::mem::take(ctx_extra),
2284                    StatementEndedExecutionReason::Errored {
2285                        error: e.to_string(),
2286                    },
2287                );
2288                return self
2289                    .error(ErrorResponse::error(
2290                        SqlState::BAD_COPY_FILE_FORMAT,
2291                        format!("{}", e),
2292                    ))
2293                    .await;
2294            }
2295        };
2296
2297        let count = rows.len();
2298
2299        if let Err(e) = self
2300            .adapter_client
2301            .insert_rows(id, columns, rows, std::mem::take(ctx_extra))
2302            .await
2303        {
2304            self.adapter_client.retire_execute(
2305                std::mem::take(ctx_extra),
2306                StatementEndedExecutionReason::Errored {
2307                    error: e.to_string(),
2308                },
2309            );
2310            return self.error(e.into_response(Severity::Error)).await;
2311        }
2312
2313        let tag = format!("COPY {}", count);
2314        self.send(BackendMessage::CommandComplete { tag }).await?;
2315
2316        Ok(State::Ready)
2317    }
2318
2319    #[instrument(level = "debug")]
2320    async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
2321        let notices = self
2322            .adapter_client
2323            .session()
2324            .drain_notices()
2325            .into_iter()
2326            .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
2327        self.send_all(notices).await?;
2328        Ok(())
2329    }
2330
2331    #[instrument(level = "debug")]
2332    async fn error(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
2333        assert!(err.severity.is_error());
2334        debug!(
2335            "cid={} error code={}",
2336            self.adapter_client.session().conn_id(),
2337            err.code.code()
2338        );
2339        let is_fatal = err.severity.is_fatal();
2340        self.send(BackendMessage::ErrorResponse(err)).await?;
2341
2342        let txn = self.adapter_client.session().transaction();
2343        match txn {
2344            // Error can be called from describe and parse and so might not be in an active
2345            // transaction.
2346            TransactionStatus::Default | TransactionStatus::Failed(_) => {}
2347            // In Started (i.e., a single statement), cleanup ourselves.
2348            TransactionStatus::Started(_) => {
2349                self.rollback_transaction().await?;
2350            }
2351            // Implicit transactions also clear themselves.
2352            TransactionStatus::InTransactionImplicit(_) => {
2353                self.rollback_transaction().await?;
2354            }
2355            // Explicit transactions move to failed.
2356            TransactionStatus::InTransaction(_) => {
2357                self.adapter_client.fail_transaction();
2358            }
2359        };
2360        if is_fatal {
2361            Ok(State::Done)
2362        } else {
2363            Ok(State::Drain)
2364        }
2365    }
2366
2367    #[instrument(level = "debug")]
2368    async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
2369        self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
2370            SqlState::IN_FAILED_SQL_TRANSACTION,
2371            ABORTED_TXN_MSG,
2372        )))
2373        .await?;
2374        Ok(State::Drain)
2375    }
2376
2377    fn is_aborted_txn(&mut self) -> bool {
2378        matches!(
2379            self.adapter_client.session().transaction(),
2380            TransactionStatus::Failed(_)
2381        )
2382    }
2383}
2384
2385fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
2386    match (formats.len(), n) {
2387        (0, e) => Ok(vec![Format::Text; e]),
2388        (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
2389        (a, e) if a == e => Ok(formats),
2390        (a, e) => Err(format!(
2391            "expected {} field format specifiers, but got {}",
2392            e, a
2393        )),
2394    }
2395}
2396
2397fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
2398    match &stmt_desc.relation_desc {
2399        Some(desc) if !stmt_desc.is_copy => {
2400            BackendMessage::RowDescription(message::encode_row_description(desc, formats))
2401        }
2402        _ => BackendMessage::NoData,
2403    }
2404}
2405
2406type GetResponse = fn(
2407    max_rows: ExecuteCount,
2408    total_sent_rows: usize,
2409    fetch_portal: Option<&mut Portal>,
2410) -> BackendMessage;
2411
2412// A GetResponse used by send_rows during execute messages on portals or for
2413// simple query messages.
2414fn portal_exec_message(
2415    max_rows: ExecuteCount,
2416    total_sent_rows: usize,
2417    _fetch_portal: Option<&mut Portal>,
2418) -> BackendMessage {
2419    // If max_rows is not specified, we will always send back a CommandComplete. If
2420    // max_rows is specified, we only send CommandComplete if there were more rows
2421    // requested than were remaining. That is, if max_rows == number of rows that
2422    // were remaining before sending (not that are remaining after sending), then
2423    // we still send a PortalSuspended. The number of remaining rows after the rows
2424    // have been sent doesn't matter. This matches postgres.
2425    match max_rows {
2426        ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
2427            BackendMessage::PortalSuspended
2428        }
2429        _ => BackendMessage::CommandComplete {
2430            tag: format!("SELECT {}", total_sent_rows),
2431        },
2432    }
2433}
2434
2435// A GetResponse used by send_rows during FETCH queries.
2436fn fetch_message(
2437    _max_rows: ExecuteCount,
2438    total_sent_rows: usize,
2439    fetch_portal: Option<&mut Portal>,
2440) -> BackendMessage {
2441    let tag = format!("FETCH {}", total_sent_rows);
2442    if let Some(portal) = fetch_portal {
2443        portal.state = PortalState::Completed(Some(tag.clone()));
2444    }
2445    BackendMessage::CommandComplete { tag }
2446}
2447
2448#[derive(Debug, Copy, Clone)]
2449enum ExecuteCount {
2450    All,
2451    Count(usize),
2452}
2453
2454// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
2455fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
2456    match stmt {
2457        // Add PREPARE to this if we ever support it.
2458        Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
2459        None => false,
2460    }
2461}
2462
2463#[derive(Debug)]
2464enum FetchResult {
2465    Rows(Option<Box<dyn RowIterator + Send + Sync>>),
2466    Canceled,
2467    Error(String),
2468    Notice(AdapterNotice),
2469}
2470
2471#[cfg(test)]
2472mod test {
2473    use super::*;
2474
2475    #[mz_ore::test]
2476    fn test_parse_options() {
2477        struct TestCase {
2478            input: &'static str,
2479            expect: Result<Vec<(&'static str, &'static str)>, ()>,
2480        }
2481        let tests = vec![
2482            TestCase {
2483                input: "",
2484                expect: Ok(vec![]),
2485            },
2486            TestCase {
2487                input: "--key",
2488                expect: Err(()),
2489            },
2490            TestCase {
2491                input: "--key=val",
2492                expect: Ok(vec![("key", "val")]),
2493            },
2494            TestCase {
2495                input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
2496                expect: Ok(vec![
2497                    ("key", "val"),
2498                    ("key2", "val2"),
2499                    ("key3", "val3"),
2500                    ("key4", "val4"),
2501                    ("key5", "val5"),
2502                ]),
2503            },
2504            TestCase {
2505                input: r#"-c\ key=val"#,
2506                expect: Ok(vec![(" key", "val")]),
2507            },
2508            TestCase {
2509                input: "--key=val -ckey2 val2",
2510                expect: Err(()),
2511            },
2512            // Unclear what this should do.
2513            TestCase {
2514                input: "--key=",
2515                expect: Ok(vec![("key", "")]),
2516            },
2517        ];
2518        for test in tests {
2519            let got = parse_options(test.input);
2520            let expect = test.expect.map(|r| {
2521                r.into_iter()
2522                    .map(|(k, v)| (k.to_owned(), v.to_owned()))
2523                    .collect()
2524            });
2525            assert_eq!(got, expect, "input: {}", test.input);
2526        }
2527    }
2528
2529    #[mz_ore::test]
2530    fn test_parse_option() {
2531        struct TestCase {
2532            input: &'static str,
2533            expect: Result<(&'static str, &'static str), ()>,
2534        }
2535        let tests = vec![
2536            TestCase {
2537                input: "",
2538                expect: Err(()),
2539            },
2540            TestCase {
2541                input: "--",
2542                expect: Err(()),
2543            },
2544            TestCase {
2545                input: "--c",
2546                expect: Err(()),
2547            },
2548            TestCase {
2549                input: "a=b",
2550                expect: Err(()),
2551            },
2552            TestCase {
2553                input: "--a=b",
2554                expect: Ok(("a", "b")),
2555            },
2556            TestCase {
2557                input: "--ca=b",
2558                expect: Ok(("ca", "b")),
2559            },
2560            TestCase {
2561                input: "-ca=b",
2562                expect: Ok(("a", "b")),
2563            },
2564            // Unclear what this should error, but at least test it.
2565            TestCase {
2566                input: "--=",
2567                expect: Ok(("", "")),
2568            },
2569        ];
2570        for test in tests {
2571            let got = parse_option(test.input);
2572            assert_eq!(got, test.expect, "input: {}", test.input);
2573        }
2574    }
2575
2576    #[mz_ore::test]
2577    fn test_split_options() {
2578        struct TestCase {
2579            input: &'static str,
2580            expect: Vec<&'static str>,
2581        }
2582        let tests = vec![
2583            TestCase {
2584                input: "",
2585                expect: vec![],
2586            },
2587            TestCase {
2588                input: "  ",
2589                expect: vec![],
2590            },
2591            TestCase {
2592                input: " a ",
2593                expect: vec!["a"],
2594            },
2595            TestCase {
2596                input: "  ab     cd   ",
2597                expect: vec!["ab", "cd"],
2598            },
2599            TestCase {
2600                input: r#"  ab\     cd   "#,
2601                expect: vec!["ab ", "cd"],
2602            },
2603            TestCase {
2604                input: r#"  ab\\     cd   "#,
2605                expect: vec![r#"ab\"#, "cd"],
2606            },
2607            TestCase {
2608                input: r#"  ab\\\     cd   "#,
2609                expect: vec![r#"ab\ "#, "cd"],
2610            },
2611            TestCase {
2612                input: r#"  ab\\\ cd   "#,
2613                expect: vec![r#"ab\ cd"#],
2614            },
2615            TestCase {
2616                input: r#"  ab\\\cd   "#,
2617                expect: vec![r#"ab\cd"#],
2618            },
2619            TestCase {
2620                input: r#"a\"#,
2621                expect: vec!["a"],
2622            },
2623            TestCase {
2624                input: r#"a\ "#,
2625                expect: vec!["a "],
2626            },
2627            TestCase {
2628                input: r#"\"#,
2629                expect: vec![],
2630            },
2631            TestCase {
2632                input: r#"\ "#,
2633                expect: vec![r#" "#],
2634            },
2635            TestCase {
2636                input: r#" \ "#,
2637                expect: vec![r#" "#],
2638            },
2639            TestCase {
2640                input: r#"\  "#,
2641                expect: vec![r#" "#],
2642            },
2643        ];
2644        for test in tests {
2645            let got = split_options(test.input);
2646            assert_eq!(got, test.expect, "input: {}", test.input);
2647        }
2648    }
2649}