1use std::borrow::Cow;
11use std::collections::BTreeMap;
12use std::fmt::{Debug, Display, Formatter};
13use std::future::Future;
14use std::pin::{self, Pin};
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18use anyhow::bail;
19use chrono::{DateTime, Utc};
20use derivative::Derivative;
21use futures::{Stream, StreamExt};
22use itertools::Itertools;
23use mz_adapter_types::connection::{ConnectionId, ConnectionIdType};
24use mz_auth::password::Password;
25use mz_auth::{Authenticated, AuthenticatorKind};
26use mz_build_info::BuildInfo;
27use mz_compute_types::ComputeInstanceId;
28use mz_ore::channel::OneshotReceiverExt;
29use mz_ore::collections::CollectionExt;
30use mz_ore::id_gen::{IdAllocator, IdAllocatorInnerBitSet, MAX_ORG_ID, org_id_conn_bits};
31use mz_ore::instrument;
32use mz_ore::now::{EpochMillis, NowFn, to_datetime};
33use mz_ore::task::AbortOnDropHandle;
34use mz_ore::thread::JoinOnDropHandle;
35use mz_ore::tracing::OpenTelemetryContext;
36use mz_repr::user::InternalUserMetadata;
37use mz_repr::{CatalogItemId, ColumnIndex, SqlScalarType};
38use mz_sql::ast::{Raw, Statement};
39use mz_sql::catalog::{EnvironmentId, SessionCatalog};
40use mz_sql::session::hint::ApplicationNameHint;
41use mz_sql::session::metadata::SessionMetadata;
42use mz_sql::session::user::SUPPORT_USER;
43use mz_sql::session::vars::{
44 CLUSTER, ENABLE_FRONTEND_PEEK_SEQUENCING, OwnedVarInput, SystemVars, Var,
45};
46use mz_sql_parser::parser::{ParserStatementError, StatementParseResult};
47use prometheus::Histogram;
48use serde_json::json;
49use tokio::sync::{mpsc, oneshot};
50use tracing::{debug, error};
51use uuid::Uuid;
52
53use crate::catalog::Catalog;
54use crate::command::{
55 CatalogDump, CatalogSnapshot, Command, CopyFromStdinWriter, ExecuteResponse, Response,
56 SASLChallengeResponse, SASLVerifyProofResponse, SuperuserAttribute,
57};
58use crate::coord::{Coordinator, ExecuteContextGuard};
59use crate::error::AdapterError;
60use crate::metrics::Metrics;
61use crate::session::{
62 EndTransactionAction, PreparedStatement, Session, SessionConfig, StateRevision, TransactionId,
63};
64use crate::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
65use crate::telemetry::{self, EventDetails, SegmentClientExt, StatementFailureType};
66use crate::webhook::AppendWebhookResponse;
67use crate::{AdapterNotice, AppendWebhookError, PeekClient, PeekResponseUnary, StartupResponse};
68
69pub struct Handle {
75 pub(crate) session_id: Uuid,
76 pub(crate) start_instant: Instant,
77 pub(crate) _thread: JoinOnDropHandle<()>,
78}
79
80impl Handle {
81 pub fn session_id(&self) -> Uuid {
87 self.session_id
88 }
89
90 pub fn start_instant(&self) -> Instant {
92 self.start_instant
93 }
94}
95
96#[derive(Debug, Clone)]
104pub struct Client {
105 build_info: &'static BuildInfo,
106 inner_cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
107 id_alloc: IdAllocator<IdAllocatorInnerBitSet>,
108 now: NowFn,
109 metrics: Metrics,
110 environment_id: EnvironmentId,
111 segment_client: Option<mz_segment::Client>,
112}
113
114impl Client {
115 pub(crate) fn new(
116 build_info: &'static BuildInfo,
117 cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
118 metrics: Metrics,
119 now: NowFn,
120 environment_id: EnvironmentId,
121 segment_client: Option<mz_segment::Client>,
122 ) -> Client {
123 let env_lower = org_id_conn_bits(&environment_id.organization_id());
131 Client {
132 build_info,
133 inner_cmd_tx: cmd_tx,
134 id_alloc: IdAllocator::new(1, MAX_ORG_ID, env_lower),
135 now,
136 metrics,
137 environment_id,
138 segment_client,
139 }
140 }
141
142 pub fn new_conn_id(&self) -> Result<ConnectionId, AdapterError> {
144 self.id_alloc.alloc().ok_or(AdapterError::IdExhaustionError)
145 }
146
147 pub fn new_session(&self, config: SessionConfig, _authenticated: Authenticated) -> Session {
153 Session::new(self.build_info, config, self.metrics().session_metrics())
157 }
158
159 pub async fn authenticate(
163 &self,
164 user: &String,
165 password: &Password,
166 ) -> Result<Authenticated, AdapterError> {
167 let (tx, rx) = oneshot::channel();
168 self.send(Command::AuthenticatePassword {
169 role_name: user.to_string(),
170 password: Some(password.clone()),
171 tx,
172 });
173 rx.await.expect("sender dropped")?;
174 Ok(Authenticated)
175 }
176
177 pub async fn generate_sasl_challenge(
180 &self,
181 user: &String,
182 client_nonce: &String,
183 ) -> Result<SASLChallengeResponse, AdapterError> {
184 let (tx, rx) = oneshot::channel();
185 self.send(Command::AuthenticateGetSASLChallenge {
186 role_name: user.to_string(),
187 nonce: client_nonce.to_string(),
188 tx,
189 });
190 let response = rx.await.expect("sender dropped")?;
191 Ok(response)
192 }
193
194 pub async fn verify_sasl_proof(
197 &self,
198 user: &String,
199 proof: &String,
200 nonce: &String,
201 mock_hash: &String,
202 ) -> Result<(SASLVerifyProofResponse, Authenticated), AdapterError> {
203 let (tx, rx) = oneshot::channel();
204 self.send(Command::AuthenticateVerifySASLProof {
205 role_name: user.to_string(),
206 proof: proof.to_string(),
207 auth_message: nonce.to_string(),
208 mock_hash: mock_hash.to_string(),
209 tx,
210 });
211 let response = rx.await.expect("sender dropped")?;
212 Ok((response, Authenticated))
213 }
214
215 pub async fn role_can_login(&self, role_name: &str) -> Result<(), AdapterError> {
217 let (tx, rx) = oneshot::channel();
218 self.send(Command::CheckRoleCanLogin {
219 role_name: role_name.to_string(),
220 tx,
221 });
222 rx.await.expect("sender dropped")
223 }
224
225 #[mz_ore::instrument(level = "debug")]
234 pub async fn startup(&self, session: Session) -> Result<SessionClient, AdapterError> {
235 let user = session.user().clone();
236 let conn_id = session.conn_id().clone();
237 let secret_key = session.secret_key();
238 let uuid = session.uuid();
239 let client_ip = session.client_ip();
240 let application_name = session.application_name().into();
241 let notice_tx = session.retain_notice_transmitter();
242
243 let (tx, rx) = oneshot::channel();
244
245 let rx = rx.with_guard(|_| {
251 self.send(Command::Terminate {
252 conn_id: conn_id.clone(),
253 tx: None,
254 });
255 });
256
257 self.send(Command::Startup {
258 tx,
259 user,
260 conn_id: conn_id.clone(),
261 secret_key,
262 uuid,
263 client_ip: client_ip.copied(),
264 application_name,
265 notice_tx,
266 });
267
268 let response = rx.await.expect("sender dropped")?;
271
272 let StartupResponse {
276 role_id,
277 write_notify,
278 session_defaults,
279 catalog,
280 storage_collections,
281 transient_id_gen,
282 optimizer_metrics,
283 persist_client,
284 statement_logging_frontend,
285 superuser_attribute,
286 } = response;
287
288 let peek_client = PeekClient::new(
289 self.clone(),
290 storage_collections,
291 transient_id_gen,
292 optimizer_metrics,
293 persist_client,
294 statement_logging_frontend,
295 );
296
297 let mut client = SessionClient {
298 inner: Some(self.clone()),
299 session: Some(session),
300 timeouts: Timeout::new(),
301 environment_id: self.environment_id.clone(),
302 segment_client: self.segment_client.clone(),
303 peek_client,
304 enable_frontend_peek_sequencing: false, };
306
307 let session = client.session();
308
309 if let SuperuserAttribute(Some(superuser)) = superuser_attribute {
312 session.apply_internal_user_metadata(InternalUserMetadata { superuser });
313 }
314
315 session.initialize_role_metadata(role_id);
316 let vars_mut = session.vars_mut();
317 for (name, val) in session_defaults {
318 if let Err(err) = vars_mut.set_default(&name, val.borrow()) {
319 tracing::error!("failed to set peristed default, {err:?}");
322 }
323 }
324 session
325 .vars_mut()
326 .end_transaction(EndTransactionAction::Commit);
327
328 session.set_builtin_table_updates(write_notify);
336
337 let catalog = catalog.for_session(session);
338
339 let cluster_active = session.vars().cluster().to_string();
340 if session.vars().welcome_message() {
341 let cluster_info = if catalog.resolve_cluster(Some(&cluster_active)).is_err() {
342 format!("{cluster_active} (does not exist)")
343 } else {
344 cluster_active.to_string()
345 };
346
347 session.add_notice(AdapterNotice::Welcome(format!(
351 "connected to Materialize v{}
352 Environment ID: {}
353 Region: {}
354 User: {}
355 Cluster: {}
356 Database: {}
357 {}
358 Session UUID: {}
359
360Issue a SQL query to get started. Need help?
361 View documentation: https://materialize.com/s/docs
362 Join our Slack community: https://materialize.com/s/chat
363 ",
364 session.vars().build_info().semver_version(),
365 self.environment_id,
366 self.environment_id.region(),
367 session.vars().user().name,
368 cluster_info,
369 session.vars().database(),
370 match session.vars().search_path() {
371 [schema] => format!("Schema: {}", schema),
372 schemas => format!(
373 "Search path: {}",
374 schemas.iter().map(|id| id.to_string()).join(", ")
375 ),
376 },
377 session.uuid(),
378 )));
379 }
380
381 if session.vars().current_object_missing_warnings() {
382 if catalog.active_database().is_none() {
383 let db = session.vars().database().into();
384 session.add_notice(AdapterNotice::UnknownSessionDatabase(db));
385 }
386 }
387
388 let cluster_var = session
391 .vars()
392 .inspect(CLUSTER.name())
393 .expect("cluster should exist");
394 if session.vars().current_object_missing_warnings()
395 && catalog.resolve_cluster(Some(&cluster_active)).is_err()
396 {
397 let cluster_notice = 'notice: {
398 if cluster_var.inspect_session_value().is_some() {
399 break 'notice Some(AdapterNotice::DefaultClusterDoesNotExist {
400 name: cluster_active,
401 kind: "session",
402 suggested_action: "Pick an extant cluster with SET CLUSTER = name. Run SHOW CLUSTERS to see available clusters.".into(),
403 });
404 }
405
406 let role_default = catalog.get_role(catalog.active_role_id());
407 let role_cluster = match role_default.vars().get(CLUSTER.name()) {
408 Some(OwnedVarInput::Flat(name)) => Some(name),
409 None => None,
410 Some(v @ OwnedVarInput::SqlSet(_)) => {
412 tracing::warn!(?v, "SqlSet found for cluster Role Default");
413 break 'notice None;
414 }
415 };
416
417 let alter_role = "with `ALTER ROLE <role> SET cluster TO <cluster>;`";
418 match role_cluster {
419 None => Some(AdapterNotice::DefaultClusterDoesNotExist {
421 name: cluster_active,
422 kind: "system",
423 suggested_action: format!(
424 "Set a default cluster for the current role {alter_role}."
425 ),
426 }),
427 Some(_) => Some(AdapterNotice::DefaultClusterDoesNotExist {
429 name: cluster_active,
430 kind: "role",
431 suggested_action: format!(
432 "Change the default cluster for the current role {alter_role}."
433 ),
434 }),
435 }
436 };
437
438 if let Some(notice) = cluster_notice {
439 session.add_notice(notice);
440 }
441 }
442
443 client.enable_frontend_peek_sequencing = ENABLE_FRONTEND_PEEK_SEQUENCING
444 .require(catalog.system_vars())
445 .is_ok();
446
447 Ok(client)
448 }
449
450 pub fn cancel_request(&self, conn_id: ConnectionIdType, secret_key: u32) {
452 self.send(Command::CancelRequest {
453 conn_id,
454 secret_key,
455 });
456 }
457
458 pub async fn support_execute_one(
461 &self,
462 sql: &str,
463 ) -> Result<Pin<Box<dyn Stream<Item = PeekResponseUnary> + Send>>, anyhow::Error> {
464 let conn_id = self.new_conn_id()?;
466 let session = self.new_session(
467 SessionConfig {
468 conn_id,
469 uuid: Uuid::new_v4(),
470 user: SUPPORT_USER.name.clone(),
471 client_ip: None,
472 external_metadata_rx: None,
473 helm_chart_version: None,
474 authenticator_kind: AuthenticatorKind::None,
475 },
476 Authenticated,
477 );
478 let mut session_client = self.startup(session).await?;
479
480 let stmts = mz_sql::parse::parse(sql)?;
482 if stmts.len() != 1 {
483 bail!("must supply exactly one query");
484 }
485 let StatementParseResult { ast: stmt, sql } = stmts.into_element();
486
487 const EMPTY_PORTAL: &str = "";
488 session_client.start_transaction(Some(1))?;
489 session_client
490 .declare(EMPTY_PORTAL.into(), stmt, sql.to_string())
491 .await?;
492
493 let execute_result = session_client
494 .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
495 .await?;
496 match execute_result {
497 (ExecuteResponse::SendingRowsStreaming { mut rows, .. }, _) => {
498 let owning_response_stream = async_stream::stream! {
503 while let Some(rows) = rows.next().await {
504 yield rows;
505 }
506 drop(session_client);
507 };
508 Ok(Box::pin(owning_response_stream))
509 }
510 r => bail!("unsupported response type: {r:?}"),
511 }
512 }
513
514 pub fn metrics(&self) -> &Metrics {
516 &self.metrics
517 }
518
519 pub fn now(&self) -> DateTime<Utc> {
521 to_datetime((self.now)())
522 }
523
524 pub async fn get_webhook_appender(
526 &self,
527 database: String,
528 schema: String,
529 name: String,
530 ) -> Result<AppendWebhookResponse, AppendWebhookError> {
531 let (tx, rx) = oneshot::channel();
532
533 self.send(Command::GetWebhook {
535 database,
536 schema,
537 name,
538 tx,
539 });
540
541 let response = rx
543 .await
544 .map_err(|_| anyhow::anyhow!("failed to receive webhook response"))?;
545
546 response
547 }
548
549 pub async fn get_system_vars(&self) -> SystemVars {
551 let (tx, rx) = oneshot::channel();
552 self.send(Command::GetSystemVars { tx });
553 rx.await.expect("coordinator unexpectedly gone")
554 }
555
556 #[instrument(level = "debug")]
557 pub(crate) fn send(&self, cmd: Command) {
558 self.inner_cmd_tx
559 .send((OpenTelemetryContext::obtain(), cmd))
560 .expect("coordinator unexpectedly gone");
561 }
562}
563
564pub struct SessionClient {
568 inner: Option<Client>,
572 session: Option<Session>,
575 timeouts: Timeout,
576 segment_client: Option<mz_segment::Client>,
577 environment_id: EnvironmentId,
578 peek_client: PeekClient,
580 pub enable_frontend_peek_sequencing: bool,
585}
586
587impl SessionClient {
588 pub fn parse<'a>(
591 &self,
592 sql: &'a str,
593 ) -> Result<Result<Vec<StatementParseResult<'a>>, ParserStatementError>, String> {
594 match mz_sql::parse::parse_with_limit(sql) {
595 Ok(Err(e)) => {
596 self.track_statement_parse_failure(&e);
597 Ok(Err(e))
598 }
599 r => r,
600 }
601 }
602
603 fn track_statement_parse_failure(&self, parse_error: &ParserStatementError) {
604 let session = self.session.as_ref().expect("session invariant violated");
605 let Some(user_id) = session.user().external_metadata.as_ref().map(|m| m.user_id) else {
606 return;
607 };
608 let Some(segment_client) = &self.segment_client else {
609 return;
610 };
611 let Some(statement_kind) = parse_error.statement else {
612 return;
613 };
614 let Some((action, object_type)) = telemetry::analyze_audited_statement(statement_kind)
615 else {
616 return;
617 };
618 let event_type = StatementFailureType::ParseFailure;
619 let event_name = format!(
620 "{} {} {}",
621 object_type.as_title_case(),
622 action.as_title_case(),
623 event_type.as_title_case(),
624 );
625 segment_client.environment_track(
626 &self.environment_id,
627 event_name,
628 json!({
629 "statement_kind": statement_kind,
630 "error": &parse_error.error,
631 }),
632 EventDetails {
633 user_id: Some(user_id),
634 application_name: Some(session.application_name()),
635 ..Default::default()
636 },
637 );
638 }
639
640 pub async fn get_prepared_statement(
643 &mut self,
644 name: &str,
645 ) -> Result<&PreparedStatement, AdapterError> {
646 let catalog = self.catalog_snapshot("get_prepared_statement").await;
647 Coordinator::verify_prepared_statement(&catalog, self.session(), name)?;
648 Ok(self
649 .session()
650 .get_prepared_statement_unverified(name)
651 .expect("must exist"))
652 }
653
654 pub async fn prepare(
659 &mut self,
660 name: String,
661 stmt: Option<Statement<Raw>>,
662 sql: String,
663 param_types: Vec<Option<SqlScalarType>>,
664 ) -> Result<(), AdapterError> {
665 let catalog = self.catalog_snapshot("prepare").await;
666
667 let mut async_pause = false;
670 (|| {
671 fail::fail_point!("async_prepare", |val| {
672 async_pause = val.map_or(false, |val| val.parse().unwrap_or(false))
673 });
674 })();
675 if async_pause {
676 tokio::time::sleep(Duration::from_secs(1)).await;
677 };
678
679 let desc = Coordinator::describe(&catalog, self.session(), stmt.clone(), param_types)?;
680 let now = self.now();
681 let state_revision = StateRevision {
682 catalog_revision: catalog.transient_revision(),
683 session_state_revision: self.session().state_revision(),
684 };
685 self.session()
686 .set_prepared_statement(name, stmt, sql, desc, state_revision, now);
687 Ok(())
688 }
689
690 #[mz_ore::instrument(level = "debug")]
692 pub async fn declare(
693 &mut self,
694 name: String,
695 stmt: Statement<Raw>,
696 sql: String,
697 ) -> Result<(), AdapterError> {
698 let catalog = self.catalog_snapshot("declare").await;
699 let param_types = vec![];
700 let desc =
701 Coordinator::describe(&catalog, self.session(), Some(stmt.clone()), param_types)?;
702 let params = vec![];
703 let result_formats = vec![mz_pgwire_common::Format::Text; desc.arity()];
704 let now = self.now();
705 let logging = self.session().mint_logging(sql, Some(&stmt), now);
706 let state_revision = StateRevision {
707 catalog_revision: catalog.transient_revision(),
708 session_state_revision: self.session().state_revision(),
709 };
710 self.session().set_portal(
711 name,
712 desc,
713 Some(stmt),
714 logging,
715 params,
716 result_formats,
717 state_revision,
718 )?;
719 Ok(())
720 }
721
722 #[mz_ore::instrument(level = "debug")]
729 pub async fn execute(
730 &mut self,
731 portal_name: String,
732 cancel_future: impl Future<Output = std::io::Error> + Send,
733 outer_ctx_extra: Option<ExecuteContextGuard>,
734 ) -> Result<(ExecuteResponse, Instant), AdapterError> {
735 let execute_started = Instant::now();
736
737 let mut outer_ctx_extra = outer_ctx_extra;
741 let peek_result = self
742 .try_frontend_peek(&portal_name, &mut outer_ctx_extra)
743 .await?;
744 if let Some(resp) = peek_result {
745 debug!("frontend peek succeeded");
746 return Ok((resp, execute_started));
749 } else {
750 debug!("frontend peek did not happen, falling back to `Command::Execute`");
751 }
756
757 let response = self
758 .send_with_cancel(
759 |tx, session| Command::Execute {
760 portal_name,
761 session,
762 tx,
763 outer_ctx_extra,
764 },
765 cancel_future,
766 )
767 .await?;
768 Ok((response, execute_started))
769 }
770
771 fn now(&self) -> EpochMillis {
772 (self.inner().now)()
773 }
774
775 fn now_datetime(&self) -> DateTime<Utc> {
776 to_datetime(self.now())
777 }
778
779 pub fn start_transaction(&mut self, implicit: Option<usize>) -> Result<(), AdapterError> {
785 let now = self.now_datetime();
786 let session = self.session.as_mut().expect("session invariant violated");
787 let result = match implicit {
788 None => session.start_transaction(now, None, None),
789 Some(stmts) => {
790 session.start_transaction_implicit(now, stmts);
791 Ok(())
792 }
793 };
794 result
795 }
796
797 #[instrument(level = "debug")]
800 pub async fn end_transaction(
801 &mut self,
802 action: EndTransactionAction,
803 ) -> Result<ExecuteResponse, AdapterError> {
804 let res = self
805 .send(|tx, session| Command::Commit {
806 action,
807 session,
808 tx,
809 })
810 .await;
811 let _ = self.session().clear_transaction();
815 res
816 }
817
818 pub fn fail_transaction(&mut self) {
820 let session = self.session.take().expect("session invariant violated");
821 let session = session.fail_transaction();
822 self.session = Some(session);
823 }
824
825 #[instrument(level = "debug")]
827 pub async fn catalog_snapshot(&self, context: &str) -> Arc<Catalog> {
828 let start = std::time::Instant::now();
829 let CatalogSnapshot { catalog } = self
830 .send_without_session(|tx| Command::CatalogSnapshot { tx })
831 .await;
832 self.inner()
833 .metrics()
834 .catalog_snapshot_seconds
835 .with_label_values(&[context])
836 .observe(start.elapsed().as_secs_f64());
837 catalog
838 }
839
840 pub async fn dump_catalog(&self) -> Result<CatalogDump, AdapterError> {
845 let catalog = self.catalog_snapshot("dump_catalog").await;
846 catalog.dump().map_err(AdapterError::from)
847 }
848
849 pub async fn check_catalog(&self) -> Result<(), serde_json::Value> {
855 let catalog = self.catalog_snapshot("check_catalog").await;
856 catalog.check_consistency()
857 }
858
859 pub async fn check_coordinator(&self) -> Result<(), serde_json::Value> {
865 self.send_without_session(|tx| Command::CheckConsistency { tx })
866 .await
867 .map_err(|inconsistencies| {
868 serde_json::to_value(inconsistencies).unwrap_or_else(|_| {
869 serde_json::Value::String("failed to serialize inconsistencies".to_string())
870 })
871 })
872 }
873
874 pub async fn dump_coordinator_state(&self) -> Result<serde_json::Value, anyhow::Error> {
875 self.send_without_session(|tx| Command::Dump { tx }).await
876 }
877
878 pub fn retire_execute(
881 &self,
882 guard: ExecuteContextGuard,
883 reason: StatementEndedExecutionReason,
884 ) {
885 if !guard.is_trivial() {
886 let data = guard.defuse();
887 let cmd = Command::RetireExecute { data, reason };
888 self.inner().send(cmd);
889 }
890 }
891
892 pub async fn start_copy_from_stdin(
898 &mut self,
899 target_id: CatalogItemId,
900 target_name: String,
901 columns: Vec<ColumnIndex>,
902 row_desc: mz_repr::RelationDesc,
903 params: mz_pgcopy::CopyFormatParams<'static>,
904 ) -> Result<CopyFromStdinWriter, AdapterError> {
905 self.send(|tx, session| Command::StartCopyFromStdin {
906 target_id,
907 target_name,
908 columns,
909 row_desc,
910 params,
911 session,
912 tx,
913 })
914 .await
915 }
916
917 pub fn stage_copy_from_stdin_batches(
922 &mut self,
923 target_id: CatalogItemId,
924 batches: Vec<mz_persist_client::batch::ProtoBatch>,
925 ) -> Result<(), AdapterError> {
926 use crate::session::{TransactionOps, WriteOp};
927 use mz_storage_client::client::TableData;
928
929 self.session()
930 .add_transaction_ops(TransactionOps::Writes(vec![WriteOp {
931 id: target_id,
932 rows: TableData::Batches(batches.into()),
933 }]))?;
934 Ok(())
935 }
936
937 pub async fn get_system_vars(&self) -> SystemVars {
939 self.inner().get_system_vars().await
940 }
941
942 pub async fn set_system_vars(
944 &mut self,
945 vars: BTreeMap<String, String>,
946 ) -> Result<(), AdapterError> {
947 let conn_id = self.session().conn_id().clone();
948 self.send_without_session(|tx| Command::SetSystemVars { vars, conn_id, tx })
949 .await
950 }
951
952 pub async fn inject_audit_events(
957 &mut self,
958 events: Vec<crate::catalog::InjectedAuditEvent>,
959 ) -> Result<(), AdapterError> {
960 let conn_id = self.session().conn_id().clone();
961 self.send_without_session(|tx| Command::InjectAuditEvents {
962 events,
963 conn_id,
964 tx,
965 })
966 .await
967 }
968
969 pub async fn terminate(&mut self) {
971 let conn_id = self.session().conn_id().clone();
972 let res = self
973 .send_without_session(|tx| Command::Terminate {
974 conn_id,
975 tx: Some(tx),
976 })
977 .await;
978 if let Err(e) = res {
979 error!("Unable to terminate session: {e:?}");
981 }
982 self.inner = None;
984 }
985
986 pub fn session(&mut self) -> &mut Session {
988 self.session.as_mut().expect("session invariant violated")
989 }
990
991 pub fn inner(&self) -> &Client {
993 self.inner.as_ref().expect("inner invariant violated")
994 }
995
996 async fn send_without_session<T, F>(&self, f: F) -> T
997 where
998 F: FnOnce(oneshot::Sender<T>) -> Command,
999 {
1000 let (tx, rx) = oneshot::channel();
1001 self.inner().send(f(tx));
1002 rx.await.expect("sender dropped")
1003 }
1004
1005 #[instrument(level = "debug")]
1006 async fn send<T, F>(&mut self, f: F) -> Result<T, AdapterError>
1007 where
1008 F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
1009 {
1010 self.send_with_cancel(f, futures::future::pending()).await
1011 }
1012
1013 #[instrument(level = "debug")]
1017 async fn send_with_cancel<T, F>(
1018 &mut self,
1019 f: F,
1020 cancel_future: impl Future<Output = std::io::Error> + Send,
1021 ) -> Result<T, AdapterError>
1022 where
1023 F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
1024 {
1025 let session = self.session.take().expect("session invariant violated");
1026 let mut typ = None;
1027 let application_name = session.application_name();
1028 let name_hint = ApplicationNameHint::from_str(application_name);
1029 let conn_id = session.conn_id().clone();
1030 let (tx, rx) = oneshot::channel();
1031
1032 let Self {
1035 inner: inner_client,
1036 session: client_session,
1037 ..
1038 } = self;
1039
1040 let inner_client = inner_client.as_ref().expect("inner invariant violated");
1043
1044 let mut guarded_rx = rx.with_guard(|response: Response<_>| {
1050 *client_session = Some(response.session);
1051 });
1052
1053 inner_client.send({
1054 let cmd = f(tx, session);
1055 match cmd {
1059 Command::Execute { .. } => typ = Some("execute"),
1060 Command::GetWebhook { .. } => typ = Some("webhook"),
1061 Command::StartCopyFromStdin { .. }
1062 | Command::Startup { .. }
1063 | Command::AuthenticatePassword { .. }
1064 | Command::AuthenticateGetSASLChallenge { .. }
1065 | Command::AuthenticateVerifySASLProof { .. }
1066 | Command::CheckRoleCanLogin { .. }
1067 | Command::CatalogSnapshot { .. }
1068 | Command::Commit { .. }
1069 | Command::CancelRequest { .. }
1070 | Command::PrivilegedCancelRequest { .. }
1071 | Command::GetSystemVars { .. }
1072 | Command::SetSystemVars { .. }
1073 | Command::Terminate { .. }
1074 | Command::RetireExecute { .. }
1075 | Command::CheckConsistency { .. }
1076 | Command::Dump { .. }
1077 | Command::GetComputeInstanceClient { .. }
1078 | Command::GetOracle { .. }
1079 | Command::DetermineRealTimeRecentTimestamp { .. }
1080 | Command::GetTransactionReadHoldsBundle { .. }
1081 | Command::StoreTransactionReadHolds { .. }
1082 | Command::ExecuteSlowPathPeek { .. }
1083 | Command::ExecuteSubscribe { .. }
1084 | Command::CopyToPreflight { .. }
1085 | Command::ExecuteCopyTo { .. }
1086 | Command::ExecuteSideEffectingFunc { .. }
1087 | Command::RegisterFrontendPeek { .. }
1088 | Command::UnregisterFrontendPeek { .. }
1089 | Command::ExplainTimestamp { .. }
1090 | Command::FrontendStatementLogging(..)
1091 | Command::InjectAuditEvents { .. } => {}
1092 };
1093 cmd
1094 });
1095
1096 let mut cancel_future = pin::pin!(cancel_future);
1097 let mut cancelled = false;
1098 loop {
1099 tokio::select! {
1100 res = &mut guarded_rx => {
1101 drop(guarded_rx);
1103
1104 let res = res.expect("sender dropped");
1105 let status = res.result.is_ok().then_some("success").unwrap_or("error");
1106 if let Err(err) = res.result.as_ref() {
1107 if name_hint.should_trace_errors() {
1108 tracing::warn!(?err, ?name_hint, "adapter response error");
1109 }
1110 }
1111
1112 if let Some(typ) = typ {
1113 inner_client
1114 .metrics
1115 .commands
1116 .with_label_values(&[typ, status, name_hint.as_str()])
1117 .inc();
1118 }
1119 *client_session = Some(res.session);
1120 return res.result;
1121 },
1122 _err = &mut cancel_future, if !cancelled => {
1123 cancelled = true;
1124 inner_client.send(Command::PrivilegedCancelRequest {
1125 conn_id: conn_id.clone(),
1126 });
1127 }
1128 };
1129 }
1130 }
1131
1132 pub fn add_idle_in_transaction_session_timeout(&mut self) {
1133 let session = self.session();
1134 let timeout_dur = session.vars().idle_in_transaction_session_timeout();
1135 if !timeout_dur.is_zero() {
1136 let timeout_dur = timeout_dur.clone();
1137 if let Some(txn) = session.transaction().inner() {
1138 let txn_id = txn.id.clone();
1139 let timeout = TimeoutType::IdleInTransactionSession(txn_id);
1140 self.timeouts.add_timeout(timeout, timeout_dur);
1141 }
1142 }
1143 }
1144
1145 pub fn remove_idle_in_transaction_session_timeout(&mut self) {
1146 let session = self.session();
1147 if let Some(txn) = session.transaction().inner() {
1148 let txn_id = txn.id.clone();
1149 self.timeouts
1150 .remove_timeout(&TimeoutType::IdleInTransactionSession(txn_id));
1151 }
1152 }
1153
1154 pub async fn recv_timeout(&mut self) -> Option<TimeoutType> {
1161 self.timeouts.recv().await
1162 }
1163
1164 pub(crate) async fn try_frontend_peek(
1172 &mut self,
1173 portal_name: &str,
1174 outer_ctx_extra: &mut Option<ExecuteContextGuard>,
1175 ) -> Result<Option<ExecuteResponse>, AdapterError> {
1176 if self.enable_frontend_peek_sequencing {
1177 let session = self.session.as_mut().expect("SessionClient invariant");
1178 self.peek_client
1179 .try_frontend_peek(portal_name, session, outer_ctx_extra)
1180 .await
1181 } else {
1182 Ok(None)
1183 }
1184 }
1185}
1186
1187impl Drop for SessionClient {
1188 fn drop(&mut self) {
1189 if let Some(session) = self.session.take() {
1193 if let Some(inner) = &self.inner {
1196 inner.send(Command::Terminate {
1197 conn_id: session.conn_id().clone(),
1198 tx: None,
1199 })
1200 }
1201 }
1202 }
1203}
1204
1205#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
1206pub enum TimeoutType {
1207 IdleInTransactionSession(TransactionId),
1208}
1209
1210impl Display for TimeoutType {
1211 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1212 match self {
1213 TimeoutType::IdleInTransactionSession(txn_id) => {
1214 writeln!(f, "Idle in transaction session for transaction '{txn_id}'")
1215 }
1216 }
1217 }
1218}
1219
1220impl From<TimeoutType> for AdapterError {
1221 fn from(timeout: TimeoutType) -> Self {
1222 match timeout {
1223 TimeoutType::IdleInTransactionSession(_) => {
1224 AdapterError::IdleInTransactionSessionTimeout
1225 }
1226 }
1227 }
1228}
1229
1230struct Timeout {
1231 tx: mpsc::UnboundedSender<TimeoutType>,
1232 rx: mpsc::UnboundedReceiver<TimeoutType>,
1233 active_timeouts: BTreeMap<TimeoutType, AbortOnDropHandle<()>>,
1234}
1235
1236impl Timeout {
1237 fn new() -> Self {
1238 let (tx, rx) = mpsc::unbounded_channel();
1239 Timeout {
1240 tx,
1241 rx,
1242 active_timeouts: BTreeMap::new(),
1243 }
1244 }
1245
1246 async fn recv(&mut self) -> Option<TimeoutType> {
1255 self.rx.recv().await
1256 }
1257
1258 fn add_timeout(&mut self, timeout: TimeoutType, duration: Duration) {
1259 let tx = self.tx.clone();
1260 let timeout_key = timeout.clone();
1261 let handle = mz_ore::task::spawn(|| format!("{timeout_key}"), async move {
1262 tokio::time::sleep(duration).await;
1263 let _ = tx.send(timeout);
1264 })
1265 .abort_on_drop();
1266 self.active_timeouts.insert(timeout_key, handle);
1267 }
1268
1269 fn remove_timeout(&mut self, timeout: &TimeoutType) {
1270 self.active_timeouts.remove(timeout);
1271
1272 let mut timeouts = Vec::new();
1274 while let Ok(pending_timeout) = self.rx.try_recv() {
1275 if timeout != &pending_timeout {
1276 timeouts.push(pending_timeout);
1277 }
1278 }
1279 for pending_timeout in timeouts {
1280 self.tx.send(pending_timeout).expect("rx is in this struct");
1281 }
1282 }
1283}
1284
1285#[derive(Derivative)]
1289#[derivative(Debug)]
1290pub struct RecordFirstRowStream {
1291 #[derivative(Debug = "ignore")]
1293 pub rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1294 pub execute_started: Instant,
1296 pub time_to_first_row_seconds: Histogram,
1299 pub saw_rows: bool,
1301 pub recorded_first_row_instant: Option<Instant>,
1303 pub no_more_rows: bool,
1305}
1306
1307impl RecordFirstRowStream {
1308 pub fn new(
1310 rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1311 execute_started: Instant,
1312 client: &SessionClient,
1313 instance_id: Option<ComputeInstanceId>,
1314 strategy: Option<StatementExecutionStrategy>,
1315 ) -> Self {
1316 let histogram = Self::histogram(client, instance_id, strategy);
1317 Self {
1318 rows,
1319 execute_started,
1320 time_to_first_row_seconds: histogram,
1321 saw_rows: false,
1322 recorded_first_row_instant: None,
1323 no_more_rows: false,
1324 }
1325 }
1326
1327 fn histogram(
1328 client: &SessionClient,
1329 instance_id: Option<ComputeInstanceId>,
1330 strategy: Option<StatementExecutionStrategy>,
1331 ) -> Histogram {
1332 let isolation_level = *client
1333 .session
1334 .as_ref()
1335 .expect("session invariant")
1336 .vars()
1337 .transaction_isolation();
1338 let instance = match instance_id {
1339 Some(i) => Cow::Owned(i.to_string()),
1340 None => Cow::Borrowed("none"),
1341 };
1342 let strategy = match strategy {
1343 Some(s) => s.name(),
1344 None => "none",
1345 };
1346
1347 client
1348 .inner()
1349 .metrics()
1350 .time_to_first_row_seconds
1351 .with_label_values(&[instance.as_ref(), isolation_level.as_str(), strategy])
1352 }
1353
1354 pub fn record(
1357 execute_started: Instant,
1358 client: &SessionClient,
1359 instance_id: Option<ComputeInstanceId>,
1360 strategy: Option<StatementExecutionStrategy>,
1361 ) {
1362 Self::histogram(client, instance_id, strategy)
1363 .observe(execute_started.elapsed().as_secs_f64());
1364 }
1365
1366 pub async fn recv(&mut self) -> Option<PeekResponseUnary> {
1367 let msg = self.rows.next().await;
1368 if !self.saw_rows && matches!(msg, Some(PeekResponseUnary::Rows(_))) {
1369 self.saw_rows = true;
1370 self.time_to_first_row_seconds
1371 .observe(self.execute_started.elapsed().as_secs_f64());
1372 self.recorded_first_row_instant = Some(Instant::now());
1373 }
1374 if msg.is_none() {
1375 self.no_more_rows = true;
1376 }
1377 msg
1378 }
1379}