1use std::borrow::Cow;
11use std::collections::BTreeMap;
12use std::fmt::{Debug, Display, Formatter};
13use std::future::Future;
14use std::pin::{self, Pin};
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18use anyhow::bail;
19use chrono::{DateTime, Utc};
20use derivative::Derivative;
21use futures::{Stream, StreamExt};
22use itertools::Itertools;
23use mz_adapter_types::connection::{ConnectionId, ConnectionIdType};
24use mz_auth::password::Password;
25use mz_build_info::BuildInfo;
26use mz_compute_types::ComputeInstanceId;
27use mz_ore::channel::OneshotReceiverExt;
28use mz_ore::collections::CollectionExt;
29use mz_ore::id_gen::{IdAllocator, IdAllocatorInnerBitSet, MAX_ORG_ID, org_id_conn_bits};
30use mz_ore::instrument;
31use mz_ore::now::{EpochMillis, NowFn, to_datetime};
32use mz_ore::result::ResultExt;
33use mz_ore::task::AbortOnDropHandle;
34use mz_ore::thread::JoinOnDropHandle;
35use mz_ore::tracing::OpenTelemetryContext;
36use mz_repr::{CatalogItemId, ColumnIndex, Row, SqlScalarType};
37use mz_sql::ast::{Raw, Statement};
38use mz_sql::catalog::{EnvironmentId, SessionCatalog};
39use mz_sql::session::hint::ApplicationNameHint;
40use mz_sql::session::metadata::SessionMetadata;
41use mz_sql::session::user::SUPPORT_USER;
42use mz_sql::session::vars::{
43 CLUSTER, ENABLE_FRONTEND_PEEK_SEQUENCING, OwnedVarInput, SystemVars, Var,
44};
45use mz_sql_parser::parser::{ParserStatementError, StatementParseResult};
46use prometheus::Histogram;
47use serde_json::json;
48use tokio::sync::{mpsc, oneshot};
49use tracing::{debug, error};
50use uuid::Uuid;
51
52use crate::catalog::Catalog;
53use crate::command::{
54 AuthResponse, CatalogDump, CatalogSnapshot, Command, ExecuteResponse, Response,
55 SASLChallengeResponse, SASLVerifyProofResponse,
56};
57use crate::coord::{Coordinator, ExecuteContextExtra};
58use crate::error::AdapterError;
59use crate::metrics::Metrics;
60use crate::optimize::dataflows::{EvalTime, ExprPrepStyle};
61use crate::optimize::{self, Optimize};
62use crate::session::{
63 EndTransactionAction, PreparedStatement, Session, SessionConfig, StateRevision, TransactionId,
64};
65use crate::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
66use crate::telemetry::{self, EventDetails, SegmentClientExt, StatementFailureType};
67use crate::webhook::AppendWebhookResponse;
68use crate::{AdapterNotice, AppendWebhookError, PeekClient, PeekResponseUnary, StartupResponse};
69
70pub struct Handle {
76 pub(crate) session_id: Uuid,
77 pub(crate) start_instant: Instant,
78 pub(crate) _thread: JoinOnDropHandle<()>,
79}
80
81impl Handle {
82 pub fn session_id(&self) -> Uuid {
88 self.session_id
89 }
90
91 pub fn start_instant(&self) -> Instant {
93 self.start_instant
94 }
95}
96
97#[derive(Debug, Clone)]
105pub struct Client {
106 build_info: &'static BuildInfo,
107 inner_cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
108 id_alloc: IdAllocator<IdAllocatorInnerBitSet>,
109 now: NowFn,
110 metrics: Metrics,
111 environment_id: EnvironmentId,
112 segment_client: Option<mz_segment::Client>,
113}
114
115impl Client {
116 pub(crate) fn new(
117 build_info: &'static BuildInfo,
118 cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
119 metrics: Metrics,
120 now: NowFn,
121 environment_id: EnvironmentId,
122 segment_client: Option<mz_segment::Client>,
123 ) -> Client {
124 let env_lower = org_id_conn_bits(&environment_id.organization_id());
132 Client {
133 build_info,
134 inner_cmd_tx: cmd_tx,
135 id_alloc: IdAllocator::new(1, MAX_ORG_ID, env_lower),
136 now,
137 metrics,
138 environment_id,
139 segment_client,
140 }
141 }
142
143 pub fn new_conn_id(&self) -> Result<ConnectionId, AdapterError> {
145 self.id_alloc.alloc().ok_or(AdapterError::IdExhaustionError)
146 }
147
148 pub fn new_session(&self, config: SessionConfig) -> Session {
152 Session::new(self.build_info, config, self.metrics().session_metrics())
156 }
157
158 pub async fn authenticate(
160 &self,
161 user: &String,
162 password: &Password,
163 ) -> Result<AuthResponse, AdapterError> {
164 let (tx, rx) = oneshot::channel();
165 self.send(Command::AuthenticatePassword {
166 role_name: user.to_string(),
167 password: Some(password.clone()),
168 tx,
169 });
170 let response = rx.await.expect("sender dropped")?;
171 Ok(response)
172 }
173
174 pub async fn generate_sasl_challenge(
175 &self,
176 user: &String,
177 client_nonce: &String,
178 ) -> Result<SASLChallengeResponse, AdapterError> {
179 let (tx, rx) = oneshot::channel();
180 self.send(Command::AuthenticateGetSASLChallenge {
181 role_name: user.to_string(),
182 nonce: client_nonce.to_string(),
183 tx,
184 });
185 let response = rx.await.expect("sender dropped")?;
186 Ok(response)
187 }
188
189 pub async fn verify_sasl_proof(
190 &self,
191 user: &String,
192 proof: &String,
193 nonce: &String,
194 mock_hash: &String,
195 ) -> Result<SASLVerifyProofResponse, AdapterError> {
196 let (tx, rx) = oneshot::channel();
197 self.send(Command::AuthenticateVerifySASLProof {
198 role_name: user.to_string(),
199 proof: proof.to_string(),
200 auth_message: nonce.to_string(),
201 mock_hash: mock_hash.to_string(),
202 tx,
203 });
204 let response = rx.await.expect("sender dropped")?;
205 Ok(response)
206 }
207
208 #[mz_ore::instrument(level = "debug")]
217 pub async fn startup(&self, session: Session) -> Result<SessionClient, AdapterError> {
218 let user = session.user().clone();
219 let conn_id = session.conn_id().clone();
220 let secret_key = session.secret_key();
221 let uuid = session.uuid();
222 let client_ip = session.client_ip();
223 let application_name = session.application_name().into();
224 let notice_tx = session.retain_notice_transmitter();
225
226 let (tx, rx) = oneshot::channel();
227
228 let rx = rx.with_guard(|_| {
234 self.send(Command::Terminate {
235 conn_id: conn_id.clone(),
236 tx: None,
237 });
238 });
239
240 self.send(Command::Startup {
241 tx,
242 user,
243 conn_id: conn_id.clone(),
244 secret_key,
245 uuid,
246 client_ip: client_ip.copied(),
247 application_name,
248 notice_tx,
249 });
250
251 let response = rx.await.expect("sender dropped")?;
254
255 let StartupResponse {
259 role_id,
260 write_notify,
261 session_defaults,
262 catalog,
263 storage_collections,
264 transient_id_gen,
265 optimizer_metrics,
266 persist_client,
267 } = response;
268
269 let peek_client = PeekClient::new(
270 self.clone(),
271 storage_collections,
272 transient_id_gen,
273 optimizer_metrics,
274 persist_client,
275 );
276
277 let mut client = SessionClient {
278 inner: Some(self.clone()),
279 session: Some(session),
280 timeouts: Timeout::new(),
281 environment_id: self.environment_id.clone(),
282 segment_client: self.segment_client.clone(),
283 peek_client,
284 enable_frontend_peek_sequencing: false, };
286
287 let session = client.session();
288 session.initialize_role_metadata(role_id);
289 let vars_mut = session.vars_mut();
290 for (name, val) in session_defaults {
291 if let Err(err) = vars_mut.set_default(&name, val.borrow()) {
292 tracing::error!("failed to set peristed default, {err:?}");
295 }
296 }
297 session
298 .vars_mut()
299 .end_transaction(EndTransactionAction::Commit);
300
301 session.set_builtin_table_updates(write_notify);
309
310 let catalog = catalog.for_session(session);
311
312 let cluster_active = session.vars().cluster().to_string();
313 if session.vars().welcome_message() {
314 let cluster_info = if catalog.resolve_cluster(Some(&cluster_active)).is_err() {
315 format!("{cluster_active} (does not exist)")
316 } else {
317 cluster_active.to_string()
318 };
319
320 session.add_notice(AdapterNotice::Welcome(format!(
324 "connected to Materialize v{}
325 Org ID: {}
326 Region: {}
327 User: {}
328 Cluster: {}
329 Database: {}
330 {}
331 Session UUID: {}
332
333Issue a SQL query to get started. Need help?
334 View documentation: https://materialize.com/s/docs
335 Join our Slack community: https://materialize.com/s/chat
336 ",
337 session.vars().build_info().semver_version(),
338 self.environment_id.organization_id(),
339 self.environment_id.region(),
340 session.vars().user().name,
341 cluster_info,
342 session.vars().database(),
343 match session.vars().search_path() {
344 [schema] => format!("Schema: {}", schema),
345 schemas => format!(
346 "Search path: {}",
347 schemas.iter().map(|id| id.to_string()).join(", ")
348 ),
349 },
350 session.uuid(),
351 )));
352 }
353
354 if session.vars().current_object_missing_warnings() {
355 if catalog.active_database().is_none() {
356 let db = session.vars().database().into();
357 session.add_notice(AdapterNotice::UnknownSessionDatabase(db));
358 }
359 }
360
361 let cluster_var = session
364 .vars()
365 .inspect(CLUSTER.name())
366 .expect("cluster should exist");
367 if session.vars().current_object_missing_warnings()
368 && catalog.resolve_cluster(Some(&cluster_active)).is_err()
369 {
370 let cluster_notice = 'notice: {
371 if cluster_var.inspect_session_value().is_some() {
372 break 'notice Some(AdapterNotice::DefaultClusterDoesNotExist {
373 name: cluster_active,
374 kind: "session",
375 suggested_action: "Pick an extant cluster with SET CLUSTER = name. Run SHOW CLUSTERS to see available clusters.".into(),
376 });
377 }
378
379 let role_default = catalog.get_role(catalog.active_role_id());
380 let role_cluster = match role_default.vars().get(CLUSTER.name()) {
381 Some(OwnedVarInput::Flat(name)) => Some(name),
382 None => None,
383 Some(v @ OwnedVarInput::SqlSet(_)) => {
385 tracing::warn!(?v, "SqlSet found for cluster Role Default");
386 break 'notice None;
387 }
388 };
389
390 let alter_role = "with `ALTER ROLE <role> SET cluster TO <cluster>;`";
391 match role_cluster {
392 None => Some(AdapterNotice::DefaultClusterDoesNotExist {
394 name: cluster_active,
395 kind: "system",
396 suggested_action: format!(
397 "Set a default cluster for the current role {alter_role}."
398 ),
399 }),
400 Some(_) => Some(AdapterNotice::DefaultClusterDoesNotExist {
402 name: cluster_active,
403 kind: "role",
404 suggested_action: format!(
405 "Change the default cluster for the current role {alter_role}."
406 ),
407 }),
408 }
409 };
410
411 if let Some(notice) = cluster_notice {
412 session.add_notice(notice);
413 }
414 }
415
416 client.enable_frontend_peek_sequencing = ENABLE_FRONTEND_PEEK_SEQUENCING
417 .require(catalog.system_vars())
418 .is_ok();
419
420 Ok(client)
421 }
422
423 pub fn cancel_request(&self, conn_id: ConnectionIdType, secret_key: u32) {
425 self.send(Command::CancelRequest {
426 conn_id,
427 secret_key,
428 });
429 }
430
431 pub async fn support_execute_one(
434 &self,
435 sql: &str,
436 ) -> Result<Pin<Box<dyn Stream<Item = PeekResponseUnary> + Send>>, anyhow::Error> {
437 let conn_id = self.new_conn_id()?;
439 let session = self.new_session(SessionConfig {
440 conn_id,
441 uuid: Uuid::new_v4(),
442 user: SUPPORT_USER.name.clone(),
443 client_ip: None,
444 external_metadata_rx: None,
445 internal_user_metadata: None,
446 helm_chart_version: None,
447 });
448 let mut session_client = self.startup(session).await?;
449
450 let stmts = mz_sql::parse::parse(sql)?;
452 if stmts.len() != 1 {
453 bail!("must supply exactly one query");
454 }
455 let StatementParseResult { ast: stmt, sql } = stmts.into_element();
456
457 const EMPTY_PORTAL: &str = "";
458 session_client.start_transaction(Some(1))?;
459 session_client
460 .declare(EMPTY_PORTAL.into(), stmt, sql.to_string())
461 .await?;
462
463 match session_client
464 .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
465 .await?
466 {
467 (ExecuteResponse::SendingRowsStreaming { mut rows, .. }, _) => {
468 let owning_response_stream = async_stream::stream! {
473 while let Some(rows) = rows.next().await {
474 yield rows;
475 }
476 drop(session_client);
477 };
478 Ok(Box::pin(owning_response_stream))
479 }
480 r => bail!("unsupported response type: {r:?}"),
481 }
482 }
483
484 pub fn metrics(&self) -> &Metrics {
486 &self.metrics
487 }
488
489 pub fn now(&self) -> DateTime<Utc> {
491 to_datetime((self.now)())
492 }
493
494 pub async fn get_webhook_appender(
496 &self,
497 database: String,
498 schema: String,
499 name: String,
500 ) -> Result<AppendWebhookResponse, AppendWebhookError> {
501 let (tx, rx) = oneshot::channel();
502
503 self.send(Command::GetWebhook {
505 database,
506 schema,
507 name,
508 tx,
509 });
510
511 let response = rx
513 .await
514 .map_err(|_| anyhow::anyhow!("failed to receive webhook response"))?;
515
516 response
517 }
518
519 pub async fn get_system_vars(&self) -> SystemVars {
521 let (tx, rx) = oneshot::channel();
522 self.send(Command::GetSystemVars { tx });
523 rx.await.expect("coordinator unexpectedly gone")
524 }
525
526 #[instrument(level = "debug")]
527 pub(crate) fn send(&self, cmd: Command) {
528 self.inner_cmd_tx
529 .send((OpenTelemetryContext::obtain(), cmd))
530 .expect("coordinator unexpectedly gone");
531 }
532}
533
534pub struct SessionClient {
538 inner: Option<Client>,
542 session: Option<Session>,
545 timeouts: Timeout,
546 segment_client: Option<mz_segment::Client>,
547 environment_id: EnvironmentId,
548 peek_client: PeekClient,
550 pub enable_frontend_peek_sequencing: bool,
555}
556
557impl SessionClient {
558 pub fn parse<'a>(
561 &self,
562 sql: &'a str,
563 ) -> Result<Result<Vec<StatementParseResult<'a>>, ParserStatementError>, String> {
564 match mz_sql::parse::parse_with_limit(sql) {
565 Ok(Err(e)) => {
566 self.track_statement_parse_failure(&e);
567 Ok(Err(e))
568 }
569 r => r,
570 }
571 }
572
573 fn track_statement_parse_failure(&self, parse_error: &ParserStatementError) {
574 let session = self.session.as_ref().expect("session invariant violated");
575 let Some(user_id) = session.user().external_metadata.as_ref().map(|m| m.user_id) else {
576 return;
577 };
578 let Some(segment_client) = &self.segment_client else {
579 return;
580 };
581 let Some(statement_kind) = parse_error.statement else {
582 return;
583 };
584 let Some((action, object_type)) = telemetry::analyze_audited_statement(statement_kind)
585 else {
586 return;
587 };
588 let event_type = StatementFailureType::ParseFailure;
589 let event_name = format!(
590 "{} {} {}",
591 object_type.as_title_case(),
592 action.as_title_case(),
593 event_type.as_title_case(),
594 );
595 segment_client.environment_track(
596 &self.environment_id,
597 event_name,
598 json!({
599 "statement_kind": statement_kind,
600 "error": &parse_error.error,
601 }),
602 EventDetails {
603 user_id: Some(user_id),
604 application_name: Some(session.application_name()),
605 ..Default::default()
606 },
607 );
608 }
609
610 pub async fn get_prepared_statement(
613 &mut self,
614 name: &str,
615 ) -> Result<&PreparedStatement, AdapterError> {
616 let catalog = self.catalog_snapshot("get_prepared_statement").await;
617 Coordinator::verify_prepared_statement(&catalog, self.session(), name)?;
618 Ok(self
619 .session()
620 .get_prepared_statement_unverified(name)
621 .expect("must exist"))
622 }
623
624 pub async fn prepare(
629 &mut self,
630 name: String,
631 stmt: Option<Statement<Raw>>,
632 sql: String,
633 param_types: Vec<Option<SqlScalarType>>,
634 ) -> Result<(), AdapterError> {
635 let catalog = self.catalog_snapshot("prepare").await;
636
637 let mut async_pause = false;
640 (|| {
641 fail::fail_point!("async_prepare", |val| {
642 async_pause = val.map_or(false, |val| val.parse().unwrap_or(false))
643 });
644 })();
645 if async_pause {
646 tokio::time::sleep(Duration::from_secs(1)).await;
647 };
648
649 let desc = Coordinator::describe(&catalog, self.session(), stmt.clone(), param_types)?;
650 let now = self.now();
651 let state_revision = StateRevision {
652 catalog_revision: catalog.transient_revision(),
653 session_state_revision: self.session().state_revision(),
654 };
655 self.session()
656 .set_prepared_statement(name, stmt, sql, desc, state_revision, now);
657 Ok(())
658 }
659
660 #[mz_ore::instrument(level = "debug")]
662 pub async fn declare(
663 &mut self,
664 name: String,
665 stmt: Statement<Raw>,
666 sql: String,
667 ) -> Result<(), AdapterError> {
668 let catalog = self.catalog_snapshot("declare").await;
669 let param_types = vec![];
670 let desc =
671 Coordinator::describe(&catalog, self.session(), Some(stmt.clone()), param_types)?;
672 let params = vec![];
673 let result_formats = vec![mz_pgwire_common::Format::Text; desc.arity()];
674 let now = self.now();
675 let logging = self.session().mint_logging(sql, Some(&stmt), now);
676 let state_revision = StateRevision {
677 catalog_revision: catalog.transient_revision(),
678 session_state_revision: self.session().state_revision(),
679 };
680 self.session().set_portal(
681 name,
682 desc,
683 Some(stmt),
684 logging,
685 params,
686 result_formats,
687 state_revision,
688 )?;
689 Ok(())
690 }
691
692 #[mz_ore::instrument(level = "debug")]
696 pub async fn execute(
697 &mut self,
698 portal_name: String,
699 cancel_future: impl Future<Output = std::io::Error> + Send,
700 outer_ctx_extra: Option<ExecuteContextExtra>,
701 ) -> Result<(ExecuteResponse, Instant), AdapterError> {
702 let execute_started = Instant::now();
703
704 if let Some(resp) = self.try_frontend_peek(&portal_name).await? {
708 debug!("frontend peek succeeded");
709 return Ok((resp, execute_started));
710 } else {
711 debug!("frontend peek did not happen");
712 }
713
714 let response = self
715 .send_with_cancel(
716 |tx, session| Command::Execute {
717 portal_name,
718 session,
719 tx,
720 outer_ctx_extra,
721 },
722 cancel_future,
723 )
724 .await?;
725 Ok((response, execute_started))
726 }
727
728 fn now(&self) -> EpochMillis {
729 (self.inner().now)()
730 }
731
732 fn now_datetime(&self) -> DateTime<Utc> {
733 to_datetime(self.now())
734 }
735
736 pub fn start_transaction(&mut self, implicit: Option<usize>) -> Result<(), AdapterError> {
742 let now = self.now_datetime();
743 let session = self.session.as_mut().expect("session invariant violated");
744 let result = match implicit {
745 None => session.start_transaction(now, None, None),
746 Some(stmts) => {
747 session.start_transaction_implicit(now, stmts);
748 Ok(())
749 }
750 };
751 result
752 }
753
754 #[instrument(level = "debug")]
757 pub async fn end_transaction(
758 &mut self,
759 action: EndTransactionAction,
760 ) -> Result<ExecuteResponse, AdapterError> {
761 let res = self
762 .send(|tx, session| Command::Commit {
763 action,
764 session,
765 tx,
766 })
767 .await;
768 let _ = self.session().clear_transaction();
772 res
773 }
774
775 pub fn fail_transaction(&mut self) {
777 let session = self.session.take().expect("session invariant violated");
778 let session = session.fail_transaction();
779 self.session = Some(session);
780 }
781
782 #[instrument(level = "debug")]
784 pub async fn catalog_snapshot(&self, context: &str) -> Arc<Catalog> {
785 let start = std::time::Instant::now();
786 let CatalogSnapshot { catalog } = self
787 .send_without_session(|tx| Command::CatalogSnapshot { tx })
788 .await;
789 self.inner()
790 .metrics()
791 .catalog_snapshot_seconds
792 .with_label_values(&[context])
793 .observe(start.elapsed().as_secs_f64());
794 catalog
795 }
796
797 pub async fn dump_catalog(&self) -> Result<CatalogDump, AdapterError> {
802 let catalog = self.catalog_snapshot("dump_catalog").await;
803 catalog.dump().map_err(AdapterError::from)
804 }
805
806 pub async fn check_catalog(&self) -> Result<(), serde_json::Value> {
812 let catalog = self.catalog_snapshot("check_catalog").await;
813 catalog.check_consistency()
814 }
815
816 pub async fn check_coordinator(&self) -> Result<(), serde_json::Value> {
822 self.send_without_session(|tx| Command::CheckConsistency { tx })
823 .await
824 .map_err(|inconsistencies| {
825 serde_json::to_value(inconsistencies).unwrap_or_else(|_| {
826 serde_json::Value::String("failed to serialize inconsistencies".to_string())
827 })
828 })
829 }
830
831 pub async fn dump_coordinator_state(&self) -> Result<serde_json::Value, anyhow::Error> {
832 self.send_without_session(|tx| Command::Dump { tx }).await
833 }
834
835 pub fn retire_execute(&self, data: ExecuteContextExtra, reason: StatementEndedExecutionReason) {
838 if !data.is_trivial() {
839 let cmd = Command::RetireExecute { data, reason };
840 self.inner().send(cmd);
841 }
842 }
843
844 pub async fn insert_rows(
850 &mut self,
851 target_id: CatalogItemId,
852 target_name: String,
853 columns: Vec<ColumnIndex>,
854 rows: Vec<Row>,
855 ctx_extra: ExecuteContextExtra,
856 ) -> Result<ExecuteResponse, AdapterError> {
857 let pcx = self.session().pcx().clone();
860
861 let session_meta = self.session().meta();
862
863 let catalog = self.catalog_snapshot("insert_rows").await;
864 let conn_catalog = catalog.for_session(self.session());
865 let catalog_state = conn_catalog.state();
866
867 let optimizer_config = optimize::OptimizerConfig::from(conn_catalog.system_vars());
869 let prep = ExprPrepStyle::OneShot {
870 logical_time: EvalTime::NotAvailable,
871 session: &session_meta,
872 catalog_state,
873 };
874 let mut optimizer =
875 optimize::view::Optimizer::new_with_prep_no_limit(optimizer_config.clone(), None, prep);
876
877 let result: Result<_, AdapterError> = mz_sql::plan::plan_copy_from(
878 &pcx,
879 &conn_catalog,
880 target_id,
881 target_name,
882 columns,
883 rows,
884 )
885 .err_into()
886 .and_then(|values| optimizer.optimize(values).err_into())
887 .and_then(|values| {
888 Coordinator::insert_constant(&catalog, self.session(), target_id, values.into_inner())
890 });
891 self.retire_execute(ctx_extra, (&result).into());
892 result
893 }
894
895 pub async fn get_system_vars(&self) -> SystemVars {
897 self.inner().get_system_vars().await
898 }
899
900 pub async fn set_system_vars(
902 &mut self,
903 vars: BTreeMap<String, String>,
904 ) -> Result<(), AdapterError> {
905 let conn_id = self.session().conn_id().clone();
906 self.send_without_session(|tx| Command::SetSystemVars { vars, conn_id, tx })
907 .await
908 }
909
910 pub async fn terminate(&mut self) {
912 let conn_id = self.session().conn_id().clone();
913 let res = self
914 .send_without_session(|tx| Command::Terminate {
915 conn_id,
916 tx: Some(tx),
917 })
918 .await;
919 if let Err(e) = res {
920 error!("Unable to terminate session: {e:?}");
922 }
923 self.inner = None;
925 }
926
927 pub fn session(&mut self) -> &mut Session {
929 self.session.as_mut().expect("session invariant violated")
930 }
931
932 pub fn inner(&self) -> &Client {
934 self.inner.as_ref().expect("inner invariant violated")
935 }
936
937 async fn send_without_session<T, F>(&self, f: F) -> T
938 where
939 F: FnOnce(oneshot::Sender<T>) -> Command,
940 {
941 let (tx, rx) = oneshot::channel();
942 self.inner().send(f(tx));
943 rx.await.expect("sender dropped")
944 }
945
946 #[instrument(level = "debug")]
947 async fn send<T, F>(&mut self, f: F) -> Result<T, AdapterError>
948 where
949 F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
950 {
951 self.send_with_cancel(f, futures::future::pending()).await
952 }
953
954 #[instrument(level = "debug")]
958 async fn send_with_cancel<T, F>(
959 &mut self,
960 f: F,
961 cancel_future: impl Future<Output = std::io::Error> + Send,
962 ) -> Result<T, AdapterError>
963 where
964 F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
965 {
966 let session = self.session.take().expect("session invariant violated");
967 let mut typ = None;
968 let application_name = session.application_name();
969 let name_hint = ApplicationNameHint::from_str(application_name);
970 let conn_id = session.conn_id().clone();
971 let (tx, rx) = oneshot::channel();
972
973 let Self {
976 inner: inner_client,
977 session: client_session,
978 ..
979 } = self;
980
981 let inner_client = inner_client.as_ref().expect("inner invariant violated");
984
985 let mut guarded_rx = rx.with_guard(|response: Response<_>| {
991 *client_session = Some(response.session);
992 });
993
994 inner_client.send({
995 let cmd = f(tx, session);
996 match cmd {
1000 Command::Execute { .. } => typ = Some("execute"),
1001 Command::GetWebhook { .. } => typ = Some("webhook"),
1002 Command::Startup { .. }
1003 | Command::AuthenticatePassword { .. }
1004 | Command::AuthenticateGetSASLChallenge { .. }
1005 | Command::AuthenticateVerifySASLProof { .. }
1006 | Command::CatalogSnapshot { .. }
1007 | Command::Commit { .. }
1008 | Command::CancelRequest { .. }
1009 | Command::PrivilegedCancelRequest { .. }
1010 | Command::GetSystemVars { .. }
1011 | Command::SetSystemVars { .. }
1012 | Command::Terminate { .. }
1013 | Command::RetireExecute { .. }
1014 | Command::CheckConsistency { .. }
1015 | Command::Dump { .. }
1016 | Command::GetComputeInstanceClient { .. }
1017 | Command::GetOracle { .. }
1018 | Command::DetermineRealTimeRecentTimestamp { .. }
1019 | Command::GetTransactionReadHoldsBundle { .. }
1020 | Command::StoreTransactionReadHolds { .. }
1021 | Command::ExecuteSlowPathPeek { .. }
1022 | Command::ExecuteCopyTo { .. } => {}
1023 };
1024 cmd
1025 });
1026
1027 let mut cancel_future = pin::pin!(cancel_future);
1028 let mut cancelled = false;
1029 loop {
1030 tokio::select! {
1031 res = &mut guarded_rx => {
1032 drop(guarded_rx);
1034
1035 let res = res.expect("sender dropped");
1036 let status = res.result.is_ok().then_some("success").unwrap_or("error");
1037 if let Err(err) = res.result.as_ref() {
1038 if name_hint.should_trace_errors() {
1039 tracing::warn!(?err, ?name_hint, "adapter response error");
1040 }
1041 }
1042
1043 if let Some(typ) = typ {
1044 inner_client
1045 .metrics
1046 .commands
1047 .with_label_values(&[typ, status, name_hint.as_str()])
1048 .inc();
1049 }
1050 *client_session = Some(res.session);
1051 return res.result;
1052 },
1053 _err = &mut cancel_future, if !cancelled => {
1054 cancelled = true;
1055 inner_client.send(Command::PrivilegedCancelRequest {
1056 conn_id: conn_id.clone(),
1057 });
1058 }
1059 };
1060 }
1061 }
1062
1063 pub fn add_idle_in_transaction_session_timeout(&mut self) {
1064 let session = self.session();
1065 let timeout_dur = session.vars().idle_in_transaction_session_timeout();
1066 if !timeout_dur.is_zero() {
1067 let timeout_dur = timeout_dur.clone();
1068 if let Some(txn) = session.transaction().inner() {
1069 let txn_id = txn.id.clone();
1070 let timeout = TimeoutType::IdleInTransactionSession(txn_id);
1071 self.timeouts.add_timeout(timeout, timeout_dur);
1072 }
1073 }
1074 }
1075
1076 pub fn remove_idle_in_transaction_session_timeout(&mut self) {
1077 let session = self.session();
1078 if let Some(txn) = session.transaction().inner() {
1079 let txn_id = txn.id.clone();
1080 self.timeouts
1081 .remove_timeout(&TimeoutType::IdleInTransactionSession(txn_id));
1082 }
1083 }
1084
1085 pub async fn recv_timeout(&mut self) -> Option<TimeoutType> {
1092 self.timeouts.recv().await
1093 }
1094
1095 pub fn peek_client(&self) -> &PeekClient {
1097 &self.peek_client
1098 }
1099
1100 pub fn peek_client_mut(&mut self) -> &mut PeekClient {
1102 &mut self.peek_client
1103 }
1104
1105 pub(crate) async fn try_frontend_peek(
1110 &mut self,
1111 portal_name: &str,
1112 ) -> Result<Option<ExecuteResponse>, AdapterError> {
1113 if self.enable_frontend_peek_sequencing {
1114 let session = self.session.as_mut().expect("SessionClient invariant");
1115 self.peek_client
1116 .try_frontend_peek_inner(portal_name, session)
1117 .await
1118 } else {
1119 Ok(None)
1120 }
1121 }
1122}
1123
1124impl Drop for SessionClient {
1125 fn drop(&mut self) {
1126 if let Some(session) = self.session.take() {
1130 if let Some(inner) = &self.inner {
1133 inner.send(Command::Terminate {
1134 conn_id: session.conn_id().clone(),
1135 tx: None,
1136 })
1137 }
1138 }
1139 }
1140}
1141
1142#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
1143pub enum TimeoutType {
1144 IdleInTransactionSession(TransactionId),
1145}
1146
1147impl Display for TimeoutType {
1148 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1149 match self {
1150 TimeoutType::IdleInTransactionSession(txn_id) => {
1151 writeln!(f, "Idle in transaction session for transaction '{txn_id}'")
1152 }
1153 }
1154 }
1155}
1156
1157impl From<TimeoutType> for AdapterError {
1158 fn from(timeout: TimeoutType) -> Self {
1159 match timeout {
1160 TimeoutType::IdleInTransactionSession(_) => {
1161 AdapterError::IdleInTransactionSessionTimeout
1162 }
1163 }
1164 }
1165}
1166
1167struct Timeout {
1168 tx: mpsc::UnboundedSender<TimeoutType>,
1169 rx: mpsc::UnboundedReceiver<TimeoutType>,
1170 active_timeouts: BTreeMap<TimeoutType, AbortOnDropHandle<()>>,
1171}
1172
1173impl Timeout {
1174 fn new() -> Self {
1175 let (tx, rx) = mpsc::unbounded_channel();
1176 Timeout {
1177 tx,
1178 rx,
1179 active_timeouts: BTreeMap::new(),
1180 }
1181 }
1182
1183 async fn recv(&mut self) -> Option<TimeoutType> {
1192 self.rx.recv().await
1193 }
1194
1195 fn add_timeout(&mut self, timeout: TimeoutType, duration: Duration) {
1196 let tx = self.tx.clone();
1197 let timeout_key = timeout.clone();
1198 let handle = mz_ore::task::spawn(|| format!("{timeout_key}"), async move {
1199 tokio::time::sleep(duration).await;
1200 let _ = tx.send(timeout);
1201 })
1202 .abort_on_drop();
1203 self.active_timeouts.insert(timeout_key, handle);
1204 }
1205
1206 fn remove_timeout(&mut self, timeout: &TimeoutType) {
1207 self.active_timeouts.remove(timeout);
1208
1209 let mut timeouts = Vec::new();
1211 while let Ok(pending_timeout) = self.rx.try_recv() {
1212 if timeout != &pending_timeout {
1213 timeouts.push(pending_timeout);
1214 }
1215 }
1216 for pending_timeout in timeouts {
1217 self.tx.send(pending_timeout).expect("rx is in this struct");
1218 }
1219 }
1220}
1221
1222#[derive(Derivative)]
1226#[derivative(Debug)]
1227pub struct RecordFirstRowStream {
1228 #[derivative(Debug = "ignore")]
1230 pub rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1231 pub execute_started: Instant,
1233 pub time_to_first_row_seconds: Histogram,
1236 pub saw_rows: bool,
1238 pub recorded_first_row_instant: Option<Instant>,
1240 pub no_more_rows: bool,
1242}
1243
1244impl RecordFirstRowStream {
1245 pub fn new(
1247 rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1248 execute_started: Instant,
1249 client: &SessionClient,
1250 instance_id: Option<ComputeInstanceId>,
1251 strategy: Option<StatementExecutionStrategy>,
1252 ) -> Self {
1253 let histogram = Self::histogram(client, instance_id, strategy);
1254 Self {
1255 rows,
1256 execute_started,
1257 time_to_first_row_seconds: histogram,
1258 saw_rows: false,
1259 recorded_first_row_instant: None,
1260 no_more_rows: false,
1261 }
1262 }
1263
1264 fn histogram(
1265 client: &SessionClient,
1266 instance_id: Option<ComputeInstanceId>,
1267 strategy: Option<StatementExecutionStrategy>,
1268 ) -> Histogram {
1269 let isolation_level = *client
1270 .session
1271 .as_ref()
1272 .expect("session invariant")
1273 .vars()
1274 .transaction_isolation();
1275 let instance = match instance_id {
1276 Some(i) => Cow::Owned(i.to_string()),
1277 None => Cow::Borrowed("none"),
1278 };
1279 let strategy = match strategy {
1280 Some(s) => s.name(),
1281 None => "none",
1282 };
1283
1284 client
1285 .inner()
1286 .metrics()
1287 .time_to_first_row_seconds
1288 .with_label_values(&[instance.as_ref(), isolation_level.as_str(), strategy])
1289 }
1290
1291 pub fn record(
1294 execute_started: Instant,
1295 client: &SessionClient,
1296 instance_id: Option<ComputeInstanceId>,
1297 strategy: Option<StatementExecutionStrategy>,
1298 ) {
1299 Self::histogram(client, instance_id, strategy)
1300 .observe(execute_started.elapsed().as_secs_f64());
1301 }
1302
1303 pub async fn recv(&mut self) -> Option<PeekResponseUnary> {
1304 let msg = self.rows.next().await;
1305 if !self.saw_rows && matches!(msg, Some(PeekResponseUnary::Rows(_))) {
1306 self.saw_rows = true;
1307 self.time_to_first_row_seconds
1308 .observe(self.execute_started.elapsed().as_secs_f64());
1309 self.recorded_first_row_instant = Some(Instant::now());
1310 }
1311 if msg.is_none() {
1312 self.no_more_rows = true;
1313 }
1314 msg
1315 }
1316}