mz_environmentd/http/
sql.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::borrow::Cow;
11use std::collections::BTreeMap;
12use std::net::{IpAddr, SocketAddr};
13use std::pin::pin;
14use std::sync::Arc;
15use std::time::{Duration, SystemTime};
16
17use anyhow::anyhow;
18use async_trait::async_trait;
19use axum::extract::connect_info::ConnectInfo;
20use axum::extract::ws::{CloseFrame, Message, WebSocket};
21use axum::extract::{State, WebSocketUpgrade};
22use axum::response::IntoResponse;
23use axum::{Extension, Json};
24use futures::Future;
25use futures::future::BoxFuture;
26
27use http::StatusCode;
28use itertools::Itertools;
29use mz_adapter::client::RecordFirstRowStream;
30use mz_adapter::session::{EndTransactionAction, TransactionStatus};
31use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
32use mz_adapter::{
33    AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, ExecuteResponseKind,
34    PeekResponseUnary, SessionClient, verify_datum_desc,
35};
36use mz_auth::password::Password;
37use mz_catalog::memory::objects::{Cluster, ClusterReplica};
38use mz_interchange::encode::TypedDatum;
39use mz_interchange::json::{JsonNumberPolicy, ToJson};
40use mz_ore::cast::CastFrom;
41use mz_ore::metrics::{MakeCollectorOpts, MetricsRegistry};
42use mz_ore::result::ResultExt;
43use mz_repr::{Datum, RelationDesc, RowArena, RowIterator};
44use mz_sql::ast::display::AstDisplay;
45use mz_sql::ast::{CopyDirection, CopyStatement, CopyTarget, Raw, Statement, StatementKind};
46use mz_sql::parse::StatementParseResult;
47use mz_sql::plan::Plan;
48use mz_sql::session::metadata::SessionMetadata;
49use prometheus::Opts;
50use prometheus::core::{AtomicF64, GenericGaugeVec};
51use serde::{Deserialize, Serialize};
52use tokio::{select, time};
53use tokio_postgres::error::SqlState;
54use tokio_stream::wrappers::UnboundedReceiverStream;
55use tower_sessions::Session as TowerSession;
56use tracing::{debug, error, info};
57use tungstenite::protocol::frame::coding::CloseCode;
58
59use crate::http::prometheus::PrometheusSqlQuery;
60use crate::http::{
61    AuthError, AuthedClient, AuthedUser, MAX_REQUEST_SIZE, SESSION_DURATION, TowerSessionData,
62    WsState, init_ws,
63};
64
65#[derive(Debug, thiserror::Error)]
66pub enum Error {
67    #[error(transparent)]
68    Adapter(#[from] AdapterError),
69    #[error(transparent)]
70    Json(#[from] serde_json::Error),
71    #[error(transparent)]
72    Axum(#[from] axum::Error),
73    #[error("SUBSCRIBE only supported over websocket")]
74    SubscribeOnlyOverWs,
75    #[error("current transaction is aborted, commands ignored until end of transaction block")]
76    AbortedTransaction,
77    #[error("unsupported via this API: {0}")]
78    Unsupported(String),
79    #[error("{0}")]
80    Unstructured(anyhow::Error),
81}
82
83impl Error {
84    pub fn detail(&self) -> Option<String> {
85        match self {
86            Error::Adapter(err) => err.detail(),
87            _ => None,
88        }
89    }
90
91    pub fn hint(&self) -> Option<String> {
92        match self {
93            Error::Adapter(err) => err.hint(),
94            _ => None,
95        }
96    }
97
98    pub fn position(&self) -> Option<usize> {
99        match self {
100            Error::Adapter(err) => err.position(),
101            _ => None,
102        }
103    }
104
105    pub fn code(&self) -> SqlState {
106        match self {
107            Error::Adapter(err) => err.code(),
108            Error::AbortedTransaction => SqlState::IN_FAILED_SQL_TRANSACTION,
109            _ => SqlState::INTERNAL_ERROR,
110        }
111    }
112}
113
114static PER_REPLICA_LABELS: &[&str] = &["replica_full_name", "instance_id", "replica_id"];
115
116async fn execute_promsql_query(
117    client: &mut AuthedClient,
118    query: &PrometheusSqlQuery<'_>,
119    metrics_registry: &MetricsRegistry,
120    metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
121    cluster: Option<(&Cluster, &ClusterReplica)>,
122) {
123    assert_eq!(query.per_replica, cluster.is_some());
124
125    let mut res = SqlResponse {
126        results: Vec::new(),
127    };
128
129    execute_request(client, query.to_sql_request(cluster), &mut res)
130        .await
131        .expect("valid SQL query");
132
133    let result = match res.results.as_slice() {
134        // Each query issued is preceded by several SET commands
135        // to make sure it is routed to the right cluster replica.
136        [
137            SqlResult::Ok { .. },
138            SqlResult::Ok { .. },
139            SqlResult::Ok { .. },
140            result,
141        ] => result,
142        // Transient errors are fine, like if the cluster or replica
143        // was dropped before the promsql query was executed. We
144        // should not see errors in the steady state.
145        _ => {
146            info!(
147                "error executing prometheus query {}: {:?}",
148                query.metric_name, res
149            );
150            return;
151        }
152    };
153
154    let SqlResult::Rows { desc, rows, .. } = result else {
155        info!(
156            "did not receive rows for SQL query for prometheus metric {}: {:?}, {:?}",
157            query.metric_name, result, cluster
158        );
159        return;
160    };
161
162    let gauge_vec = metrics_by_name
163        .entry(query.metric_name.to_string())
164        .or_insert_with(|| {
165            let mut label_names: Vec<String> = desc
166                .columns
167                .iter()
168                .filter(|col| col.name != query.value_column_name)
169                .map(|col| col.name.clone())
170                .collect();
171
172            if query.per_replica {
173                label_names.extend(PER_REPLICA_LABELS.iter().map(|label| label.to_string()));
174            }
175
176            metrics_registry.register::<GenericGaugeVec<AtomicF64>>(MakeCollectorOpts {
177                opts: Opts::new(query.metric_name, query.help).variable_labels(label_names),
178                buckets: None,
179            })
180        });
181
182    for row in rows {
183        let mut label_values = desc
184            .columns
185            .iter()
186            .zip_eq(row)
187            .filter(|(col, _)| col.name != query.value_column_name)
188            .map(|(_, val)| val.as_str().expect("must be string"))
189            .collect::<Vec<_>>();
190
191        let value = desc
192            .columns
193            .iter()
194            .zip_eq(row)
195            .find(|(col, _)| col.name == query.value_column_name)
196            .map(|(_, val)| val.as_str().unwrap_or("0").parse::<f64>().unwrap_or(0.0))
197            .unwrap_or(0.0);
198
199        match cluster {
200            Some((cluster, replica)) => {
201                let replica_full_name = format!("{}.{}", cluster.name, replica.name);
202                let cluster_id = cluster.id.to_string();
203                let replica_id = replica.replica_id.to_string();
204
205                label_values.push(&replica_full_name);
206                label_values.push(&cluster_id);
207                label_values.push(&replica_id);
208
209                gauge_vec
210                    .get_metric_with_label_values(&label_values)
211                    .expect("valid labels")
212                    .set(value);
213            }
214            None => {
215                gauge_vec
216                    .get_metric_with_label_values(&label_values)
217                    .expect("valid labels")
218                    .set(value);
219            }
220        }
221    }
222}
223
224async fn handle_promsql_query(
225    client: &mut AuthedClient,
226    query: &PrometheusSqlQuery<'_>,
227    metrics_registry: &MetricsRegistry,
228    metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
229) {
230    if !query.per_replica {
231        execute_promsql_query(client, query, metrics_registry, metrics_by_name, None).await;
232        return;
233    }
234
235    let catalog = client.client.catalog_snapshot("handle_promsql_query").await;
236    let clusters: Vec<&Cluster> = catalog.clusters().collect();
237
238    for cluster in clusters {
239        for replica in cluster.replicas() {
240            execute_promsql_query(
241                client,
242                query,
243                metrics_registry,
244                metrics_by_name,
245                Some((cluster, replica)),
246            )
247            .await;
248        }
249    }
250}
251
252pub async fn handle_promsql(
253    mut client: AuthedClient,
254    queries: &[PrometheusSqlQuery<'_>],
255) -> MetricsRegistry {
256    let metrics_registry = MetricsRegistry::new();
257    let mut metrics_by_name = BTreeMap::new();
258
259    for query in queries {
260        handle_promsql_query(&mut client, query, &metrics_registry, &mut metrics_by_name).await;
261    }
262
263    metrics_registry
264}
265
266pub async fn handle_sql(
267    mut client: AuthedClient,
268    Json(request): Json<SqlRequest>,
269) -> impl IntoResponse {
270    let mut res = SqlResponse {
271        results: Vec::new(),
272    };
273    // Don't need to worry about timeouts or resetting cancel here because there is always exactly 1
274    // request.
275    match execute_request(&mut client, request, &mut res).await {
276        Ok(()) => Ok(Json(res)),
277        Err(e) => Err((StatusCode::BAD_REQUEST, e.to_string())),
278    }
279}
280
281pub async fn handle_sql_ws(
282    State(state): State<WsState>,
283    existing_user: Option<Extension<AuthedUser>>,
284    ws: WebSocketUpgrade,
285    ConnectInfo(addr): ConnectInfo<SocketAddr>,
286    tower_session: Option<Extension<TowerSession>>,
287) -> impl IntoResponse {
288    // An upstream middleware may have already provided the user for us
289    let user = match (existing_user, tower_session) {
290        (Some(Extension(user)), _) => Some(user),
291        (None, Some(session)) => {
292            if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
293                // Check session expiration
294                if session_data
295                    .last_activity
296                    .elapsed()
297                    .unwrap_or(Duration::MAX)
298                    > SESSION_DURATION
299                {
300                    let _ = session.delete().await;
301                    return Err(AuthError::SessionExpired);
302                }
303                // Update last activity
304                let mut updated_data = session_data.clone();
305                updated_data.last_activity = SystemTime::now();
306                session
307                    .insert("data", &updated_data)
308                    .await
309                    .map_err(|_| AuthError::FailedToUpdateSession)?;
310                // User is authenticated via session
311                Some(AuthedUser {
312                    name: session_data.username,
313                    external_metadata_rx: None,
314                    internal_metadata: Some(session_data.internal_metadata),
315                })
316            } else {
317                None
318            }
319        }
320        _ => None,
321    };
322
323    let addr = Box::new(addr.ip());
324    Ok(ws
325        .max_message_size(MAX_REQUEST_SIZE)
326        .on_upgrade(|ws| async move { run_ws(&state, user, *addr, ws).await }))
327}
328
329#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
330#[serde(untagged)]
331pub enum WebSocketAuth {
332    Basic {
333        user: String,
334        password: Password,
335        #[serde(default)]
336        options: BTreeMap<String, String>,
337    },
338    Bearer {
339        token: String,
340        #[serde(default)]
341        options: BTreeMap<String, String>,
342    },
343    OptionsOnly {
344        #[serde(default)]
345        options: BTreeMap<String, String>,
346    },
347}
348
349async fn run_ws(state: &WsState, user: Option<AuthedUser>, peer_addr: IpAddr, mut ws: WebSocket) {
350    let mut client = match init_ws(state, user, peer_addr, &mut ws).await {
351        Ok(client) => client,
352        Err(e) => {
353            // We omit most detail from the error message we send to the client, to
354            // avoid giving attackers unnecessary information during auth. AdapterErrors
355            // are safe to return because they're generated after authentication.
356            debug!("WS request failed init: {}", e);
357            let reason = match e.downcast_ref::<AdapterError>() {
358                Some(error) => Cow::Owned(error.to_string()),
359                None => "unauthorized".into(),
360            };
361            let _ = ws
362                .send(Message::Close(Some(CloseFrame {
363                    code: CloseCode::Protocol.into(),
364                    reason,
365                })))
366                .await;
367            return;
368        }
369    };
370
371    // Successful auth, send startup messages.
372    let mut msgs = Vec::new();
373    let session = client.client.session();
374    for var in session.vars().notify_set() {
375        msgs.push(WebSocketResponse::ParameterStatus(ParameterStatus {
376            name: var.name().to_string(),
377            value: var.value(),
378        }));
379    }
380    msgs.push(WebSocketResponse::BackendKeyData(BackendKeyData {
381        conn_id: session.conn_id().unhandled(),
382        secret_key: session.secret_key(),
383    }));
384    msgs.push(WebSocketResponse::ReadyForQuery(
385        session.transaction_code().into(),
386    ));
387    for msg in msgs {
388        let _ = ws
389            .send(Message::Text(
390                serde_json::to_string(&msg).expect("must serialize"),
391            ))
392            .await;
393    }
394
395    // Send any notices that might have been generated on startup.
396    let notices = session.drain_notices();
397    if let Err(err) = forward_notices(&mut ws, notices).await {
398        debug!("failed to forward notices to WebSocket, {err:?}");
399        return;
400    }
401
402    loop {
403        // Handle timeouts first so we don't execute any statements when there's a pending timeout.
404        let msg = select! {
405            biased;
406
407            // `recv_timeout()` is cancel-safe as per it's docs.
408            Some(timeout) = client.client.recv_timeout() => {
409                client.client.terminate().await;
410                // We must wait for the client to send a request before we can send the error
411                // response. Although this isn't the PG wire protocol, we choose to mirror it by
412                // only sending errors as responses to requests.
413                let _ = ws.recv().await;
414                let err = Error::from(AdapterError::from(timeout));
415                let _ = send_ws_response(&mut ws, WebSocketResponse::Error(err.into())).await;
416                return;
417            },
418            message = ws.recv() => message,
419        };
420
421        client.client.remove_idle_in_transaction_session_timeout();
422
423        let msg = match msg {
424            Some(Ok(msg)) => msg,
425            _ => {
426                // client disconnected
427                return;
428            }
429        };
430
431        let req: Result<SqlRequest, Error> = match msg {
432            Message::Text(data) => serde_json::from_str(&data).err_into(),
433            Message::Binary(data) => serde_json::from_slice(&data).err_into(),
434            // Handled automatically by the server.
435            Message::Ping(_) => {
436                continue;
437            }
438            Message::Pong(_) => {
439                continue;
440            }
441            Message::Close(_) => {
442                return;
443            }
444        };
445
446        // Figure out if we need to send an error, any notices, but always the ready message.
447        let err = match run_ws_request(req, &mut client, &mut ws).await {
448            Ok(()) => None,
449            Err(err) => Some(WebSocketResponse::Error(err.into())),
450        };
451
452        // After running our request, there are several messages we need to send in a
453        // specific order.
454        //
455        // Note: we nest these into a closure so we can centralize our error handling
456        // for when sending over the WebSocket fails. We could also use a try {} block
457        // here, but those aren't stabilized yet.
458        let ws_response = || async {
459            // First respond with any error that might have occurred.
460            if let Some(e_resp) = err {
461                send_ws_response(&mut ws, e_resp).await?;
462            }
463
464            // Then forward along any notices we generated.
465            let notices = client.client.session().drain_notices();
466            forward_notices(&mut ws, notices).await?;
467
468            // Finally, respond that we're ready for the next query.
469            let ready =
470                WebSocketResponse::ReadyForQuery(client.client.session().transaction_code().into());
471            send_ws_response(&mut ws, ready).await?;
472
473            Ok::<_, Error>(())
474        };
475
476        if let Err(err) = ws_response().await {
477            debug!("failed to send response over WebSocket, {err:?}");
478            return;
479        }
480    }
481}
482
483async fn run_ws_request(
484    req: Result<SqlRequest, Error>,
485    client: &mut AuthedClient,
486    ws: &mut WebSocket,
487) -> Result<(), Error> {
488    let req = req?;
489    execute_request(client, req, ws).await
490}
491
492/// Sends a single [`WebSocketResponse`] over the provided [`WebSocket`].
493async fn send_ws_response(ws: &mut WebSocket, resp: WebSocketResponse) -> Result<(), Error> {
494    let msg = serde_json::to_string(&resp).unwrap();
495    let msg = Message::Text(msg);
496    ws.send(msg).await?;
497
498    Ok(())
499}
500
501/// Forwards a collection of Notices to the provided [`WebSocket`].
502async fn forward_notices(
503    ws: &mut WebSocket,
504    notices: impl IntoIterator<Item = AdapterNotice>,
505) -> Result<(), Error> {
506    let ws_notices = notices.into_iter().map(|notice| {
507        WebSocketResponse::Notice(Notice {
508            message: notice.to_string(),
509            code: notice.code().code().to_string(),
510            severity: notice.severity().as_str().to_lowercase(),
511            detail: notice.detail(),
512            hint: notice.hint(),
513        })
514    });
515
516    for notice in ws_notices {
517        send_ws_response(ws, notice).await?;
518    }
519
520    Ok(())
521}
522
523/// A request to execute SQL over HTTP.
524#[derive(Serialize, Deserialize, Debug)]
525#[serde(untagged)]
526pub enum SqlRequest {
527    /// A simple query request.
528    Simple {
529        /// A query string containing zero or more queries delimited by
530        /// semicolons.
531        query: String,
532    },
533    /// An extended query request.
534    Extended {
535        /// Queries to execute using the extended protocol.
536        queries: Vec<ExtendedRequest>,
537    },
538}
539
540/// An request to execute a SQL query using the extended protocol.
541#[derive(Serialize, Deserialize, Debug)]
542pub struct ExtendedRequest {
543    /// A query string containing zero or one queries.
544    query: String,
545    /// Optional parameters for the query.
546    #[serde(default)]
547    params: Vec<Option<String>>,
548}
549
550/// The response to a `SqlRequest`.
551#[derive(Debug, Serialize, Deserialize)]
552pub struct SqlResponse {
553    /// The results for each query in the request.
554    results: Vec<SqlResult>,
555}
556
557enum StatementResult {
558    SqlResult(SqlResult),
559    Subscribe {
560        desc: RelationDesc,
561        tag: String,
562        rx: RecordFirstRowStream,
563        ctx_extra: ExecuteContextExtra,
564    },
565}
566
567impl From<SqlResult> for StatementResult {
568    fn from(inner: SqlResult) -> Self {
569        Self::SqlResult(inner)
570    }
571}
572
573/// The result of a single query in a [`SqlResponse`].
574#[derive(Debug, Serialize, Deserialize)]
575#[serde(untagged)]
576pub enum SqlResult {
577    /// The query returned rows.
578    Rows {
579        /// The command complete tag.
580        tag: String,
581        /// The result rows.
582        rows: Vec<Vec<serde_json::Value>>,
583        /// Information about each column.
584        desc: Description,
585        // Any notices generated during execution of the query.
586        notices: Vec<Notice>,
587    },
588    /// The query executed successfully but did not return rows.
589    Ok {
590        /// The command complete tag.
591        ok: String,
592        /// Any notices generated during execution of the query.
593        notices: Vec<Notice>,
594        /// Any parameters that may have changed.
595        ///
596        /// Note: skip serializing this field in a response if the list of parameters is empty.
597        #[serde(skip_serializing_if = "Vec::is_empty")]
598        parameters: Vec<ParameterStatus>,
599    },
600    /// The query returned an error.
601    Err {
602        error: SqlError,
603        // Any notices generated during execution of the query.
604        notices: Vec<Notice>,
605    },
606}
607
608impl SqlResult {
609    /// Convert adapter Row results into the web row result format. Error if the row format does not
610    /// match the expected descriptor.
611    // TODO(aljoscha): Bail when max_result_size is exceeded.
612    async fn rows<S>(
613        sender: &mut S,
614        client: &mut SessionClient,
615        mut rows_stream: RecordFirstRowStream,
616        max_query_result_size: usize,
617        desc: &RelationDesc,
618    ) -> Result<SqlResult, Error>
619    where
620        S: ResultSender,
621    {
622        let mut rows: Vec<Vec<serde_json::Value>> = vec![];
623        let mut datum_vec = mz_repr::DatumVec::new();
624        let types = &desc.typ().column_types;
625
626        let mut query_result_size = 0;
627
628        loop {
629            let peek_response = tokio::select! {
630                notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
631                    sender.emit_streaming_notices(vec![notice]).await?;
632                    continue;
633                }
634                e = sender.connection_error() => return Err(e),
635                r = rows_stream.recv() => {
636                    match r {
637                        Some(r) => r,
638                        None => break,
639                    }
640                },
641            };
642
643            let mut sql_rows = match peek_response {
644                PeekResponseUnary::Rows(rows) => rows,
645                PeekResponseUnary::Error(e) => {
646                    return Ok(SqlResult::err(client, Error::Unstructured(anyhow!(e))));
647                }
648                PeekResponseUnary::Canceled => {
649                    return Ok(SqlResult::err(client, AdapterError::Canceled));
650                }
651            };
652
653            if let Err(err) = verify_datum_desc(desc, &mut sql_rows) {
654                return Ok(SqlResult::Err {
655                    error: err.into(),
656                    notices: make_notices(client),
657                });
658            }
659
660            while let Some(row) = sql_rows.next() {
661                query_result_size += row.byte_len();
662                if query_result_size > max_query_result_size {
663                    use bytesize::ByteSize;
664                    return Ok(SqlResult::err(
665                        client,
666                        AdapterError::ResultSize(format!(
667                            "result exceeds max size of {}",
668                            ByteSize::b(u64::cast_from(max_query_result_size))
669                        )),
670                    ));
671                }
672
673                let datums = datum_vec.borrow_with(row);
674                rows.push(
675                    datums
676                        .iter()
677                        .enumerate()
678                        .map(|(i, d)| {
679                            TypedDatum::new(*d, &types[i])
680                                .json(&JsonNumberPolicy::ConvertNumberToString)
681                        })
682                        .collect(),
683                );
684            }
685        }
686
687        let tag = format!("SELECT {}", rows.len());
688        Ok(SqlResult::Rows {
689            tag,
690            rows,
691            desc: Description::from(desc),
692            notices: make_notices(client),
693        })
694    }
695
696    fn err(client: &mut SessionClient, error: impl Into<SqlError>) -> SqlResult {
697        SqlResult::Err {
698            error: error.into(),
699            notices: make_notices(client),
700        }
701    }
702
703    fn ok(client: &mut SessionClient, tag: String, params: Vec<ParameterStatus>) -> SqlResult {
704        SqlResult::Ok {
705            ok: tag,
706            parameters: params,
707            notices: make_notices(client),
708        }
709    }
710}
711
712#[derive(Debug, Deserialize, Serialize)]
713pub struct SqlError {
714    pub message: String,
715    pub code: String,
716    #[serde(skip_serializing_if = "Option::is_none")]
717    pub detail: Option<String>,
718    #[serde(skip_serializing_if = "Option::is_none")]
719    pub hint: Option<String>,
720    #[serde(skip_serializing_if = "Option::is_none")]
721    pub position: Option<usize>,
722}
723
724impl From<Error> for SqlError {
725    fn from(err: Error) -> Self {
726        SqlError {
727            message: err.to_string(),
728            code: err.code().code().to_string(),
729            detail: err.detail(),
730            hint: err.hint(),
731            position: err.position(),
732        }
733    }
734}
735
736impl From<AdapterError> for SqlError {
737    fn from(value: AdapterError) -> Self {
738        Error::from(value).into()
739    }
740}
741
742#[derive(Debug, Deserialize, Serialize)]
743#[serde(tag = "type", content = "payload")]
744pub enum WebSocketResponse {
745    ReadyForQuery(String),
746    Notice(Notice),
747    Rows(Description),
748    Row(Vec<serde_json::Value>),
749    CommandStarting(CommandStarting),
750    CommandComplete(String),
751    Error(SqlError),
752    ParameterStatus(ParameterStatus),
753    BackendKeyData(BackendKeyData),
754}
755
756#[derive(Debug, Serialize, Deserialize)]
757pub struct Notice {
758    message: String,
759    code: String,
760    severity: String,
761    #[serde(skip_serializing_if = "Option::is_none")]
762    pub detail: Option<String>,
763    #[serde(skip_serializing_if = "Option::is_none")]
764    pub hint: Option<String>,
765}
766
767impl Notice {
768    pub fn message(&self) -> &str {
769        &self.message
770    }
771}
772
773#[derive(Debug, Serialize, Deserialize)]
774pub struct Description {
775    pub columns: Vec<Column>,
776}
777
778impl From<&RelationDesc> for Description {
779    fn from(desc: &RelationDesc) -> Self {
780        let columns = desc
781            .iter()
782            .map(|(name, typ)| {
783                let pg_type = mz_pgrepr::Type::from(&typ.scalar_type);
784                Column {
785                    name: name.to_string(),
786                    type_oid: pg_type.oid(),
787                    type_len: pg_type.typlen(),
788                    type_mod: pg_type.typmod(),
789                }
790            })
791            .collect();
792        Description { columns }
793    }
794}
795
796#[derive(Debug, Serialize, Deserialize)]
797pub struct Column {
798    pub name: String,
799    pub type_oid: u32,
800    pub type_len: i16,
801    pub type_mod: i32,
802}
803
804#[derive(Debug, Serialize, Deserialize)]
805pub struct ParameterStatus {
806    name: String,
807    value: String,
808}
809
810#[derive(Debug, Serialize, Deserialize)]
811pub struct BackendKeyData {
812    conn_id: u32,
813    secret_key: u32,
814}
815
816#[derive(Debug, Serialize, Deserialize)]
817pub struct CommandStarting {
818    has_rows: bool,
819    is_streaming: bool,
820}
821
822/// Trait describing how to transmit a response to a client. HTTP clients
823/// accumulate into a Vec and send all at once. WebSocket clients send each
824/// message as they occur.
825#[async_trait]
826trait ResultSender: Send {
827    const SUPPORTS_STREAMING_NOTICES: bool = false;
828
829    /// Adds a result to the client. The first component of the return value is
830    /// Err if sending to the client
831    /// produced an error and the server should disconnect. It is Ok(Err) if the statement
832    /// produced an error and should error the transaction, but remain connected. It is Ok(Ok(()))
833    /// if the statement succeeded.
834    /// The second component of the return value is `Some` if execution still
835    /// needs to be retired for statement logging purposes.
836    async fn add_result(
837        &mut self,
838        client: &mut SessionClient,
839        res: StatementResult,
840    ) -> (
841        Result<Result<(), ()>, Error>,
842        Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
843    );
844
845    /// Returns a future that resolves only when the client connection has gone away.
846    fn connection_error(&mut self) -> BoxFuture<'_, Error>;
847    /// Reports whether the client supports streaming SUBSCRIBE results.
848    fn allow_subscribe(&self) -> bool;
849
850    /// Emits a streaming notice if the sender supports it.
851    ///
852    /// Does nothing if `SUPPORTS_STREAMING_NOTICES` is false.
853    async fn emit_streaming_notices(&mut self, _: Vec<AdapterNotice>) -> Result<(), Error> {
854        unreachable!("streaming notices marked as unsupported")
855    }
856}
857
858#[async_trait]
859impl ResultSender for SqlResponse {
860    // The first component of the return value is
861    // Err if sending to the client
862    // produced an error and the server should disconnect. It is Ok(Err) if the statement
863    // produced an error and should error the transaction, but remain connected. It is Ok(Ok(()))
864    // if the statement succeeded.
865    // The second component of the return value is `Some` if execution still
866    // needs to be retired for statement logging purposes.
867    async fn add_result(
868        &mut self,
869        _client: &mut SessionClient,
870        res: StatementResult,
871    ) -> (
872        Result<Result<(), ()>, Error>,
873        Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
874    ) {
875        let (res, stmt_logging) = match res {
876            StatementResult::SqlResult(res) => {
877                let is_err = matches!(res, SqlResult::Err { .. });
878                self.results.push(res);
879                let res = if is_err { Err(()) } else { Ok(()) };
880                (res, None)
881            }
882            StatementResult::Subscribe { ctx_extra, .. } => {
883                let message = "SUBSCRIBE only supported over websocket";
884                self.results.push(SqlResult::Err {
885                    error: Error::SubscribeOnlyOverWs.into(),
886                    notices: Vec::new(),
887                });
888                (
889                    Err(()),
890                    Some((
891                        StatementEndedExecutionReason::Errored {
892                            error: message.into(),
893                        },
894                        ctx_extra,
895                    )),
896                )
897            }
898        };
899        (Ok(res), stmt_logging)
900    }
901
902    fn connection_error(&mut self) -> BoxFuture<'_, Error> {
903        Box::pin(futures::future::pending())
904    }
905
906    fn allow_subscribe(&self) -> bool {
907        false
908    }
909}
910
911#[async_trait]
912impl ResultSender for WebSocket {
913    const SUPPORTS_STREAMING_NOTICES: bool = true;
914
915    // The first component of the return value is Err if sending to the client produced an error and
916    // the server should disconnect. It is Ok(Err) if the statement produced an error and should
917    // error the transaction, but remain connected. It is Ok(Ok(())) if the statement succeeded. The
918    // second component of the return value is `Some` if execution still needs to be retired for
919    // statement logging purposes.
920    async fn add_result(
921        &mut self,
922        client: &mut SessionClient,
923        res: StatementResult,
924    ) -> (
925        Result<Result<(), ()>, Error>,
926        Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
927    ) {
928        let (has_rows, is_streaming) = match res {
929            StatementResult::SqlResult(SqlResult::Err { .. }) => (false, false),
930            StatementResult::SqlResult(SqlResult::Ok { .. }) => (false, false),
931            StatementResult::SqlResult(SqlResult::Rows { .. }) => (true, false),
932            StatementResult::Subscribe { .. } => (true, true),
933        };
934        if let Err(e) = send_ws_response(
935            self,
936            WebSocketResponse::CommandStarting(CommandStarting {
937                has_rows,
938                is_streaming,
939            }),
940        )
941        .await
942        {
943            return (Err(e), None);
944        }
945
946        let (is_err, msgs, stmt_logging) = match res {
947            StatementResult::SqlResult(SqlResult::Rows {
948                tag,
949                rows,
950                desc,
951                notices,
952            }) => {
953                let mut msgs = vec![WebSocketResponse::Rows(desc)];
954                msgs.extend(rows.into_iter().map(WebSocketResponse::Row));
955                msgs.push(WebSocketResponse::CommandComplete(tag));
956                msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
957                (false, msgs, None)
958            }
959            StatementResult::SqlResult(SqlResult::Ok {
960                ok,
961                parameters,
962                notices,
963            }) => {
964                let mut msgs = vec![WebSocketResponse::CommandComplete(ok)];
965                msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
966                msgs.extend(
967                    parameters
968                        .into_iter()
969                        .map(WebSocketResponse::ParameterStatus),
970                );
971                (false, msgs, None)
972            }
973            StatementResult::SqlResult(SqlResult::Err { error, notices }) => {
974                let mut msgs = vec![WebSocketResponse::Error(error)];
975                msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
976                (true, msgs, None)
977            }
978            StatementResult::Subscribe {
979                ref desc,
980                tag,
981                mut rx,
982                ctx_extra,
983            } => {
984                if let Err(e) = send_ws_response(self, WebSocketResponse::Rows(desc.into())).await {
985                    // We consider the remote breaking the connection to be a cancellation,
986                    // matching the behavior for pgwire
987                    return (
988                        Err(e),
989                        Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
990                    );
991                }
992
993                let mut datum_vec = mz_repr::DatumVec::new();
994                let mut result_size: usize = 0;
995                let mut rows_returned = 0;
996                loop {
997                    let res = match await_rows(self, client, rx.recv()).await {
998                        Ok(res) => res,
999                        Err(e) => {
1000                            // We consider the remote breaking the connection to be a cancellation,
1001                            // matching the behavior for pgwire
1002                            return (
1003                                Err(e),
1004                                Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1005                            );
1006                        }
1007                    };
1008                    match res {
1009                        Some(PeekResponseUnary::Rows(mut rows)) => {
1010                            if let Err(err) = verify_datum_desc(desc, &mut rows) {
1011                                let error = err.to_string();
1012                                break (
1013                                    true,
1014                                    vec![WebSocketResponse::Error(err.into())],
1015                                    Some((
1016                                        StatementEndedExecutionReason::Errored { error },
1017                                        ctx_extra,
1018                                    )),
1019                                );
1020                            }
1021
1022                            rows_returned += rows.count();
1023                            while let Some(row) = rows.next() {
1024                                result_size += row.byte_len();
1025                                let datums = datum_vec.borrow_with(row);
1026                                let types = &desc.typ().column_types;
1027                                if let Err(e) = send_ws_response(
1028                                    self,
1029                                    WebSocketResponse::Row(
1030                                        datums
1031                                            .iter()
1032                                            .enumerate()
1033                                            .map(|(i, d)| {
1034                                                TypedDatum::new(*d, &types[i])
1035                                                    .json(&JsonNumberPolicy::ConvertNumberToString)
1036                                            })
1037                                            .collect(),
1038                                    ),
1039                                )
1040                                .await
1041                                {
1042                                    // We consider the remote breaking the connection to be a cancellation,
1043                                    // matching the behavior for pgwire
1044                                    return (
1045                                        Err(e),
1046                                        Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1047                                    );
1048                                }
1049                            }
1050                        }
1051                        Some(PeekResponseUnary::Error(error)) => {
1052                            break (
1053                                true,
1054                                vec![WebSocketResponse::Error(
1055                                    Error::Unstructured(anyhow!(error.clone())).into(),
1056                                )],
1057                                Some((StatementEndedExecutionReason::Errored { error }, ctx_extra)),
1058                            );
1059                        }
1060                        Some(PeekResponseUnary::Canceled) => {
1061                            break (
1062                                true,
1063                                vec![WebSocketResponse::Error(AdapterError::Canceled.into())],
1064                                Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1065                            );
1066                        }
1067                        None => {
1068                            break (
1069                                false,
1070                                vec![WebSocketResponse::CommandComplete(tag)],
1071                                Some((
1072                                    StatementEndedExecutionReason::Success {
1073                                        result_size: Some(u64::cast_from(result_size)),
1074                                        rows_returned: Some(u64::cast_from(rows_returned)),
1075                                        execution_strategy: Some(
1076                                            StatementExecutionStrategy::Standard,
1077                                        ),
1078                                    },
1079                                    ctx_extra,
1080                                )),
1081                            );
1082                        }
1083                    }
1084                }
1085            }
1086        };
1087        for msg in msgs {
1088            if let Err(e) = send_ws_response(self, msg).await {
1089                return (
1090                    Err(e),
1091                    stmt_logging.map(|(_old_reason, ctx_extra)| {
1092                        (StatementEndedExecutionReason::Canceled, ctx_extra)
1093                    }),
1094                );
1095            }
1096        }
1097        (Ok(if is_err { Err(()) } else { Ok(()) }), stmt_logging)
1098    }
1099
1100    // Send a websocket Ping every second to verify the client is still
1101    // connected.
1102    fn connection_error(&mut self) -> BoxFuture<'_, Error> {
1103        Box::pin(async {
1104            let mut tick = time::interval(Duration::from_secs(1));
1105            tick.tick().await;
1106            loop {
1107                tick.tick().await;
1108                if let Err(err) = self.send(Message::Ping(Vec::new())).await {
1109                    return err.into();
1110                }
1111            }
1112        })
1113    }
1114
1115    fn allow_subscribe(&self) -> bool {
1116        true
1117    }
1118
1119    async fn emit_streaming_notices(&mut self, notices: Vec<AdapterNotice>) -> Result<(), Error> {
1120        forward_notices(self, notices).await
1121    }
1122}
1123
1124async fn await_rows<S, F, R>(sender: &mut S, client: &mut SessionClient, f: F) -> Result<R, Error>
1125where
1126    S: ResultSender,
1127    F: Future<Output = R> + Send,
1128{
1129    let mut f = pin!(f);
1130    loop {
1131        tokio::select! {
1132            notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
1133                sender.emit_streaming_notices(vec![notice]).await?;
1134            }
1135            e = sender.connection_error() => return Err(e),
1136            r = &mut f => return Ok(r),
1137        }
1138    }
1139}
1140
1141async fn send_and_retire<S: ResultSender>(
1142    res: StatementResult,
1143    client: &mut SessionClient,
1144    sender: &mut S,
1145) -> Result<Result<(), ()>, Error> {
1146    let (res, stmt_logging) = sender.add_result(client, res).await;
1147    if let Some((reason, ctx_extra)) = stmt_logging {
1148        client.retire_execute(ctx_extra, reason);
1149    }
1150    res
1151}
1152
1153/// Returns Ok(Err) if any statement error'd during execution.
1154async fn execute_stmt_group<S: ResultSender>(
1155    client: &mut SessionClient,
1156    sender: &mut S,
1157    stmt_group: Vec<(Statement<Raw>, String, Vec<Option<String>>)>,
1158) -> Result<Result<(), ()>, Error> {
1159    let num_stmts = stmt_group.len();
1160    for (stmt, sql, params) in stmt_group {
1161        assert!(
1162            num_stmts <= 1 || params.is_empty(),
1163            "statement groups contain more than 1 statement iff Simple request, which does not support parameters"
1164        );
1165
1166        let is_aborted_txn = matches!(client.session().transaction(), TransactionStatus::Failed(_));
1167        if is_aborted_txn && !is_txn_exit_stmt(&stmt) {
1168            let err = SqlResult::err(client, Error::AbortedTransaction);
1169            let _ = send_and_retire(err.into(), client, sender).await?;
1170            return Ok(Err(()));
1171        }
1172
1173        // Mirror the behavior of the PostgreSQL simple query protocol.
1174        // See the pgwire::protocol::StateMachine::query method for details.
1175        if let Err(e) = client.start_transaction(Some(num_stmts)) {
1176            let err = SqlResult::err(client, e);
1177            let _ = send_and_retire(err.into(), client, sender).await?;
1178            return Ok(Err(()));
1179        }
1180        let res = execute_stmt(client, sender, stmt, sql, params).await?;
1181        let is_err = send_and_retire(res, client, sender).await?;
1182
1183        if is_err.is_err() {
1184            // Mirror StateMachine::error, which sometimes will clean up the
1185            // transaction state instead of always leaving it in Failed.
1186            let txn = client.session().transaction();
1187            match txn {
1188                // Error can be called from describe and parse and so might not be in an active
1189                // transaction.
1190                TransactionStatus::Default | TransactionStatus::Failed(_) => {}
1191                // In Started (i.e., a single statement) and implicit transactions cleanup themselves.
1192                TransactionStatus::Started(_) | TransactionStatus::InTransactionImplicit(_) => {
1193                    if let Err(err) = client.end_transaction(EndTransactionAction::Rollback).await {
1194                        let err = SqlResult::err(client, err);
1195                        let _ = send_and_retire(err.into(), client, sender).await?;
1196                    }
1197                }
1198                // Explicit transactions move to failed.
1199                TransactionStatus::InTransaction(_) => {
1200                    client.fail_transaction();
1201                }
1202            }
1203            return Ok(Err(()));
1204        }
1205    }
1206    Ok(Ok(()))
1207}
1208
1209/// Executes an entire [`SqlRequest`].
1210///
1211/// See the user-facing documentation about the HTTP API for a description of
1212/// the semantics of this function.
1213async fn execute_request<S: ResultSender>(
1214    client: &mut AuthedClient,
1215    request: SqlRequest,
1216    sender: &mut S,
1217) -> Result<(), Error> {
1218    let client = &mut client.client;
1219
1220    // This API prohibits executing statements with responses whose
1221    // semantics are at odds with an HTTP response.
1222    fn check_prohibited_stmts<S: ResultSender>(
1223        sender: &S,
1224        stmt: &Statement<Raw>,
1225    ) -> Result<(), Error> {
1226        let kind: StatementKind = stmt.into();
1227        let execute_responses = Plan::generated_from(&kind)
1228            .into_iter()
1229            .map(ExecuteResponse::generated_from)
1230            .flatten()
1231            .collect::<Vec<_>>();
1232
1233        // Special-case `COPY TO` statements that are not `COPY ... TO STDOUT`, since
1234        // StatementKind::Copy links to several `ExecuteResponseKind`s that are not supported,
1235        // but this specific statement should be allowed.
1236        let is_valid_copy = matches!(
1237            stmt,
1238            Statement::Copy(CopyStatement {
1239                direction: CopyDirection::To,
1240                target: CopyTarget::Expr(_),
1241                ..
1242            }) | Statement::Copy(CopyStatement {
1243                direction: CopyDirection::From,
1244                target: CopyTarget::Expr(_),
1245                ..
1246            })
1247        );
1248
1249        if !is_valid_copy
1250            && execute_responses.iter().any(|execute_response| {
1251                // Returns true if a statement or execute response are unsupported.
1252                match execute_response {
1253                    ExecuteResponseKind::Subscribing if sender.allow_subscribe() => false,
1254                    ExecuteResponseKind::Fetch
1255                    | ExecuteResponseKind::Subscribing
1256                    | ExecuteResponseKind::CopyFrom
1257                    | ExecuteResponseKind::DeclaredCursor
1258                    | ExecuteResponseKind::ClosedCursor => true,
1259                    // Various statements generate `PeekPlan` (`SELECT`, `COPY`,
1260                    // `EXPLAIN`, `SHOW`) which has both `SendRows` and `CopyTo` as its
1261                    // possible response types. but `COPY` needs be picked out because
1262                    // http don't support its response type
1263                    ExecuteResponseKind::CopyTo if matches!(kind, StatementKind::Copy) => true,
1264                    _ => false,
1265                }
1266            })
1267        {
1268            return Err(Error::Unsupported(stmt.to_ast_string_simple()));
1269        }
1270        Ok(())
1271    }
1272
1273    fn parse<'a>(
1274        client: &SessionClient,
1275        query: &'a str,
1276    ) -> Result<Vec<StatementParseResult<'a>>, Error> {
1277        let result = client
1278            .parse(query)
1279            .map_err(|e| Error::Unstructured(anyhow!(e)))?;
1280        result.map_err(|e| AdapterError::from(e).into())
1281    }
1282
1283    let mut stmt_groups = vec![];
1284
1285    match request {
1286        SqlRequest::Simple { query } => match parse(client, &query) {
1287            Ok(stmts) => {
1288                let mut stmt_group = Vec::with_capacity(stmts.len());
1289                let mut stmt_err = None;
1290                for StatementParseResult { ast: stmt, sql } in stmts {
1291                    if let Err(err) = check_prohibited_stmts(sender, &stmt) {
1292                        stmt_err = Some(err);
1293                        break;
1294                    }
1295                    stmt_group.push((stmt, sql.to_string(), vec![]));
1296                }
1297                stmt_groups.push(stmt_err.map(Err).unwrap_or_else(|| Ok(stmt_group)));
1298            }
1299            Err(e) => stmt_groups.push(Err(e)),
1300        },
1301        SqlRequest::Extended { queries } => {
1302            for ExtendedRequest { query, params } in queries {
1303                match parse(client, &query) {
1304                    Ok(mut stmts) => {
1305                        if stmts.len() != 1 {
1306                            return Err(Error::Unstructured(anyhow!(
1307                                "each query must contain exactly 1 statement, but \"{}\" contains {}",
1308                                query,
1309                                stmts.len()
1310                            )));
1311                        }
1312
1313                        let StatementParseResult { ast: stmt, sql } = stmts.pop().unwrap();
1314                        stmt_groups.push(
1315                            check_prohibited_stmts(sender, &stmt)
1316                                .map(|_| vec![(stmt, sql.to_string(), params)]),
1317                        );
1318                    }
1319                    Err(e) => stmt_groups.push(Err(e)),
1320                };
1321            }
1322        }
1323    }
1324
1325    for stmt_group_res in stmt_groups {
1326        let executed = match stmt_group_res {
1327            Ok(stmt_group) => execute_stmt_group(client, sender, stmt_group).await,
1328            Err(e) => {
1329                let err = SqlResult::err(client, e);
1330                let _ = send_and_retire(err.into(), client, sender).await?;
1331                Ok(Err(()))
1332            }
1333        };
1334        // At the end of each group, commit implicit transactions. Do that here so that any `?`
1335        // early return can still be handled here.
1336        if client.session().transaction().is_implicit() {
1337            let ended = client.end_transaction(EndTransactionAction::Commit).await;
1338            if let Err(err) = ended {
1339                let err = SqlResult::err(client, err);
1340                let _ = send_and_retire(StatementResult::SqlResult(err), client, sender).await?;
1341            }
1342        }
1343        if executed?.is_err() {
1344            break;
1345        }
1346    }
1347
1348    Ok(())
1349}
1350
1351/// Executes a single statement in a [`SqlRequest`].
1352async fn execute_stmt<S: ResultSender>(
1353    client: &mut SessionClient,
1354    sender: &mut S,
1355    stmt: Statement<Raw>,
1356    sql: String,
1357    raw_params: Vec<Option<String>>,
1358) -> Result<StatementResult, Error> {
1359    const EMPTY_PORTAL: &str = "";
1360    if let Err(e) = client
1361        .prepare(EMPTY_PORTAL.into(), Some(stmt.clone()), sql, vec![])
1362        .await
1363    {
1364        return Ok(SqlResult::err(client, e).into());
1365    }
1366
1367    let prep_stmt = match client.get_prepared_statement(EMPTY_PORTAL).await {
1368        Ok(stmt) => stmt,
1369        Err(err) => {
1370            return Ok(SqlResult::err(client, err).into());
1371        }
1372    };
1373
1374    let param_types = &prep_stmt.desc().param_types;
1375    if param_types.len() != raw_params.len() {
1376        let message = anyhow!(
1377            "request supplied {actual} parameters, \
1378                        but {statement} requires {expected}",
1379            statement = stmt.to_ast_string_simple(),
1380            actual = raw_params.len(),
1381            expected = param_types.len()
1382        );
1383        return Ok(SqlResult::err(client, Error::Unstructured(message)).into());
1384    }
1385
1386    let buf = RowArena::new();
1387    let mut params = vec![];
1388    for (raw_param, mz_typ) in raw_params.into_iter().zip_eq(param_types) {
1389        let pg_typ = mz_pgrepr::Type::from(mz_typ);
1390        let datum = match raw_param {
1391            None => Datum::Null,
1392            Some(raw_param) => {
1393                match mz_pgrepr::Value::decode(
1394                    mz_pgwire_common::Format::Text,
1395                    &pg_typ,
1396                    raw_param.as_bytes(),
1397                ) {
1398                    Ok(param) => param.into_datum(&buf, &pg_typ),
1399                    Err(err) => {
1400                        let msg = anyhow!("unable to decode parameter: {}", err);
1401                        return Ok(SqlResult::err(client, Error::Unstructured(msg)).into());
1402                    }
1403                }
1404            }
1405        };
1406        params.push((datum, mz_typ.clone()))
1407    }
1408
1409    let result_formats = vec![
1410        mz_pgwire_common::Format::Text;
1411        prep_stmt
1412            .desc()
1413            .relation_desc
1414            .clone()
1415            .map(|desc| desc.typ().column_types.len())
1416            .unwrap_or(0)
1417    ];
1418
1419    let desc = prep_stmt.desc().clone();
1420    let logging = Arc::clone(prep_stmt.logging());
1421    let stmt_ast = prep_stmt.stmt().cloned();
1422    let state_revision = prep_stmt.state_revision;
1423    if let Err(err) = client.session().set_portal(
1424        EMPTY_PORTAL.into(),
1425        desc,
1426        stmt_ast,
1427        logging,
1428        params,
1429        result_formats,
1430        state_revision,
1431    ) {
1432        return Ok(SqlResult::err(client, err).into());
1433    }
1434
1435    let desc = client
1436        .session()
1437        // We do not need to verify here because `client.execute` verifies below.
1438        .get_portal_unverified(EMPTY_PORTAL)
1439        .map(|portal| portal.desc.clone())
1440        .expect("unnamed portal should be present");
1441
1442    let res = client
1443        .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
1444        .await;
1445
1446    if S::SUPPORTS_STREAMING_NOTICES {
1447        sender
1448            .emit_streaming_notices(client.session().drain_notices())
1449            .await?;
1450    }
1451
1452    let (res, execute_started) = match res {
1453        Ok(res) => res,
1454        Err(e) => {
1455            return Ok(SqlResult::err(client, e).into());
1456        }
1457    };
1458    let tag = res.tag();
1459
1460    Ok(match res {
1461        ExecuteResponse::CreatedConnection { .. }
1462        | ExecuteResponse::CreatedDatabase { .. }
1463        | ExecuteResponse::CreatedSchema { .. }
1464        | ExecuteResponse::CreatedRole
1465        | ExecuteResponse::CreatedCluster { .. }
1466        | ExecuteResponse::CreatedClusterReplica { .. }
1467        | ExecuteResponse::CreatedTable { .. }
1468        | ExecuteResponse::CreatedIndex { .. }
1469        | ExecuteResponse::CreatedIntrospectionSubscribe
1470        | ExecuteResponse::CreatedSecret { .. }
1471        | ExecuteResponse::CreatedSource { .. }
1472        | ExecuteResponse::CreatedSink { .. }
1473        | ExecuteResponse::CreatedView { .. }
1474        | ExecuteResponse::CreatedViews { .. }
1475        | ExecuteResponse::CreatedMaterializedView { .. }
1476        | ExecuteResponse::CreatedContinualTask { .. }
1477        | ExecuteResponse::CreatedType
1478        | ExecuteResponse::CreatedNetworkPolicy
1479        | ExecuteResponse::Comment
1480        | ExecuteResponse::Deleted(_)
1481        | ExecuteResponse::DiscardedTemp
1482        | ExecuteResponse::DiscardedAll
1483        | ExecuteResponse::DroppedObject(_)
1484        | ExecuteResponse::DroppedOwned
1485        | ExecuteResponse::EmptyQuery
1486        | ExecuteResponse::GrantedPrivilege
1487        | ExecuteResponse::GrantedRole
1488        | ExecuteResponse::Inserted(_)
1489        | ExecuteResponse::Copied(_)
1490        | ExecuteResponse::Raised
1491        | ExecuteResponse::ReassignOwned
1492        | ExecuteResponse::RevokedPrivilege
1493        | ExecuteResponse::AlteredDefaultPrivileges
1494        | ExecuteResponse::RevokedRole
1495        | ExecuteResponse::StartedTransaction { .. }
1496        | ExecuteResponse::Updated(_)
1497        | ExecuteResponse::AlteredObject(_)
1498        | ExecuteResponse::AlteredRole
1499        | ExecuteResponse::AlteredSystemConfiguration
1500        | ExecuteResponse::Deallocate { .. }
1501        | ExecuteResponse::ValidatedConnection
1502        | ExecuteResponse::Prepare => SqlResult::ok(
1503            client,
1504            tag.expect("ok only called on tag-generating results"),
1505            Vec::default(),
1506        )
1507        .into(),
1508        ExecuteResponse::TransactionCommitted { params }
1509        | ExecuteResponse::TransactionRolledBack { params } => {
1510            let notify_set: mz_ore::collections::HashSet<_> = client
1511                .session()
1512                .vars()
1513                .notify_set()
1514                .map(|v| v.name().to_string())
1515                .collect();
1516            let params = params
1517                .into_iter()
1518                .filter(|(name, _value)| notify_set.contains(*name))
1519                .map(|(name, value)| ParameterStatus {
1520                    name: name.to_string(),
1521                    value,
1522                })
1523                .collect();
1524            SqlResult::ok(
1525                client,
1526                tag.expect("ok only called on tag-generating results"),
1527                params,
1528            )
1529            .into()
1530        }
1531        ExecuteResponse::SetVariable { name, .. } => {
1532            let mut params = Vec::with_capacity(1);
1533            if let Some(var) = client
1534                .session()
1535                .vars()
1536                .notify_set()
1537                .find(|v| v.name() == &name)
1538            {
1539                params.push(ParameterStatus {
1540                    name,
1541                    value: var.value(),
1542                });
1543            };
1544            SqlResult::ok(
1545                client,
1546                tag.expect("ok only called on tag-generating results"),
1547                params,
1548            )
1549            .into()
1550        }
1551        ExecuteResponse::SendingRowsStreaming {
1552            rows,
1553            instance_id,
1554            strategy,
1555        } => {
1556            let max_query_result_size =
1557                usize::cast_from(client.get_system_vars().await.max_result_size());
1558
1559            let rows_stream = RecordFirstRowStream::new(
1560                Box::new(rows),
1561                execute_started,
1562                client,
1563                Some(instance_id),
1564                Some(strategy),
1565            );
1566
1567            SqlResult::rows(
1568                sender,
1569                client,
1570                rows_stream,
1571                max_query_result_size,
1572                &desc.relation_desc.expect("RelationDesc must exist"),
1573            )
1574            .await?
1575            .into()
1576        }
1577        ExecuteResponse::SendingRowsImmediate { rows } => {
1578            let max_query_result_size =
1579                usize::cast_from(client.get_system_vars().await.max_result_size());
1580
1581            let rows = futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1582            let rows_stream =
1583                RecordFirstRowStream::new(Box::new(rows), execute_started, client, None, None);
1584
1585            SqlResult::rows(
1586                sender,
1587                client,
1588                rows_stream,
1589                max_query_result_size,
1590                &desc.relation_desc.expect("RelationDesc must exist"),
1591            )
1592            .await?
1593            .into()
1594        }
1595        ExecuteResponse::Subscribing {
1596            rx,
1597            ctx_extra,
1598            instance_id,
1599        } => StatementResult::Subscribe {
1600            tag: "SUBSCRIBE".into(),
1601            desc: desc.relation_desc.unwrap(),
1602            rx: RecordFirstRowStream::new(
1603                Box::new(UnboundedReceiverStream::new(rx)),
1604                execute_started,
1605                client,
1606                Some(instance_id),
1607                None,
1608            ),
1609            ctx_extra,
1610        },
1611        res @ (ExecuteResponse::Fetch { .. }
1612        | ExecuteResponse::CopyTo { .. }
1613        | ExecuteResponse::CopyFrom { .. }
1614        | ExecuteResponse::DeclaredCursor
1615        | ExecuteResponse::ClosedCursor) => SqlResult::err(
1616            client,
1617            Error::Unstructured(anyhow!(
1618                "internal error: encountered prohibited ExecuteResponse {:?}.\n\n
1619            This is a bug. Can you please file an bug report letting us know?\n
1620            https://github.com/MaterializeInc/materialize/discussions/new?category=bug-reports",
1621                ExecuteResponseKind::from(res)
1622            )),
1623        )
1624        .into(),
1625    })
1626}
1627
1628fn make_notices(client: &mut SessionClient) -> Vec<Notice> {
1629    client
1630        .session()
1631        .drain_notices()
1632        .into_iter()
1633        .map(|notice| Notice {
1634            message: notice.to_string(),
1635            code: notice.code().code().to_string(),
1636            severity: notice.severity().as_str().to_lowercase(),
1637            detail: notice.detail(),
1638            hint: notice.hint(),
1639        })
1640        .collect()
1641}
1642
1643// Duplicated from protocol.rs.
1644// See postgres' backend/tcop/postgres.c IsTransactionExitStmt.
1645fn is_txn_exit_stmt(stmt: &Statement<Raw>) -> bool {
1646    matches!(
1647        stmt,
1648        Statement::Commit(_) | Statement::Rollback(_) | Statement::Prepare(_)
1649    )
1650}
1651
1652#[cfg(test)]
1653mod tests {
1654    use std::collections::BTreeMap;
1655
1656    use super::{Password, WebSocketAuth};
1657
1658    #[mz_ore::test]
1659    fn smoke_test_websocket_auth_parse() {
1660        struct TestCase {
1661            json: &'static str,
1662            expected: WebSocketAuth,
1663        }
1664
1665        let test_cases = vec![
1666            TestCase {
1667                json: r#"{ "user": "mz", "password": "1234" }"#,
1668                expected: WebSocketAuth::Basic {
1669                    user: "mz".to_string(),
1670                    password: Password("1234".to_string()),
1671                    options: BTreeMap::default(),
1672                },
1673            },
1674            TestCase {
1675                json: r#"{ "user": "mz", "password": "1234", "options": {} }"#,
1676                expected: WebSocketAuth::Basic {
1677                    user: "mz".to_string(),
1678                    password: Password("1234".to_string()),
1679                    options: BTreeMap::default(),
1680                },
1681            },
1682            TestCase {
1683                json: r#"{ "token": "i_am_a_token" }"#,
1684                expected: WebSocketAuth::Bearer {
1685                    token: "i_am_a_token".to_string(),
1686                    options: BTreeMap::default(),
1687                },
1688            },
1689            TestCase {
1690                json: r#"{ "token": "i_am_a_token", "options": { "foo": "bar" } }"#,
1691                expected: WebSocketAuth::Bearer {
1692                    token: "i_am_a_token".to_string(),
1693                    options: BTreeMap::from([("foo".to_string(), "bar".to_string())]),
1694                },
1695            },
1696        ];
1697
1698        fn assert_parse(json: &'static str, expected: WebSocketAuth) {
1699            let parsed: WebSocketAuth = serde_json::from_str(json).unwrap();
1700            assert_eq!(parsed, expected);
1701        }
1702
1703        for TestCase { json, expected } in test_cases {
1704            assert_parse(json, expected)
1705        }
1706    }
1707}