mz_environmentd/http/
sql.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::borrow::Cow;
11use std::collections::BTreeMap;
12use std::net::{IpAddr, SocketAddr};
13use std::pin::pin;
14use std::sync::Arc;
15use std::time::Duration;
16
17use anyhow::anyhow;
18use async_trait::async_trait;
19use axum::extract::connect_info::ConnectInfo;
20use axum::extract::ws::{CloseFrame, Message, Utf8Bytes, WebSocket};
21use axum::extract::{State, WebSocketUpgrade};
22use axum::response::IntoResponse;
23use axum::{Extension, Json};
24use futures::Future;
25use futures::future::BoxFuture;
26
27use http::StatusCode;
28use itertools::izip;
29use mz_adapter::client::RecordFirstRowStream;
30use mz_adapter::session::{EndTransactionAction, TransactionStatus};
31use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
32use mz_adapter::{
33    AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, ExecuteResponseKind,
34    PeekResponseUnary, SessionClient, verify_datum_desc,
35};
36use mz_catalog::memory::objects::{Cluster, ClusterReplica};
37use mz_interchange::encode::TypedDatum;
38use mz_interchange::json::{JsonNumberPolicy, ToJson};
39use mz_ore::cast::CastFrom;
40use mz_ore::metrics::{MakeCollectorOpts, MetricsRegistry};
41use mz_ore::result::ResultExt;
42use mz_repr::{Datum, RelationDesc, RowArena, RowIterator};
43use mz_sql::ast::display::AstDisplay;
44use mz_sql::ast::{CopyDirection, CopyStatement, CopyTarget, Raw, Statement, StatementKind};
45use mz_sql::parse::StatementParseResult;
46use mz_sql::plan::Plan;
47use mz_sql::session::metadata::SessionMetadata;
48use prometheus::Opts;
49use prometheus::core::{AtomicF64, GenericGaugeVec};
50use serde::{Deserialize, Serialize};
51use tokio::{select, time};
52use tokio_postgres::error::SqlState;
53use tokio_stream::wrappers::UnboundedReceiverStream;
54use tracing::{debug, error, info};
55use tungstenite::protocol::frame::coding::CloseCode;
56
57use crate::http::prometheus::PrometheusSqlQuery;
58use crate::http::{AuthedClient, AuthedUser, MAX_REQUEST_SIZE, WsState, init_ws};
59
60#[derive(Debug, thiserror::Error)]
61pub enum Error {
62    #[error(transparent)]
63    Adapter(#[from] AdapterError),
64    #[error(transparent)]
65    Json(#[from] serde_json::Error),
66    #[error(transparent)]
67    Axum(#[from] axum::Error),
68    #[error("SUBSCRIBE only supported over websocket")]
69    SubscribeOnlyOverWs,
70    #[error("current transaction is aborted, commands ignored until end of transaction block")]
71    AbortedTransaction,
72    #[error("unsupported via this API: {0}")]
73    Unsupported(String),
74    #[error("{0}")]
75    Unstructured(anyhow::Error),
76}
77
78impl Error {
79    pub fn detail(&self) -> Option<String> {
80        match self {
81            Error::Adapter(err) => err.detail(),
82            _ => None,
83        }
84    }
85
86    pub fn hint(&self) -> Option<String> {
87        match self {
88            Error::Adapter(err) => err.hint(),
89            _ => None,
90        }
91    }
92
93    pub fn position(&self) -> Option<usize> {
94        match self {
95            Error::Adapter(err) => err.position(),
96            _ => None,
97        }
98    }
99
100    pub fn code(&self) -> SqlState {
101        match self {
102            Error::Adapter(err) => err.code(),
103            Error::AbortedTransaction => SqlState::IN_FAILED_SQL_TRANSACTION,
104            _ => SqlState::INTERNAL_ERROR,
105        }
106    }
107}
108
109static PER_REPLICA_LABELS: &[&str] = &["replica_full_name", "instance_id", "replica_id"];
110
111async fn execute_promsql_query(
112    client: &mut AuthedClient,
113    query: &PrometheusSqlQuery<'_>,
114    metrics_registry: &MetricsRegistry,
115    metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
116    cluster: Option<(&Cluster, &ClusterReplica)>,
117) {
118    assert_eq!(query.per_replica, cluster.is_some());
119
120    let mut res = SqlResponse {
121        results: Vec::new(),
122    };
123
124    execute_request(client, query.to_sql_request(cluster), &mut res)
125        .await
126        .expect("valid SQL query");
127
128    let result = match res.results.as_slice() {
129        // Each query issued is preceded by several SET commands
130        // to make sure it is routed to the right cluster replica.
131        [
132            SqlResult::Ok { .. },
133            SqlResult::Ok { .. },
134            SqlResult::Ok { .. },
135            result,
136        ] => result,
137        // Transient errors are fine, like if the cluster or replica
138        // was dropped before the promsql query was executed. We
139        // should not see errors in the steady state.
140        _ => {
141            info!(
142                "error executing prometheus query {}: {:?}",
143                query.metric_name, res
144            );
145            return;
146        }
147    };
148
149    let SqlResult::Rows { desc, rows, .. } = result else {
150        info!(
151            "did not receive rows for SQL query for prometheus metric {}: {:?}, {:?}",
152            query.metric_name, result, cluster
153        );
154        return;
155    };
156
157    let gauge_vec = metrics_by_name
158        .entry(query.metric_name.to_string())
159        .or_insert_with(|| {
160            let mut label_names: Vec<String> = desc
161                .columns
162                .iter()
163                .filter(|col| col.name != query.value_column_name)
164                .map(|col| col.name.clone())
165                .collect();
166
167            if query.per_replica {
168                label_names.extend(PER_REPLICA_LABELS.iter().map(|label| label.to_string()));
169            }
170
171            metrics_registry.register::<GenericGaugeVec<AtomicF64>>(MakeCollectorOpts {
172                opts: Opts::new(query.metric_name, query.help).variable_labels(label_names),
173                buckets: None,
174            })
175        });
176
177    for row in rows {
178        let mut label_values = desc
179            .columns
180            .iter()
181            .zip(row)
182            .filter(|(col, _)| col.name != query.value_column_name)
183            .map(|(_, val)| val.as_str().expect("must be string"))
184            .collect::<Vec<_>>();
185
186        let value = desc
187            .columns
188            .iter()
189            .zip(row)
190            .find(|(col, _)| col.name == query.value_column_name)
191            .map(|(_, val)| val.as_str().unwrap_or("0").parse::<f64>().unwrap_or(0.0))
192            .unwrap_or(0.0);
193
194        match cluster {
195            Some((cluster, replica)) => {
196                let replica_full_name = format!("{}.{}", cluster.name, replica.name);
197                let cluster_id = cluster.id.to_string();
198                let replica_id = replica.replica_id.to_string();
199
200                label_values.push(&replica_full_name);
201                label_values.push(&cluster_id);
202                label_values.push(&replica_id);
203
204                gauge_vec
205                    .get_metric_with_label_values(&label_values)
206                    .expect("valid labels")
207                    .set(value);
208            }
209            None => {
210                gauge_vec
211                    .get_metric_with_label_values(&label_values)
212                    .expect("valid labels")
213                    .set(value);
214            }
215        }
216    }
217}
218
219async fn handle_promsql_query(
220    client: &mut AuthedClient,
221    query: &PrometheusSqlQuery<'_>,
222    metrics_registry: &MetricsRegistry,
223    metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
224) {
225    if !query.per_replica {
226        execute_promsql_query(client, query, metrics_registry, metrics_by_name, None).await;
227        return;
228    }
229
230    let catalog = client.client.catalog_snapshot().await;
231    let clusters: Vec<&Cluster> = catalog.clusters().collect();
232
233    for cluster in clusters {
234        for replica in cluster.replicas() {
235            execute_promsql_query(
236                client,
237                query,
238                metrics_registry,
239                metrics_by_name,
240                Some((cluster, replica)),
241            )
242            .await;
243        }
244    }
245}
246
247pub async fn handle_promsql(
248    mut client: AuthedClient,
249    queries: &[PrometheusSqlQuery<'_>],
250) -> MetricsRegistry {
251    let metrics_registry = MetricsRegistry::new();
252    let mut metrics_by_name = BTreeMap::new();
253
254    for query in queries {
255        handle_promsql_query(&mut client, query, &metrics_registry, &mut metrics_by_name).await;
256    }
257
258    metrics_registry
259}
260
261pub async fn handle_sql(
262    mut client: AuthedClient,
263    Json(request): Json<SqlRequest>,
264) -> impl IntoResponse {
265    let mut res = SqlResponse {
266        results: Vec::new(),
267    };
268    // Don't need to worry about timeouts or resetting cancel here because there is always exactly 1
269    // request.
270    match execute_request(&mut client, request, &mut res).await {
271        Ok(()) => Ok(Json(res)),
272        Err(e) => Err((StatusCode::BAD_REQUEST, e.to_string())),
273    }
274}
275
276pub async fn handle_sql_ws(
277    State(state): State<WsState>,
278    existing_user: Option<Extension<AuthedUser>>,
279    ws: WebSocketUpgrade,
280    ConnectInfo(addr): ConnectInfo<SocketAddr>,
281) -> impl IntoResponse {
282    // An upstream middleware may have already provided the user for us
283    let user = existing_user.and_then(|Extension(user)| Some(user));
284    let addr = Box::new(addr.ip());
285    ws.max_message_size(MAX_REQUEST_SIZE)
286        .on_upgrade(|ws| async move { run_ws(&state, user, *addr, ws).await })
287}
288
289#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
290#[serde(untagged)]
291pub enum WebSocketAuth {
292    Basic {
293        user: String,
294        password: String,
295        #[serde(default)]
296        options: BTreeMap<String, String>,
297    },
298    Bearer {
299        token: String,
300        #[serde(default)]
301        options: BTreeMap<String, String>,
302    },
303    OptionsOnly {
304        #[serde(default)]
305        options: BTreeMap<String, String>,
306    },
307}
308
309async fn run_ws(state: &WsState, user: Option<AuthedUser>, peer_addr: IpAddr, mut ws: WebSocket) {
310    let mut client = match init_ws(state, user, peer_addr, &mut ws).await {
311        Ok(client) => client,
312        Err(e) => {
313            // We omit most detail from the error message we send to the client, to
314            // avoid giving attackers unnecessary information during auth. AdapterErrors
315            // are safe to return because they're generated after authentication.
316            debug!("WS request failed init: {}", e);
317            let reason = match e.downcast_ref::<AdapterError>() {
318                Some(error) => Cow::Owned(error.to_string()),
319                None => "unauthorized".into(),
320            };
321            let _ = ws
322                .send(Message::Close(Some(CloseFrame {
323                    code: CloseCode::Protocol.into(),
324                    reason: Utf8Bytes::from(reason.as_ref()),
325                })))
326                .await;
327            return;
328        }
329    };
330
331    // Successful auth, send startup messages.
332    let mut msgs = Vec::new();
333    let session = client.client.session();
334    for var in session.vars().notify_set() {
335        msgs.push(WebSocketResponse::ParameterStatus(ParameterStatus {
336            name: var.name().to_string(),
337            value: var.value(),
338        }));
339    }
340    msgs.push(WebSocketResponse::BackendKeyData(BackendKeyData {
341        conn_id: session.conn_id().unhandled(),
342        secret_key: session.secret_key(),
343    }));
344    msgs.push(WebSocketResponse::ReadyForQuery(
345        session.transaction_code().into(),
346    ));
347    for msg in msgs {
348        let _ = ws
349            .send(Message::Text(
350                serde_json::to_string(&msg).expect("must serialize").into(),
351            ))
352            .await;
353    }
354
355    // Send any notices that might have been generated on startup.
356    let notices = session.drain_notices();
357    if let Err(err) = forward_notices(&mut ws, notices).await {
358        debug!("failed to forward notices to WebSocket, {err:?}");
359        return;
360    }
361
362    loop {
363        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
364        let msg = select! {
365            biased;
366
367            // `recv_timeout()` is cancel-safe as per it's docs.
368            Some(timeout) = client.client.recv_timeout() => {
369                client.client.terminate().await;
370                // We must wait for the client to send a request before we can send the error
371                // response. Although this isn't the PG wire protocol, we choose to mirror it by
372                // only sending errors as responses to requests.
373                let _ = ws.recv().await;
374                let err = Error::from(AdapterError::from(timeout));
375                let _ = send_ws_response(&mut ws, WebSocketResponse::Error(err.into())).await;
376                return;
377            },
378            message = ws.recv() => message,
379        };
380
381        client.client.remove_idle_in_transaction_session_timeout();
382
383        let msg = match msg {
384            Some(Ok(msg)) => msg,
385            _ => {
386                // client disconnected
387                return;
388            }
389        };
390
391        let req: Result<SqlRequest, Error> = match msg {
392            Message::Text(data) => serde_json::from_str(&data).err_into(),
393            Message::Binary(data) => serde_json::from_slice(&data).err_into(),
394            // Handled automatically by the server.
395            Message::Ping(_) => {
396                continue;
397            }
398            Message::Pong(_) => {
399                continue;
400            }
401            Message::Close(_) => {
402                return;
403            }
404        };
405
406        // Figure out if we need to send an error, any notices, but always the ready message.
407        let err = match run_ws_request(req, &mut client, &mut ws).await {
408            Ok(()) => None,
409            Err(err) => Some(WebSocketResponse::Error(err.into())),
410        };
411
412        // After running our request, there are several messages we need to send in a
413        // specific order.
414        //
415        // Note: we nest these into a closure so we can centralize our error handling
416        // for when sending over the WebSocket fails. We could also use a try {} block
417        // here, but those aren't stabilized yet.
418        let ws_response = || async {
419            // First respond with any error that might have occurred.
420            if let Some(e_resp) = err {
421                send_ws_response(&mut ws, e_resp).await?;
422            }
423
424            // Then forward along any notices we generated.
425            let notices = client.client.session().drain_notices();
426            forward_notices(&mut ws, notices).await?;
427
428            // Finally, respond that we're ready for the next query.
429            let ready =
430                WebSocketResponse::ReadyForQuery(client.client.session().transaction_code().into());
431            send_ws_response(&mut ws, ready).await?;
432
433            Ok::<_, Error>(())
434        };
435
436        if let Err(err) = ws_response().await {
437            debug!("failed to send response over WebSocket, {err:?}");
438            return;
439        }
440    }
441}
442
443async fn run_ws_request(
444    req: Result<SqlRequest, Error>,
445    client: &mut AuthedClient,
446    ws: &mut WebSocket,
447) -> Result<(), Error> {
448    let req = req?;
449    execute_request(client, req, ws).await
450}
451
452/// Sends a single [`WebSocketResponse`] over the provided [`WebSocket`].
453async fn send_ws_response(ws: &mut WebSocket, resp: WebSocketResponse) -> Result<(), Error> {
454    let msg = serde_json::to_string(&resp).unwrap();
455    let msg = Message::Text(msg.into());
456    ws.send(msg).await?;
457
458    Ok(())
459}
460
461/// Forwards a collection of Notices to the provided [`WebSocket`].
462async fn forward_notices(
463    ws: &mut WebSocket,
464    notices: impl IntoIterator<Item = AdapterNotice>,
465) -> Result<(), Error> {
466    let ws_notices = notices.into_iter().map(|notice| {
467        WebSocketResponse::Notice(Notice {
468            message: notice.to_string(),
469            code: notice.code().code().to_string(),
470            severity: notice.severity().as_str().to_lowercase(),
471            detail: notice.detail(),
472            hint: notice.hint(),
473        })
474    });
475
476    for notice in ws_notices {
477        send_ws_response(ws, notice).await?;
478    }
479
480    Ok(())
481}
482
483/// A request to execute SQL over HTTP.
484#[derive(Serialize, Deserialize, Debug)]
485#[serde(untagged)]
486pub enum SqlRequest {
487    /// A simple query request.
488    Simple {
489        /// A query string containing zero or more queries delimited by
490        /// semicolons.
491        query: String,
492    },
493    /// An extended query request.
494    Extended {
495        /// Queries to execute using the extended protocol.
496        queries: Vec<ExtendedRequest>,
497    },
498}
499
500/// An request to execute a SQL query using the extended protocol.
501#[derive(Serialize, Deserialize, Debug)]
502pub struct ExtendedRequest {
503    /// A query string containing zero or one queries.
504    query: String,
505    /// Optional parameters for the query.
506    #[serde(default)]
507    params: Vec<Option<String>>,
508}
509
510/// The response to a `SqlRequest`.
511#[derive(Debug, Serialize, Deserialize)]
512pub struct SqlResponse {
513    /// The results for each query in the request.
514    results: Vec<SqlResult>,
515}
516
517enum StatementResult {
518    SqlResult(SqlResult),
519    Subscribe {
520        desc: RelationDesc,
521        tag: String,
522        rx: RecordFirstRowStream,
523        ctx_extra: ExecuteContextExtra,
524    },
525}
526
527impl From<SqlResult> for StatementResult {
528    fn from(inner: SqlResult) -> Self {
529        Self::SqlResult(inner)
530    }
531}
532
533/// The result of a single query in a [`SqlResponse`].
534#[derive(Debug, Serialize, Deserialize)]
535#[serde(untagged)]
536pub enum SqlResult {
537    /// The query returned rows.
538    Rows {
539        /// The command complete tag.
540        tag: String,
541        /// The result rows.
542        rows: Vec<Vec<serde_json::Value>>,
543        /// Information about each column.
544        desc: Description,
545        // Any notices generated during execution of the query.
546        notices: Vec<Notice>,
547    },
548    /// The query executed successfully but did not return rows.
549    Ok {
550        /// The command complete tag.
551        ok: String,
552        /// Any notices generated during execution of the query.
553        notices: Vec<Notice>,
554        /// Any parameters that may have changed.
555        ///
556        /// Note: skip serializing this field in a response if the list of parameters is empty.
557        #[serde(skip_serializing_if = "Vec::is_empty")]
558        parameters: Vec<ParameterStatus>,
559    },
560    /// The query returned an error.
561    Err {
562        error: SqlError,
563        // Any notices generated during execution of the query.
564        notices: Vec<Notice>,
565    },
566}
567
568impl SqlResult {
569    /// Convert adapter Row results into the web row result format. Error if the row format does not
570    /// match the expected descriptor.
571    fn rows(
572        client: &mut SessionClient,
573        mut sql_rows: Box<dyn RowIterator>,
574        desc: &RelationDesc,
575    ) -> SqlResult {
576        if let Err(err) = verify_datum_desc(desc, &mut sql_rows) {
577            return SqlResult::Err {
578                error: err.into(),
579                notices: make_notices(client),
580            };
581        }
582
583        let mut rows: Vec<Vec<serde_json::Value>> = vec![];
584        let mut datum_vec = mz_repr::DatumVec::new();
585        let types = &desc.typ().column_types;
586
587        while let Some(row) = sql_rows.next() {
588            let datums = datum_vec.borrow_with(row);
589            rows.push(
590                datums
591                    .iter()
592                    .enumerate()
593                    .map(|(i, d)| {
594                        TypedDatum::new(*d, &types[i])
595                            .json(&JsonNumberPolicy::ConvertNumberToString)
596                    })
597                    .collect(),
598            );
599        }
600
601        let tag = format!("SELECT {}", rows.len());
602        SqlResult::Rows {
603            tag,
604            rows,
605            desc: Description::from(desc),
606            notices: make_notices(client),
607        }
608    }
609
610    fn err(client: &mut SessionClient, error: impl Into<SqlError>) -> SqlResult {
611        SqlResult::Err {
612            error: error.into(),
613            notices: make_notices(client),
614        }
615    }
616
617    fn ok(client: &mut SessionClient, tag: String, params: Vec<ParameterStatus>) -> SqlResult {
618        SqlResult::Ok {
619            ok: tag,
620            parameters: params,
621            notices: make_notices(client),
622        }
623    }
624}
625
626#[derive(Debug, Deserialize, Serialize)]
627pub struct SqlError {
628    pub message: String,
629    pub code: String,
630    #[serde(skip_serializing_if = "Option::is_none")]
631    pub detail: Option<String>,
632    #[serde(skip_serializing_if = "Option::is_none")]
633    pub hint: Option<String>,
634    #[serde(skip_serializing_if = "Option::is_none")]
635    pub position: Option<usize>,
636}
637
638impl From<Error> for SqlError {
639    fn from(err: Error) -> Self {
640        SqlError {
641            message: err.to_string(),
642            code: err.code().code().to_string(),
643            detail: err.detail(),
644            hint: err.hint(),
645            position: err.position(),
646        }
647    }
648}
649
650impl From<AdapterError> for SqlError {
651    fn from(value: AdapterError) -> Self {
652        Error::from(value).into()
653    }
654}
655
656#[derive(Debug, Deserialize, Serialize)]
657#[serde(tag = "type", content = "payload")]
658pub enum WebSocketResponse {
659    ReadyForQuery(String),
660    Notice(Notice),
661    Rows(Description),
662    Row(Vec<serde_json::Value>),
663    CommandStarting(CommandStarting),
664    CommandComplete(String),
665    Error(SqlError),
666    ParameterStatus(ParameterStatus),
667    BackendKeyData(BackendKeyData),
668}
669
670#[derive(Debug, Serialize, Deserialize)]
671pub struct Notice {
672    message: String,
673    code: String,
674    severity: String,
675    #[serde(skip_serializing_if = "Option::is_none")]
676    pub detail: Option<String>,
677    #[serde(skip_serializing_if = "Option::is_none")]
678    pub hint: Option<String>,
679}
680
681impl Notice {
682    pub fn message(&self) -> &str {
683        &self.message
684    }
685}
686
687#[derive(Debug, Serialize, Deserialize)]
688pub struct Description {
689    pub columns: Vec<Column>,
690}
691
692impl From<&RelationDesc> for Description {
693    fn from(desc: &RelationDesc) -> Self {
694        let columns = desc
695            .iter()
696            .map(|(name, typ)| {
697                let pg_type = mz_pgrepr::Type::from(&typ.scalar_type);
698                Column {
699                    name: name.to_string(),
700                    type_oid: pg_type.oid(),
701                    type_len: pg_type.typlen(),
702                    type_mod: pg_type.typmod(),
703                }
704            })
705            .collect();
706        Description { columns }
707    }
708}
709
710#[derive(Debug, Serialize, Deserialize)]
711pub struct Column {
712    pub name: String,
713    pub type_oid: u32,
714    pub type_len: i16,
715    pub type_mod: i32,
716}
717
718#[derive(Debug, Serialize, Deserialize)]
719pub struct ParameterStatus {
720    name: String,
721    value: String,
722}
723
724#[derive(Debug, Serialize, Deserialize)]
725pub struct BackendKeyData {
726    conn_id: u32,
727    secret_key: u32,
728}
729
730#[derive(Debug, Serialize, Deserialize)]
731pub struct CommandStarting {
732    has_rows: bool,
733    is_streaming: bool,
734}
735
736/// Trait describing how to transmit a response to a client. HTTP clients
737/// accumulate into a Vec and send all at once. WebSocket clients send each
738/// message as they occur.
739#[async_trait]
740trait ResultSender: Send {
741    const SUPPORTS_STREAMING_NOTICES: bool = false;
742
743    /// Adds a result to the client. The first component of the return value is
744    /// Err if sending to the client
745    /// produced an error and the server should disconnect. It is Ok(Err) if the statement
746    /// produced an error and should error the transaction, but remain connected. It is Ok(Ok(()))
747    /// if the statement succeeded.
748    /// The second component of the return value is `Some` if execution still
749    /// needs to be retired for statement logging purposes.
750    async fn add_result(
751        &mut self,
752        client: &mut SessionClient,
753        res: StatementResult,
754    ) -> (
755        Result<Result<(), ()>, Error>,
756        Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
757    );
758
759    /// Returns a future that resolves only when the client connection has gone away.
760    fn connection_error(&mut self) -> BoxFuture<Error>;
761    /// Reports whether the client supports streaming SUBSCRIBE results.
762    fn allow_subscribe(&self) -> bool;
763
764    /// Emits a streaming notice if the sender supports it.
765    ///
766    /// Does nothing if `SUPPORTS_STREAMING_NOTICES` is false.
767    async fn emit_streaming_notices(&mut self, _: Vec<AdapterNotice>) -> Result<(), Error> {
768        unreachable!("streaming notices marked as unsupported")
769    }
770}
771
772#[async_trait]
773impl ResultSender for SqlResponse {
774    // The first component of the return value is
775    // Err if sending to the client
776    // produced an error and the server should disconnect. It is Ok(Err) if the statement
777    // produced an error and should error the transaction, but remain connected. It is Ok(Ok(()))
778    // if the statement succeeded.
779    // The second component of the return value is `Some` if execution still
780    // needs to be retired for statement logging purposes.
781    async fn add_result(
782        &mut self,
783        _client: &mut SessionClient,
784        res: StatementResult,
785    ) -> (
786        Result<Result<(), ()>, Error>,
787        Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
788    ) {
789        let (res, stmt_logging) = match res {
790            StatementResult::SqlResult(res) => {
791                let is_err = matches!(res, SqlResult::Err { .. });
792                self.results.push(res);
793                let res = if is_err { Err(()) } else { Ok(()) };
794                (res, None)
795            }
796            StatementResult::Subscribe { ctx_extra, .. } => {
797                let message = "SUBSCRIBE only supported over websocket";
798                self.results.push(SqlResult::Err {
799                    error: Error::SubscribeOnlyOverWs.into(),
800                    notices: Vec::new(),
801                });
802                (
803                    Err(()),
804                    Some((
805                        StatementEndedExecutionReason::Errored {
806                            error: message.into(),
807                        },
808                        ctx_extra,
809                    )),
810                )
811            }
812        };
813        (Ok(res), stmt_logging)
814    }
815
816    fn connection_error(&mut self) -> BoxFuture<Error> {
817        Box::pin(futures::future::pending())
818    }
819
820    fn allow_subscribe(&self) -> bool {
821        false
822    }
823}
824
825#[async_trait]
826impl ResultSender for WebSocket {
827    const SUPPORTS_STREAMING_NOTICES: bool = true;
828
829    // The first component of the return value is Err if sending to the client produced an error and
830    // the server should disconnect. It is Ok(Err) if the statement produced an error and should
831    // error the transaction, but remain connected. It is Ok(Ok(())) if the statement succeeded. The
832    // second component of the return value is `Some` if execution still needs to be retired for
833    // statement logging purposes.
834    async fn add_result(
835        &mut self,
836        client: &mut SessionClient,
837        res: StatementResult,
838    ) -> (
839        Result<Result<(), ()>, Error>,
840        Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
841    ) {
842        let (has_rows, is_streaming) = match res {
843            StatementResult::SqlResult(SqlResult::Err { .. }) => (false, false),
844            StatementResult::SqlResult(SqlResult::Ok { .. }) => (false, false),
845            StatementResult::SqlResult(SqlResult::Rows { .. }) => (true, false),
846            StatementResult::Subscribe { .. } => (true, true),
847        };
848        if let Err(e) = send_ws_response(
849            self,
850            WebSocketResponse::CommandStarting(CommandStarting {
851                has_rows,
852                is_streaming,
853            }),
854        )
855        .await
856        {
857            return (Err(e), None);
858        }
859
860        let (is_err, msgs, stmt_logging) = match res {
861            StatementResult::SqlResult(SqlResult::Rows {
862                tag,
863                rows,
864                desc,
865                notices,
866            }) => {
867                let mut msgs = vec![WebSocketResponse::Rows(desc)];
868                msgs.extend(rows.into_iter().map(WebSocketResponse::Row));
869                msgs.push(WebSocketResponse::CommandComplete(tag));
870                msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
871                (false, msgs, None)
872            }
873            StatementResult::SqlResult(SqlResult::Ok {
874                ok,
875                parameters,
876                notices,
877            }) => {
878                let mut msgs = vec![WebSocketResponse::CommandComplete(ok)];
879                msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
880                msgs.extend(
881                    parameters
882                        .into_iter()
883                        .map(WebSocketResponse::ParameterStatus),
884                );
885                (false, msgs, None)
886            }
887            StatementResult::SqlResult(SqlResult::Err { error, notices }) => {
888                let mut msgs = vec![WebSocketResponse::Error(error)];
889                msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
890                (true, msgs, None)
891            }
892            StatementResult::Subscribe {
893                ref desc,
894                tag,
895                mut rx,
896                ctx_extra,
897            } => {
898                if let Err(e) = send_ws_response(self, WebSocketResponse::Rows(desc.into())).await {
899                    // We consider the remote breaking the connection to be a cancellation,
900                    // matching the behavior for pgwire
901                    return (
902                        Err(e),
903                        Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
904                    );
905                }
906
907                let mut datum_vec = mz_repr::DatumVec::new();
908                let mut result_size: usize = 0;
909                let mut rows_returned = 0;
910                loop {
911                    let res = match await_rows(self, client, rx.recv()).await {
912                        Ok(res) => res,
913                        Err(e) => {
914                            // We consider the remote breaking the connection to be a cancellation,
915                            // matching the behavior for pgwire
916                            return (
917                                Err(e),
918                                Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
919                            );
920                        }
921                    };
922                    match res {
923                        Some(PeekResponseUnary::Rows(mut rows)) => {
924                            if let Err(err) = verify_datum_desc(desc, &mut rows) {
925                                let error = err.to_string();
926                                break (
927                                    true,
928                                    vec![WebSocketResponse::Error(err.into())],
929                                    Some((
930                                        StatementEndedExecutionReason::Errored { error },
931                                        ctx_extra,
932                                    )),
933                                );
934                            }
935
936                            rows_returned += rows.count();
937                            while let Some(row) = rows.next() {
938                                result_size += row.byte_len();
939                                let datums = datum_vec.borrow_with(row);
940                                let types = &desc.typ().column_types;
941                                if let Err(e) = send_ws_response(
942                                    self,
943                                    WebSocketResponse::Row(
944                                        datums
945                                            .iter()
946                                            .enumerate()
947                                            .map(|(i, d)| {
948                                                TypedDatum::new(*d, &types[i])
949                                                    .json(&JsonNumberPolicy::ConvertNumberToString)
950                                            })
951                                            .collect(),
952                                    ),
953                                )
954                                .await
955                                {
956                                    // We consider the remote breaking the connection to be a cancellation,
957                                    // matching the behavior for pgwire
958                                    return (
959                                        Err(e),
960                                        Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
961                                    );
962                                }
963                            }
964                        }
965                        Some(PeekResponseUnary::Error(error)) => {
966                            break (
967                                true,
968                                vec![WebSocketResponse::Error(
969                                    Error::Unstructured(anyhow!(error.clone())).into(),
970                                )],
971                                Some((StatementEndedExecutionReason::Errored { error }, ctx_extra)),
972                            );
973                        }
974                        Some(PeekResponseUnary::Canceled) => {
975                            break (
976                                true,
977                                vec![WebSocketResponse::Error(AdapterError::Canceled.into())],
978                                Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
979                            );
980                        }
981                        None => {
982                            break (
983                                false,
984                                vec![WebSocketResponse::CommandComplete(tag)],
985                                Some((
986                                    StatementEndedExecutionReason::Success {
987                                        result_size: Some(u64::cast_from(result_size)),
988                                        rows_returned: Some(u64::cast_from(rows_returned)),
989                                        execution_strategy: Some(
990                                            StatementExecutionStrategy::Standard,
991                                        ),
992                                    },
993                                    ctx_extra,
994                                )),
995                            );
996                        }
997                    }
998                }
999            }
1000        };
1001        for msg in msgs {
1002            if let Err(e) = send_ws_response(self, msg).await {
1003                return (
1004                    Err(e),
1005                    stmt_logging.map(|(_old_reason, ctx_extra)| {
1006                        (StatementEndedExecutionReason::Canceled, ctx_extra)
1007                    }),
1008                );
1009            }
1010        }
1011        (Ok(if is_err { Err(()) } else { Ok(()) }), stmt_logging)
1012    }
1013
1014    // Send a websocket Ping every second to verify the client is still
1015    // connected.
1016    fn connection_error(&mut self) -> BoxFuture<Error> {
1017        Box::pin(async {
1018            let mut tick = time::interval(Duration::from_secs(1));
1019            tick.tick().await;
1020            loop {
1021                tick.tick().await;
1022                if let Err(err) = self.send(Message::Ping(Vec::new().into())).await {
1023                    return err.into();
1024                }
1025            }
1026        })
1027    }
1028
1029    fn allow_subscribe(&self) -> bool {
1030        true
1031    }
1032
1033    async fn emit_streaming_notices(&mut self, notices: Vec<AdapterNotice>) -> Result<(), Error> {
1034        forward_notices(self, notices).await
1035    }
1036}
1037
1038async fn await_rows<S, F, R>(sender: &mut S, client: &mut SessionClient, f: F) -> Result<R, Error>
1039where
1040    S: ResultSender,
1041    F: Future<Output = R> + Send,
1042{
1043    let mut f = pin!(f);
1044    loop {
1045        tokio::select! {
1046            notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
1047                sender.emit_streaming_notices(vec![notice]).await?;
1048            }
1049            e = sender.connection_error() => return Err(e),
1050            r = &mut f => return Ok(r),
1051        }
1052    }
1053}
1054
1055async fn send_and_retire<S: ResultSender>(
1056    res: StatementResult,
1057    client: &mut SessionClient,
1058    sender: &mut S,
1059) -> Result<Result<(), ()>, Error> {
1060    let (res, stmt_logging) = sender.add_result(client, res).await;
1061    if let Some((reason, ctx_extra)) = stmt_logging {
1062        client.retire_execute(ctx_extra, reason);
1063    }
1064    res
1065}
1066
1067/// Returns Ok(Err) if any statement error'd during execution.
1068async fn execute_stmt_group<S: ResultSender>(
1069    client: &mut SessionClient,
1070    sender: &mut S,
1071    stmt_group: Vec<(Statement<Raw>, String, Vec<Option<String>>)>,
1072) -> Result<Result<(), ()>, Error> {
1073    let num_stmts = stmt_group.len();
1074    for (stmt, sql, params) in stmt_group {
1075        assert!(
1076            num_stmts <= 1 || params.is_empty(),
1077            "statement groups contain more than 1 statement iff Simple request, which does not support parameters"
1078        );
1079
1080        let is_aborted_txn = matches!(client.session().transaction(), TransactionStatus::Failed(_));
1081        if is_aborted_txn && !is_txn_exit_stmt(&stmt) {
1082            let err = SqlResult::err(client, Error::AbortedTransaction);
1083            let _ = send_and_retire(err.into(), client, sender).await?;
1084            return Ok(Err(()));
1085        }
1086
1087        // Mirror the behavior of the PostgreSQL simple query protocol.
1088        // See the pgwire::protocol::StateMachine::query method for details.
1089        if let Err(e) = client.start_transaction(Some(num_stmts)) {
1090            let err = SqlResult::err(client, e);
1091            let _ = send_and_retire(err.into(), client, sender).await?;
1092            return Ok(Err(()));
1093        }
1094        let res = execute_stmt(client, sender, stmt, sql, params).await?;
1095        let is_err = send_and_retire(res, client, sender).await?;
1096
1097        if is_err.is_err() {
1098            // Mirror StateMachine::error, which sometimes will clean up the
1099            // transaction state instead of always leaving it in Failed.
1100            let txn = client.session().transaction();
1101            match txn {
1102                // Error can be called from describe and parse and so might not be in an active
1103                // transaction.
1104                TransactionStatus::Default | TransactionStatus::Failed(_) => {}
1105                // In Started (i.e., a single statement) and implicit transactions cleanup themselves.
1106                TransactionStatus::Started(_) | TransactionStatus::InTransactionImplicit(_) => {
1107                    if let Err(err) = client.end_transaction(EndTransactionAction::Rollback).await {
1108                        let err = SqlResult::err(client, err);
1109                        let _ = send_and_retire(err.into(), client, sender).await?;
1110                    }
1111                }
1112                // Explicit transactions move to failed.
1113                TransactionStatus::InTransaction(_) => {
1114                    client.fail_transaction();
1115                }
1116            }
1117            return Ok(Err(()));
1118        }
1119    }
1120    Ok(Ok(()))
1121}
1122
1123/// Executes an entire [`SqlRequest`].
1124///
1125/// See the user-facing documentation about the HTTP API for a description of
1126/// the semantics of this function.
1127async fn execute_request<S: ResultSender>(
1128    client: &mut AuthedClient,
1129    request: SqlRequest,
1130    sender: &mut S,
1131) -> Result<(), Error> {
1132    let client = &mut client.client;
1133
1134    // This API prohibits executing statements with responses whose
1135    // semantics are at odds with an HTTP response.
1136    fn check_prohibited_stmts<S: ResultSender>(
1137        sender: &S,
1138        stmt: &Statement<Raw>,
1139    ) -> Result<(), Error> {
1140        let kind: StatementKind = stmt.into();
1141        let execute_responses = Plan::generated_from(&kind)
1142            .into_iter()
1143            .map(ExecuteResponse::generated_from)
1144            .flatten()
1145            .collect::<Vec<_>>();
1146
1147        // Special-case `COPY TO` statements that are not `COPY ... TO STDOUT`, since
1148        // StatementKind::Copy links to several `ExecuteResponseKind`s that are not supported,
1149        // but this specific statement should be allowed.
1150        let is_valid_copy = matches!(
1151            stmt,
1152            Statement::Copy(CopyStatement {
1153                direction: CopyDirection::To,
1154                target: CopyTarget::Expr(_),
1155                ..
1156            }) | Statement::Copy(CopyStatement {
1157                direction: CopyDirection::From,
1158                target: CopyTarget::Expr(_),
1159                ..
1160            })
1161        );
1162
1163        if !is_valid_copy
1164            && execute_responses.iter().any(|execute_response| {
1165                // Returns true if a statement or execute response are unsupported.
1166                match execute_response {
1167                    ExecuteResponseKind::Subscribing if sender.allow_subscribe() => false,
1168                    ExecuteResponseKind::Fetch
1169                    | ExecuteResponseKind::Subscribing
1170                    | ExecuteResponseKind::CopyFrom
1171                    | ExecuteResponseKind::DeclaredCursor
1172                    | ExecuteResponseKind::ClosedCursor => true,
1173                    // Various statements generate `PeekPlan` (`SELECT`, `COPY`,
1174                    // `EXPLAIN`, `SHOW`) which has both `SendRows` and `CopyTo` as its
1175                    // possible response types. but `COPY` needs be picked out because
1176                    // http don't support its response type
1177                    ExecuteResponseKind::CopyTo if matches!(kind, StatementKind::Copy) => true,
1178                    _ => false,
1179                }
1180            })
1181        {
1182            return Err(Error::Unsupported(stmt.to_ast_string_simple()));
1183        }
1184        Ok(())
1185    }
1186
1187    fn parse<'a>(
1188        client: &SessionClient,
1189        query: &'a str,
1190    ) -> Result<Vec<StatementParseResult<'a>>, Error> {
1191        let result = client
1192            .parse(query)
1193            .map_err(|e| Error::Unstructured(anyhow!(e)))?;
1194        result.map_err(|e| AdapterError::from(e).into())
1195    }
1196
1197    let mut stmt_groups = vec![];
1198
1199    match request {
1200        SqlRequest::Simple { query } => match parse(client, &query) {
1201            Ok(stmts) => {
1202                let mut stmt_group = Vec::with_capacity(stmts.len());
1203                let mut stmt_err = None;
1204                for StatementParseResult { ast: stmt, sql } in stmts {
1205                    if let Err(err) = check_prohibited_stmts(sender, &stmt) {
1206                        stmt_err = Some(err);
1207                        break;
1208                    }
1209                    stmt_group.push((stmt, sql.to_string(), vec![]));
1210                }
1211                stmt_groups.push(stmt_err.map(Err).unwrap_or_else(|| Ok(stmt_group)));
1212            }
1213            Err(e) => stmt_groups.push(Err(e)),
1214        },
1215        SqlRequest::Extended { queries } => {
1216            for ExtendedRequest { query, params } in queries {
1217                match parse(client, &query) {
1218                    Ok(mut stmts) => {
1219                        if stmts.len() != 1 {
1220                            return Err(Error::Unstructured(anyhow!(
1221                                "each query must contain exactly 1 statement, but \"{}\" contains {}",
1222                                query,
1223                                stmts.len()
1224                            )));
1225                        }
1226
1227                        let StatementParseResult { ast: stmt, sql } = stmts.pop().unwrap();
1228                        stmt_groups.push(
1229                            check_prohibited_stmts(sender, &stmt)
1230                                .map(|_| vec![(stmt, sql.to_string(), params)]),
1231                        );
1232                    }
1233                    Err(e) => stmt_groups.push(Err(e)),
1234                };
1235            }
1236        }
1237    }
1238
1239    for stmt_group_res in stmt_groups {
1240        let executed = match stmt_group_res {
1241            Ok(stmt_group) => execute_stmt_group(client, sender, stmt_group).await,
1242            Err(e) => {
1243                let err = SqlResult::err(client, e);
1244                let _ = send_and_retire(err.into(), client, sender).await?;
1245                Ok(Err(()))
1246            }
1247        };
1248        // At the end of each group, commit implicit transactions. Do that here so that any `?`
1249        // early return can still be handled here.
1250        if client.session().transaction().is_implicit() {
1251            let ended = client.end_transaction(EndTransactionAction::Commit).await;
1252            if let Err(err) = ended {
1253                let err = SqlResult::err(client, err);
1254                let _ = send_and_retire(StatementResult::SqlResult(err), client, sender).await?;
1255            }
1256        }
1257        if executed?.is_err() {
1258            break;
1259        }
1260    }
1261
1262    Ok(())
1263}
1264
1265/// Executes a single statement in a [`SqlRequest`].
1266async fn execute_stmt<S: ResultSender>(
1267    client: &mut SessionClient,
1268    sender: &mut S,
1269    stmt: Statement<Raw>,
1270    sql: String,
1271    raw_params: Vec<Option<String>>,
1272) -> Result<StatementResult, Error> {
1273    const EMPTY_PORTAL: &str = "";
1274    if let Err(e) = client
1275        .prepare(EMPTY_PORTAL.into(), Some(stmt.clone()), sql, vec![])
1276        .await
1277    {
1278        return Ok(SqlResult::err(client, e).into());
1279    }
1280
1281    let prep_stmt = match client.get_prepared_statement(EMPTY_PORTAL).await {
1282        Ok(stmt) => stmt,
1283        Err(err) => {
1284            return Ok(SqlResult::err(client, err).into());
1285        }
1286    };
1287
1288    let param_types = &prep_stmt.desc().param_types;
1289    if param_types.len() != raw_params.len() {
1290        let message = anyhow!(
1291            "request supplied {actual} parameters, \
1292                        but {statement} requires {expected}",
1293            statement = stmt.to_ast_string_simple(),
1294            actual = raw_params.len(),
1295            expected = param_types.len()
1296        );
1297        return Ok(SqlResult::err(client, Error::Unstructured(message)).into());
1298    }
1299
1300    let buf = RowArena::new();
1301    let mut params = vec![];
1302    for (raw_param, mz_typ) in izip!(raw_params, param_types) {
1303        let pg_typ = mz_pgrepr::Type::from(mz_typ);
1304        let datum = match raw_param {
1305            None => Datum::Null,
1306            Some(raw_param) => {
1307                match mz_pgrepr::Value::decode(
1308                    mz_pgwire_common::Format::Text,
1309                    &pg_typ,
1310                    raw_param.as_bytes(),
1311                ) {
1312                    Ok(param) => param.into_datum(&buf, &pg_typ),
1313                    Err(err) => {
1314                        let msg = anyhow!("unable to decode parameter: {}", err);
1315                        return Ok(SqlResult::err(client, Error::Unstructured(msg)).into());
1316                    }
1317                }
1318            }
1319        };
1320        params.push((datum, mz_typ.clone()))
1321    }
1322
1323    let result_formats = vec![
1324        mz_pgwire_common::Format::Text;
1325        prep_stmt
1326            .desc()
1327            .relation_desc
1328            .clone()
1329            .map(|desc| desc.typ().column_types.len())
1330            .unwrap_or(0)
1331    ];
1332
1333    let desc = prep_stmt.desc().clone();
1334    let revision = prep_stmt.catalog_revision;
1335    let stmt = prep_stmt.stmt().cloned();
1336    let logging = Arc::clone(prep_stmt.logging());
1337    if let Err(err) = client.session().set_portal(
1338        EMPTY_PORTAL.into(),
1339        desc,
1340        stmt,
1341        logging,
1342        params,
1343        result_formats,
1344        revision,
1345    ) {
1346        return Ok(SqlResult::err(client, err).into());
1347    }
1348
1349    let desc = client
1350        .session()
1351        // We do not need to verify here because `client.execute` verifies below.
1352        .get_portal_unverified(EMPTY_PORTAL)
1353        .map(|portal| portal.desc.clone())
1354        .expect("unnamed portal should be present");
1355
1356    let res = client
1357        .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
1358        .await;
1359
1360    if S::SUPPORTS_STREAMING_NOTICES {
1361        sender
1362            .emit_streaming_notices(client.session().drain_notices())
1363            .await?;
1364    }
1365
1366    let (res, execute_started) = match res {
1367        Ok(res) => res,
1368        Err(e) => {
1369            return Ok(SqlResult::err(client, e).into());
1370        }
1371    };
1372    let tag = res.tag();
1373
1374    Ok(match res {
1375        ExecuteResponse::CreatedConnection { .. }
1376        | ExecuteResponse::CreatedDatabase { .. }
1377        | ExecuteResponse::CreatedSchema { .. }
1378        | ExecuteResponse::CreatedRole
1379        | ExecuteResponse::CreatedCluster { .. }
1380        | ExecuteResponse::CreatedClusterReplica { .. }
1381        | ExecuteResponse::CreatedTable { .. }
1382        | ExecuteResponse::CreatedIndex { .. }
1383        | ExecuteResponse::CreatedIntrospectionSubscribe
1384        | ExecuteResponse::CreatedSecret { .. }
1385        | ExecuteResponse::CreatedSource { .. }
1386        | ExecuteResponse::CreatedSink { .. }
1387        | ExecuteResponse::CreatedView { .. }
1388        | ExecuteResponse::CreatedViews { .. }
1389        | ExecuteResponse::CreatedMaterializedView { .. }
1390        | ExecuteResponse::CreatedContinualTask { .. }
1391        | ExecuteResponse::CreatedType
1392        | ExecuteResponse::CreatedNetworkPolicy
1393        | ExecuteResponse::Comment
1394        | ExecuteResponse::Deleted(_)
1395        | ExecuteResponse::DiscardedTemp
1396        | ExecuteResponse::DiscardedAll
1397        | ExecuteResponse::DroppedObject(_)
1398        | ExecuteResponse::DroppedOwned
1399        | ExecuteResponse::EmptyQuery
1400        | ExecuteResponse::GrantedPrivilege
1401        | ExecuteResponse::GrantedRole
1402        | ExecuteResponse::Inserted(_)
1403        | ExecuteResponse::Copied(_)
1404        | ExecuteResponse::Raised
1405        | ExecuteResponse::ReassignOwned
1406        | ExecuteResponse::RevokedPrivilege
1407        | ExecuteResponse::AlteredDefaultPrivileges
1408        | ExecuteResponse::RevokedRole
1409        | ExecuteResponse::StartedTransaction { .. }
1410        | ExecuteResponse::Updated(_)
1411        | ExecuteResponse::AlteredObject(_)
1412        | ExecuteResponse::AlteredRole
1413        | ExecuteResponse::AlteredSystemConfiguration
1414        | ExecuteResponse::Deallocate { .. }
1415        | ExecuteResponse::ValidatedConnection
1416        | ExecuteResponse::Prepare => SqlResult::ok(
1417            client,
1418            tag.expect("ok only called on tag-generating results"),
1419            Vec::default(),
1420        )
1421        .into(),
1422        ExecuteResponse::TransactionCommitted { params }
1423        | ExecuteResponse::TransactionRolledBack { params } => {
1424            let notify_set: mz_ore::collections::HashSet<_> = client
1425                .session()
1426                .vars()
1427                .notify_set()
1428                .map(|v| v.name().to_string())
1429                .collect();
1430            let params = params
1431                .into_iter()
1432                .filter(|(name, _value)| notify_set.contains(*name))
1433                .map(|(name, value)| ParameterStatus {
1434                    name: name.to_string(),
1435                    value,
1436                })
1437                .collect();
1438            SqlResult::ok(
1439                client,
1440                tag.expect("ok only called on tag-generating results"),
1441                params,
1442            )
1443            .into()
1444        }
1445        ExecuteResponse::SetVariable { name, .. } => {
1446            let mut params = Vec::with_capacity(1);
1447            if let Some(var) = client
1448                .session()
1449                .vars()
1450                .notify_set()
1451                .find(|v| v.name() == &name)
1452            {
1453                params.push(ParameterStatus {
1454                    name,
1455                    value: var.value(),
1456                });
1457            };
1458            SqlResult::ok(
1459                client,
1460                tag.expect("ok only called on tag-generating results"),
1461                params,
1462            )
1463            .into()
1464        }
1465        ExecuteResponse::SendingRows {
1466            future: mut rows,
1467            instance_id,
1468            strategy,
1469        } => {
1470            let rows = match await_rows(sender, client, &mut rows).await? {
1471                PeekResponseUnary::Rows(rows) => {
1472                    RecordFirstRowStream::record(
1473                        execute_started,
1474                        client,
1475                        Some(instance_id),
1476                        Some(strategy),
1477                    );
1478                    rows
1479                }
1480                PeekResponseUnary::Error(e) => {
1481                    return Ok(SqlResult::err(client, Error::Unstructured(anyhow!(e))).into());
1482                }
1483                PeekResponseUnary::Canceled => {
1484                    return Ok(SqlResult::err(client, AdapterError::Canceled).into());
1485                }
1486            };
1487            SqlResult::rows(
1488                client,
1489                rows,
1490                &desc.relation_desc.expect("RelationDesc must exist"),
1491            )
1492            .into()
1493        }
1494        ExecuteResponse::SendingRowsImmediate { rows } => SqlResult::rows(
1495            client,
1496            rows,
1497            &desc.relation_desc.expect("RelationDesc must exist"),
1498        )
1499        .into(),
1500        ExecuteResponse::Subscribing {
1501            rx,
1502            ctx_extra,
1503            instance_id,
1504        } => StatementResult::Subscribe {
1505            tag: "SUBSCRIBE".into(),
1506            desc: desc.relation_desc.unwrap(),
1507            rx: RecordFirstRowStream::new(
1508                Box::new(UnboundedReceiverStream::new(rx)),
1509                execute_started,
1510                client,
1511                Some(instance_id),
1512                None,
1513            ),
1514            ctx_extra,
1515        },
1516        res @ (ExecuteResponse::Fetch { .. }
1517        | ExecuteResponse::CopyTo { .. }
1518        | ExecuteResponse::CopyFrom { .. }
1519        | ExecuteResponse::DeclaredCursor
1520        | ExecuteResponse::ClosedCursor) => SqlResult::err(
1521            client,
1522            Error::Unstructured(anyhow!(
1523                "internal error: encountered prohibited ExecuteResponse {:?}.\n\n
1524            This is a bug. Can you please file an bug report letting us know?\n
1525            https://github.com/MaterializeInc/materialize/discussions/new?category=bug-reports",
1526                ExecuteResponseKind::from(res)
1527            )),
1528        )
1529        .into(),
1530    })
1531}
1532
1533fn make_notices(client: &mut SessionClient) -> Vec<Notice> {
1534    client
1535        .session()
1536        .drain_notices()
1537        .into_iter()
1538        .map(|notice| Notice {
1539            message: notice.to_string(),
1540            code: notice.code().code().to_string(),
1541            severity: notice.severity().as_str().to_lowercase(),
1542            detail: notice.detail(),
1543            hint: notice.hint(),
1544        })
1545        .collect()
1546}
1547
1548// Duplicated from protocol.rs.
1549// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
1550fn is_txn_exit_stmt(stmt: &Statement<Raw>) -> bool {
1551    matches!(
1552        stmt,
1553        Statement::Commit(_) | Statement::Rollback(_) | Statement::Prepare(_)
1554    )
1555}
1556
1557#[cfg(test)]
1558mod tests {
1559    use std::collections::BTreeMap;
1560
1561    use super::WebSocketAuth;
1562
1563    #[mz_ore::test]
1564    fn smoke_test_websocket_auth_parse() {
1565        struct TestCase {
1566            json: &'static str,
1567            expected: WebSocketAuth,
1568        }
1569
1570        let test_cases = vec![
1571            TestCase {
1572                json: r#"{ "user": "mz", "password": "1234" }"#,
1573                expected: WebSocketAuth::Basic {
1574                    user: "mz".to_string(),
1575                    password: "1234".to_string(),
1576                    options: BTreeMap::default(),
1577                },
1578            },
1579            TestCase {
1580                json: r#"{ "user": "mz", "password": "1234", "options": {} }"#,
1581                expected: WebSocketAuth::Basic {
1582                    user: "mz".to_string(),
1583                    password: "1234".to_string(),
1584                    options: BTreeMap::default(),
1585                },
1586            },
1587            TestCase {
1588                json: r#"{ "token": "i_am_a_token" }"#,
1589                expected: WebSocketAuth::Bearer {
1590                    token: "i_am_a_token".to_string(),
1591                    options: BTreeMap::default(),
1592                },
1593            },
1594            TestCase {
1595                json: r#"{ "token": "i_am_a_token", "options": { "foo": "bar" } }"#,
1596                expected: WebSocketAuth::Bearer {
1597                    token: "i_am_a_token".to_string(),
1598                    options: BTreeMap::from([("foo".to_string(), "bar".to_string())]),
1599                },
1600            },
1601        ];
1602
1603        fn assert_parse(json: &'static str, expected: WebSocketAuth) {
1604            let parsed: WebSocketAuth = serde_json::from_str(json).unwrap();
1605            assert_eq!(parsed, expected);
1606        }
1607
1608        for TestCase { json, expected } in test_cases {
1609            assert_parse(json, expected)
1610        }
1611    }
1612}