1use std::borrow::Cow;
11use std::collections::BTreeMap;
12use std::net::{IpAddr, SocketAddr};
13use std::pin::pin;
14use std::sync::Arc;
15use std::time::Duration;
16
17use anyhow::anyhow;
18use async_trait::async_trait;
19use axum::extract::connect_info::ConnectInfo;
20use axum::extract::ws::{CloseFrame, Message, Utf8Bytes, WebSocket};
21use axum::extract::{State, WebSocketUpgrade};
22use axum::response::IntoResponse;
23use axum::{Extension, Json};
24use futures::Future;
25use futures::future::BoxFuture;
26
27use http::StatusCode;
28use itertools::izip;
29use mz_adapter::client::RecordFirstRowStream;
30use mz_adapter::session::{EndTransactionAction, TransactionStatus};
31use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
32use mz_adapter::{
33 AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, ExecuteResponseKind,
34 PeekResponseUnary, SessionClient, verify_datum_desc,
35};
36use mz_catalog::memory::objects::{Cluster, ClusterReplica};
37use mz_interchange::encode::TypedDatum;
38use mz_interchange::json::{JsonNumberPolicy, ToJson};
39use mz_ore::cast::CastFrom;
40use mz_ore::metrics::{MakeCollectorOpts, MetricsRegistry};
41use mz_ore::result::ResultExt;
42use mz_repr::{Datum, RelationDesc, RowArena, RowIterator};
43use mz_sql::ast::display::AstDisplay;
44use mz_sql::ast::{CopyDirection, CopyStatement, CopyTarget, Raw, Statement, StatementKind};
45use mz_sql::parse::StatementParseResult;
46use mz_sql::plan::Plan;
47use mz_sql::session::metadata::SessionMetadata;
48use prometheus::Opts;
49use prometheus::core::{AtomicF64, GenericGaugeVec};
50use serde::{Deserialize, Serialize};
51use tokio::{select, time};
52use tokio_postgres::error::SqlState;
53use tokio_stream::wrappers::UnboundedReceiverStream;
54use tracing::{debug, error, info};
55use tungstenite::protocol::frame::coding::CloseCode;
56
57use crate::http::prometheus::PrometheusSqlQuery;
58use crate::http::{AuthedClient, AuthedUser, MAX_REQUEST_SIZE, WsState, init_ws};
59
60#[derive(Debug, thiserror::Error)]
61pub enum Error {
62 #[error(transparent)]
63 Adapter(#[from] AdapterError),
64 #[error(transparent)]
65 Json(#[from] serde_json::Error),
66 #[error(transparent)]
67 Axum(#[from] axum::Error),
68 #[error("SUBSCRIBE only supported over websocket")]
69 SubscribeOnlyOverWs,
70 #[error("current transaction is aborted, commands ignored until end of transaction block")]
71 AbortedTransaction,
72 #[error("unsupported via this API: {0}")]
73 Unsupported(String),
74 #[error("{0}")]
75 Unstructured(anyhow::Error),
76}
77
78impl Error {
79 pub fn detail(&self) -> Option<String> {
80 match self {
81 Error::Adapter(err) => err.detail(),
82 _ => None,
83 }
84 }
85
86 pub fn hint(&self) -> Option<String> {
87 match self {
88 Error::Adapter(err) => err.hint(),
89 _ => None,
90 }
91 }
92
93 pub fn position(&self) -> Option<usize> {
94 match self {
95 Error::Adapter(err) => err.position(),
96 _ => None,
97 }
98 }
99
100 pub fn code(&self) -> SqlState {
101 match self {
102 Error::Adapter(err) => err.code(),
103 Error::AbortedTransaction => SqlState::IN_FAILED_SQL_TRANSACTION,
104 _ => SqlState::INTERNAL_ERROR,
105 }
106 }
107}
108
109static PER_REPLICA_LABELS: &[&str] = &["replica_full_name", "instance_id", "replica_id"];
110
111async fn execute_promsql_query(
112 client: &mut AuthedClient,
113 query: &PrometheusSqlQuery<'_>,
114 metrics_registry: &MetricsRegistry,
115 metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
116 cluster: Option<(&Cluster, &ClusterReplica)>,
117) {
118 assert_eq!(query.per_replica, cluster.is_some());
119
120 let mut res = SqlResponse {
121 results: Vec::new(),
122 };
123
124 execute_request(client, query.to_sql_request(cluster), &mut res)
125 .await
126 .expect("valid SQL query");
127
128 let result = match res.results.as_slice() {
129 [
132 SqlResult::Ok { .. },
133 SqlResult::Ok { .. },
134 SqlResult::Ok { .. },
135 result,
136 ] => result,
137 _ => {
141 info!(
142 "error executing prometheus query {}: {:?}",
143 query.metric_name, res
144 );
145 return;
146 }
147 };
148
149 let SqlResult::Rows { desc, rows, .. } = result else {
150 info!(
151 "did not receive rows for SQL query for prometheus metric {}: {:?}, {:?}",
152 query.metric_name, result, cluster
153 );
154 return;
155 };
156
157 let gauge_vec = metrics_by_name
158 .entry(query.metric_name.to_string())
159 .or_insert_with(|| {
160 let mut label_names: Vec<String> = desc
161 .columns
162 .iter()
163 .filter(|col| col.name != query.value_column_name)
164 .map(|col| col.name.clone())
165 .collect();
166
167 if query.per_replica {
168 label_names.extend(PER_REPLICA_LABELS.iter().map(|label| label.to_string()));
169 }
170
171 metrics_registry.register::<GenericGaugeVec<AtomicF64>>(MakeCollectorOpts {
172 opts: Opts::new(query.metric_name, query.help).variable_labels(label_names),
173 buckets: None,
174 })
175 });
176
177 for row in rows {
178 let mut label_values = desc
179 .columns
180 .iter()
181 .zip(row)
182 .filter(|(col, _)| col.name != query.value_column_name)
183 .map(|(_, val)| val.as_str().expect("must be string"))
184 .collect::<Vec<_>>();
185
186 let value = desc
187 .columns
188 .iter()
189 .zip(row)
190 .find(|(col, _)| col.name == query.value_column_name)
191 .map(|(_, val)| val.as_str().unwrap_or("0").parse::<f64>().unwrap_or(0.0))
192 .unwrap_or(0.0);
193
194 match cluster {
195 Some((cluster, replica)) => {
196 let replica_full_name = format!("{}.{}", cluster.name, replica.name);
197 let cluster_id = cluster.id.to_string();
198 let replica_id = replica.replica_id.to_string();
199
200 label_values.push(&replica_full_name);
201 label_values.push(&cluster_id);
202 label_values.push(&replica_id);
203
204 gauge_vec
205 .get_metric_with_label_values(&label_values)
206 .expect("valid labels")
207 .set(value);
208 }
209 None => {
210 gauge_vec
211 .get_metric_with_label_values(&label_values)
212 .expect("valid labels")
213 .set(value);
214 }
215 }
216 }
217}
218
219async fn handle_promsql_query(
220 client: &mut AuthedClient,
221 query: &PrometheusSqlQuery<'_>,
222 metrics_registry: &MetricsRegistry,
223 metrics_by_name: &mut BTreeMap<String, GenericGaugeVec<AtomicF64>>,
224) {
225 if !query.per_replica {
226 execute_promsql_query(client, query, metrics_registry, metrics_by_name, None).await;
227 return;
228 }
229
230 let catalog = client.client.catalog_snapshot().await;
231 let clusters: Vec<&Cluster> = catalog.clusters().collect();
232
233 for cluster in clusters {
234 for replica in cluster.replicas() {
235 execute_promsql_query(
236 client,
237 query,
238 metrics_registry,
239 metrics_by_name,
240 Some((cluster, replica)),
241 )
242 .await;
243 }
244 }
245}
246
247pub async fn handle_promsql(
248 mut client: AuthedClient,
249 queries: &[PrometheusSqlQuery<'_>],
250) -> MetricsRegistry {
251 let metrics_registry = MetricsRegistry::new();
252 let mut metrics_by_name = BTreeMap::new();
253
254 for query in queries {
255 handle_promsql_query(&mut client, query, &metrics_registry, &mut metrics_by_name).await;
256 }
257
258 metrics_registry
259}
260
261pub async fn handle_sql(
262 mut client: AuthedClient,
263 Json(request): Json<SqlRequest>,
264) -> impl IntoResponse {
265 let mut res = SqlResponse {
266 results: Vec::new(),
267 };
268 match execute_request(&mut client, request, &mut res).await {
271 Ok(()) => Ok(Json(res)),
272 Err(e) => Err((StatusCode::BAD_REQUEST, e.to_string())),
273 }
274}
275
276pub async fn handle_sql_ws(
277 State(state): State<WsState>,
278 existing_user: Option<Extension<AuthedUser>>,
279 ws: WebSocketUpgrade,
280 ConnectInfo(addr): ConnectInfo<SocketAddr>,
281) -> impl IntoResponse {
282 let user = existing_user.and_then(|Extension(user)| Some(user));
284 let addr = Box::new(addr.ip());
285 ws.max_message_size(MAX_REQUEST_SIZE)
286 .on_upgrade(|ws| async move { run_ws(&state, user, *addr, ws).await })
287}
288
289#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
290#[serde(untagged)]
291pub enum WebSocketAuth {
292 Basic {
293 user: String,
294 password: String,
295 #[serde(default)]
296 options: BTreeMap<String, String>,
297 },
298 Bearer {
299 token: String,
300 #[serde(default)]
301 options: BTreeMap<String, String>,
302 },
303 OptionsOnly {
304 #[serde(default)]
305 options: BTreeMap<String, String>,
306 },
307}
308
309async fn run_ws(state: &WsState, user: Option<AuthedUser>, peer_addr: IpAddr, mut ws: WebSocket) {
310 let mut client = match init_ws(state, user, peer_addr, &mut ws).await {
311 Ok(client) => client,
312 Err(e) => {
313 debug!("WS request failed init: {}", e);
317 let reason = match e.downcast_ref::<AdapterError>() {
318 Some(error) => Cow::Owned(error.to_string()),
319 None => "unauthorized".into(),
320 };
321 let _ = ws
322 .send(Message::Close(Some(CloseFrame {
323 code: CloseCode::Protocol.into(),
324 reason: Utf8Bytes::from(reason.as_ref()),
325 })))
326 .await;
327 return;
328 }
329 };
330
331 let mut msgs = Vec::new();
333 let session = client.client.session();
334 for var in session.vars().notify_set() {
335 msgs.push(WebSocketResponse::ParameterStatus(ParameterStatus {
336 name: var.name().to_string(),
337 value: var.value(),
338 }));
339 }
340 msgs.push(WebSocketResponse::BackendKeyData(BackendKeyData {
341 conn_id: session.conn_id().unhandled(),
342 secret_key: session.secret_key(),
343 }));
344 msgs.push(WebSocketResponse::ReadyForQuery(
345 session.transaction_code().into(),
346 ));
347 for msg in msgs {
348 let _ = ws
349 .send(Message::Text(
350 serde_json::to_string(&msg).expect("must serialize").into(),
351 ))
352 .await;
353 }
354
355 let notices = session.drain_notices();
357 if let Err(err) = forward_notices(&mut ws, notices).await {
358 debug!("failed to forward notices to WebSocket, {err:?}");
359 return;
360 }
361
362 loop {
363 let msg = select! {
365 biased;
366
367 Some(timeout) = client.client.recv_timeout() => {
369 client.client.terminate().await;
370 let _ = ws.recv().await;
374 let err = Error::from(AdapterError::from(timeout));
375 let _ = send_ws_response(&mut ws, WebSocketResponse::Error(err.into())).await;
376 return;
377 },
378 message = ws.recv() => message,
379 };
380
381 client.client.remove_idle_in_transaction_session_timeout();
382
383 let msg = match msg {
384 Some(Ok(msg)) => msg,
385 _ => {
386 return;
388 }
389 };
390
391 let req: Result<SqlRequest, Error> = match msg {
392 Message::Text(data) => serde_json::from_str(&data).err_into(),
393 Message::Binary(data) => serde_json::from_slice(&data).err_into(),
394 Message::Ping(_) => {
396 continue;
397 }
398 Message::Pong(_) => {
399 continue;
400 }
401 Message::Close(_) => {
402 return;
403 }
404 };
405
406 let err = match run_ws_request(req, &mut client, &mut ws).await {
408 Ok(()) => None,
409 Err(err) => Some(WebSocketResponse::Error(err.into())),
410 };
411
412 let ws_response = || async {
419 if let Some(e_resp) = err {
421 send_ws_response(&mut ws, e_resp).await?;
422 }
423
424 let notices = client.client.session().drain_notices();
426 forward_notices(&mut ws, notices).await?;
427
428 let ready =
430 WebSocketResponse::ReadyForQuery(client.client.session().transaction_code().into());
431 send_ws_response(&mut ws, ready).await?;
432
433 Ok::<_, Error>(())
434 };
435
436 if let Err(err) = ws_response().await {
437 debug!("failed to send response over WebSocket, {err:?}");
438 return;
439 }
440 }
441}
442
443async fn run_ws_request(
444 req: Result<SqlRequest, Error>,
445 client: &mut AuthedClient,
446 ws: &mut WebSocket,
447) -> Result<(), Error> {
448 let req = req?;
449 execute_request(client, req, ws).await
450}
451
452async fn send_ws_response(ws: &mut WebSocket, resp: WebSocketResponse) -> Result<(), Error> {
454 let msg = serde_json::to_string(&resp).unwrap();
455 let msg = Message::Text(msg.into());
456 ws.send(msg).await?;
457
458 Ok(())
459}
460
461async fn forward_notices(
463 ws: &mut WebSocket,
464 notices: impl IntoIterator<Item = AdapterNotice>,
465) -> Result<(), Error> {
466 let ws_notices = notices.into_iter().map(|notice| {
467 WebSocketResponse::Notice(Notice {
468 message: notice.to_string(),
469 code: notice.code().code().to_string(),
470 severity: notice.severity().as_str().to_lowercase(),
471 detail: notice.detail(),
472 hint: notice.hint(),
473 })
474 });
475
476 for notice in ws_notices {
477 send_ws_response(ws, notice).await?;
478 }
479
480 Ok(())
481}
482
483#[derive(Serialize, Deserialize, Debug)]
485#[serde(untagged)]
486pub enum SqlRequest {
487 Simple {
489 query: String,
492 },
493 Extended {
495 queries: Vec<ExtendedRequest>,
497 },
498}
499
500#[derive(Serialize, Deserialize, Debug)]
502pub struct ExtendedRequest {
503 query: String,
505 #[serde(default)]
507 params: Vec<Option<String>>,
508}
509
510#[derive(Debug, Serialize, Deserialize)]
512pub struct SqlResponse {
513 results: Vec<SqlResult>,
515}
516
517enum StatementResult {
518 SqlResult(SqlResult),
519 Subscribe {
520 desc: RelationDesc,
521 tag: String,
522 rx: RecordFirstRowStream,
523 ctx_extra: ExecuteContextExtra,
524 },
525}
526
527impl From<SqlResult> for StatementResult {
528 fn from(inner: SqlResult) -> Self {
529 Self::SqlResult(inner)
530 }
531}
532
533#[derive(Debug, Serialize, Deserialize)]
535#[serde(untagged)]
536pub enum SqlResult {
537 Rows {
539 tag: String,
541 rows: Vec<Vec<serde_json::Value>>,
543 desc: Description,
545 notices: Vec<Notice>,
547 },
548 Ok {
550 ok: String,
552 notices: Vec<Notice>,
554 #[serde(skip_serializing_if = "Vec::is_empty")]
558 parameters: Vec<ParameterStatus>,
559 },
560 Err {
562 error: SqlError,
563 notices: Vec<Notice>,
565 },
566}
567
568impl SqlResult {
569 fn rows(
572 client: &mut SessionClient,
573 mut sql_rows: Box<dyn RowIterator>,
574 desc: &RelationDesc,
575 ) -> SqlResult {
576 if let Err(err) = verify_datum_desc(desc, &mut sql_rows) {
577 return SqlResult::Err {
578 error: err.into(),
579 notices: make_notices(client),
580 };
581 }
582
583 let mut rows: Vec<Vec<serde_json::Value>> = vec![];
584 let mut datum_vec = mz_repr::DatumVec::new();
585 let types = &desc.typ().column_types;
586
587 while let Some(row) = sql_rows.next() {
588 let datums = datum_vec.borrow_with(row);
589 rows.push(
590 datums
591 .iter()
592 .enumerate()
593 .map(|(i, d)| {
594 TypedDatum::new(*d, &types[i])
595 .json(&JsonNumberPolicy::ConvertNumberToString)
596 })
597 .collect(),
598 );
599 }
600
601 let tag = format!("SELECT {}", rows.len());
602 SqlResult::Rows {
603 tag,
604 rows,
605 desc: Description::from(desc),
606 notices: make_notices(client),
607 }
608 }
609
610 fn err(client: &mut SessionClient, error: impl Into<SqlError>) -> SqlResult {
611 SqlResult::Err {
612 error: error.into(),
613 notices: make_notices(client),
614 }
615 }
616
617 fn ok(client: &mut SessionClient, tag: String, params: Vec<ParameterStatus>) -> SqlResult {
618 SqlResult::Ok {
619 ok: tag,
620 parameters: params,
621 notices: make_notices(client),
622 }
623 }
624}
625
626#[derive(Debug, Deserialize, Serialize)]
627pub struct SqlError {
628 pub message: String,
629 pub code: String,
630 #[serde(skip_serializing_if = "Option::is_none")]
631 pub detail: Option<String>,
632 #[serde(skip_serializing_if = "Option::is_none")]
633 pub hint: Option<String>,
634 #[serde(skip_serializing_if = "Option::is_none")]
635 pub position: Option<usize>,
636}
637
638impl From<Error> for SqlError {
639 fn from(err: Error) -> Self {
640 SqlError {
641 message: err.to_string(),
642 code: err.code().code().to_string(),
643 detail: err.detail(),
644 hint: err.hint(),
645 position: err.position(),
646 }
647 }
648}
649
650impl From<AdapterError> for SqlError {
651 fn from(value: AdapterError) -> Self {
652 Error::from(value).into()
653 }
654}
655
656#[derive(Debug, Deserialize, Serialize)]
657#[serde(tag = "type", content = "payload")]
658pub enum WebSocketResponse {
659 ReadyForQuery(String),
660 Notice(Notice),
661 Rows(Description),
662 Row(Vec<serde_json::Value>),
663 CommandStarting(CommandStarting),
664 CommandComplete(String),
665 Error(SqlError),
666 ParameterStatus(ParameterStatus),
667 BackendKeyData(BackendKeyData),
668}
669
670#[derive(Debug, Serialize, Deserialize)]
671pub struct Notice {
672 message: String,
673 code: String,
674 severity: String,
675 #[serde(skip_serializing_if = "Option::is_none")]
676 pub detail: Option<String>,
677 #[serde(skip_serializing_if = "Option::is_none")]
678 pub hint: Option<String>,
679}
680
681impl Notice {
682 pub fn message(&self) -> &str {
683 &self.message
684 }
685}
686
687#[derive(Debug, Serialize, Deserialize)]
688pub struct Description {
689 pub columns: Vec<Column>,
690}
691
692impl From<&RelationDesc> for Description {
693 fn from(desc: &RelationDesc) -> Self {
694 let columns = desc
695 .iter()
696 .map(|(name, typ)| {
697 let pg_type = mz_pgrepr::Type::from(&typ.scalar_type);
698 Column {
699 name: name.to_string(),
700 type_oid: pg_type.oid(),
701 type_len: pg_type.typlen(),
702 type_mod: pg_type.typmod(),
703 }
704 })
705 .collect();
706 Description { columns }
707 }
708}
709
710#[derive(Debug, Serialize, Deserialize)]
711pub struct Column {
712 pub name: String,
713 pub type_oid: u32,
714 pub type_len: i16,
715 pub type_mod: i32,
716}
717
718#[derive(Debug, Serialize, Deserialize)]
719pub struct ParameterStatus {
720 name: String,
721 value: String,
722}
723
724#[derive(Debug, Serialize, Deserialize)]
725pub struct BackendKeyData {
726 conn_id: u32,
727 secret_key: u32,
728}
729
730#[derive(Debug, Serialize, Deserialize)]
731pub struct CommandStarting {
732 has_rows: bool,
733 is_streaming: bool,
734}
735
736#[async_trait]
740trait ResultSender: Send {
741 const SUPPORTS_STREAMING_NOTICES: bool = false;
742
743 async fn add_result(
751 &mut self,
752 client: &mut SessionClient,
753 res: StatementResult,
754 ) -> (
755 Result<Result<(), ()>, Error>,
756 Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
757 );
758
759 fn connection_error(&mut self) -> BoxFuture<Error>;
761 fn allow_subscribe(&self) -> bool;
763
764 async fn emit_streaming_notices(&mut self, _: Vec<AdapterNotice>) -> Result<(), Error> {
768 unreachable!("streaming notices marked as unsupported")
769 }
770}
771
772#[async_trait]
773impl ResultSender for SqlResponse {
774 async fn add_result(
782 &mut self,
783 _client: &mut SessionClient,
784 res: StatementResult,
785 ) -> (
786 Result<Result<(), ()>, Error>,
787 Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
788 ) {
789 let (res, stmt_logging) = match res {
790 StatementResult::SqlResult(res) => {
791 let is_err = matches!(res, SqlResult::Err { .. });
792 self.results.push(res);
793 let res = if is_err { Err(()) } else { Ok(()) };
794 (res, None)
795 }
796 StatementResult::Subscribe { ctx_extra, .. } => {
797 let message = "SUBSCRIBE only supported over websocket";
798 self.results.push(SqlResult::Err {
799 error: Error::SubscribeOnlyOverWs.into(),
800 notices: Vec::new(),
801 });
802 (
803 Err(()),
804 Some((
805 StatementEndedExecutionReason::Errored {
806 error: message.into(),
807 },
808 ctx_extra,
809 )),
810 )
811 }
812 };
813 (Ok(res), stmt_logging)
814 }
815
816 fn connection_error(&mut self) -> BoxFuture<Error> {
817 Box::pin(futures::future::pending())
818 }
819
820 fn allow_subscribe(&self) -> bool {
821 false
822 }
823}
824
825#[async_trait]
826impl ResultSender for WebSocket {
827 const SUPPORTS_STREAMING_NOTICES: bool = true;
828
829 async fn add_result(
835 &mut self,
836 client: &mut SessionClient,
837 res: StatementResult,
838 ) -> (
839 Result<Result<(), ()>, Error>,
840 Option<(StatementEndedExecutionReason, ExecuteContextExtra)>,
841 ) {
842 let (has_rows, is_streaming) = match res {
843 StatementResult::SqlResult(SqlResult::Err { .. }) => (false, false),
844 StatementResult::SqlResult(SqlResult::Ok { .. }) => (false, false),
845 StatementResult::SqlResult(SqlResult::Rows { .. }) => (true, false),
846 StatementResult::Subscribe { .. } => (true, true),
847 };
848 if let Err(e) = send_ws_response(
849 self,
850 WebSocketResponse::CommandStarting(CommandStarting {
851 has_rows,
852 is_streaming,
853 }),
854 )
855 .await
856 {
857 return (Err(e), None);
858 }
859
860 let (is_err, msgs, stmt_logging) = match res {
861 StatementResult::SqlResult(SqlResult::Rows {
862 tag,
863 rows,
864 desc,
865 notices,
866 }) => {
867 let mut msgs = vec![WebSocketResponse::Rows(desc)];
868 msgs.extend(rows.into_iter().map(WebSocketResponse::Row));
869 msgs.push(WebSocketResponse::CommandComplete(tag));
870 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
871 (false, msgs, None)
872 }
873 StatementResult::SqlResult(SqlResult::Ok {
874 ok,
875 parameters,
876 notices,
877 }) => {
878 let mut msgs = vec![WebSocketResponse::CommandComplete(ok)];
879 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
880 msgs.extend(
881 parameters
882 .into_iter()
883 .map(WebSocketResponse::ParameterStatus),
884 );
885 (false, msgs, None)
886 }
887 StatementResult::SqlResult(SqlResult::Err { error, notices }) => {
888 let mut msgs = vec![WebSocketResponse::Error(error)];
889 msgs.extend(notices.into_iter().map(WebSocketResponse::Notice));
890 (true, msgs, None)
891 }
892 StatementResult::Subscribe {
893 ref desc,
894 tag,
895 mut rx,
896 ctx_extra,
897 } => {
898 if let Err(e) = send_ws_response(self, WebSocketResponse::Rows(desc.into())).await {
899 return (
902 Err(e),
903 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
904 );
905 }
906
907 let mut datum_vec = mz_repr::DatumVec::new();
908 let mut result_size: usize = 0;
909 let mut rows_returned = 0;
910 loop {
911 let res = match await_rows(self, client, rx.recv()).await {
912 Ok(res) => res,
913 Err(e) => {
914 return (
917 Err(e),
918 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
919 );
920 }
921 };
922 match res {
923 Some(PeekResponseUnary::Rows(mut rows)) => {
924 if let Err(err) = verify_datum_desc(desc, &mut rows) {
925 let error = err.to_string();
926 break (
927 true,
928 vec![WebSocketResponse::Error(err.into())],
929 Some((
930 StatementEndedExecutionReason::Errored { error },
931 ctx_extra,
932 )),
933 );
934 }
935
936 rows_returned += rows.count();
937 while let Some(row) = rows.next() {
938 result_size += row.byte_len();
939 let datums = datum_vec.borrow_with(row);
940 let types = &desc.typ().column_types;
941 if let Err(e) = send_ws_response(
942 self,
943 WebSocketResponse::Row(
944 datums
945 .iter()
946 .enumerate()
947 .map(|(i, d)| {
948 TypedDatum::new(*d, &types[i])
949 .json(&JsonNumberPolicy::ConvertNumberToString)
950 })
951 .collect(),
952 ),
953 )
954 .await
955 {
956 return (
959 Err(e),
960 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
961 );
962 }
963 }
964 }
965 Some(PeekResponseUnary::Error(error)) => {
966 break (
967 true,
968 vec![WebSocketResponse::Error(
969 Error::Unstructured(anyhow!(error.clone())).into(),
970 )],
971 Some((StatementEndedExecutionReason::Errored { error }, ctx_extra)),
972 );
973 }
974 Some(PeekResponseUnary::Canceled) => {
975 break (
976 true,
977 vec![WebSocketResponse::Error(AdapterError::Canceled.into())],
978 Some((StatementEndedExecutionReason::Canceled, ctx_extra)),
979 );
980 }
981 None => {
982 break (
983 false,
984 vec![WebSocketResponse::CommandComplete(tag)],
985 Some((
986 StatementEndedExecutionReason::Success {
987 result_size: Some(u64::cast_from(result_size)),
988 rows_returned: Some(u64::cast_from(rows_returned)),
989 execution_strategy: Some(
990 StatementExecutionStrategy::Standard,
991 ),
992 },
993 ctx_extra,
994 )),
995 );
996 }
997 }
998 }
999 }
1000 };
1001 for msg in msgs {
1002 if let Err(e) = send_ws_response(self, msg).await {
1003 return (
1004 Err(e),
1005 stmt_logging.map(|(_old_reason, ctx_extra)| {
1006 (StatementEndedExecutionReason::Canceled, ctx_extra)
1007 }),
1008 );
1009 }
1010 }
1011 (Ok(if is_err { Err(()) } else { Ok(()) }), stmt_logging)
1012 }
1013
1014 fn connection_error(&mut self) -> BoxFuture<Error> {
1017 Box::pin(async {
1018 let mut tick = time::interval(Duration::from_secs(1));
1019 tick.tick().await;
1020 loop {
1021 tick.tick().await;
1022 if let Err(err) = self.send(Message::Ping(Vec::new().into())).await {
1023 return err.into();
1024 }
1025 }
1026 })
1027 }
1028
1029 fn allow_subscribe(&self) -> bool {
1030 true
1031 }
1032
1033 async fn emit_streaming_notices(&mut self, notices: Vec<AdapterNotice>) -> Result<(), Error> {
1034 forward_notices(self, notices).await
1035 }
1036}
1037
1038async fn await_rows<S, F, R>(sender: &mut S, client: &mut SessionClient, f: F) -> Result<R, Error>
1039where
1040 S: ResultSender,
1041 F: Future<Output = R> + Send,
1042{
1043 let mut f = pin!(f);
1044 loop {
1045 tokio::select! {
1046 notice = client.session().recv_notice(), if S::SUPPORTS_STREAMING_NOTICES => {
1047 sender.emit_streaming_notices(vec![notice]).await?;
1048 }
1049 e = sender.connection_error() => return Err(e),
1050 r = &mut f => return Ok(r),
1051 }
1052 }
1053}
1054
1055async fn send_and_retire<S: ResultSender>(
1056 res: StatementResult,
1057 client: &mut SessionClient,
1058 sender: &mut S,
1059) -> Result<Result<(), ()>, Error> {
1060 let (res, stmt_logging) = sender.add_result(client, res).await;
1061 if let Some((reason, ctx_extra)) = stmt_logging {
1062 client.retire_execute(ctx_extra, reason);
1063 }
1064 res
1065}
1066
1067async fn execute_stmt_group<S: ResultSender>(
1069 client: &mut SessionClient,
1070 sender: &mut S,
1071 stmt_group: Vec<(Statement<Raw>, String, Vec<Option<String>>)>,
1072) -> Result<Result<(), ()>, Error> {
1073 let num_stmts = stmt_group.len();
1074 for (stmt, sql, params) in stmt_group {
1075 assert!(
1076 num_stmts <= 1 || params.is_empty(),
1077 "statement groups contain more than 1 statement iff Simple request, which does not support parameters"
1078 );
1079
1080 let is_aborted_txn = matches!(client.session().transaction(), TransactionStatus::Failed(_));
1081 if is_aborted_txn && !is_txn_exit_stmt(&stmt) {
1082 let err = SqlResult::err(client, Error::AbortedTransaction);
1083 let _ = send_and_retire(err.into(), client, sender).await?;
1084 return Ok(Err(()));
1085 }
1086
1087 if let Err(e) = client.start_transaction(Some(num_stmts)) {
1090 let err = SqlResult::err(client, e);
1091 let _ = send_and_retire(err.into(), client, sender).await?;
1092 return Ok(Err(()));
1093 }
1094 let res = execute_stmt(client, sender, stmt, sql, params).await?;
1095 let is_err = send_and_retire(res, client, sender).await?;
1096
1097 if is_err.is_err() {
1098 let txn = client.session().transaction();
1101 match txn {
1102 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
1105 TransactionStatus::Started(_) | TransactionStatus::InTransactionImplicit(_) => {
1107 if let Err(err) = client.end_transaction(EndTransactionAction::Rollback).await {
1108 let err = SqlResult::err(client, err);
1109 let _ = send_and_retire(err.into(), client, sender).await?;
1110 }
1111 }
1112 TransactionStatus::InTransaction(_) => {
1114 client.fail_transaction();
1115 }
1116 }
1117 return Ok(Err(()));
1118 }
1119 }
1120 Ok(Ok(()))
1121}
1122
1123async fn execute_request<S: ResultSender>(
1128 client: &mut AuthedClient,
1129 request: SqlRequest,
1130 sender: &mut S,
1131) -> Result<(), Error> {
1132 let client = &mut client.client;
1133
1134 fn check_prohibited_stmts<S: ResultSender>(
1137 sender: &S,
1138 stmt: &Statement<Raw>,
1139 ) -> Result<(), Error> {
1140 let kind: StatementKind = stmt.into();
1141 let execute_responses = Plan::generated_from(&kind)
1142 .into_iter()
1143 .map(ExecuteResponse::generated_from)
1144 .flatten()
1145 .collect::<Vec<_>>();
1146
1147 let is_valid_copy = matches!(
1151 stmt,
1152 Statement::Copy(CopyStatement {
1153 direction: CopyDirection::To,
1154 target: CopyTarget::Expr(_),
1155 ..
1156 }) | Statement::Copy(CopyStatement {
1157 direction: CopyDirection::From,
1158 target: CopyTarget::Expr(_),
1159 ..
1160 })
1161 );
1162
1163 if !is_valid_copy
1164 && execute_responses.iter().any(|execute_response| {
1165 match execute_response {
1167 ExecuteResponseKind::Subscribing if sender.allow_subscribe() => false,
1168 ExecuteResponseKind::Fetch
1169 | ExecuteResponseKind::Subscribing
1170 | ExecuteResponseKind::CopyFrom
1171 | ExecuteResponseKind::DeclaredCursor
1172 | ExecuteResponseKind::ClosedCursor => true,
1173 ExecuteResponseKind::CopyTo if matches!(kind, StatementKind::Copy) => true,
1178 _ => false,
1179 }
1180 })
1181 {
1182 return Err(Error::Unsupported(stmt.to_ast_string_simple()));
1183 }
1184 Ok(())
1185 }
1186
1187 fn parse<'a>(
1188 client: &SessionClient,
1189 query: &'a str,
1190 ) -> Result<Vec<StatementParseResult<'a>>, Error> {
1191 let result = client
1192 .parse(query)
1193 .map_err(|e| Error::Unstructured(anyhow!(e)))?;
1194 result.map_err(|e| AdapterError::from(e).into())
1195 }
1196
1197 let mut stmt_groups = vec![];
1198
1199 match request {
1200 SqlRequest::Simple { query } => match parse(client, &query) {
1201 Ok(stmts) => {
1202 let mut stmt_group = Vec::with_capacity(stmts.len());
1203 let mut stmt_err = None;
1204 for StatementParseResult { ast: stmt, sql } in stmts {
1205 if let Err(err) = check_prohibited_stmts(sender, &stmt) {
1206 stmt_err = Some(err);
1207 break;
1208 }
1209 stmt_group.push((stmt, sql.to_string(), vec![]));
1210 }
1211 stmt_groups.push(stmt_err.map(Err).unwrap_or_else(|| Ok(stmt_group)));
1212 }
1213 Err(e) => stmt_groups.push(Err(e)),
1214 },
1215 SqlRequest::Extended { queries } => {
1216 for ExtendedRequest { query, params } in queries {
1217 match parse(client, &query) {
1218 Ok(mut stmts) => {
1219 if stmts.len() != 1 {
1220 return Err(Error::Unstructured(anyhow!(
1221 "each query must contain exactly 1 statement, but \"{}\" contains {}",
1222 query,
1223 stmts.len()
1224 )));
1225 }
1226
1227 let StatementParseResult { ast: stmt, sql } = stmts.pop().unwrap();
1228 stmt_groups.push(
1229 check_prohibited_stmts(sender, &stmt)
1230 .map(|_| vec![(stmt, sql.to_string(), params)]),
1231 );
1232 }
1233 Err(e) => stmt_groups.push(Err(e)),
1234 };
1235 }
1236 }
1237 }
1238
1239 for stmt_group_res in stmt_groups {
1240 let executed = match stmt_group_res {
1241 Ok(stmt_group) => execute_stmt_group(client, sender, stmt_group).await,
1242 Err(e) => {
1243 let err = SqlResult::err(client, e);
1244 let _ = send_and_retire(err.into(), client, sender).await?;
1245 Ok(Err(()))
1246 }
1247 };
1248 if client.session().transaction().is_implicit() {
1251 let ended = client.end_transaction(EndTransactionAction::Commit).await;
1252 if let Err(err) = ended {
1253 let err = SqlResult::err(client, err);
1254 let _ = send_and_retire(StatementResult::SqlResult(err), client, sender).await?;
1255 }
1256 }
1257 if executed?.is_err() {
1258 break;
1259 }
1260 }
1261
1262 Ok(())
1263}
1264
1265async fn execute_stmt<S: ResultSender>(
1267 client: &mut SessionClient,
1268 sender: &mut S,
1269 stmt: Statement<Raw>,
1270 sql: String,
1271 raw_params: Vec<Option<String>>,
1272) -> Result<StatementResult, Error> {
1273 const EMPTY_PORTAL: &str = "";
1274 if let Err(e) = client
1275 .prepare(EMPTY_PORTAL.into(), Some(stmt.clone()), sql, vec![])
1276 .await
1277 {
1278 return Ok(SqlResult::err(client, e).into());
1279 }
1280
1281 let prep_stmt = match client.get_prepared_statement(EMPTY_PORTAL).await {
1282 Ok(stmt) => stmt,
1283 Err(err) => {
1284 return Ok(SqlResult::err(client, err).into());
1285 }
1286 };
1287
1288 let param_types = &prep_stmt.desc().param_types;
1289 if param_types.len() != raw_params.len() {
1290 let message = anyhow!(
1291 "request supplied {actual} parameters, \
1292 but {statement} requires {expected}",
1293 statement = stmt.to_ast_string_simple(),
1294 actual = raw_params.len(),
1295 expected = param_types.len()
1296 );
1297 return Ok(SqlResult::err(client, Error::Unstructured(message)).into());
1298 }
1299
1300 let buf = RowArena::new();
1301 let mut params = vec![];
1302 for (raw_param, mz_typ) in izip!(raw_params, param_types) {
1303 let pg_typ = mz_pgrepr::Type::from(mz_typ);
1304 let datum = match raw_param {
1305 None => Datum::Null,
1306 Some(raw_param) => {
1307 match mz_pgrepr::Value::decode(
1308 mz_pgwire_common::Format::Text,
1309 &pg_typ,
1310 raw_param.as_bytes(),
1311 ) {
1312 Ok(param) => param.into_datum(&buf, &pg_typ),
1313 Err(err) => {
1314 let msg = anyhow!("unable to decode parameter: {}", err);
1315 return Ok(SqlResult::err(client, Error::Unstructured(msg)).into());
1316 }
1317 }
1318 }
1319 };
1320 params.push((datum, mz_typ.clone()))
1321 }
1322
1323 let result_formats = vec![
1324 mz_pgwire_common::Format::Text;
1325 prep_stmt
1326 .desc()
1327 .relation_desc
1328 .clone()
1329 .map(|desc| desc.typ().column_types.len())
1330 .unwrap_or(0)
1331 ];
1332
1333 let desc = prep_stmt.desc().clone();
1334 let revision = prep_stmt.catalog_revision;
1335 let stmt = prep_stmt.stmt().cloned();
1336 let logging = Arc::clone(prep_stmt.logging());
1337 if let Err(err) = client.session().set_portal(
1338 EMPTY_PORTAL.into(),
1339 desc,
1340 stmt,
1341 logging,
1342 params,
1343 result_formats,
1344 revision,
1345 ) {
1346 return Ok(SqlResult::err(client, err).into());
1347 }
1348
1349 let desc = client
1350 .session()
1351 .get_portal_unverified(EMPTY_PORTAL)
1353 .map(|portal| portal.desc.clone())
1354 .expect("unnamed portal should be present");
1355
1356 let res = client
1357 .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
1358 .await;
1359
1360 if S::SUPPORTS_STREAMING_NOTICES {
1361 sender
1362 .emit_streaming_notices(client.session().drain_notices())
1363 .await?;
1364 }
1365
1366 let (res, execute_started) = match res {
1367 Ok(res) => res,
1368 Err(e) => {
1369 return Ok(SqlResult::err(client, e).into());
1370 }
1371 };
1372 let tag = res.tag();
1373
1374 Ok(match res {
1375 ExecuteResponse::CreatedConnection { .. }
1376 | ExecuteResponse::CreatedDatabase { .. }
1377 | ExecuteResponse::CreatedSchema { .. }
1378 | ExecuteResponse::CreatedRole
1379 | ExecuteResponse::CreatedCluster { .. }
1380 | ExecuteResponse::CreatedClusterReplica { .. }
1381 | ExecuteResponse::CreatedTable { .. }
1382 | ExecuteResponse::CreatedIndex { .. }
1383 | ExecuteResponse::CreatedIntrospectionSubscribe
1384 | ExecuteResponse::CreatedSecret { .. }
1385 | ExecuteResponse::CreatedSource { .. }
1386 | ExecuteResponse::CreatedSink { .. }
1387 | ExecuteResponse::CreatedView { .. }
1388 | ExecuteResponse::CreatedViews { .. }
1389 | ExecuteResponse::CreatedMaterializedView { .. }
1390 | ExecuteResponse::CreatedContinualTask { .. }
1391 | ExecuteResponse::CreatedType
1392 | ExecuteResponse::CreatedNetworkPolicy
1393 | ExecuteResponse::Comment
1394 | ExecuteResponse::Deleted(_)
1395 | ExecuteResponse::DiscardedTemp
1396 | ExecuteResponse::DiscardedAll
1397 | ExecuteResponse::DroppedObject(_)
1398 | ExecuteResponse::DroppedOwned
1399 | ExecuteResponse::EmptyQuery
1400 | ExecuteResponse::GrantedPrivilege
1401 | ExecuteResponse::GrantedRole
1402 | ExecuteResponse::Inserted(_)
1403 | ExecuteResponse::Copied(_)
1404 | ExecuteResponse::Raised
1405 | ExecuteResponse::ReassignOwned
1406 | ExecuteResponse::RevokedPrivilege
1407 | ExecuteResponse::AlteredDefaultPrivileges
1408 | ExecuteResponse::RevokedRole
1409 | ExecuteResponse::StartedTransaction { .. }
1410 | ExecuteResponse::Updated(_)
1411 | ExecuteResponse::AlteredObject(_)
1412 | ExecuteResponse::AlteredRole
1413 | ExecuteResponse::AlteredSystemConfiguration
1414 | ExecuteResponse::Deallocate { .. }
1415 | ExecuteResponse::ValidatedConnection
1416 | ExecuteResponse::Prepare => SqlResult::ok(
1417 client,
1418 tag.expect("ok only called on tag-generating results"),
1419 Vec::default(),
1420 )
1421 .into(),
1422 ExecuteResponse::TransactionCommitted { params }
1423 | ExecuteResponse::TransactionRolledBack { params } => {
1424 let notify_set: mz_ore::collections::HashSet<_> = client
1425 .session()
1426 .vars()
1427 .notify_set()
1428 .map(|v| v.name().to_string())
1429 .collect();
1430 let params = params
1431 .into_iter()
1432 .filter(|(name, _value)| notify_set.contains(*name))
1433 .map(|(name, value)| ParameterStatus {
1434 name: name.to_string(),
1435 value,
1436 })
1437 .collect();
1438 SqlResult::ok(
1439 client,
1440 tag.expect("ok only called on tag-generating results"),
1441 params,
1442 )
1443 .into()
1444 }
1445 ExecuteResponse::SetVariable { name, .. } => {
1446 let mut params = Vec::with_capacity(1);
1447 if let Some(var) = client
1448 .session()
1449 .vars()
1450 .notify_set()
1451 .find(|v| v.name() == &name)
1452 {
1453 params.push(ParameterStatus {
1454 name,
1455 value: var.value(),
1456 });
1457 };
1458 SqlResult::ok(
1459 client,
1460 tag.expect("ok only called on tag-generating results"),
1461 params,
1462 )
1463 .into()
1464 }
1465 ExecuteResponse::SendingRows {
1466 future: mut rows,
1467 instance_id,
1468 strategy,
1469 } => {
1470 let rows = match await_rows(sender, client, &mut rows).await? {
1471 PeekResponseUnary::Rows(rows) => {
1472 RecordFirstRowStream::record(
1473 execute_started,
1474 client,
1475 Some(instance_id),
1476 Some(strategy),
1477 );
1478 rows
1479 }
1480 PeekResponseUnary::Error(e) => {
1481 return Ok(SqlResult::err(client, Error::Unstructured(anyhow!(e))).into());
1482 }
1483 PeekResponseUnary::Canceled => {
1484 return Ok(SqlResult::err(client, AdapterError::Canceled).into());
1485 }
1486 };
1487 SqlResult::rows(
1488 client,
1489 rows,
1490 &desc.relation_desc.expect("RelationDesc must exist"),
1491 )
1492 .into()
1493 }
1494 ExecuteResponse::SendingRowsImmediate { rows } => SqlResult::rows(
1495 client,
1496 rows,
1497 &desc.relation_desc.expect("RelationDesc must exist"),
1498 )
1499 .into(),
1500 ExecuteResponse::Subscribing {
1501 rx,
1502 ctx_extra,
1503 instance_id,
1504 } => StatementResult::Subscribe {
1505 tag: "SUBSCRIBE".into(),
1506 desc: desc.relation_desc.unwrap(),
1507 rx: RecordFirstRowStream::new(
1508 Box::new(UnboundedReceiverStream::new(rx)),
1509 execute_started,
1510 client,
1511 Some(instance_id),
1512 None,
1513 ),
1514 ctx_extra,
1515 },
1516 res @ (ExecuteResponse::Fetch { .. }
1517 | ExecuteResponse::CopyTo { .. }
1518 | ExecuteResponse::CopyFrom { .. }
1519 | ExecuteResponse::DeclaredCursor
1520 | ExecuteResponse::ClosedCursor) => SqlResult::err(
1521 client,
1522 Error::Unstructured(anyhow!(
1523 "internal error: encountered prohibited ExecuteResponse {:?}.\n\n
1524 This is a bug. Can you please file an bug report letting us know?\n
1525 https://github.com/MaterializeInc/materialize/discussions/new?category=bug-reports",
1526 ExecuteResponseKind::from(res)
1527 )),
1528 )
1529 .into(),
1530 })
1531}
1532
1533fn make_notices(client: &mut SessionClient) -> Vec<Notice> {
1534 client
1535 .session()
1536 .drain_notices()
1537 .into_iter()
1538 .map(|notice| Notice {
1539 message: notice.to_string(),
1540 code: notice.code().code().to_string(),
1541 severity: notice.severity().as_str().to_lowercase(),
1542 detail: notice.detail(),
1543 hint: notice.hint(),
1544 })
1545 .collect()
1546}
1547
1548fn is_txn_exit_stmt(stmt: &Statement<Raw>) -> bool {
1551 matches!(
1552 stmt,
1553 Statement::Commit(_) | Statement::Rollback(_) | Statement::Prepare(_)
1554 )
1555}
1556
1557#[cfg(test)]
1558mod tests {
1559 use std::collections::BTreeMap;
1560
1561 use super::WebSocketAuth;
1562
1563 #[mz_ore::test]
1564 fn smoke_test_websocket_auth_parse() {
1565 struct TestCase {
1566 json: &'static str,
1567 expected: WebSocketAuth,
1568 }
1569
1570 let test_cases = vec![
1571 TestCase {
1572 json: r#"{ "user": "mz", "password": "1234" }"#,
1573 expected: WebSocketAuth::Basic {
1574 user: "mz".to_string(),
1575 password: "1234".to_string(),
1576 options: BTreeMap::default(),
1577 },
1578 },
1579 TestCase {
1580 json: r#"{ "user": "mz", "password": "1234", "options": {} }"#,
1581 expected: WebSocketAuth::Basic {
1582 user: "mz".to_string(),
1583 password: "1234".to_string(),
1584 options: BTreeMap::default(),
1585 },
1586 },
1587 TestCase {
1588 json: r#"{ "token": "i_am_a_token" }"#,
1589 expected: WebSocketAuth::Bearer {
1590 token: "i_am_a_token".to_string(),
1591 options: BTreeMap::default(),
1592 },
1593 },
1594 TestCase {
1595 json: r#"{ "token": "i_am_a_token", "options": { "foo": "bar" } }"#,
1596 expected: WebSocketAuth::Bearer {
1597 token: "i_am_a_token".to_string(),
1598 options: BTreeMap::from([("foo".to_string(), "bar".to_string())]),
1599 },
1600 },
1601 ];
1602
1603 fn assert_parse(json: &'static str, expected: WebSocketAuth) {
1604 let parsed: WebSocketAuth = serde_json::from_str(json).unwrap();
1605 assert_eq!(parsed, expected);
1606 }
1607
1608 for TestCase { json, expected } in test_cases {
1609 assert_parse(json, expected)
1610 }
1611 }
1612}