1use std::borrow::Cow;
11use std::collections::BTreeMap;
12use std::fmt::{Debug, Display, Formatter};
13use std::future::Future;
14use std::pin;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18use anyhow::bail;
19use chrono::{DateTime, Utc};
20use derivative::Derivative;
21use futures::{Stream, StreamExt};
22use itertools::Itertools;
23use mz_adapter_types::connection::{ConnectionId, ConnectionIdType};
24use mz_auth::password::Password;
25use mz_build_info::BuildInfo;
26use mz_compute_types::ComputeInstanceId;
27use mz_ore::channel::OneshotReceiverExt;
28use mz_ore::collections::CollectionExt;
29use mz_ore::id_gen::{IdAllocator, IdAllocatorInnerBitSet, MAX_ORG_ID, org_id_conn_bits};
30use mz_ore::instrument;
31use mz_ore::now::{EpochMillis, NowFn, to_datetime};
32use mz_ore::result::ResultExt;
33use mz_ore::task::AbortOnDropHandle;
34use mz_ore::thread::JoinOnDropHandle;
35use mz_ore::tracing::OpenTelemetryContext;
36use mz_repr::{CatalogItemId, ColumnIndex, Row, RowIterator, ScalarType};
37use mz_sql::ast::{Raw, Statement};
38use mz_sql::catalog::{EnvironmentId, SessionCatalog};
39use mz_sql::session::hint::ApplicationNameHint;
40use mz_sql::session::metadata::SessionMetadata;
41use mz_sql::session::user::SUPPORT_USER;
42use mz_sql::session::vars::{CLUSTER, OwnedVarInput, SystemVars, Var};
43use mz_sql_parser::parser::{ParserStatementError, StatementParseResult};
44use prometheus::Histogram;
45use serde_json::json;
46use tokio::sync::{mpsc, oneshot};
47use tracing::error;
48use uuid::Uuid;
49
50use crate::catalog::Catalog;
51use crate::command::{
52 AuthResponse, CatalogDump, CatalogSnapshot, Command, ExecuteResponse, Response,
53};
54use crate::coord::{Coordinator, ExecuteContextExtra};
55use crate::error::AdapterError;
56use crate::metrics::Metrics;
57use crate::optimize::{self, Optimize};
58use crate::session::{
59 EndTransactionAction, PreparedStatement, Session, SessionConfig, TransactionId,
60};
61use crate::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
62use crate::telemetry::{self, EventDetails, SegmentClientExt, StatementFailureType};
63use crate::webhook::AppendWebhookResponse;
64use crate::{AdapterNotice, AppendWebhookError, PeekResponseUnary, StartupResponse};
65
66pub struct Handle {
72 pub(crate) session_id: Uuid,
73 pub(crate) start_instant: Instant,
74 pub(crate) _thread: JoinOnDropHandle<()>,
75}
76
77impl Handle {
78 pub fn session_id(&self) -> Uuid {
84 self.session_id
85 }
86
87 pub fn start_instant(&self) -> Instant {
89 self.start_instant
90 }
91}
92
93#[derive(Debug, Clone)]
101pub struct Client {
102 build_info: &'static BuildInfo,
103 inner_cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
104 id_alloc: IdAllocator<IdAllocatorInnerBitSet>,
105 now: NowFn,
106 metrics: Metrics,
107 environment_id: EnvironmentId,
108 segment_client: Option<mz_segment::Client>,
109}
110
111impl Client {
112 pub(crate) fn new(
113 build_info: &'static BuildInfo,
114 cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
115 metrics: Metrics,
116 now: NowFn,
117 environment_id: EnvironmentId,
118 segment_client: Option<mz_segment::Client>,
119 ) -> Client {
120 let env_lower = org_id_conn_bits(&environment_id.organization_id());
128 Client {
129 build_info,
130 inner_cmd_tx: cmd_tx,
131 id_alloc: IdAllocator::new(1, MAX_ORG_ID, env_lower),
132 now,
133 metrics,
134 environment_id,
135 segment_client,
136 }
137 }
138
139 pub fn new_conn_id(&self) -> Result<ConnectionId, AdapterError> {
141 self.id_alloc.alloc().ok_or(AdapterError::IdExhaustionError)
142 }
143
144 pub fn new_session(&self, config: SessionConfig) -> Session {
148 Session::new(self.build_info, config, self.metrics().session_metrics())
152 }
153
154 pub async fn authenticate(
156 &self,
157 user: &String,
158 password: &Password,
159 ) -> Result<AuthResponse, AdapterError> {
160 let (tx, rx) = oneshot::channel();
161 self.send(Command::AuthenticatePassword {
162 role_name: user.to_string(),
163 password: Some(password.clone()),
164 tx,
165 });
166 let response = rx.await.expect("sender dropped")?;
167 Ok(response)
168 }
169
170 #[mz_ore::instrument(level = "debug")]
179 pub async fn startup(&self, session: Session) -> Result<SessionClient, AdapterError> {
180 let user = session.user().clone();
181 let conn_id = session.conn_id().clone();
182 let secret_key = session.secret_key();
183 let uuid = session.uuid();
184 let client_ip = session.client_ip();
185 let application_name = session.application_name().into();
186 let notice_tx = session.retain_notice_transmitter();
187
188 let (tx, rx) = oneshot::channel();
189
190 let rx = rx.with_guard(|_| {
196 self.send(Command::Terminate {
197 conn_id: conn_id.clone(),
198 tx: None,
199 });
200 });
201
202 self.send(Command::Startup {
203 tx,
204 user,
205 conn_id: conn_id.clone(),
206 secret_key,
207 uuid,
208 client_ip: client_ip.copied(),
209 application_name,
210 notice_tx,
211 });
212
213 let response = rx.await.expect("sender dropped")?;
216
217 let mut client = SessionClient {
220 inner: Some(self.clone()),
221 session: Some(session),
222 timeouts: Timeout::new(),
223 environment_id: self.environment_id.clone(),
224 segment_client: self.segment_client.clone(),
225 };
226
227 let StartupResponse {
228 role_id,
229 write_notify,
230 session_defaults,
231 catalog,
232 } = response;
233
234 let session = client.session();
235 session.initialize_role_metadata(role_id);
236 let vars_mut = session.vars_mut();
237 for (name, val) in session_defaults {
238 if let Err(err) = vars_mut.set_default(&name, val.borrow()) {
239 tracing::error!("failed to set peristed default, {err:?}");
242 }
243 }
244 session
245 .vars_mut()
246 .end_transaction(EndTransactionAction::Commit);
247
248 session.set_builtin_table_updates(write_notify);
256
257 let catalog = catalog.for_session(session);
258
259 let cluster_active = session.vars().cluster().to_string();
260 if session.vars().welcome_message() {
261 let cluster_info = if catalog.resolve_cluster(Some(&cluster_active)).is_err() {
262 format!("{cluster_active} (does not exist)")
263 } else {
264 cluster_active.to_string()
265 };
266
267 session.add_notice(AdapterNotice::Welcome(format!(
271 "connected to Materialize v{}
272 Org ID: {}
273 Region: {}
274 User: {}
275 Cluster: {}
276 Database: {}
277 {}
278 Session UUID: {}
279
280Issue a SQL query to get started. Need help?
281 View documentation: https://materialize.com/s/docs
282 Join our Slack community: https://materialize.com/s/chat
283 ",
284 session.vars().build_info().semver_version(),
285 self.environment_id.organization_id(),
286 self.environment_id.region(),
287 session.vars().user().name,
288 cluster_info,
289 session.vars().database(),
290 match session.vars().search_path() {
291 [schema] => format!("Schema: {}", schema),
292 schemas => format!(
293 "Search path: {}",
294 schemas.iter().map(|id| id.to_string()).join(", ")
295 ),
296 },
297 session.uuid(),
298 )));
299 }
300
301 if session.vars().current_object_missing_warnings() {
302 if catalog.active_database().is_none() {
303 let db = session.vars().database().into();
304 session.add_notice(AdapterNotice::UnknownSessionDatabase(db));
305 }
306 }
307
308 let cluster_var = session
311 .vars()
312 .inspect(CLUSTER.name())
313 .expect("cluster should exist");
314 if session.vars().current_object_missing_warnings()
315 && catalog.resolve_cluster(Some(&cluster_active)).is_err()
316 {
317 let cluster_notice = 'notice: {
318 if cluster_var.inspect_session_value().is_some() {
319 break 'notice Some(AdapterNotice::DefaultClusterDoesNotExist {
320 name: cluster_active,
321 kind: "session",
322 suggested_action: "Pick an extant cluster with SET CLUSTER = name. Run SHOW CLUSTERS to see available clusters.".into(),
323 });
324 }
325
326 let role_default = catalog.get_role(catalog.active_role_id());
327 let role_cluster = match role_default.vars().get(CLUSTER.name()) {
328 Some(OwnedVarInput::Flat(name)) => Some(name),
329 None => None,
330 Some(v @ OwnedVarInput::SqlSet(_)) => {
332 tracing::warn!(?v, "SqlSet found for cluster Role Default");
333 break 'notice None;
334 }
335 };
336
337 let alter_role = "with `ALTER ROLE <role> SET cluster TO <cluster>;`";
338 match role_cluster {
339 None => Some(AdapterNotice::DefaultClusterDoesNotExist {
341 name: cluster_active,
342 kind: "system",
343 suggested_action: format!(
344 "Set a default cluster for the current role {alter_role}."
345 ),
346 }),
347 Some(_) => Some(AdapterNotice::DefaultClusterDoesNotExist {
349 name: cluster_active,
350 kind: "role",
351 suggested_action: format!(
352 "Change the default cluster for the current role {alter_role}."
353 ),
354 }),
355 }
356 };
357
358 if let Some(notice) = cluster_notice {
359 session.add_notice(notice);
360 }
361 }
362
363 Ok(client)
364 }
365
366 pub fn cancel_request(&self, conn_id: ConnectionIdType, secret_key: u32) {
368 self.send(Command::CancelRequest {
369 conn_id,
370 secret_key,
371 });
372 }
373
374 pub async fn support_execute_one(
377 &self,
378 sql: &str,
379 ) -> Result<Box<dyn RowIterator>, anyhow::Error> {
380 let conn_id = self.new_conn_id()?;
382 let session = self.new_session(SessionConfig {
383 conn_id,
384 uuid: Uuid::new_v4(),
385 user: SUPPORT_USER.name.clone(),
386 client_ip: None,
387 external_metadata_rx: None,
388 internal_user_metadata: None,
389 helm_chart_version: None,
390 });
391 let mut session_client = self.startup(session).await?;
392
393 let stmts = mz_sql::parse::parse(sql)?;
395 if stmts.len() != 1 {
396 bail!("must supply exactly one query");
397 }
398 let StatementParseResult { ast: stmt, sql } = stmts.into_element();
399
400 const EMPTY_PORTAL: &str = "";
401 session_client.start_transaction(Some(1))?;
402 session_client
403 .declare(EMPTY_PORTAL.into(), stmt, sql.to_string())
404 .await?;
405 match session_client
406 .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
407 .await?
408 {
409 (ExecuteResponse::SendingRows { future, .. }, _) => match future.await {
410 PeekResponseUnary::Rows(rows) => Ok(rows),
411 PeekResponseUnary::Canceled => bail!("query canceled"),
412 PeekResponseUnary::Error(e) => bail!(e),
413 },
414 r => bail!("unsupported response type: {r:?}"),
415 }
416 }
417
418 pub fn metrics(&self) -> &Metrics {
420 &self.metrics
421 }
422
423 pub fn now(&self) -> DateTime<Utc> {
425 to_datetime((self.now)())
426 }
427
428 pub async fn get_webhook_appender(
430 &self,
431 database: String,
432 schema: String,
433 name: String,
434 ) -> Result<AppendWebhookResponse, AppendWebhookError> {
435 let (tx, rx) = oneshot::channel();
436
437 self.send(Command::GetWebhook {
439 database,
440 schema,
441 name,
442 tx,
443 });
444
445 let response = rx
447 .await
448 .map_err(|_| anyhow::anyhow!("failed to receive webhook response"))?;
449
450 response
451 }
452
453 pub async fn get_system_vars(&self) -> SystemVars {
455 let (tx, rx) = oneshot::channel();
456 self.send(Command::GetSystemVars { tx });
457 rx.await.expect("coordinator unexpectedly gone")
458 }
459
460 #[instrument(level = "debug")]
461 fn send(&self, cmd: Command) {
462 self.inner_cmd_tx
463 .send((OpenTelemetryContext::obtain(), cmd))
464 .expect("coordinator unexpectedly gone");
465 }
466}
467
468pub struct SessionClient {
472 inner: Option<Client>,
476 session: Option<Session>,
479 timeouts: Timeout,
480 segment_client: Option<mz_segment::Client>,
481 environment_id: EnvironmentId,
482}
483
484impl SessionClient {
485 pub fn parse<'a>(
488 &self,
489 sql: &'a str,
490 ) -> Result<Result<Vec<StatementParseResult<'a>>, ParserStatementError>, String> {
491 match mz_sql::parse::parse_with_limit(sql) {
492 Ok(Err(e)) => {
493 self.track_statement_parse_failure(&e);
494 Ok(Err(e))
495 }
496 r => r,
497 }
498 }
499
500 fn track_statement_parse_failure(&self, parse_error: &ParserStatementError) {
501 let session = self.session.as_ref().expect("session invariant violated");
502 let Some(user_id) = session.user().external_metadata.as_ref().map(|m| m.user_id) else {
503 return;
504 };
505 let Some(segment_client) = &self.segment_client else {
506 return;
507 };
508 let Some(statement_kind) = parse_error.statement else {
509 return;
510 };
511 let Some((action, object_type)) = telemetry::analyze_audited_statement(statement_kind)
512 else {
513 return;
514 };
515 let event_type = StatementFailureType::ParseFailure;
516 let event_name = format!(
517 "{} {} {}",
518 object_type.as_title_case(),
519 action.as_title_case(),
520 event_type.as_title_case(),
521 );
522 segment_client.environment_track(
523 &self.environment_id,
524 event_name,
525 json!({
526 "statement_kind": statement_kind,
527 "error": &parse_error.error,
528 }),
529 EventDetails {
530 user_id: Some(user_id),
531 application_name: Some(session.application_name()),
532 ..Default::default()
533 },
534 );
535 }
536
537 pub async fn get_prepared_statement(
540 &mut self,
541 name: &str,
542 ) -> Result<&PreparedStatement, AdapterError> {
543 let catalog = self.catalog_snapshot().await;
544 Coordinator::verify_prepared_statement(&catalog, self.session(), name)?;
545 Ok(self
546 .session()
547 .get_prepared_statement_unverified(name)
548 .expect("must exist"))
549 }
550
551 pub async fn prepare(
556 &mut self,
557 name: String,
558 stmt: Option<Statement<Raw>>,
559 sql: String,
560 param_types: Vec<Option<ScalarType>>,
561 ) -> Result<(), AdapterError> {
562 let catalog = self.catalog_snapshot().await;
563
564 let mut async_pause = false;
567 (|| {
568 fail::fail_point!("async_prepare", |val| {
569 async_pause = val.map_or(false, |val| val.parse().unwrap_or(false))
570 });
571 })();
572 if async_pause {
573 tokio::time::sleep(Duration::from_secs(1)).await;
574 };
575
576 let desc = Coordinator::describe(&catalog, self.session(), stmt.clone(), param_types)?;
577 let now = self.now();
578 self.session().set_prepared_statement(
579 name,
580 stmt,
581 sql,
582 desc,
583 catalog.transient_revision(),
584 now,
585 );
586 Ok(())
587 }
588
589 #[mz_ore::instrument(level = "debug")]
591 pub async fn declare(
592 &mut self,
593 name: String,
594 stmt: Statement<Raw>,
595 sql: String,
596 ) -> Result<(), AdapterError> {
597 let catalog = self.catalog_snapshot().await;
598 let param_types = vec![];
599 let desc =
600 Coordinator::describe(&catalog, self.session(), Some(stmt.clone()), param_types)?;
601 let params = vec![];
602 let result_formats = vec![mz_pgwire_common::Format::Text; desc.arity()];
603 let now = self.now();
604 let logging = self.session().mint_logging(sql, Some(&stmt), now);
605 self.session().set_portal(
606 name,
607 desc,
608 Some(stmt),
609 logging,
610 params,
611 result_formats,
612 catalog.transient_revision(),
613 )?;
614 Ok(())
615 }
616
617 #[mz_ore::instrument(level = "debug")]
621 pub async fn execute(
622 &mut self,
623 portal_name: String,
624 cancel_future: impl Future<Output = std::io::Error> + Send,
625 outer_ctx_extra: Option<ExecuteContextExtra>,
626 ) -> Result<(ExecuteResponse, Instant), AdapterError> {
627 let execute_started = Instant::now();
628 let response = self
629 .send_with_cancel(
630 |tx, session| Command::Execute {
631 portal_name,
632 session,
633 tx,
634 outer_ctx_extra,
635 },
636 cancel_future,
637 )
638 .await?;
639 Ok((response, execute_started))
640 }
641
642 fn now(&self) -> EpochMillis {
643 (self.inner().now)()
644 }
645
646 fn now_datetime(&self) -> DateTime<Utc> {
647 to_datetime(self.now())
648 }
649
650 pub fn start_transaction(&mut self, implicit: Option<usize>) -> Result<(), AdapterError> {
656 let now = self.now_datetime();
657 let session = self.session.as_mut().expect("session invariant violated");
658 let result = match implicit {
659 None => session.start_transaction(now, None, None),
660 Some(stmts) => {
661 session.start_transaction_implicit(now, stmts);
662 Ok(())
663 }
664 };
665 result
666 }
667
668 #[instrument(level = "debug")]
671 pub async fn end_transaction(
672 &mut self,
673 action: EndTransactionAction,
674 ) -> Result<ExecuteResponse, AdapterError> {
675 let res = self
676 .send(|tx, session| Command::Commit {
677 action,
678 session,
679 tx,
680 })
681 .await;
682 let _ = self.session().clear_transaction();
686 res
687 }
688
689 pub fn fail_transaction(&mut self) {
691 let session = self.session.take().expect("session invariant violated");
692 let session = session.fail_transaction();
693 self.session = Some(session);
694 }
695
696 #[instrument(level = "debug")]
698 pub async fn catalog_snapshot(&self) -> Arc<Catalog> {
699 let CatalogSnapshot { catalog } = self
700 .send_without_session(|tx| Command::CatalogSnapshot { tx })
701 .await;
702 catalog
703 }
704
705 pub async fn dump_catalog(&self) -> Result<CatalogDump, AdapterError> {
710 let catalog = self.catalog_snapshot().await;
711 catalog.dump().map_err(AdapterError::from)
712 }
713
714 pub async fn check_catalog(&self) -> Result<(), serde_json::Value> {
720 let catalog = self.catalog_snapshot().await;
721 catalog.check_consistency()
722 }
723
724 pub async fn check_coordinator(&self) -> Result<(), serde_json::Value> {
730 self.send_without_session(|tx| Command::CheckConsistency { tx })
731 .await
732 .map_err(|inconsistencies| {
733 serde_json::to_value(inconsistencies).unwrap_or_else(|_| {
734 serde_json::Value::String("failed to serialize inconsistencies".to_string())
735 })
736 })
737 }
738
739 pub async fn dump_coordinator_state(&self) -> Result<serde_json::Value, anyhow::Error> {
740 self.send_without_session(|tx| Command::Dump { tx }).await
741 }
742
743 pub fn retire_execute(&self, data: ExecuteContextExtra, reason: StatementEndedExecutionReason) {
746 if !data.is_trivial() {
747 let cmd = Command::RetireExecute { data, reason };
748 self.inner().send(cmd);
749 }
750 }
751
752 pub async fn insert_rows(
758 &mut self,
759 id: CatalogItemId,
760 columns: Vec<ColumnIndex>,
761 rows: Vec<Row>,
762 ctx_extra: ExecuteContextExtra,
763 ) -> Result<ExecuteResponse, AdapterError> {
764 let pcx = self.session().pcx().clone();
767
768 let catalog = self.catalog_snapshot().await;
769 let conn_catalog = catalog.for_session(self.session());
770
771 let optimizer_config = optimize::OptimizerConfig::from(conn_catalog.system_vars());
773 let mut optimizer = optimize::view::Optimizer::new(optimizer_config, None);
774
775 let result: Result<_, AdapterError> =
776 mz_sql::plan::plan_copy_from(&pcx, &conn_catalog, id, columns, rows)
777 .err_into()
778 .and_then(|values| optimizer.optimize(values).err_into())
779 .and_then(|values| {
780 Coordinator::insert_constant(&catalog, self.session(), id, values.into_inner())
782 });
783 self.retire_execute(ctx_extra, (&result).into());
784 result
785 }
786
787 pub async fn get_system_vars(&self) -> SystemVars {
789 self.inner().get_system_vars().await
790 }
791
792 pub async fn set_system_vars(
794 &mut self,
795 vars: BTreeMap<String, String>,
796 ) -> Result<(), AdapterError> {
797 let conn_id = self.session().conn_id().clone();
798 self.send_without_session(|tx| Command::SetSystemVars { vars, conn_id, tx })
799 .await
800 }
801
802 pub async fn terminate(&mut self) {
804 let conn_id = self.session().conn_id().clone();
805 let res = self
806 .send_without_session(|tx| Command::Terminate {
807 conn_id,
808 tx: Some(tx),
809 })
810 .await;
811 if let Err(e) = res {
812 error!("Unable to terminate session: {e:?}");
814 }
815 self.inner = None;
817 }
818
819 pub fn session(&mut self) -> &mut Session {
821 self.session.as_mut().expect("session invariant violated")
822 }
823
824 pub fn inner(&self) -> &Client {
826 self.inner.as_ref().expect("inner invariant violated")
827 }
828
829 async fn send_without_session<T, F>(&self, f: F) -> T
830 where
831 F: FnOnce(oneshot::Sender<T>) -> Command,
832 {
833 let (tx, rx) = oneshot::channel();
834 self.inner().send(f(tx));
835 rx.await.expect("sender dropped")
836 }
837
838 #[instrument(level = "debug")]
839 async fn send<T, F>(&mut self, f: F) -> Result<T, AdapterError>
840 where
841 F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
842 {
843 self.send_with_cancel(f, futures::future::pending()).await
844 }
845
846 #[instrument(level = "debug")]
850 async fn send_with_cancel<T, F>(
851 &mut self,
852 f: F,
853 cancel_future: impl Future<Output = std::io::Error> + Send,
854 ) -> Result<T, AdapterError>
855 where
856 F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
857 {
858 let session = self.session.take().expect("session invariant violated");
859 let mut typ = None;
860 let application_name = session.application_name();
861 let name_hint = ApplicationNameHint::from_str(application_name);
862 let conn_id = session.conn_id().clone();
863 let (tx, rx) = oneshot::channel();
864
865 let Self {
868 inner: inner_client,
869 session: client_session,
870 ..
871 } = self;
872
873 let inner_client = inner_client.as_ref().expect("inner invariant violated");
876
877 let mut guarded_rx = rx.with_guard(|response: Response<_>| {
883 *client_session = Some(response.session);
884 });
885
886 inner_client.send({
887 let cmd = f(tx, session);
888 match cmd {
892 Command::Execute { .. } => typ = Some("execute"),
893 Command::GetWebhook { .. } => typ = Some("webhook"),
894 Command::Startup { .. }
895 | Command::AuthenticatePassword { .. }
896 | Command::CatalogSnapshot { .. }
897 | Command::Commit { .. }
898 | Command::CancelRequest { .. }
899 | Command::PrivilegedCancelRequest { .. }
900 | Command::GetSystemVars { .. }
901 | Command::SetSystemVars { .. }
902 | Command::Terminate { .. }
903 | Command::RetireExecute { .. }
904 | Command::CheckConsistency { .. }
905 | Command::Dump { .. } => {}
906 };
907 cmd
908 });
909
910 let mut cancel_future = pin::pin!(cancel_future);
911 let mut cancelled = false;
912 loop {
913 tokio::select! {
914 res = &mut guarded_rx => {
915 drop(guarded_rx);
917
918 let res = res.expect("sender dropped");
919 let status = res.result.is_ok().then_some("success").unwrap_or("error");
920 if let Err(err) = res.result.as_ref() {
921 if name_hint.should_trace_errors() {
922 tracing::warn!(?err, ?name_hint, "adapter response error");
923 }
924 }
925
926 if let Some(typ) = typ {
927 inner_client
928 .metrics
929 .commands
930 .with_label_values(&[typ, status, name_hint.as_str()])
931 .inc();
932 }
933 *client_session = Some(res.session);
934 return res.result;
935 },
936 _err = &mut cancel_future, if !cancelled => {
937 cancelled = true;
938 inner_client.send(Command::PrivilegedCancelRequest {
939 conn_id: conn_id.clone(),
940 });
941 }
942 };
943 }
944 }
945
946 pub fn add_idle_in_transaction_session_timeout(&mut self) {
947 let session = self.session();
948 let timeout_dur = session.vars().idle_in_transaction_session_timeout();
949 if !timeout_dur.is_zero() {
950 let timeout_dur = timeout_dur.clone();
951 if let Some(txn) = session.transaction().inner() {
952 let txn_id = txn.id.clone();
953 let timeout = TimeoutType::IdleInTransactionSession(txn_id);
954 self.timeouts.add_timeout(timeout, timeout_dur);
955 }
956 }
957 }
958
959 pub fn remove_idle_in_transaction_session_timeout(&mut self) {
960 let session = self.session();
961 if let Some(txn) = session.transaction().inner() {
962 let txn_id = txn.id.clone();
963 self.timeouts
964 .remove_timeout(&TimeoutType::IdleInTransactionSession(txn_id));
965 }
966 }
967
968 pub async fn recv_timeout(&mut self) -> Option<TimeoutType> {
975 self.timeouts.recv().await
976 }
977}
978
979impl Drop for SessionClient {
980 fn drop(&mut self) {
981 if let Some(session) = self.session.take() {
985 if let Some(inner) = &self.inner {
988 inner.send(Command::Terminate {
989 conn_id: session.conn_id().clone(),
990 tx: None,
991 })
992 }
993 }
994 }
995}
996
997#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
998pub enum TimeoutType {
999 IdleInTransactionSession(TransactionId),
1000}
1001
1002impl Display for TimeoutType {
1003 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1004 match self {
1005 TimeoutType::IdleInTransactionSession(txn_id) => {
1006 writeln!(f, "Idle in transaction session for transaction '{txn_id}'")
1007 }
1008 }
1009 }
1010}
1011
1012impl From<TimeoutType> for AdapterError {
1013 fn from(timeout: TimeoutType) -> Self {
1014 match timeout {
1015 TimeoutType::IdleInTransactionSession(_) => {
1016 AdapterError::IdleInTransactionSessionTimeout
1017 }
1018 }
1019 }
1020}
1021
1022struct Timeout {
1023 tx: mpsc::UnboundedSender<TimeoutType>,
1024 rx: mpsc::UnboundedReceiver<TimeoutType>,
1025 active_timeouts: BTreeMap<TimeoutType, AbortOnDropHandle<()>>,
1026}
1027
1028impl Timeout {
1029 fn new() -> Self {
1030 let (tx, rx) = mpsc::unbounded_channel();
1031 Timeout {
1032 tx,
1033 rx,
1034 active_timeouts: BTreeMap::new(),
1035 }
1036 }
1037
1038 async fn recv(&mut self) -> Option<TimeoutType> {
1047 self.rx.recv().await
1048 }
1049
1050 fn add_timeout(&mut self, timeout: TimeoutType, duration: Duration) {
1051 let tx = self.tx.clone();
1052 let timeout_key = timeout.clone();
1053 let handle = mz_ore::task::spawn(|| format!("{timeout_key}"), async move {
1054 tokio::time::sleep(duration).await;
1055 let _ = tx.send(timeout);
1056 })
1057 .abort_on_drop();
1058 self.active_timeouts.insert(timeout_key, handle);
1059 }
1060
1061 fn remove_timeout(&mut self, timeout: &TimeoutType) {
1062 self.active_timeouts.remove(timeout);
1063
1064 let mut timeouts = Vec::new();
1066 while let Ok(pending_timeout) = self.rx.try_recv() {
1067 if timeout != &pending_timeout {
1068 timeouts.push(pending_timeout);
1069 }
1070 }
1071 for pending_timeout in timeouts {
1072 self.tx.send(pending_timeout).expect("rx is in this struct");
1073 }
1074 }
1075}
1076
1077#[derive(Derivative)]
1080#[derivative(Debug)]
1081pub struct RecordFirstRowStream {
1082 #[derivative(Debug = "ignore")]
1083 pub rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1084 pub execute_started: Instant,
1085 pub time_to_first_row_seconds: Histogram,
1086 saw_rows: bool,
1087}
1088
1089impl RecordFirstRowStream {
1090 pub fn new(
1092 rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1093 execute_started: Instant,
1094 client: &SessionClient,
1095 instance_id: Option<ComputeInstanceId>,
1096 strategy: Option<StatementExecutionStrategy>,
1097 ) -> Self {
1098 let histogram = Self::histogram(client, instance_id, strategy);
1099 Self {
1100 rows,
1101 execute_started,
1102 time_to_first_row_seconds: histogram,
1103 saw_rows: false,
1104 }
1105 }
1106
1107 fn histogram(
1108 client: &SessionClient,
1109 instance_id: Option<ComputeInstanceId>,
1110 strategy: Option<StatementExecutionStrategy>,
1111 ) -> Histogram {
1112 let isolation_level = *client
1113 .session
1114 .as_ref()
1115 .expect("session invariant")
1116 .vars()
1117 .transaction_isolation();
1118 let instance = match instance_id {
1119 Some(i) => Cow::Owned(i.to_string()),
1120 None => Cow::Borrowed("none"),
1121 };
1122 let strategy = match strategy {
1123 Some(s) => s.name(),
1124 None => "none",
1125 };
1126
1127 client
1128 .inner()
1129 .metrics()
1130 .time_to_first_row_seconds
1131 .with_label_values(&[&instance, isolation_level.as_str(), strategy])
1132 }
1133
1134 pub fn record(
1137 execute_started: Instant,
1138 client: &SessionClient,
1139 instance_id: Option<ComputeInstanceId>,
1140 strategy: Option<StatementExecutionStrategy>,
1141 ) {
1142 Self::histogram(client, instance_id, strategy)
1143 .observe(execute_started.elapsed().as_secs_f64());
1144 }
1145
1146 pub async fn recv(&mut self) -> Option<PeekResponseUnary> {
1147 let msg = self.rows.next().await;
1148 if !self.saw_rows && matches!(msg, Some(PeekResponseUnary::Rows(_))) {
1149 self.saw_rows = true;
1150 self.time_to_first_row_seconds
1151 .observe(self.execute_started.elapsed().as_secs_f64());
1152 }
1153 msg
1154 }
1155}