1use std::borrow::Cow;
11use std::collections::BTreeMap;
12use std::net::{IpAddr, SocketAddr};
13use std::pin::pin;
14use std::sync::Arc;
15use std::time::{Duration, SystemTime};
16
17use anyhow::anyhow;
18use async_trait::async_trait;
19use axum::extract::connect_info::ConnectInfo;
20use axum::extract::ws::{CloseFrame, Message, WebSocket};
21use axum::extract::{State, WebSocketUpgrade};
22use axum::response::IntoResponse;
23use axum::{Extension, Json};
24use futures::Future;
25use futures::future::BoxFuture;
26
27use http::StatusCode;
28use itertools::izip;
29use mz_adapter::client::RecordFirstRowStream;
30use mz_adapter::session::{EndTransactionAction, TransactionStatus};
31use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
32use mz_adapter::{
33 AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, ExecuteResponseKind,
34 PeekResponseUnary, SessionClient, verify_datum_desc,
35};
36use mz_auth::password::Password;
37use mz_catalog::memory::objects::{Cluster, ClusterReplica};
38use mz_interchange::encode::TypedDatum;
39use mz_interchange::json::{JsonNumberPolicy, ToJson};
40use mz_ore::cast::CastFrom;
41use mz_ore::metrics::{MakeCollectorOpts, MetricsRegistry};
42use mz_ore::result::ResultExt;
43use mz_repr::{Datum, RelationDesc, RowArena, RowIterator};
44use mz_sql::ast::display::AstDisplay;
45use mz_sql::ast::{CopyDirection, CopyStatement, CopyTarget, Raw, Statement, StatementKind};
46use mz_sql::parse::StatementParseResult;
47use mz_sql::plan::Plan;
48use mz_sql::session::metadata::SessionMetadata;
49use prometheus::Opts;
50use prometheus::core::{AtomicF64, GenericGaugeVec};
51use serde::{Deserialize, Serialize};
52use tokio::{select, time};
53use tokio_postgres::error::SqlState;
54use tokio_stream::wrappers::UnboundedReceiverStream;
55use tower_sessions::Session as TowerSession;
56use tracing::{debug, error, info};
57use tungstenite::protocol::frame::coding::CloseCode;
58
59use crate::http::prometheus::PrometheusSqlQuery;
60use crate::http::{
61 AuthError, AuthedClient, AuthedUser, MAX_REQUEST_SIZE, SESSION_DURATION, TowerSessionData,
62 WsState, init_ws,
63};
64
65#[derive(Debug, thiserror::Error)]
66pub enum Error {
67 #[error(transparent)]
68 Adapter(#[from] AdapterError),
69 #[error(transparent)]
70 Json(#[from] serde_json::Error),
71 #[error(transparent)]
72 Axum(#[from] axum::Error),
73 #[error("SUBSCRIBE only supported over websocket")]
74 SubscribeOnlyOverWs,
75 #[error("current transaction is aborted, commands ignored until end of transaction block")]
76 AbortedTransaction,
77 #[error("unsupported via this API: {0}")]
78 Unsupported(String),
79 #[error("{0}")]
80 Unstructured(anyhow::Error),
81}
82
83impl Error {
84 pub fn detail(&self) -> Option<String> {
85 match self {
86 Error::Adapter(err) => err.detail(),
87 _ => None,
88 }
89 }
90
91 pub fn hint(&self) -> Option<String> {
92 match self {
93 Error::Adapter(err) => err.hint(),
94 _ => None,
95 }
96 }
97
98 pub fn position(&self) -> Option<usize> {
99 match self {
100 Error::Adapter(err) => err.position(),
101 _ => None,
102 }
103 }
104
105 pub fn code(&self) -> SqlState {
106 match self {
107 Error::Adapter(err) => err.code(),
108 Error::AbortedTransaction => SqlState::IN_FAILED_SQL_TRANSACTION,
109 _ => SqlState::INTERNAL_ERROR,
110 }
111 }
112}
113
114static PER_REPLICA_LABELS: &[&str] = &["replica_full_name", "instance_id", "replica_id"];
115
116async fn execute_promsql_query(
117 client: &mut AuthedClient,
118 query: &PrometheusSqlQuery<'_>,
119 metrics_registry: &MetricsRegistry,
120 metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
121 cluster: Option<(&Cluster, &ClusterReplica)>,
122) {
123 assert_eq!(query.per_replica, cluster.is_some());
124
125 let mut res = SqlResponse {
126 results: Vec::new(),
127 };
128
129 execute_request(client, query.to_sql_request(cluster), &mut res)
130 .await
131 .expect("valid SQL query");
132
133 let result = match res.results.as_slice() {
134 [
137 SqlResult::Ok { .. },
138 SqlResult::Ok { .. },
139 SqlResult::Ok { .. },
140 result,
141 ] => result,
142 _ => {
146 info!(
147 "error executing prometheus query {}: {:?}",
148 query.metric_name, res
149 );
150 return;
151 }
152 };
153
154 let SqlResult::Rows { desc, rows, .. } = result else {
155 info!(
156 "did not receive rows for SQL query for prometheus metric {}: {:?}, {:?}",
157 query.metric_name, result, cluster
158 );
159 return;
160 };
161
162 let gauge_vec = metrics_by_name
163 .entry(query.metric_name.to_string())
164 .or_insert_with(|| {
165 let mut label_names: Vec<String> = desc
166 .columns
167 .iter()
168 .filter(|col| col.name != query.value_column_name)
169 .map(|col| col.name.clone())
170 .collect();
171
172 if query.per_replica {
173 label_names.extend(PER_REPLICA_LABELS.iter().map(|label| label.to_string()));
174 }
175
176 metrics_registry.register::<GenericGaugeVec<AtomicF64>>(MakeCollectorOpts {
177 opts: Opts::new(query.metric_name, query.help).variable_labels(label_names),
178 buckets: None,
179 })
180 });
181
182 for row in rows {
183 let mut label_values = desc
184 .columns
185 .iter()
186 .zip(row)
187 .filter(|(col, _)| col.name != query.value_column_name)
188 .map(|(_, val)| val.as_str().expect("must be string"))
189 .collect::<Vec<_>>();
190
191 let value = desc
192 .columns
193 .iter()
194 .zip(row)
195 .find(|(col, _)| col.name == query.value_column_name)
196 .map(|(_, val)| val.as_str().unwrap_or("0").parse::<f64>().unwrap_or(0.0))
197 .unwrap_or(0.0);
198
199 match cluster {
200 Some((cluster, replica)) => {
201 let replica_full_name = format!("{}.{}", cluster.name, replica.name);
202 let cluster_id = cluster.id.to_string();
203 let replica_id = replica.replica_id.to_string();
204
205 label_values.push(&replica_full_name);
206 label_values.push(&cluster_id);
207 label_values.push(&replica_id);
208
209 gauge_vec
210 .get_metric_with_label_values(&label_values)
211 .expect("valid labels")
212 .set(value);
213 }
214 None => {
215 gauge_vec
216 .get_metric_with_label_values(&label_values)
217 .expect("valid labels")
218 .set(value);
219 }
220 }
221 }
222}
223
224async fn handle_promsql_query(
225 client: &mut AuthedClient,
226 query: &PrometheusSqlQuery<'_>,
227 metrics_registry: &MetricsRegistry,
228 metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
229) {
230 if !query.per_replica {
231 execute_promsql_query(client, query, metrics_registry, metrics_by_name, None).await;
232 return;
233 }
234
235 let catalog = client.client.catalog_snapshot("handle_promsql_query").await;
236 let clusters: Vec<&Cluster> = catalog.clusters().collect();
237
238 for cluster in clusters {
239 for replica in cluster.replicas() {
240 execute_promsql_query(
241 client,
242 query,
243 metrics_registry,
244 metrics_by_name,
245 Some((cluster, replica)),
246 )
247 .await;
248 }
249 }
250}
251
252pub async fn handle_promsql(
253 mut client: AuthedClient,
254 queries: &[PrometheusSqlQuery<'_>],
255) -> MetricsRegistry {
256 let metrics_registry = MetricsRegistry::new();
257 let mut metrics_by_name = BTreeMap::new();
258
259 for query in queries {
260 handle_promsql_query(&mut client, query, &metrics_registry, &mut metrics_by_name).await;
261 }
262
263 metrics_registry
264}
265
266pub async fn handle_sql(
267 mut client: AuthedClient,
268 Json(request): Json<SqlRequest>,
269) -> impl IntoResponse {
270 let mut res = SqlResponse {
271 results: Vec::new(),
272 };
273 match execute_request(&mut client, request, &mut res).await {
276 Ok(()) => Ok(Json(res)),
277 Err(e) => Err((StatusCode::BAD_REQUEST, e.to_string())),
278 }
279}
280
281pub async fn handle_sql_ws(
282 State(state): State<WsState>,
283 existing_user: Option<Extension<AuthedUser>>,
284 ws: WebSocketUpgrade,
285 ConnectInfo(addr): ConnectInfo<SocketAddr>,
286 tower_session: Option<Extension<TowerSession>>,
287) -> impl IntoResponse {
288 let user = match (existing_user, tower_session) {
290 (Some(Extension(user)), _) => Some(user),
291 (None, Some(session)) => {
292 if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
293 if session_data
295 .last_activity
296 .elapsed()
297 .unwrap_or(Duration::MAX)
298 > SESSION_DURATION
299 {
300 let _ = session.delete().await;
301 return Err(AuthError::SessionExpired);
302 }
303 let mut updated_data = session_data.clone();
305 updated_data.last_activity = SystemTime::now();
306 session
307 .insert("data", &updated_data)
308 .await
309 .map_err(|_| AuthError::FailedToUpdateSession)?;
310 Some(AuthedUser {
312 name: session_data.username,
313 external_metadata_rx: None,
314 })
315 } else {
316 None
317 }
318 }
319 _ => None,
320 };
321
322 let addr = Box::new(addr.ip());
323 Ok(ws
324 .max_message_size(MAX_REQUEST_SIZE)
325 .on_upgrade(|ws| async move { run_ws(&state, user, *addr, ws).await }))
326}
327
328#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
329#[serde(untagged)]
330pub enum WebSocketAuth {
331 Basic {
332 user: String,
333 password: Password,
334 #[serde(default)]
335 options: BTreeMap<String, String>,
336 },
337 Bearer {
338 token: String,
339 #[serde(default)]
340 options: BTreeMap<String, String>,
341 },
342 OptionsOnly {
343 #[serde(default)]
344 options: BTreeMap<String, String>,
345 },
346}
347
348async fn run_ws(state: &WsState, user: Option<AuthedUser>, peer_addr: IpAddr, mut ws: WebSocket) {
349 let mut client = match init_ws(state, user, peer_addr, &mut ws).await {
350 Ok(client) => client,
351 Err(e) => {
352 debug!("WS request failed init: {}", e);
356 let reason = match e.downcast_ref::<AdapterError>() {
357 Some(error) => Cow::Owned(error.to_string()),
358 None => "unauthorized".into(),
359 };
360 let _ = ws
361 .send(Message::Close(Some(CloseFrame {
362 code: CloseCode::Protocol.into(),
363 reason,
364 })))
365 .await;
366 return;
367 }
368 };
369
370 let mut msgs = Vec::new();
372 let session = client.client.session();
373 for var in session.vars().notify_set() {
374 msgs.push(WebSocketResponse::ParameterStatus(ParameterStatus {
375 name: var.name().to_string(),
376 value: var.value(),
377 }));
378 }
379 msgs.push(WebSocketResponse::BackendKeyData(BackendKeyData {
380 conn_id: session.conn_id().unhandled(),
381 secret_key: session.secret_key(),
382 }));
383 msgs.push(WebSocketResponse::ReadyForQuery(
384 session.transaction_code().into(),
385 ));
386 for msg in msgs {
387 let _ = ws
388 .send(Message::Text(
389 serde_json::to_string(&msg).expect("must serialize"),
390 ))
391 .await;
392 }
393
394 let notices = session.drain_notices();
396 if let Err(err) = forward_notices(&mut ws, notices).await {
397 debug!("failed to forward notices to WebSocket, {err:?}");
398 return;
399 }
400
401 loop {
402 let msg = select! {
404 biased;
405
406 Some(timeout) = client.client.recv_timeout() => {
408 client.client.terminate().await;
409 let _ = ws.recv().await;
413 let err = Error::from(AdapterError::from(timeout));
414 let _ = send_ws_response(&mut ws, WebSocketResponse::Error(err.into())).await;
415 return;
416 },
417 message = ws.recv() => message,
418 };
419
420 client.client.remove_idle_in_transaction_session_timeout();
421
422 let msg = match msg {
423 Some(Ok(msg)) => msg,
424 _ => {
425 return;
427 }
428 };
429
430 let req: Result<SqlRequest, Error> = match msg {
431 Message::Text(data) => serde_json::from_str(&data).err_into(),
432 Message::Binary(data) => serde_json::from_slice(&data).err_into(),
433 Message::Ping(_) => {
435 continue;
436 }
437 Message::Pong(_) => {
438 continue;
439 }
440 Message::Close(_) => {
441 return;
442 }
443 };
444
445 let err = match run_ws_request(req, &mut client, &mut ws).await {
447 Ok(()) => None,
448 Err(err) => Some(WebSocketResponse::Error(err.into())),
449 };
450
451 let ws_response = || async {
458 if let Some(e_resp) = err {
460 send_ws_response(&mut ws, e_resp).await?;
461 }
462
463 let notices = client.client.session().drain_notices();
465 forward_notices(&mut ws, notices).await?;
466
467 let ready =
469 WebSocketResponse::ReadyForQuery(client.client.session().transaction_code().into());
470 send_ws_response(&mut ws, ready).await?;
471
472 Ok::<_, Error>(())
473 };
474
475 if let Err(err) = ws_response().await {
476 debug!("failed to send response over WebSocket, {err:?}");
477 return;
478 }
479 }
480}
481
482async fn run_ws_request(
483 req: Result<SqlRequest, Error>,
484 client: &mut AuthedClient,
485 ws: &mut WebSocket,
486) -> Result<(), Error> {
487 let req = req?;
488 execute_request(client, req, ws).await
489}
490
491async fn send_ws_response(ws: &mut WebSocket, resp: WebSocketResponse) -> Result<(), Error> {
493 let msg = serde_json::to_string(&resp).unwrap();
494 let msg = Message::Text(msg);
495 ws.send(msg).await?;
496
497 Ok(())
498}
499
500async fn forward_notices(
502 ws: &mut WebSocket,
503 notices: impl IntoIterator<Item = AdapterNotice>,
504) -> Result<(), Error> {
505 let ws_notices = notices.into_iter().map(|notice| {
506 WebSocketResponse::Notice(Notice {
507 message: notice.to_string(),
508 code: notice.code().code().to_string(),
509 severity: notice.severity().as_str().to_lowercase(),
510 detail: notice.detail(),
511 hint: notice.hint(),
512 })
513 });
514
515 for notice in ws_notices {
516 send_ws_response(ws, notice).await?;
517 }
518
519 Ok(())
520}
521
522#[derive(Serialize, Deserialize, Debug)]
524#[serde(untagged)]
525pub enum SqlRequest {
526 Simple {
528 query: String,
531 },
532 Extended {
534 queries: Vec<ExtendedRequest>,
536 },
537}
538
539#[derive(Serialize, Deserialize, Debug)]
541pub struct ExtendedRequest {
542 query: String,
544 #[serde(default)]
546 params: Vec<Option<String>>,
547}
548
549#[derive(Debug, Serialize, Deserialize)]
551pub struct SqlResponse {
552 results: Vec<SqlResult>,
554}
555
556enum StatementResult {
557 SqlResult(SqlResult),
558 Subscribe {
559 desc: RelationDesc,
560 tag: String,
561 rx: RecordFirstRowStream,
562 ctx_extra: ExecuteContextExtra,
563 },
564}
565
566impl From<SqlResult> for StatementResult {
567 fn from(inner: SqlResult) -> Self {
568 Self::SqlResult(inner)
569 }
570}
571
572#[derive(Debug, Serialize, Deserialize)]
574#[serde(untagged)]
575pub enum SqlResult {
576 Rows {
578 tag: String,
580 rows: Vec<Vec<serde_json::Value>>,
582 desc: Description,
584 notices: Vec<Notice>,
586 },
587 Ok {
589 ok: String,
591 notices: Vec<Notice>,
593 #[serde(skip_serializing_if = "Vec::is_empty")]
597 parameters: Vec<ParameterStatus>,
598 },
599 Err {
601 error: SqlError,
602 notices: Vec<Notice>,
604 },
605}
606
607impl SqlResult {
608 async fn rows<S>(
612 sender: &mut S,
613 client: &mut SessionClient,
614 mut rows_stream: RecordFirstRowStream,
615 max_query_result_size: usize,
616 desc: &RelationDesc,
617 ) -> Result<SqlResult, Error>
618 where
619 S: ResultSender,
620 {
621 let mut rows: Vec<Vec<serde_json::Value>> = vec![];
622 let mut datum_vec = mz_repr::DatumVec::new();
623 let types = &desc.typ().column_types;
624
625 let mut query_result_size = 0;
626
627 loop {
628 let peek_response = tokio::select! {
629 notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
630 sender.emit_streaming_notices(vec![notice]).await?;
631 continue;
632 }
633 e = sender.connection_error() => return Err(e),
634 r = rows_stream.recv() => {
635 match r {
636 Some(r) => r,
637 None => break,
638 }
639 },
640 };
641
642 let mut sql_rows = match peek_response {
643 PeekResponseUnary::Rows(rows) => rows,
644 PeekResponseUnary::Error(e) => {
645 return Ok(SqlResult::err(client, Error::Unstructured(anyhow!(e))));
646 }
647 PeekResponseUnary::Canceled => {
648 return Ok(SqlResult::err(client, AdapterError::Canceled));
649 }
650 };
651
652 if let Err(err) = verify_datum_desc(desc, &mut sql_rows) {
653 return Ok(SqlResult::Err {
654 error: err.into(),
655 notices: make_notices(client),
656 });
657 }
658
659 while let Some(row) = sql_rows.next() {
660 query_result_size += row.byte_len();
661 if query_result_size > max_query_result_size {
662 use bytesize::ByteSize;
663 return Ok(SqlResult::err(
664 client,
665 AdapterError::ResultSize(format!(
666 "result exceeds max size of {}",
667 ByteSize::b(u64::cast_from(max_query_result_size))
668 )),
669 ));
670 }
671
672 let datums = datum_vec.borrow_with(row);
673 rows.push(
674 datums
675 .iter()
676 .enumerate()
677 .map(|(i, d)| {
678 TypedDatum::new(*d, &types[i])
679 .json(&JsonNumberPolicy::ConvertNumberToString)
680 })
681 .collect(),
682 );
683 }
684 }
685
686 let tag = format!("SELECT {}", rows.len());
687 Ok(SqlResult::Rows {
688 tag,
689 rows,
690 desc: Description::from(desc),
691 notices: make_notices(client),
692 })
693 }
694
695 fn err(client: &mut SessionClient, error: impl Into<SqlError>) -> SqlResult {
696 SqlResult::Err {
697 error: error.into(),
698 notices: make_notices(client),
699 }
700 }
701
702 fn ok(client: &mut SessionClient, tag: String, params: Vec<ParameterStatus>) -> SqlResult {
703 SqlResult::Ok {
704 ok: tag,
705 parameters: params,
706 notices: make_notices(client),
707 }
708 }
709}
710
711#[derive(Debug, Deserialize, Serialize)]
712pub struct SqlError {
713 pub message: String,
714 pub code: String,
715 #[serde(skip_serializing_if = "Option::is_none")]
716 pub detail: Option<String>,
717 #[serde(skip_serializing_if = "Option::is_none")]
718 pub hint: Option<String>,
719 #[serde(skip_serializing_if = "Option::is_none")]
720 pub position: Option<usize>,
721}
722
723impl From<Error> for SqlError {
724 fn from(err: Error) -> Self {
725 SqlError {
726 message: err.to_string(),
727 code: err.code().code().to_string(),
728 detail: err.detail(),
729 hint: err.hint(),
730 position: err.position(),
731 }
732 }
733}
734
735impl From<AdapterError> for SqlError {
736 fn from(value: AdapterError) -> Self {
737 Error::from(value).into()
738 }
739}
740
741#[derive(Debug, Deserialize, Serialize)]
742#[serde(tag = "type", content = "payload")]
743pub enum WebSocketResponse {
744 ReadyForQuery(String),
745 Notice(Notice),
746 Rows(Description),
747 Row(Vec<serde_json::Value>),
748 CommandStarting(CommandStarting),
749 CommandComplete(String),
750 Error(SqlError),
751 ParameterStatus(ParameterStatus),
752 BackendKeyData(BackendKeyData),
753}
754
755#[derive(Debug, Serialize, Deserialize)]
756pub struct Notice {
757 message: String,
758 code: String,
759 severity: String,
760 #[serde(skip_serializing_if = "Option::is_none")]
761 pub detail: Option<String>,
762 #[serde(skip_serializing_if = "Option::is_none")]
763 pub hint: Option<String>,
764}
765
766impl Notice {
767 pub fn message(&self) -> &str {
768 &self.message
769 }
770}
771
772#[derive(Debug, Serialize, Deserialize)]
773pub struct Description {
774 pub columns: Vec<Column>,
775}
776
777impl From<&RelationDesc> for Description {
778 fn from(desc: &RelationDesc) -> Self {
779 let columns = desc
780 .iter()
781 .map(|(name, typ)| {
782 let pg_type = mz_pgrepr::Type::from(&typ.scalar_type);
783 Column {
784 name: name.to_string(),
785 type_oid: pg_type.oid(),
786 type_len: pg_type.typlen(),
787 type_mod: pg_type.typmod(),
788 }
789 })
790 .collect();
791 Description { columns }
792 }
793}
794
795#[derive(Debug, Serialize, Deserialize)]
796pub struct Column {
797 pub name: String,
798 pub type_oid: u32,
799 pub type_len: i16,
800 pub type_mod: i32,
801}
802
803#[derive(Debug, Serialize, Deserialize)]
804pub struct ParameterStatus {
805 name: String,
806 value: String,
807}
808
809#[derive(Debug, Serialize, Deserialize)]
810pub struct BackendKeyData {
811 conn_id: u32,
812 secret_key: u32,
813}
814
815#[derive(Debug, Serialize, Deserialize)]
816pub struct CommandStarting {
817 has_rows: bool,
818 is_streaming: bool,
819}
820
821#[async_trait]
825trait ResultSender: Send {
826 const SUPPORTS_STREAMING_NOTICES: bool = false;
827
828 async fn add_result(
836 &mut self,
837 client: &mut SessionClient,
838 res: StatementResult,
839 ) -> (
840 Result<Result<(), ()>, Error>,
841 Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
842 );
843
844 fn connection_error(&mut self) -> BoxFuture<Error>;
846 fn allow_subscribe(&self) -> bool;
848
849 async fn emit_streaming_notices(&mut self, _: Vec<AdapterNotice>) -> Result<(), Error> {
853 unreachable!("streaming notices marked as unsupported")
854 }
855}
856
857#[async_trait]
858impl ResultSender for SqlResponse {
859 async fn add_result(
867 &mut self,
868 _client: &mut SessionClient,
869 res: StatementResult,
870 ) -> (
871 Result<Result<(), ()>, Error>,
872 Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
873 ) {
874 let (res, stmt_logging) = match res {
875 StatementResult::SqlResult(res) => {
876 let is_err = matches!(res, SqlResult::Err { .. });
877 self.results.push(res);
878 let res = if is_err { Err(()) } else { Ok(()) };
879 (res, None)
880 }
881 StatementResult::Subscribe { ctx_extra, .. } => {
882 let message = "SUBSCRIBE only supported over websocket";
883 self.results.push(SqlResult::Err {
884 error: Error::SubscribeOnlyOverWs.into(),
885 notices: Vec::new(),
886 });
887 (
888 Err(()),
889 Some((
890 StatementEndedExecutionReason::Errored {
891 error: message.into(),
892 },
893 ctx_extra,
894 )),
895 )
896 }
897 };
898 (Ok(res), stmt_logging)
899 }
900
901 fn connection_error(&mut self) -> BoxFuture<Error> {
902 Box::pin(futures::future::pending())
903 }
904
905 fn allow_subscribe(&self) -> bool {
906 false
907 }
908}
909
910#[async_trait]
911impl ResultSender for WebSocket {
912 const SUPPORTS_STREAMING_NOTICES: bool = true;
913
914 async fn add_result(
920 &mut self,
921 client: &mut SessionClient,
922 res: StatementResult,
923 ) -> (
924 Result<Result<(), ()>, Error>,
925 Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
926 ) {
927 let (has_rows, is_streaming) = match res {
928 StatementResult::SqlResult(SqlResult::Err { .. }) => (false, false),
929 StatementResult::SqlResult(SqlResult::Ok { .. }) => (false, false),
930 StatementResult::SqlResult(SqlResult::Rows { .. }) => (true, false),
931 StatementResult::Subscribe { .. } => (true, true),
932 };
933 if let Err(e) = send_ws_response(
934 self,
935 WebSocketResponse::CommandStarting(CommandStarting {
936 has_rows,
937 is_streaming,
938 }),
939 )
940 .await
941 {
942 return (Err(e), None);
943 }
944
945 let (is_err, msgs, stmt_logging) = match res {
946 StatementResult::SqlResult(SqlResult::Rows {
947 tag,
948 rows,
949 desc,
950 notices,
951 }) => {
952 let mut msgs = vec![WebSocketResponse::Rows(desc)];
953 msgs.extend(rows.into_iter().map(WebSocketResponse::Row));
954 msgs.push(WebSocketResponse::CommandComplete(tag));
955 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
956 (false, msgs, None)
957 }
958 StatementResult::SqlResult(SqlResult::Ok {
959 ok,
960 parameters,
961 notices,
962 }) => {
963 let mut msgs = vec![WebSocketResponse::CommandComplete(ok)];
964 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
965 msgs.extend(
966 parameters
967 .into_iter()
968 .map(WebSocketResponse::ParameterStatus),
969 );
970 (false, msgs, None)
971 }
972 StatementResult::SqlResult(SqlResult::Err { error, notices }) => {
973 let mut msgs = vec![WebSocketResponse::Error(error)];
974 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
975 (true, msgs, None)
976 }
977 StatementResult::Subscribe {
978 ref desc,
979 tag,
980 mut rx,
981 ctx_extra,
982 } => {
983 if let Err(e) = send_ws_response(self, WebSocketResponse::Rows(desc.into())).await {
984 return (
987 Err(e),
988 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
989 );
990 }
991
992 let mut datum_vec = mz_repr::DatumVec::new();
993 let mut result_size: usize = 0;
994 let mut rows_returned = 0;
995 loop {
996 let res = match await_rows(self, client, rx.recv()).await {
997 Ok(res) => res,
998 Err(e) => {
999 return (
1002 Err(e),
1003 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1004 );
1005 }
1006 };
1007 match res {
1008 Some(PeekResponseUnary::Rows(mut rows)) => {
1009 if let Err(err) = verify_datum_desc(desc, &mut rows) {
1010 let error = err.to_string();
1011 break (
1012 true,
1013 vec![WebSocketResponse::Error(err.into())],
1014 Some((
1015 StatementEndedExecutionReason::Errored { error },
1016 ctx_extra,
1017 )),
1018 );
1019 }
1020
1021 rows_returned += rows.count();
1022 while let Some(row) = rows.next() {
1023 result_size += row.byte_len();
1024 let datums = datum_vec.borrow_with(row);
1025 let types = &desc.typ().column_types;
1026 if let Err(e) = send_ws_response(
1027 self,
1028 WebSocketResponse::Row(
1029 datums
1030 .iter()
1031 .enumerate()
1032 .map(|(i, d)| {
1033 TypedDatum::new(*d, &types[i])
1034 .json(&JsonNumberPolicy::ConvertNumberToString)
1035 })
1036 .collect(),
1037 ),
1038 )
1039 .await
1040 {
1041 return (
1044 Err(e),
1045 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1046 );
1047 }
1048 }
1049 }
1050 Some(PeekResponseUnary::Error(error)) => {
1051 break (
1052 true,
1053 vec![WebSocketResponse::Error(
1054 Error::Unstructured(anyhow!(error.clone())).into(),
1055 )],
1056 Some((StatementEndedExecutionReason::Errored { error }, ctx_extra)),
1057 );
1058 }
1059 Some(PeekResponseUnary::Canceled) => {
1060 break (
1061 true,
1062 vec![WebSocketResponse::Error(AdapterError::Canceled.into())],
1063 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
1064 );
1065 }
1066 None => {
1067 break (
1068 false,
1069 vec![WebSocketResponse::CommandComplete(tag)],
1070 Some((
1071 StatementEndedExecutionReason::Success {
1072 result_size: Some(u64::cast_from(result_size)),
1073 rows_returned: Some(u64::cast_from(rows_returned)),
1074 execution_strategy: Some(
1075 StatementExecutionStrategy::Standard,
1076 ),
1077 },
1078 ctx_extra,
1079 )),
1080 );
1081 }
1082 }
1083 }
1084 }
1085 };
1086 for msg in msgs {
1087 if let Err(e) = send_ws_response(self, msg).await {
1088 return (
1089 Err(e),
1090 stmt_logging.map(|(_old_reason, ctx_extra)| {
1091 (StatementEndedExecutionReason::Canceled, ctx_extra)
1092 }),
1093 );
1094 }
1095 }
1096 (Ok(if is_err { Err(()) } else { Ok(()) }), stmt_logging)
1097 }
1098
1099 fn connection_error(&mut self) -> BoxFuture<Error> {
1102 Box::pin(async {
1103 let mut tick = time::interval(Duration::from_secs(1));
1104 tick.tick().await;
1105 loop {
1106 tick.tick().await;
1107 if let Err(err) = self.send(Message::Ping(Vec::new())).await {
1108 return err.into();
1109 }
1110 }
1111 })
1112 }
1113
1114 fn allow_subscribe(&self) -> bool {
1115 true
1116 }
1117
1118 async fn emit_streaming_notices(&mut self, notices: Vec<AdapterNotice>) -> Result<(), Error> {
1119 forward_notices(self, notices).await
1120 }
1121}
1122
1123async fn await_rows<S, F, R>(sender: &mut S, client: &mut SessionClient, f: F) -> Result<R, Error>
1124where
1125 S: ResultSender,
1126 F: Future<Output = R> + Send,
1127{
1128 let mut f = pin!(f);
1129 loop {
1130 tokio::select! {
1131 notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
1132 sender.emit_streaming_notices(vec![notice]).await?;
1133 }
1134 e = sender.connection_error() => return Err(e),
1135 r = &mut f => return Ok(r),
1136 }
1137 }
1138}
1139
1140async fn send_and_retire<S: ResultSender>(
1141 res: StatementResult,
1142 client: &mut SessionClient,
1143 sender: &mut S,
1144) -> Result<Result<(), ()>, Error> {
1145 let (res, stmt_logging) = sender.add_result(client, res).await;
1146 if let Some((reason, ctx_extra)) = stmt_logging {
1147 client.retire_execute(ctx_extra, reason);
1148 }
1149 res
1150}
1151
1152async fn execute_stmt_group<S: ResultSender>(
1154 client: &mut SessionClient,
1155 sender: &mut S,
1156 stmt_group: Vec<(Statement<Raw>, String, Vec<Option<String>>)>,
1157) -> Result<Result<(), ()>, Error> {
1158 let num_stmts = stmt_group.len();
1159 for (stmt, sql, params) in stmt_group {
1160 assert!(
1161 num_stmts <= 1 || params.is_empty(),
1162 "statement groups contain more than 1 statement iff Simple request, which does not support parameters"
1163 );
1164
1165 let is_aborted_txn = matches!(client.session().transaction(), TransactionStatus::Failed(_));
1166 if is_aborted_txn && !is_txn_exit_stmt(&stmt) {
1167 let err = SqlResult::err(client, Error::AbortedTransaction);
1168 let _ = send_and_retire(err.into(), client, sender).await?;
1169 return Ok(Err(()));
1170 }
1171
1172 if let Err(e) = client.start_transaction(Some(num_stmts)) {
1175 let err = SqlResult::err(client, e);
1176 let _ = send_and_retire(err.into(), client, sender).await?;
1177 return Ok(Err(()));
1178 }
1179 let res = execute_stmt(client, sender, stmt, sql, params).await?;
1180 let is_err = send_and_retire(res, client, sender).await?;
1181
1182 if is_err.is_err() {
1183 let txn = client.session().transaction();
1186 match txn {
1187 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
1190 TransactionStatus::Started(_) | TransactionStatus::InTransactionImplicit(_) => {
1192 if let Err(err) = client.end_transaction(EndTransactionAction::Rollback).await {
1193 let err = SqlResult::err(client, err);
1194 let _ = send_and_retire(err.into(), client, sender).await?;
1195 }
1196 }
1197 TransactionStatus::InTransaction(_) => {
1199 client.fail_transaction();
1200 }
1201 }
1202 return Ok(Err(()));
1203 }
1204 }
1205 Ok(Ok(()))
1206}
1207
1208async fn execute_request<S: ResultSender>(
1213 client: &mut AuthedClient,
1214 request: SqlRequest,
1215 sender: &mut S,
1216) -> Result<(), Error> {
1217 let client = &mut client.client;
1218
1219 fn check_prohibited_stmts<S: ResultSender>(
1222 sender: &S,
1223 stmt: &Statement<Raw>,
1224 ) -> Result<(), Error> {
1225 let kind: StatementKind = stmt.into();
1226 let execute_responses = Plan::generated_from(&kind)
1227 .into_iter()
1228 .map(ExecuteResponse::generated_from)
1229 .flatten()
1230 .collect::<Vec<_>>();
1231
1232 let is_valid_copy = matches!(
1236 stmt,
1237 Statement::Copy(CopyStatement {
1238 direction: CopyDirection::To,
1239 target: CopyTarget::Expr(_),
1240 ..
1241 }) | Statement::Copy(CopyStatement {
1242 direction: CopyDirection::From,
1243 target: CopyTarget::Expr(_),
1244 ..
1245 })
1246 );
1247
1248 if !is_valid_copy
1249 && execute_responses.iter().any(|execute_response| {
1250 match execute_response {
1252 ExecuteResponseKind::Subscribing if sender.allow_subscribe() => false,
1253 ExecuteResponseKind::Fetch
1254 | ExecuteResponseKind::Subscribing
1255 | ExecuteResponseKind::CopyFrom
1256 | ExecuteResponseKind::DeclaredCursor
1257 | ExecuteResponseKind::ClosedCursor => true,
1258 ExecuteResponseKind::CopyTo if matches!(kind, StatementKind::Copy) => true,
1263 _ => false,
1264 }
1265 })
1266 {
1267 return Err(Error::Unsupported(stmt.to_ast_string_simple()));
1268 }
1269 Ok(())
1270 }
1271
1272 fn parse<'a>(
1273 client: &SessionClient,
1274 query: &'a str,
1275 ) -> Result<Vec<StatementParseResult<'a>>, Error> {
1276 let result = client
1277 .parse(query)
1278 .map_err(|e| Error::Unstructured(anyhow!(e)))?;
1279 result.map_err(|e| AdapterError::from(e).into())
1280 }
1281
1282 let mut stmt_groups = vec![];
1283
1284 match request {
1285 SqlRequest::Simple { query } => match parse(client, &query) {
1286 Ok(stmts) => {
1287 let mut stmt_group = Vec::with_capacity(stmts.len());
1288 let mut stmt_err = None;
1289 for StatementParseResult { ast: stmt, sql } in stmts {
1290 if let Err(err) = check_prohibited_stmts(sender, &stmt) {
1291 stmt_err = Some(err);
1292 break;
1293 }
1294 stmt_group.push((stmt, sql.to_string(), vec![]));
1295 }
1296 stmt_groups.push(stmt_err.map(Err).unwrap_or_else(|| Ok(stmt_group)));
1297 }
1298 Err(e) => stmt_groups.push(Err(e)),
1299 },
1300 SqlRequest::Extended { queries } => {
1301 for ExtendedRequest { query, params } in queries {
1302 match parse(client, &query) {
1303 Ok(mut stmts) => {
1304 if stmts.len() != 1 {
1305 return Err(Error::Unstructured(anyhow!(
1306 "each query must contain exactly 1 statement, but \"{}\" contains {}",
1307 query,
1308 stmts.len()
1309 )));
1310 }
1311
1312 let StatementParseResult { ast: stmt, sql } = stmts.pop().unwrap();
1313 stmt_groups.push(
1314 check_prohibited_stmts(sender, &stmt)
1315 .map(|_| vec![(stmt, sql.to_string(), params)]),
1316 );
1317 }
1318 Err(e) => stmt_groups.push(Err(e)),
1319 };
1320 }
1321 }
1322 }
1323
1324 for stmt_group_res in stmt_groups {
1325 let executed = match stmt_group_res {
1326 Ok(stmt_group) => execute_stmt_group(client, sender, stmt_group).await,
1327 Err(e) => {
1328 let err = SqlResult::err(client, e);
1329 let _ = send_and_retire(err.into(), client, sender).await?;
1330 Ok(Err(()))
1331 }
1332 };
1333 if client.session().transaction().is_implicit() {
1336 let ended = client.end_transaction(EndTransactionAction::Commit).await;
1337 if let Err(err) = ended {
1338 let err = SqlResult::err(client, err);
1339 let _ = send_and_retire(StatementResult::SqlResult(err), client, sender).await?;
1340 }
1341 }
1342 if executed?.is_err() {
1343 break;
1344 }
1345 }
1346
1347 Ok(())
1348}
1349
1350async fn execute_stmt<S: ResultSender>(
1352 client: &mut SessionClient,
1353 sender: &mut S,
1354 stmt: Statement<Raw>,
1355 sql: String,
1356 raw_params: Vec<Option<String>>,
1357) -> Result<StatementResult, Error> {
1358 const EMPTY_PORTAL: &str = "";
1359 if let Err(e) = client
1360 .prepare(EMPTY_PORTAL.into(), Some(stmt.clone()), sql, vec![])
1361 .await
1362 {
1363 return Ok(SqlResult::err(client, e).into());
1364 }
1365
1366 let prep_stmt = match client.get_prepared_statement(EMPTY_PORTAL).await {
1367 Ok(stmt) => stmt,
1368 Err(err) => {
1369 return Ok(SqlResult::err(client, err).into());
1370 }
1371 };
1372
1373 let param_types = &prep_stmt.desc().param_types;
1374 if param_types.len() != raw_params.len() {
1375 let message = anyhow!(
1376 "request supplied {actual} parameters, \
1377 but {statement} requires {expected}",
1378 statement = stmt.to_ast_string_simple(),
1379 actual = raw_params.len(),
1380 expected = param_types.len()
1381 );
1382 return Ok(SqlResult::err(client, Error::Unstructured(message)).into());
1383 }
1384
1385 let buf = RowArena::new();
1386 let mut params = vec![];
1387 for (raw_param, mz_typ) in izip!(raw_params, param_types) {
1388 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1389 let datum = match raw_param {
1390 None => Datum::Null,
1391 Some(raw_param) => {
1392 match mz_pgrepr::Value::decode(
1393 mz_pgwire_common::Format::Text,
1394 &pg_typ,
1395 raw_param.as_bytes(),
1396 ) {
1397 Ok(param) => param.into_datum(&buf, &pg_typ),
1398 Err(err) => {
1399 let msg = anyhow!("unable to decode parameter: {}", err);
1400 return Ok(SqlResult::err(client, Error::Unstructured(msg)).into());
1401 }
1402 }
1403 }
1404 };
1405 params.push((datum, mz_typ.clone()))
1406 }
1407
1408 let result_formats = vec![
1409 mz_pgwire_common::Format::Text;
1410 prep_stmt
1411 .desc()
1412 .relation_desc
1413 .clone()
1414 .map(|desc| desc.typ().column_types.len())
1415 .unwrap_or(0)
1416 ];
1417
1418 let desc = prep_stmt.desc().clone();
1419 let revision = prep_stmt.catalog_revision;
1420 let stmt = prep_stmt.stmt().cloned();
1421 let logging = Arc::clone(prep_stmt.logging());
1422 if let Err(err) = client.session().set_portal(
1423 EMPTY_PORTAL.into(),
1424 desc,
1425 stmt,
1426 logging,
1427 params,
1428 result_formats,
1429 revision,
1430 ) {
1431 return Ok(SqlResult::err(client, err).into());
1432 }
1433
1434 let desc = client
1435 .session()
1436 .get_portal_unverified(EMPTY_PORTAL)
1438 .map(|portal| portal.desc.clone())
1439 .expect("unnamed portal should be present");
1440
1441 let res = client
1442 .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
1443 .await;
1444
1445 if S::SUPPORTS_STREAMING_NOTICES {
1446 sender
1447 .emit_streaming_notices(client.session().drain_notices())
1448 .await?;
1449 }
1450
1451 let (res, execute_started) = match res {
1452 Ok(res) => res,
1453 Err(e) => {
1454 return Ok(SqlResult::err(client, e).into());
1455 }
1456 };
1457 let tag = res.tag();
1458
1459 Ok(match res {
1460 ExecuteResponse::CreatedConnection { .. }
1461 | ExecuteResponse::CreatedDatabase { .. }
1462 | ExecuteResponse::CreatedSchema { .. }
1463 | ExecuteResponse::CreatedRole
1464 | ExecuteResponse::CreatedCluster { .. }
1465 | ExecuteResponse::CreatedClusterReplica { .. }
1466 | ExecuteResponse::CreatedTable { .. }
1467 | ExecuteResponse::CreatedIndex { .. }
1468 | ExecuteResponse::CreatedIntrospectionSubscribe
1469 | ExecuteResponse::CreatedSecret { .. }
1470 | ExecuteResponse::CreatedSource { .. }
1471 | ExecuteResponse::CreatedSink { .. }
1472 | ExecuteResponse::CreatedView { .. }
1473 | ExecuteResponse::CreatedViews { .. }
1474 | ExecuteResponse::CreatedMaterializedView { .. }
1475 | ExecuteResponse::CreatedContinualTask { .. }
1476 | ExecuteResponse::CreatedType
1477 | ExecuteResponse::CreatedNetworkPolicy
1478 | ExecuteResponse::Comment
1479 | ExecuteResponse::Deleted(_)
1480 | ExecuteResponse::DiscardedTemp
1481 | ExecuteResponse::DiscardedAll
1482 | ExecuteResponse::DroppedObject(_)
1483 | ExecuteResponse::DroppedOwned
1484 | ExecuteResponse::EmptyQuery
1485 | ExecuteResponse::GrantedPrivilege
1486 | ExecuteResponse::GrantedRole
1487 | ExecuteResponse::Inserted(_)
1488 | ExecuteResponse::Copied(_)
1489 | ExecuteResponse::Raised
1490 | ExecuteResponse::ReassignOwned
1491 | ExecuteResponse::RevokedPrivilege
1492 | ExecuteResponse::AlteredDefaultPrivileges
1493 | ExecuteResponse::RevokedRole
1494 | ExecuteResponse::StartedTransaction { .. }
1495 | ExecuteResponse::Updated(_)
1496 | ExecuteResponse::AlteredObject(_)
1497 | ExecuteResponse::AlteredRole
1498 | ExecuteResponse::AlteredSystemConfiguration
1499 | ExecuteResponse::Deallocate { .. }
1500 | ExecuteResponse::ValidatedConnection
1501 | ExecuteResponse::Prepare => SqlResult::ok(
1502 client,
1503 tag.expect("ok only called on tag-generating results"),
1504 Vec::default(),
1505 )
1506 .into(),
1507 ExecuteResponse::TransactionCommitted { params }
1508 | ExecuteResponse::TransactionRolledBack { params } => {
1509 let notify_set: mz_ore::collections::HashSet<_> = client
1510 .session()
1511 .vars()
1512 .notify_set()
1513 .map(|v| v.name().to_string())
1514 .collect();
1515 let params = params
1516 .into_iter()
1517 .filter(|(name, _value)| notify_set.contains(*name))
1518 .map(|(name, value)| ParameterStatus {
1519 name: name.to_string(),
1520 value,
1521 })
1522 .collect();
1523 SqlResult::ok(
1524 client,
1525 tag.expect("ok only called on tag-generating results"),
1526 params,
1527 )
1528 .into()
1529 }
1530 ExecuteResponse::SetVariable { name, .. } => {
1531 let mut params = Vec::with_capacity(1);
1532 if let Some(var) = client
1533 .session()
1534 .vars()
1535 .notify_set()
1536 .find(|v| v.name() == &name)
1537 {
1538 params.push(ParameterStatus {
1539 name,
1540 value: var.value(),
1541 });
1542 };
1543 SqlResult::ok(
1544 client,
1545 tag.expect("ok only called on tag-generating results"),
1546 params,
1547 )
1548 .into()
1549 }
1550 ExecuteResponse::SendingRowsStreaming {
1551 rows,
1552 instance_id,
1553 strategy,
1554 } => {
1555 let max_query_result_size =
1556 usize::cast_from(client.get_system_vars().await.max_result_size());
1557
1558 let rows_stream = RecordFirstRowStream::new(
1559 Box::new(rows),
1560 execute_started,
1561 client,
1562 Some(instance_id),
1563 Some(strategy),
1564 );
1565
1566 SqlResult::rows(
1567 sender,
1568 client,
1569 rows_stream,
1570 max_query_result_size,
1571 &desc.relation_desc.expect("RelationDesc must exist"),
1572 )
1573 .await?
1574 .into()
1575 }
1576 ExecuteResponse::SendingRowsImmediate { rows } => {
1577 let max_query_result_size =
1578 usize::cast_from(client.get_system_vars().await.max_result_size());
1579
1580 let rows = futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1581 let rows_stream =
1582 RecordFirstRowStream::new(Box::new(rows), execute_started, client, None, None);
1583
1584 SqlResult::rows(
1585 sender,
1586 client,
1587 rows_stream,
1588 max_query_result_size,
1589 &desc.relation_desc.expect("RelationDesc must exist"),
1590 )
1591 .await?
1592 .into()
1593 }
1594 ExecuteResponse::Subscribing {
1595 rx,
1596 ctx_extra,
1597 instance_id,
1598 } => StatementResult::Subscribe {
1599 tag: "SUBSCRIBE".into(),
1600 desc: desc.relation_desc.unwrap(),
1601 rx: RecordFirstRowStream::new(
1602 Box::new(UnboundedReceiverStream::new(rx)),
1603 execute_started,
1604 client,
1605 Some(instance_id),
1606 None,
1607 ),
1608 ctx_extra,
1609 },
1610 res @ (ExecuteResponse::Fetch { .. }
1611 | ExecuteResponse::CopyTo { .. }
1612 | ExecuteResponse::CopyFrom { .. }
1613 | ExecuteResponse::DeclaredCursor
1614 | ExecuteResponse::ClosedCursor) => SqlResult::err(
1615 client,
1616 Error::Unstructured(anyhow!(
1617 "internal error: encountered prohibited ExecuteResponse {:?}.\n\n
1618 This is a bug. Can you please file an bug report letting us know?\n
1619 https://github.com/MaterializeInc/materialize/discussions/new?category=bug-reports",
1620 ExecuteResponseKind::from(res)
1621 )),
1622 )
1623 .into(),
1624 })
1625}
1626
1627fn make_notices(client: &mut SessionClient) -> Vec<Notice> {
1628 client
1629 .session()
1630 .drain_notices()
1631 .into_iter()
1632 .map(|notice| Notice {
1633 message: notice.to_string(),
1634 code: notice.code().code().to_string(),
1635 severity: notice.severity().as_str().to_lowercase(),
1636 detail: notice.detail(),
1637 hint: notice.hint(),
1638 })
1639 .collect()
1640}
1641
1642fn is_txn_exit_stmt(stmt: &Statement<Raw>) -> bool {
1645 matches!(
1646 stmt,
1647 Statement::Commit(_) | Statement::Rollback(_) | Statement::Prepare(_)
1648 )
1649}
1650
1651#[cfg(test)]
1652mod tests {
1653 use std::collections::BTreeMap;
1654
1655 use super::{Password, WebSocketAuth};
1656
1657 #[mz_ore::test]
1658 fn smoke_test_websocket_auth_parse() {
1659 struct TestCase {
1660 json: &'static str,
1661 expected: WebSocketAuth,
1662 }
1663
1664 let test_cases = vec![
1665 TestCase {
1666 json: r#"{ "user": "mz", "password": "1234" }"#,
1667 expected: WebSocketAuth::Basic {
1668 user: "mz".to_string(),
1669 password: Password("1234".to_string()),
1670 options: BTreeMap::default(),
1671 },
1672 },
1673 TestCase {
1674 json: r#"{ "user": "mz", "password": "1234", "options": {} }"#,
1675 expected: WebSocketAuth::Basic {
1676 user: "mz".to_string(),
1677 password: Password("1234".to_string()),
1678 options: BTreeMap::default(),
1679 },
1680 },
1681 TestCase {
1682 json: r#"{ "token": "i_am_a_token" }"#,
1683 expected: WebSocketAuth::Bearer {
1684 token: "i_am_a_token".to_string(),
1685 options: BTreeMap::default(),
1686 },
1687 },
1688 TestCase {
1689 json: r#"{ "token": "i_am_a_token", "options": { "foo": "bar" } }"#,
1690 expected: WebSocketAuth::Bearer {
1691 token: "i_am_a_token".to_string(),
1692 options: BTreeMap::from([("foo".to_string(), "bar".to_string())]),
1693 },
1694 },
1695 ];
1696
1697 fn assert_parse(json: &'static str, expected: WebSocketAuth) {
1698 let parsed: WebSocketAuth = serde_json::from_str(json).unwrap();
1699 assert_eq!(parsed, expected);
1700 }
1701
1702 for TestCase { json, expected } in test_cases {
1703 assert_parse(json, expected)
1704 }
1705 }
1706}