Skip to main content

mz_environmentd/http/
sql.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::net::{IpAddr, SocketAddr};
12use std::pin::pin;
13use std::sync::Arc;
14use std::time::Duration;
15
16use anyhow::anyhow;
17use async_trait::async_trait;
18use axum::extract::connect_info::ConnectInfo;
19use axum::extract::ws::{CloseFrame, Message, Utf8Bytes, WebSocket};
20use axum::extract::{State, WebSocketUpgrade};
21use axum::response::IntoResponse;
22use axum::{Extension, Json};
23use futures::Future;
24use futures::future::BoxFuture;
25
26use http::StatusCode;
27use itertools::Itertools;
28use mz_adapter::client::RecordFirstRowStream;
29use mz_adapter::session::{EndTransactionAction, TransactionStatus};
30use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
31use mz_adapter::{
32    AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, ExecuteResponseKind,
33    PeekResponseUnary, SessionClient, verify_datum_desc,
34};
35use mz_auth::password::Password;
36use mz_catalog::memory::objects::{Cluster, ClusterReplica};
37use mz_interchange::encode::TypedDatum;
38use mz_interchange::json::{JsonNumberPolicy, ToJson};
39use mz_ore::cast::CastFrom;
40use mz_ore::metrics::{MakeCollectorOpts, MetricsRegistry};
41use mz_ore::result::ResultExt;
42use mz_repr::{Datum, RelationDesc, RowArena, RowIterator};
43use mz_sql::ast::display::AstDisplay;
44use mz_sql::ast::{CopyDirection, CopyStatement, CopyTarget, Raw, Statement, StatementKind};
45use mz_sql::parse::StatementParseResult;
46use mz_sql::plan::Plan;
47use mz_sql::session::metadata::SessionMetadata;
48use prometheus::Opts;
49use prometheus::core::{AtomicF64, GenericGaugeVec};
50use serde::{Deserialize, Serialize};
51use tokio::{select, time};
52use tokio_postgres::error::SqlState;
53use tokio_stream::wrappers::UnboundedReceiverStream;
54use tower_sessions::Session as TowerSession;
55use tracing::{debug, info};
56use tungstenite::protocol::frame::coding::CloseCode;
57
58use crate::http::prometheus::PrometheusSqlQuery;
59use crate::http::{
60    AuthError, AuthedClient, AuthedUser, MAX_REQUEST_SIZE, WsState, ensure_session_unexpired,
61    init_ws, maybe_get_authenticated_session,
62};
63
64#[derive(Debug, thiserror::Error)]
65pub enum Error {
66    #[error(transparent)]
67    Adapter(#[from] AdapterError),
68    #[error(transparent)]
69    Json(#[from] serde_json::Error),
70    #[error(transparent)]
71    Axum(#[from] axum::Error),
72    #[error("SUBSCRIBE only supported over websocket")]
73    SubscribeOnlyOverWs,
74    #[error("current transaction is aborted, commands ignored until end of transaction block")]
75    AbortedTransaction,
76    #[error("unsupported via this API: {0}")]
77    Unsupported(String),
78    #[error("{0}")]
79    Unstructured(anyhow::Error),
80}
81
82impl Error {
83    pub fn detail(&self) -> Option<String> {
84        match self {
85            Error::Adapter(err) => err.detail(),
86            _ => None,
87        }
88    }
89
90    pub fn hint(&self) -> Option<String> {
91        match self {
92            Error::Adapter(err) => err.hint(),
93            _ => None,
94        }
95    }
96
97    pub fn position(&self) -> Option<usize> {
98        match self {
99            Error::Adapter(err) => err.position(),
100            _ => None,
101        }
102    }
103
104    pub fn code(&self) -> SqlState {
105        match self {
106            Error::Adapter(err) => err.code(),
107            Error::AbortedTransaction => SqlState::IN_FAILED_SQL_TRANSACTION,
108            _ => SqlState::INTERNAL_ERROR,
109        }
110    }
111}
112
113static PER_REPLICA_LABELS: &[&str] = &["replica_full_name", "instance_id", "replica_id"];
114
115async fn execute_promsql_query(
116    client: &mut AuthedClient,
117    query: &PrometheusSqlQuery<'_>,
118    metrics_registry: &MetricsRegistry,
119    metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
120    cluster: Option<(&Cluster, &ClusterReplica)>,
121) {
122    assert_eq!(query.per_replica, cluster.is_some());
123
124    let mut res = SqlResponse {
125        results: Vec::new(),
126    };
127
128    execute_request(client, query.to_sql_request(cluster), &mut res)
129        .await
130        .expect("valid SQL query");
131
132    let result = match res.results.as_slice() {
133        // Each query issued is preceded by several SET commands
134        // to make sure it is routed to the right cluster replica.
135        [
136            SqlResult::Ok { .. },
137            SqlResult::Ok { .. },
138            SqlResult::Ok { .. },
139            result,
140        ] => result,
141        // Transient errors are fine, like if the cluster or replica
142        // was dropped before the promsql query was executed. We
143        // should not see errors in the steady state.
144        _ => {
145            info!(
146                "error executing prometheus query {}: {:?}",
147                query.metric_name, res
148            );
149            return;
150        }
151    };
152
153    let SqlResult::Rows { desc, rows, .. } = result else {
154        info!(
155            "did not receive rows for SQL query for prometheus metric {}: {:?}, {:?}",
156            query.metric_name, result, cluster
157        );
158        return;
159    };
160
161    let gauge_vec = metrics_by_name
162        .entry(query.metric_name.to_string())
163        .or_insert_with(|| {
164            let mut label_names: Vec<String> = desc
165                .columns
166                .iter()
167                .filter(|col| col.name != query.value_column_name)
168                .map(|col| col.name.clone())
169                .collect();
170
171            if query.per_replica {
172                label_names.extend(PER_REPLICA_LABELS.iter().map(|label| label.to_string()));
173            }
174
175            metrics_registry.register::<GenericGaugeVec<AtomicF64>>(MakeCollectorOpts {
176                opts: Opts::new(query.metric_name, query.help).variable_labels(label_names),
177                buckets: None,
178            })
179        });
180
181    for row in rows {
182        let mut label_values = desc
183            .columns
184            .iter()
185            .zip_eq(row)
186            .filter(|(col, _)| col.name != query.value_column_name)
187            .map(|(_, val)| val.as_str().expect("must be string"))
188            .collect::<Vec<_>>();
189
190        let value = desc
191            .columns
192            .iter()
193            .zip_eq(row)
194            .find(|(col, _)| col.name == query.value_column_name)
195            .map(|(_, val)| val.as_str().unwrap_or("0").parse::<f64>().unwrap_or(0.0))
196            .unwrap_or(0.0);
197
198        match cluster {
199            Some((cluster, replica)) => {
200                let replica_full_name = format!("{}.{}", cluster.name, replica.name);
201                let cluster_id = cluster.id.to_string();
202                let replica_id = replica.replica_id.to_string();
203
204                label_values.push(&replica_full_name);
205                label_values.push(&cluster_id);
206                label_values.push(&replica_id);
207
208                gauge_vec
209                    .get_metric_with_label_values(&label_values)
210                    .expect("valid labels")
211                    .set(value);
212            }
213            None => {
214                gauge_vec
215                    .get_metric_with_label_values(&label_values)
216                    .expect("valid labels")
217                    .set(value);
218            }
219        }
220    }
221}
222
223async fn handle_promsql_query(
224    client: &mut AuthedClient,
225    query: &PrometheusSqlQuery<'_>,
226    metrics_registry: &MetricsRegistry,
227    metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
228) {
229    if !query.per_replica {
230        execute_promsql_query(client, query, metrics_registry, metrics_by_name, None).await;
231        return;
232    }
233
234    let catalog = client.client.catalog_snapshot("handle_promsql_query").await;
235    let clusters: Vec<&Cluster> = catalog.clusters().collect();
236
237    for cluster in clusters {
238        for replica in cluster.replicas() {
239            execute_promsql_query(
240                client,
241                query,
242                metrics_registry,
243                metrics_by_name,
244                Some((cluster, replica)),
245            )
246            .await;
247        }
248    }
249}
250
251pub async fn handle_promsql(
252    mut client: AuthedClient,
253    queries: &[PrometheusSqlQuery<'_>],
254) -> MetricsRegistry {
255    let metrics_registry = MetricsRegistry::new();
256    let mut metrics_by_name = BTreeMap::new();
257
258    for query in queries {
259        handle_promsql_query(&mut client, query, &metrics_registry, &mut metrics_by_name).await;
260    }
261
262    metrics_registry
263}
264
265pub async fn handle_sql(
266    mut client: AuthedClient,
267    Json(request): Json<SqlRequest>,
268) -> impl IntoResponse {
269    let mut res = SqlResponse {
270        results: Vec::new(),
271    };
272    // Don't need to worry about timeouts or resetting cancel here because there is always exactly 1
273    // request.
274    match execute_request(&mut client, request, &mut res).await {
275        Ok(()) => Ok(Json(res)),
276        Err(e) => Err((StatusCode::BAD_REQUEST, e.to_string())),
277    }
278}
279
280#[derive(Debug)]
281pub enum ExistingUser {
282    /// An AuthedUser provided by the
283    /// `x_materialize_user_header_auth` middleware
284    XMaterializeUserHeader(AuthedUser),
285    /// An AuthedUser provided by an authenticated session
286    /// established via [`crate::http::handle_login`].
287    Session(AuthedUser),
288}
289
290pub(crate) async fn handle_sql_ws(
291    State(state): State<WsState>,
292    existing_user: Option<Extension<AuthedUser>>,
293    ws: WebSocketUpgrade,
294    ConnectInfo(addr): ConnectInfo<SocketAddr>,
295    tower_session: Option<Extension<TowerSession>>,
296) -> Result<impl IntoResponse, AuthError> {
297    let session = tower_session.map(|Extension(session)| session);
298    // The `x_materialize_user_header_auth` middleware may have already provided the user for us
299    let user = match existing_user {
300        Some(Extension(user)) => Some(ExistingUser::XMaterializeUserHeader(user)),
301        None => {
302            let session = maybe_get_authenticated_session(session.as_ref()).await;
303            if let Some((session, session_data)) = session {
304                let user = ensure_session_unexpired(session, session_data).await?;
305                Some(ExistingUser::Session(user))
306            } else {
307                None
308            }
309        }
310    };
311
312    let addr = Box::new(addr.ip());
313    Ok(ws
314        .max_message_size(MAX_REQUEST_SIZE)
315        .on_upgrade(|ws| async move { run_ws(state, user, *addr, ws).await }))
316}
317
318#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
319#[serde(untagged)]
320pub enum WebSocketAuth {
321    Basic {
322        user: String,
323        password: Password,
324        #[serde(default)]
325        options: BTreeMap<String, String>,
326    },
327    Bearer {
328        token: String,
329        #[serde(default)]
330        options: BTreeMap<String, String>,
331    },
332    OptionsOnly {
333        #[serde(default)]
334        options: BTreeMap<String, String>,
335    },
336}
337
338async fn run_ws(state: WsState, user: Option<ExistingUser>, peer_addr: IpAddr, mut ws: WebSocket) {
339    let mut client = match init_ws(state, user, peer_addr, &mut ws).await {
340        Ok(client) => client,
341        Err(e) => {
342            // We omit most detail from the error message we send to the client, to
343            // avoid giving attackers unnecessary information during auth. AdapterErrors
344            // are safe to return because they're generated after authentication.
345            debug!("WS request failed init: {}", e);
346            let reason: Utf8Bytes = match e.downcast_ref::<AdapterError>() {
347                Some(error) => error.to_string().into(),
348                None => "unauthorized".to_string().into(),
349            };
350            let _ = ws
351                .send(Message::Close(Some(CloseFrame {
352                    code: CloseCode::Protocol.into(),
353                    reason,
354                })))
355                .await;
356            return;
357        }
358    };
359
360    // Successful auth, send startup messages.
361    let mut msgs = Vec::new();
362    let session = client.client.session();
363    for var in session.vars().notify_set() {
364        msgs.push(WebSocketResponse::ParameterStatus(ParameterStatus {
365            name: var.name().to_string(),
366            value: var.value(),
367        }));
368    }
369    msgs.push(WebSocketResponse::BackendKeyData(BackendKeyData {
370        conn_id: session.conn_id().unhandled(),
371        secret_key: session.secret_key(),
372    }));
373    msgs.push(WebSocketResponse::ReadyForQuery(
374        session.transaction_code().into(),
375    ));
376    for msg in msgs {
377        let _ = ws
378            .send(Message::Text(
379                serde_json::to_string(&msg).expect("must serialize").into(),
380            ))
381            .await;
382    }
383
384    // Send any notices that might have been generated on startup.
385    let notices = session.drain_notices();
386    if let Err(err) = forward_notices(&mut ws, notices).await {
387        debug!("failed to forward notices to WebSocket, {err:?}");
388        return;
389    }
390
391    loop {
392        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
393        let msg = select! {
394            biased;
395
396            // `recv_timeout()` is cancel-safe as per it's docs.
397            Some(timeout) = client.client.recv_timeout() => {
398                client.client.terminate().await;
399                // We must wait for the client to send a request before we can send the error
400                // response. Although this isn't the PG wire protocol, we choose to mirror it by
401                // only sending errors as responses to requests.
402                let _ = ws.recv().await;
403                let err = Error::from(AdapterError::from(timeout));
404                let _ = send_ws_response(&mut ws, WebSocketResponse::Error(err.into())).await;
405                return;
406            },
407            message = ws.recv() => message,
408        };
409
410        client.client.remove_idle_in_transaction_session_timeout();
411
412        let msg = match msg {
413            Some(Ok(msg)) => msg,
414            _ => {
415                // client disconnected
416                return;
417            }
418        };
419
420        let req: Result<SqlRequest, Error> = match msg {
421            Message::Text(data) => serde_json::from_str(&data).err_into(),
422            Message::Binary(data) => serde_json::from_slice(&data).err_into(),
423            // Handled automatically by the server.
424            Message::Ping(_) => {
425                continue;
426            }
427            Message::Pong(_) => {
428                continue;
429            }
430            Message::Close(_) => {
431                return;
432            }
433        };
434
435        // Figure out if we need to send an error, any notices, but always the ready message.
436        let err = match run_ws_request(req, &mut client, &mut ws).await {
437            Ok(()) => None,
438            Err(err) => Some(WebSocketResponse::Error(err.into())),
439        };
440
441        // After running our request, there are several messages we need to send in a
442        // specific order.
443        //
444        // Note: we nest these into a closure so we can centralize our error handling
445        // for when sending over the WebSocket fails. We could also use a try {} block
446        // here, but those aren't stabilized yet.
447        let ws_response = || async {
448            // First respond with any error that might have occurred.
449            if let Some(e_resp) = err {
450                send_ws_response(&mut ws, e_resp).await?;
451            }
452
453            // Then forward along any notices we generated.
454            let notices = client.client.session().drain_notices();
455            forward_notices(&mut ws, notices).await?;
456
457            // Finally, respond that we're ready for the next query.
458            let ready =
459                WebSocketResponse::ReadyForQuery(client.client.session().transaction_code().into());
460            send_ws_response(&mut ws, ready).await?;
461
462            Ok::<_, Error>(())
463        };
464
465        if let Err(err) = ws_response().await {
466            debug!("failed to send response over WebSocket, {err:?}");
467            return;
468        }
469    }
470}
471
472async fn run_ws_request(
473    req: Result<SqlRequest, Error>,
474    client: &mut AuthedClient,
475    ws: &mut WebSocket,
476) -> Result<(), Error> {
477    let req = req?;
478    execute_request(client, req, ws).await
479}
480
481/// Sends a single [`WebSocketResponse`] over the provided [`WebSocket`].
482async fn send_ws_response(ws: &mut WebSocket, resp: WebSocketResponse) -> Result<(), Error> {
483    let msg = serde_json::to_string(&resp).unwrap();
484    let msg = Message::Text(msg.into());
485    ws.send(msg).await?;
486
487    Ok(())
488}
489
490/// Forwards a collection of Notices to the provided [`WebSocket`].
491async fn forward_notices(
492    ws: &mut WebSocket,
493    notices: impl IntoIterator<Item = AdapterNotice>,
494) -> Result<(), Error> {
495    let ws_notices = notices.into_iter().map(|notice| {
496        WebSocketResponse::Notice(Notice {
497            message: notice.to_string(),
498            code: notice.code().code().to_string(),
499            severity: notice.severity().as_str().to_lowercase(),
500            detail: notice.detail(),
501            hint: notice.hint(),
502        })
503    });
504
505    for notice in ws_notices {
506        send_ws_response(ws, notice).await?;
507    }
508
509    Ok(())
510}
511
512/// A request to execute SQL over HTTP.
513#[derive(Serialize, Deserialize, Debug)]
514#[serde(untagged)]
515pub enum SqlRequest {
516    /// A simple query request.
517    Simple {
518        /// A query string containing zero or more queries delimited by
519        /// semicolons.
520        query: String,
521    },
522    /// An extended query request.
523    Extended {
524        /// Queries to execute using the extended protocol.
525        queries: Vec<ExtendedRequest>,
526    },
527}
528
529/// An request to execute a SQL query using the extended protocol.
530#[derive(Serialize, Deserialize, Debug)]
531pub struct ExtendedRequest {
532    /// A query string containing zero or one queries.
533    query: String,
534    /// Optional parameters for the query.
535    #[serde(default)]
536    params: Vec<Option<String>>,
537}
538
539/// The response to a `SqlRequest`.
540#[derive(Debug, Serialize, Deserialize)]
541pub struct SqlResponse {
542    /// The results for each query in the request.
543    pub(in crate::http) results: Vec<SqlResult>,
544}
545
546impl SqlResponse {
547    /// Creates a new empty SqlResponse for collecting results.
548    pub(in crate::http) fn new() -> Self {
549        Self {
550            results: Vec::new(),
551        }
552    }
553}
554
555pub(in crate::http) enum StatementResult {
556    SqlResult(SqlResult),
557    Subscribe {
558        desc: RelationDesc,
559        tag: String,
560        rx: RecordFirstRowStream,
561        ctx_extra: ExecuteContextGuard,
562    },
563}
564
565impl From<SqlResult> for StatementResult {
566    fn from(inner: SqlResult) -> Self {
567        Self::SqlResult(inner)
568    }
569}
570
571/// The result of a single query in a [`SqlResponse`].
572#[derive(Debug, Serialize, Deserialize)]
573#[serde(untagged)]
574pub enum SqlResult {
575    /// The query returned rows.
576    Rows {
577        /// The command complete tag.
578        tag: String,
579        /// The result rows.
580        rows: Vec<Vec<serde_json::Value>>,
581        /// Information about each column.
582        desc: Description,
583        // Any notices generated during execution of the query.
584        notices: Vec<Notice>,
585    },
586    /// The query executed successfully but did not return rows.
587    Ok {
588        /// The command complete tag.
589        ok: String,
590        /// Any notices generated during execution of the query.
591        notices: Vec<Notice>,
592        /// Any parameters that may have changed.
593        ///
594        /// Note: skip serializing this field in a response if the list of parameters is empty.
595        #[serde(skip_serializing_if = "Vec::is_empty")]
596        parameters: Vec<ParameterStatus>,
597    },
598    /// The query returned an error.
599    Err {
600        error: SqlError,
601        // Any notices generated during execution of the query.
602        notices: Vec<Notice>,
603    },
604}
605
606impl SqlResult {
607    /// Convert adapter Row results into the web row result format. Error if the row format does not
608    /// match the expected descriptor.
609    // TODO(aljoscha): Bail when max_result_size is exceeded.
610    async fn rows<S>(
611        sender: &mut S,
612        client: &mut SessionClient,
613        mut rows_stream: RecordFirstRowStream,
614        max_query_result_size: usize,
615        desc: &RelationDesc,
616    ) -> Result<SqlResult, Error>
617    where
618        S: ResultSender,
619    {
620        let mut rows: Vec<Vec<serde_json::Value>> = vec![];
621        let mut datum_vec = mz_repr::DatumVec::new();
622        let types = &desc.typ().column_types;
623
624        let mut query_result_size = 0;
625
626        loop {
627            let peek_response = tokio::select! {
628                notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
629                    sender.emit_streaming_notices(vec![notice]).await?;
630                    continue;
631                }
632                e = sender.connection_error() => return Err(e),
633                r = rows_stream.recv() => {
634                    match r {
635                        Some(r) => r,
636                        None => break,
637                    }
638                },
639            };
640
641            let mut sql_rows = match peek_response {
642                PeekResponseUnary::Rows(rows) => rows,
643                PeekResponseUnary::Error(e) => {
644                    return Ok(SqlResult::err(client, Error::Unstructured(anyhow!(e))));
645                }
646                PeekResponseUnary::DependencyDropped(dep) => {
647                    return Ok(SqlResult::err(client, dep.to_concurrent_dependency_drop()));
648                }
649                PeekResponseUnary::Canceled => {
650                    return Ok(SqlResult::err(client, AdapterError::Canceled));
651                }
652            };
653
654            if let Err(err) = verify_datum_desc(desc, &mut sql_rows) {
655                return Ok(SqlResult::Err {
656                    error: err.into(),
657                    notices: make_notices(client),
658                });
659            }
660
661            while let Some(row) = sql_rows.next() {
662                query_result_size += row.byte_len();
663                if query_result_size > max_query_result_size {
664                    use bytesize::ByteSize;
665                    return Ok(SqlResult::err(
666                        client,
667                        AdapterError::ResultSize(format!(
668                            "result exceeds max size of {}",
669                            ByteSize::b(u64::cast_from(max_query_result_size))
670                        )),
671                    ));
672                }
673
674                let datums = datum_vec.borrow_with(row);
675                rows.push(
676                    datums
677                        .iter()
678                        .enumerate()
679                        .map(|(i, d)| {
680                            TypedDatum::new(*d, &types[i])
681                                .json(&JsonNumberPolicy::ConvertNumberToString)
682                        })
683                        .collect(),
684                );
685            }
686        }
687
688        let tag = format!("SELECT {}", rows.len());
689        Ok(SqlResult::Rows {
690            tag,
691            rows,
692            desc: Description::from(desc),
693            notices: make_notices(client),
694        })
695    }
696
697    fn err(client: &mut SessionClient, error: impl Into<SqlError>) -> SqlResult {
698        SqlResult::Err {
699            error: error.into(),
700            notices: make_notices(client),
701        }
702    }
703
704    fn ok(client: &mut SessionClient, tag: String, params: Vec<ParameterStatus>) -> SqlResult {
705        SqlResult::Ok {
706            ok: tag,
707            parameters: params,
708            notices: make_notices(client),
709        }
710    }
711}
712
713#[derive(Debug, Deserialize, Serialize)]
714pub struct SqlError {
715    pub message: String,
716    pub code: String,
717    #[serde(skip_serializing_if = "Option::is_none")]
718    pub detail: Option<String>,
719    #[serde(skip_serializing_if = "Option::is_none")]
720    pub hint: Option<String>,
721    #[serde(skip_serializing_if = "Option::is_none")]
722    pub position: Option<usize>,
723}
724
725impl From<Error> for SqlError {
726    fn from(err: Error) -> Self {
727        SqlError {
728            message: err.to_string(),
729            code: err.code().code().to_string(),
730            detail: err.detail(),
731            hint: err.hint(),
732            position: err.position(),
733        }
734    }
735}
736
737impl From<AdapterError> for SqlError {
738    fn from(value: AdapterError) -> Self {
739        Error::from(value).into()
740    }
741}
742
743#[derive(Debug, Deserialize, Serialize)]
744#[serde(tag = "type", content = "payload")]
745pub enum WebSocketResponse {
746    ReadyForQuery(String),
747    Notice(Notice),
748    Rows(Description),
749    Row(Vec<serde_json::Value>),
750    CommandStarting(CommandStarting),
751    CommandComplete(String),
752    Error(SqlError),
753    ParameterStatus(ParameterStatus),
754    BackendKeyData(BackendKeyData),
755}
756
757#[derive(Debug, Serialize, Deserialize)]
758pub struct Notice {
759    message: String,
760    code: String,
761    severity: String,
762    #[serde(skip_serializing_if = "Option::is_none")]
763    pub detail: Option<String>,
764    #[serde(skip_serializing_if = "Option::is_none")]
765    pub hint: Option<String>,
766}
767
768impl Notice {
769    pub fn message(&self) -> &str {
770        &self.message
771    }
772}
773
774#[derive(Debug, Serialize, Deserialize)]
775pub struct Description {
776    pub columns: Vec<Column>,
777}
778
779impl From<&RelationDesc> for Description {
780    fn from(desc: &RelationDesc) -> Self {
781        let columns = desc
782            .iter()
783            .map(|(name, typ)| {
784                let pg_type = mz_pgrepr::Type::from(&typ.scalar_type);
785                Column {
786                    name: name.to_string(),
787                    type_oid: pg_type.oid(),
788                    type_len: pg_type.typlen(),
789                    type_mod: pg_type.typmod(),
790                }
791            })
792            .collect();
793        Description { columns }
794    }
795}
796
797#[derive(Debug, Serialize, Deserialize)]
798pub struct Column {
799    pub name: String,
800    pub type_oid: u32,
801    pub type_len: i16,
802    pub type_mod: i32,
803}
804
805#[derive(Debug, Serialize, Deserialize)]
806pub struct ParameterStatus {
807    name: String,
808    value: String,
809}
810
811#[derive(Debug, Serialize, Deserialize)]
812pub struct BackendKeyData {
813    conn_id: u32,
814    secret_key: u32,
815}
816
817#[derive(Debug, Serialize, Deserialize)]
818pub struct CommandStarting {
819    has_rows: bool,
820    is_streaming: bool,
821}
822
823/// Trait describing how to transmit a response to a client. HTTP clients
824/// accumulate into a Vec and send all at once. WebSocket clients send each
825/// message as they occur.
826#[async_trait]
827pub(in crate::http) trait ResultSender: Send {
828    const SUPPORTS_STREAMING_NOTICES: bool = false;
829
830    /// Adds a result to the client. The first component of the return value is
831    /// Err if sending to the client
832    /// produced an error and the server should disconnect. It is Ok(Err) if the statement
833    /// produced an error and should error the transaction, but remain connected. It is Ok(Ok(()))
834    /// if the statement succeeded.
835    /// The second component of the return value is `Some` if execution still
836    /// needs to be retired for statement logging purposes.
837    async fn add_result(
838        &mut self,
839        client: &mut SessionClient,
840        res: StatementResult,
841    ) -> (
842        Result<Result<(), ()>, Error>,
843        Option<(StatementEndedExecutionReason, ExecuteContextGuard)>,
844    );
845
846    /// Returns a future that resolves only when the client connection has gone away.
847    fn connection_error(&mut self) -> BoxFuture<'_, Error>;
848    /// Reports whether the client supports streaming SUBSCRIBE results.
849    fn allow_subscribe(&self) -> bool;
850
851    /// Emits a streaming notice if the sender supports it.
852    ///
853    /// Does nothing if `SUPPORTS_STREAMING_NOTICES` is false.
854    async fn emit_streaming_notices(&mut self, _: Vec<AdapterNotice>) -> Result<(), Error> {
855        unreachable!("streaming notices marked as unsupported")
856    }
857}
858
859#[async_trait]
860impl ResultSender for SqlResponse {
861    // The first component of the return value is
862    // Err if sending to the client
863    // produced an error and the server should disconnect. It is Ok(Err) if the statement
864    // produced an error and should error the transaction, but remain connected. It is Ok(Ok(()))
865    // if the statement succeeded.
866    // The second component of the return value is `Some` if execution still
867    // needs to be retired for statement logging purposes.
868    async fn add_result(
869        &mut self,
870        _client: &mut SessionClient,
871        res: StatementResult,
872    ) -> (
873        Result<Result<(), ()>, Error>,
874        Option<(StatementEndedExecutionReason, ExecuteContextGuard)>,
875    ) {
876        let (res, stmt_logging) = match res {
877            StatementResult::SqlResult(res) => {
878                let is_err = matches!(res, SqlResult::Err { .. });
879                self.results.push(res);
880                let res = if is_err { Err(()) } else { Ok(()) };
881                (res, None)
882            }
883            StatementResult::Subscribe { ctx_extra, .. } => {
884                let message = "SUBSCRIBE only supported over websocket";
885                self.results.push(SqlResult::Err {
886                    error: Error::SubscribeOnlyOverWs.into(),
887                    notices: Vec::new(),
888                });
889                (
890                    Err(()),
891                    Some((
892                        StatementEndedExecutionReason::Errored {
893                            error: message.into(),
894                        },
895                        ctx_extra,
896                    )),
897                )
898            }
899        };
900        (Ok(res), stmt_logging)
901    }
902
903    fn connection_error(&mut self) -> BoxFuture<'_, Error> {
904        Box::pin(futures::future::pending())
905    }
906
907    fn allow_subscribe(&self) -> bool {
908        false
909    }
910}
911
912#[async_trait]
913impl ResultSender for WebSocket {
914    const SUPPORTS_STREAMING_NOTICES: bool = true;
915
916    // The first component of the return value is Err if sending to the client produced an error and
917    // the server should disconnect. It is Ok(Err) if the statement produced an error and should
918    // error the transaction, but remain connected. It is Ok(Ok(())) if the statement succeeded. The
919    // second component of the return value is `Some` if execution still needs to be retired for
920    // statement logging purposes.
921    async fn add_result(
922        &mut self,
923        client: &mut SessionClient,
924        res: StatementResult,
925    ) -> (
926        Result<Result<(), ()>, Error>,
927        Option<(StatementEndedExecutionReason, ExecuteContextGuard)>,
928    ) {
929        let (has_rows, is_streaming) = match res {
930            StatementResult::SqlResult(SqlResult::Err { .. }) => (false, false),
931            StatementResult::SqlResult(SqlResult::Ok { .. }) => (false, false),
932            StatementResult::SqlResult(SqlResult::Rows { .. }) => (true, false),
933            StatementResult::Subscribe { .. } => (true, true),
934        };
935        if let Err(e) = send_ws_response(
936            self,
937            WebSocketResponse::CommandStarting(CommandStarting {
938                has_rows,
939                is_streaming,
940            }),
941        )
942        .await
943        {
944            return (Err(e), None);
945        }
946
947        let (is_err, msgs, stmt_logging) = match res {
948            StatementResult::SqlResult(SqlResult::Rows {
949                tag,
950                rows,
951                desc,
952                notices,
953            }) => {
954                // Stream rows directly to avoid buffering all rows as
955                // WebSocketResponse, which inflates memory due to enum sizing.
956                if let Err(e) = send_ws_response(self, WebSocketResponse::Rows(desc)).await {
957                    return (Err(e), None);
958                }
959                for row in rows {
960                    if let Err(e) = send_ws_response(self, WebSocketResponse::Row(row)).await {
961                        return (Err(e), None);
962                    }
963                }
964                let mut msgs = vec![WebSocketResponse::CommandComplete(tag)];
965                msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
966                (false, msgs, None)
967            }
968            StatementResult::SqlResult(SqlResult::Ok {
969                ok,
970                parameters,
971                notices,
972            }) => {
973                let mut msgs = vec![WebSocketResponse::CommandComplete(ok)];
974                msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
975                msgs.extend(
976                    parameters
977                        .into_iter()
978                        .map(WebSocketResponse::ParameterStatus),
979                );
980                (false, msgs, None)
981            }
982            StatementResult::SqlResult(SqlResult::Err { error, notices }) => {
983                let mut msgs = vec![WebSocketResponse::Error(error)];
984                msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
985                (true, msgs, None)
986            }
987            StatementResult::Subscribe {
988                ref desc,
989                tag,
990                mut rx,
991                ctx_extra,
992            } => {
993                if let Err(e) = send_ws_response(self, WebSocketResponse::Rows(desc.into())).await {
994                    // We consider the remote breaking the connection to be a cancellation,
995                    // matching the behavior for pgwire
996                    return (
997                        Err(e),
998                        Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
999                    );
1000                }
1001
1002                let mut datum_vec = mz_repr::DatumVec::new();
1003                let mut result_size: usize = 0;
1004                let mut rows_returned = 0;
1005                loop {
1006                    let res = match await_rows(self, client, rx.recv()).await {
1007                        Ok(res) => res,
1008                        Err(e) => {
1009                            // We consider the remote breaking the connection to be a cancellation,
1010                            // matching the behavior for pgwire
1011                            return (
1012                                Err(e),
1013                                Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1014                            );
1015                        }
1016                    };
1017                    match res {
1018                        Some(PeekResponseUnary::Rows(mut rows)) => {
1019                            if let Err(err) = verify_datum_desc(desc, &mut rows) {
1020                                let error = err.to_string();
1021                                break (
1022                                    true,
1023                                    vec![WebSocketResponse::Error(err.into())],
1024                                    Some((
1025                                        StatementEndedExecutionReason::Errored { error },
1026                                        ctx_extra,
1027                                    )),
1028                                );
1029                            }
1030
1031                            rows_returned += rows.count();
1032                            while let Some(row) = rows.next() {
1033                                result_size += row.byte_len();
1034                                let datums = datum_vec.borrow_with(row);
1035                                let types = &desc.typ().column_types;
1036                                if let Err(e) = send_ws_response(
1037                                    self,
1038                                    WebSocketResponse::Row(
1039                                        datums
1040                                            .iter()
1041                                            .enumerate()
1042                                            .map(|(i, d)| {
1043                                                TypedDatum::new(*d, &types[i])
1044                                                    .json(&JsonNumberPolicy::ConvertNumberToString)
1045                                            })
1046                                            .collect(),
1047                                    ),
1048                                )
1049                                .await
1050                                {
1051                                    // We consider the remote breaking the connection to be a cancellation,
1052                                    // matching the behavior for pgwire
1053                                    return (
1054                                        Err(e),
1055                                        Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1056                                    );
1057                                }
1058                            }
1059                        }
1060                        Some(PeekResponseUnary::Error(error)) => {
1061                            break (
1062                                true,
1063                                vec![WebSocketResponse::Error(
1064                                    Error::Unstructured(anyhow!(error.clone())).into(),
1065                                )],
1066                                Some((StatementEndedExecutionReason::Errored { error }, ctx_extra)),
1067                            );
1068                        }
1069                        Some(PeekResponseUnary::DependencyDropped(dep)) => {
1070                            let err = dep.to_concurrent_dependency_drop();
1071                            let error = err.to_string();
1072                            break (
1073                                true,
1074                                vec![WebSocketResponse::Error(err.into())],
1075                                Some((StatementEndedExecutionReason::Errored { error }, ctx_extra)),
1076                            );
1077                        }
1078                        Some(PeekResponseUnary::Canceled) => {
1079                            break (
1080                                true,
1081                                vec![WebSocketResponse::Error(AdapterError::Canceled.into())],
1082                                Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1083                            );
1084                        }
1085                        None => {
1086                            break (
1087                                false,
1088                                vec![WebSocketResponse::CommandComplete(tag)],
1089                                Some((
1090                                    StatementEndedExecutionReason::Success {
1091                                        result_size: Some(u64::cast_from(result_size)),
1092                                        rows_returned: Some(u64::cast_from(rows_returned)),
1093                                        execution_strategy: Some(
1094                                            StatementExecutionStrategy::Standard,
1095                                        ),
1096                                    },
1097                                    ctx_extra,
1098                                )),
1099                            );
1100                        }
1101                    }
1102                }
1103            }
1104        };
1105        for msg in msgs {
1106            if let Err(e) = send_ws_response(self, msg).await {
1107                return (
1108                    Err(e),
1109                    stmt_logging.map(|(_old_reason, ctx_extra)| {
1110                        (StatementEndedExecutionReason::Canceled, ctx_extra)
1111                    }),
1112                );
1113            }
1114        }
1115        (Ok(if is_err { Err(()) } else { Ok(()) }), stmt_logging)
1116    }
1117
1118    // Send a websocket Ping every second to verify the client is still
1119    // connected.
1120    fn connection_error(&mut self) -> BoxFuture<'_, Error> {
1121        Box::pin(async {
1122            let mut tick = time::interval(Duration::from_secs(1));
1123            tick.tick().await;
1124            loop {
1125                tick.tick().await;
1126                if let Err(err) = self.send(Message::Ping(Vec::new().into())).await {
1127                    return err.into();
1128                }
1129            }
1130        })
1131    }
1132
1133    fn allow_subscribe(&self) -> bool {
1134        true
1135    }
1136
1137    async fn emit_streaming_notices(&mut self, notices: Vec<AdapterNotice>) -> Result<(), Error> {
1138        forward_notices(self, notices).await
1139    }
1140}
1141
1142async fn await_rows<S, F, R>(sender: &mut S, client: &mut SessionClient, f: F) -> Result<R, Error>
1143where
1144    S: ResultSender,
1145    F: Future<Output = R> + Send,
1146{
1147    let mut f = pin!(f);
1148    loop {
1149        tokio::select! {
1150            notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
1151                sender.emit_streaming_notices(vec![notice]).await?;
1152            }
1153            e = sender.connection_error() => return Err(e),
1154            r = &mut f => return Ok(r),
1155        }
1156    }
1157}
1158
1159async fn send_and_retire<S: ResultSender>(
1160    res: StatementResult,
1161    client: &mut SessionClient,
1162    sender: &mut S,
1163) -> Result<Result<(), ()>, Error> {
1164    let (res, stmt_logging) = sender.add_result(client, res).await;
1165    if let Some((reason, ctx_extra)) = stmt_logging {
1166        client.retire_execute(ctx_extra, reason);
1167    }
1168    res
1169}
1170
1171/// Returns Ok(Err) if any statement error'd during execution.
1172async fn execute_stmt_group<S: ResultSender>(
1173    client: &mut SessionClient,
1174    sender: &mut S,
1175    stmt_group: Vec<(Statement<Raw>, String, Vec<Option<String>>)>,
1176) -> Result<Result<(), ()>, Error> {
1177    let num_stmts = stmt_group.len();
1178    for (stmt, sql, params) in stmt_group {
1179        assert!(
1180            num_stmts <= 1 || params.is_empty(),
1181            "statement groups contain more than 1 statement iff Simple request, which does not support parameters"
1182        );
1183
1184        let is_aborted_txn = matches!(client.session().transaction(), TransactionStatus::Failed(_));
1185        if is_aborted_txn && !is_txn_exit_stmt(&stmt) {
1186            let err = SqlResult::err(client, Error::AbortedTransaction);
1187            let _ = send_and_retire(err.into(), client, sender).await?;
1188            return Ok(Err(()));
1189        }
1190
1191        // Mirror the behavior of the PostgreSQL simple query protocol.
1192        // See the pgwire::protocol::StateMachine::query method for details.
1193        if let Err(e) = client.start_transaction(Some(num_stmts)) {
1194            let err = SqlResult::err(client, e);
1195            let _ = send_and_retire(err.into(), client, sender).await?;
1196            return Ok(Err(()));
1197        }
1198        let res = execute_stmt(client, sender, stmt, sql, params).await?;
1199        let is_err = send_and_retire(res, client, sender).await?;
1200
1201        if is_err.is_err() {
1202            // Mirror StateMachine::error, which sometimes will clean up the
1203            // transaction state instead of always leaving it in Failed.
1204            let txn = client.session().transaction();
1205            match txn {
1206                // Error can be called from describe and parse and so might not be in an active
1207                // transaction.
1208                TransactionStatus::Default | TransactionStatus::Failed(_) => {}
1209                // In Started (i.e., a single statement) and implicit transactions cleanup themselves.
1210                TransactionStatus::Started(_) | TransactionStatus::InTransactionImplicit(_) => {
1211                    if let Err(err) = client.end_transaction(EndTransactionAction::Rollback).await {
1212                        let err = SqlResult::err(client, err);
1213                        let _ = send_and_retire(err.into(), client, sender).await?;
1214                    }
1215                }
1216                // Explicit transactions move to failed.
1217                TransactionStatus::InTransaction(_) => {
1218                    client.fail_transaction();
1219                }
1220            }
1221            return Ok(Err(()));
1222        }
1223    }
1224    Ok(Ok(()))
1225}
1226
1227/// Executes an entire [`SqlRequest`].
1228///
1229/// See the user-facing documentation about the HTTP API for a description of
1230/// the semantics of this function.
1231/// Executes a SQL request and sends results to the provided sender.
1232///
1233/// Made visible to http submodules (like mcp) via `pub(in crate::http)` to allow
1234/// reuse of SQL execution logic.
1235pub(in crate::http) async fn execute_request<S: ResultSender>(
1236    client: &mut AuthedClient,
1237    request: SqlRequest,
1238    sender: &mut S,
1239) -> Result<(), Error> {
1240    let client = &mut client.client;
1241
1242    // This API prohibits executing statements with responses whose
1243    // semantics are at odds with an HTTP response.
1244    fn check_prohibited_stmts<S: ResultSender>(
1245        sender: &S,
1246        stmt: &Statement<Raw>,
1247    ) -> Result<(), Error> {
1248        let kind: StatementKind = stmt.into();
1249        let execute_responses = Plan::generated_from(&kind)
1250            .into_iter()
1251            .map(ExecuteResponse::generated_from)
1252            .flatten()
1253            .collect::<Vec<_>>();
1254
1255        // Special-case `COPY TO` statements that are not `COPY ... TO STDOUT`, since
1256        // StatementKind::Copy links to several `ExecuteResponseKind`s that are not supported,
1257        // but this specific statement should be allowed.
1258        let is_valid_copy = matches!(
1259            stmt,
1260            Statement::Copy(CopyStatement {
1261                direction: CopyDirection::To,
1262                target: CopyTarget::Expr(_),
1263                ..
1264            }) | Statement::Copy(CopyStatement {
1265                direction: CopyDirection::From,
1266                target: CopyTarget::Expr(_),
1267                ..
1268            })
1269        );
1270
1271        if !is_valid_copy
1272            && execute_responses.iter().any(|execute_response| {
1273                // Returns true if a statement or execute response are unsupported.
1274                match execute_response {
1275                    ExecuteResponseKind::Subscribing if sender.allow_subscribe() => false,
1276                    ExecuteResponseKind::Fetch
1277                    | ExecuteResponseKind::Subscribing
1278                    | ExecuteResponseKind::CopyFrom
1279                    | ExecuteResponseKind::DeclaredCursor
1280                    | ExecuteResponseKind::ClosedCursor => true,
1281                    // Various statements generate `PeekPlan` (`SELECT`, `COPY`,
1282                    // `EXPLAIN`, `SHOW`) which has both `SendRows` and `CopyTo` as its
1283                    // possible response types. but `COPY` needs be picked out because
1284                    // http don't support its response type
1285                    ExecuteResponseKind::CopyTo if matches!(kind, StatementKind::Copy) => true,
1286                    _ => false,
1287                }
1288            })
1289        {
1290            return Err(Error::Unsupported(stmt.to_ast_string_simple()));
1291        }
1292        Ok(())
1293    }
1294
1295    fn parse<'a>(
1296        client: &SessionClient,
1297        query: &'a str,
1298    ) -> Result<Vec<StatementParseResult<'a>>, Error> {
1299        let result = client
1300            .parse(query)
1301            .map_err(|e| Error::Unstructured(anyhow!(e)))?;
1302        result.map_err(|e| AdapterError::from(e).into())
1303    }
1304
1305    let mut stmt_groups = vec![];
1306
1307    match request {
1308        SqlRequest::Simple { query } => match parse(client, &query) {
1309            Ok(stmts) => {
1310                let mut stmt_group = Vec::with_capacity(stmts.len());
1311                let mut stmt_err = None;
1312                for StatementParseResult { ast: stmt, sql } in stmts {
1313                    if let Err(err) = check_prohibited_stmts(sender, &stmt) {
1314                        stmt_err = Some(err);
1315                        break;
1316                    }
1317                    stmt_group.push((stmt, sql.to_string(), vec![]));
1318                }
1319                stmt_groups.push(stmt_err.map(Err).unwrap_or_else(|| Ok(stmt_group)));
1320            }
1321            Err(e) => stmt_groups.push(Err(e)),
1322        },
1323        SqlRequest::Extended { queries } => {
1324            for ExtendedRequest { query, params } in queries {
1325                match parse(client, &query) {
1326                    Ok(mut stmts) => {
1327                        if stmts.len() != 1 {
1328                            return Err(Error::Unstructured(anyhow!(
1329                                "each query must contain exactly 1 statement, but \"{}\" contains {}",
1330                                query,
1331                                stmts.len()
1332                            )));
1333                        }
1334
1335                        let StatementParseResult { ast: stmt, sql } = stmts.pop().unwrap();
1336                        stmt_groups.push(
1337                            check_prohibited_stmts(sender, &stmt)
1338                                .map(|_| vec![(stmt, sql.to_string(), params)]),
1339                        );
1340                    }
1341                    Err(e) => stmt_groups.push(Err(e)),
1342                };
1343            }
1344        }
1345    }
1346
1347    for stmt_group_res in stmt_groups {
1348        let executed = match stmt_group_res {
1349            Ok(stmt_group) => execute_stmt_group(client, sender, stmt_group).await,
1350            Err(e) => {
1351                let err = SqlResult::err(client, e);
1352                let _ = send_and_retire(err.into(), client, sender).await?;
1353                Ok(Err(()))
1354            }
1355        };
1356        // At the end of each group, commit implicit transactions. Do that here so that any `?`
1357        // early return can still be handled here.
1358        if client.session().transaction().is_implicit() {
1359            let ended = client.end_transaction(EndTransactionAction::Commit).await;
1360            if let Err(err) = ended {
1361                let err = SqlResult::err(client, err);
1362                let _ = send_and_retire(StatementResult::SqlResult(err), client, sender).await?;
1363            }
1364        }
1365        if executed?.is_err() {
1366            break;
1367        }
1368    }
1369
1370    Ok(())
1371}
1372
1373/// Executes a single statement in a [`SqlRequest`].
1374async fn execute_stmt<S: ResultSender>(
1375    client: &mut SessionClient,
1376    sender: &mut S,
1377    stmt: Statement<Raw>,
1378    sql: String,
1379    raw_params: Vec<Option<String>>,
1380) -> Result<StatementResult, Error> {
1381    const EMPTY_PORTAL: &str = "";
1382    if let Err(e) = client
1383        .prepare(EMPTY_PORTAL.into(), Some(stmt.clone()), sql, vec![])
1384        .await
1385    {
1386        return Ok(SqlResult::err(client, e).into());
1387    }
1388
1389    let prep_stmt = match client.get_prepared_statement(EMPTY_PORTAL).await {
1390        Ok(stmt) => stmt,
1391        Err(err) => {
1392            return Ok(SqlResult::err(client, err).into());
1393        }
1394    };
1395
1396    let param_types = &prep_stmt.desc().param_types;
1397    if param_types.len() != raw_params.len() {
1398        let message = anyhow!(
1399            "request supplied {actual} parameters, \
1400                        but {statement} requires {expected}",
1401            statement = stmt.to_ast_string_simple(),
1402            actual = raw_params.len(),
1403            expected = param_types.len()
1404        );
1405        return Ok(SqlResult::err(client, Error::Unstructured(message)).into());
1406    }
1407
1408    let buf = RowArena::new();
1409    let mut params = vec![];
1410    for (raw_param, mz_typ) in raw_params.into_iter().zip_eq(param_types) {
1411        let pg_typ = mz_pgrepr::Type::from(mz_typ);
1412        let datum = match raw_param {
1413            None => Datum::Null,
1414            Some(raw_param) => {
1415                match mz_pgrepr::Value::decode(
1416                    mz_pgwire_common::Format::Text,
1417                    &pg_typ,
1418                    raw_param.as_bytes(),
1419                ) {
1420                    Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1421                        Ok(datum) => datum,
1422                        Err(msg) => {
1423                            return Ok(
1424                                SqlResult::err(client, Error::Unstructured(anyhow!(msg))).into()
1425                            );
1426                        }
1427                    },
1428                    Err(err) => {
1429                        let msg = anyhow!("unable to decode parameter: {}", err);
1430                        return Ok(SqlResult::err(client, Error::Unstructured(msg)).into());
1431                    }
1432                }
1433            }
1434        };
1435        params.push((datum, mz_typ.clone()))
1436    }
1437
1438    let result_formats = vec![
1439        mz_pgwire_common::Format::Text;
1440        prep_stmt
1441            .desc()
1442            .relation_desc
1443            .clone()
1444            .map(|desc| desc.typ().column_types.len())
1445            .unwrap_or(0)
1446    ];
1447
1448    let desc = prep_stmt.desc().clone();
1449    let logging = Arc::clone(prep_stmt.logging());
1450    let stmt_ast = prep_stmt.stmt().cloned();
1451    let state_revision = prep_stmt.state_revision;
1452    if let Err(err) = client.session().set_portal(
1453        EMPTY_PORTAL.into(),
1454        desc,
1455        stmt_ast,
1456        logging,
1457        params,
1458        result_formats,
1459        state_revision,
1460    ) {
1461        return Ok(SqlResult::err(client, err).into());
1462    }
1463
1464    let desc = client
1465        .session()
1466        // We do not need to verify here because `client.execute` verifies below.
1467        .get_portal_unverified(EMPTY_PORTAL)
1468        .map(|portal| portal.desc.clone())
1469        .expect("unnamed portal should be present");
1470
1471    let res = client
1472        .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
1473        .await;
1474
1475    if S::SUPPORTS_STREAMING_NOTICES {
1476        sender
1477            .emit_streaming_notices(client.session().drain_notices())
1478            .await?;
1479    }
1480
1481    let (res, execute_started) = match res {
1482        Ok(res) => res,
1483        Err(e) => {
1484            return Ok(SqlResult::err(client, e).into());
1485        }
1486    };
1487    let tag = res.tag();
1488
1489    Ok(match res {
1490        ExecuteResponse::CreatedConnection { .. }
1491        | ExecuteResponse::CreatedDatabase { .. }
1492        | ExecuteResponse::CreatedSchema { .. }
1493        | ExecuteResponse::CreatedRole
1494        | ExecuteResponse::CreatedCluster { .. }
1495        | ExecuteResponse::CreatedClusterReplica { .. }
1496        | ExecuteResponse::CreatedTable { .. }
1497        | ExecuteResponse::CreatedIndex { .. }
1498        | ExecuteResponse::CreatedIntrospectionSubscribe
1499        | ExecuteResponse::CreatedSecret { .. }
1500        | ExecuteResponse::CreatedSource { .. }
1501        | ExecuteResponse::CreatedSink { .. }
1502        | ExecuteResponse::CreatedView { .. }
1503        | ExecuteResponse::CreatedViews { .. }
1504        | ExecuteResponse::CreatedMaterializedView { .. }
1505        | ExecuteResponse::CreatedType
1506        | ExecuteResponse::CreatedNetworkPolicy
1507        | ExecuteResponse::Comment
1508        | ExecuteResponse::Deleted(_)
1509        | ExecuteResponse::DiscardedTemp
1510        | ExecuteResponse::DiscardedAll
1511        | ExecuteResponse::DroppedObject(_)
1512        | ExecuteResponse::DroppedOwned
1513        | ExecuteResponse::EmptyQuery
1514        | ExecuteResponse::GrantedPrivilege
1515        | ExecuteResponse::GrantedRole
1516        | ExecuteResponse::Inserted(_)
1517        | ExecuteResponse::Copied(_)
1518        | ExecuteResponse::Raised
1519        | ExecuteResponse::ReassignOwned
1520        | ExecuteResponse::RevokedPrivilege
1521        | ExecuteResponse::AlteredDefaultPrivileges
1522        | ExecuteResponse::RevokedRole
1523        | ExecuteResponse::StartedTransaction { .. }
1524        | ExecuteResponse::Updated(_)
1525        | ExecuteResponse::AlteredObject(_)
1526        | ExecuteResponse::AlteredRole
1527        | ExecuteResponse::AlteredSystemConfiguration
1528        | ExecuteResponse::Deallocate { .. }
1529        | ExecuteResponse::ValidatedConnection
1530        | ExecuteResponse::Prepare => SqlResult::ok(
1531            client,
1532            tag.expect("ok only called on tag-generating results"),
1533            Vec::default(),
1534        )
1535        .into(),
1536        ExecuteResponse::TransactionCommitted { params }
1537        | ExecuteResponse::TransactionRolledBack { params } => {
1538            let notify_set: mz_ore::collections::HashSet<_> = client
1539                .session()
1540                .vars()
1541                .notify_set()
1542                .map(|v| v.name().to_string())
1543                .collect();
1544            let params = params
1545                .into_iter()
1546                .filter(|(name, _value)| notify_set.contains(*name))
1547                .map(|(name, value)| ParameterStatus {
1548                    name: name.to_string(),
1549                    value,
1550                })
1551                .collect();
1552            SqlResult::ok(
1553                client,
1554                tag.expect("ok only called on tag-generating results"),
1555                params,
1556            )
1557            .into()
1558        }
1559        ExecuteResponse::SetVariable { name, .. } => {
1560            let mut params = Vec::with_capacity(1);
1561            if let Some(var) = client
1562                .session()
1563                .vars()
1564                .notify_set()
1565                .find(|v| v.name() == &name)
1566            {
1567                params.push(ParameterStatus {
1568                    name,
1569                    value: var.value(),
1570                });
1571            };
1572            SqlResult::ok(
1573                client,
1574                tag.expect("ok only called on tag-generating results"),
1575                params,
1576            )
1577            .into()
1578        }
1579        ExecuteResponse::SendingRowsStreaming {
1580            rows,
1581            instance_id,
1582            strategy,
1583        } => {
1584            let max_query_result_size =
1585                usize::cast_from(client.get_system_vars().await.max_result_size());
1586
1587            let rows_stream = RecordFirstRowStream::new(
1588                Box::new(rows),
1589                execute_started,
1590                client,
1591                Some(instance_id),
1592                Some(strategy),
1593            );
1594
1595            SqlResult::rows(
1596                sender,
1597                client,
1598                rows_stream,
1599                max_query_result_size,
1600                &desc.relation_desc.expect("RelationDesc must exist"),
1601            )
1602            .await?
1603            .into()
1604        }
1605        ExecuteResponse::SendingRowsImmediate { rows } => {
1606            let max_query_result_size =
1607                usize::cast_from(client.get_system_vars().await.max_result_size());
1608
1609            let rows = futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1610            let rows_stream =
1611                RecordFirstRowStream::new(Box::new(rows), execute_started, client, None, None);
1612
1613            SqlResult::rows(
1614                sender,
1615                client,
1616                rows_stream,
1617                max_query_result_size,
1618                &desc.relation_desc.expect("RelationDesc must exist"),
1619            )
1620            .await?
1621            .into()
1622        }
1623        ExecuteResponse::Subscribing {
1624            rx,
1625            ctx_extra,
1626            instance_id,
1627        } => StatementResult::Subscribe {
1628            tag: "SUBSCRIBE".into(),
1629            desc: desc.relation_desc.unwrap(),
1630            rx: RecordFirstRowStream::new(
1631                Box::new(UnboundedReceiverStream::new(rx)),
1632                execute_started,
1633                client,
1634                Some(instance_id),
1635                None,
1636            ),
1637            ctx_extra,
1638        },
1639        res @ (ExecuteResponse::Fetch { .. }
1640        | ExecuteResponse::CopyTo { .. }
1641        | ExecuteResponse::CopyFrom { .. }
1642        | ExecuteResponse::DeclaredCursor
1643        | ExecuteResponse::ClosedCursor) => SqlResult::err(
1644            client,
1645            Error::Unstructured(anyhow!(
1646                "internal error: encountered prohibited ExecuteResponse {:?}.\n\n
1647            This is a bug. Can you please file an bug report letting us know?\n
1648            https://github.com/MaterializeInc/materialize/discussions/new?category=bug-reports",
1649                ExecuteResponseKind::from(res)
1650            )),
1651        )
1652        .into(),
1653    })
1654}
1655
1656fn make_notices(client: &mut SessionClient) -> Vec<Notice> {
1657    client
1658        .session()
1659        .drain_notices()
1660        .into_iter()
1661        .map(|notice| Notice {
1662            message: notice.to_string(),
1663            code: notice.code().code().to_string(),
1664            severity: notice.severity().as_str().to_lowercase(),
1665            detail: notice.detail(),
1666            hint: notice.hint(),
1667        })
1668        .collect()
1669}
1670
1671// Duplicated from protocol.rs.
1672// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
1673fn is_txn_exit_stmt(stmt: &Statement<Raw>) -> bool {
1674    matches!(
1675        stmt,
1676        Statement::Commit(_) | Statement::Rollback(_) | Statement::Prepare(_)
1677    )
1678}
1679
1680#[cfg(test)]
1681mod tests {
1682    use std::collections::BTreeMap;
1683
1684    use super::{Password, WebSocketAuth};
1685
1686    #[mz_ore::test]
1687    fn smoke_test_websocket_auth_parse() {
1688        struct TestCase {
1689            json: &'static str,
1690            expected: WebSocketAuth,
1691        }
1692
1693        let test_cases = vec![
1694            TestCase {
1695                json: r#"{ "user": "mz", "password": "1234" }"#,
1696                expected: WebSocketAuth::Basic {
1697                    user: "mz".to_string(),
1698                    password: Password("1234".to_string()),
1699                    options: BTreeMap::default(),
1700                },
1701            },
1702            TestCase {
1703                json: r#"{ "user": "mz", "password": "1234", "options": {} }"#,
1704                expected: WebSocketAuth::Basic {
1705                    user: "mz".to_string(),
1706                    password: Password("1234".to_string()),
1707                    options: BTreeMap::default(),
1708                },
1709            },
1710            TestCase {
1711                json: r#"{ "token": "i_am_a_token" }"#,
1712                expected: WebSocketAuth::Bearer {
1713                    token: "i_am_a_token".to_string(),
1714                    options: BTreeMap::default(),
1715                },
1716            },
1717            TestCase {
1718                json: r#"{ "token": "i_am_a_token", "options": { "foo": "bar" } }"#,
1719                expected: WebSocketAuth::Bearer {
1720                    token: "i_am_a_token".to_string(),
1721                    options: BTreeMap::from([("foo".to_string(), "bar".to_string())]),
1722                },
1723            },
1724        ];
1725
1726        fn assert_parse(json: &'static str, expected: WebSocketAuth) {
1727            let parsed: WebSocketAuth = serde_json::from_str(json).unwrap();
1728            assert_eq!(parsed, expected);
1729        }
1730
1731        for TestCase { json, expected } in test_cases {
1732            assert_parse(json, expected)
1733        }
1734    }
1735}