1use std::collections::BTreeMap;
11use std::net::{IpAddr, SocketAddr};
12use std::pin::pin;
13use std::sync::Arc;
14use std::time::Duration;
15
16use anyhow::anyhow;
17use async_trait::async_trait;
18use axum::extract::connect_info::ConnectInfo;
19use axum::extract::ws::{CloseFrame, Message, Utf8Bytes, WebSocket};
20use axum::extract::{State, WebSocketUpgrade};
21use axum::response::IntoResponse;
22use axum::{Extension, Json};
23use futures::Future;
24use futures::future::BoxFuture;
25
26use http::StatusCode;
27use itertools::Itertools;
28use mz_adapter::client::RecordFirstRowStream;
29use mz_adapter::session::{EndTransactionAction, TransactionStatus};
30use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
31use mz_adapter::{
32 AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, ExecuteResponseKind,
33 PeekResponseUnary, SessionClient, verify_datum_desc,
34};
35use mz_auth::password::Password;
36use mz_catalog::memory::objects::{Cluster, ClusterReplica};
37use mz_interchange::encode::TypedDatum;
38use mz_interchange::json::{JsonNumberPolicy, ToJson};
39use mz_ore::cast::CastFrom;
40use mz_ore::metrics::{MakeCollectorOpts, MetricsRegistry};
41use mz_ore::result::ResultExt;
42use mz_repr::{Datum, RelationDesc, RowArena, RowIterator};
43use mz_sql::ast::display::AstDisplay;
44use mz_sql::ast::{CopyDirection, CopyStatement, CopyTarget, Raw, Statement, StatementKind};
45use mz_sql::parse::StatementParseResult;
46use mz_sql::plan::Plan;
47use mz_sql::session::metadata::SessionMetadata;
48use prometheus::Opts;
49use prometheus::core::{AtomicF64, GenericGaugeVec};
50use serde::{Deserialize, Serialize};
51use tokio::{select, time};
52use tokio_postgres::error::SqlState;
53use tokio_stream::wrappers::UnboundedReceiverStream;
54use tower_sessions::Session as TowerSession;
55use tracing::{debug, info};
56use tungstenite::protocol::frame::coding::CloseCode;
57
58use crate::http::prometheus::PrometheusSqlQuery;
59use crate::http::{
60 AuthError, AuthedClient, AuthedUser, MAX_REQUEST_SIZE, WsState, ensure_session_unexpired,
61 init_ws, maybe_get_authenticated_session,
62};
63
64#[derive(Debug, thiserror::Error)]
65pub enum Error {
66 #[error(transparent)]
67 Adapter(#[from] AdapterError),
68 #[error(transparent)]
69 Json(#[from] serde_json::Error),
70 #[error(transparent)]
71 Axum(#[from] axum::Error),
72 #[error("SUBSCRIBE only supported over websocket")]
73 SubscribeOnlyOverWs,
74 #[error("current transaction is aborted, commands ignored until end of transaction block")]
75 AbortedTransaction,
76 #[error("unsupported via this API: {0}")]
77 Unsupported(String),
78 #[error("{0}")]
79 Unstructured(anyhow::Error),
80}
81
82impl Error {
83 pub fn detail(&self) -> Option<String> {
84 match self {
85 Error::Adapter(err) => err.detail(),
86 _ => None,
87 }
88 }
89
90 pub fn hint(&self) -> Option<String> {
91 match self {
92 Error::Adapter(err) => err.hint(),
93 _ => None,
94 }
95 }
96
97 pub fn position(&self) -> Option<usize> {
98 match self {
99 Error::Adapter(err) => err.position(),
100 _ => None,
101 }
102 }
103
104 pub fn code(&self) -> SqlState {
105 match self {
106 Error::Adapter(err) => err.code(),
107 Error::AbortedTransaction => SqlState::IN_FAILED_SQL_TRANSACTION,
108 _ => SqlState::INTERNAL_ERROR,
109 }
110 }
111}
112
113static PER_REPLICA_LABELS: &[&str] = &["replica_full_name", "instance_id", "replica_id"];
114
115async fn execute_promsql_query(
116 client: &mut AuthedClient,
117 query: &PrometheusSqlQuery<'_>,
118 metrics_registry: &MetricsRegistry,
119 metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
120 cluster: Option<(&Cluster, &ClusterReplica)>,
121) {
122 assert_eq!(query.per_replica, cluster.is_some());
123
124 let mut res = SqlResponse {
125 results: Vec::new(),
126 };
127
128 execute_request(client, query.to_sql_request(cluster), &mut res)
129 .await
130 .expect("valid SQL query");
131
132 let result = match res.results.as_slice() {
133 [
136 SqlResult::Ok { .. },
137 SqlResult::Ok { .. },
138 SqlResult::Ok { .. },
139 result,
140 ] => result,
141 _ => {
145 info!(
146 "error executing prometheus query {}: {:?}",
147 query.metric_name, res
148 );
149 return;
150 }
151 };
152
153 let SqlResult::Rows { desc, rows, .. } = result else {
154 info!(
155 "did not receive rows for SQL query for prometheus metric {}: {:?}, {:?}",
156 query.metric_name, result, cluster
157 );
158 return;
159 };
160
161 let gauge_vec = metrics_by_name
162 .entry(query.metric_name.to_string())
163 .or_insert_with(|| {
164 let mut label_names: Vec<String> = desc
165 .columns
166 .iter()
167 .filter(|col| col.name != query.value_column_name)
168 .map(|col| col.name.clone())
169 .collect();
170
171 if query.per_replica {
172 label_names.extend(PER_REPLICA_LABELS.iter().map(|label| label.to_string()));
173 }
174
175 metrics_registry.register::<GenericGaugeVec<AtomicF64>>(MakeCollectorOpts {
176 opts: Opts::new(query.metric_name, query.help).variable_labels(label_names),
177 buckets: None,
178 })
179 });
180
181 for row in rows {
182 let mut label_values = desc
183 .columns
184 .iter()
185 .zip_eq(row)
186 .filter(|(col, _)| col.name != query.value_column_name)
187 .map(|(_, val)| val.as_str().expect("must be string"))
188 .collect::<Vec<_>>();
189
190 let value = desc
191 .columns
192 .iter()
193 .zip_eq(row)
194 .find(|(col, _)| col.name == query.value_column_name)
195 .map(|(_, val)| val.as_str().unwrap_or("0").parse::<f64>().unwrap_or(0.0))
196 .unwrap_or(0.0);
197
198 match cluster {
199 Some((cluster, replica)) => {
200 let replica_full_name = format!("{}.{}", cluster.name, replica.name);
201 let cluster_id = cluster.id.to_string();
202 let replica_id = replica.replica_id.to_string();
203
204 label_values.push(&replica_full_name);
205 label_values.push(&cluster_id);
206 label_values.push(&replica_id);
207
208 gauge_vec
209 .get_metric_with_label_values(&label_values)
210 .expect("valid labels")
211 .set(value);
212 }
213 None => {
214 gauge_vec
215 .get_metric_with_label_values(&label_values)
216 .expect("valid labels")
217 .set(value);
218 }
219 }
220 }
221}
222
223async fn handle_promsql_query(
224 client: &mut AuthedClient,
225 query: &PrometheusSqlQuery<'_>,
226 metrics_registry: &MetricsRegistry,
227 metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
228) {
229 if !query.per_replica {
230 execute_promsql_query(client, query, metrics_registry, metrics_by_name, None).await;
231 return;
232 }
233
234 let catalog = client.client.catalog_snapshot("handle_promsql_query").await;
235 let clusters: Vec<&Cluster> = catalog.clusters().collect();
236
237 for cluster in clusters {
238 for replica in cluster.replicas() {
239 execute_promsql_query(
240 client,
241 query,
242 metrics_registry,
243 metrics_by_name,
244 Some((cluster, replica)),
245 )
246 .await;
247 }
248 }
249}
250
251pub async fn handle_promsql(
252 mut client: AuthedClient,
253 queries: &[PrometheusSqlQuery<'_>],
254) -> MetricsRegistry {
255 let metrics_registry = MetricsRegistry::new();
256 let mut metrics_by_name = BTreeMap::new();
257
258 for query in queries {
259 handle_promsql_query(&mut client, query, &metrics_registry, &mut metrics_by_name).await;
260 }
261
262 metrics_registry
263}
264
265pub async fn handle_sql(
266 mut client: AuthedClient,
267 Json(request): Json<SqlRequest>,
268) -> impl IntoResponse {
269 let mut res = SqlResponse {
270 results: Vec::new(),
271 };
272 match execute_request(&mut client, request, &mut res).await {
275 Ok(()) => Ok(Json(res)),
276 Err(e) => Err((StatusCode::BAD_REQUEST, e.to_string())),
277 }
278}
279
280#[derive(Debug)]
281pub enum ExistingUser {
282 XMaterializeUserHeader(AuthedUser),
285 Session(AuthedUser),
288}
289
290pub(crate) async fn handle_sql_ws(
291 State(state): State<WsState>,
292 existing_user: Option<Extension<AuthedUser>>,
293 ws: WebSocketUpgrade,
294 ConnectInfo(addr): ConnectInfo<SocketAddr>,
295 tower_session: Option<Extension<TowerSession>>,
296) -> Result<impl IntoResponse, AuthError> {
297 let session = tower_session.map(|Extension(session)| session);
298 let user = match existing_user {
300 Some(Extension(user)) => Some(ExistingUser::XMaterializeUserHeader(user)),
301 None => {
302 let session = maybe_get_authenticated_session(session.as_ref()).await;
303 if let Some((session, session_data)) = session {
304 let user = ensure_session_unexpired(session, session_data).await?;
305 Some(ExistingUser::Session(user))
306 } else {
307 None
308 }
309 }
310 };
311
312 let addr = Box::new(addr.ip());
313 Ok(ws
314 .max_message_size(MAX_REQUEST_SIZE)
315 .on_upgrade(|ws| async move { run_ws(state, user, *addr, ws).await }))
316}
317
318#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
319#[serde(untagged)]
320pub enum WebSocketAuth {
321 Basic {
322 user: String,
323 password: Password,
324 #[serde(default)]
325 options: BTreeMap<String, String>,
326 },
327 Bearer {
328 token: String,
329 #[serde(default)]
330 options: BTreeMap<String, String>,
331 },
332 OptionsOnly {
333 #[serde(default)]
334 options: BTreeMap<String, String>,
335 },
336}
337
338async fn run_ws(state: WsState, user: Option<ExistingUser>, peer_addr: IpAddr, mut ws: WebSocket) {
339 let mut client = match init_ws(state, user, peer_addr, &mut ws).await {
340 Ok(client) => client,
341 Err(e) => {
342 debug!("WS request failed init: {}", e);
346 let reason: Utf8Bytes = match e.downcast_ref::<AdapterError>() {
347 Some(error) => error.to_string().into(),
348 None => "unauthorized".to_string().into(),
349 };
350 let _ = ws
351 .send(Message::Close(Some(CloseFrame {
352 code: CloseCode::Protocol.into(),
353 reason,
354 })))
355 .await;
356 return;
357 }
358 };
359
360 let mut msgs = Vec::new();
362 let session = client.client.session();
363 for var in session.vars().notify_set() {
364 msgs.push(WebSocketResponse::ParameterStatus(ParameterStatus {
365 name: var.name().to_string(),
366 value: var.value(),
367 }));
368 }
369 msgs.push(WebSocketResponse::BackendKeyData(BackendKeyData {
370 conn_id: session.conn_id().unhandled(),
371 secret_key: session.secret_key(),
372 }));
373 msgs.push(WebSocketResponse::ReadyForQuery(
374 session.transaction_code().into(),
375 ));
376 for msg in msgs {
377 let _ = ws
378 .send(Message::Text(
379 serde_json::to_string(&msg).expect("must serialize").into(),
380 ))
381 .await;
382 }
383
384 let notices = session.drain_notices();
386 if let Err(err) = forward_notices(&mut ws, notices).await {
387 debug!("failed to forward notices to WebSocket, {err:?}");
388 return;
389 }
390
391 loop {
392 let msg = select! {
394 biased;
395
396 Some(timeout) = client.client.recv_timeout() => {
398 client.client.terminate().await;
399 let _ = ws.recv().await;
403 let err = Error::from(AdapterError::from(timeout));
404 let _ = send_ws_response(&mut ws, WebSocketResponse::Error(err.into())).await;
405 return;
406 },
407 message = ws.recv() => message,
408 };
409
410 client.client.remove_idle_in_transaction_session_timeout();
411
412 let msg = match msg {
413 Some(Ok(msg)) => msg,
414 _ => {
415 return;
417 }
418 };
419
420 let req: Result<SqlRequest, Error> = match msg {
421 Message::Text(data) => serde_json::from_str(&data).err_into(),
422 Message::Binary(data) => serde_json::from_slice(&data).err_into(),
423 Message::Ping(_) => {
425 continue;
426 }
427 Message::Pong(_) => {
428 continue;
429 }
430 Message::Close(_) => {
431 return;
432 }
433 };
434
435 let err = match run_ws_request(req, &mut client, &mut ws).await {
437 Ok(()) => None,
438 Err(err) => Some(WebSocketResponse::Error(err.into())),
439 };
440
441 let ws_response = || async {
448 if let Some(e_resp) = err {
450 send_ws_response(&mut ws, e_resp).await?;
451 }
452
453 let notices = client.client.session().drain_notices();
455 forward_notices(&mut ws, notices).await?;
456
457 let ready =
459 WebSocketResponse::ReadyForQuery(client.client.session().transaction_code().into());
460 send_ws_response(&mut ws, ready).await?;
461
462 Ok::<_, Error>(())
463 };
464
465 if let Err(err) = ws_response().await {
466 debug!("failed to send response over WebSocket, {err:?}");
467 return;
468 }
469 }
470}
471
472async fn run_ws_request(
473 req: Result<SqlRequest, Error>,
474 client: &mut AuthedClient,
475 ws: &mut WebSocket,
476) -> Result<(), Error> {
477 let req = req?;
478 execute_request(client, req, ws).await
479}
480
481async fn send_ws_response(ws: &mut WebSocket, resp: WebSocketResponse) -> Result<(), Error> {
483 let msg = serde_json::to_string(&resp).unwrap();
484 let msg = Message::Text(msg.into());
485 ws.send(msg).await?;
486
487 Ok(())
488}
489
490async fn forward_notices(
492 ws: &mut WebSocket,
493 notices: impl IntoIterator<Item = AdapterNotice>,
494) -> Result<(), Error> {
495 let ws_notices = notices.into_iter().map(|notice| {
496 WebSocketResponse::Notice(Notice {
497 message: notice.to_string(),
498 code: notice.code().code().to_string(),
499 severity: notice.severity().as_str().to_lowercase(),
500 detail: notice.detail(),
501 hint: notice.hint(),
502 })
503 });
504
505 for notice in ws_notices {
506 send_ws_response(ws, notice).await?;
507 }
508
509 Ok(())
510}
511
512#[derive(Serialize, Deserialize, Debug)]
514#[serde(untagged)]
515pub enum SqlRequest {
516 Simple {
518 query: String,
521 },
522 Extended {
524 queries: Vec<ExtendedRequest>,
526 },
527}
528
529#[derive(Serialize, Deserialize, Debug)]
531pub struct ExtendedRequest {
532 query: String,
534 #[serde(default)]
536 params: Vec<Option<String>>,
537}
538
539#[derive(Debug, Serialize, Deserialize)]
541pub struct SqlResponse {
542 pub(in crate::http) results: Vec<SqlResult>,
544}
545
546impl SqlResponse {
547 pub(in crate::http) fn new() -> Self {
549 Self {
550 results: Vec::new(),
551 }
552 }
553}
554
555pub(in crate::http) enum StatementResult {
556 SqlResult(SqlResult),
557 Subscribe {
558 desc: RelationDesc,
559 tag: String,
560 rx: RecordFirstRowStream,
561 ctx_extra: ExecuteContextGuard,
562 },
563}
564
565impl From<SqlResult> for StatementResult {
566 fn from(inner: SqlResult) -> Self {
567 Self::SqlResult(inner)
568 }
569}
570
571#[derive(Debug, Serialize, Deserialize)]
573#[serde(untagged)]
574pub enum SqlResult {
575 Rows {
577 tag: String,
579 rows: Vec<Vec<serde_json::Value>>,
581 desc: Description,
583 notices: Vec<Notice>,
585 },
586 Ok {
588 ok: String,
590 notices: Vec<Notice>,
592 #[serde(skip_serializing_if = "Vec::is_empty")]
596 parameters: Vec<ParameterStatus>,
597 },
598 Err {
600 error: SqlError,
601 notices: Vec<Notice>,
603 },
604}
605
606impl SqlResult {
607 async fn rows<S>(
611 sender: &mut S,
612 client: &mut SessionClient,
613 mut rows_stream: RecordFirstRowStream,
614 max_query_result_size: usize,
615 desc: &RelationDesc,
616 ) -> Result<SqlResult, Error>
617 where
618 S: ResultSender,
619 {
620 let mut rows: Vec<Vec<serde_json::Value>> = vec![];
621 let mut datum_vec = mz_repr::DatumVec::new();
622 let types = &desc.typ().column_types;
623
624 let mut query_result_size = 0;
625
626 loop {
627 let peek_response = tokio::select! {
628 notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
629 sender.emit_streaming_notices(vec![notice]).await?;
630 continue;
631 }
632 e = sender.connection_error() => return Err(e),
633 r = rows_stream.recv() => {
634 match r {
635 Some(r) => r,
636 None => break,
637 }
638 },
639 };
640
641 let mut sql_rows = match peek_response {
642 PeekResponseUnary::Rows(rows) => rows,
643 PeekResponseUnary::Error(e) => {
644 return Ok(SqlResult::err(client, Error::Unstructured(anyhow!(e))));
645 }
646 PeekResponseUnary::DependencyDropped(dep) => {
647 return Ok(SqlResult::err(client, dep.to_concurrent_dependency_drop()));
648 }
649 PeekResponseUnary::Canceled => {
650 return Ok(SqlResult::err(client, AdapterError::Canceled));
651 }
652 };
653
654 if let Err(err) = verify_datum_desc(desc, &mut sql_rows) {
655 return Ok(SqlResult::Err {
656 error: err.into(),
657 notices: make_notices(client),
658 });
659 }
660
661 while let Some(row) = sql_rows.next() {
662 query_result_size += row.byte_len();
663 if query_result_size > max_query_result_size {
664 use bytesize::ByteSize;
665 return Ok(SqlResult::err(
666 client,
667 AdapterError::ResultSize(format!(
668 "result exceeds max size of {}",
669 ByteSize::b(u64::cast_from(max_query_result_size))
670 )),
671 ));
672 }
673
674 let datums = datum_vec.borrow_with(row);
675 rows.push(
676 datums
677 .iter()
678 .enumerate()
679 .map(|(i, d)| {
680 TypedDatum::new(*d, &types[i])
681 .json(&JsonNumberPolicy::ConvertNumberToString)
682 })
683 .collect(),
684 );
685 }
686 }
687
688 let tag = format!("SELECT {}", rows.len());
689 Ok(SqlResult::Rows {
690 tag,
691 rows,
692 desc: Description::from(desc),
693 notices: make_notices(client),
694 })
695 }
696
697 fn err(client: &mut SessionClient, error: impl Into<SqlError>) -> SqlResult {
698 SqlResult::Err {
699 error: error.into(),
700 notices: make_notices(client),
701 }
702 }
703
704 fn ok(client: &mut SessionClient, tag: String, params: Vec<ParameterStatus>) -> SqlResult {
705 SqlResult::Ok {
706 ok: tag,
707 parameters: params,
708 notices: make_notices(client),
709 }
710 }
711}
712
713#[derive(Debug, Deserialize, Serialize)]
714pub struct SqlError {
715 pub message: String,
716 pub code: String,
717 #[serde(skip_serializing_if = "Option::is_none")]
718 pub detail: Option<String>,
719 #[serde(skip_serializing_if = "Option::is_none")]
720 pub hint: Option<String>,
721 #[serde(skip_serializing_if = "Option::is_none")]
722 pub position: Option<usize>,
723}
724
725impl From<Error> for SqlError {
726 fn from(err: Error) -> Self {
727 SqlError {
728 message: err.to_string(),
729 code: err.code().code().to_string(),
730 detail: err.detail(),
731 hint: err.hint(),
732 position: err.position(),
733 }
734 }
735}
736
737impl From<AdapterError> for SqlError {
738 fn from(value: AdapterError) -> Self {
739 Error::from(value).into()
740 }
741}
742
743#[derive(Debug, Deserialize, Serialize)]
744#[serde(tag = "type", content = "payload")]
745pub enum WebSocketResponse {
746 ReadyForQuery(String),
747 Notice(Notice),
748 Rows(Description),
749 Row(Vec<serde_json::Value>),
750 CommandStarting(CommandStarting),
751 CommandComplete(String),
752 Error(SqlError),
753 ParameterStatus(ParameterStatus),
754 BackendKeyData(BackendKeyData),
755}
756
757#[derive(Debug, Serialize, Deserialize)]
758pub struct Notice {
759 message: String,
760 code: String,
761 severity: String,
762 #[serde(skip_serializing_if = "Option::is_none")]
763 pub detail: Option<String>,
764 #[serde(skip_serializing_if = "Option::is_none")]
765 pub hint: Option<String>,
766}
767
768impl Notice {
769 pub fn message(&self) -> &str {
770 &self.message
771 }
772}
773
774#[derive(Debug, Serialize, Deserialize)]
775pub struct Description {
776 pub columns: Vec<Column>,
777}
778
779impl From<&RelationDesc> for Description {
780 fn from(desc: &RelationDesc) -> Self {
781 let columns = desc
782 .iter()
783 .map(|(name, typ)| {
784 let pg_type = mz_pgrepr::Type::from(&typ.scalar_type);
785 Column {
786 name: name.to_string(),
787 type_oid: pg_type.oid(),
788 type_len: pg_type.typlen(),
789 type_mod: pg_type.typmod(),
790 }
791 })
792 .collect();
793 Description { columns }
794 }
795}
796
797#[derive(Debug, Serialize, Deserialize)]
798pub struct Column {
799 pub name: String,
800 pub type_oid: u32,
801 pub type_len: i16,
802 pub type_mod: i32,
803}
804
805#[derive(Debug, Serialize, Deserialize)]
806pub struct ParameterStatus {
807 name: String,
808 value: String,
809}
810
811#[derive(Debug, Serialize, Deserialize)]
812pub struct BackendKeyData {
813 conn_id: u32,
814 secret_key: u32,
815}
816
817#[derive(Debug, Serialize, Deserialize)]
818pub struct CommandStarting {
819 has_rows: bool,
820 is_streaming: bool,
821}
822
823#[async_trait]
827pub(in crate::http) trait ResultSender: Send {
828 const SUPPORTS_STREAMING_NOTICES: bool = false;
829
830 async fn add_result(
838 &mut self,
839 client: &mut SessionClient,
840 res: StatementResult,
841 ) -> (
842 Result<Result<(), ()>, Error>,
843 Option<(StatementEndedExecutionReason, ExecuteContextGuard)>,
844 );
845
846 fn connection_error(&mut self) -> BoxFuture<'_, Error>;
848 fn allow_subscribe(&self) -> bool;
850
851 async fn emit_streaming_notices(&mut self, _: Vec<AdapterNotice>) -> Result<(), Error> {
855 unreachable!("streaming notices marked as unsupported")
856 }
857}
858
859#[async_trait]
860impl ResultSender for SqlResponse {
861 async fn add_result(
869 &mut self,
870 _client: &mut SessionClient,
871 res: StatementResult,
872 ) -> (
873 Result<Result<(), ()>, Error>,
874 Option<(StatementEndedExecutionReason, ExecuteContextGuard)>,
875 ) {
876 let (res, stmt_logging) = match res {
877 StatementResult::SqlResult(res) => {
878 let is_err = matches!(res, SqlResult::Err { .. });
879 self.results.push(res);
880 let res = if is_err { Err(()) } else { Ok(()) };
881 (res, None)
882 }
883 StatementResult::Subscribe { ctx_extra, .. } => {
884 let message = "SUBSCRIBE only supported over websocket";
885 self.results.push(SqlResult::Err {
886 error: Error::SubscribeOnlyOverWs.into(),
887 notices: Vec::new(),
888 });
889 (
890 Err(()),
891 Some((
892 StatementEndedExecutionReason::Errored {
893 error: message.into(),
894 },
895 ctx_extra,
896 )),
897 )
898 }
899 };
900 (Ok(res), stmt_logging)
901 }
902
903 fn connection_error(&mut self) -> BoxFuture<'_, Error> {
904 Box::pin(futures::future::pending())
905 }
906
907 fn allow_subscribe(&self) -> bool {
908 false
909 }
910}
911
912#[async_trait]
913impl ResultSender for WebSocket {
914 const SUPPORTS_STREAMING_NOTICES: bool = true;
915
916 async fn add_result(
922 &mut self,
923 client: &mut SessionClient,
924 res: StatementResult,
925 ) -> (
926 Result<Result<(), ()>, Error>,
927 Option<(StatementEndedExecutionReason, ExecuteContextGuard)>,
928 ) {
929 let (has_rows, is_streaming) = match res {
930 StatementResult::SqlResult(SqlResult::Err { .. }) => (false, false),
931 StatementResult::SqlResult(SqlResult::Ok { .. }) => (false, false),
932 StatementResult::SqlResult(SqlResult::Rows { .. }) => (true, false),
933 StatementResult::Subscribe { .. } => (true, true),
934 };
935 if let Err(e) = send_ws_response(
936 self,
937 WebSocketResponse::CommandStarting(CommandStarting {
938 has_rows,
939 is_streaming,
940 }),
941 )
942 .await
943 {
944 return (Err(e), None);
945 }
946
947 let (is_err, msgs, stmt_logging) = match res {
948 StatementResult::SqlResult(SqlResult::Rows {
949 tag,
950 rows,
951 desc,
952 notices,
953 }) => {
954 if let Err(e) = send_ws_response(self, WebSocketResponse::Rows(desc)).await {
957 return (Err(e), None);
958 }
959 for row in rows {
960 if let Err(e) = send_ws_response(self, WebSocketResponse::Row(row)).await {
961 return (Err(e), None);
962 }
963 }
964 let mut msgs = vec![WebSocketResponse::CommandComplete(tag)];
965 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
966 (false, msgs, None)
967 }
968 StatementResult::SqlResult(SqlResult::Ok {
969 ok,
970 parameters,
971 notices,
972 }) => {
973 let mut msgs = vec![WebSocketResponse::CommandComplete(ok)];
974 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
975 msgs.extend(
976 parameters
977 .into_iter()
978 .map(WebSocketResponse::ParameterStatus),
979 );
980 (false, msgs, None)
981 }
982 StatementResult::SqlResult(SqlResult::Err { error, notices }) => {
983 let mut msgs = vec![WebSocketResponse::Error(error)];
984 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
985 (true, msgs, None)
986 }
987 StatementResult::Subscribe {
988 ref desc,
989 tag,
990 mut rx,
991 ctx_extra,
992 } => {
993 if let Err(e) = send_ws_response(self, WebSocketResponse::Rows(desc.into())).await {
994 return (
997 Err(e),
998 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
999 );
1000 }
1001
1002 let mut datum_vec = mz_repr::DatumVec::new();
1003 let mut result_size: usize = 0;
1004 let mut rows_returned = 0;
1005 loop {
1006 let res = match await_rows(self, client, rx.recv()).await {
1007 Ok(res) => res,
1008 Err(e) => {
1009 return (
1012 Err(e),
1013 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1014 );
1015 }
1016 };
1017 match res {
1018 Some(PeekResponseUnary::Rows(mut rows)) => {
1019 if let Err(err) = verify_datum_desc(desc, &mut rows) {
1020 let error = err.to_string();
1021 break (
1022 true,
1023 vec![WebSocketResponse::Error(err.into())],
1024 Some((
1025 StatementEndedExecutionReason::Errored { error },
1026 ctx_extra,
1027 )),
1028 );
1029 }
1030
1031 rows_returned += rows.count();
1032 while let Some(row) = rows.next() {
1033 result_size += row.byte_len();
1034 let datums = datum_vec.borrow_with(row);
1035 let types = &desc.typ().column_types;
1036 if let Err(e) = send_ws_response(
1037 self,
1038 WebSocketResponse::Row(
1039 datums
1040 .iter()
1041 .enumerate()
1042 .map(|(i, d)| {
1043 TypedDatum::new(*d, &types[i])
1044 .json(&JsonNumberPolicy::ConvertNumberToString)
1045 })
1046 .collect(),
1047 ),
1048 )
1049 .await
1050 {
1051 return (
1054 Err(e),
1055 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1056 );
1057 }
1058 }
1059 }
1060 Some(PeekResponseUnary::Error(error)) => {
1061 break (
1062 true,
1063 vec![WebSocketResponse::Error(
1064 Error::Unstructured(anyhow!(error.clone())).into(),
1065 )],
1066 Some((StatementEndedExecutionReason::Errored { error }, ctx_extra)),
1067 );
1068 }
1069 Some(PeekResponseUnary::DependencyDropped(dep)) => {
1070 let err = dep.to_concurrent_dependency_drop();
1071 let error = err.to_string();
1072 break (
1073 true,
1074 vec![WebSocketResponse::Error(err.into())],
1075 Some((StatementEndedExecutionReason::Errored { error }, ctx_extra)),
1076 );
1077 }
1078 Some(PeekResponseUnary::Canceled) => {
1079 break (
1080 true,
1081 vec![WebSocketResponse::Error(AdapterError::Canceled.into())],
1082 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1083 );
1084 }
1085 None => {
1086 break (
1087 false,
1088 vec![WebSocketResponse::CommandComplete(tag)],
1089 Some((
1090 StatementEndedExecutionReason::Success {
1091 result_size: Some(u64::cast_from(result_size)),
1092 rows_returned: Some(u64::cast_from(rows_returned)),
1093 execution_strategy: Some(
1094 StatementExecutionStrategy::Standard,
1095 ),
1096 },
1097 ctx_extra,
1098 )),
1099 );
1100 }
1101 }
1102 }
1103 }
1104 };
1105 for msg in msgs {
1106 if let Err(e) = send_ws_response(self, msg).await {
1107 return (
1108 Err(e),
1109 stmt_logging.map(|(_old_reason, ctx_extra)| {
1110 (StatementEndedExecutionReason::Canceled, ctx_extra)
1111 }),
1112 );
1113 }
1114 }
1115 (Ok(if is_err { Err(()) } else { Ok(()) }), stmt_logging)
1116 }
1117
1118 fn connection_error(&mut self) -> BoxFuture<'_, Error> {
1121 Box::pin(async {
1122 let mut tick = time::interval(Duration::from_secs(1));
1123 tick.tick().await;
1124 loop {
1125 tick.tick().await;
1126 if let Err(err) = self.send(Message::Ping(Vec::new().into())).await {
1127 return err.into();
1128 }
1129 }
1130 })
1131 }
1132
1133 fn allow_subscribe(&self) -> bool {
1134 true
1135 }
1136
1137 async fn emit_streaming_notices(&mut self, notices: Vec<AdapterNotice>) -> Result<(), Error> {
1138 forward_notices(self, notices).await
1139 }
1140}
1141
1142async fn await_rows<S, F, R>(sender: &mut S, client: &mut SessionClient, f: F) -> Result<R, Error>
1143where
1144 S: ResultSender,
1145 F: Future<Output = R> + Send,
1146{
1147 let mut f = pin!(f);
1148 loop {
1149 tokio::select! {
1150 notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
1151 sender.emit_streaming_notices(vec![notice]).await?;
1152 }
1153 e = sender.connection_error() => return Err(e),
1154 r = &mut f => return Ok(r),
1155 }
1156 }
1157}
1158
1159async fn send_and_retire<S: ResultSender>(
1160 res: StatementResult,
1161 client: &mut SessionClient,
1162 sender: &mut S,
1163) -> Result<Result<(), ()>, Error> {
1164 let (res, stmt_logging) = sender.add_result(client, res).await;
1165 if let Some((reason, ctx_extra)) = stmt_logging {
1166 client.retire_execute(ctx_extra, reason);
1167 }
1168 res
1169}
1170
1171async fn execute_stmt_group<S: ResultSender>(
1173 client: &mut SessionClient,
1174 sender: &mut S,
1175 stmt_group: Vec<(Statement<Raw>, String, Vec<Option<String>>)>,
1176) -> Result<Result<(), ()>, Error> {
1177 let num_stmts = stmt_group.len();
1178 for (stmt, sql, params) in stmt_group {
1179 assert!(
1180 num_stmts <= 1 || params.is_empty(),
1181 "statement groups contain more than 1 statement iff Simple request, which does not support parameters"
1182 );
1183
1184 let is_aborted_txn = matches!(client.session().transaction(), TransactionStatus::Failed(_));
1185 if is_aborted_txn && !is_txn_exit_stmt(&stmt) {
1186 let err = SqlResult::err(client, Error::AbortedTransaction);
1187 let _ = send_and_retire(err.into(), client, sender).await?;
1188 return Ok(Err(()));
1189 }
1190
1191 if let Err(e) = client.start_transaction(Some(num_stmts)) {
1194 let err = SqlResult::err(client, e);
1195 let _ = send_and_retire(err.into(), client, sender).await?;
1196 return Ok(Err(()));
1197 }
1198 let res = execute_stmt(client, sender, stmt, sql, params).await?;
1199 let is_err = send_and_retire(res, client, sender).await?;
1200
1201 if is_err.is_err() {
1202 let txn = client.session().transaction();
1205 match txn {
1206 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
1209 TransactionStatus::Started(_) | TransactionStatus::InTransactionImplicit(_) => {
1211 if let Err(err) = client.end_transaction(EndTransactionAction::Rollback).await {
1212 let err = SqlResult::err(client, err);
1213 let _ = send_and_retire(err.into(), client, sender).await?;
1214 }
1215 }
1216 TransactionStatus::InTransaction(_) => {
1218 client.fail_transaction();
1219 }
1220 }
1221 return Ok(Err(()));
1222 }
1223 }
1224 Ok(Ok(()))
1225}
1226
1227pub(in crate::http) async fn execute_request<S: ResultSender>(
1236 client: &mut AuthedClient,
1237 request: SqlRequest,
1238 sender: &mut S,
1239) -> Result<(), Error> {
1240 let client = &mut client.client;
1241
1242 fn check_prohibited_stmts<S: ResultSender>(
1245 sender: &S,
1246 stmt: &Statement<Raw>,
1247 ) -> Result<(), Error> {
1248 let kind: StatementKind = stmt.into();
1249 let execute_responses = Plan::generated_from(&kind)
1250 .into_iter()
1251 .map(ExecuteResponse::generated_from)
1252 .flatten()
1253 .collect::<Vec<_>>();
1254
1255 let is_valid_copy = matches!(
1259 stmt,
1260 Statement::Copy(CopyStatement {
1261 direction: CopyDirection::To,
1262 target: CopyTarget::Expr(_),
1263 ..
1264 }) | Statement::Copy(CopyStatement {
1265 direction: CopyDirection::From,
1266 target: CopyTarget::Expr(_),
1267 ..
1268 })
1269 );
1270
1271 if !is_valid_copy
1272 && execute_responses.iter().any(|execute_response| {
1273 match execute_response {
1275 ExecuteResponseKind::Subscribing if sender.allow_subscribe() => false,
1276 ExecuteResponseKind::Fetch
1277 | ExecuteResponseKind::Subscribing
1278 | ExecuteResponseKind::CopyFrom
1279 | ExecuteResponseKind::DeclaredCursor
1280 | ExecuteResponseKind::ClosedCursor => true,
1281 ExecuteResponseKind::CopyTo if matches!(kind, StatementKind::Copy) => true,
1286 _ => false,
1287 }
1288 })
1289 {
1290 return Err(Error::Unsupported(stmt.to_ast_string_simple()));
1291 }
1292 Ok(())
1293 }
1294
1295 fn parse<'a>(
1296 client: &SessionClient,
1297 query: &'a str,
1298 ) -> Result<Vec<StatementParseResult<'a>>, Error> {
1299 let result = client
1300 .parse(query)
1301 .map_err(|e| Error::Unstructured(anyhow!(e)))?;
1302 result.map_err(|e| AdapterError::from(e).into())
1303 }
1304
1305 let mut stmt_groups = vec![];
1306
1307 match request {
1308 SqlRequest::Simple { query } => match parse(client, &query) {
1309 Ok(stmts) => {
1310 let mut stmt_group = Vec::with_capacity(stmts.len());
1311 let mut stmt_err = None;
1312 for StatementParseResult { ast: stmt, sql } in stmts {
1313 if let Err(err) = check_prohibited_stmts(sender, &stmt) {
1314 stmt_err = Some(err);
1315 break;
1316 }
1317 stmt_group.push((stmt, sql.to_string(), vec![]));
1318 }
1319 stmt_groups.push(stmt_err.map(Err).unwrap_or_else(|| Ok(stmt_group)));
1320 }
1321 Err(e) => stmt_groups.push(Err(e)),
1322 },
1323 SqlRequest::Extended { queries } => {
1324 for ExtendedRequest { query, params } in queries {
1325 match parse(client, &query) {
1326 Ok(mut stmts) => {
1327 if stmts.len() != 1 {
1328 return Err(Error::Unstructured(anyhow!(
1329 "each query must contain exactly 1 statement, but \"{}\" contains {}",
1330 query,
1331 stmts.len()
1332 )));
1333 }
1334
1335 let StatementParseResult { ast: stmt, sql } = stmts.pop().unwrap();
1336 stmt_groups.push(
1337 check_prohibited_stmts(sender, &stmt)
1338 .map(|_| vec![(stmt, sql.to_string(), params)]),
1339 );
1340 }
1341 Err(e) => stmt_groups.push(Err(e)),
1342 };
1343 }
1344 }
1345 }
1346
1347 for stmt_group_res in stmt_groups {
1348 let executed = match stmt_group_res {
1349 Ok(stmt_group) => execute_stmt_group(client, sender, stmt_group).await,
1350 Err(e) => {
1351 let err = SqlResult::err(client, e);
1352 let _ = send_and_retire(err.into(), client, sender).await?;
1353 Ok(Err(()))
1354 }
1355 };
1356 if client.session().transaction().is_implicit() {
1359 let ended = client.end_transaction(EndTransactionAction::Commit).await;
1360 if let Err(err) = ended {
1361 let err = SqlResult::err(client, err);
1362 let _ = send_and_retire(StatementResult::SqlResult(err), client, sender).await?;
1363 }
1364 }
1365 if executed?.is_err() {
1366 break;
1367 }
1368 }
1369
1370 Ok(())
1371}
1372
1373async fn execute_stmt<S: ResultSender>(
1375 client: &mut SessionClient,
1376 sender: &mut S,
1377 stmt: Statement<Raw>,
1378 sql: String,
1379 raw_params: Vec<Option<String>>,
1380) -> Result<StatementResult, Error> {
1381 const EMPTY_PORTAL: &str = "";
1382 if let Err(e) = client
1383 .prepare(EMPTY_PORTAL.into(), Some(stmt.clone()), sql, vec![])
1384 .await
1385 {
1386 return Ok(SqlResult::err(client, e).into());
1387 }
1388
1389 let prep_stmt = match client.get_prepared_statement(EMPTY_PORTAL).await {
1390 Ok(stmt) => stmt,
1391 Err(err) => {
1392 return Ok(SqlResult::err(client, err).into());
1393 }
1394 };
1395
1396 let param_types = &prep_stmt.desc().param_types;
1397 if param_types.len() != raw_params.len() {
1398 let message = anyhow!(
1399 "request supplied {actual} parameters, \
1400 but {statement} requires {expected}",
1401 statement = stmt.to_ast_string_simple(),
1402 actual = raw_params.len(),
1403 expected = param_types.len()
1404 );
1405 return Ok(SqlResult::err(client, Error::Unstructured(message)).into());
1406 }
1407
1408 let buf = RowArena::new();
1409 let mut params = vec![];
1410 for (raw_param, mz_typ) in raw_params.into_iter().zip_eq(param_types) {
1411 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1412 let datum = match raw_param {
1413 None => Datum::Null,
1414 Some(raw_param) => {
1415 match mz_pgrepr::Value::decode(
1416 mz_pgwire_common::Format::Text,
1417 &pg_typ,
1418 raw_param.as_bytes(),
1419 ) {
1420 Ok(param) => match param.into_datum_decode_error(&buf, &pg_typ, "parameter") {
1421 Ok(datum) => datum,
1422 Err(msg) => {
1423 return Ok(
1424 SqlResult::err(client, Error::Unstructured(anyhow!(msg))).into()
1425 );
1426 }
1427 },
1428 Err(err) => {
1429 let msg = anyhow!("unable to decode parameter: {}", err);
1430 return Ok(SqlResult::err(client, Error::Unstructured(msg)).into());
1431 }
1432 }
1433 }
1434 };
1435 params.push((datum, mz_typ.clone()))
1436 }
1437
1438 let result_formats = vec![
1439 mz_pgwire_common::Format::Text;
1440 prep_stmt
1441 .desc()
1442 .relation_desc
1443 .clone()
1444 .map(|desc| desc.typ().column_types.len())
1445 .unwrap_or(0)
1446 ];
1447
1448 let desc = prep_stmt.desc().clone();
1449 let logging = Arc::clone(prep_stmt.logging());
1450 let stmt_ast = prep_stmt.stmt().cloned();
1451 let state_revision = prep_stmt.state_revision;
1452 if let Err(err) = client.session().set_portal(
1453 EMPTY_PORTAL.into(),
1454 desc,
1455 stmt_ast,
1456 logging,
1457 params,
1458 result_formats,
1459 state_revision,
1460 ) {
1461 return Ok(SqlResult::err(client, err).into());
1462 }
1463
1464 let desc = client
1465 .session()
1466 .get_portal_unverified(EMPTY_PORTAL)
1468 .map(|portal| portal.desc.clone())
1469 .expect("unnamed portal should be present");
1470
1471 let res = client
1472 .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
1473 .await;
1474
1475 if S::SUPPORTS_STREAMING_NOTICES {
1476 sender
1477 .emit_streaming_notices(client.session().drain_notices())
1478 .await?;
1479 }
1480
1481 let (res, execute_started) = match res {
1482 Ok(res) => res,
1483 Err(e) => {
1484 return Ok(SqlResult::err(client, e).into());
1485 }
1486 };
1487 let tag = res.tag();
1488
1489 Ok(match res {
1490 ExecuteResponse::CreatedConnection { .. }
1491 | ExecuteResponse::CreatedDatabase { .. }
1492 | ExecuteResponse::CreatedSchema { .. }
1493 | ExecuteResponse::CreatedRole
1494 | ExecuteResponse::CreatedCluster { .. }
1495 | ExecuteResponse::CreatedClusterReplica { .. }
1496 | ExecuteResponse::CreatedTable { .. }
1497 | ExecuteResponse::CreatedIndex { .. }
1498 | ExecuteResponse::CreatedIntrospectionSubscribe
1499 | ExecuteResponse::CreatedSecret { .. }
1500 | ExecuteResponse::CreatedSource { .. }
1501 | ExecuteResponse::CreatedSink { .. }
1502 | ExecuteResponse::CreatedView { .. }
1503 | ExecuteResponse::CreatedViews { .. }
1504 | ExecuteResponse::CreatedMaterializedView { .. }
1505 | ExecuteResponse::CreatedType
1506 | ExecuteResponse::CreatedNetworkPolicy
1507 | ExecuteResponse::Comment
1508 | ExecuteResponse::Deleted(_)
1509 | ExecuteResponse::DiscardedTemp
1510 | ExecuteResponse::DiscardedAll
1511 | ExecuteResponse::DroppedObject(_)
1512 | ExecuteResponse::DroppedOwned
1513 | ExecuteResponse::EmptyQuery
1514 | ExecuteResponse::GrantedPrivilege
1515 | ExecuteResponse::GrantedRole
1516 | ExecuteResponse::Inserted(_)
1517 | ExecuteResponse::Copied(_)
1518 | ExecuteResponse::Raised
1519 | ExecuteResponse::ReassignOwned
1520 | ExecuteResponse::RevokedPrivilege
1521 | ExecuteResponse::AlteredDefaultPrivileges
1522 | ExecuteResponse::RevokedRole
1523 | ExecuteResponse::StartedTransaction { .. }
1524 | ExecuteResponse::Updated(_)
1525 | ExecuteResponse::AlteredObject(_)
1526 | ExecuteResponse::AlteredRole
1527 | ExecuteResponse::AlteredSystemConfiguration
1528 | ExecuteResponse::Deallocate { .. }
1529 | ExecuteResponse::ValidatedConnection
1530 | ExecuteResponse::Prepare => SqlResult::ok(
1531 client,
1532 tag.expect("ok only called on tag-generating results"),
1533 Vec::default(),
1534 )
1535 .into(),
1536 ExecuteResponse::TransactionCommitted { params }
1537 | ExecuteResponse::TransactionRolledBack { params } => {
1538 let notify_set: mz_ore::collections::HashSet<_> = client
1539 .session()
1540 .vars()
1541 .notify_set()
1542 .map(|v| v.name().to_string())
1543 .collect();
1544 let params = params
1545 .into_iter()
1546 .filter(|(name, _value)| notify_set.contains(*name))
1547 .map(|(name, value)| ParameterStatus {
1548 name: name.to_string(),
1549 value,
1550 })
1551 .collect();
1552 SqlResult::ok(
1553 client,
1554 tag.expect("ok only called on tag-generating results"),
1555 params,
1556 )
1557 .into()
1558 }
1559 ExecuteResponse::SetVariable { name, .. } => {
1560 let mut params = Vec::with_capacity(1);
1561 if let Some(var) = client
1562 .session()
1563 .vars()
1564 .notify_set()
1565 .find(|v| v.name() == &name)
1566 {
1567 params.push(ParameterStatus {
1568 name,
1569 value: var.value(),
1570 });
1571 };
1572 SqlResult::ok(
1573 client,
1574 tag.expect("ok only called on tag-generating results"),
1575 params,
1576 )
1577 .into()
1578 }
1579 ExecuteResponse::SendingRowsStreaming {
1580 rows,
1581 instance_id,
1582 strategy,
1583 } => {
1584 let max_query_result_size =
1585 usize::cast_from(client.get_system_vars().await.max_result_size());
1586
1587 let rows_stream = RecordFirstRowStream::new(
1588 Box::new(rows),
1589 execute_started,
1590 client,
1591 Some(instance_id),
1592 Some(strategy),
1593 );
1594
1595 SqlResult::rows(
1596 sender,
1597 client,
1598 rows_stream,
1599 max_query_result_size,
1600 &desc.relation_desc.expect("RelationDesc must exist"),
1601 )
1602 .await?
1603 .into()
1604 }
1605 ExecuteResponse::SendingRowsImmediate { rows } => {
1606 let max_query_result_size =
1607 usize::cast_from(client.get_system_vars().await.max_result_size());
1608
1609 let rows = futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1610 let rows_stream =
1611 RecordFirstRowStream::new(Box::new(rows), execute_started, client, None, None);
1612
1613 SqlResult::rows(
1614 sender,
1615 client,
1616 rows_stream,
1617 max_query_result_size,
1618 &desc.relation_desc.expect("RelationDesc must exist"),
1619 )
1620 .await?
1621 .into()
1622 }
1623 ExecuteResponse::Subscribing {
1624 rx,
1625 ctx_extra,
1626 instance_id,
1627 } => StatementResult::Subscribe {
1628 tag: "SUBSCRIBE".into(),
1629 desc: desc.relation_desc.unwrap(),
1630 rx: RecordFirstRowStream::new(
1631 Box::new(UnboundedReceiverStream::new(rx)),
1632 execute_started,
1633 client,
1634 Some(instance_id),
1635 None,
1636 ),
1637 ctx_extra,
1638 },
1639 res @ (ExecuteResponse::Fetch { .. }
1640 | ExecuteResponse::CopyTo { .. }
1641 | ExecuteResponse::CopyFrom { .. }
1642 | ExecuteResponse::DeclaredCursor
1643 | ExecuteResponse::ClosedCursor) => SqlResult::err(
1644 client,
1645 Error::Unstructured(anyhow!(
1646 "internal error: encountered prohibited ExecuteResponse {:?}.\n\n
1647 This is a bug. Can you please file an bug report letting us know?\n
1648 https://github.com/MaterializeInc/materialize/discussions/new?category=bug-reports",
1649 ExecuteResponseKind::from(res)
1650 )),
1651 )
1652 .into(),
1653 })
1654}
1655
1656fn make_notices(client: &mut SessionClient) -> Vec<Notice> {
1657 client
1658 .session()
1659 .drain_notices()
1660 .into_iter()
1661 .map(|notice| Notice {
1662 message: notice.to_string(),
1663 code: notice.code().code().to_string(),
1664 severity: notice.severity().as_str().to_lowercase(),
1665 detail: notice.detail(),
1666 hint: notice.hint(),
1667 })
1668 .collect()
1669}
1670
1671fn is_txn_exit_stmt(stmt: &Statement<Raw>) -> bool {
1674 matches!(
1675 stmt,
1676 Statement::Commit(_) | Statement::Rollback(_) | Statement::Prepare(_)
1677 )
1678}
1679
1680#[cfg(test)]
1681mod tests {
1682 use std::collections::BTreeMap;
1683
1684 use super::{Password, WebSocketAuth};
1685
1686 #[mz_ore::test]
1687 fn smoke_test_websocket_auth_parse() {
1688 struct TestCase {
1689 json: &'static str,
1690 expected: WebSocketAuth,
1691 }
1692
1693 let test_cases = vec![
1694 TestCase {
1695 json: r#"{ "user": "mz", "password": "1234" }"#,
1696 expected: WebSocketAuth::Basic {
1697 user: "mz".to_string(),
1698 password: Password("1234".to_string()),
1699 options: BTreeMap::default(),
1700 },
1701 },
1702 TestCase {
1703 json: r#"{ "user": "mz", "password": "1234", "options": {} }"#,
1704 expected: WebSocketAuth::Basic {
1705 user: "mz".to_string(),
1706 password: Password("1234".to_string()),
1707 options: BTreeMap::default(),
1708 },
1709 },
1710 TestCase {
1711 json: r#"{ "token": "i_am_a_token" }"#,
1712 expected: WebSocketAuth::Bearer {
1713 token: "i_am_a_token".to_string(),
1714 options: BTreeMap::default(),
1715 },
1716 },
1717 TestCase {
1718 json: r#"{ "token": "i_am_a_token", "options": { "foo": "bar" } }"#,
1719 expected: WebSocketAuth::Bearer {
1720 token: "i_am_a_token".to_string(),
1721 options: BTreeMap::from([("foo".to_string(), "bar".to_string())]),
1722 },
1723 },
1724 ];
1725
1726 fn assert_parse(json: &'static str, expected: WebSocketAuth) {
1727 let parsed: WebSocketAuth = serde_json::from_str(json).unwrap();
1728 assert_eq!(parsed, expected);
1729 }
1730
1731 for TestCase { json, expected } in test_cases {
1732 assert_parse(json, expected)
1733 }
1734 }
1735}