1use std::borrow::Cow;
11use std::collections::BTreeMap;
12use std::net::{IpAddr, SocketAddr};
13use std::pin::pin;
14use std::sync::Arc;
15use std::time::{Duration, SystemTime};
16
17use anyhow::anyhow;
18use async_trait::async_trait;
19use axum::extract::connect_info::ConnectInfo;
20use axum::extract::ws::{CloseFrame, Message, WebSocket};
21use axum::extract::{State, WebSocketUpgrade};
22use axum::response::IntoResponse;
23use axum::{Extension, Json};
24use futures::Future;
25use futures::future::BoxFuture;
26
27use http::StatusCode;
28use itertools::Itertools;
29use mz_adapter::client::RecordFirstRowStream;
30use mz_adapter::session::{EndTransactionAction, TransactionStatus};
31use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
32use mz_adapter::{
33 AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, ExecuteResponseKind,
34 PeekResponseUnary, SessionClient, verify_datum_desc,
35};
36use mz_auth::password::Password;
37use mz_catalog::memory::objects::{Cluster, ClusterReplica};
38use mz_interchange::encode::TypedDatum;
39use mz_interchange::json::{JsonNumberPolicy, ToJson};
40use mz_ore::cast::CastFrom;
41use mz_ore::metrics::{MakeCollectorOpts, MetricsRegistry};
42use mz_ore::result::ResultExt;
43use mz_repr::{Datum, RelationDesc, RowArena, RowIterator};
44use mz_sql::ast::display::AstDisplay;
45use mz_sql::ast::{CopyDirection, CopyStatement, CopyTarget, Raw, Statement, StatementKind};
46use mz_sql::parse::StatementParseResult;
47use mz_sql::plan::Plan;
48use mz_sql::session::metadata::SessionMetadata;
49use prometheus::Opts;
50use prometheus::core::{AtomicF64, GenericGaugeVec};
51use serde::{Deserialize, Serialize};
52use tokio::{select, time};
53use tokio_postgres::error::SqlState;
54use tokio_stream::wrappers::UnboundedReceiverStream;
55use tower_sessions::Session as TowerSession;
56use tracing::{debug, error, info};
57use tungstenite::protocol::frame::coding::CloseCode;
58
59use crate::http::prometheus::PrometheusSqlQuery;
60use crate::http::{
61 AuthError, AuthedClient, AuthedUser, MAX_REQUEST_SIZE, SESSION_DURATION, TowerSessionData,
62 WsState, init_ws,
63};
64
65#[derive(Debug, thiserror::Error)]
66pub enum Error {
67 #[error(transparent)]
68 Adapter(#[from] AdapterError),
69 #[error(transparent)]
70 Json(#[from] serde_json::Error),
71 #[error(transparent)]
72 Axum(#[from] axum::Error),
73 #[error("SUBSCRIBE only supported over websocket")]
74 SubscribeOnlyOverWs,
75 #[error("current transaction is aborted, commands ignored until end of transaction block")]
76 AbortedTransaction,
77 #[error("unsupported via this API: {0}")]
78 Unsupported(String),
79 #[error("{0}")]
80 Unstructured(anyhow::Error),
81}
82
83impl Error {
84 pub fn detail(&self) -> Option<String> {
85 match self {
86 Error::Adapter(err) => err.detail(),
87 _ => None,
88 }
89 }
90
91 pub fn hint(&self) -> Option<String> {
92 match self {
93 Error::Adapter(err) => err.hint(),
94 _ => None,
95 }
96 }
97
98 pub fn position(&self) -> Option<usize> {
99 match self {
100 Error::Adapter(err) => err.position(),
101 _ => None,
102 }
103 }
104
105 pub fn code(&self) -> SqlState {
106 match self {
107 Error::Adapter(err) => err.code(),
108 Error::AbortedTransaction => SqlState::IN_FAILED_SQL_TRANSACTION,
109 _ => SqlState::INTERNAL_ERROR,
110 }
111 }
112}
113
114static PER_REPLICA_LABELS: &[&str] = &["replica_full_name", "instance_id", "replica_id"];
115
116async fn execute_promsql_query(
117 client: &mut AuthedClient,
118 query: &PrometheusSqlQuery<'_>,
119 metrics_registry: &MetricsRegistry,
120 metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
121 cluster: Option<(&Cluster, &ClusterReplica)>,
122) {
123 assert_eq!(query.per_replica, cluster.is_some());
124
125 let mut res = SqlResponse {
126 results: Vec::new(),
127 };
128
129 execute_request(client, query.to_sql_request(cluster), &mut res)
130 .await
131 .expect("valid SQL query");
132
133 let result = match res.results.as_slice() {
134 [
137 SqlResult::Ok { .. },
138 SqlResult::Ok { .. },
139 SqlResult::Ok { .. },
140 result,
141 ] => result,
142 _ => {
146 info!(
147 "error executing prometheus query {}: {:?}",
148 query.metric_name, res
149 );
150 return;
151 }
152 };
153
154 let SqlResult::Rows { desc, rows, .. } = result else {
155 info!(
156 "did not receive rows for SQL query for prometheus metric {}: {:?}, {:?}",
157 query.metric_name, result, cluster
158 );
159 return;
160 };
161
162 let gauge_vec = metrics_by_name
163 .entry(query.metric_name.to_string())
164 .or_insert_with(|| {
165 let mut label_names: Vec<String> = desc
166 .columns
167 .iter()
168 .filter(|col| col.name != query.value_column_name)
169 .map(|col| col.name.clone())
170 .collect();
171
172 if query.per_replica {
173 label_names.extend(PER_REPLICA_LABELS.iter().map(|label| label.to_string()));
174 }
175
176 metrics_registry.register::<GenericGaugeVec<AtomicF64>>(MakeCollectorOpts {
177 opts: Opts::new(query.metric_name, query.help).variable_labels(label_names),
178 buckets: None,
179 })
180 });
181
182 for row in rows {
183 let mut label_values = desc
184 .columns
185 .iter()
186 .zip_eq(row)
187 .filter(|(col, _)| col.name != query.value_column_name)
188 .map(|(_, val)| val.as_str().expect("must be string"))
189 .collect::<Vec<_>>();
190
191 let value = desc
192 .columns
193 .iter()
194 .zip_eq(row)
195 .find(|(col, _)| col.name == query.value_column_name)
196 .map(|(_, val)| val.as_str().unwrap_or("0").parse::<f64>().unwrap_or(0.0))
197 .unwrap_or(0.0);
198
199 match cluster {
200 Some((cluster, replica)) => {
201 let replica_full_name = format!("{}.{}", cluster.name, replica.name);
202 let cluster_id = cluster.id.to_string();
203 let replica_id = replica.replica_id.to_string();
204
205 label_values.push(&replica_full_name);
206 label_values.push(&cluster_id);
207 label_values.push(&replica_id);
208
209 gauge_vec
210 .get_metric_with_label_values(&label_values)
211 .expect("valid labels")
212 .set(value);
213 }
214 None => {
215 gauge_vec
216 .get_metric_with_label_values(&label_values)
217 .expect("valid labels")
218 .set(value);
219 }
220 }
221 }
222}
223
224async fn handle_promsql_query(
225 client: &mut AuthedClient,
226 query: &PrometheusSqlQuery<'_>,
227 metrics_registry: &MetricsRegistry,
228 metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
229) {
230 if !query.per_replica {
231 execute_promsql_query(client, query, metrics_registry, metrics_by_name, None).await;
232 return;
233 }
234
235 let catalog = client.client.catalog_snapshot("handle_promsql_query").await;
236 let clusters: Vec<&Cluster> = catalog.clusters().collect();
237
238 for cluster in clusters {
239 for replica in cluster.replicas() {
240 execute_promsql_query(
241 client,
242 query,
243 metrics_registry,
244 metrics_by_name,
245 Some((cluster, replica)),
246 )
247 .await;
248 }
249 }
250}
251
252pub async fn handle_promsql(
253 mut client: AuthedClient,
254 queries: &[PrometheusSqlQuery<'_>],
255) -> MetricsRegistry {
256 let metrics_registry = MetricsRegistry::new();
257 let mut metrics_by_name = BTreeMap::new();
258
259 for query in queries {
260 handle_promsql_query(&mut client, query, &metrics_registry, &mut metrics_by_name).await;
261 }
262
263 metrics_registry
264}
265
266pub async fn handle_sql(
267 mut client: AuthedClient,
268 Json(request): Json<SqlRequest>,
269) -> impl IntoResponse {
270 let mut res = SqlResponse {
271 results: Vec::new(),
272 };
273 match execute_request(&mut client, request, &mut res).await {
276 Ok(()) => Ok(Json(res)),
277 Err(e) => Err((StatusCode::BAD_REQUEST, e.to_string())),
278 }
279}
280
281pub async fn handle_sql_ws(
282 State(state): State<WsState>,
283 existing_user: Option<Extension<AuthedUser>>,
284 ws: WebSocketUpgrade,
285 ConnectInfo(addr): ConnectInfo<SocketAddr>,
286 tower_session: Option<Extension<TowerSession>>,
287) -> impl IntoResponse {
288 let user = match (existing_user, tower_session) {
290 (Some(Extension(user)), _) => Some(user),
291 (None, Some(session)) => {
292 if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
293 if session_data
295 .last_activity
296 .elapsed()
297 .unwrap_or(Duration::MAX)
298 > SESSION_DURATION
299 {
300 let _ = session.delete().await;
301 return Err(AuthError::SessionExpired);
302 }
303 let mut updated_data = session_data.clone();
305 updated_data.last_activity = SystemTime::now();
306 session
307 .insert("data", &updated_data)
308 .await
309 .map_err(|_| AuthError::FailedToUpdateSession)?;
310 Some(AuthedUser {
312 name: session_data.username,
313 external_metadata_rx: None,
314 internal_metadata: Some(session_data.internal_metadata),
315 })
316 } else {
317 None
318 }
319 }
320 _ => None,
321 };
322
323 let addr = Box::new(addr.ip());
324 Ok(ws
325 .max_message_size(MAX_REQUEST_SIZE)
326 .on_upgrade(|ws| async move { run_ws(&state, user, *addr, ws).await }))
327}
328
329#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
330#[serde(untagged)]
331pub enum WebSocketAuth {
332 Basic {
333 user: String,
334 password: Password,
335 #[serde(default)]
336 options: BTreeMap<String, String>,
337 },
338 Bearer {
339 token: String,
340 #[serde(default)]
341 options: BTreeMap<String, String>,
342 },
343 OptionsOnly {
344 #[serde(default)]
345 options: BTreeMap<String, String>,
346 },
347}
348
349async fn run_ws(state: &WsState, user: Option<AuthedUser>, peer_addr: IpAddr, mut ws: WebSocket) {
350 let mut client = match init_ws(state, user, peer_addr, &mut ws).await {
351 Ok(client) => client,
352 Err(e) => {
353 debug!("WS request failed init: {}", e);
357 let reason = match e.downcast_ref::<AdapterError>() {
358 Some(error) => Cow::Owned(error.to_string()),
359 None => "unauthorized".into(),
360 };
361 let _ = ws
362 .send(Message::Close(Some(CloseFrame {
363 code: CloseCode::Protocol.into(),
364 reason,
365 })))
366 .await;
367 return;
368 }
369 };
370
371 let mut msgs = Vec::new();
373 let session = client.client.session();
374 for var in session.vars().notify_set() {
375 msgs.push(WebSocketResponse::ParameterStatus(ParameterStatus {
376 name: var.name().to_string(),
377 value: var.value(),
378 }));
379 }
380 msgs.push(WebSocketResponse::BackendKeyData(BackendKeyData {
381 conn_id: session.conn_id().unhandled(),
382 secret_key: session.secret_key(),
383 }));
384 msgs.push(WebSocketResponse::ReadyForQuery(
385 session.transaction_code().into(),
386 ));
387 for msg in msgs {
388 let _ = ws
389 .send(Message::Text(
390 serde_json::to_string(&msg).expect("must serialize"),
391 ))
392 .await;
393 }
394
395 let notices = session.drain_notices();
397 if let Err(err) = forward_notices(&mut ws, notices).await {
398 debug!("failed to forward notices to WebSocket, {err:?}");
399 return;
400 }
401
402 loop {
403 let msg = select! {
405 biased;
406
407 Some(timeout) = client.client.recv_timeout() => {
409 client.client.terminate().await;
410 let _ = ws.recv().await;
414 let err = Error::from(AdapterError::from(timeout));
415 let _ = send_ws_response(&mut ws, WebSocketResponse::Error(err.into())).await;
416 return;
417 },
418 message = ws.recv() => message,
419 };
420
421 client.client.remove_idle_in_transaction_session_timeout();
422
423 let msg = match msg {
424 Some(Ok(msg)) => msg,
425 _ => {
426 return;
428 }
429 };
430
431 let req: Result<SqlRequest, Error> = match msg {
432 Message::Text(data) => serde_json::from_str(&data).err_into(),
433 Message::Binary(data) => serde_json::from_slice(&data).err_into(),
434 Message::Ping(_) => {
436 continue;
437 }
438 Message::Pong(_) => {
439 continue;
440 }
441 Message::Close(_) => {
442 return;
443 }
444 };
445
446 let err = match run_ws_request(req, &mut client, &mut ws).await {
448 Ok(()) => None,
449 Err(err) => Some(WebSocketResponse::Error(err.into())),
450 };
451
452 let ws_response = || async {
459 if let Some(e_resp) = err {
461 send_ws_response(&mut ws, e_resp).await?;
462 }
463
464 let notices = client.client.session().drain_notices();
466 forward_notices(&mut ws, notices).await?;
467
468 let ready =
470 WebSocketResponse::ReadyForQuery(client.client.session().transaction_code().into());
471 send_ws_response(&mut ws, ready).await?;
472
473 Ok::<_, Error>(())
474 };
475
476 if let Err(err) = ws_response().await {
477 debug!("failed to send response over WebSocket, {err:?}");
478 return;
479 }
480 }
481}
482
483async fn run_ws_request(
484 req: Result<SqlRequest, Error>,
485 client: &mut AuthedClient,
486 ws: &mut WebSocket,
487) -> Result<(), Error> {
488 let req = req?;
489 execute_request(client, req, ws).await
490}
491
492async fn send_ws_response(ws: &mut WebSocket, resp: WebSocketResponse) -> Result<(), Error> {
494 let msg = serde_json::to_string(&resp).unwrap();
495 let msg = Message::Text(msg);
496 ws.send(msg).await?;
497
498 Ok(())
499}
500
501async fn forward_notices(
503 ws: &mut WebSocket,
504 notices: impl IntoIterator<Item = AdapterNotice>,
505) -> Result<(), Error> {
506 let ws_notices = notices.into_iter().map(|notice| {
507 WebSocketResponse::Notice(Notice {
508 message: notice.to_string(),
509 code: notice.code().code().to_string(),
510 severity: notice.severity().as_str().to_lowercase(),
511 detail: notice.detail(),
512 hint: notice.hint(),
513 })
514 });
515
516 for notice in ws_notices {
517 send_ws_response(ws, notice).await?;
518 }
519
520 Ok(())
521}
522
523#[derive(Serialize, Deserialize, Debug)]
525#[serde(untagged)]
526pub enum SqlRequest {
527 Simple {
529 query: String,
532 },
533 Extended {
535 queries: Vec<ExtendedRequest>,
537 },
538}
539
540#[derive(Serialize, Deserialize, Debug)]
542pub struct ExtendedRequest {
543 query: String,
545 #[serde(default)]
547 params: Vec<Option<String>>,
548}
549
550#[derive(Debug, Serialize, Deserialize)]
552pub struct SqlResponse {
553 results: Vec<SqlResult>,
555}
556
557enum StatementResult {
558 SqlResult(SqlResult),
559 Subscribe {
560 desc: RelationDesc,
561 tag: String,
562 rx: RecordFirstRowStream,
563 ctx_extra: ExecuteContextExtra,
564 },
565}
566
567impl From<SqlResult> for StatementResult {
568 fn from(inner: SqlResult) -> Self {
569 Self::SqlResult(inner)
570 }
571}
572
573#[derive(Debug, Serialize, Deserialize)]
575#[serde(untagged)]
576pub enum SqlResult {
577 Rows {
579 tag: String,
581 rows: Vec<Vec<serde_json::Value>>,
583 desc: Description,
585 notices: Vec<Notice>,
587 },
588 Ok {
590 ok: String,
592 notices: Vec<Notice>,
594 #[serde(skip_serializing_if = "Vec::is_empty")]
598 parameters: Vec<ParameterStatus>,
599 },
600 Err {
602 error: SqlError,
603 notices: Vec<Notice>,
605 },
606}
607
608impl SqlResult {
609 async fn rows<S>(
613 sender: &mut S,
614 client: &mut SessionClient,
615 mut rows_stream: RecordFirstRowStream,
616 max_query_result_size: usize,
617 desc: &RelationDesc,
618 ) -> Result<SqlResult, Error>
619 where
620 S: ResultSender,
621 {
622 let mut rows: Vec<Vec<serde_json::Value>> = vec![];
623 let mut datum_vec = mz_repr::DatumVec::new();
624 let types = &desc.typ().column_types;
625
626 let mut query_result_size = 0;
627
628 loop {
629 let peek_response = tokio::select! {
630 notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
631 sender.emit_streaming_notices(vec![notice]).await?;
632 continue;
633 }
634 e = sender.connection_error() => return Err(e),
635 r = rows_stream.recv() => {
636 match r {
637 Some(r) => r,
638 None => break,
639 }
640 },
641 };
642
643 let mut sql_rows = match peek_response {
644 PeekResponseUnary::Rows(rows) => rows,
645 PeekResponseUnary::Error(e) => {
646 return Ok(SqlResult::err(client, Error::Unstructured(anyhow!(e))));
647 }
648 PeekResponseUnary::Canceled => {
649 return Ok(SqlResult::err(client, AdapterError::Canceled));
650 }
651 };
652
653 if let Err(err) = verify_datum_desc(desc, &mut sql_rows) {
654 return Ok(SqlResult::Err {
655 error: err.into(),
656 notices: make_notices(client),
657 });
658 }
659
660 while let Some(row) = sql_rows.next() {
661 query_result_size += row.byte_len();
662 if query_result_size > max_query_result_size {
663 use bytesize::ByteSize;
664 return Ok(SqlResult::err(
665 client,
666 AdapterError::ResultSize(format!(
667 "result exceeds max size of {}",
668 ByteSize::b(u64::cast_from(max_query_result_size))
669 )),
670 ));
671 }
672
673 let datums = datum_vec.borrow_with(row);
674 rows.push(
675 datums
676 .iter()
677 .enumerate()
678 .map(|(i, d)| {
679 TypedDatum::new(*d, &types[i])
680 .json(&JsonNumberPolicy::ConvertNumberToString)
681 })
682 .collect(),
683 );
684 }
685 }
686
687 let tag = format!("SELECT {}", rows.len());
688 Ok(SqlResult::Rows {
689 tag,
690 rows,
691 desc: Description::from(desc),
692 notices: make_notices(client),
693 })
694 }
695
696 fn err(client: &mut SessionClient, error: impl Into<SqlError>) -> SqlResult {
697 SqlResult::Err {
698 error: error.into(),
699 notices: make_notices(client),
700 }
701 }
702
703 fn ok(client: &mut SessionClient, tag: String, params: Vec<ParameterStatus>) -> SqlResult {
704 SqlResult::Ok {
705 ok: tag,
706 parameters: params,
707 notices: make_notices(client),
708 }
709 }
710}
711
712#[derive(Debug, Deserialize, Serialize)]
713pub struct SqlError {
714 pub message: String,
715 pub code: String,
716 #[serde(skip_serializing_if = "Option::is_none")]
717 pub detail: Option<String>,
718 #[serde(skip_serializing_if = "Option::is_none")]
719 pub hint: Option<String>,
720 #[serde(skip_serializing_if = "Option::is_none")]
721 pub position: Option<usize>,
722}
723
724impl From<Error> for SqlError {
725 fn from(err: Error) -> Self {
726 SqlError {
727 message: err.to_string(),
728 code: err.code().code().to_string(),
729 detail: err.detail(),
730 hint: err.hint(),
731 position: err.position(),
732 }
733 }
734}
735
736impl From<AdapterError> for SqlError {
737 fn from(value: AdapterError) -> Self {
738 Error::from(value).into()
739 }
740}
741
742#[derive(Debug, Deserialize, Serialize)]
743#[serde(tag = "type", content = "payload")]
744pub enum WebSocketResponse {
745 ReadyForQuery(String),
746 Notice(Notice),
747 Rows(Description),
748 Row(Vec<serde_json::Value>),
749 CommandStarting(CommandStarting),
750 CommandComplete(String),
751 Error(SqlError),
752 ParameterStatus(ParameterStatus),
753 BackendKeyData(BackendKeyData),
754}
755
756#[derive(Debug, Serialize, Deserialize)]
757pub struct Notice {
758 message: String,
759 code: String,
760 severity: String,
761 #[serde(skip_serializing_if = "Option::is_none")]
762 pub detail: Option<String>,
763 #[serde(skip_serializing_if = "Option::is_none")]
764 pub hint: Option<String>,
765}
766
767impl Notice {
768 pub fn message(&self) -> &str {
769 &self.message
770 }
771}
772
773#[derive(Debug, Serialize, Deserialize)]
774pub struct Description {
775 pub columns: Vec<Column>,
776}
777
778impl From<&RelationDesc> for Description {
779 fn from(desc: &RelationDesc) -> Self {
780 let columns = desc
781 .iter()
782 .map(|(name, typ)| {
783 let pg_type = mz_pgrepr::Type::from(&typ.scalar_type);
784 Column {
785 name: name.to_string(),
786 type_oid: pg_type.oid(),
787 type_len: pg_type.typlen(),
788 type_mod: pg_type.typmod(),
789 }
790 })
791 .collect();
792 Description { columns }
793 }
794}
795
796#[derive(Debug, Serialize, Deserialize)]
797pub struct Column {
798 pub name: String,
799 pub type_oid: u32,
800 pub type_len: i16,
801 pub type_mod: i32,
802}
803
804#[derive(Debug, Serialize, Deserialize)]
805pub struct ParameterStatus {
806 name: String,
807 value: String,
808}
809
810#[derive(Debug, Serialize, Deserialize)]
811pub struct BackendKeyData {
812 conn_id: u32,
813 secret_key: u32,
814}
815
816#[derive(Debug, Serialize, Deserialize)]
817pub struct CommandStarting {
818 has_rows: bool,
819 is_streaming: bool,
820}
821
822#[async_trait]
826trait ResultSender: Send {
827 const SUPPORTS_STREAMING_NOTICES: bool = false;
828
829 async fn add_result(
837 &mut self,
838 client: &mut SessionClient,
839 res: StatementResult,
840 ) -> (
841 Result<Result<(), ()>, Error>,
842 Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
843 );
844
845 fn connection_error(&mut self) -> BoxFuture<'_, Error>;
847 fn allow_subscribe(&self) -> bool;
849
850 async fn emit_streaming_notices(&mut self, _: Vec<AdapterNotice>) -> Result<(), Error> {
854 unreachable!("streaming notices marked as unsupported")
855 }
856}
857
858#[async_trait]
859impl ResultSender for SqlResponse {
860 async fn add_result(
868 &mut self,
869 _client: &mut SessionClient,
870 res: StatementResult,
871 ) -> (
872 Result<Result<(), ()>, Error>,
873 Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
874 ) {
875 let (res, stmt_logging) = match res {
876 StatementResult::SqlResult(res) => {
877 let is_err = matches!(res, SqlResult::Err { .. });
878 self.results.push(res);
879 let res = if is_err { Err(()) } else { Ok(()) };
880 (res, None)
881 }
882 StatementResult::Subscribe { ctx_extra, .. } => {
883 let message = "SUBSCRIBE only supported over websocket";
884 self.results.push(SqlResult::Err {
885 error: Error::SubscribeOnlyOverWs.into(),
886 notices: Vec::new(),
887 });
888 (
889 Err(()),
890 Some((
891 StatementEndedExecutionReason::Errored {
892 error: message.into(),
893 },
894 ctx_extra,
895 )),
896 )
897 }
898 };
899 (Ok(res), stmt_logging)
900 }
901
902 fn connection_error(&mut self) -> BoxFuture<'_, Error> {
903 Box::pin(futures::future::pending())
904 }
905
906 fn allow_subscribe(&self) -> bool {
907 false
908 }
909}
910
911#[async_trait]
912impl ResultSender for WebSocket {
913 const SUPPORTS_STREAMING_NOTICES: bool = true;
914
915 async fn add_result(
921 &mut self,
922 client: &mut SessionClient,
923 res: StatementResult,
924 ) -> (
925 Result<Result<(), ()>, Error>,
926 Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
927 ) {
928 let (has_rows, is_streaming) = match res {
929 StatementResult::SqlResult(SqlResult::Err { .. }) => (false, false),
930 StatementResult::SqlResult(SqlResult::Ok { .. }) => (false, false),
931 StatementResult::SqlResult(SqlResult::Rows { .. }) => (true, false),
932 StatementResult::Subscribe { .. } => (true, true),
933 };
934 if let Err(e) = send_ws_response(
935 self,
936 WebSocketResponse::CommandStarting(CommandStarting {
937 has_rows,
938 is_streaming,
939 }),
940 )
941 .await
942 {
943 return (Err(e), None);
944 }
945
946 let (is_err, msgs, stmt_logging) = match res {
947 StatementResult::SqlResult(SqlResult::Rows {
948 tag,
949 rows,
950 desc,
951 notices,
952 }) => {
953 let mut msgs = vec![WebSocketResponse::Rows(desc)];
954 msgs.extend(rows.into_iter().map(WebSocketResponse::Row));
955 msgs.push(WebSocketResponse::CommandComplete(tag));
956 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
957 (false, msgs, None)
958 }
959 StatementResult::SqlResult(SqlResult::Ok {
960 ok,
961 parameters,
962 notices,
963 }) => {
964 let mut msgs = vec![WebSocketResponse::CommandComplete(ok)];
965 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
966 msgs.extend(
967 parameters
968 .into_iter()
969 .map(WebSocketResponse::ParameterStatus),
970 );
971 (false, msgs, None)
972 }
973 StatementResult::SqlResult(SqlResult::Err { error, notices }) => {
974 let mut msgs = vec![WebSocketResponse::Error(error)];
975 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
976 (true, msgs, None)
977 }
978 StatementResult::Subscribe {
979 ref desc,
980 tag,
981 mut rx,
982 ctx_extra,
983 } => {
984 if let Err(e) = send_ws_response(self, WebSocketResponse::Rows(desc.into())).await {
985 return (
988 Err(e),
989 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
990 );
991 }
992
993 let mut datum_vec = mz_repr::DatumVec::new();
994 let mut result_size: usize = 0;
995 let mut rows_returned = 0;
996 loop {
997 let res = match await_rows(self, client, rx.recv()).await {
998 Ok(res) => res,
999 Err(e) => {
1000 return (
1003 Err(e),
1004 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1005 );
1006 }
1007 };
1008 match res {
1009 Some(PeekResponseUnary::Rows(mut rows)) => {
1010 if let Err(err) = verify_datum_desc(desc, &mut rows) {
1011 let error = err.to_string();
1012 break (
1013 true,
1014 vec![WebSocketResponse::Error(err.into())],
1015 Some((
1016 StatementEndedExecutionReason::Errored { error },
1017 ctx_extra,
1018 )),
1019 );
1020 }
1021
1022 rows_returned += rows.count();
1023 while let Some(row) = rows.next() {
1024 result_size += row.byte_len();
1025 let datums = datum_vec.borrow_with(row);
1026 let types = &desc.typ().column_types;
1027 if let Err(e) = send_ws_response(
1028 self,
1029 WebSocketResponse::Row(
1030 datums
1031 .iter()
1032 .enumerate()
1033 .map(|(i, d)| {
1034 TypedDatum::new(*d, &types[i])
1035 .json(&JsonNumberPolicy::ConvertNumberToString)
1036 })
1037 .collect(),
1038 ),
1039 )
1040 .await
1041 {
1042 return (
1045 Err(e),
1046 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1047 );
1048 }
1049 }
1050 }
1051 Some(PeekResponseUnary::Error(error)) => {
1052 break (
1053 true,
1054 vec![WebSocketResponse::Error(
1055 Error::Unstructured(anyhow!(error.clone())).into(),
1056 )],
1057 Some((StatementEndedExecutionReason::Errored { error }, ctx_extra)),
1058 );
1059 }
1060 Some(PeekResponseUnary::Canceled) => {
1061 break (
1062 true,
1063 vec![WebSocketResponse::Error(AdapterError::Canceled.into())],
1064 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1065 );
1066 }
1067 None => {
1068 break (
1069 false,
1070 vec![WebSocketResponse::CommandComplete(tag)],
1071 Some((
1072 StatementEndedExecutionReason::Success {
1073 result_size: Some(u64::cast_from(result_size)),
1074 rows_returned: Some(u64::cast_from(rows_returned)),
1075 execution_strategy: Some(
1076 StatementExecutionStrategy::Standard,
1077 ),
1078 },
1079 ctx_extra,
1080 )),
1081 );
1082 }
1083 }
1084 }
1085 }
1086 };
1087 for msg in msgs {
1088 if let Err(e) = send_ws_response(self, msg).await {
1089 return (
1090 Err(e),
1091 stmt_logging.map(|(_old_reason, ctx_extra)| {
1092 (StatementEndedExecutionReason::Canceled, ctx_extra)
1093 }),
1094 );
1095 }
1096 }
1097 (Ok(if is_err { Err(()) } else { Ok(()) }), stmt_logging)
1098 }
1099
1100 fn connection_error(&mut self) -> BoxFuture<'_, Error> {
1103 Box::pin(async {
1104 let mut tick = time::interval(Duration::from_secs(1));
1105 tick.tick().await;
1106 loop {
1107 tick.tick().await;
1108 if let Err(err) = self.send(Message::Ping(Vec::new())).await {
1109 return err.into();
1110 }
1111 }
1112 })
1113 }
1114
1115 fn allow_subscribe(&self) -> bool {
1116 true
1117 }
1118
1119 async fn emit_streaming_notices(&mut self, notices: Vec<AdapterNotice>) -> Result<(), Error> {
1120 forward_notices(self, notices).await
1121 }
1122}
1123
1124async fn await_rows<S, F, R>(sender: &mut S, client: &mut SessionClient, f: F) -> Result<R, Error>
1125where
1126 S: ResultSender,
1127 F: Future<Output = R> + Send,
1128{
1129 let mut f = pin!(f);
1130 loop {
1131 tokio::select! {
1132 notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
1133 sender.emit_streaming_notices(vec![notice]).await?;
1134 }
1135 e = sender.connection_error() => return Err(e),
1136 r = &mut f => return Ok(r),
1137 }
1138 }
1139}
1140
1141async fn send_and_retire<S: ResultSender>(
1142 res: StatementResult,
1143 client: &mut SessionClient,
1144 sender: &mut S,
1145) -> Result<Result<(), ()>, Error> {
1146 let (res, stmt_logging) = sender.add_result(client, res).await;
1147 if let Some((reason, ctx_extra)) = stmt_logging {
1148 client.retire_execute(ctx_extra, reason);
1149 }
1150 res
1151}
1152
1153async fn execute_stmt_group<S: ResultSender>(
1155 client: &mut SessionClient,
1156 sender: &mut S,
1157 stmt_group: Vec<(Statement<Raw>, String, Vec<Option<String>>)>,
1158) -> Result<Result<(), ()>, Error> {
1159 let num_stmts = stmt_group.len();
1160 for (stmt, sql, params) in stmt_group {
1161 assert!(
1162 num_stmts <= 1 || params.is_empty(),
1163 "statement groups contain more than 1 statement iff Simple request, which does not support parameters"
1164 );
1165
1166 let is_aborted_txn = matches!(client.session().transaction(), TransactionStatus::Failed(_));
1167 if is_aborted_txn && !is_txn_exit_stmt(&stmt) {
1168 let err = SqlResult::err(client, Error::AbortedTransaction);
1169 let _ = send_and_retire(err.into(), client, sender).await?;
1170 return Ok(Err(()));
1171 }
1172
1173 if let Err(e) = client.start_transaction(Some(num_stmts)) {
1176 let err = SqlResult::err(client, e);
1177 let _ = send_and_retire(err.into(), client, sender).await?;
1178 return Ok(Err(()));
1179 }
1180 let res = execute_stmt(client, sender, stmt, sql, params).await?;
1181 let is_err = send_and_retire(res, client, sender).await?;
1182
1183 if is_err.is_err() {
1184 let txn = client.session().transaction();
1187 match txn {
1188 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
1191 TransactionStatus::Started(_) | TransactionStatus::InTransactionImplicit(_) => {
1193 if let Err(err) = client.end_transaction(EndTransactionAction::Rollback).await {
1194 let err = SqlResult::err(client, err);
1195 let _ = send_and_retire(err.into(), client, sender).await?;
1196 }
1197 }
1198 TransactionStatus::InTransaction(_) => {
1200 client.fail_transaction();
1201 }
1202 }
1203 return Ok(Err(()));
1204 }
1205 }
1206 Ok(Ok(()))
1207}
1208
1209async fn execute_request<S: ResultSender>(
1214 client: &mut AuthedClient,
1215 request: SqlRequest,
1216 sender: &mut S,
1217) -> Result<(), Error> {
1218 let client = &mut client.client;
1219
1220 fn check_prohibited_stmts<S: ResultSender>(
1223 sender: &S,
1224 stmt: &Statement<Raw>,
1225 ) -> Result<(), Error> {
1226 let kind: StatementKind = stmt.into();
1227 let execute_responses = Plan::generated_from(&kind)
1228 .into_iter()
1229 .map(ExecuteResponse::generated_from)
1230 .flatten()
1231 .collect::<Vec<_>>();
1232
1233 let is_valid_copy = matches!(
1237 stmt,
1238 Statement::Copy(CopyStatement {
1239 direction: CopyDirection::To,
1240 target: CopyTarget::Expr(_),
1241 ..
1242 }) | Statement::Copy(CopyStatement {
1243 direction: CopyDirection::From,
1244 target: CopyTarget::Expr(_),
1245 ..
1246 })
1247 );
1248
1249 if !is_valid_copy
1250 && execute_responses.iter().any(|execute_response| {
1251 match execute_response {
1253 ExecuteResponseKind::Subscribing if sender.allow_subscribe() => false,
1254 ExecuteResponseKind::Fetch
1255 | ExecuteResponseKind::Subscribing
1256 | ExecuteResponseKind::CopyFrom
1257 | ExecuteResponseKind::DeclaredCursor
1258 | ExecuteResponseKind::ClosedCursor => true,
1259 ExecuteResponseKind::CopyTo if matches!(kind, StatementKind::Copy) => true,
1264 _ => false,
1265 }
1266 })
1267 {
1268 return Err(Error::Unsupported(stmt.to_ast_string_simple()));
1269 }
1270 Ok(())
1271 }
1272
1273 fn parse<'a>(
1274 client: &SessionClient,
1275 query: &'a str,
1276 ) -> Result<Vec<StatementParseResult<'a>>, Error> {
1277 let result = client
1278 .parse(query)
1279 .map_err(|e| Error::Unstructured(anyhow!(e)))?;
1280 result.map_err(|e| AdapterError::from(e).into())
1281 }
1282
1283 let mut stmt_groups = vec![];
1284
1285 match request {
1286 SqlRequest::Simple { query } => match parse(client, &query) {
1287 Ok(stmts) => {
1288 let mut stmt_group = Vec::with_capacity(stmts.len());
1289 let mut stmt_err = None;
1290 for StatementParseResult { ast: stmt, sql } in stmts {
1291 if let Err(err) = check_prohibited_stmts(sender, &stmt) {
1292 stmt_err = Some(err);
1293 break;
1294 }
1295 stmt_group.push((stmt, sql.to_string(), vec![]));
1296 }
1297 stmt_groups.push(stmt_err.map(Err).unwrap_or_else(|| Ok(stmt_group)));
1298 }
1299 Err(e) => stmt_groups.push(Err(e)),
1300 },
1301 SqlRequest::Extended { queries } => {
1302 for ExtendedRequest { query, params } in queries {
1303 match parse(client, &query) {
1304 Ok(mut stmts) => {
1305 if stmts.len() != 1 {
1306 return Err(Error::Unstructured(anyhow!(
1307 "each query must contain exactly 1 statement, but \"{}\" contains {}",
1308 query,
1309 stmts.len()
1310 )));
1311 }
1312
1313 let StatementParseResult { ast: stmt, sql } = stmts.pop().unwrap();
1314 stmt_groups.push(
1315 check_prohibited_stmts(sender, &stmt)
1316 .map(|_| vec![(stmt, sql.to_string(), params)]),
1317 );
1318 }
1319 Err(e) => stmt_groups.push(Err(e)),
1320 };
1321 }
1322 }
1323 }
1324
1325 for stmt_group_res in stmt_groups {
1326 let executed = match stmt_group_res {
1327 Ok(stmt_group) => execute_stmt_group(client, sender, stmt_group).await,
1328 Err(e) => {
1329 let err = SqlResult::err(client, e);
1330 let _ = send_and_retire(err.into(), client, sender).await?;
1331 Ok(Err(()))
1332 }
1333 };
1334 if client.session().transaction().is_implicit() {
1337 let ended = client.end_transaction(EndTransactionAction::Commit).await;
1338 if let Err(err) = ended {
1339 let err = SqlResult::err(client, err);
1340 let _ = send_and_retire(StatementResult::SqlResult(err), client, sender).await?;
1341 }
1342 }
1343 if executed?.is_err() {
1344 break;
1345 }
1346 }
1347
1348 Ok(())
1349}
1350
1351async fn execute_stmt<S: ResultSender>(
1353 client: &mut SessionClient,
1354 sender: &mut S,
1355 stmt: Statement<Raw>,
1356 sql: String,
1357 raw_params: Vec<Option<String>>,
1358) -> Result<StatementResult, Error> {
1359 const EMPTY_PORTAL: &str = "";
1360 if let Err(e) = client
1361 .prepare(EMPTY_PORTAL.into(), Some(stmt.clone()), sql, vec![])
1362 .await
1363 {
1364 return Ok(SqlResult::err(client, e).into());
1365 }
1366
1367 let prep_stmt = match client.get_prepared_statement(EMPTY_PORTAL).await {
1368 Ok(stmt) => stmt,
1369 Err(err) => {
1370 return Ok(SqlResult::err(client, err).into());
1371 }
1372 };
1373
1374 let param_types = &prep_stmt.desc().param_types;
1375 if param_types.len() != raw_params.len() {
1376 let message = anyhow!(
1377 "request supplied {actual} parameters, \
1378 but {statement} requires {expected}",
1379 statement = stmt.to_ast_string_simple(),
1380 actual = raw_params.len(),
1381 expected = param_types.len()
1382 );
1383 return Ok(SqlResult::err(client, Error::Unstructured(message)).into());
1384 }
1385
1386 let buf = RowArena::new();
1387 let mut params = vec![];
1388 for (raw_param, mz_typ) in raw_params.into_iter().zip_eq(param_types) {
1389 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1390 let datum = match raw_param {
1391 None => Datum::Null,
1392 Some(raw_param) => {
1393 match mz_pgrepr::Value::decode(
1394 mz_pgwire_common::Format::Text,
1395 &pg_typ,
1396 raw_param.as_bytes(),
1397 ) {
1398 Ok(param) => param.into_datum(&buf, &pg_typ),
1399 Err(err) => {
1400 let msg = anyhow!("unable to decode parameter: {}", err);
1401 return Ok(SqlResult::err(client, Error::Unstructured(msg)).into());
1402 }
1403 }
1404 }
1405 };
1406 params.push((datum, mz_typ.clone()))
1407 }
1408
1409 let result_formats = vec![
1410 mz_pgwire_common::Format::Text;
1411 prep_stmt
1412 .desc()
1413 .relation_desc
1414 .clone()
1415 .map(|desc| desc.typ().column_types.len())
1416 .unwrap_or(0)
1417 ];
1418
1419 let desc = prep_stmt.desc().clone();
1420 let logging = Arc::clone(prep_stmt.logging());
1421 let stmt_ast = prep_stmt.stmt().cloned();
1422 let state_revision = prep_stmt.state_revision;
1423 if let Err(err) = client.session().set_portal(
1424 EMPTY_PORTAL.into(),
1425 desc,
1426 stmt_ast,
1427 logging,
1428 params,
1429 result_formats,
1430 state_revision,
1431 ) {
1432 return Ok(SqlResult::err(client, err).into());
1433 }
1434
1435 let desc = client
1436 .session()
1437 .get_portal_unverified(EMPTY_PORTAL)
1439 .map(|portal| portal.desc.clone())
1440 .expect("unnamed portal should be present");
1441
1442 let res = client
1443 .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
1444 .await;
1445
1446 if S::SUPPORTS_STREAMING_NOTICES {
1447 sender
1448 .emit_streaming_notices(client.session().drain_notices())
1449 .await?;
1450 }
1451
1452 let (res, execute_started) = match res {
1453 Ok(res) => res,
1454 Err(e) => {
1455 return Ok(SqlResult::err(client, e).into());
1456 }
1457 };
1458 let tag = res.tag();
1459
1460 Ok(match res {
1461 ExecuteResponse::CreatedConnection { .. }
1462 | ExecuteResponse::CreatedDatabase { .. }
1463 | ExecuteResponse::CreatedSchema { .. }
1464 | ExecuteResponse::CreatedRole
1465 | ExecuteResponse::CreatedCluster { .. }
1466 | ExecuteResponse::CreatedClusterReplica { .. }
1467 | ExecuteResponse::CreatedTable { .. }
1468 | ExecuteResponse::CreatedIndex { .. }
1469 | ExecuteResponse::CreatedIntrospectionSubscribe
1470 | ExecuteResponse::CreatedSecret { .. }
1471 | ExecuteResponse::CreatedSource { .. }
1472 | ExecuteResponse::CreatedSink { .. }
1473 | ExecuteResponse::CreatedView { .. }
1474 | ExecuteResponse::CreatedViews { .. }
1475 | ExecuteResponse::CreatedMaterializedView { .. }
1476 | ExecuteResponse::CreatedContinualTask { .. }
1477 | ExecuteResponse::CreatedType
1478 | ExecuteResponse::CreatedNetworkPolicy
1479 | ExecuteResponse::Comment
1480 | ExecuteResponse::Deleted(_)
1481 | ExecuteResponse::DiscardedTemp
1482 | ExecuteResponse::DiscardedAll
1483 | ExecuteResponse::DroppedObject(_)
1484 | ExecuteResponse::DroppedOwned
1485 | ExecuteResponse::EmptyQuery
1486 | ExecuteResponse::GrantedPrivilege
1487 | ExecuteResponse::GrantedRole
1488 | ExecuteResponse::Inserted(_)
1489 | ExecuteResponse::Copied(_)
1490 | ExecuteResponse::Raised
1491 | ExecuteResponse::ReassignOwned
1492 | ExecuteResponse::RevokedPrivilege
1493 | ExecuteResponse::AlteredDefaultPrivileges
1494 | ExecuteResponse::RevokedRole
1495 | ExecuteResponse::StartedTransaction { .. }
1496 | ExecuteResponse::Updated(_)
1497 | ExecuteResponse::AlteredObject(_)
1498 | ExecuteResponse::AlteredRole
1499 | ExecuteResponse::AlteredSystemConfiguration
1500 | ExecuteResponse::Deallocate { .. }
1501 | ExecuteResponse::ValidatedConnection
1502 | ExecuteResponse::Prepare => SqlResult::ok(
1503 client,
1504 tag.expect("ok only called on tag-generating results"),
1505 Vec::default(),
1506 )
1507 .into(),
1508 ExecuteResponse::TransactionCommitted { params }
1509 | ExecuteResponse::TransactionRolledBack { params } => {
1510 let notify_set: mz_ore::collections::HashSet<_> = client
1511 .session()
1512 .vars()
1513 .notify_set()
1514 .map(|v| v.name().to_string())
1515 .collect();
1516 let params = params
1517 .into_iter()
1518 .filter(|(name, _value)| notify_set.contains(*name))
1519 .map(|(name, value)| ParameterStatus {
1520 name: name.to_string(),
1521 value,
1522 })
1523 .collect();
1524 SqlResult::ok(
1525 client,
1526 tag.expect("ok only called on tag-generating results"),
1527 params,
1528 )
1529 .into()
1530 }
1531 ExecuteResponse::SetVariable { name, .. } => {
1532 let mut params = Vec::with_capacity(1);
1533 if let Some(var) = client
1534 .session()
1535 .vars()
1536 .notify_set()
1537 .find(|v| v.name() == &name)
1538 {
1539 params.push(ParameterStatus {
1540 name,
1541 value: var.value(),
1542 });
1543 };
1544 SqlResult::ok(
1545 client,
1546 tag.expect("ok only called on tag-generating results"),
1547 params,
1548 )
1549 .into()
1550 }
1551 ExecuteResponse::SendingRowsStreaming {
1552 rows,
1553 instance_id,
1554 strategy,
1555 } => {
1556 let max_query_result_size =
1557 usize::cast_from(client.get_system_vars().await.max_result_size());
1558
1559 let rows_stream = RecordFirstRowStream::new(
1560 Box::new(rows),
1561 execute_started,
1562 client,
1563 Some(instance_id),
1564 Some(strategy),
1565 );
1566
1567 SqlResult::rows(
1568 sender,
1569 client,
1570 rows_stream,
1571 max_query_result_size,
1572 &desc.relation_desc.expect("RelationDesc must exist"),
1573 )
1574 .await?
1575 .into()
1576 }
1577 ExecuteResponse::SendingRowsImmediate { rows } => {
1578 let max_query_result_size =
1579 usize::cast_from(client.get_system_vars().await.max_result_size());
1580
1581 let rows = futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1582 let rows_stream =
1583 RecordFirstRowStream::new(Box::new(rows), execute_started, client, None, None);
1584
1585 SqlResult::rows(
1586 sender,
1587 client,
1588 rows_stream,
1589 max_query_result_size,
1590 &desc.relation_desc.expect("RelationDesc must exist"),
1591 )
1592 .await?
1593 .into()
1594 }
1595 ExecuteResponse::Subscribing {
1596 rx,
1597 ctx_extra,
1598 instance_id,
1599 } => StatementResult::Subscribe {
1600 tag: "SUBSCRIBE".into(),
1601 desc: desc.relation_desc.unwrap(),
1602 rx: RecordFirstRowStream::new(
1603 Box::new(UnboundedReceiverStream::new(rx)),
1604 execute_started,
1605 client,
1606 Some(instance_id),
1607 None,
1608 ),
1609 ctx_extra,
1610 },
1611 res @ (ExecuteResponse::Fetch { .. }
1612 | ExecuteResponse::CopyTo { .. }
1613 | ExecuteResponse::CopyFrom { .. }
1614 | ExecuteResponse::DeclaredCursor
1615 | ExecuteResponse::ClosedCursor) => SqlResult::err(
1616 client,
1617 Error::Unstructured(anyhow!(
1618 "internal error: encountered prohibited ExecuteResponse {:?}.\n\n
1619 This is a bug. Can you please file an bug report letting us know?\n
1620 https://github.com/MaterializeInc/materialize/discussions/new?category=bug-reports",
1621 ExecuteResponseKind::from(res)
1622 )),
1623 )
1624 .into(),
1625 })
1626}
1627
1628fn make_notices(client: &mut SessionClient) -> Vec<Notice> {
1629 client
1630 .session()
1631 .drain_notices()
1632 .into_iter()
1633 .map(|notice| Notice {
1634 message: notice.to_string(),
1635 code: notice.code().code().to_string(),
1636 severity: notice.severity().as_str().to_lowercase(),
1637 detail: notice.detail(),
1638 hint: notice.hint(),
1639 })
1640 .collect()
1641}
1642
1643fn is_txn_exit_stmt(stmt: &Statement<Raw>) -> bool {
1646 matches!(
1647 stmt,
1648 Statement::Commit(_) | Statement::Rollback(_) | Statement::Prepare(_)
1649 )
1650}
1651
1652#[cfg(test)]
1653mod tests {
1654 use std::collections::BTreeMap;
1655
1656 use super::{Password, WebSocketAuth};
1657
1658 #[mz_ore::test]
1659 fn smoke_test_websocket_auth_parse() {
1660 struct TestCase {
1661 json: &'static str,
1662 expected: WebSocketAuth,
1663 }
1664
1665 let test_cases = vec![
1666 TestCase {
1667 json: r#"{ "user": "mz", "password": "1234" }"#,
1668 expected: WebSocketAuth::Basic {
1669 user: "mz".to_string(),
1670 password: Password("1234".to_string()),
1671 options: BTreeMap::default(),
1672 },
1673 },
1674 TestCase {
1675 json: r#"{ "user": "mz", "password": "1234", "options": {} }"#,
1676 expected: WebSocketAuth::Basic {
1677 user: "mz".to_string(),
1678 password: Password("1234".to_string()),
1679 options: BTreeMap::default(),
1680 },
1681 },
1682 TestCase {
1683 json: r#"{ "token": "i_am_a_token" }"#,
1684 expected: WebSocketAuth::Bearer {
1685 token: "i_am_a_token".to_string(),
1686 options: BTreeMap::default(),
1687 },
1688 },
1689 TestCase {
1690 json: r#"{ "token": "i_am_a_token", "options": { "foo": "bar" } }"#,
1691 expected: WebSocketAuth::Bearer {
1692 token: "i_am_a_token".to_string(),
1693 options: BTreeMap::from([("foo".to_string(), "bar".to_string())]),
1694 },
1695 },
1696 ];
1697
1698 fn assert_parse(json: &'static str, expected: WebSocketAuth) {
1699 let parsed: WebSocketAuth = serde_json::from_str(json).unwrap();
1700 assert_eq!(parsed, expected);
1701 }
1702
1703 for TestCase { json, expected } in test_cases {
1704 assert_parse(json, expected)
1705 }
1706 }
1707}