1use std::collections::BTreeMap;
11use std::net::{IpAddr, SocketAddr};
12use std::pin::pin;
13use std::sync::Arc;
14use std::time::Duration;
15
16use anyhow::anyhow;
17use async_trait::async_trait;
18use axum::extract::connect_info::ConnectInfo;
19use axum::extract::ws::{CloseFrame, Message, Utf8Bytes, WebSocket};
20use axum::extract::{State, WebSocketUpgrade};
21use axum::response::IntoResponse;
22use axum::{Extension, Json};
23use futures::Future;
24use futures::future::BoxFuture;
25
26use http::StatusCode;
27use itertools::Itertools;
28use mz_adapter::client::RecordFirstRowStream;
29use mz_adapter::session::{EndTransactionAction, TransactionStatus};
30use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
31use mz_adapter::{
32 AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, ExecuteResponseKind,
33 PeekResponseUnary, SessionClient, verify_datum_desc,
34};
35use mz_auth::password::Password;
36use mz_catalog::memory::objects::{Cluster, ClusterReplica};
37use mz_interchange::encode::TypedDatum;
38use mz_interchange::json::{JsonNumberPolicy, ToJson};
39use mz_ore::cast::CastFrom;
40use mz_ore::metrics::{MakeCollectorOpts, MetricsRegistry};
41use mz_ore::result::ResultExt;
42use mz_ore::sql::Sql;
43use mz_repr::{Datum, RelationDesc, RowArena, RowIterator};
44use mz_sql::ast::display::AstDisplay;
45use mz_sql::ast::{CopyDirection, CopyStatement, CopyTarget, Raw, Statement, StatementKind};
46use mz_sql::parse::StatementParseResult;
47use mz_sql::plan::Plan;
48use mz_sql::session::metadata::SessionMetadata;
49use prometheus::Opts;
50use prometheus::core::{AtomicF64, GenericGaugeVec};
51use serde::{Deserialize, Serialize};
52use tokio::{select, time};
53use tokio_postgres::error::SqlState;
54use tokio_stream::wrappers::UnboundedReceiverStream;
55use tower_sessions::Session as TowerSession;
56use tracing::{debug, info};
57use tungstenite::protocol::frame::coding::CloseCode;
58
59use crate::http::prometheus::PrometheusSqlQuery;
60use crate::http::{
61 AuthError, AuthedClient, AuthedUser, MAX_REQUEST_SIZE, WsState, ensure_session_unexpired,
62 init_ws, maybe_get_authenticated_session,
63};
64
65#[derive(Debug, thiserror::Error)]
66pub enum Error {
67 #[error(transparent)]
68 Adapter(#[from] AdapterError),
69 #[error(transparent)]
70 Json(#[from] serde_json::Error),
71 #[error(transparent)]
72 Axum(#[from] axum::Error),
73 #[error("SUBSCRIBE only supported over websocket")]
74 SubscribeOnlyOverWs,
75 #[error("current transaction is aborted, commands ignored until end of transaction block")]
76 AbortedTransaction,
77 #[error("unsupported via this API: {0}")]
78 Unsupported(String),
79 #[error("{0}")]
80 Unstructured(anyhow::Error),
81}
82
83impl Error {
84 pub fn detail(&self) -> Option<String> {
85 match self {
86 Error::Adapter(err) => err.detail(),
87 _ => None,
88 }
89 }
90
91 pub fn hint(&self) -> Option<String> {
92 match self {
93 Error::Adapter(err) => err.hint(),
94 _ => None,
95 }
96 }
97
98 pub fn position(&self) -> Option<usize> {
99 match self {
100 Error::Adapter(err) => err.position(),
101 _ => None,
102 }
103 }
104
105 pub fn code(&self) -> SqlState {
106 match self {
107 Error::Adapter(err) => err.code(),
108 Error::AbortedTransaction => SqlState::IN_FAILED_SQL_TRANSACTION,
109 _ => SqlState::INTERNAL_ERROR,
110 }
111 }
112}
113
114static PER_REPLICA_LABELS: &[&str] = &["replica_full_name", "instance_id", "replica_id"];
115
116async fn execute_promsql_query(
117 client: &mut AuthedClient,
118 query: &PrometheusSqlQuery<'_>,
119 metrics_registry: &MetricsRegistry,
120 metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
121 cluster: Option<(&Cluster, &ClusterReplica)>,
122) {
123 assert_eq!(query.per_replica, cluster.is_some());
124
125 let mut res = SqlResponse {
126 results: Vec::new(),
127 };
128
129 execute_request(client, query.to_sql_request(cluster), &mut res)
130 .await
131 .expect("valid SQL query");
132
133 let result = match res.results.as_slice() {
134 [
137 SqlResult::Ok { .. },
138 SqlResult::Ok { .. },
139 SqlResult::Ok { .. },
140 result,
141 ] => result,
142 _ => {
146 info!(
147 "error executing prometheus query {}: {:?}",
148 query.metric_name, res
149 );
150 return;
151 }
152 };
153
154 let SqlResult::Rows { desc, rows, .. } = result else {
155 info!(
156 "did not receive rows for SQL query for prometheus metric {}: {:?}, {:?}",
157 query.metric_name, result, cluster
158 );
159 return;
160 };
161
162 let gauge_vec = metrics_by_name
163 .entry(query.metric_name.to_string())
164 .or_insert_with(|| {
165 let mut label_names: Vec<String> = desc
166 .columns
167 .iter()
168 .filter(|col| col.name != query.value_column_name)
169 .map(|col| col.name.clone())
170 .collect();
171
172 if query.per_replica {
173 label_names.extend(PER_REPLICA_LABELS.iter().map(|label| label.to_string()));
174 }
175
176 metrics_registry.register::<GenericGaugeVec<AtomicF64>>(MakeCollectorOpts {
177 opts: Opts::new(query.metric_name, query.help).variable_labels(label_names),
178 buckets: None,
179 })
180 });
181
182 for row in rows {
183 let mut label_values = desc
184 .columns
185 .iter()
186 .zip_eq(row)
187 .filter(|(col, _)| col.name != query.value_column_name)
188 .map(|(_, val)| val.as_str().expect("must be string"))
189 .collect::<Vec<_>>();
190
191 let value = desc
192 .columns
193 .iter()
194 .zip_eq(row)
195 .find(|(col, _)| col.name == query.value_column_name)
196 .map(|(_, val)| val.as_str().unwrap_or("0").parse::<f64>().unwrap_or(0.0))
197 .unwrap_or(0.0);
198
199 match cluster {
200 Some((cluster, replica)) => {
201 let replica_full_name = format!("{}.{}", cluster.name, replica.name);
202 let cluster_id = cluster.id.to_string();
203 let replica_id = replica.replica_id.to_string();
204
205 label_values.push(&replica_full_name);
206 label_values.push(&cluster_id);
207 label_values.push(&replica_id);
208
209 gauge_vec
210 .get_metric_with_label_values(&label_values)
211 .expect("valid labels")
212 .set(value);
213 }
214 None => {
215 gauge_vec
216 .get_metric_with_label_values(&label_values)
217 .expect("valid labels")
218 .set(value);
219 }
220 }
221 }
222}
223
224async fn handle_promsql_query(
225 client: &mut AuthedClient,
226 query: &PrometheusSqlQuery<'_>,
227 metrics_registry: &MetricsRegistry,
228 metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
229) {
230 if !query.per_replica {
231 execute_promsql_query(client, query, metrics_registry, metrics_by_name, None).await;
232 return;
233 }
234
235 let catalog = client.client.catalog_snapshot("handle_promsql_query").await;
236 let clusters: Vec<&Cluster> = catalog.clusters().collect();
237
238 for cluster in clusters {
239 for replica in cluster.replicas() {
240 execute_promsql_query(
241 client,
242 query,
243 metrics_registry,
244 metrics_by_name,
245 Some((cluster, replica)),
246 )
247 .await;
248 }
249 }
250}
251
252pub async fn handle_promsql(
253 mut client: AuthedClient,
254 queries: &[PrometheusSqlQuery<'_>],
255) -> MetricsRegistry {
256 let metrics_registry = MetricsRegistry::new();
257 let mut metrics_by_name = BTreeMap::new();
258
259 for query in queries {
260 handle_promsql_query(&mut client, query, &metrics_registry, &mut metrics_by_name).await;
261 }
262
263 metrics_registry
264}
265
266pub async fn handle_sql(
267 mut client: AuthedClient,
268 Json(request): Json<SqlRequest>,
269) -> impl IntoResponse {
270 let mut res = SqlResponse {
271 results: Vec::new(),
272 };
273 match execute_request(&mut client, request, &mut res).await {
276 Ok(()) => Ok(Json(res)),
277 Err(e) => Err((StatusCode::BAD_REQUEST, e.to_string())),
278 }
279}
280
281#[derive(Debug)]
282pub enum ExistingUser {
283 XMaterializeUserHeader(AuthedUser),
286 Session(AuthedUser),
289}
290
291pub(crate) async fn handle_sql_ws(
292 State(state): State<WsState>,
293 existing_user: Option<Extension<AuthedUser>>,
294 ws: WebSocketUpgrade,
295 ConnectInfo(addr): ConnectInfo<SocketAddr>,
296 tower_session: Option<Extension<TowerSession>>,
297) -> Result<impl IntoResponse, AuthError> {
298 let session = tower_session.map(|Extension(session)| session);
299 let user = match existing_user {
301 Some(Extension(user)) => Some(ExistingUser::XMaterializeUserHeader(user)),
302 None => {
303 let session = maybe_get_authenticated_session(session.as_ref()).await;
304 if let Some((session, session_data)) = session {
305 let user = ensure_session_unexpired(session, session_data).await?;
306 Some(ExistingUser::Session(user))
307 } else {
308 None
309 }
310 }
311 };
312
313 let addr = Box::new(addr.ip());
314 Ok(ws
315 .max_message_size(MAX_REQUEST_SIZE)
316 .on_upgrade(|ws| async move { run_ws(state, user, *addr, ws).await }))
317}
318
319#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
320#[serde(untagged)]
321pub enum WebSocketAuth {
322 Basic {
323 user: String,
324 password: Password,
325 #[serde(default)]
326 options: BTreeMap<String, String>,
327 },
328 Bearer {
329 token: String,
330 #[serde(default)]
331 options: BTreeMap<String, String>,
332 },
333 OptionsOnly {
334 #[serde(default)]
335 options: BTreeMap<String, String>,
336 },
337}
338
339async fn run_ws(state: WsState, user: Option<ExistingUser>, peer_addr: IpAddr, mut ws: WebSocket) {
340 let mut client = match init_ws(state, user, peer_addr, &mut ws).await {
341 Ok(client) => client,
342 Err(e) => {
343 debug!("WS request failed init: {}", e);
347 let reason: Utf8Bytes = match e.downcast_ref::<AdapterError>() {
348 Some(error) => error.to_string().into(),
349 None => "unauthorized".to_string().into(),
350 };
351 let _ = ws
352 .send(Message::Close(Some(CloseFrame {
353 code: CloseCode::Protocol.into(),
354 reason,
355 })))
356 .await;
357 return;
358 }
359 };
360
361 let mut msgs = Vec::new();
363 let session = client.client.session();
364 for var in session.vars().notify_set() {
365 msgs.push(WebSocketResponse::ParameterStatus(ParameterStatus {
366 name: var.name().to_string(),
367 value: var.value(),
368 }));
369 }
370 msgs.push(WebSocketResponse::BackendKeyData(BackendKeyData {
371 conn_id: session.conn_id().unhandled(),
372 secret_key: session.secret_key(),
373 }));
374 msgs.push(WebSocketResponse::ReadyForQuery(
375 session.transaction_code().into(),
376 ));
377 for msg in msgs {
378 let _ = ws
379 .send(Message::Text(
380 serde_json::to_string(&msg).expect("must serialize").into(),
381 ))
382 .await;
383 }
384
385 let notices = session.drain_notices();
387 if let Err(err) = forward_notices(&mut ws, notices).await {
388 debug!("failed to forward notices to WebSocket, {err:?}");
389 return;
390 }
391
392 loop {
393 let msg = select! {
395 biased;
396
397 Some(timeout) = client.client.recv_timeout() => {
399 client.client.terminate().await;
400 let _ = ws.recv().await;
404 let err = Error::from(AdapterError::from(timeout));
405 let _ = send_ws_response(&mut ws, WebSocketResponse::Error(err.into())).await;
406 return;
407 },
408 message = ws.recv() => message,
409 };
410
411 client.client.remove_idle_in_transaction_session_timeout();
412
413 let msg = match msg {
414 Some(Ok(msg)) => msg,
415 _ => {
416 return;
418 }
419 };
420
421 let req: Result<SqlRequest, Error> = match msg {
422 Message::Text(data) => serde_json::from_str(&data).err_into(),
423 Message::Binary(data) => serde_json::from_slice(&data).err_into(),
424 Message::Ping(_) => {
426 continue;
427 }
428 Message::Pong(_) => {
429 continue;
430 }
431 Message::Close(_) => {
432 return;
433 }
434 };
435
436 let err = match run_ws_request(req, &mut client, &mut ws).await {
438 Ok(()) => None,
439 Err(err) => Some(WebSocketResponse::Error(err.into())),
440 };
441
442 let ws_response = || async {
449 if let Some(e_resp) = err {
451 send_ws_response(&mut ws, e_resp).await?;
452 }
453
454 let notices = client.client.session().drain_notices();
456 forward_notices(&mut ws, notices).await?;
457
458 let ready =
460 WebSocketResponse::ReadyForQuery(client.client.session().transaction_code().into());
461 send_ws_response(&mut ws, ready).await?;
462
463 Ok::<_, Error>(())
464 };
465
466 if let Err(err) = ws_response().await {
467 debug!("failed to send response over WebSocket, {err:?}");
468 return;
469 }
470 }
471}
472
473async fn run_ws_request(
474 req: Result<SqlRequest, Error>,
475 client: &mut AuthedClient,
476 ws: &mut WebSocket,
477) -> Result<(), Error> {
478 let req = req?;
479 execute_request(client, req, ws).await
480}
481
482async fn send_ws_response(ws: &mut WebSocket, resp: WebSocketResponse) -> Result<(), Error> {
484 let msg = serde_json::to_string(&resp).unwrap();
485 let msg = Message::Text(msg.into());
486 ws.send(msg).await?;
487
488 Ok(())
489}
490
491async fn forward_notices(
493 ws: &mut WebSocket,
494 notices: impl IntoIterator<Item = AdapterNotice>,
495) -> Result<(), Error> {
496 let ws_notices = notices.into_iter().map(|notice| {
497 WebSocketResponse::Notice(Notice {
498 message: notice.to_string(),
499 code: notice.code().code().to_string(),
500 severity: notice.severity().as_str().to_lowercase(),
501 detail: notice.detail(),
502 hint: notice.hint(),
503 })
504 });
505
506 for notice in ws_notices {
507 send_ws_response(ws, notice).await?;
508 }
509
510 Ok(())
511}
512
513#[derive(Serialize, Deserialize, Debug)]
515#[serde(untagged)]
516pub enum SqlRequest {
517 Simple {
519 query: Sql,
522 },
523 Extended {
525 queries: Vec<ExtendedRequest>,
527 },
528}
529
530#[derive(Serialize, Deserialize, Debug)]
532pub struct ExtendedRequest {
533 query: String,
535 #[serde(default)]
537 params: Vec<Option<String>>,
538}
539
540#[derive(Debug, Serialize, Deserialize)]
542pub struct SqlResponse {
543 pub(in crate::http) results: Vec<SqlResult>,
545}
546
547impl SqlResponse {
548 pub(in crate::http) fn new() -> Self {
550 Self {
551 results: Vec::new(),
552 }
553 }
554}
555
556pub(in crate::http) enum StatementResult {
557 SqlResult(SqlResult),
558 Subscribe {
559 desc: RelationDesc,
560 tag: String,
561 rx: RecordFirstRowStream,
562 ctx_extra: ExecuteContextGuard,
563 },
564}
565
566impl From<SqlResult> for StatementResult {
567 fn from(inner: SqlResult) -> Self {
568 Self::SqlResult(inner)
569 }
570}
571
572#[derive(Debug, Serialize, Deserialize)]
574#[serde(untagged)]
575pub enum SqlResult {
576 Rows {
578 tag: String,
580 rows: Vec<Vec<serde_json::Value>>,
582 desc: Description,
584 notices: Vec<Notice>,
586 },
587 Ok {
589 ok: String,
591 notices: Vec<Notice>,
593 #[serde(skip_serializing_if = "Vec::is_empty")]
597 parameters: Vec<ParameterStatus>,
598 },
599 Err {
601 error: SqlError,
602 notices: Vec<Notice>,
604 },
605}
606
607impl SqlResult {
608 async fn rows<S>(
612 sender: &mut S,
613 client: &mut SessionClient,
614 mut rows_stream: RecordFirstRowStream,
615 max_query_result_size: usize,
616 desc: &RelationDesc,
617 ) -> Result<SqlResult, Error>
618 where
619 S: ResultSender,
620 {
621 let mut rows: Vec<Vec<serde_json::Value>> = vec![];
622 let mut datum_vec = mz_repr::DatumVec::new();
623 let types = &desc.typ().column_types;
624
625 let mut query_result_size = 0;
626
627 loop {
628 let peek_response = tokio::select! {
629 notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
630 sender.emit_streaming_notices(vec![notice]).await?;
631 continue;
632 }
633 e = sender.connection_error() => return Err(e),
634 r = rows_stream.recv() => {
635 match r {
636 Some(r) => r,
637 None => break,
638 }
639 },
640 };
641
642 let mut sql_rows = match peek_response {
643 PeekResponseUnary::Rows(rows) => rows,
644 PeekResponseUnary::Error(e) => {
645 return Ok(SqlResult::err(client, Error::Unstructured(anyhow!(e))));
646 }
647 PeekResponseUnary::DependencyDropped(dep) => {
648 return Ok(SqlResult::err(client, dep.to_concurrent_dependency_drop()));
649 }
650 PeekResponseUnary::Canceled => {
651 return Ok(SqlResult::err(client, AdapterError::Canceled));
652 }
653 };
654
655 if let Err(err) = verify_datum_desc(desc, &mut sql_rows) {
656 return Ok(SqlResult::Err {
657 error: err.into(),
658 notices: make_notices(client),
659 });
660 }
661
662 while let Some(row) = sql_rows.next() {
663 query_result_size += row.byte_len();
664 if query_result_size > max_query_result_size {
665 use bytesize::ByteSize;
666 return Ok(SqlResult::err(
667 client,
668 AdapterError::ResultSize(format!(
669 "result exceeds max size of {}",
670 ByteSize::b(u64::cast_from(max_query_result_size))
671 )),
672 ));
673 }
674
675 let datums = datum_vec.borrow_with(row);
676 rows.push(
677 datums
678 .iter()
679 .enumerate()
680 .map(|(i, d)| {
681 TypedDatum::new(*d, &types[i])
682 .json(&JsonNumberPolicy::ConvertNumberToString)
683 })
684 .collect(),
685 );
686 }
687 }
688
689 let tag = format!("SELECT {}", rows.len());
690 Ok(SqlResult::Rows {
691 tag,
692 rows,
693 desc: Description::from(desc),
694 notices: make_notices(client),
695 })
696 }
697
698 fn err(client: &mut SessionClient, error: impl Into<SqlError>) -> SqlResult {
699 SqlResult::Err {
700 error: error.into(),
701 notices: make_notices(client),
702 }
703 }
704
705 fn ok(client: &mut SessionClient, tag: String, params: Vec<ParameterStatus>) -> SqlResult {
706 SqlResult::Ok {
707 ok: tag,
708 parameters: params,
709 notices: make_notices(client),
710 }
711 }
712}
713
714#[derive(Debug, Deserialize, Serialize)]
715pub struct SqlError {
716 pub message: String,
717 pub code: String,
718 #[serde(skip_serializing_if = "Option::is_none")]
719 pub detail: Option<String>,
720 #[serde(skip_serializing_if = "Option::is_none")]
721 pub hint: Option<String>,
722 #[serde(skip_serializing_if = "Option::is_none")]
723 pub position: Option<usize>,
724}
725
726impl From<Error> for SqlError {
727 fn from(err: Error) -> Self {
728 SqlError {
729 message: err.to_string(),
730 code: err.code().code().to_string(),
731 detail: err.detail(),
732 hint: err.hint(),
733 position: err.position(),
734 }
735 }
736}
737
738impl From<AdapterError> for SqlError {
739 fn from(value: AdapterError) -> Self {
740 Error::from(value).into()
741 }
742}
743
744#[derive(Debug, Deserialize, Serialize)]
745#[serde(tag = "type", content = "payload")]
746pub enum WebSocketResponse {
747 ReadyForQuery(String),
748 Notice(Notice),
749 Rows(Description),
750 Row(Vec<serde_json::Value>),
751 CommandStarting(CommandStarting),
752 CommandComplete(String),
753 Error(SqlError),
754 ParameterStatus(ParameterStatus),
755 BackendKeyData(BackendKeyData),
756}
757
758#[derive(Debug, Serialize, Deserialize)]
759pub struct Notice {
760 message: String,
761 code: String,
762 severity: String,
763 #[serde(skip_serializing_if = "Option::is_none")]
764 pub detail: Option<String>,
765 #[serde(skip_serializing_if = "Option::is_none")]
766 pub hint: Option<String>,
767}
768
769impl Notice {
770 pub fn message(&self) -> &str {
771 &self.message
772 }
773}
774
775#[derive(Debug, Serialize, Deserialize)]
776pub struct Description {
777 pub columns: Vec<Column>,
778}
779
780impl From<&RelationDesc> for Description {
781 fn from(desc: &RelationDesc) -> Self {
782 let columns = desc
783 .iter()
784 .map(|(name, typ)| {
785 let pg_type = mz_pgrepr::Type::from(&typ.scalar_type);
786 Column {
787 name: name.to_string(),
788 type_oid: pg_type.oid(),
789 type_len: pg_type.typlen(),
790 type_mod: pg_type.typmod(),
791 }
792 })
793 .collect();
794 Description { columns }
795 }
796}
797
798#[derive(Debug, Serialize, Deserialize)]
799pub struct Column {
800 pub name: String,
801 pub type_oid: u32,
802 pub type_len: i16,
803 pub type_mod: i32,
804}
805
806#[derive(Debug, Serialize, Deserialize)]
807pub struct ParameterStatus {
808 name: String,
809 value: String,
810}
811
812#[derive(Debug, Serialize, Deserialize)]
813pub struct BackendKeyData {
814 conn_id: u32,
815 secret_key: u32,
816}
817
818#[derive(Debug, Serialize, Deserialize)]
819pub struct CommandStarting {
820 has_rows: bool,
821 is_streaming: bool,
822}
823
824#[async_trait]
828pub(in crate::http) trait ResultSender: Send {
829 const SUPPORTS_STREAMING_NOTICES: bool = false;
830
831 async fn add_result(
839 &mut self,
840 client: &mut SessionClient,
841 res: StatementResult,
842 ) -> (
843 Result<Result<(), ()>, Error>,
844 Option<(StatementEndedExecutionReason, ExecuteContextGuard)>,
845 );
846
847 fn connection_error(&mut self) -> BoxFuture<'_, Error>;
849 fn allow_subscribe(&self) -> bool;
851
852 async fn emit_streaming_notices(&mut self, _: Vec<AdapterNotice>) -> Result<(), Error> {
856 unreachable!("streaming notices marked as unsupported")
857 }
858}
859
860#[async_trait]
861impl ResultSender for SqlResponse {
862 async fn add_result(
870 &mut self,
871 _client: &mut SessionClient,
872 res: StatementResult,
873 ) -> (
874 Result<Result<(), ()>, Error>,
875 Option<(StatementEndedExecutionReason, ExecuteContextGuard)>,
876 ) {
877 let (res, stmt_logging) = match res {
878 StatementResult::SqlResult(res) => {
879 let is_err = matches!(res, SqlResult::Err { .. });
880 self.results.push(res);
881 let res = if is_err { Err(()) } else { Ok(()) };
882 (res, None)
883 }
884 StatementResult::Subscribe { ctx_extra, .. } => {
885 let message = "SUBSCRIBE only supported over websocket";
886 self.results.push(SqlResult::Err {
887 error: Error::SubscribeOnlyOverWs.into(),
888 notices: Vec::new(),
889 });
890 (
891 Err(()),
892 Some((
893 StatementEndedExecutionReason::Errored {
894 error: message.into(),
895 },
896 ctx_extra,
897 )),
898 )
899 }
900 };
901 (Ok(res), stmt_logging)
902 }
903
904 fn connection_error(&mut self) -> BoxFuture<'_, Error> {
905 Box::pin(futures::future::pending())
906 }
907
908 fn allow_subscribe(&self) -> bool {
909 false
910 }
911}
912
913#[async_trait]
914impl ResultSender for WebSocket {
915 const SUPPORTS_STREAMING_NOTICES: bool = true;
916
917 async fn add_result(
923 &mut self,
924 client: &mut SessionClient,
925 res: StatementResult,
926 ) -> (
927 Result<Result<(), ()>, Error>,
928 Option<(StatementEndedExecutionReason, ExecuteContextGuard)>,
929 ) {
930 let (has_rows, is_streaming) = match res {
931 StatementResult::SqlResult(SqlResult::Err { .. }) => (false, false),
932 StatementResult::SqlResult(SqlResult::Ok { .. }) => (false, false),
933 StatementResult::SqlResult(SqlResult::Rows { .. }) => (true, false),
934 StatementResult::Subscribe { .. } => (true, true),
935 };
936 if let Err(e) = send_ws_response(
937 self,
938 WebSocketResponse::CommandStarting(CommandStarting {
939 has_rows,
940 is_streaming,
941 }),
942 )
943 .await
944 {
945 return (Err(e), None);
946 }
947
948 let (is_err, msgs, stmt_logging) = match res {
949 StatementResult::SqlResult(SqlResult::Rows {
950 tag,
951 rows,
952 desc,
953 notices,
954 }) => {
955 if let Err(e) = send_ws_response(self, WebSocketResponse::Rows(desc)).await {
958 return (Err(e), None);
959 }
960 for row in rows {
961 if let Err(e) = send_ws_response(self, WebSocketResponse::Row(row)).await {
962 return (Err(e), None);
963 }
964 }
965 let mut msgs = vec![WebSocketResponse::CommandComplete(tag)];
966 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
967 (false, msgs, None)
968 }
969 StatementResult::SqlResult(SqlResult::Ok {
970 ok,
971 parameters,
972 notices,
973 }) => {
974 let mut msgs = vec![WebSocketResponse::CommandComplete(ok)];
975 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
976 msgs.extend(
977 parameters
978 .into_iter()
979 .map(WebSocketResponse::ParameterStatus),
980 );
981 (false, msgs, None)
982 }
983 StatementResult::SqlResult(SqlResult::Err { error, notices }) => {
984 let mut msgs = vec![WebSocketResponse::Error(error)];
985 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
986 (true, msgs, None)
987 }
988 StatementResult::Subscribe {
989 ref desc,
990 tag,
991 mut rx,
992 ctx_extra,
993 } => {
994 if let Err(e) = send_ws_response(self, WebSocketResponse::Rows(desc.into())).await {
995 return (
998 Err(e),
999 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1000 );
1001 }
1002
1003 let mut datum_vec = mz_repr::DatumVec::new();
1004 let mut result_size: usize = 0;
1005 let mut rows_returned = 0;
1006 loop {
1007 let res = match await_rows(self, client, rx.recv()).await {
1008 Ok(res) => res,
1009 Err(e) => {
1010 return (
1013 Err(e),
1014 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1015 );
1016 }
1017 };
1018 match res {
1019 Some(PeekResponseUnary::Rows(mut rows)) => {
1020 if let Err(err) = verify_datum_desc(desc, &mut rows) {
1021 let error = err.to_string();
1022 break (
1023 true,
1024 vec![WebSocketResponse::Error(err.into())],
1025 Some((
1026 StatementEndedExecutionReason::Errored { error },
1027 ctx_extra,
1028 )),
1029 );
1030 }
1031
1032 rows_returned += rows.count();
1033 while let Some(row) = rows.next() {
1034 result_size += row.byte_len();
1035 let datums = datum_vec.borrow_with(row);
1036 let types = &desc.typ().column_types;
1037 if let Err(e) = send_ws_response(
1038 self,
1039 WebSocketResponse::Row(
1040 datums
1041 .iter()
1042 .enumerate()
1043 .map(|(i, d)| {
1044 TypedDatum::new(*d, &types[i])
1045 .json(&JsonNumberPolicy::ConvertNumberToString)
1046 })
1047 .collect(),
1048 ),
1049 )
1050 .await
1051 {
1052 return (
1055 Err(e),
1056 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1057 );
1058 }
1059 }
1060 }
1061 Some(PeekResponseUnary::Error(error)) => {
1062 break (
1063 true,
1064 vec![WebSocketResponse::Error(
1065 Error::Unstructured(anyhow!(error.clone())).into(),
1066 )],
1067 Some((StatementEndedExecutionReason::Errored { error }, ctx_extra)),
1068 );
1069 }
1070 Some(PeekResponseUnary::DependencyDropped(dep)) => {
1071 let err = dep.to_concurrent_dependency_drop();
1072 let error = err.to_string();
1073 break (
1074 true,
1075 vec![WebSocketResponse::Error(err.into())],
1076 Some((StatementEndedExecutionReason::Errored { error }, ctx_extra)),
1077 );
1078 }
1079 Some(PeekResponseUnary::Canceled) => {
1080 break (
1081 true,
1082 vec![WebSocketResponse::Error(AdapterError::Canceled.into())],
1083 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1084 );
1085 }
1086 None => {
1087 break (
1088 false,
1089 vec![WebSocketResponse::CommandComplete(tag)],
1090 Some((
1091 StatementEndedExecutionReason::Success {
1092 result_size: Some(u64::cast_from(result_size)),
1093 rows_returned: Some(u64::cast_from(rows_returned)),
1094 execution_strategy: Some(
1095 StatementExecutionStrategy::Standard,
1096 ),
1097 },
1098 ctx_extra,
1099 )),
1100 );
1101 }
1102 }
1103 }
1104 }
1105 };
1106 for msg in msgs {
1107 if let Err(e) = send_ws_response(self, msg).await {
1108 return (
1109 Err(e),
1110 stmt_logging.map(|(_old_reason, ctx_extra)| {
1111 (StatementEndedExecutionReason::Canceled, ctx_extra)
1112 }),
1113 );
1114 }
1115 }
1116 (Ok(if is_err { Err(()) } else { Ok(()) }), stmt_logging)
1117 }
1118
1119 fn connection_error(&mut self) -> BoxFuture<'_, Error> {
1122 Box::pin(async {
1123 let mut tick = time::interval(Duration::from_secs(1));
1124 tick.tick().await;
1125 loop {
1126 tick.tick().await;
1127 if let Err(err) = self.send(Message::Ping(Vec::new().into())).await {
1128 return err.into();
1129 }
1130 }
1131 })
1132 }
1133
1134 fn allow_subscribe(&self) -> bool {
1135 true
1136 }
1137
1138 async fn emit_streaming_notices(&mut self, notices: Vec<AdapterNotice>) -> Result<(), Error> {
1139 forward_notices(self, notices).await
1140 }
1141}
1142
1143async fn await_rows<S, F, R>(sender: &mut S, client: &mut SessionClient, f: F) -> Result<R, Error>
1144where
1145 S: ResultSender,
1146 F: Future<Output = R> + Send,
1147{
1148 let mut f = pin!(f);
1149 loop {
1150 tokio::select! {
1151 notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
1152 sender.emit_streaming_notices(vec![notice]).await?;
1153 }
1154 e = sender.connection_error() => return Err(e),
1155 r = &mut f => return Ok(r),
1156 }
1157 }
1158}
1159
1160async fn send_and_retire<S: ResultSender>(
1161 res: StatementResult,
1162 client: &mut SessionClient,
1163 sender: &mut S,
1164) -> Result<Result<(), ()>, Error> {
1165 let (res, stmt_logging) = sender.add_result(client, res).await;
1166 if let Some((reason, ctx_extra)) = stmt_logging {
1167 client.retire_execute(ctx_extra, reason);
1168 }
1169 res
1170}
1171
1172async fn execute_stmt_group<S: ResultSender>(
1174 client: &mut SessionClient,
1175 sender: &mut S,
1176 stmt_group: Vec<(Statement<Raw>, String, Vec<Option<String>>)>,
1177) -> Result<Result<(), ()>, Error> {
1178 let num_stmts = stmt_group.len();
1179 for (stmt, sql, params) in stmt_group {
1180 assert!(
1181 num_stmts <= 1 || params.is_empty(),
1182 "statement groups contain more than 1 statement iff Simple request, which does not support parameters"
1183 );
1184
1185 let is_aborted_txn = matches!(client.session().transaction(), TransactionStatus::Failed(_));
1186 if is_aborted_txn && !is_txn_exit_stmt(&stmt) {
1187 let err = SqlResult::err(client, Error::AbortedTransaction);
1188 let _ = send_and_retire(err.into(), client, sender).await?;
1189 return Ok(Err(()));
1190 }
1191
1192 if let Err(e) = client.start_transaction(Some(num_stmts)) {
1195 let err = SqlResult::err(client, e);
1196 let _ = send_and_retire(err.into(), client, sender).await?;
1197 return Ok(Err(()));
1198 }
1199 let res = execute_stmt(client, sender, stmt, sql, params).await?;
1200 let is_err = send_and_retire(res, client, sender).await?;
1201
1202 if is_err.is_err() {
1203 let txn = client.session().transaction();
1206 match txn {
1207 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
1210 TransactionStatus::Started(_) | TransactionStatus::InTransactionImplicit(_) => {
1212 if let Err(err) = client.end_transaction(EndTransactionAction::Rollback).await {
1213 let err = SqlResult::err(client, err);
1214 let _ = send_and_retire(err.into(), client, sender).await?;
1215 }
1216 }
1217 TransactionStatus::InTransaction(_) => {
1219 client.fail_transaction();
1220 }
1221 }
1222 return Ok(Err(()));
1223 }
1224 }
1225 Ok(Ok(()))
1226}
1227
1228pub(in crate::http) async fn execute_request<S: ResultSender>(
1237 client: &mut AuthedClient,
1238 request: SqlRequest,
1239 sender: &mut S,
1240) -> Result<(), Error> {
1241 let client = &mut client.client;
1242
1243 fn check_prohibited_stmts<S: ResultSender>(
1246 sender: &S,
1247 stmt: &Statement<Raw>,
1248 ) -> Result<(), Error> {
1249 let kind: StatementKind = stmt.into();
1250 let execute_responses = Plan::generated_from(&kind)
1251 .into_iter()
1252 .map(ExecuteResponse::generated_from)
1253 .flatten()
1254 .collect::<Vec<_>>();
1255
1256 let is_valid_copy = matches!(
1260 stmt,
1261 Statement::Copy(CopyStatement {
1262 direction: CopyDirection::To,
1263 target: CopyTarget::Expr(_),
1264 ..
1265 }) | Statement::Copy(CopyStatement {
1266 direction: CopyDirection::From,
1267 target: CopyTarget::Expr(_),
1268 ..
1269 })
1270 );
1271
1272 if !is_valid_copy
1273 && execute_responses.iter().any(|execute_response| {
1274 match execute_response {
1276 ExecuteResponseKind::Subscribing if sender.allow_subscribe() => false,
1277 ExecuteResponseKind::Fetch
1278 | ExecuteResponseKind::Subscribing
1279 | ExecuteResponseKind::CopyFrom
1280 | ExecuteResponseKind::DeclaredCursor
1281 | ExecuteResponseKind::ClosedCursor => true,
1282 ExecuteResponseKind::CopyTo if matches!(kind, StatementKind::Copy) => true,
1287 _ => false,
1288 }
1289 })
1290 {
1291 return Err(Error::Unsupported(stmt.to_ast_string_simple()));
1292 }
1293 Ok(())
1294 }
1295
1296 fn parse<'a>(
1297 client: &SessionClient,
1298 query: &'a str,
1299 ) -> Result<Vec<StatementParseResult<'a>>, Error> {
1300 let result = client
1301 .parse(query)
1302 .map_err(|e| Error::Unstructured(anyhow!(e)))?;
1303 result.map_err(|e| AdapterError::from(e).into())
1304 }
1305
1306 let mut stmt_groups = vec![];
1307
1308 match request {
1309 SqlRequest::Simple { query } => match parse(client, query.as_str()) {
1310 Ok(stmts) => {
1311 let mut stmt_group = Vec::with_capacity(stmts.len());
1312 let mut stmt_err = None;
1313 for StatementParseResult { ast: stmt, sql } in stmts {
1314 if let Err(err) = check_prohibited_stmts(sender, &stmt) {
1315 stmt_err = Some(err);
1316 break;
1317 }
1318 stmt_group.push((stmt, sql.to_string(), vec![]));
1319 }
1320 stmt_groups.push(stmt_err.map(Err).unwrap_or_else(|| Ok(stmt_group)));
1321 }
1322 Err(e) => stmt_groups.push(Err(e)),
1323 },
1324 SqlRequest::Extended { queries } => {
1325 for ExtendedRequest { query, params } in queries {
1326 match parse(client, &query) {
1327 Ok(mut stmts) => {
1328 if stmts.len() != 1 {
1329 return Err(Error::Unstructured(anyhow!(
1330 "each query must contain exactly 1 statement, but \"{}\" contains {}",
1331 query,
1332 stmts.len()
1333 )));
1334 }
1335
1336 let StatementParseResult { ast: stmt, sql } = stmts.pop().unwrap();
1337 stmt_groups.push(
1338 check_prohibited_stmts(sender, &stmt)
1339 .map(|_| vec![(stmt, sql.to_string(), params)]),
1340 );
1341 }
1342 Err(e) => stmt_groups.push(Err(e)),
1343 };
1344 }
1345 }
1346 }
1347
1348 for stmt_group_res in stmt_groups {
1349 let executed = match stmt_group_res {
1350 Ok(stmt_group) => execute_stmt_group(client, sender, stmt_group).await,
1351 Err(e) => {
1352 let err = SqlResult::err(client, e);
1353 let _ = send_and_retire(err.into(), client, sender).await?;
1354 Ok(Err(()))
1355 }
1356 };
1357 if client.session().transaction().is_implicit() {
1360 let ended = client.end_transaction(EndTransactionAction::Commit).await;
1361 if let Err(err) = ended {
1362 let err = SqlResult::err(client, err);
1363 let _ = send_and_retire(StatementResult::SqlResult(err), client, sender).await?;
1364 }
1365 }
1366 if executed?.is_err() {
1367 break;
1368 }
1369 }
1370
1371 Ok(())
1372}
1373
1374async fn execute_stmt<S: ResultSender>(
1376 client: &mut SessionClient,
1377 sender: &mut S,
1378 stmt: Statement<Raw>,
1379 sql: String,
1380 raw_params: Vec<Option<String>>,
1381) -> Result<StatementResult, Error> {
1382 const EMPTY_PORTAL: &str = "";
1383 if let Err(e) = client
1384 .prepare(EMPTY_PORTAL.into(), Some(stmt.clone()), sql, vec![])
1385 .await
1386 {
1387 return Ok(SqlResult::err(client, e).into());
1388 }
1389
1390 let prep_stmt = match client.get_prepared_statement(EMPTY_PORTAL).await {
1391 Ok(stmt) => stmt,
1392 Err(err) => {
1393 return Ok(SqlResult::err(client, err).into());
1394 }
1395 };
1396
1397 let param_types = &prep_stmt.desc().param_types;
1398 if param_types.len() != raw_params.len() {
1399 let message = anyhow!(
1400 "request supplied {actual} parameters, \
1401 but {statement} requires {expected}",
1402 statement = stmt.to_ast_string_simple(),
1403 actual = raw_params.len(),
1404 expected = param_types.len()
1405 );
1406 return Ok(SqlResult::err(client, Error::Unstructured(message)).into());
1407 }
1408
1409 let buf = RowArena::new();
1410 let mut params = vec![];
1411 for (raw_param, mz_typ) in raw_params.into_iter().zip_eq(param_types) {
1412 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1413 let datum = match raw_param {
1414 None => Datum::Null,
1415 Some(raw_param) => {
1416 match mz_pgrepr::Value::decode(
1417 mz_pgwire_common::Format::Text,
1418 &pg_typ,
1419 raw_param.as_bytes(),
1420 ) {
1421 Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1422 Ok(datum) => datum,
1423 Err(msg) => {
1424 return Ok(
1425 SqlResult::err(client, Error::Unstructured(anyhow!(msg))).into()
1426 );
1427 }
1428 },
1429 Err(err) => {
1430 let msg = anyhow!("unable to decode parameter: {}", err);
1431 return Ok(SqlResult::err(client, Error::Unstructured(msg)).into());
1432 }
1433 }
1434 }
1435 };
1436 params.push((datum, mz_typ.clone()))
1437 }
1438
1439 let result_formats = vec![
1440 mz_pgwire_common::Format::Text;
1441 prep_stmt
1442 .desc()
1443 .relation_desc
1444 .clone()
1445 .map(|desc| desc.typ().column_types.len())
1446 .unwrap_or(0)
1447 ];
1448
1449 let desc = prep_stmt.desc().clone();
1450 let logging = Arc::clone(prep_stmt.logging());
1451 let stmt_ast = prep_stmt.stmt().cloned();
1452 let state_revision = prep_stmt.state_revision;
1453 if let Err(err) = client.session().set_portal(
1454 EMPTY_PORTAL.into(),
1455 desc,
1456 stmt_ast,
1457 logging,
1458 params,
1459 result_formats,
1460 state_revision,
1461 ) {
1462 return Ok(SqlResult::err(client, err).into());
1463 }
1464
1465 let desc = client
1466 .session()
1467 .get_portal_unverified(EMPTY_PORTAL)
1469 .map(|portal| portal.desc.clone())
1470 .expect("unnamed portal should be present");
1471
1472 let res = client
1473 .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
1474 .await;
1475
1476 if S::SUPPORTS_STREAMING_NOTICES {
1477 sender
1478 .emit_streaming_notices(client.session().drain_notices())
1479 .await?;
1480 }
1481
1482 let (res, execute_started) = match res {
1483 Ok(res) => res,
1484 Err(e) => {
1485 return Ok(SqlResult::err(client, e).into());
1486 }
1487 };
1488 let tag = res.tag();
1489
1490 Ok(match res {
1491 ExecuteResponse::CreatedConnection { .. }
1492 | ExecuteResponse::CreatedDatabase { .. }
1493 | ExecuteResponse::CreatedSchema { .. }
1494 | ExecuteResponse::CreatedRole
1495 | ExecuteResponse::CreatedCluster { .. }
1496 | ExecuteResponse::CreatedClusterReplica { .. }
1497 | ExecuteResponse::CreatedTable { .. }
1498 | ExecuteResponse::CreatedIndex { .. }
1499 | ExecuteResponse::CreatedIntrospectionSubscribe
1500 | ExecuteResponse::CreatedSecret { .. }
1501 | ExecuteResponse::CreatedSource { .. }
1502 | ExecuteResponse::CreatedSink { .. }
1503 | ExecuteResponse::CreatedView { .. }
1504 | ExecuteResponse::CreatedViews { .. }
1505 | ExecuteResponse::CreatedMaterializedView { .. }
1506 | ExecuteResponse::CreatedType
1507 | ExecuteResponse::CreatedNetworkPolicy
1508 | ExecuteResponse::Comment
1509 | ExecuteResponse::Deleted(_)
1510 | ExecuteResponse::DiscardedTemp
1511 | ExecuteResponse::DiscardedAll
1512 | ExecuteResponse::DroppedObject(_)
1513 | ExecuteResponse::DroppedOwned
1514 | ExecuteResponse::EmptyQuery
1515 | ExecuteResponse::GrantedPrivilege
1516 | ExecuteResponse::GrantedRole
1517 | ExecuteResponse::Inserted(_)
1518 | ExecuteResponse::Copied(_)
1519 | ExecuteResponse::Raised
1520 | ExecuteResponse::ReassignOwned
1521 | ExecuteResponse::RevokedPrivilege
1522 | ExecuteResponse::AlteredDefaultPrivileges
1523 | ExecuteResponse::RevokedRole
1524 | ExecuteResponse::StartedTransaction { .. }
1525 | ExecuteResponse::Updated(_)
1526 | ExecuteResponse::AlteredObject(_)
1527 | ExecuteResponse::AlteredRole
1528 | ExecuteResponse::AlteredSystemConfiguration
1529 | ExecuteResponse::Deallocate { .. }
1530 | ExecuteResponse::ValidatedConnection
1531 | ExecuteResponse::Prepare => SqlResult::ok(
1532 client,
1533 tag.expect("ok only called on tag-generating results"),
1534 Vec::default(),
1535 )
1536 .into(),
1537 ExecuteResponse::TransactionCommitted { params }
1538 | ExecuteResponse::TransactionRolledBack { params } => {
1539 let notify_set: mz_ore::collections::HashSet<_> = client
1540 .session()
1541 .vars()
1542 .notify_set()
1543 .map(|v| v.name().to_string())
1544 .collect();
1545 let params = params
1546 .into_iter()
1547 .filter(|(name, _value)| notify_set.contains(*name))
1548 .map(|(name, value)| ParameterStatus {
1549 name: name.to_string(),
1550 value,
1551 })
1552 .collect();
1553 SqlResult::ok(
1554 client,
1555 tag.expect("ok only called on tag-generating results"),
1556 params,
1557 )
1558 .into()
1559 }
1560 ExecuteResponse::SetVariable { name, .. } => {
1561 let mut params = Vec::with_capacity(1);
1562 if let Some(var) = client
1563 .session()
1564 .vars()
1565 .notify_set()
1566 .find(|v| v.name() == &name)
1567 {
1568 params.push(ParameterStatus {
1569 name,
1570 value: var.value(),
1571 });
1572 };
1573 SqlResult::ok(
1574 client,
1575 tag.expect("ok only called on tag-generating results"),
1576 params,
1577 )
1578 .into()
1579 }
1580 ExecuteResponse::SendingRowsStreaming {
1581 rows,
1582 instance_id,
1583 strategy,
1584 } => {
1585 let max_query_result_size =
1586 usize::cast_from(client.get_system_vars().await.max_result_size());
1587
1588 let rows_stream = RecordFirstRowStream::new(
1589 Box::new(rows),
1590 execute_started,
1591 client,
1592 Some(instance_id),
1593 Some(strategy),
1594 );
1595
1596 SqlResult::rows(
1597 sender,
1598 client,
1599 rows_stream,
1600 max_query_result_size,
1601 &desc.relation_desc.expect("RelationDesc must exist"),
1602 )
1603 .await?
1604 .into()
1605 }
1606 ExecuteResponse::SendingRowsImmediate { rows } => {
1607 let max_query_result_size =
1608 usize::cast_from(client.get_system_vars().await.max_result_size());
1609
1610 let rows = futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1611 let rows_stream =
1612 RecordFirstRowStream::new(Box::new(rows), execute_started, client, None, None);
1613
1614 SqlResult::rows(
1615 sender,
1616 client,
1617 rows_stream,
1618 max_query_result_size,
1619 &desc.relation_desc.expect("RelationDesc must exist"),
1620 )
1621 .await?
1622 .into()
1623 }
1624 ExecuteResponse::Subscribing {
1625 rx,
1626 ctx_extra,
1627 instance_id,
1628 } => StatementResult::Subscribe {
1629 tag: "SUBSCRIBE".into(),
1630 desc: desc.relation_desc.unwrap(),
1631 rx: RecordFirstRowStream::new(
1632 Box::new(UnboundedReceiverStream::new(rx)),
1633 execute_started,
1634 client,
1635 Some(instance_id),
1636 None,
1637 ),
1638 ctx_extra,
1639 },
1640 res @ (ExecuteResponse::Fetch { .. }
1641 | ExecuteResponse::CopyTo { .. }
1642 | ExecuteResponse::CopyFrom { .. }
1643 | ExecuteResponse::DeclaredCursor
1644 | ExecuteResponse::ClosedCursor) => SqlResult::err(
1645 client,
1646 Error::Unstructured(anyhow!(
1647 "internal error: encountered prohibited ExecuteResponse {:?}.\n\n
1648 This is a bug. Can you please file an bug report letting us know?\n
1649 https://github.com/MaterializeInc/materialize/discussions/new?category=bug-reports",
1650 ExecuteResponseKind::from(res)
1651 )),
1652 )
1653 .into(),
1654 })
1655}
1656
1657fn make_notices(client: &mut SessionClient) -> Vec<Notice> {
1658 client
1659 .session()
1660 .drain_notices()
1661 .into_iter()
1662 .map(|notice| Notice {
1663 message: notice.to_string(),
1664 code: notice.code().code().to_string(),
1665 severity: notice.severity().as_str().to_lowercase(),
1666 detail: notice.detail(),
1667 hint: notice.hint(),
1668 })
1669 .collect()
1670}
1671
1672fn is_txn_exit_stmt(stmt: &Statement<Raw>) -> bool {
1675 matches!(
1676 stmt,
1677 Statement::Commit(_) | Statement::Rollback(_) | Statement::Prepare(_)
1678 )
1679}
1680
1681#[cfg(test)]
1682mod tests {
1683 use std::collections::BTreeMap;
1684
1685 use super::{Password, WebSocketAuth};
1686
1687 #[mz_ore::test]
1688 fn smoke_test_websocket_auth_parse() {
1689 struct TestCase {
1690 json: &'static str,
1691 expected: WebSocketAuth,
1692 }
1693
1694 let test_cases = vec![
1695 TestCase {
1696 json: r#"{ "user": "mz", "password": "1234" }"#,
1697 expected: WebSocketAuth::Basic {
1698 user: "mz".to_string(),
1699 password: Password("1234".to_string()),
1700 options: BTreeMap::default(),
1701 },
1702 },
1703 TestCase {
1704 json: r#"{ "user": "mz", "password": "1234", "options": {} }"#,
1705 expected: WebSocketAuth::Basic {
1706 user: "mz".to_string(),
1707 password: Password("1234".to_string()),
1708 options: BTreeMap::default(),
1709 },
1710 },
1711 TestCase {
1712 json: r#"{ "token": "i_am_a_token" }"#,
1713 expected: WebSocketAuth::Bearer {
1714 token: "i_am_a_token".to_string(),
1715 options: BTreeMap::default(),
1716 },
1717 },
1718 TestCase {
1719 json: r#"{ "token": "i_am_a_token", "options": { "foo": "bar" } }"#,
1720 expected: WebSocketAuth::Bearer {
1721 token: "i_am_a_token".to_string(),
1722 options: BTreeMap::from([("foo".to_string(), "bar".to_string())]),
1723 },
1724 },
1725 ];
1726
1727 fn assert_parse(json: &'static str, expected: WebSocketAuth) {
1728 let parsed: WebSocketAuth = serde_json::from_str(json).unwrap();
1729 assert_eq!(parsed, expected);
1730 }
1731
1732 for TestCase { json, expected } in test_cases {
1733 assert_parse(json, expected)
1734 }
1735 }
1736}