Skip to main content

mz_environmentd/http/
sql.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::net::{IpAddr, SocketAddr};
12use std::pin::pin;
13use std::sync::Arc;
14use std::time::Duration;
15
16use anyhow::anyhow;
17use async_trait::async_trait;
18use axum::extract::connect_info::ConnectInfo;
19use axum::extract::ws::{CloseFrame, Message, Utf8Bytes, WebSocket};
20use axum::extract::{State, WebSocketUpgrade};
21use axum::response::IntoResponse;
22use axum::{Extension, Json};
23use futures::Future;
24use futures::future::BoxFuture;
25
26use http::StatusCode;
27use itertools::Itertools;
28use mz_adapter::client::RecordFirstRowStream;
29use mz_adapter::session::{EndTransactionAction, TransactionStatus};
30use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
31use mz_adapter::{
32    AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, ExecuteResponseKind,
33    PeekResponseUnary, SessionClient, verify_datum_desc,
34};
35use mz_auth::password::Password;
36use mz_catalog::memory::objects::{Cluster, ClusterReplica};
37use mz_interchange::encode::TypedDatum;
38use mz_interchange::json::{JsonNumberPolicy, ToJson};
39use mz_ore::cast::CastFrom;
40use mz_ore::metrics::{MakeCollectorOpts, MetricsRegistry};
41use mz_ore::result::ResultExt;
42use mz_ore::sql::Sql;
43use mz_repr::{Datum, RelationDesc, RowArena, RowIterator};
44use mz_sql::ast::display::AstDisplay;
45use mz_sql::ast::{CopyDirection, CopyStatement, CopyTarget, Raw, Statement, StatementKind};
46use mz_sql::parse::StatementParseResult;
47use mz_sql::plan::Plan;
48use mz_sql::session::metadata::SessionMetadata;
49use prometheus::Opts;
50use prometheus::core::{AtomicF64, GenericGaugeVec};
51use serde::{Deserialize, Serialize};
52use tokio::{select, time};
53use tokio_postgres::error::SqlState;
54use tokio_stream::wrappers::UnboundedReceiverStream;
55use tower_sessions::Session as TowerSession;
56use tracing::{debug, info};
57use tungstenite::protocol::frame::coding::CloseCode;
58
59use crate::http::prometheus::PrometheusSqlQuery;
60use crate::http::{
61    AuthError, AuthedClient, AuthedUser, MAX_REQUEST_SIZE, WsState, ensure_session_unexpired,
62    init_ws, maybe_get_authenticated_session,
63};
64
65#[derive(Debug, thiserror::Error)]
66pub enum Error {
67    #[error(transparent)]
68    Adapter(#[from] AdapterError),
69    #[error(transparent)]
70    Json(#[from] serde_json::Error),
71    #[error(transparent)]
72    Axum(#[from] axum::Error),
73    #[error("SUBSCRIBE only supported over websocket")]
74    SubscribeOnlyOverWs,
75    #[error("current transaction is aborted, commands ignored until end of transaction block")]
76    AbortedTransaction,
77    #[error("unsupported via this API: {0}")]
78    Unsupported(String),
79    #[error("{0}")]
80    Unstructured(anyhow::Error),
81}
82
83impl Error {
84    pub fn detail(&self) -> Option<String> {
85        match self {
86            Error::Adapter(err) => err.detail(),
87            _ => None,
88        }
89    }
90
91    pub fn hint(&self) -> Option<String> {
92        match self {
93            Error::Adapter(err) => err.hint(),
94            _ => None,
95        }
96    }
97
98    pub fn position(&self) -> Option<usize> {
99        match self {
100            Error::Adapter(err) => err.position(),
101            _ => None,
102        }
103    }
104
105    pub fn code(&self) -> SqlState {
106        match self {
107            Error::Adapter(err) => err.code(),
108            Error::AbortedTransaction => SqlState::IN_FAILED_SQL_TRANSACTION,
109            _ => SqlState::INTERNAL_ERROR,
110        }
111    }
112}
113
114static PER_REPLICA_LABELS: &[&str] = &["replica_full_name", "instance_id", "replica_id"];
115
116async fn execute_promsql_query(
117    client: &mut AuthedClient,
118    query: &PrometheusSqlQuery<'_>,
119    metrics_registry: &MetricsRegistry,
120    metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
121    cluster: Option<(&Cluster, &ClusterReplica)>,
122) {
123    assert_eq!(query.per_replica, cluster.is_some());
124
125    let mut res = SqlResponse {
126        results: Vec::new(),
127    };
128
129    execute_request(client, query.to_sql_request(cluster), &mut res)
130        .await
131        .expect("valid SQL query");
132
133    let result = match res.results.as_slice() {
134        // Each query issued is preceded by several SET commands
135        // to make sure it is routed to the right cluster replica.
136        [
137            SqlResult::Ok { .. },
138            SqlResult::Ok { .. },
139            SqlResult::Ok { .. },
140            result,
141        ] => result,
142        // Transient errors are fine, like if the cluster or replica
143        // was dropped before the promsql query was executed. We
144        // should not see errors in the steady state.
145        _ => {
146            info!(
147                "error executing prometheus query {}: {:?}",
148                query.metric_name, res
149            );
150            return;
151        }
152    };
153
154    let SqlResult::Rows { desc, rows, .. } = result else {
155        info!(
156            "did not receive rows for SQL query for prometheus metric {}: {:?}, {:?}",
157            query.metric_name, result, cluster
158        );
159        return;
160    };
161
162    let gauge_vec = metrics_by_name
163        .entry(query.metric_name.to_string())
164        .or_insert_with(|| {
165            let mut label_names: Vec<String> = desc
166                .columns
167                .iter()
168                .filter(|col| col.name != query.value_column_name)
169                .map(|col| col.name.clone())
170                .collect();
171
172            if query.per_replica {
173                label_names.extend(PER_REPLICA_LABELS.iter().map(|label| label.to_string()));
174            }
175
176            metrics_registry.register::<GenericGaugeVec<AtomicF64>>(MakeCollectorOpts {
177                opts: Opts::new(query.metric_name, query.help).variable_labels(label_names),
178                buckets: None,
179            })
180        });
181
182    for row in rows {
183        let mut label_values = desc
184            .columns
185            .iter()
186            .zip_eq(row)
187            .filter(|(col, _)| col.name != query.value_column_name)
188            .map(|(_, val)| val.as_str().expect("must be string"))
189            .collect::<Vec<_>>();
190
191        let value = desc
192            .columns
193            .iter()
194            .zip_eq(row)
195            .find(|(col, _)| col.name == query.value_column_name)
196            .map(|(_, val)| val.as_str().unwrap_or("0").parse::<f64>().unwrap_or(0.0))
197            .unwrap_or(0.0);
198
199        match cluster {
200            Some((cluster, replica)) => {
201                let replica_full_name = format!("{}.{}", cluster.name, replica.name);
202                let cluster_id = cluster.id.to_string();
203                let replica_id = replica.replica_id.to_string();
204
205                label_values.push(&replica_full_name);
206                label_values.push(&cluster_id);
207                label_values.push(&replica_id);
208
209                gauge_vec
210                    .get_metric_with_label_values(&label_values)
211                    .expect("valid labels")
212                    .set(value);
213            }
214            None => {
215                gauge_vec
216                    .get_metric_with_label_values(&label_values)
217                    .expect("valid labels")
218                    .set(value);
219            }
220        }
221    }
222}
223
224async fn handle_promsql_query(
225    client: &mut AuthedClient,
226    query: &PrometheusSqlQuery<'_>,
227    metrics_registry: &MetricsRegistry,
228    metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
229) {
230    if !query.per_replica {
231        execute_promsql_query(client, query, metrics_registry, metrics_by_name, None).await;
232        return;
233    }
234
235    let catalog = client.client.catalog_snapshot("handle_promsql_query").await;
236    let clusters: Vec<&Cluster> = catalog.clusters().collect();
237
238    for cluster in clusters {
239        for replica in cluster.replicas() {
240            execute_promsql_query(
241                client,
242                query,
243                metrics_registry,
244                metrics_by_name,
245                Some((cluster, replica)),
246            )
247            .await;
248        }
249    }
250}
251
252pub async fn handle_promsql(
253    mut client: AuthedClient,
254    queries: &[PrometheusSqlQuery<'_>],
255) -> MetricsRegistry {
256    let metrics_registry = MetricsRegistry::new();
257    let mut metrics_by_name = BTreeMap::new();
258
259    for query in queries {
260        handle_promsql_query(&mut client, query, &metrics_registry, &mut metrics_by_name).await;
261    }
262
263    metrics_registry
264}
265
266pub async fn handle_sql(
267    mut client: AuthedClient,
268    Json(request): Json<SqlRequest>,
269) -> impl IntoResponse {
270    let mut res = SqlResponse {
271        results: Vec::new(),
272    };
273    // Don't need to worry about timeouts or resetting cancel here because there is always exactly 1
274    // request.
275    match execute_request(&mut client, request, &mut res).await {
276        Ok(()) => Ok(Json(res)),
277        Err(e) => Err((StatusCode::BAD_REQUEST, e.to_string())),
278    }
279}
280
281#[derive(Debug)]
282pub enum ExistingUser {
283    /// An AuthedUser provided by the
284    /// `x_materialize_user_header_auth` middleware
285    XMaterializeUserHeader(AuthedUser),
286    /// An AuthedUser provided by an authenticated session
287    /// established via [`crate::http::handle_login`].
288    Session(AuthedUser),
289}
290
291pub(crate) async fn handle_sql_ws(
292    State(state): State<WsState>,
293    existing_user: Option<Extension<AuthedUser>>,
294    ws: WebSocketUpgrade,
295    ConnectInfo(addr): ConnectInfo<SocketAddr>,
296    tower_session: Option<Extension<TowerSession>>,
297) -> Result<impl IntoResponse, AuthError> {
298    let session = tower_session.map(|Extension(session)| session);
299    // The `x_materialize_user_header_auth` middleware may have already provided the user for us
300    let user = match existing_user {
301        Some(Extension(user)) => Some(ExistingUser::XMaterializeUserHeader(user)),
302        None => {
303            let session = maybe_get_authenticated_session(session.as_ref()).await;
304            if let Some((session, session_data)) = session {
305                let user = ensure_session_unexpired(session, session_data).await?;
306                Some(ExistingUser::Session(user))
307            } else {
308                None
309            }
310        }
311    };
312
313    let addr = Box::new(addr.ip());
314    Ok(ws
315        .max_message_size(MAX_REQUEST_SIZE)
316        .on_upgrade(|ws| async move { run_ws(state, user, *addr, ws).await }))
317}
318
319#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
320#[serde(untagged)]
321pub enum WebSocketAuth {
322    Basic {
323        user: String,
324        password: Password,
325        #[serde(default)]
326        options: BTreeMap<String, String>,
327    },
328    Bearer {
329        token: String,
330        #[serde(default)]
331        options: BTreeMap<String, String>,
332    },
333    OptionsOnly {
334        #[serde(default)]
335        options: BTreeMap<String, String>,
336    },
337}
338
339async fn run_ws(state: WsState, user: Option<ExistingUser>, peer_addr: IpAddr, mut ws: WebSocket) {
340    let mut client = match init_ws(state, user, peer_addr, &mut ws).await {
341        Ok(client) => client,
342        Err(e) => {
343            // We omit most detail from the error message we send to the client, to
344            // avoid giving attackers unnecessary information during auth. AdapterErrors
345            // are safe to return because they're generated after authentication.
346            debug!("WS request failed init: {}", e);
347            let reason: Utf8Bytes = match e.downcast_ref::<AdapterError>() {
348                Some(error) => error.to_string().into(),
349                None => "unauthorized".to_string().into(),
350            };
351            let _ = ws
352                .send(Message::Close(Some(CloseFrame {
353                    code: CloseCode::Protocol.into(),
354                    reason,
355                })))
356                .await;
357            return;
358        }
359    };
360
361    // Successful auth, send startup messages.
362    let mut msgs = Vec::new();
363    let session = client.client.session();
364    for var in session.vars().notify_set() {
365        msgs.push(WebSocketResponse::ParameterStatus(ParameterStatus {
366            name: var.name().to_string(),
367            value: var.value(),
368        }));
369    }
370    msgs.push(WebSocketResponse::BackendKeyData(BackendKeyData {
371        conn_id: session.conn_id().unhandled(),
372        secret_key: session.secret_key(),
373    }));
374    msgs.push(WebSocketResponse::ReadyForQuery(
375        session.transaction_code().into(),
376    ));
377    for msg in msgs {
378        let _ = ws
379            .send(Message::Text(
380                serde_json::to_string(&msg).expect("must serialize").into(),
381            ))
382            .await;
383    }
384
385    // Send any notices that might have been generated on startup.
386    let notices = session.drain_notices();
387    if let Err(err) = forward_notices(&mut ws, notices).await {
388        debug!("failed to forward notices to WebSocket, {err:?}");
389        return;
390    }
391
392    loop {
393        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
394        let msg = select! {
395            biased;
396
397            // `recv_timeout()` is cancel-safe as per it's docs.
398            Some(timeout) = client.client.recv_timeout() => {
399                client.client.terminate().await;
400                // We must wait for the client to send a request before we can send the error
401                // response. Although this isn't the PG wire protocol, we choose to mirror it by
402                // only sending errors as responses to requests.
403                let _ = ws.recv().await;
404                let err = Error::from(AdapterError::from(timeout));
405                let _ = send_ws_response(&mut ws, WebSocketResponse::Error(err.into())).await;
406                return;
407            },
408            message = ws.recv() => message,
409        };
410
411        client.client.remove_idle_in_transaction_session_timeout();
412
413        let msg = match msg {
414            Some(Ok(msg)) => msg,
415            _ => {
416                // client disconnected
417                return;
418            }
419        };
420
421        let req: Result<SqlRequest, Error> = match msg {
422            Message::Text(data) => serde_json::from_str(&data).err_into(),
423            Message::Binary(data) => serde_json::from_slice(&data).err_into(),
424            // Handled automatically by the server.
425            Message::Ping(_) => {
426                continue;
427            }
428            Message::Pong(_) => {
429                continue;
430            }
431            Message::Close(_) => {
432                return;
433            }
434        };
435
436        // Figure out if we need to send an error, any notices, but always the ready message.
437        let err = match run_ws_request(req, &mut client, &mut ws).await {
438            Ok(()) => None,
439            Err(err) => Some(WebSocketResponse::Error(err.into())),
440        };
441
442        // After running our request, there are several messages we need to send in a
443        // specific order.
444        //
445        // Note: we nest these into a closure so we can centralize our error handling
446        // for when sending over the WebSocket fails. We could also use a try {} block
447        // here, but those aren't stabilized yet.
448        let ws_response = || async {
449            // First respond with any error that might have occurred.
450            if let Some(e_resp) = err {
451                send_ws_response(&mut ws, e_resp).await?;
452            }
453
454            // Then forward along any notices we generated.
455            let notices = client.client.session().drain_notices();
456            forward_notices(&mut ws, notices).await?;
457
458            // Finally, respond that we're ready for the next query.
459            let ready =
460                WebSocketResponse::ReadyForQuery(client.client.session().transaction_code().into());
461            send_ws_response(&mut ws, ready).await?;
462
463            Ok::<_, Error>(())
464        };
465
466        if let Err(err) = ws_response().await {
467            debug!("failed to send response over WebSocket, {err:?}");
468            return;
469        }
470    }
471}
472
473async fn run_ws_request(
474    req: Result<SqlRequest, Error>,
475    client: &mut AuthedClient,
476    ws: &mut WebSocket,
477) -> Result<(), Error> {
478    let req = req?;
479    execute_request(client, req, ws).await
480}
481
482/// Sends a single [`WebSocketResponse`] over the provided [`WebSocket`].
483async fn send_ws_response(ws: &mut WebSocket, resp: WebSocketResponse) -> Result<(), Error> {
484    let msg = serde_json::to_string(&resp).unwrap();
485    let msg = Message::Text(msg.into());
486    ws.send(msg).await?;
487
488    Ok(())
489}
490
491/// Forwards a collection of Notices to the provided [`WebSocket`].
492async fn forward_notices(
493    ws: &mut WebSocket,
494    notices: impl IntoIterator<Item = AdapterNotice>,
495) -> Result<(), Error> {
496    let ws_notices = notices.into_iter().map(|notice| {
497        WebSocketResponse::Notice(Notice {
498            message: notice.to_string(),
499            code: notice.code().code().to_string(),
500            severity: notice.severity().as_str().to_lowercase(),
501            detail: notice.detail(),
502            hint: notice.hint(),
503        })
504    });
505
506    for notice in ws_notices {
507        send_ws_response(ws, notice).await?;
508    }
509
510    Ok(())
511}
512
513/// A request to execute SQL over HTTP.
514#[derive(Serialize, Deserialize, Debug)]
515#[serde(untagged)]
516pub enum SqlRequest {
517    /// A simple query request.
518    Simple {
519        /// A query string containing zero or more queries delimited by
520        /// semicolons.
521        query: Sql,
522    },
523    /// An extended query request.
524    Extended {
525        /// Queries to execute using the extended protocol.
526        queries: Vec<ExtendedRequest>,
527    },
528}
529
530/// An request to execute a SQL query using the extended protocol.
531#[derive(Serialize, Deserialize, Debug)]
532pub struct ExtendedRequest {
533    /// A query string containing zero or one queries.
534    query: String,
535    /// Optional parameters for the query.
536    #[serde(default)]
537    params: Vec<Option<String>>,
538}
539
540/// The response to a `SqlRequest`.
541#[derive(Debug, Serialize, Deserialize)]
542pub struct SqlResponse {
543    /// The results for each query in the request.
544    pub(in crate::http) results: Vec<SqlResult>,
545}
546
547impl SqlResponse {
548    /// Creates a new empty SqlResponse for collecting results.
549    pub(in crate::http) fn new() -> Self {
550        Self {
551            results: Vec::new(),
552        }
553    }
554}
555
556pub(in crate::http) enum StatementResult {
557    SqlResult(SqlResult),
558    Subscribe {
559        desc: RelationDesc,
560        tag: String,
561        rx: RecordFirstRowStream,
562        ctx_extra: ExecuteContextGuard,
563    },
564}
565
566impl From<SqlResult> for StatementResult {
567    fn from(inner: SqlResult) -> Self {
568        Self::SqlResult(inner)
569    }
570}
571
572/// The result of a single query in a [`SqlResponse`].
573#[derive(Debug, Serialize, Deserialize)]
574#[serde(untagged)]
575pub enum SqlResult {
576    /// The query returned rows.
577    Rows {
578        /// The command complete tag.
579        tag: String,
580        /// The result rows.
581        rows: Vec<Vec<serde_json::Value>>,
582        /// Information about each column.
583        desc: Description,
584        // Any notices generated during execution of the query.
585        notices: Vec<Notice>,
586    },
587    /// The query executed successfully but did not return rows.
588    Ok {
589        /// The command complete tag.
590        ok: String,
591        /// Any notices generated during execution of the query.
592        notices: Vec<Notice>,
593        /// Any parameters that may have changed.
594        ///
595        /// Note: skip serializing this field in a response if the list of parameters is empty.
596        #[serde(skip_serializing_if = "Vec::is_empty")]
597        parameters: Vec<ParameterStatus>,
598    },
599    /// The query returned an error.
600    Err {
601        error: SqlError,
602        // Any notices generated during execution of the query.
603        notices: Vec<Notice>,
604    },
605}
606
607impl SqlResult {
608    /// Convert adapter Row results into the web row result format. Error if the row format does not
609    /// match the expected descriptor.
610    // TODO(aljoscha): Bail when max_result_size is exceeded.
611    async fn rows<S>(
612        sender: &mut S,
613        client: &mut SessionClient,
614        mut rows_stream: RecordFirstRowStream,
615        max_query_result_size: usize,
616        desc: &RelationDesc,
617    ) -> Result<SqlResult, Error>
618    where
619        S: ResultSender,
620    {
621        let mut rows: Vec<Vec<serde_json::Value>> = vec![];
622        let mut datum_vec = mz_repr::DatumVec::new();
623        let types = &desc.typ().column_types;
624
625        let mut query_result_size = 0;
626
627        loop {
628            let peek_response = tokio::select! {
629                notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
630                    sender.emit_streaming_notices(vec![notice]).await?;
631                    continue;
632                }
633                e = sender.connection_error() => return Err(e),
634                r = rows_stream.recv() => {
635                    match r {
636                        Some(r) => r,
637                        None => break,
638                    }
639                },
640            };
641
642            let mut sql_rows = match peek_response {
643                PeekResponseUnary::Rows(rows) => rows,
644                PeekResponseUnary::Error(e) => {
645                    return Ok(SqlResult::err(client, Error::Unstructured(anyhow!(e))));
646                }
647                PeekResponseUnary::DependencyDropped(dep) => {
648                    return Ok(SqlResult::err(client, dep.to_concurrent_dependency_drop()));
649                }
650                PeekResponseUnary::Canceled => {
651                    return Ok(SqlResult::err(client, AdapterError::Canceled));
652                }
653            };
654
655            if let Err(err) = verify_datum_desc(desc, &mut sql_rows) {
656                return Ok(SqlResult::Err {
657                    error: err.into(),
658                    notices: make_notices(client),
659                });
660            }
661
662            while let Some(row) = sql_rows.next() {
663                query_result_size += row.byte_len();
664                if query_result_size > max_query_result_size {
665                    use bytesize::ByteSize;
666                    return Ok(SqlResult::err(
667                        client,
668                        AdapterError::ResultSize(format!(
669                            "result exceeds max size of {}",
670                            ByteSize::b(u64::cast_from(max_query_result_size))
671                        )),
672                    ));
673                }
674
675                let datums = datum_vec.borrow_with(row);
676                rows.push(
677                    datums
678                        .iter()
679                        .enumerate()
680                        .map(|(i, d)| {
681                            TypedDatum::new(*d, &types[i])
682                                .json(&JsonNumberPolicy::ConvertNumberToString)
683                        })
684                        .collect(),
685                );
686            }
687        }
688
689        let tag = format!("SELECT {}", rows.len());
690        Ok(SqlResult::Rows {
691            tag,
692            rows,
693            desc: Description::from(desc),
694            notices: make_notices(client),
695        })
696    }
697
698    fn err(client: &mut SessionClient, error: impl Into<SqlError>) -> SqlResult {
699        SqlResult::Err {
700            error: error.into(),
701            notices: make_notices(client),
702        }
703    }
704
705    fn ok(client: &mut SessionClient, tag: String, params: Vec<ParameterStatus>) -> SqlResult {
706        SqlResult::Ok {
707            ok: tag,
708            parameters: params,
709            notices: make_notices(client),
710        }
711    }
712}
713
714#[derive(Debug, Deserialize, Serialize)]
715pub struct SqlError {
716    pub message: String,
717    pub code: String,
718    #[serde(skip_serializing_if = "Option::is_none")]
719    pub detail: Option<String>,
720    #[serde(skip_serializing_if = "Option::is_none")]
721    pub hint: Option<String>,
722    #[serde(skip_serializing_if = "Option::is_none")]
723    pub position: Option<usize>,
724}
725
726impl From<Error> for SqlError {
727    fn from(err: Error) -> Self {
728        SqlError {
729            message: err.to_string(),
730            code: err.code().code().to_string(),
731            detail: err.detail(),
732            hint: err.hint(),
733            position: err.position(),
734        }
735    }
736}
737
738impl From<AdapterError> for SqlError {
739    fn from(value: AdapterError) -> Self {
740        Error::from(value).into()
741    }
742}
743
744#[derive(Debug, Deserialize, Serialize)]
745#[serde(tag = "type", content = "payload")]
746pub enum WebSocketResponse {
747    ReadyForQuery(String),
748    Notice(Notice),
749    Rows(Description),
750    Row(Vec<serde_json::Value>),
751    CommandStarting(CommandStarting),
752    CommandComplete(String),
753    Error(SqlError),
754    ParameterStatus(ParameterStatus),
755    BackendKeyData(BackendKeyData),
756}
757
758#[derive(Debug, Serialize, Deserialize)]
759pub struct Notice {
760    message: String,
761    code: String,
762    severity: String,
763    #[serde(skip_serializing_if = "Option::is_none")]
764    pub detail: Option<String>,
765    #[serde(skip_serializing_if = "Option::is_none")]
766    pub hint: Option<String>,
767}
768
769impl Notice {
770    pub fn message(&self) -> &str {
771        &self.message
772    }
773}
774
775#[derive(Debug, Serialize, Deserialize)]
776pub struct Description {
777    pub columns: Vec<Column>,
778}
779
780impl From<&RelationDesc> for Description {
781    fn from(desc: &RelationDesc) -> Self {
782        let columns = desc
783            .iter()
784            .map(|(name, typ)| {
785                let pg_type = mz_pgrepr::Type::from(&typ.scalar_type);
786                Column {
787                    name: name.to_string(),
788                    type_oid: pg_type.oid(),
789                    type_len: pg_type.typlen(),
790                    type_mod: pg_type.typmod(),
791                }
792            })
793            .collect();
794        Description { columns }
795    }
796}
797
798#[derive(Debug, Serialize, Deserialize)]
799pub struct Column {
800    pub name: String,
801    pub type_oid: u32,
802    pub type_len: i16,
803    pub type_mod: i32,
804}
805
806#[derive(Debug, Serialize, Deserialize)]
807pub struct ParameterStatus {
808    name: String,
809    value: String,
810}
811
812#[derive(Debug, Serialize, Deserialize)]
813pub struct BackendKeyData {
814    conn_id: u32,
815    secret_key: u32,
816}
817
818#[derive(Debug, Serialize, Deserialize)]
819pub struct CommandStarting {
820    has_rows: bool,
821    is_streaming: bool,
822}
823
824/// Trait describing how to transmit a response to a client. HTTP clients
825/// accumulate into a Vec and send all at once. WebSocket clients send each
826/// message as they occur.
827#[async_trait]
828pub(in crate::http) trait ResultSender: Send {
829    const SUPPORTS_STREAMING_NOTICES: bool = false;
830
831    /// Adds a result to the client. The first component of the return value is
832    /// Err if sending to the client
833    /// produced an error and the server should disconnect. It is Ok(Err) if the statement
834    /// produced an error and should error the transaction, but remain connected. It is Ok(Ok(()))
835    /// if the statement succeeded.
836    /// The second component of the return value is `Some` if execution still
837    /// needs to be retired for statement logging purposes.
838    async fn add_result(
839        &mut self,
840        client: &mut SessionClient,
841        res: StatementResult,
842    ) -> (
843        Result<Result<(), ()>, Error>,
844        Option<(StatementEndedExecutionReason, ExecuteContextGuard)>,
845    );
846
847    /// Returns a future that resolves only when the client connection has gone away.
848    fn connection_error(&mut self) -> BoxFuture<'_, Error>;
849    /// Reports whether the client supports streaming SUBSCRIBE results.
850    fn allow_subscribe(&self) -> bool;
851
852    /// Emits a streaming notice if the sender supports it.
853    ///
854    /// Does nothing if `SUPPORTS_STREAMING_NOTICES` is false.
855    async fn emit_streaming_notices(&mut self, _: Vec<AdapterNotice>) -> Result<(), Error> {
856        unreachable!("streaming notices marked as unsupported")
857    }
858}
859
860#[async_trait]
861impl ResultSender for SqlResponse {
862    // The first component of the return value is
863    // Err if sending to the client
864    // produced an error and the server should disconnect. It is Ok(Err) if the statement
865    // produced an error and should error the transaction, but remain connected. It is Ok(Ok(()))
866    // if the statement succeeded.
867    // The second component of the return value is `Some` if execution still
868    // needs to be retired for statement logging purposes.
869    async fn add_result(
870        &mut self,
871        _client: &mut SessionClient,
872        res: StatementResult,
873    ) -> (
874        Result<Result<(), ()>, Error>,
875        Option<(StatementEndedExecutionReason, ExecuteContextGuard)>,
876    ) {
877        let (res, stmt_logging) = match res {
878            StatementResult::SqlResult(res) => {
879                let is_err = matches!(res, SqlResult::Err { .. });
880                self.results.push(res);
881                let res = if is_err { Err(()) } else { Ok(()) };
882                (res, None)
883            }
884            StatementResult::Subscribe { ctx_extra, .. } => {
885                let message = "SUBSCRIBE only supported over websocket";
886                self.results.push(SqlResult::Err {
887                    error: Error::SubscribeOnlyOverWs.into(),
888                    notices: Vec::new(),
889                });
890                (
891                    Err(()),
892                    Some((
893                        StatementEndedExecutionReason::Errored {
894                            error: message.into(),
895                        },
896                        ctx_extra,
897                    )),
898                )
899            }
900        };
901        (Ok(res), stmt_logging)
902    }
903
904    fn connection_error(&mut self) -> BoxFuture<'_, Error> {
905        Box::pin(futures::future::pending())
906    }
907
908    fn allow_subscribe(&self) -> bool {
909        false
910    }
911}
912
913#[async_trait]
914impl ResultSender for WebSocket {
915    const SUPPORTS_STREAMING_NOTICES: bool = true;
916
917    // The first component of the return value is Err if sending to the client produced an error and
918    // the server should disconnect. It is Ok(Err) if the statement produced an error and should
919    // error the transaction, but remain connected. It is Ok(Ok(())) if the statement succeeded. The
920    // second component of the return value is `Some` if execution still needs to be retired for
921    // statement logging purposes.
922    async fn add_result(
923        &mut self,
924        client: &mut SessionClient,
925        res: StatementResult,
926    ) -> (
927        Result<Result<(), ()>, Error>,
928        Option<(StatementEndedExecutionReason, ExecuteContextGuard)>,
929    ) {
930        let (has_rows, is_streaming) = match res {
931            StatementResult::SqlResult(SqlResult::Err { .. }) => (false, false),
932            StatementResult::SqlResult(SqlResult::Ok { .. }) => (false, false),
933            StatementResult::SqlResult(SqlResult::Rows { .. }) => (true, false),
934            StatementResult::Subscribe { .. } => (true, true),
935        };
936        if let Err(e) = send_ws_response(
937            self,
938            WebSocketResponse::CommandStarting(CommandStarting {
939                has_rows,
940                is_streaming,
941            }),
942        )
943        .await
944        {
945            return (Err(e), None);
946        }
947
948        let (is_err, msgs, stmt_logging) = match res {
949            StatementResult::SqlResult(SqlResult::Rows {
950                tag,
951                rows,
952                desc,
953                notices,
954            }) => {
955                // Stream rows directly to avoid buffering all rows as
956                // WebSocketResponse, which inflates memory due to enum sizing.
957                if let Err(e) = send_ws_response(self, WebSocketResponse::Rows(desc)).await {
958                    return (Err(e), None);
959                }
960                for row in rows {
961                    if let Err(e) = send_ws_response(self, WebSocketResponse::Row(row)).await {
962                        return (Err(e), None);
963                    }
964                }
965                let mut msgs = vec![WebSocketResponse::CommandComplete(tag)];
966                msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
967                (false, msgs, None)
968            }
969            StatementResult::SqlResult(SqlResult::Ok {
970                ok,
971                parameters,
972                notices,
973            }) => {
974                let mut msgs = vec![WebSocketResponse::CommandComplete(ok)];
975                msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
976                msgs.extend(
977                    parameters
978                        .into_iter()
979                        .map(WebSocketResponse::ParameterStatus),
980                );
981                (false, msgs, None)
982            }
983            StatementResult::SqlResult(SqlResult::Err { error, notices }) => {
984                let mut msgs = vec![WebSocketResponse::Error(error)];
985                msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
986                (true, msgs, None)
987            }
988            StatementResult::Subscribe {
989                ref desc,
990                tag,
991                mut rx,
992                ctx_extra,
993            } => {
994                if let Err(e) = send_ws_response(self, WebSocketResponse::Rows(desc.into())).await {
995                    // We consider the remote breaking the connection to be a cancellation,
996                    // matching the behavior for pgwire
997                    return (
998                        Err(e),
999                        Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1000                    );
1001                }
1002
1003                let mut datum_vec = mz_repr::DatumVec::new();
1004                let mut result_size: usize = 0;
1005                let mut rows_returned = 0;
1006                loop {
1007                    let res = match await_rows(self, client, rx.recv()).await {
1008                        Ok(res) => res,
1009                        Err(e) => {
1010                            // We consider the remote breaking the connection to be a cancellation,
1011                            // matching the behavior for pgwire
1012                            return (
1013                                Err(e),
1014                                Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1015                            );
1016                        }
1017                    };
1018                    match res {
1019                        Some(PeekResponseUnary::Rows(mut rows)) => {
1020                            if let Err(err) = verify_datum_desc(desc, &mut rows) {
1021                                let error = err.to_string();
1022                                break (
1023                                    true,
1024                                    vec![WebSocketResponse::Error(err.into())],
1025                                    Some((
1026                                        StatementEndedExecutionReason::Errored { error },
1027                                        ctx_extra,
1028                                    )),
1029                                );
1030                            }
1031
1032                            rows_returned += rows.count();
1033                            while let Some(row) = rows.next() {
1034                                result_size += row.byte_len();
1035                                let datums = datum_vec.borrow_with(row);
1036                                let types = &desc.typ().column_types;
1037                                if let Err(e) = send_ws_response(
1038                                    self,
1039                                    WebSocketResponse::Row(
1040                                        datums
1041                                            .iter()
1042                                            .enumerate()
1043                                            .map(|(i, d)| {
1044                                                TypedDatum::new(*d, &types[i])
1045                                                    .json(&JsonNumberPolicy::ConvertNumberToString)
1046                                            })
1047                                            .collect(),
1048                                    ),
1049                                )
1050                                .await
1051                                {
1052                                    // We consider the remote breaking the connection to be a cancellation,
1053                                    // matching the behavior for pgwire
1054                                    return (
1055                                        Err(e),
1056                                        Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1057                                    );
1058                                }
1059                            }
1060                        }
1061                        Some(PeekResponseUnary::Error(error)) => {
1062                            break (
1063                                true,
1064                                vec![WebSocketResponse::Error(
1065                                    Error::Unstructured(anyhow!(error.clone())).into(),
1066                                )],
1067                                Some((StatementEndedExecutionReason::Errored { error }, ctx_extra)),
1068                            );
1069                        }
1070                        Some(PeekResponseUnary::DependencyDropped(dep)) => {
1071                            let err = dep.to_concurrent_dependency_drop();
1072                            let error = err.to_string();
1073                            break (
1074                                true,
1075                                vec![WebSocketResponse::Error(err.into())],
1076                                Some((StatementEndedExecutionReason::Errored { error }, ctx_extra)),
1077                            );
1078                        }
1079                        Some(PeekResponseUnary::Canceled) => {
1080                            break (
1081                                true,
1082                                vec![WebSocketResponse::Error(AdapterError::Canceled.into())],
1083                                Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1084                            );
1085                        }
1086                        None => {
1087                            break (
1088                                false,
1089                                vec![WebSocketResponse::CommandComplete(tag)],
1090                                Some((
1091                                    StatementEndedExecutionReason::Success {
1092                                        result_size: Some(u64::cast_from(result_size)),
1093                                        rows_returned: Some(u64::cast_from(rows_returned)),
1094                                        execution_strategy: Some(
1095                                            StatementExecutionStrategy::Standard,
1096                                        ),
1097                                    },
1098                                    ctx_extra,
1099                                )),
1100                            );
1101                        }
1102                    }
1103                }
1104            }
1105        };
1106        for msg in msgs {
1107            if let Err(e) = send_ws_response(self, msg).await {
1108                return (
1109                    Err(e),
1110                    stmt_logging.map(|(_old_reason, ctx_extra)| {
1111                        (StatementEndedExecutionReason::Canceled, ctx_extra)
1112                    }),
1113                );
1114            }
1115        }
1116        (Ok(if is_err { Err(()) } else { Ok(()) }), stmt_logging)
1117    }
1118
1119    // Send a websocket Ping every second to verify the client is still
1120    // connected.
1121    fn connection_error(&mut self) -> BoxFuture<'_, Error> {
1122        Box::pin(async {
1123            let mut tick = time::interval(Duration::from_secs(1));
1124            tick.tick().await;
1125            loop {
1126                tick.tick().await;
1127                if let Err(err) = self.send(Message::Ping(Vec::new().into())).await {
1128                    return err.into();
1129                }
1130            }
1131        })
1132    }
1133
1134    fn allow_subscribe(&self) -> bool {
1135        true
1136    }
1137
1138    async fn emit_streaming_notices(&mut self, notices: Vec<AdapterNotice>) -> Result<(), Error> {
1139        forward_notices(self, notices).await
1140    }
1141}
1142
1143async fn await_rows<S, F, R>(sender: &mut S, client: &mut SessionClient, f: F) -> Result<R, Error>
1144where
1145    S: ResultSender,
1146    F: Future<Output = R> + Send,
1147{
1148    let mut f = pin!(f);
1149    loop {
1150        tokio::select! {
1151            notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
1152                sender.emit_streaming_notices(vec![notice]).await?;
1153            }
1154            e = sender.connection_error() => return Err(e),
1155            r = &mut f => return Ok(r),
1156        }
1157    }
1158}
1159
1160async fn send_and_retire<S: ResultSender>(
1161    res: StatementResult,
1162    client: &mut SessionClient,
1163    sender: &mut S,
1164) -> Result<Result<(), ()>, Error> {
1165    let (res, stmt_logging) = sender.add_result(client, res).await;
1166    if let Some((reason, ctx_extra)) = stmt_logging {
1167        client.retire_execute(ctx_extra, reason);
1168    }
1169    res
1170}
1171
1172/// Returns Ok(Err) if any statement error'd during execution.
1173async fn execute_stmt_group<S: ResultSender>(
1174    client: &mut SessionClient,
1175    sender: &mut S,
1176    stmt_group: Vec<(Statement<Raw>, String, Vec<Option<String>>)>,
1177) -> Result<Result<(), ()>, Error> {
1178    let num_stmts = stmt_group.len();
1179    for (stmt, sql, params) in stmt_group {
1180        assert!(
1181            num_stmts <= 1 || params.is_empty(),
1182            "statement groups contain more than 1 statement iff Simple request, which does not support parameters"
1183        );
1184
1185        let is_aborted_txn = matches!(client.session().transaction(), TransactionStatus::Failed(_));
1186        if is_aborted_txn && !is_txn_exit_stmt(&stmt) {
1187            let err = SqlResult::err(client, Error::AbortedTransaction);
1188            let _ = send_and_retire(err.into(), client, sender).await?;
1189            return Ok(Err(()));
1190        }
1191
1192        // Mirror the behavior of the PostgreSQL simple query protocol.
1193        // See the pgwire::protocol::StateMachine::query method for details.
1194        if let Err(e) = client.start_transaction(Some(num_stmts)) {
1195            let err = SqlResult::err(client, e);
1196            let _ = send_and_retire(err.into(), client, sender).await?;
1197            return Ok(Err(()));
1198        }
1199        let res = execute_stmt(client, sender, stmt, sql, params).await?;
1200        let is_err = send_and_retire(res, client, sender).await?;
1201
1202        if is_err.is_err() {
1203            // Mirror StateMachine::error, which sometimes will clean up the
1204            // transaction state instead of always leaving it in Failed.
1205            let txn = client.session().transaction();
1206            match txn {
1207                // Error can be called from describe and parse and so might not be in an active
1208                // transaction.
1209                TransactionStatus::Default | TransactionStatus::Failed(_) => {}
1210                // In Started (i.e., a single statement) and implicit transactions cleanup themselves.
1211                TransactionStatus::Started(_) | TransactionStatus::InTransactionImplicit(_) => {
1212                    if let Err(err) = client.end_transaction(EndTransactionAction::Rollback).await {
1213                        let err = SqlResult::err(client, err);
1214                        let _ = send_and_retire(err.into(), client, sender).await?;
1215                    }
1216                }
1217                // Explicit transactions move to failed.
1218                TransactionStatus::InTransaction(_) => {
1219                    client.fail_transaction();
1220                }
1221            }
1222            return Ok(Err(()));
1223        }
1224    }
1225    Ok(Ok(()))
1226}
1227
1228/// Executes an entire [`SqlRequest`].
1229///
1230/// See the user-facing documentation about the HTTP API for a description of
1231/// the semantics of this function.
1232/// Executes a SQL request and sends results to the provided sender.
1233///
1234/// Made visible to http submodules (like mcp) via `pub(in crate::http)` to allow
1235/// reuse of SQL execution logic.
1236pub(in crate::http) async fn execute_request<S: ResultSender>(
1237    client: &mut AuthedClient,
1238    request: SqlRequest,
1239    sender: &mut S,
1240) -> Result<(), Error> {
1241    let client = &mut client.client;
1242
1243    // This API prohibits executing statements with responses whose
1244    // semantics are at odds with an HTTP response.
1245    fn check_prohibited_stmts<S: ResultSender>(
1246        sender: &S,
1247        stmt: &Statement<Raw>,
1248    ) -> Result<(), Error> {
1249        let kind: StatementKind = stmt.into();
1250        let execute_responses = Plan::generated_from(&kind)
1251            .into_iter()
1252            .map(ExecuteResponse::generated_from)
1253            .flatten()
1254            .collect::<Vec<_>>();
1255
1256        // Special-case `COPY TO` statements that are not `COPY ... TO STDOUT`, since
1257        // StatementKind::Copy links to several `ExecuteResponseKind`s that are not supported,
1258        // but this specific statement should be allowed.
1259        let is_valid_copy = matches!(
1260            stmt,
1261            Statement::Copy(CopyStatement {
1262                direction: CopyDirection::To,
1263                target: CopyTarget::Expr(_),
1264                ..
1265            }) | Statement::Copy(CopyStatement {
1266                direction: CopyDirection::From,
1267                target: CopyTarget::Expr(_),
1268                ..
1269            })
1270        );
1271
1272        if !is_valid_copy
1273            && execute_responses.iter().any(|execute_response| {
1274                // Returns true if a statement or execute response are unsupported.
1275                match execute_response {
1276                    ExecuteResponseKind::Subscribing if sender.allow_subscribe() => false,
1277                    ExecuteResponseKind::Fetch
1278                    | ExecuteResponseKind::Subscribing
1279                    | ExecuteResponseKind::CopyFrom
1280                    | ExecuteResponseKind::DeclaredCursor
1281                    | ExecuteResponseKind::ClosedCursor => true,
1282                    // Various statements generate `PeekPlan` (`SELECT`, `COPY`,
1283                    // `EXPLAIN`, `SHOW`) which has both `SendRows` and `CopyTo` as its
1284                    // possible response types. but `COPY` needs be picked out because
1285                    // http don't support its response type
1286                    ExecuteResponseKind::CopyTo if matches!(kind, StatementKind::Copy) => true,
1287                    _ => false,
1288                }
1289            })
1290        {
1291            return Err(Error::Unsupported(stmt.to_ast_string_simple()));
1292        }
1293        Ok(())
1294    }
1295
1296    fn parse<'a>(
1297        client: &SessionClient,
1298        query: &'a str,
1299    ) -> Result<Vec<StatementParseResult<'a>>, Error> {
1300        let result = client
1301            .parse(query)
1302            .map_err(|e| Error::Unstructured(anyhow!(e)))?;
1303        result.map_err(|e| AdapterError::from(e).into())
1304    }
1305
1306    let mut stmt_groups = vec![];
1307
1308    match request {
1309        SqlRequest::Simple { query } => match parse(client, query.as_str()) {
1310            Ok(stmts) => {
1311                let mut stmt_group = Vec::with_capacity(stmts.len());
1312                let mut stmt_err = None;
1313                for StatementParseResult { ast: stmt, sql } in stmts {
1314                    if let Err(err) = check_prohibited_stmts(sender, &stmt) {
1315                        stmt_err = Some(err);
1316                        break;
1317                    }
1318                    stmt_group.push((stmt, sql.to_string(), vec![]));
1319                }
1320                stmt_groups.push(stmt_err.map(Err).unwrap_or_else(|| Ok(stmt_group)));
1321            }
1322            Err(e) => stmt_groups.push(Err(e)),
1323        },
1324        SqlRequest::Extended { queries } => {
1325            for ExtendedRequest { query, params } in queries {
1326                match parse(client, &query) {
1327                    Ok(mut stmts) => {
1328                        if stmts.len() != 1 {
1329                            return Err(Error::Unstructured(anyhow!(
1330                                "each query must contain exactly 1 statement, but \"{}\" contains {}",
1331                                query,
1332                                stmts.len()
1333                            )));
1334                        }
1335
1336                        let StatementParseResult { ast: stmt, sql } = stmts.pop().unwrap();
1337                        stmt_groups.push(
1338                            check_prohibited_stmts(sender, &stmt)
1339                                .map(|_| vec![(stmt, sql.to_string(), params)]),
1340                        );
1341                    }
1342                    Err(e) => stmt_groups.push(Err(e)),
1343                };
1344            }
1345        }
1346    }
1347
1348    for stmt_group_res in stmt_groups {
1349        let executed = match stmt_group_res {
1350            Ok(stmt_group) => execute_stmt_group(client, sender, stmt_group).await,
1351            Err(e) => {
1352                let err = SqlResult::err(client, e);
1353                let _ = send_and_retire(err.into(), client, sender).await?;
1354                Ok(Err(()))
1355            }
1356        };
1357        // At the end of each group, commit implicit transactions. Do that here so that any `?`
1358        // early return can still be handled here.
1359        if client.session().transaction().is_implicit() {
1360            let ended = client.end_transaction(EndTransactionAction::Commit).await;
1361            if let Err(err) = ended {
1362                let err = SqlResult::err(client, err);
1363                let _ = send_and_retire(StatementResult::SqlResult(err), client, sender).await?;
1364            }
1365        }
1366        if executed?.is_err() {
1367            break;
1368        }
1369    }
1370
1371    Ok(())
1372}
1373
1374/// Executes a single statement in a [`SqlRequest`].
1375async fn execute_stmt<S: ResultSender>(
1376    client: &mut SessionClient,
1377    sender: &mut S,
1378    stmt: Statement<Raw>,
1379    sql: String,
1380    raw_params: Vec<Option<String>>,
1381) -> Result<StatementResult, Error> {
1382    const EMPTY_PORTAL: &str = "";
1383    if let Err(e) = client
1384        .prepare(EMPTY_PORTAL.into(), Some(stmt.clone()), sql, vec![])
1385        .await
1386    {
1387        return Ok(SqlResult::err(client, e).into());
1388    }
1389
1390    let prep_stmt = match client.get_prepared_statement(EMPTY_PORTAL).await {
1391        Ok(stmt) => stmt,
1392        Err(err) => {
1393            return Ok(SqlResult::err(client, err).into());
1394        }
1395    };
1396
1397    let param_types = &prep_stmt.desc().param_types;
1398    if param_types.len() != raw_params.len() {
1399        let message = anyhow!(
1400            "request supplied {actual} parameters, \
1401                        but {statement} requires {expected}",
1402            statement = stmt.to_ast_string_simple(),
1403            actual = raw_params.len(),
1404            expected = param_types.len()
1405        );
1406        return Ok(SqlResult::err(client, Error::Unstructured(message)).into());
1407    }
1408
1409    let buf = RowArena::new();
1410    let mut params = vec![];
1411    for (raw_param, mz_typ) in raw_params.into_iter().zip_eq(param_types) {
1412        let pg_typ = mz_pgrepr::Type::from(mz_typ);
1413        let datum = match raw_param {
1414            None => Datum::Null,
1415            Some(raw_param) => {
1416                match mz_pgrepr::Value::decode(
1417                    mz_pgwire_common::Format::Text,
1418                    &pg_typ,
1419                    raw_param.as_bytes(),
1420                ) {
1421                    Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1422                        Ok(datum) => datum,
1423                        Err(msg) => {
1424                            return Ok(
1425                                SqlResult::err(client, Error::Unstructured(anyhow!(msg))).into()
1426                            );
1427                        }
1428                    },
1429                    Err(err) => {
1430                        let msg = anyhow!("unable to decode parameter: {}", err);
1431                        return Ok(SqlResult::err(client, Error::Unstructured(msg)).into());
1432                    }
1433                }
1434            }
1435        };
1436        params.push((datum, mz_typ.clone()))
1437    }
1438
1439    let result_formats = vec![
1440        mz_pgwire_common::Format::Text;
1441        prep_stmt
1442            .desc()
1443            .relation_desc
1444            .clone()
1445            .map(|desc| desc.typ().column_types.len())
1446            .unwrap_or(0)
1447    ];
1448
1449    let desc = prep_stmt.desc().clone();
1450    let logging = Arc::clone(prep_stmt.logging());
1451    let stmt_ast = prep_stmt.stmt().cloned();
1452    let state_revision = prep_stmt.state_revision;
1453    if let Err(err) = client.session().set_portal(
1454        EMPTY_PORTAL.into(),
1455        desc,
1456        stmt_ast,
1457        logging,
1458        params,
1459        result_formats,
1460        state_revision,
1461    ) {
1462        return Ok(SqlResult::err(client, err).into());
1463    }
1464
1465    let desc = client
1466        .session()
1467        // We do not need to verify here because `client.execute` verifies below.
1468        .get_portal_unverified(EMPTY_PORTAL)
1469        .map(|portal| portal.desc.clone())
1470        .expect("unnamed portal should be present");
1471
1472    let res = client
1473        .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
1474        .await;
1475
1476    if S::SUPPORTS_STREAMING_NOTICES {
1477        sender
1478            .emit_streaming_notices(client.session().drain_notices())
1479            .await?;
1480    }
1481
1482    let (res, execute_started) = match res {
1483        Ok(res) => res,
1484        Err(e) => {
1485            return Ok(SqlResult::err(client, e).into());
1486        }
1487    };
1488    let tag = res.tag();
1489
1490    Ok(match res {
1491        ExecuteResponse::CreatedConnection { .. }
1492        | ExecuteResponse::CreatedDatabase { .. }
1493        | ExecuteResponse::CreatedSchema { .. }
1494        | ExecuteResponse::CreatedRole
1495        | ExecuteResponse::CreatedCluster { .. }
1496        | ExecuteResponse::CreatedClusterReplica { .. }
1497        | ExecuteResponse::CreatedTable { .. }
1498        | ExecuteResponse::CreatedIndex { .. }
1499        | ExecuteResponse::CreatedIntrospectionSubscribe
1500        | ExecuteResponse::CreatedSecret { .. }
1501        | ExecuteResponse::CreatedSource { .. }
1502        | ExecuteResponse::CreatedSink { .. }
1503        | ExecuteResponse::CreatedView { .. }
1504        | ExecuteResponse::CreatedViews { .. }
1505        | ExecuteResponse::CreatedMaterializedView { .. }
1506        | ExecuteResponse::CreatedType
1507        | ExecuteResponse::CreatedNetworkPolicy
1508        | ExecuteResponse::Comment
1509        | ExecuteResponse::Deleted(_)
1510        | ExecuteResponse::DiscardedTemp
1511        | ExecuteResponse::DiscardedAll
1512        | ExecuteResponse::DroppedObject(_)
1513        | ExecuteResponse::DroppedOwned
1514        | ExecuteResponse::EmptyQuery
1515        | ExecuteResponse::GrantedPrivilege
1516        | ExecuteResponse::GrantedRole
1517        | ExecuteResponse::Inserted(_)
1518        | ExecuteResponse::Copied(_)
1519        | ExecuteResponse::Raised
1520        | ExecuteResponse::ReassignOwned
1521        | ExecuteResponse::RevokedPrivilege
1522        | ExecuteResponse::AlteredDefaultPrivileges
1523        | ExecuteResponse::RevokedRole
1524        | ExecuteResponse::StartedTransaction { .. }
1525        | ExecuteResponse::Updated(_)
1526        | ExecuteResponse::AlteredObject(_)
1527        | ExecuteResponse::AlteredRole
1528        | ExecuteResponse::AlteredSystemConfiguration
1529        | ExecuteResponse::Deallocate { .. }
1530        | ExecuteResponse::ValidatedConnection
1531        | ExecuteResponse::Prepare => SqlResult::ok(
1532            client,
1533            tag.expect("ok only called on tag-generating results"),
1534            Vec::default(),
1535        )
1536        .into(),
1537        ExecuteResponse::TransactionCommitted { params }
1538        | ExecuteResponse::TransactionRolledBack { params } => {
1539            let notify_set: mz_ore::collections::HashSet<_> = client
1540                .session()
1541                .vars()
1542                .notify_set()
1543                .map(|v| v.name().to_string())
1544                .collect();
1545            let params = params
1546                .into_iter()
1547                .filter(|(name, _value)| notify_set.contains(*name))
1548                .map(|(name, value)| ParameterStatus {
1549                    name: name.to_string(),
1550                    value,
1551                })
1552                .collect();
1553            SqlResult::ok(
1554                client,
1555                tag.expect("ok only called on tag-generating results"),
1556                params,
1557            )
1558            .into()
1559        }
1560        ExecuteResponse::SetVariable { name, .. } => {
1561            let mut params = Vec::with_capacity(1);
1562            if let Some(var) = client
1563                .session()
1564                .vars()
1565                .notify_set()
1566                .find(|v| v.name() == &name)
1567            {
1568                params.push(ParameterStatus {
1569                    name,
1570                    value: var.value(),
1571                });
1572            };
1573            SqlResult::ok(
1574                client,
1575                tag.expect("ok only called on tag-generating results"),
1576                params,
1577            )
1578            .into()
1579        }
1580        ExecuteResponse::SendingRowsStreaming {
1581            rows,
1582            instance_id,
1583            strategy,
1584        } => {
1585            let max_query_result_size =
1586                usize::cast_from(client.get_system_vars().await.max_result_size());
1587
1588            let rows_stream = RecordFirstRowStream::new(
1589                Box::new(rows),
1590                execute_started,
1591                client,
1592                Some(instance_id),
1593                Some(strategy),
1594            );
1595
1596            SqlResult::rows(
1597                sender,
1598                client,
1599                rows_stream,
1600                max_query_result_size,
1601                &desc.relation_desc.expect("RelationDesc must exist"),
1602            )
1603            .await?
1604            .into()
1605        }
1606        ExecuteResponse::SendingRowsImmediate { rows } => {
1607            let max_query_result_size =
1608                usize::cast_from(client.get_system_vars().await.max_result_size());
1609
1610            let rows = futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1611            let rows_stream =
1612                RecordFirstRowStream::new(Box::new(rows), execute_started, client, None, None);
1613
1614            SqlResult::rows(
1615                sender,
1616                client,
1617                rows_stream,
1618                max_query_result_size,
1619                &desc.relation_desc.expect("RelationDesc must exist"),
1620            )
1621            .await?
1622            .into()
1623        }
1624        ExecuteResponse::Subscribing {
1625            rx,
1626            ctx_extra,
1627            instance_id,
1628        } => StatementResult::Subscribe {
1629            tag: "SUBSCRIBE".into(),
1630            desc: desc.relation_desc.unwrap(),
1631            rx: RecordFirstRowStream::new(
1632                Box::new(UnboundedReceiverStream::new(rx)),
1633                execute_started,
1634                client,
1635                Some(instance_id),
1636                None,
1637            ),
1638            ctx_extra,
1639        },
1640        res @ (ExecuteResponse::Fetch { .. }
1641        | ExecuteResponse::CopyTo { .. }
1642        | ExecuteResponse::CopyFrom { .. }
1643        | ExecuteResponse::DeclaredCursor
1644        | ExecuteResponse::ClosedCursor) => SqlResult::err(
1645            client,
1646            Error::Unstructured(anyhow!(
1647                "internal error: encountered prohibited ExecuteResponse {:?}.\n\n
1648            This is a bug. Can you please file an bug report letting us know?\n
1649            https://github.com/MaterializeInc/materialize/discussions/new?category=bug-reports",
1650                ExecuteResponseKind::from(res)
1651            )),
1652        )
1653        .into(),
1654    })
1655}
1656
1657fn make_notices(client: &mut SessionClient) -> Vec<Notice> {
1658    client
1659        .session()
1660        .drain_notices()
1661        .into_iter()
1662        .map(|notice| Notice {
1663            message: notice.to_string(),
1664            code: notice.code().code().to_string(),
1665            severity: notice.severity().as_str().to_lowercase(),
1666            detail: notice.detail(),
1667            hint: notice.hint(),
1668        })
1669        .collect()
1670}
1671
1672// Duplicated from protocol.rs.
1673// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
1674fn is_txn_exit_stmt(stmt: &Statement<Raw>) -> bool {
1675    matches!(
1676        stmt,
1677        Statement::Commit(_) | Statement::Rollback(_) | Statement::Prepare(_)
1678    )
1679}
1680
1681#[cfg(test)]
1682mod tests {
1683    use std::collections::BTreeMap;
1684
1685    use super::{Password, WebSocketAuth};
1686
1687    #[mz_ore::test]
1688    fn smoke_test_websocket_auth_parse() {
1689        struct TestCase {
1690            json: &'static str,
1691            expected: WebSocketAuth,
1692        }
1693
1694        let test_cases = vec![
1695            TestCase {
1696                json: r#"{ "user": "mz", "password": "1234" }"#,
1697                expected: WebSocketAuth::Basic {
1698                    user: "mz".to_string(),
1699                    password: Password("1234".to_string()),
1700                    options: BTreeMap::default(),
1701                },
1702            },
1703            TestCase {
1704                json: r#"{ "user": "mz", "password": "1234", "options": {} }"#,
1705                expected: WebSocketAuth::Basic {
1706                    user: "mz".to_string(),
1707                    password: Password("1234".to_string()),
1708                    options: BTreeMap::default(),
1709                },
1710            },
1711            TestCase {
1712                json: r#"{ "token": "i_am_a_token" }"#,
1713                expected: WebSocketAuth::Bearer {
1714                    token: "i_am_a_token".to_string(),
1715                    options: BTreeMap::default(),
1716                },
1717            },
1718            TestCase {
1719                json: r#"{ "token": "i_am_a_token", "options": { "foo": "bar" } }"#,
1720                expected: WebSocketAuth::Bearer {
1721                    token: "i_am_a_token".to_string(),
1722                    options: BTreeMap::from([("foo".to_string(), "bar".to_string())]),
1723                },
1724            },
1725        ];
1726
1727        fn assert_parse(json: &'static str, expected: WebSocketAuth) {
1728            let parsed: WebSocketAuth = serde_json::from_str(json).unwrap();
1729            assert_eq!(parsed, expected);
1730        }
1731
1732        for TestCase { json, expected } in test_cases {
1733            assert_parse(json, expected)
1734        }
1735    }
1736}