1use std::borrow::Cow;
11use std::collections::BTreeMap;
12use std::fmt::{Debug, Display, Formatter};
13use std::future::Future;
14use std::pin::{self, Pin};
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18use anyhow::bail;
19use chrono::{DateTime, Utc};
20use derivative::Derivative;
21use futures::{Stream, StreamExt};
22use itertools::Itertools;
23use mz_adapter_types::connection::{ConnectionId, ConnectionIdType};
24use mz_auth::Authenticated;
25use mz_auth::password::Password;
26use mz_build_info::BuildInfo;
27use mz_compute_types::ComputeInstanceId;
28use mz_ore::channel::OneshotReceiverExt;
29use mz_ore::collections::CollectionExt;
30use mz_ore::id_gen::{IdAllocator, IdAllocatorInnerBitSet, MAX_ORG_ID, org_id_conn_bits};
31use mz_ore::instrument;
32use mz_ore::now::{EpochMillis, NowFn, to_datetime};
33use mz_ore::task::AbortOnDropHandle;
34use mz_ore::thread::JoinOnDropHandle;
35use mz_ore::tracing::OpenTelemetryContext;
36use mz_repr::user::InternalUserMetadata;
37use mz_repr::{CatalogItemId, ColumnIndex, SqlScalarType};
38use mz_sql::ast::{Raw, Statement};
39use mz_sql::catalog::{EnvironmentId, SessionCatalog};
40use mz_sql::session::hint::ApplicationNameHint;
41use mz_sql::session::metadata::SessionMetadata;
42use mz_sql::session::user::SUPPORT_USER;
43use mz_sql::session::vars::{
44 CLUSTER, ENABLE_FRONTEND_PEEK_SEQUENCING, OwnedVarInput, SystemVars, Var,
45};
46use mz_sql_parser::parser::{ParserStatementError, StatementParseResult};
47use prometheus::Histogram;
48use serde_json::json;
49use tokio::sync::{mpsc, oneshot};
50use tracing::{debug, error};
51use uuid::Uuid;
52
53use crate::catalog::Catalog;
54use crate::command::{
55 CatalogDump, CatalogSnapshot, Command, CopyFromStdinWriter, ExecuteResponse, Response,
56 SASLChallengeResponse, SASLVerifyProofResponse, SuperuserAttribute,
57};
58use crate::coord::{Coordinator, ExecuteContextGuard};
59use crate::error::AdapterError;
60use crate::metrics::Metrics;
61use crate::session::{
62 EndTransactionAction, PreparedStatement, Session, SessionConfig, StateRevision, TransactionId,
63};
64use crate::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
65use crate::telemetry::{self, EventDetails, SegmentClientExt, StatementFailureType};
66use crate::webhook::AppendWebhookResponse;
67use crate::{AdapterNotice, AppendWebhookError, PeekClient, PeekResponseUnary, StartupResponse};
68
69pub struct Handle {
75 pub(crate) session_id: Uuid,
76 pub(crate) start_instant: Instant,
77 pub(crate) _thread: JoinOnDropHandle<()>,
78}
79
80impl Handle {
81 pub fn session_id(&self) -> Uuid {
87 self.session_id
88 }
89
90 pub fn start_instant(&self) -> Instant {
92 self.start_instant
93 }
94}
95
96#[derive(Debug, Clone)]
104pub struct Client {
105 build_info: &'static BuildInfo,
106 inner_cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
107 id_alloc: IdAllocator<IdAllocatorInnerBitSet>,
108 now: NowFn,
109 metrics: Metrics,
110 environment_id: EnvironmentId,
111 segment_client: Option<mz_segment::Client>,
112}
113
114impl Client {
115 pub(crate) fn new(
116 build_info: &'static BuildInfo,
117 cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
118 metrics: Metrics,
119 now: NowFn,
120 environment_id: EnvironmentId,
121 segment_client: Option<mz_segment::Client>,
122 ) -> Client {
123 let env_lower = org_id_conn_bits(&environment_id.organization_id());
131 Client {
132 build_info,
133 inner_cmd_tx: cmd_tx,
134 id_alloc: IdAllocator::new(1, MAX_ORG_ID, env_lower),
135 now,
136 metrics,
137 environment_id,
138 segment_client,
139 }
140 }
141
142 pub fn new_conn_id(&self) -> Result<ConnectionId, AdapterError> {
144 self.id_alloc.alloc().ok_or(AdapterError::IdExhaustionError)
145 }
146
147 pub fn new_session(&self, config: SessionConfig, _authenticated: Authenticated) -> Session {
153 Session::new(self.build_info, config, self.metrics().session_metrics())
157 }
158
159 pub async fn authenticate(
162 &self,
163 user: &String,
164 password: &Password,
165 ) -> Result<Authenticated, AdapterError> {
166 let (tx, rx) = oneshot::channel();
167 self.send(Command::AuthenticatePassword {
168 role_name: user.to_string(),
169 password: Some(password.clone()),
170 tx,
171 });
172 rx.await.expect("sender dropped")?;
173 Ok(Authenticated)
174 }
175
176 pub async fn generate_sasl_challenge(
177 &self,
178 user: &String,
179 client_nonce: &String,
180 ) -> Result<SASLChallengeResponse, AdapterError> {
181 let (tx, rx) = oneshot::channel();
182 self.send(Command::AuthenticateGetSASLChallenge {
183 role_name: user.to_string(),
184 nonce: client_nonce.to_string(),
185 tx,
186 });
187 let response = rx.await.expect("sender dropped")?;
188 Ok(response)
189 }
190
191 pub async fn verify_sasl_proof(
192 &self,
193 user: &String,
194 proof: &String,
195 nonce: &String,
196 mock_hash: &String,
197 ) -> Result<(SASLVerifyProofResponse, Authenticated), AdapterError> {
198 let (tx, rx) = oneshot::channel();
199 self.send(Command::AuthenticateVerifySASLProof {
200 role_name: user.to_string(),
201 proof: proof.to_string(),
202 auth_message: nonce.to_string(),
203 mock_hash: mock_hash.to_string(),
204 tx,
205 });
206 let response = rx.await.expect("sender dropped")?;
207 Ok((response, Authenticated))
208 }
209
210 #[mz_ore::instrument(level = "debug")]
219 pub async fn startup(&self, session: Session) -> Result<SessionClient, AdapterError> {
220 let user = session.user().clone();
221 let conn_id = session.conn_id().clone();
222 let secret_key = session.secret_key();
223 let uuid = session.uuid();
224 let client_ip = session.client_ip();
225 let application_name = session.application_name().into();
226 let notice_tx = session.retain_notice_transmitter();
227
228 let (tx, rx) = oneshot::channel();
229
230 let rx = rx.with_guard(|_| {
236 self.send(Command::Terminate {
237 conn_id: conn_id.clone(),
238 tx: None,
239 });
240 });
241
242 self.send(Command::Startup {
243 tx,
244 user,
245 conn_id: conn_id.clone(),
246 secret_key,
247 uuid,
248 client_ip: client_ip.copied(),
249 application_name,
250 notice_tx,
251 });
252
253 let response = rx.await.expect("sender dropped")?;
256
257 let StartupResponse {
261 role_id,
262 write_notify,
263 session_defaults,
264 catalog,
265 storage_collections,
266 transient_id_gen,
267 optimizer_metrics,
268 persist_client,
269 statement_logging_frontend,
270 superuser_attribute,
271 } = response;
272
273 let peek_client = PeekClient::new(
274 self.clone(),
275 storage_collections,
276 transient_id_gen,
277 optimizer_metrics,
278 persist_client,
279 statement_logging_frontend,
280 );
281
282 let mut client = SessionClient {
283 inner: Some(self.clone()),
284 session: Some(session),
285 timeouts: Timeout::new(),
286 environment_id: self.environment_id.clone(),
287 segment_client: self.segment_client.clone(),
288 peek_client,
289 enable_frontend_peek_sequencing: false, };
291
292 let session = client.session();
293
294 if let SuperuserAttribute(Some(superuser)) = superuser_attribute {
297 session.apply_internal_user_metadata(InternalUserMetadata { superuser });
298 }
299
300 session.initialize_role_metadata(role_id);
301 let vars_mut = session.vars_mut();
302 for (name, val) in session_defaults {
303 if let Err(err) = vars_mut.set_default(&name, val.borrow()) {
304 tracing::error!("failed to set peristed default, {err:?}");
307 }
308 }
309 session
310 .vars_mut()
311 .end_transaction(EndTransactionAction::Commit);
312
313 session.set_builtin_table_updates(write_notify);
321
322 let catalog = catalog.for_session(session);
323
324 let cluster_active = session.vars().cluster().to_string();
325 if session.vars().welcome_message() {
326 let cluster_info = if catalog.resolve_cluster(Some(&cluster_active)).is_err() {
327 format!("{cluster_active} (does not exist)")
328 } else {
329 cluster_active.to_string()
330 };
331
332 session.add_notice(AdapterNotice::Welcome(format!(
336 "connected to Materialize v{}
337 Environment ID: {}
338 Region: {}
339 User: {}
340 Cluster: {}
341 Database: {}
342 {}
343 Session UUID: {}
344
345Issue a SQL query to get started. Need help?
346 View documentation: https://materialize.com/s/docs
347 Join our Slack community: https://materialize.com/s/chat
348 ",
349 session.vars().build_info().semver_version(),
350 self.environment_id,
351 self.environment_id.region(),
352 session.vars().user().name,
353 cluster_info,
354 session.vars().database(),
355 match session.vars().search_path() {
356 [schema] => format!("Schema: {}", schema),
357 schemas => format!(
358 "Search path: {}",
359 schemas.iter().map(|id| id.to_string()).join(", ")
360 ),
361 },
362 session.uuid(),
363 )));
364 }
365
366 if session.vars().current_object_missing_warnings() {
367 if catalog.active_database().is_none() {
368 let db = session.vars().database().into();
369 session.add_notice(AdapterNotice::UnknownSessionDatabase(db));
370 }
371 }
372
373 let cluster_var = session
376 .vars()
377 .inspect(CLUSTER.name())
378 .expect("cluster should exist");
379 if session.vars().current_object_missing_warnings()
380 && catalog.resolve_cluster(Some(&cluster_active)).is_err()
381 {
382 let cluster_notice = 'notice: {
383 if cluster_var.inspect_session_value().is_some() {
384 break 'notice Some(AdapterNotice::DefaultClusterDoesNotExist {
385 name: cluster_active,
386 kind: "session",
387 suggested_action: "Pick an extant cluster with SET CLUSTER = name. Run SHOW CLUSTERS to see available clusters.".into(),
388 });
389 }
390
391 let role_default = catalog.get_role(catalog.active_role_id());
392 let role_cluster = match role_default.vars().get(CLUSTER.name()) {
393 Some(OwnedVarInput::Flat(name)) => Some(name),
394 None => None,
395 Some(v @ OwnedVarInput::SqlSet(_)) => {
397 tracing::warn!(?v, "SqlSet found for cluster Role Default");
398 break 'notice None;
399 }
400 };
401
402 let alter_role = "with `ALTER ROLE <role> SET cluster TO <cluster>;`";
403 match role_cluster {
404 None => Some(AdapterNotice::DefaultClusterDoesNotExist {
406 name: cluster_active,
407 kind: "system",
408 suggested_action: format!(
409 "Set a default cluster for the current role {alter_role}."
410 ),
411 }),
412 Some(_) => Some(AdapterNotice::DefaultClusterDoesNotExist {
414 name: cluster_active,
415 kind: "role",
416 suggested_action: format!(
417 "Change the default cluster for the current role {alter_role}."
418 ),
419 }),
420 }
421 };
422
423 if let Some(notice) = cluster_notice {
424 session.add_notice(notice);
425 }
426 }
427
428 client.enable_frontend_peek_sequencing = ENABLE_FRONTEND_PEEK_SEQUENCING
429 .require(catalog.system_vars())
430 .is_ok();
431
432 Ok(client)
433 }
434
435 pub fn cancel_request(&self, conn_id: ConnectionIdType, secret_key: u32) {
437 self.send(Command::CancelRequest {
438 conn_id,
439 secret_key,
440 });
441 }
442
443 pub async fn support_execute_one(
446 &self,
447 sql: &str,
448 ) -> Result<Pin<Box<dyn Stream<Item = PeekResponseUnary> + Send>>, anyhow::Error> {
449 let conn_id = self.new_conn_id()?;
451 let session = self.new_session(
452 SessionConfig {
453 conn_id,
454 uuid: Uuid::new_v4(),
455 user: SUPPORT_USER.name.clone(),
456 client_ip: None,
457 external_metadata_rx: None,
458 helm_chart_version: None,
459 },
460 Authenticated,
461 );
462 let mut session_client = self.startup(session).await?;
463
464 let stmts = mz_sql::parse::parse(sql)?;
466 if stmts.len() != 1 {
467 bail!("must supply exactly one query");
468 }
469 let StatementParseResult { ast: stmt, sql } = stmts.into_element();
470
471 const EMPTY_PORTAL: &str = "";
472 session_client.start_transaction(Some(1))?;
473 session_client
474 .declare(EMPTY_PORTAL.into(), stmt, sql.to_string())
475 .await?;
476
477 match session_client
478 .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
479 .await?
480 {
481 (ExecuteResponse::SendingRowsStreaming { mut rows, .. }, _) => {
482 let owning_response_stream = async_stream::stream! {
487 while let Some(rows) = rows.next().await {
488 yield rows;
489 }
490 drop(session_client);
491 };
492 Ok(Box::pin(owning_response_stream))
493 }
494 r => bail!("unsupported response type: {r:?}"),
495 }
496 }
497
498 pub fn metrics(&self) -> &Metrics {
500 &self.metrics
501 }
502
503 pub fn now(&self) -> DateTime<Utc> {
505 to_datetime((self.now)())
506 }
507
508 pub async fn get_webhook_appender(
510 &self,
511 database: String,
512 schema: String,
513 name: String,
514 ) -> Result<AppendWebhookResponse, AppendWebhookError> {
515 let (tx, rx) = oneshot::channel();
516
517 self.send(Command::GetWebhook {
519 database,
520 schema,
521 name,
522 tx,
523 });
524
525 let response = rx
527 .await
528 .map_err(|_| anyhow::anyhow!("failed to receive webhook response"))?;
529
530 response
531 }
532
533 pub async fn get_system_vars(&self) -> SystemVars {
535 let (tx, rx) = oneshot::channel();
536 self.send(Command::GetSystemVars { tx });
537 rx.await.expect("coordinator unexpectedly gone")
538 }
539
540 #[instrument(level = "debug")]
541 pub(crate) fn send(&self, cmd: Command) {
542 self.inner_cmd_tx
543 .send((OpenTelemetryContext::obtain(), cmd))
544 .expect("coordinator unexpectedly gone");
545 }
546}
547
548pub struct SessionClient {
552 inner: Option<Client>,
556 session: Option<Session>,
559 timeouts: Timeout,
560 segment_client: Option<mz_segment::Client>,
561 environment_id: EnvironmentId,
562 peek_client: PeekClient,
564 pub enable_frontend_peek_sequencing: bool,
569}
570
571impl SessionClient {
572 pub fn parse<'a>(
575 &self,
576 sql: &'a str,
577 ) -> Result<Result<Vec<StatementParseResult<'a>>, ParserStatementError>, String> {
578 match mz_sql::parse::parse_with_limit(sql) {
579 Ok(Err(e)) => {
580 self.track_statement_parse_failure(&e);
581 Ok(Err(e))
582 }
583 r => r,
584 }
585 }
586
587 fn track_statement_parse_failure(&self, parse_error: &ParserStatementError) {
588 let session = self.session.as_ref().expect("session invariant violated");
589 let Some(user_id) = session.user().external_metadata.as_ref().map(|m| m.user_id) else {
590 return;
591 };
592 let Some(segment_client) = &self.segment_client else {
593 return;
594 };
595 let Some(statement_kind) = parse_error.statement else {
596 return;
597 };
598 let Some((action, object_type)) = telemetry::analyze_audited_statement(statement_kind)
599 else {
600 return;
601 };
602 let event_type = StatementFailureType::ParseFailure;
603 let event_name = format!(
604 "{} {} {}",
605 object_type.as_title_case(),
606 action.as_title_case(),
607 event_type.as_title_case(),
608 );
609 segment_client.environment_track(
610 &self.environment_id,
611 event_name,
612 json!({
613 "statement_kind": statement_kind,
614 "error": &parse_error.error,
615 }),
616 EventDetails {
617 user_id: Some(user_id),
618 application_name: Some(session.application_name()),
619 ..Default::default()
620 },
621 );
622 }
623
624 pub async fn get_prepared_statement(
627 &mut self,
628 name: &str,
629 ) -> Result<&PreparedStatement, AdapterError> {
630 let catalog = self.catalog_snapshot("get_prepared_statement").await;
631 Coordinator::verify_prepared_statement(&catalog, self.session(), name)?;
632 Ok(self
633 .session()
634 .get_prepared_statement_unverified(name)
635 .expect("must exist"))
636 }
637
638 pub async fn prepare(
643 &mut self,
644 name: String,
645 stmt: Option<Statement<Raw>>,
646 sql: String,
647 param_types: Vec<Option<SqlScalarType>>,
648 ) -> Result<(), AdapterError> {
649 let catalog = self.catalog_snapshot("prepare").await;
650
651 let mut async_pause = false;
654 (|| {
655 fail::fail_point!("async_prepare", |val| {
656 async_pause = val.map_or(false, |val| val.parse().unwrap_or(false))
657 });
658 })();
659 if async_pause {
660 tokio::time::sleep(Duration::from_secs(1)).await;
661 };
662
663 let desc = Coordinator::describe(&catalog, self.session(), stmt.clone(), param_types)?;
664 let now = self.now();
665 let state_revision = StateRevision {
666 catalog_revision: catalog.transient_revision(),
667 session_state_revision: self.session().state_revision(),
668 };
669 self.session()
670 .set_prepared_statement(name, stmt, sql, desc, state_revision, now);
671 Ok(())
672 }
673
674 #[mz_ore::instrument(level = "debug")]
676 pub async fn declare(
677 &mut self,
678 name: String,
679 stmt: Statement<Raw>,
680 sql: String,
681 ) -> Result<(), AdapterError> {
682 let catalog = self.catalog_snapshot("declare").await;
683 let param_types = vec![];
684 let desc =
685 Coordinator::describe(&catalog, self.session(), Some(stmt.clone()), param_types)?;
686 let params = vec![];
687 let result_formats = vec![mz_pgwire_common::Format::Text; desc.arity()];
688 let now = self.now();
689 let logging = self.session().mint_logging(sql, Some(&stmt), now);
690 let state_revision = StateRevision {
691 catalog_revision: catalog.transient_revision(),
692 session_state_revision: self.session().state_revision(),
693 };
694 self.session().set_portal(
695 name,
696 desc,
697 Some(stmt),
698 logging,
699 params,
700 result_formats,
701 state_revision,
702 )?;
703 Ok(())
704 }
705
706 #[mz_ore::instrument(level = "debug")]
713 pub async fn execute(
714 &mut self,
715 portal_name: String,
716 cancel_future: impl Future<Output = std::io::Error> + Send,
717 outer_ctx_extra: Option<ExecuteContextGuard>,
718 ) -> Result<(ExecuteResponse, Instant), AdapterError> {
719 let execute_started = Instant::now();
720
721 let mut outer_ctx_extra = outer_ctx_extra;
725 if let Some(resp) = self
726 .try_frontend_peek(&portal_name, &mut outer_ctx_extra)
727 .await?
728 {
729 debug!("frontend peek succeeded");
730 return Ok((resp, execute_started));
733 } else {
734 debug!("frontend peek did not happen, falling back to `Command::Execute`");
735 }
740
741 let response = self
742 .send_with_cancel(
743 |tx, session| Command::Execute {
744 portal_name,
745 session,
746 tx,
747 outer_ctx_extra,
748 },
749 cancel_future,
750 )
751 .await?;
752 Ok((response, execute_started))
753 }
754
755 fn now(&self) -> EpochMillis {
756 (self.inner().now)()
757 }
758
759 fn now_datetime(&self) -> DateTime<Utc> {
760 to_datetime(self.now())
761 }
762
763 pub fn start_transaction(&mut self, implicit: Option<usize>) -> Result<(), AdapterError> {
769 let now = self.now_datetime();
770 let session = self.session.as_mut().expect("session invariant violated");
771 let result = match implicit {
772 None => session.start_transaction(now, None, None),
773 Some(stmts) => {
774 session.start_transaction_implicit(now, stmts);
775 Ok(())
776 }
777 };
778 result
779 }
780
781 #[instrument(level = "debug")]
784 pub async fn end_transaction(
785 &mut self,
786 action: EndTransactionAction,
787 ) -> Result<ExecuteResponse, AdapterError> {
788 let res = self
789 .send(|tx, session| Command::Commit {
790 action,
791 session,
792 tx,
793 })
794 .await;
795 let _ = self.session().clear_transaction();
799 res
800 }
801
802 pub fn fail_transaction(&mut self) {
804 let session = self.session.take().expect("session invariant violated");
805 let session = session.fail_transaction();
806 self.session = Some(session);
807 }
808
809 #[instrument(level = "debug")]
811 pub async fn catalog_snapshot(&self, context: &str) -> Arc<Catalog> {
812 let start = std::time::Instant::now();
813 let CatalogSnapshot { catalog } = self
814 .send_without_session(|tx| Command::CatalogSnapshot { tx })
815 .await;
816 self.inner()
817 .metrics()
818 .catalog_snapshot_seconds
819 .with_label_values(&[context])
820 .observe(start.elapsed().as_secs_f64());
821 catalog
822 }
823
824 pub async fn dump_catalog(&self) -> Result<CatalogDump, AdapterError> {
829 let catalog = self.catalog_snapshot("dump_catalog").await;
830 catalog.dump().map_err(AdapterError::from)
831 }
832
833 pub async fn check_catalog(&self) -> Result<(), serde_json::Value> {
839 let catalog = self.catalog_snapshot("check_catalog").await;
840 catalog.check_consistency()
841 }
842
843 pub async fn check_coordinator(&self) -> Result<(), serde_json::Value> {
849 self.send_without_session(|tx| Command::CheckConsistency { tx })
850 .await
851 .map_err(|inconsistencies| {
852 serde_json::to_value(inconsistencies).unwrap_or_else(|_| {
853 serde_json::Value::String("failed to serialize inconsistencies".to_string())
854 })
855 })
856 }
857
858 pub async fn dump_coordinator_state(&self) -> Result<serde_json::Value, anyhow::Error> {
859 self.send_without_session(|tx| Command::Dump { tx }).await
860 }
861
862 pub fn retire_execute(
865 &self,
866 guard: ExecuteContextGuard,
867 reason: StatementEndedExecutionReason,
868 ) {
869 if !guard.is_trivial() {
870 let data = guard.defuse();
871 let cmd = Command::RetireExecute { data, reason };
872 self.inner().send(cmd);
873 }
874 }
875
876 pub async fn start_copy_from_stdin(
882 &mut self,
883 target_id: CatalogItemId,
884 target_name: String,
885 columns: Vec<ColumnIndex>,
886 row_desc: mz_repr::RelationDesc,
887 params: mz_pgcopy::CopyFormatParams<'static>,
888 ) -> Result<CopyFromStdinWriter, AdapterError> {
889 self.send(|tx, session| Command::StartCopyFromStdin {
890 target_id,
891 target_name,
892 columns,
893 row_desc,
894 params,
895 session,
896 tx,
897 })
898 .await
899 }
900
901 pub fn stage_copy_from_stdin_batches(
906 &mut self,
907 target_id: CatalogItemId,
908 batches: Vec<mz_persist_client::batch::ProtoBatch>,
909 ) -> Result<(), AdapterError> {
910 use crate::session::{TransactionOps, WriteOp};
911 use mz_storage_client::client::TableData;
912
913 self.session()
914 .add_transaction_ops(TransactionOps::Writes(vec![WriteOp {
915 id: target_id,
916 rows: TableData::Batches(batches.into()),
917 }]))?;
918 Ok(())
919 }
920
921 pub async fn get_system_vars(&self) -> SystemVars {
923 self.inner().get_system_vars().await
924 }
925
926 pub async fn set_system_vars(
928 &mut self,
929 vars: BTreeMap<String, String>,
930 ) -> Result<(), AdapterError> {
931 let conn_id = self.session().conn_id().clone();
932 self.send_without_session(|tx| Command::SetSystemVars { vars, conn_id, tx })
933 .await
934 }
935
936 pub async fn terminate(&mut self) {
938 let conn_id = self.session().conn_id().clone();
939 let res = self
940 .send_without_session(|tx| Command::Terminate {
941 conn_id,
942 tx: Some(tx),
943 })
944 .await;
945 if let Err(e) = res {
946 error!("Unable to terminate session: {e:?}");
948 }
949 self.inner = None;
951 }
952
953 pub fn session(&mut self) -> &mut Session {
955 self.session.as_mut().expect("session invariant violated")
956 }
957
958 pub fn inner(&self) -> &Client {
960 self.inner.as_ref().expect("inner invariant violated")
961 }
962
963 async fn send_without_session<T, F>(&self, f: F) -> T
964 where
965 F: FnOnce(oneshot::Sender<T>) -> Command,
966 {
967 let (tx, rx) = oneshot::channel();
968 self.inner().send(f(tx));
969 rx.await.expect("sender dropped")
970 }
971
972 #[instrument(level = "debug")]
973 async fn send<T, F>(&mut self, f: F) -> Result<T, AdapterError>
974 where
975 F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
976 {
977 self.send_with_cancel(f, futures::future::pending()).await
978 }
979
980 #[instrument(level = "debug")]
984 async fn send_with_cancel<T, F>(
985 &mut self,
986 f: F,
987 cancel_future: impl Future<Output = std::io::Error> + Send,
988 ) -> Result<T, AdapterError>
989 where
990 F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
991 {
992 let session = self.session.take().expect("session invariant violated");
993 let mut typ = None;
994 let application_name = session.application_name();
995 let name_hint = ApplicationNameHint::from_str(application_name);
996 let conn_id = session.conn_id().clone();
997 let (tx, rx) = oneshot::channel();
998
999 let Self {
1002 inner: inner_client,
1003 session: client_session,
1004 ..
1005 } = self;
1006
1007 let inner_client = inner_client.as_ref().expect("inner invariant violated");
1010
1011 let mut guarded_rx = rx.with_guard(|response: Response<_>| {
1017 *client_session = Some(response.session);
1018 });
1019
1020 inner_client.send({
1021 let cmd = f(tx, session);
1022 match cmd {
1026 Command::Execute { .. } => typ = Some("execute"),
1027 Command::GetWebhook { .. } => typ = Some("webhook"),
1028 Command::StartCopyFromStdin { .. }
1029 | Command::Startup { .. }
1030 | Command::AuthenticatePassword { .. }
1031 | Command::AuthenticateGetSASLChallenge { .. }
1032 | Command::AuthenticateVerifySASLProof { .. }
1033 | Command::CatalogSnapshot { .. }
1034 | Command::Commit { .. }
1035 | Command::CancelRequest { .. }
1036 | Command::PrivilegedCancelRequest { .. }
1037 | Command::GetSystemVars { .. }
1038 | Command::SetSystemVars { .. }
1039 | Command::Terminate { .. }
1040 | Command::RetireExecute { .. }
1041 | Command::CheckConsistency { .. }
1042 | Command::Dump { .. }
1043 | Command::GetComputeInstanceClient { .. }
1044 | Command::GetOracle { .. }
1045 | Command::DetermineRealTimeRecentTimestamp { .. }
1046 | Command::GetTransactionReadHoldsBundle { .. }
1047 | Command::StoreTransactionReadHolds { .. }
1048 | Command::ExecuteSlowPathPeek { .. }
1049 | Command::CopyToPreflight { .. }
1050 | Command::ExecuteCopyTo { .. }
1051 | Command::ExecuteSideEffectingFunc { .. }
1052 | Command::RegisterFrontendPeek { .. }
1053 | Command::UnregisterFrontendPeek { .. }
1054 | Command::ExplainTimestamp { .. }
1055 | Command::FrontendStatementLogging(..) => {}
1056 };
1057 cmd
1058 });
1059
1060 let mut cancel_future = pin::pin!(cancel_future);
1061 let mut cancelled = false;
1062 loop {
1063 tokio::select! {
1064 res = &mut guarded_rx => {
1065 drop(guarded_rx);
1067
1068 let res = res.expect("sender dropped");
1069 let status = res.result.is_ok().then_some("success").unwrap_or("error");
1070 if let Err(err) = res.result.as_ref() {
1071 if name_hint.should_trace_errors() {
1072 tracing::warn!(?err, ?name_hint, "adapter response error");
1073 }
1074 }
1075
1076 if let Some(typ) = typ {
1077 inner_client
1078 .metrics
1079 .commands
1080 .with_label_values(&[typ, status, name_hint.as_str()])
1081 .inc();
1082 }
1083 *client_session = Some(res.session);
1084 return res.result;
1085 },
1086 _err = &mut cancel_future, if !cancelled => {
1087 cancelled = true;
1088 inner_client.send(Command::PrivilegedCancelRequest {
1089 conn_id: conn_id.clone(),
1090 });
1091 }
1092 };
1093 }
1094 }
1095
1096 pub fn add_idle_in_transaction_session_timeout(&mut self) {
1097 let session = self.session();
1098 let timeout_dur = session.vars().idle_in_transaction_session_timeout();
1099 if !timeout_dur.is_zero() {
1100 let timeout_dur = timeout_dur.clone();
1101 if let Some(txn) = session.transaction().inner() {
1102 let txn_id = txn.id.clone();
1103 let timeout = TimeoutType::IdleInTransactionSession(txn_id);
1104 self.timeouts.add_timeout(timeout, timeout_dur);
1105 }
1106 }
1107 }
1108
1109 pub fn remove_idle_in_transaction_session_timeout(&mut self) {
1110 let session = self.session();
1111 if let Some(txn) = session.transaction().inner() {
1112 let txn_id = txn.id.clone();
1113 self.timeouts
1114 .remove_timeout(&TimeoutType::IdleInTransactionSession(txn_id));
1115 }
1116 }
1117
1118 pub async fn recv_timeout(&mut self) -> Option<TimeoutType> {
1125 self.timeouts.recv().await
1126 }
1127
1128 pub fn peek_client(&self) -> &PeekClient {
1130 &self.peek_client
1131 }
1132
1133 pub fn peek_client_mut(&mut self) -> &mut PeekClient {
1135 &mut self.peek_client
1136 }
1137
1138 pub(crate) async fn try_frontend_peek(
1146 &mut self,
1147 portal_name: &str,
1148 outer_ctx_extra: &mut Option<ExecuteContextGuard>,
1149 ) -> Result<Option<ExecuteResponse>, AdapterError> {
1150 if self.enable_frontend_peek_sequencing {
1151 let session = self.session.as_mut().expect("SessionClient invariant");
1152 self.peek_client
1153 .try_frontend_peek(portal_name, session, outer_ctx_extra)
1154 .await
1155 } else {
1156 Ok(None)
1157 }
1158 }
1159}
1160
1161impl Drop for SessionClient {
1162 fn drop(&mut self) {
1163 if let Some(session) = self.session.take() {
1167 if let Some(inner) = &self.inner {
1170 inner.send(Command::Terminate {
1171 conn_id: session.conn_id().clone(),
1172 tx: None,
1173 })
1174 }
1175 }
1176 }
1177}
1178
1179#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
1180pub enum TimeoutType {
1181 IdleInTransactionSession(TransactionId),
1182}
1183
1184impl Display for TimeoutType {
1185 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1186 match self {
1187 TimeoutType::IdleInTransactionSession(txn_id) => {
1188 writeln!(f, "Idle in transaction session for transaction '{txn_id}'")
1189 }
1190 }
1191 }
1192}
1193
1194impl From<TimeoutType> for AdapterError {
1195 fn from(timeout: TimeoutType) -> Self {
1196 match timeout {
1197 TimeoutType::IdleInTransactionSession(_) => {
1198 AdapterError::IdleInTransactionSessionTimeout
1199 }
1200 }
1201 }
1202}
1203
1204struct Timeout {
1205 tx: mpsc::UnboundedSender<TimeoutType>,
1206 rx: mpsc::UnboundedReceiver<TimeoutType>,
1207 active_timeouts: BTreeMap<TimeoutType, AbortOnDropHandle<()>>,
1208}
1209
1210impl Timeout {
1211 fn new() -> Self {
1212 let (tx, rx) = mpsc::unbounded_channel();
1213 Timeout {
1214 tx,
1215 rx,
1216 active_timeouts: BTreeMap::new(),
1217 }
1218 }
1219
1220 async fn recv(&mut self) -> Option<TimeoutType> {
1229 self.rx.recv().await
1230 }
1231
1232 fn add_timeout(&mut self, timeout: TimeoutType, duration: Duration) {
1233 let tx = self.tx.clone();
1234 let timeout_key = timeout.clone();
1235 let handle = mz_ore::task::spawn(|| format!("{timeout_key}"), async move {
1236 tokio::time::sleep(duration).await;
1237 let _ = tx.send(timeout);
1238 })
1239 .abort_on_drop();
1240 self.active_timeouts.insert(timeout_key, handle);
1241 }
1242
1243 fn remove_timeout(&mut self, timeout: &TimeoutType) {
1244 self.active_timeouts.remove(timeout);
1245
1246 let mut timeouts = Vec::new();
1248 while let Ok(pending_timeout) = self.rx.try_recv() {
1249 if timeout != &pending_timeout {
1250 timeouts.push(pending_timeout);
1251 }
1252 }
1253 for pending_timeout in timeouts {
1254 self.tx.send(pending_timeout).expect("rx is in this struct");
1255 }
1256 }
1257}
1258
1259#[derive(Derivative)]
1263#[derivative(Debug)]
1264pub struct RecordFirstRowStream {
1265 #[derivative(Debug = "ignore")]
1267 pub rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1268 pub execute_started: Instant,
1270 pub time_to_first_row_seconds: Histogram,
1273 pub saw_rows: bool,
1275 pub recorded_first_row_instant: Option<Instant>,
1277 pub no_more_rows: bool,
1279}
1280
1281impl RecordFirstRowStream {
1282 pub fn new(
1284 rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1285 execute_started: Instant,
1286 client: &SessionClient,
1287 instance_id: Option<ComputeInstanceId>,
1288 strategy: Option<StatementExecutionStrategy>,
1289 ) -> Self {
1290 let histogram = Self::histogram(client, instance_id, strategy);
1291 Self {
1292 rows,
1293 execute_started,
1294 time_to_first_row_seconds: histogram,
1295 saw_rows: false,
1296 recorded_first_row_instant: None,
1297 no_more_rows: false,
1298 }
1299 }
1300
1301 fn histogram(
1302 client: &SessionClient,
1303 instance_id: Option<ComputeInstanceId>,
1304 strategy: Option<StatementExecutionStrategy>,
1305 ) -> Histogram {
1306 let isolation_level = *client
1307 .session
1308 .as_ref()
1309 .expect("session invariant")
1310 .vars()
1311 .transaction_isolation();
1312 let instance = match instance_id {
1313 Some(i) => Cow::Owned(i.to_string()),
1314 None => Cow::Borrowed("none"),
1315 };
1316 let strategy = match strategy {
1317 Some(s) => s.name(),
1318 None => "none",
1319 };
1320
1321 client
1322 .inner()
1323 .metrics()
1324 .time_to_first_row_seconds
1325 .with_label_values(&[instance.as_ref(), isolation_level.as_str(), strategy])
1326 }
1327
1328 pub fn record(
1331 execute_started: Instant,
1332 client: &SessionClient,
1333 instance_id: Option<ComputeInstanceId>,
1334 strategy: Option<StatementExecutionStrategy>,
1335 ) {
1336 Self::histogram(client, instance_id, strategy)
1337 .observe(execute_started.elapsed().as_secs_f64());
1338 }
1339
1340 pub async fn recv(&mut self) -> Option<PeekResponseUnary> {
1341 let msg = self.rows.next().await;
1342 if !self.saw_rows && matches!(msg, Some(PeekResponseUnary::Rows(_))) {
1343 self.saw_rows = true;
1344 self.time_to_first_row_seconds
1345 .observe(self.execute_started.elapsed().as_secs_f64());
1346 self.recorded_first_row_instant = Some(Instant::now());
1347 }
1348 if msg.is_none() {
1349 self.no_more_rows = true;
1350 }
1351 msg
1352 }
1353}