1use std::borrow::Cow;
11use std::collections::BTreeMap;
12use std::fmt::{Debug, Display, Formatter};
13use std::future::Future;
14use std::pin::{self, Pin};
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18use anyhow::bail;
19use chrono::{DateTime, Utc};
20use derivative::Derivative;
21use futures::{Stream, StreamExt};
22use itertools::Itertools;
23use mz_adapter_types::connection::{ConnectionId, ConnectionIdType};
24use mz_auth::password::Password;
25use mz_build_info::BuildInfo;
26use mz_compute_types::ComputeInstanceId;
27use mz_ore::channel::OneshotReceiverExt;
28use mz_ore::collections::CollectionExt;
29use mz_ore::id_gen::{IdAllocator, IdAllocatorInnerBitSet, MAX_ORG_ID, org_id_conn_bits};
30use mz_ore::instrument;
31use mz_ore::now::{EpochMillis, NowFn, to_datetime};
32use mz_ore::result::ResultExt;
33use mz_ore::task::AbortOnDropHandle;
34use mz_ore::thread::JoinOnDropHandle;
35use mz_ore::tracing::OpenTelemetryContext;
36use mz_repr::{CatalogItemId, ColumnIndex, Row, SqlScalarType};
37use mz_sql::ast::{Raw, Statement};
38use mz_sql::catalog::{EnvironmentId, SessionCatalog};
39use mz_sql::session::hint::ApplicationNameHint;
40use mz_sql::session::metadata::SessionMetadata;
41use mz_sql::session::user::SUPPORT_USER;
42use mz_sql::session::vars::{CLUSTER, OwnedVarInput, SystemVars, Var};
43use mz_sql_parser::parser::{ParserStatementError, StatementParseResult};
44use prometheus::Histogram;
45use serde_json::json;
46use tokio::sync::{mpsc, oneshot};
47use tracing::error;
48use uuid::Uuid;
49
50use crate::catalog::Catalog;
51use crate::command::{
52 AuthResponse, CatalogDump, CatalogSnapshot, Command, ExecuteResponse, Response,
53 SASLChallengeResponse, SASLVerifyProofResponse,
54};
55use crate::coord::{Coordinator, ExecuteContextExtra};
56use crate::error::AdapterError;
57use crate::metrics::Metrics;
58use crate::optimize::dataflows::{EvalTime, ExprPrepStyle};
59use crate::optimize::{self, Optimize};
60use crate::session::{
61 EndTransactionAction, PreparedStatement, Session, SessionConfig, StateRevision, TransactionId,
62};
63use crate::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
64use crate::telemetry::{self, EventDetails, SegmentClientExt, StatementFailureType};
65use crate::webhook::AppendWebhookResponse;
66use crate::{AdapterNotice, AppendWebhookError, PeekResponseUnary, StartupResponse};
67
68pub struct Handle {
74 pub(crate) session_id: Uuid,
75 pub(crate) start_instant: Instant,
76 pub(crate) _thread: JoinOnDropHandle<()>,
77}
78
79impl Handle {
80 pub fn session_id(&self) -> Uuid {
86 self.session_id
87 }
88
89 pub fn start_instant(&self) -> Instant {
91 self.start_instant
92 }
93}
94
95#[derive(Debug, Clone)]
103pub struct Client {
104 build_info: &'static BuildInfo,
105 inner_cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
106 id_alloc: IdAllocator<IdAllocatorInnerBitSet>,
107 now: NowFn,
108 metrics: Metrics,
109 environment_id: EnvironmentId,
110 segment_client: Option<mz_segment::Client>,
111}
112
113impl Client {
114 pub(crate) fn new(
115 build_info: &'static BuildInfo,
116 cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
117 metrics: Metrics,
118 now: NowFn,
119 environment_id: EnvironmentId,
120 segment_client: Option<mz_segment::Client>,
121 ) -> Client {
122 let env_lower = org_id_conn_bits(&environment_id.organization_id());
130 Client {
131 build_info,
132 inner_cmd_tx: cmd_tx,
133 id_alloc: IdAllocator::new(1, MAX_ORG_ID, env_lower),
134 now,
135 metrics,
136 environment_id,
137 segment_client,
138 }
139 }
140
141 pub fn new_conn_id(&self) -> Result<ConnectionId, AdapterError> {
143 self.id_alloc.alloc().ok_or(AdapterError::IdExhaustionError)
144 }
145
146 pub fn new_session(&self, config: SessionConfig) -> Session {
150 Session::new(self.build_info, config, self.metrics().session_metrics())
154 }
155
156 pub async fn authenticate(
158 &self,
159 user: &String,
160 password: &Password,
161 ) -> Result<AuthResponse, AdapterError> {
162 let (tx, rx) = oneshot::channel();
163 self.send(Command::AuthenticatePassword {
164 role_name: user.to_string(),
165 password: Some(password.clone()),
166 tx,
167 });
168 let response = rx.await.expect("sender dropped")?;
169 Ok(response)
170 }
171
172 pub async fn generate_sasl_challenge(
173 &self,
174 user: &String,
175 client_nonce: &String,
176 ) -> Result<SASLChallengeResponse, AdapterError> {
177 let (tx, rx) = oneshot::channel();
178 self.send(Command::AuthenticateGetSASLChallenge {
179 role_name: user.to_string(),
180 nonce: client_nonce.to_string(),
181 tx,
182 });
183 let response = rx.await.expect("sender dropped")?;
184 Ok(response)
185 }
186
187 pub async fn verify_sasl_proof(
188 &self,
189 user: &String,
190 proof: &String,
191 nonce: &String,
192 mock_hash: &String,
193 ) -> Result<SASLVerifyProofResponse, AdapterError> {
194 let (tx, rx) = oneshot::channel();
195 self.send(Command::AuthenticateVerifySASLProof {
196 role_name: user.to_string(),
197 proof: proof.to_string(),
198 auth_message: nonce.to_string(),
199 mock_hash: mock_hash.to_string(),
200 tx,
201 });
202 let response = rx.await.expect("sender dropped")?;
203 Ok(response)
204 }
205
206 #[mz_ore::instrument(level = "debug")]
215 pub async fn startup(&self, session: Session) -> Result<SessionClient, AdapterError> {
216 let user = session.user().clone();
217 let conn_id = session.conn_id().clone();
218 let secret_key = session.secret_key();
219 let uuid = session.uuid();
220 let client_ip = session.client_ip();
221 let application_name = session.application_name().into();
222 let notice_tx = session.retain_notice_transmitter();
223
224 let (tx, rx) = oneshot::channel();
225
226 let rx = rx.with_guard(|_| {
232 self.send(Command::Terminate {
233 conn_id: conn_id.clone(),
234 tx: None,
235 });
236 });
237
238 self.send(Command::Startup {
239 tx,
240 user,
241 conn_id: conn_id.clone(),
242 secret_key,
243 uuid,
244 client_ip: client_ip.copied(),
245 application_name,
246 notice_tx,
247 });
248
249 let response = rx.await.expect("sender dropped")?;
252
253 let mut client = SessionClient {
256 inner: Some(self.clone()),
257 session: Some(session),
258 timeouts: Timeout::new(),
259 environment_id: self.environment_id.clone(),
260 segment_client: self.segment_client.clone(),
261 };
262
263 let StartupResponse {
264 role_id,
265 write_notify,
266 session_defaults,
267 catalog,
268 } = response;
269
270 let session = client.session();
271 session.initialize_role_metadata(role_id);
272 let vars_mut = session.vars_mut();
273 for (name, val) in session_defaults {
274 if let Err(err) = vars_mut.set_default(&name, val.borrow()) {
275 tracing::error!("failed to set peristed default, {err:?}");
278 }
279 }
280 session
281 .vars_mut()
282 .end_transaction(EndTransactionAction::Commit);
283
284 session.set_builtin_table_updates(write_notify);
292
293 let catalog = catalog.for_session(session);
294
295 let cluster_active = session.vars().cluster().to_string();
296 if session.vars().welcome_message() {
297 let cluster_info = if catalog.resolve_cluster(Some(&cluster_active)).is_err() {
298 format!("{cluster_active} (does not exist)")
299 } else {
300 cluster_active.to_string()
301 };
302
303 session.add_notice(AdapterNotice::Welcome(format!(
307 "connected to Materialize v{}
308 Org ID: {}
309 Region: {}
310 User: {}
311 Cluster: {}
312 Database: {}
313 {}
314 Session UUID: {}
315
316Issue a SQL query to get started. Need help?
317 View documentation: https://materialize.com/s/docs
318 Join our Slack community: https://materialize.com/s/chat
319 ",
320 session.vars().build_info().semver_version(),
321 self.environment_id.organization_id(),
322 self.environment_id.region(),
323 session.vars().user().name,
324 cluster_info,
325 session.vars().database(),
326 match session.vars().search_path() {
327 [schema] => format!("Schema: {}", schema),
328 schemas => format!(
329 "Search path: {}",
330 schemas.iter().map(|id| id.to_string()).join(", ")
331 ),
332 },
333 session.uuid(),
334 )));
335 }
336
337 if session.vars().current_object_missing_warnings() {
338 if catalog.active_database().is_none() {
339 let db = session.vars().database().into();
340 session.add_notice(AdapterNotice::UnknownSessionDatabase(db));
341 }
342 }
343
344 let cluster_var = session
347 .vars()
348 .inspect(CLUSTER.name())
349 .expect("cluster should exist");
350 if session.vars().current_object_missing_warnings()
351 && catalog.resolve_cluster(Some(&cluster_active)).is_err()
352 {
353 let cluster_notice = 'notice: {
354 if cluster_var.inspect_session_value().is_some() {
355 break 'notice Some(AdapterNotice::DefaultClusterDoesNotExist {
356 name: cluster_active,
357 kind: "session",
358 suggested_action: "Pick an extant cluster with SET CLUSTER = name. Run SHOW CLUSTERS to see available clusters.".into(),
359 });
360 }
361
362 let role_default = catalog.get_role(catalog.active_role_id());
363 let role_cluster = match role_default.vars().get(CLUSTER.name()) {
364 Some(OwnedVarInput::Flat(name)) => Some(name),
365 None => None,
366 Some(v @ OwnedVarInput::SqlSet(_)) => {
368 tracing::warn!(?v, "SqlSet found for cluster Role Default");
369 break 'notice None;
370 }
371 };
372
373 let alter_role = "with `ALTER ROLE <role> SET cluster TO <cluster>;`";
374 match role_cluster {
375 None => Some(AdapterNotice::DefaultClusterDoesNotExist {
377 name: cluster_active,
378 kind: "system",
379 suggested_action: format!(
380 "Set a default cluster for the current role {alter_role}."
381 ),
382 }),
383 Some(_) => Some(AdapterNotice::DefaultClusterDoesNotExist {
385 name: cluster_active,
386 kind: "role",
387 suggested_action: format!(
388 "Change the default cluster for the current role {alter_role}."
389 ),
390 }),
391 }
392 };
393
394 if let Some(notice) = cluster_notice {
395 session.add_notice(notice);
396 }
397 }
398
399 Ok(client)
400 }
401
402 pub fn cancel_request(&self, conn_id: ConnectionIdType, secret_key: u32) {
404 self.send(Command::CancelRequest {
405 conn_id,
406 secret_key,
407 });
408 }
409
410 pub async fn support_execute_one(
413 &self,
414 sql: &str,
415 ) -> Result<Pin<Box<dyn Stream<Item = PeekResponseUnary> + Send + Sync>>, anyhow::Error> {
416 let conn_id = self.new_conn_id()?;
418 let session = self.new_session(SessionConfig {
419 conn_id,
420 uuid: Uuid::new_v4(),
421 user: SUPPORT_USER.name.clone(),
422 client_ip: None,
423 external_metadata_rx: None,
424 internal_user_metadata: None,
425 helm_chart_version: None,
426 });
427 let mut session_client = self.startup(session).await?;
428
429 let stmts = mz_sql::parse::parse(sql)?;
431 if stmts.len() != 1 {
432 bail!("must supply exactly one query");
433 }
434 let StatementParseResult { ast: stmt, sql } = stmts.into_element();
435
436 const EMPTY_PORTAL: &str = "";
437 session_client.start_transaction(Some(1))?;
438 session_client
439 .declare(EMPTY_PORTAL.into(), stmt, sql.to_string())
440 .await?;
441
442 match session_client
443 .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
444 .await?
445 {
446 (ExecuteResponse::SendingRowsStreaming { mut rows, .. }, _) => {
447 let owning_response_stream = async_stream::stream! {
452 while let Some(rows) = rows.next().await {
453 yield rows;
454 }
455 drop(session_client);
456 };
457 Ok(Box::pin(owning_response_stream))
458 }
459 r => bail!("unsupported response type: {r:?}"),
460 }
461 }
462
463 pub fn metrics(&self) -> &Metrics {
465 &self.metrics
466 }
467
468 pub fn now(&self) -> DateTime<Utc> {
470 to_datetime((self.now)())
471 }
472
473 pub async fn get_webhook_appender(
475 &self,
476 database: String,
477 schema: String,
478 name: String,
479 ) -> Result<AppendWebhookResponse, AppendWebhookError> {
480 let (tx, rx) = oneshot::channel();
481
482 self.send(Command::GetWebhook {
484 database,
485 schema,
486 name,
487 tx,
488 });
489
490 let response = rx
492 .await
493 .map_err(|_| anyhow::anyhow!("failed to receive webhook response"))?;
494
495 response
496 }
497
498 pub async fn get_system_vars(&self) -> SystemVars {
500 let (tx, rx) = oneshot::channel();
501 self.send(Command::GetSystemVars { tx });
502 rx.await.expect("coordinator unexpectedly gone")
503 }
504
505 #[instrument(level = "debug")]
506 fn send(&self, cmd: Command) {
507 self.inner_cmd_tx
508 .send((OpenTelemetryContext::obtain(), cmd))
509 .expect("coordinator unexpectedly gone");
510 }
511}
512
513pub struct SessionClient {
517 inner: Option<Client>,
521 session: Option<Session>,
524 timeouts: Timeout,
525 segment_client: Option<mz_segment::Client>,
526 environment_id: EnvironmentId,
527}
528
529impl SessionClient {
530 pub fn parse<'a>(
533 &self,
534 sql: &'a str,
535 ) -> Result<Result<Vec<StatementParseResult<'a>>, ParserStatementError>, String> {
536 match mz_sql::parse::parse_with_limit(sql) {
537 Ok(Err(e)) => {
538 self.track_statement_parse_failure(&e);
539 Ok(Err(e))
540 }
541 r => r,
542 }
543 }
544
545 fn track_statement_parse_failure(&self, parse_error: &ParserStatementError) {
546 let session = self.session.as_ref().expect("session invariant violated");
547 let Some(user_id) = session.user().external_metadata.as_ref().map(|m| m.user_id) else {
548 return;
549 };
550 let Some(segment_client) = &self.segment_client else {
551 return;
552 };
553 let Some(statement_kind) = parse_error.statement else {
554 return;
555 };
556 let Some((action, object_type)) = telemetry::analyze_audited_statement(statement_kind)
557 else {
558 return;
559 };
560 let event_type = StatementFailureType::ParseFailure;
561 let event_name = format!(
562 "{} {} {}",
563 object_type.as_title_case(),
564 action.as_title_case(),
565 event_type.as_title_case(),
566 );
567 segment_client.environment_track(
568 &self.environment_id,
569 event_name,
570 json!({
571 "statement_kind": statement_kind,
572 "error": &parse_error.error,
573 }),
574 EventDetails {
575 user_id: Some(user_id),
576 application_name: Some(session.application_name()),
577 ..Default::default()
578 },
579 );
580 }
581
582 pub async fn get_prepared_statement(
585 &mut self,
586 name: &str,
587 ) -> Result<&PreparedStatement, AdapterError> {
588 let catalog = self.catalog_snapshot("get_prepared_statement").await;
589 Coordinator::verify_prepared_statement(&catalog, self.session(), name)?;
590 Ok(self
591 .session()
592 .get_prepared_statement_unverified(name)
593 .expect("must exist"))
594 }
595
596 pub async fn prepare(
601 &mut self,
602 name: String,
603 stmt: Option<Statement<Raw>>,
604 sql: String,
605 param_types: Vec<Option<SqlScalarType>>,
606 ) -> Result<(), AdapterError> {
607 let catalog = self.catalog_snapshot("prepare").await;
608
609 let mut async_pause = false;
612 (|| {
613 fail::fail_point!("async_prepare", |val| {
614 async_pause = val.map_or(false, |val| val.parse().unwrap_or(false))
615 });
616 })();
617 if async_pause {
618 tokio::time::sleep(Duration::from_secs(1)).await;
619 };
620
621 let desc = Coordinator::describe(&catalog, self.session(), stmt.clone(), param_types)?;
622 let now = self.now();
623 let state_revision = StateRevision {
624 catalog_revision: catalog.transient_revision(),
625 session_state_revision: self.session().state_revision(),
626 };
627 self.session()
628 .set_prepared_statement(name, stmt, sql, desc, state_revision, now);
629 Ok(())
630 }
631
632 #[mz_ore::instrument(level = "debug")]
634 pub async fn declare(
635 &mut self,
636 name: String,
637 stmt: Statement<Raw>,
638 sql: String,
639 ) -> Result<(), AdapterError> {
640 let catalog = self.catalog_snapshot("declare").await;
641 let param_types = vec![];
642 let desc =
643 Coordinator::describe(&catalog, self.session(), Some(stmt.clone()), param_types)?;
644 let params = vec![];
645 let result_formats = vec![mz_pgwire_common::Format::Text; desc.arity()];
646 let now = self.now();
647 let logging = self.session().mint_logging(sql, Some(&stmt), now);
648 let state_revision = StateRevision {
649 catalog_revision: catalog.transient_revision(),
650 session_state_revision: self.session().state_revision(),
651 };
652 self.session().set_portal(
653 name,
654 desc,
655 Some(stmt),
656 logging,
657 params,
658 result_formats,
659 state_revision,
660 )?;
661 Ok(())
662 }
663
664 #[mz_ore::instrument(level = "debug")]
668 pub async fn execute(
669 &mut self,
670 portal_name: String,
671 cancel_future: impl Future<Output = std::io::Error> + Send,
672 outer_ctx_extra: Option<ExecuteContextExtra>,
673 ) -> Result<(ExecuteResponse, Instant), AdapterError> {
674 let execute_started = Instant::now();
675 let response = self
676 .send_with_cancel(
677 |tx, session| Command::Execute {
678 portal_name,
679 session,
680 tx,
681 outer_ctx_extra,
682 },
683 cancel_future,
684 )
685 .await?;
686 Ok((response, execute_started))
687 }
688
689 fn now(&self) -> EpochMillis {
690 (self.inner().now)()
691 }
692
693 fn now_datetime(&self) -> DateTime<Utc> {
694 to_datetime(self.now())
695 }
696
697 pub fn start_transaction(&mut self, implicit: Option<usize>) -> Result<(), AdapterError> {
703 let now = self.now_datetime();
704 let session = self.session.as_mut().expect("session invariant violated");
705 let result = match implicit {
706 None => session.start_transaction(now, None, None),
707 Some(stmts) => {
708 session.start_transaction_implicit(now, stmts);
709 Ok(())
710 }
711 };
712 result
713 }
714
715 #[instrument(level = "debug")]
718 pub async fn end_transaction(
719 &mut self,
720 action: EndTransactionAction,
721 ) -> Result<ExecuteResponse, AdapterError> {
722 let res = self
723 .send(|tx, session| Command::Commit {
724 action,
725 session,
726 tx,
727 })
728 .await;
729 let _ = self.session().clear_transaction();
733 res
734 }
735
736 pub fn fail_transaction(&mut self) {
738 let session = self.session.take().expect("session invariant violated");
739 let session = session.fail_transaction();
740 self.session = Some(session);
741 }
742
743 #[instrument(level = "debug")]
745 pub async fn catalog_snapshot(&self, context: &str) -> Arc<Catalog> {
746 let start = std::time::Instant::now();
747 let CatalogSnapshot { catalog } = self
748 .send_without_session(|tx| Command::CatalogSnapshot { tx })
749 .await;
750 self.inner()
751 .metrics()
752 .catalog_snapshot_seconds
753 .with_label_values(&[context])
754 .observe(start.elapsed().as_secs_f64());
755 catalog
756 }
757
758 pub async fn dump_catalog(&self) -> Result<CatalogDump, AdapterError> {
763 let catalog = self.catalog_snapshot("dump_catalog").await;
764 catalog.dump().map_err(AdapterError::from)
765 }
766
767 pub async fn check_catalog(&self) -> Result<(), serde_json::Value> {
773 let catalog = self.catalog_snapshot("check_catalog").await;
774 catalog.check_consistency()
775 }
776
777 pub async fn check_coordinator(&self) -> Result<(), serde_json::Value> {
783 self.send_without_session(|tx| Command::CheckConsistency { tx })
784 .await
785 .map_err(|inconsistencies| {
786 serde_json::to_value(inconsistencies).unwrap_or_else(|_| {
787 serde_json::Value::String("failed to serialize inconsistencies".to_string())
788 })
789 })
790 }
791
792 pub async fn dump_coordinator_state(&self) -> Result<serde_json::Value, anyhow::Error> {
793 self.send_without_session(|tx| Command::Dump { tx }).await
794 }
795
796 pub fn retire_execute(&self, data: ExecuteContextExtra, reason: StatementEndedExecutionReason) {
799 if !data.is_trivial() {
800 let cmd = Command::RetireExecute { data, reason };
801 self.inner().send(cmd);
802 }
803 }
804
805 pub async fn insert_rows(
811 &mut self,
812 target_id: CatalogItemId,
813 target_name: String,
814 columns: Vec<ColumnIndex>,
815 rows: Vec<Row>,
816 ctx_extra: ExecuteContextExtra,
817 ) -> Result<ExecuteResponse, AdapterError> {
818 let pcx = self.session().pcx().clone();
821
822 let session_meta = self.session().meta();
823
824 let catalog = self.catalog_snapshot("insert_rows").await;
825 let conn_catalog = catalog.for_session(self.session());
826 let catalog_state = conn_catalog.state();
827
828 let optimizer_config = optimize::OptimizerConfig::from(conn_catalog.system_vars());
830 let prep = ExprPrepStyle::OneShot {
831 logical_time: EvalTime::NotAvailable,
832 session: &session_meta,
833 catalog_state,
834 };
835 let mut optimizer =
836 optimize::view::Optimizer::new_with_prep_no_limit(optimizer_config.clone(), None, prep);
837
838 let result: Result<_, AdapterError> = mz_sql::plan::plan_copy_from(
839 &pcx,
840 &conn_catalog,
841 target_id,
842 target_name,
843 columns,
844 rows,
845 )
846 .err_into()
847 .and_then(|values| optimizer.optimize(values).err_into())
848 .and_then(|values| {
849 Coordinator::insert_constant(&catalog, self.session(), target_id, values.into_inner())
851 });
852 self.retire_execute(ctx_extra, (&result).into());
853 result
854 }
855
856 pub async fn get_system_vars(&self) -> SystemVars {
858 self.inner().get_system_vars().await
859 }
860
861 pub async fn set_system_vars(
863 &mut self,
864 vars: BTreeMap<String, String>,
865 ) -> Result<(), AdapterError> {
866 let conn_id = self.session().conn_id().clone();
867 self.send_without_session(|tx| Command::SetSystemVars { vars, conn_id, tx })
868 .await
869 }
870
871 pub async fn terminate(&mut self) {
873 let conn_id = self.session().conn_id().clone();
874 let res = self
875 .send_without_session(|tx| Command::Terminate {
876 conn_id,
877 tx: Some(tx),
878 })
879 .await;
880 if let Err(e) = res {
881 error!("Unable to terminate session: {e:?}");
883 }
884 self.inner = None;
886 }
887
888 pub fn session(&mut self) -> &mut Session {
890 self.session.as_mut().expect("session invariant violated")
891 }
892
893 pub fn inner(&self) -> &Client {
895 self.inner.as_ref().expect("inner invariant violated")
896 }
897
898 async fn send_without_session<T, F>(&self, f: F) -> T
899 where
900 F: FnOnce(oneshot::Sender<T>) -> Command,
901 {
902 let (tx, rx) = oneshot::channel();
903 self.inner().send(f(tx));
904 rx.await.expect("sender dropped")
905 }
906
907 #[instrument(level = "debug")]
908 async fn send<T, F>(&mut self, f: F) -> Result<T, AdapterError>
909 where
910 F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
911 {
912 self.send_with_cancel(f, futures::future::pending()).await
913 }
914
915 #[instrument(level = "debug")]
919 async fn send_with_cancel<T, F>(
920 &mut self,
921 f: F,
922 cancel_future: impl Future<Output = std::io::Error> + Send,
923 ) -> Result<T, AdapterError>
924 where
925 F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
926 {
927 let session = self.session.take().expect("session invariant violated");
928 let mut typ = None;
929 let application_name = session.application_name();
930 let name_hint = ApplicationNameHint::from_str(application_name);
931 let conn_id = session.conn_id().clone();
932 let (tx, rx) = oneshot::channel();
933
934 let Self {
937 inner: inner_client,
938 session: client_session,
939 ..
940 } = self;
941
942 let inner_client = inner_client.as_ref().expect("inner invariant violated");
945
946 let mut guarded_rx = rx.with_guard(|response: Response<_>| {
952 *client_session = Some(response.session);
953 });
954
955 inner_client.send({
956 let cmd = f(tx, session);
957 match cmd {
961 Command::Execute { .. } => typ = Some("execute"),
962 Command::GetWebhook { .. } => typ = Some("webhook"),
963 Command::Startup { .. }
964 | Command::AuthenticatePassword { .. }
965 | Command::AuthenticateGetSASLChallenge { .. }
966 | Command::AuthenticateVerifySASLProof { .. }
967 | Command::CatalogSnapshot { .. }
968 | Command::Commit { .. }
969 | Command::CancelRequest { .. }
970 | Command::PrivilegedCancelRequest { .. }
971 | Command::GetSystemVars { .. }
972 | Command::SetSystemVars { .. }
973 | Command::Terminate { .. }
974 | Command::RetireExecute { .. }
975 | Command::CheckConsistency { .. }
976 | Command::Dump { .. } => {}
977 };
978 cmd
979 });
980
981 let mut cancel_future = pin::pin!(cancel_future);
982 let mut cancelled = false;
983 loop {
984 tokio::select! {
985 res = &mut guarded_rx => {
986 drop(guarded_rx);
988
989 let res = res.expect("sender dropped");
990 let status = res.result.is_ok().then_some("success").unwrap_or("error");
991 if let Err(err) = res.result.as_ref() {
992 if name_hint.should_trace_errors() {
993 tracing::warn!(?err, ?name_hint, "adapter response error");
994 }
995 }
996
997 if let Some(typ) = typ {
998 inner_client
999 .metrics
1000 .commands
1001 .with_label_values(&[typ, status, name_hint.as_str()])
1002 .inc();
1003 }
1004 *client_session = Some(res.session);
1005 return res.result;
1006 },
1007 _err = &mut cancel_future, if !cancelled => {
1008 cancelled = true;
1009 inner_client.send(Command::PrivilegedCancelRequest {
1010 conn_id: conn_id.clone(),
1011 });
1012 }
1013 };
1014 }
1015 }
1016
1017 pub fn add_idle_in_transaction_session_timeout(&mut self) {
1018 let session = self.session();
1019 let timeout_dur = session.vars().idle_in_transaction_session_timeout();
1020 if !timeout_dur.is_zero() {
1021 let timeout_dur = timeout_dur.clone();
1022 if let Some(txn) = session.transaction().inner() {
1023 let txn_id = txn.id.clone();
1024 let timeout = TimeoutType::IdleInTransactionSession(txn_id);
1025 self.timeouts.add_timeout(timeout, timeout_dur);
1026 }
1027 }
1028 }
1029
1030 pub fn remove_idle_in_transaction_session_timeout(&mut self) {
1031 let session = self.session();
1032 if let Some(txn) = session.transaction().inner() {
1033 let txn_id = txn.id.clone();
1034 self.timeouts
1035 .remove_timeout(&TimeoutType::IdleInTransactionSession(txn_id));
1036 }
1037 }
1038
1039 pub async fn recv_timeout(&mut self) -> Option<TimeoutType> {
1046 self.timeouts.recv().await
1047 }
1048}
1049
1050impl Drop for SessionClient {
1051 fn drop(&mut self) {
1052 if let Some(session) = self.session.take() {
1056 if let Some(inner) = &self.inner {
1059 inner.send(Command::Terminate {
1060 conn_id: session.conn_id().clone(),
1061 tx: None,
1062 })
1063 }
1064 }
1065 }
1066}
1067
1068#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
1069pub enum TimeoutType {
1070 IdleInTransactionSession(TransactionId),
1071}
1072
1073impl Display for TimeoutType {
1074 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1075 match self {
1076 TimeoutType::IdleInTransactionSession(txn_id) => {
1077 writeln!(f, "Idle in transaction session for transaction '{txn_id}'")
1078 }
1079 }
1080 }
1081}
1082
1083impl From<TimeoutType> for AdapterError {
1084 fn from(timeout: TimeoutType) -> Self {
1085 match timeout {
1086 TimeoutType::IdleInTransactionSession(_) => {
1087 AdapterError::IdleInTransactionSessionTimeout
1088 }
1089 }
1090 }
1091}
1092
1093struct Timeout {
1094 tx: mpsc::UnboundedSender<TimeoutType>,
1095 rx: mpsc::UnboundedReceiver<TimeoutType>,
1096 active_timeouts: BTreeMap<TimeoutType, AbortOnDropHandle<()>>,
1097}
1098
1099impl Timeout {
1100 fn new() -> Self {
1101 let (tx, rx) = mpsc::unbounded_channel();
1102 Timeout {
1103 tx,
1104 rx,
1105 active_timeouts: BTreeMap::new(),
1106 }
1107 }
1108
1109 async fn recv(&mut self) -> Option<TimeoutType> {
1118 self.rx.recv().await
1119 }
1120
1121 fn add_timeout(&mut self, timeout: TimeoutType, duration: Duration) {
1122 let tx = self.tx.clone();
1123 let timeout_key = timeout.clone();
1124 let handle = mz_ore::task::spawn(|| format!("{timeout_key}"), async move {
1125 tokio::time::sleep(duration).await;
1126 let _ = tx.send(timeout);
1127 })
1128 .abort_on_drop();
1129 self.active_timeouts.insert(timeout_key, handle);
1130 }
1131
1132 fn remove_timeout(&mut self, timeout: &TimeoutType) {
1133 self.active_timeouts.remove(timeout);
1134
1135 let mut timeouts = Vec::new();
1137 while let Ok(pending_timeout) = self.rx.try_recv() {
1138 if timeout != &pending_timeout {
1139 timeouts.push(pending_timeout);
1140 }
1141 }
1142 for pending_timeout in timeouts {
1143 self.tx.send(pending_timeout).expect("rx is in this struct");
1144 }
1145 }
1146}
1147
1148#[derive(Derivative)]
1152#[derivative(Debug)]
1153pub struct RecordFirstRowStream {
1154 #[derivative(Debug = "ignore")]
1156 pub rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1157 pub execute_started: Instant,
1159 pub time_to_first_row_seconds: Histogram,
1162 pub saw_rows: bool,
1164 pub recorded_first_row_instant: Option<Instant>,
1166 pub no_more_rows: bool,
1168}
1169
1170impl RecordFirstRowStream {
1171 pub fn new(
1173 rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1174 execute_started: Instant,
1175 client: &SessionClient,
1176 instance_id: Option<ComputeInstanceId>,
1177 strategy: Option<StatementExecutionStrategy>,
1178 ) -> Self {
1179 let histogram = Self::histogram(client, instance_id, strategy);
1180 Self {
1181 rows,
1182 execute_started,
1183 time_to_first_row_seconds: histogram,
1184 saw_rows: false,
1185 recorded_first_row_instant: None,
1186 no_more_rows: false,
1187 }
1188 }
1189
1190 fn histogram(
1191 client: &SessionClient,
1192 instance_id: Option<ComputeInstanceId>,
1193 strategy: Option<StatementExecutionStrategy>,
1194 ) -> Histogram {
1195 let isolation_level = *client
1196 .session
1197 .as_ref()
1198 .expect("session invariant")
1199 .vars()
1200 .transaction_isolation();
1201 let instance = match instance_id {
1202 Some(i) => Cow::Owned(i.to_string()),
1203 None => Cow::Borrowed("none"),
1204 };
1205 let strategy = match strategy {
1206 Some(s) => s.name(),
1207 None => "none",
1208 };
1209
1210 client
1211 .inner()
1212 .metrics()
1213 .time_to_first_row_seconds
1214 .with_label_values(&[&instance, isolation_level.as_str(), strategy])
1215 }
1216
1217 pub fn record(
1220 execute_started: Instant,
1221 client: &SessionClient,
1222 instance_id: Option<ComputeInstanceId>,
1223 strategy: Option<StatementExecutionStrategy>,
1224 ) {
1225 Self::histogram(client, instance_id, strategy)
1226 .observe(execute_started.elapsed().as_secs_f64());
1227 }
1228
1229 pub async fn recv(&mut self) -> Option<PeekResponseUnary> {
1230 let msg = self.rows.next().await;
1231 if !self.saw_rows && matches!(msg, Some(PeekResponseUnary::Rows(_))) {
1232 self.saw_rows = true;
1233 self.time_to_first_row_seconds
1234 .observe(self.execute_started.elapsed().as_secs_f64());
1235 self.recorded_first_row_instant = Some(Instant::now());
1236 }
1237 if msg.is_none() {
1238 self.no_more_rows = true;
1239 }
1240 msg
1241 }
1242}