1use std::borrow::Cow;
11use std::collections::BTreeMap;
12use std::fmt::{Debug, Display, Formatter};
13use std::future::Future;
14use std::pin::{self, Pin};
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18use anyhow::bail;
19use chrono::{DateTime, Utc};
20use derivative::Derivative;
21use futures::{Stream, StreamExt};
22use itertools::Itertools;
23use mz_adapter_types::connection::{ConnectionId, ConnectionIdType};
24use mz_auth::password::Password;
25use mz_build_info::BuildInfo;
26use mz_compute_types::ComputeInstanceId;
27use mz_ore::channel::OneshotReceiverExt;
28use mz_ore::collections::CollectionExt;
29use mz_ore::id_gen::{IdAllocator, IdAllocatorInnerBitSet, MAX_ORG_ID, org_id_conn_bits};
30use mz_ore::instrument;
31use mz_ore::now::{EpochMillis, NowFn, to_datetime};
32use mz_ore::result::ResultExt;
33use mz_ore::task::AbortOnDropHandle;
34use mz_ore::thread::JoinOnDropHandle;
35use mz_ore::tracing::OpenTelemetryContext;
36use mz_repr::{CatalogItemId, ColumnIndex, Row, SqlScalarType};
37use mz_sql::ast::{Raw, Statement};
38use mz_sql::catalog::{EnvironmentId, SessionCatalog};
39use mz_sql::session::hint::ApplicationNameHint;
40use mz_sql::session::metadata::SessionMetadata;
41use mz_sql::session::user::SUPPORT_USER;
42use mz_sql::session::vars::{CLUSTER, OwnedVarInput, SystemVars, Var};
43use mz_sql_parser::parser::{ParserStatementError, StatementParseResult};
44use prometheus::Histogram;
45use serde_json::json;
46use tokio::sync::{mpsc, oneshot};
47use tracing::error;
48use uuid::Uuid;
49
50use crate::catalog::Catalog;
51use crate::command::{
52 AuthResponse, CatalogDump, CatalogSnapshot, Command, ExecuteResponse, Response,
53};
54use crate::coord::{Coordinator, ExecuteContextExtra};
55use crate::error::AdapterError;
56use crate::metrics::Metrics;
57use crate::optimize::dataflows::{EvalTime, ExprPrepStyle};
58use crate::optimize::{self, Optimize};
59use crate::session::{
60 EndTransactionAction, PreparedStatement, Session, SessionConfig, TransactionId,
61};
62use crate::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
63use crate::telemetry::{self, EventDetails, SegmentClientExt, StatementFailureType};
64use crate::webhook::AppendWebhookResponse;
65use crate::{AdapterNotice, AppendWebhookError, PeekResponseUnary, StartupResponse};
66
67pub struct Handle {
73 pub(crate) session_id: Uuid,
74 pub(crate) start_instant: Instant,
75 pub(crate) _thread: JoinOnDropHandle<()>,
76}
77
78impl Handle {
79 pub fn session_id(&self) -> Uuid {
85 self.session_id
86 }
87
88 pub fn start_instant(&self) -> Instant {
90 self.start_instant
91 }
92}
93
94#[derive(Debug, Clone)]
102pub struct Client {
103 build_info: &'static BuildInfo,
104 inner_cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
105 id_alloc: IdAllocator<IdAllocatorInnerBitSet>,
106 now: NowFn,
107 metrics: Metrics,
108 environment_id: EnvironmentId,
109 segment_client: Option<mz_segment::Client>,
110}
111
112impl Client {
113 pub(crate) fn new(
114 build_info: &'static BuildInfo,
115 cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
116 metrics: Metrics,
117 now: NowFn,
118 environment_id: EnvironmentId,
119 segment_client: Option<mz_segment::Client>,
120 ) -> Client {
121 let env_lower = org_id_conn_bits(&environment_id.organization_id());
129 Client {
130 build_info,
131 inner_cmd_tx: cmd_tx,
132 id_alloc: IdAllocator::new(1, MAX_ORG_ID, env_lower),
133 now,
134 metrics,
135 environment_id,
136 segment_client,
137 }
138 }
139
140 pub fn new_conn_id(&self) -> Result<ConnectionId, AdapterError> {
142 self.id_alloc.alloc().ok_or(AdapterError::IdExhaustionError)
143 }
144
145 pub fn new_session(&self, config: SessionConfig) -> Session {
149 Session::new(self.build_info, config, self.metrics().session_metrics())
153 }
154
155 pub async fn authenticate(
157 &self,
158 user: &String,
159 password: &Password,
160 ) -> Result<AuthResponse, AdapterError> {
161 let (tx, rx) = oneshot::channel();
162 self.send(Command::AuthenticatePassword {
163 role_name: user.to_string(),
164 password: Some(password.clone()),
165 tx,
166 });
167 let response = rx.await.expect("sender dropped")?;
168 Ok(response)
169 }
170
171 #[mz_ore::instrument(level = "debug")]
180 pub async fn startup(&self, session: Session) -> Result<SessionClient, AdapterError> {
181 let user = session.user().clone();
182 let conn_id = session.conn_id().clone();
183 let secret_key = session.secret_key();
184 let uuid = session.uuid();
185 let client_ip = session.client_ip();
186 let application_name = session.application_name().into();
187 let notice_tx = session.retain_notice_transmitter();
188
189 let (tx, rx) = oneshot::channel();
190
191 let rx = rx.with_guard(|_| {
197 self.send(Command::Terminate {
198 conn_id: conn_id.clone(),
199 tx: None,
200 });
201 });
202
203 self.send(Command::Startup {
204 tx,
205 user,
206 conn_id: conn_id.clone(),
207 secret_key,
208 uuid,
209 client_ip: client_ip.copied(),
210 application_name,
211 notice_tx,
212 });
213
214 let response = rx.await.expect("sender dropped")?;
217
218 let mut client = SessionClient {
221 inner: Some(self.clone()),
222 session: Some(session),
223 timeouts: Timeout::new(),
224 environment_id: self.environment_id.clone(),
225 segment_client: self.segment_client.clone(),
226 };
227
228 let StartupResponse {
229 role_id,
230 write_notify,
231 session_defaults,
232 catalog,
233 } = response;
234
235 let session = client.session();
236 session.initialize_role_metadata(role_id);
237 let vars_mut = session.vars_mut();
238 for (name, val) in session_defaults {
239 if let Err(err) = vars_mut.set_default(&name, val.borrow()) {
240 tracing::error!("failed to set peristed default, {err:?}");
243 }
244 }
245 session
246 .vars_mut()
247 .end_transaction(EndTransactionAction::Commit);
248
249 session.set_builtin_table_updates(write_notify);
257
258 let catalog = catalog.for_session(session);
259
260 let cluster_active = session.vars().cluster().to_string();
261 if session.vars().welcome_message() {
262 let cluster_info = if catalog.resolve_cluster(Some(&cluster_active)).is_err() {
263 format!("{cluster_active} (does not exist)")
264 } else {
265 cluster_active.to_string()
266 };
267
268 session.add_notice(AdapterNotice::Welcome(format!(
272 "connected to Materialize v{}
273 Org ID: {}
274 Region: {}
275 User: {}
276 Cluster: {}
277 Database: {}
278 {}
279 Session UUID: {}
280
281Issue a SQL query to get started. Need help?
282 View documentation: https://materialize.com/s/docs
283 Join our Slack community: https://materialize.com/s/chat
284 ",
285 session.vars().build_info().semver_version(),
286 self.environment_id.organization_id(),
287 self.environment_id.region(),
288 session.vars().user().name,
289 cluster_info,
290 session.vars().database(),
291 match session.vars().search_path() {
292 [schema] => format!("Schema: {}", schema),
293 schemas => format!(
294 "Search path: {}",
295 schemas.iter().map(|id| id.to_string()).join(", ")
296 ),
297 },
298 session.uuid(),
299 )));
300 }
301
302 if session.vars().current_object_missing_warnings() {
303 if catalog.active_database().is_none() {
304 let db = session.vars().database().into();
305 session.add_notice(AdapterNotice::UnknownSessionDatabase(db));
306 }
307 }
308
309 let cluster_var = session
312 .vars()
313 .inspect(CLUSTER.name())
314 .expect("cluster should exist");
315 if session.vars().current_object_missing_warnings()
316 && catalog.resolve_cluster(Some(&cluster_active)).is_err()
317 {
318 let cluster_notice = 'notice: {
319 if cluster_var.inspect_session_value().is_some() {
320 break 'notice Some(AdapterNotice::DefaultClusterDoesNotExist {
321 name: cluster_active,
322 kind: "session",
323 suggested_action: "Pick an extant cluster with SET CLUSTER = name. Run SHOW CLUSTERS to see available clusters.".into(),
324 });
325 }
326
327 let role_default = catalog.get_role(catalog.active_role_id());
328 let role_cluster = match role_default.vars().get(CLUSTER.name()) {
329 Some(OwnedVarInput::Flat(name)) => Some(name),
330 None => None,
331 Some(v @ OwnedVarInput::SqlSet(_)) => {
333 tracing::warn!(?v, "SqlSet found for cluster Role Default");
334 break 'notice None;
335 }
336 };
337
338 let alter_role = "with `ALTER ROLE <role> SET cluster TO <cluster>;`";
339 match role_cluster {
340 None => Some(AdapterNotice::DefaultClusterDoesNotExist {
342 name: cluster_active,
343 kind: "system",
344 suggested_action: format!(
345 "Set a default cluster for the current role {alter_role}."
346 ),
347 }),
348 Some(_) => Some(AdapterNotice::DefaultClusterDoesNotExist {
350 name: cluster_active,
351 kind: "role",
352 suggested_action: format!(
353 "Change the default cluster for the current role {alter_role}."
354 ),
355 }),
356 }
357 };
358
359 if let Some(notice) = cluster_notice {
360 session.add_notice(notice);
361 }
362 }
363
364 Ok(client)
365 }
366
367 pub fn cancel_request(&self, conn_id: ConnectionIdType, secret_key: u32) {
369 self.send(Command::CancelRequest {
370 conn_id,
371 secret_key,
372 });
373 }
374
375 pub async fn support_execute_one(
378 &self,
379 sql: &str,
380 ) -> Result<Pin<Box<dyn Stream<Item = PeekResponseUnary> + Send + Sync>>, anyhow::Error> {
381 let conn_id = self.new_conn_id()?;
383 let session = self.new_session(SessionConfig {
384 conn_id,
385 uuid: Uuid::new_v4(),
386 user: SUPPORT_USER.name.clone(),
387 client_ip: None,
388 external_metadata_rx: None,
389 internal_user_metadata: None,
390 helm_chart_version: None,
391 });
392 let mut session_client = self.startup(session).await?;
393
394 let stmts = mz_sql::parse::parse(sql)?;
396 if stmts.len() != 1 {
397 bail!("must supply exactly one query");
398 }
399 let StatementParseResult { ast: stmt, sql } = stmts.into_element();
400
401 const EMPTY_PORTAL: &str = "";
402 session_client.start_transaction(Some(1))?;
403 session_client
404 .declare(EMPTY_PORTAL.into(), stmt, sql.to_string())
405 .await?;
406
407 match session_client
408 .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
409 .await?
410 {
411 (ExecuteResponse::SendingRowsStreaming { mut rows, .. }, _) => {
412 let owning_response_stream = async_stream::stream! {
417 while let Some(rows) = rows.next().await {
418 yield rows;
419 }
420 drop(session_client);
421 };
422 Ok(Box::pin(owning_response_stream))
423 }
424 r => bail!("unsupported response type: {r:?}"),
425 }
426 }
427
428 pub fn metrics(&self) -> &Metrics {
430 &self.metrics
431 }
432
433 pub fn now(&self) -> DateTime<Utc> {
435 to_datetime((self.now)())
436 }
437
438 pub async fn get_webhook_appender(
440 &self,
441 database: String,
442 schema: String,
443 name: String,
444 ) -> Result<AppendWebhookResponse, AppendWebhookError> {
445 let (tx, rx) = oneshot::channel();
446
447 self.send(Command::GetWebhook {
449 database,
450 schema,
451 name,
452 tx,
453 });
454
455 let response = rx
457 .await
458 .map_err(|_| anyhow::anyhow!("failed to receive webhook response"))?;
459
460 response
461 }
462
463 pub async fn get_system_vars(&self) -> SystemVars {
465 let (tx, rx) = oneshot::channel();
466 self.send(Command::GetSystemVars { tx });
467 rx.await.expect("coordinator unexpectedly gone")
468 }
469
470 #[instrument(level = "debug")]
471 fn send(&self, cmd: Command) {
472 self.inner_cmd_tx
473 .send((OpenTelemetryContext::obtain(), cmd))
474 .expect("coordinator unexpectedly gone");
475 }
476}
477
478pub struct SessionClient {
482 inner: Option<Client>,
486 session: Option<Session>,
489 timeouts: Timeout,
490 segment_client: Option<mz_segment::Client>,
491 environment_id: EnvironmentId,
492}
493
494impl SessionClient {
495 pub fn parse<'a>(
498 &self,
499 sql: &'a str,
500 ) -> Result<Result<Vec<StatementParseResult<'a>>, ParserStatementError>, String> {
501 match mz_sql::parse::parse_with_limit(sql) {
502 Ok(Err(e)) => {
503 self.track_statement_parse_failure(&e);
504 Ok(Err(e))
505 }
506 r => r,
507 }
508 }
509
510 fn track_statement_parse_failure(&self, parse_error: &ParserStatementError) {
511 let session = self.session.as_ref().expect("session invariant violated");
512 let Some(user_id) = session.user().external_metadata.as_ref().map(|m| m.user_id) else {
513 return;
514 };
515 let Some(segment_client) = &self.segment_client else {
516 return;
517 };
518 let Some(statement_kind) = parse_error.statement else {
519 return;
520 };
521 let Some((action, object_type)) = telemetry::analyze_audited_statement(statement_kind)
522 else {
523 return;
524 };
525 let event_type = StatementFailureType::ParseFailure;
526 let event_name = format!(
527 "{} {} {}",
528 object_type.as_title_case(),
529 action.as_title_case(),
530 event_type.as_title_case(),
531 );
532 segment_client.environment_track(
533 &self.environment_id,
534 event_name,
535 json!({
536 "statement_kind": statement_kind,
537 "error": &parse_error.error,
538 }),
539 EventDetails {
540 user_id: Some(user_id),
541 application_name: Some(session.application_name()),
542 ..Default::default()
543 },
544 );
545 }
546
547 pub async fn get_prepared_statement(
550 &mut self,
551 name: &str,
552 ) -> Result<&PreparedStatement, AdapterError> {
553 let catalog = self.catalog_snapshot("get_prepared_statement").await;
554 Coordinator::verify_prepared_statement(&catalog, self.session(), name)?;
555 Ok(self
556 .session()
557 .get_prepared_statement_unverified(name)
558 .expect("must exist"))
559 }
560
561 pub async fn prepare(
566 &mut self,
567 name: String,
568 stmt: Option<Statement<Raw>>,
569 sql: String,
570 param_types: Vec<Option<SqlScalarType>>,
571 ) -> Result<(), AdapterError> {
572 let catalog = self.catalog_snapshot("prepare").await;
573
574 let mut async_pause = false;
577 (|| {
578 fail::fail_point!("async_prepare", |val| {
579 async_pause = val.map_or(false, |val| val.parse().unwrap_or(false))
580 });
581 })();
582 if async_pause {
583 tokio::time::sleep(Duration::from_secs(1)).await;
584 };
585
586 let desc = Coordinator::describe(&catalog, self.session(), stmt.clone(), param_types)?;
587 let now = self.now();
588 self.session().set_prepared_statement(
589 name,
590 stmt,
591 sql,
592 desc,
593 catalog.transient_revision(),
594 now,
595 );
596 Ok(())
597 }
598
599 #[mz_ore::instrument(level = "debug")]
601 pub async fn declare(
602 &mut self,
603 name: String,
604 stmt: Statement<Raw>,
605 sql: String,
606 ) -> Result<(), AdapterError> {
607 let catalog = self.catalog_snapshot("declare").await;
608 let param_types = vec![];
609 let desc =
610 Coordinator::describe(&catalog, self.session(), Some(stmt.clone()), param_types)?;
611 let params = vec![];
612 let result_formats = vec![mz_pgwire_common::Format::Text; desc.arity()];
613 let now = self.now();
614 let logging = self.session().mint_logging(sql, Some(&stmt), now);
615 self.session().set_portal(
616 name,
617 desc,
618 Some(stmt),
619 logging,
620 params,
621 result_formats,
622 catalog.transient_revision(),
623 )?;
624 Ok(())
625 }
626
627 #[mz_ore::instrument(level = "debug")]
631 pub async fn execute(
632 &mut self,
633 portal_name: String,
634 cancel_future: impl Future<Output = std::io::Error> + Send,
635 outer_ctx_extra: Option<ExecuteContextExtra>,
636 ) -> Result<(ExecuteResponse, Instant), AdapterError> {
637 let execute_started = Instant::now();
638 let response = self
639 .send_with_cancel(
640 |tx, session| Command::Execute {
641 portal_name,
642 session,
643 tx,
644 outer_ctx_extra,
645 },
646 cancel_future,
647 )
648 .await?;
649 Ok((response, execute_started))
650 }
651
652 fn now(&self) -> EpochMillis {
653 (self.inner().now)()
654 }
655
656 fn now_datetime(&self) -> DateTime<Utc> {
657 to_datetime(self.now())
658 }
659
660 pub fn start_transaction(&mut self, implicit: Option<usize>) -> Result<(), AdapterError> {
666 let now = self.now_datetime();
667 let session = self.session.as_mut().expect("session invariant violated");
668 let result = match implicit {
669 None => session.start_transaction(now, None, None),
670 Some(stmts) => {
671 session.start_transaction_implicit(now, stmts);
672 Ok(())
673 }
674 };
675 result
676 }
677
678 #[instrument(level = "debug")]
681 pub async fn end_transaction(
682 &mut self,
683 action: EndTransactionAction,
684 ) -> Result<ExecuteResponse, AdapterError> {
685 let res = self
686 .send(|tx, session| Command::Commit {
687 action,
688 session,
689 tx,
690 })
691 .await;
692 let _ = self.session().clear_transaction();
696 res
697 }
698
699 pub fn fail_transaction(&mut self) {
701 let session = self.session.take().expect("session invariant violated");
702 let session = session.fail_transaction();
703 self.session = Some(session);
704 }
705
706 #[instrument(level = "debug")]
708 pub async fn catalog_snapshot(&self, context: &str) -> Arc<Catalog> {
709 let start = std::time::Instant::now();
710 let CatalogSnapshot { catalog } = self
711 .send_without_session(|tx| Command::CatalogSnapshot { tx })
712 .await;
713 self.inner()
714 .metrics()
715 .catalog_snapshot_seconds
716 .with_label_values(&[context])
717 .observe(start.elapsed().as_secs_f64());
718 catalog
719 }
720
721 pub async fn dump_catalog(&self) -> Result<CatalogDump, AdapterError> {
726 let catalog = self.catalog_snapshot("dump_catalog").await;
727 catalog.dump().map_err(AdapterError::from)
728 }
729
730 pub async fn check_catalog(&self) -> Result<(), serde_json::Value> {
736 let catalog = self.catalog_snapshot("check_catalog").await;
737 catalog.check_consistency()
738 }
739
740 pub async fn check_coordinator(&self) -> Result<(), serde_json::Value> {
746 self.send_without_session(|tx| Command::CheckConsistency { tx })
747 .await
748 .map_err(|inconsistencies| {
749 serde_json::to_value(inconsistencies).unwrap_or_else(|_| {
750 serde_json::Value::String("failed to serialize inconsistencies".to_string())
751 })
752 })
753 }
754
755 pub async fn dump_coordinator_state(&self) -> Result<serde_json::Value, anyhow::Error> {
756 self.send_without_session(|tx| Command::Dump { tx }).await
757 }
758
759 pub fn retire_execute(&self, data: ExecuteContextExtra, reason: StatementEndedExecutionReason) {
762 if !data.is_trivial() {
763 let cmd = Command::RetireExecute { data, reason };
764 self.inner().send(cmd);
765 }
766 }
767
768 pub async fn insert_rows(
774 &mut self,
775 id: CatalogItemId,
776 columns: Vec<ColumnIndex>,
777 rows: Vec<Row>,
778 ctx_extra: ExecuteContextExtra,
779 ) -> Result<ExecuteResponse, AdapterError> {
780 let pcx = self.session().pcx().clone();
783
784 let session_meta = self.session().meta();
785
786 let catalog = self.catalog_snapshot("insert_rows").await;
787 let conn_catalog = catalog.for_session(self.session());
788 let catalog_state = conn_catalog.state();
789
790 let optimizer_config = optimize::OptimizerConfig::from(conn_catalog.system_vars());
792 let prep = ExprPrepStyle::OneShot {
793 logical_time: EvalTime::NotAvailable,
794 session: &session_meta,
795 catalog_state,
796 };
797 let mut optimizer =
798 optimize::view::Optimizer::new_with_prep_no_limit(optimizer_config.clone(), None, prep);
799
800 let result: Result<_, AdapterError> =
801 mz_sql::plan::plan_copy_from(&pcx, &conn_catalog, id, columns, rows)
802 .err_into()
803 .and_then(|values| optimizer.optimize(values).err_into())
804 .and_then(|values| {
805 Coordinator::insert_constant(&catalog, self.session(), id, values.into_inner())
807 });
808 self.retire_execute(ctx_extra, (&result).into());
809 result
810 }
811
812 pub async fn get_system_vars(&self) -> SystemVars {
814 self.inner().get_system_vars().await
815 }
816
817 pub async fn set_system_vars(
819 &mut self,
820 vars: BTreeMap<String, String>,
821 ) -> Result<(), AdapterError> {
822 let conn_id = self.session().conn_id().clone();
823 self.send_without_session(|tx| Command::SetSystemVars { vars, conn_id, tx })
824 .await
825 }
826
827 pub async fn terminate(&mut self) {
829 let conn_id = self.session().conn_id().clone();
830 let res = self
831 .send_without_session(|tx| Command::Terminate {
832 conn_id,
833 tx: Some(tx),
834 })
835 .await;
836 if let Err(e) = res {
837 error!("Unable to terminate session: {e:?}");
839 }
840 self.inner = None;
842 }
843
844 pub fn session(&mut self) -> &mut Session {
846 self.session.as_mut().expect("session invariant violated")
847 }
848
849 pub fn inner(&self) -> &Client {
851 self.inner.as_ref().expect("inner invariant violated")
852 }
853
854 async fn send_without_session<T, F>(&self, f: F) -> T
855 where
856 F: FnOnce(oneshot::Sender<T>) -> Command,
857 {
858 let (tx, rx) = oneshot::channel();
859 self.inner().send(f(tx));
860 rx.await.expect("sender dropped")
861 }
862
863 #[instrument(level = "debug")]
864 async fn send<T, F>(&mut self, f: F) -> Result<T, AdapterError>
865 where
866 F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
867 {
868 self.send_with_cancel(f, futures::future::pending()).await
869 }
870
871 #[instrument(level = "debug")]
875 async fn send_with_cancel<T, F>(
876 &mut self,
877 f: F,
878 cancel_future: impl Future<Output = std::io::Error> + Send,
879 ) -> Result<T, AdapterError>
880 where
881 F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
882 {
883 let session = self.session.take().expect("session invariant violated");
884 let mut typ = None;
885 let application_name = session.application_name();
886 let name_hint = ApplicationNameHint::from_str(application_name);
887 let conn_id = session.conn_id().clone();
888 let (tx, rx) = oneshot::channel();
889
890 let Self {
893 inner: inner_client,
894 session: client_session,
895 ..
896 } = self;
897
898 let inner_client = inner_client.as_ref().expect("inner invariant violated");
901
902 let mut guarded_rx = rx.with_guard(|response: Response<_>| {
908 *client_session = Some(response.session);
909 });
910
911 inner_client.send({
912 let cmd = f(tx, session);
913 match cmd {
917 Command::Execute { .. } => typ = Some("execute"),
918 Command::GetWebhook { .. } => typ = Some("webhook"),
919 Command::Startup { .. }
920 | Command::AuthenticatePassword { .. }
921 | Command::CatalogSnapshot { .. }
922 | Command::Commit { .. }
923 | Command::CancelRequest { .. }
924 | Command::PrivilegedCancelRequest { .. }
925 | Command::GetSystemVars { .. }
926 | Command::SetSystemVars { .. }
927 | Command::Terminate { .. }
928 | Command::RetireExecute { .. }
929 | Command::CheckConsistency { .. }
930 | Command::Dump { .. } => {}
931 };
932 cmd
933 });
934
935 let mut cancel_future = pin::pin!(cancel_future);
936 let mut cancelled = false;
937 loop {
938 tokio::select! {
939 res = &mut guarded_rx => {
940 drop(guarded_rx);
942
943 let res = res.expect("sender dropped");
944 let status = res.result.is_ok().then_some("success").unwrap_or("error");
945 if let Err(err) = res.result.as_ref() {
946 if name_hint.should_trace_errors() {
947 tracing::warn!(?err, ?name_hint, "adapter response error");
948 }
949 }
950
951 if let Some(typ) = typ {
952 inner_client
953 .metrics
954 .commands
955 .with_label_values(&[typ, status, name_hint.as_str()])
956 .inc();
957 }
958 *client_session = Some(res.session);
959 return res.result;
960 },
961 _err = &mut cancel_future, if !cancelled => {
962 cancelled = true;
963 inner_client.send(Command::PrivilegedCancelRequest {
964 conn_id: conn_id.clone(),
965 });
966 }
967 };
968 }
969 }
970
971 pub fn add_idle_in_transaction_session_timeout(&mut self) {
972 let session = self.session();
973 let timeout_dur = session.vars().idle_in_transaction_session_timeout();
974 if !timeout_dur.is_zero() {
975 let timeout_dur = timeout_dur.clone();
976 if let Some(txn) = session.transaction().inner() {
977 let txn_id = txn.id.clone();
978 let timeout = TimeoutType::IdleInTransactionSession(txn_id);
979 self.timeouts.add_timeout(timeout, timeout_dur);
980 }
981 }
982 }
983
984 pub fn remove_idle_in_transaction_session_timeout(&mut self) {
985 let session = self.session();
986 if let Some(txn) = session.transaction().inner() {
987 let txn_id = txn.id.clone();
988 self.timeouts
989 .remove_timeout(&TimeoutType::IdleInTransactionSession(txn_id));
990 }
991 }
992
993 pub async fn recv_timeout(&mut self) -> Option<TimeoutType> {
1000 self.timeouts.recv().await
1001 }
1002}
1003
1004impl Drop for SessionClient {
1005 fn drop(&mut self) {
1006 if let Some(session) = self.session.take() {
1010 if let Some(inner) = &self.inner {
1013 inner.send(Command::Terminate {
1014 conn_id: session.conn_id().clone(),
1015 tx: None,
1016 })
1017 }
1018 }
1019 }
1020}
1021
1022#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
1023pub enum TimeoutType {
1024 IdleInTransactionSession(TransactionId),
1025}
1026
1027impl Display for TimeoutType {
1028 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1029 match self {
1030 TimeoutType::IdleInTransactionSession(txn_id) => {
1031 writeln!(f, "Idle in transaction session for transaction '{txn_id}'")
1032 }
1033 }
1034 }
1035}
1036
1037impl From<TimeoutType> for AdapterError {
1038 fn from(timeout: TimeoutType) -> Self {
1039 match timeout {
1040 TimeoutType::IdleInTransactionSession(_) => {
1041 AdapterError::IdleInTransactionSessionTimeout
1042 }
1043 }
1044 }
1045}
1046
1047struct Timeout {
1048 tx: mpsc::UnboundedSender<TimeoutType>,
1049 rx: mpsc::UnboundedReceiver<TimeoutType>,
1050 active_timeouts: BTreeMap<TimeoutType, AbortOnDropHandle<()>>,
1051}
1052
1053impl Timeout {
1054 fn new() -> Self {
1055 let (tx, rx) = mpsc::unbounded_channel();
1056 Timeout {
1057 tx,
1058 rx,
1059 active_timeouts: BTreeMap::new(),
1060 }
1061 }
1062
1063 async fn recv(&mut self) -> Option<TimeoutType> {
1072 self.rx.recv().await
1073 }
1074
1075 fn add_timeout(&mut self, timeout: TimeoutType, duration: Duration) {
1076 let tx = self.tx.clone();
1077 let timeout_key = timeout.clone();
1078 let handle = mz_ore::task::spawn(|| format!("{timeout_key}"), async move {
1079 tokio::time::sleep(duration).await;
1080 let _ = tx.send(timeout);
1081 })
1082 .abort_on_drop();
1083 self.active_timeouts.insert(timeout_key, handle);
1084 }
1085
1086 fn remove_timeout(&mut self, timeout: &TimeoutType) {
1087 self.active_timeouts.remove(timeout);
1088
1089 let mut timeouts = Vec::new();
1091 while let Ok(pending_timeout) = self.rx.try_recv() {
1092 if timeout != &pending_timeout {
1093 timeouts.push(pending_timeout);
1094 }
1095 }
1096 for pending_timeout in timeouts {
1097 self.tx.send(pending_timeout).expect("rx is in this struct");
1098 }
1099 }
1100}
1101
1102#[derive(Derivative)]
1106#[derivative(Debug)]
1107pub struct RecordFirstRowStream {
1108 #[derivative(Debug = "ignore")]
1110 pub rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1111 pub execute_started: Instant,
1113 pub time_to_first_row_seconds: Histogram,
1116 pub saw_rows: bool,
1118 pub recorded_first_row_instant: Option<Instant>,
1120 pub no_more_rows: bool,
1122}
1123
1124impl RecordFirstRowStream {
1125 pub fn new(
1127 rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1128 execute_started: Instant,
1129 client: &SessionClient,
1130 instance_id: Option<ComputeInstanceId>,
1131 strategy: Option<StatementExecutionStrategy>,
1132 ) -> Self {
1133 let histogram = Self::histogram(client, instance_id, strategy);
1134 Self {
1135 rows,
1136 execute_started,
1137 time_to_first_row_seconds: histogram,
1138 saw_rows: false,
1139 recorded_first_row_instant: None,
1140 no_more_rows: false,
1141 }
1142 }
1143
1144 fn histogram(
1145 client: &SessionClient,
1146 instance_id: Option<ComputeInstanceId>,
1147 strategy: Option<StatementExecutionStrategy>,
1148 ) -> Histogram {
1149 let isolation_level = *client
1150 .session
1151 .as_ref()
1152 .expect("session invariant")
1153 .vars()
1154 .transaction_isolation();
1155 let instance = match instance_id {
1156 Some(i) => Cow::Owned(i.to_string()),
1157 None => Cow::Borrowed("none"),
1158 };
1159 let strategy = match strategy {
1160 Some(s) => s.name(),
1161 None => "none",
1162 };
1163
1164 client
1165 .inner()
1166 .metrics()
1167 .time_to_first_row_seconds
1168 .with_label_values(&[&instance, isolation_level.as_str(), strategy])
1169 }
1170
1171 pub fn record(
1174 execute_started: Instant,
1175 client: &SessionClient,
1176 instance_id: Option<ComputeInstanceId>,
1177 strategy: Option<StatementExecutionStrategy>,
1178 ) {
1179 Self::histogram(client, instance_id, strategy)
1180 .observe(execute_started.elapsed().as_secs_f64());
1181 }
1182
1183 pub async fn recv(&mut self) -> Option<PeekResponseUnary> {
1184 let msg = self.rows.next().await;
1185 if !self.saw_rows && matches!(msg, Some(PeekResponseUnary::Rows(_))) {
1186 self.saw_rows = true;
1187 self.time_to_first_row_seconds
1188 .observe(self.execute_started.elapsed().as_secs_f64());
1189 self.recorded_first_row_instant = Some(Instant::now());
1190 }
1191 if msg.is_none() {
1192 self.no_more_rows = true;
1193 }
1194 msg
1195 }
1196}