mz_adapter/
client.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::borrow::Cow;
11use std::collections::BTreeMap;
12use std::fmt::{Debug, Display, Formatter};
13use std::future::Future;
14use std::pin;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18use anyhow::bail;
19use chrono::{DateTime, Utc};
20use derivative::Derivative;
21use futures::{Stream, StreamExt};
22use itertools::Itertools;
23use mz_adapter_types::connection::{ConnectionId, ConnectionIdType};
24use mz_auth::password::Password;
25use mz_build_info::BuildInfo;
26use mz_compute_types::ComputeInstanceId;
27use mz_ore::channel::OneshotReceiverExt;
28use mz_ore::collections::CollectionExt;
29use mz_ore::id_gen::{IdAllocator, IdAllocatorInnerBitSet, MAX_ORG_ID, org_id_conn_bits};
30use mz_ore::instrument;
31use mz_ore::now::{EpochMillis, NowFn, to_datetime};
32use mz_ore::result::ResultExt;
33use mz_ore::task::AbortOnDropHandle;
34use mz_ore::thread::JoinOnDropHandle;
35use mz_ore::tracing::OpenTelemetryContext;
36use mz_repr::{CatalogItemId, ColumnIndex, Row, RowIterator, ScalarType};
37use mz_sql::ast::{Raw, Statement};
38use mz_sql::catalog::{EnvironmentId, SessionCatalog};
39use mz_sql::session::hint::ApplicationNameHint;
40use mz_sql::session::metadata::SessionMetadata;
41use mz_sql::session::user::SUPPORT_USER;
42use mz_sql::session::vars::{CLUSTER, OwnedVarInput, SystemVars, Var};
43use mz_sql_parser::parser::{ParserStatementError, StatementParseResult};
44use prometheus::Histogram;
45use serde_json::json;
46use tokio::sync::{mpsc, oneshot};
47use tracing::error;
48use uuid::Uuid;
49
50use crate::catalog::Catalog;
51use crate::command::{
52    AuthResponse, CatalogDump, CatalogSnapshot, Command, ExecuteResponse, Response,
53};
54use crate::coord::{Coordinator, ExecuteContextExtra};
55use crate::error::AdapterError;
56use crate::metrics::Metrics;
57use crate::optimize::{self, Optimize};
58use crate::session::{
59    EndTransactionAction, PreparedStatement, Session, SessionConfig, TransactionId,
60};
61use crate::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
62use crate::telemetry::{self, EventDetails, SegmentClientExt, StatementFailureType};
63use crate::webhook::AppendWebhookResponse;
64use crate::{AdapterNotice, AppendWebhookError, PeekResponseUnary, StartupResponse};
65
66/// A handle to a running coordinator.
67///
68/// The coordinator runs on its own thread. Dropping the handle will wait for
69/// the coordinator's thread to exit, which will only occur after all
70/// outstanding [`Client`]s for the coordinator have dropped.
71pub struct Handle {
72    pub(crate) session_id: Uuid,
73    pub(crate) start_instant: Instant,
74    pub(crate) _thread: JoinOnDropHandle<()>,
75}
76
77impl Handle {
78    /// Returns the session ID associated with this coordinator.
79    ///
80    /// The session ID is generated on coordinator boot. It lasts for the
81    /// lifetime of the coordinator. Restarting the coordinator will result
82    /// in a new session ID.
83    pub fn session_id(&self) -> Uuid {
84        self.session_id
85    }
86
87    /// Returns the instant at which the coordinator booted.
88    pub fn start_instant(&self) -> Instant {
89        self.start_instant
90    }
91}
92
93/// A coordinator client.
94///
95/// A coordinator client is a simple handle to a communication channel with the
96/// coordinator. It can be cheaply cloned.
97///
98/// Clients keep the coordinator alive. The coordinator will not exit until all
99/// outstanding clients have dropped.
100#[derive(Debug, Clone)]
101pub struct Client {
102    build_info: &'static BuildInfo,
103    inner_cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
104    id_alloc: IdAllocator<IdAllocatorInnerBitSet>,
105    now: NowFn,
106    metrics: Metrics,
107    environment_id: EnvironmentId,
108    segment_client: Option<mz_segment::Client>,
109}
110
111impl Client {
112    pub(crate) fn new(
113        build_info: &'static BuildInfo,
114        cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
115        metrics: Metrics,
116        now: NowFn,
117        environment_id: EnvironmentId,
118        segment_client: Option<mz_segment::Client>,
119    ) -> Client {
120        // Connection ids are 32 bits and have 3 parts.
121        // 1. MSB bit is always 0 because these are interpreted as an i32, and it is possible some
122        //    driver will not handle a negative id since postgres has never produced one because it
123        //    uses process ids.
124        // 2. Next 12 bits are the lower 12 bits of the org id. This allows balancerd to route
125        //    incoming cancel messages to a subset of the environments.
126        // 3. Last 19 bits are random.
127        let env_lower = org_id_conn_bits(&environment_id.organization_id());
128        Client {
129            build_info,
130            inner_cmd_tx: cmd_tx,
131            id_alloc: IdAllocator::new(1, MAX_ORG_ID, env_lower),
132            now,
133            metrics,
134            environment_id,
135            segment_client,
136        }
137    }
138
139    /// Allocates a client for an incoming connection.
140    pub fn new_conn_id(&self) -> Result<ConnectionId, AdapterError> {
141        self.id_alloc.alloc().ok_or(AdapterError::IdExhaustionError)
142    }
143
144    /// Creates a new session associated with this client for the given user.
145    ///
146    /// It is the caller's responsibility to have authenticated the user.
147    pub fn new_session(&self, config: SessionConfig) -> Session {
148        // We use the system clock to determine when a session connected to Materialize. This is not
149        // intended to be 100% accurate and correct, so we don't burden the timestamp oracle with
150        // generating a more correct timestamp.
151        Session::new(self.build_info, config, self.metrics().session_metrics())
152    }
153
154    /// Preforms an authentication check for the given user.
155    pub async fn authenticate(
156        &self,
157        user: &String,
158        password: &Password,
159    ) -> Result<AuthResponse, AdapterError> {
160        let (tx, rx) = oneshot::channel();
161        self.send(Command::AuthenticatePassword {
162            role_name: user.to_string(),
163            password: Some(password.clone()),
164            tx,
165        });
166        let response = rx.await.expect("sender dropped")?;
167        Ok(response)
168    }
169
170    /// Upgrades this client to a session client.
171    ///
172    /// A session is a connection that has successfully negotiated parameters,
173    /// like the user. Most coordinator operations are available only after
174    /// upgrading a connection to a session.
175    ///
176    /// Returns a new client that is bound to the session and a response
177    /// containing various details about the startup.
178    #[mz_ore::instrument(level = "debug")]
179    pub async fn startup(&self, session: Session) -> Result<SessionClient, AdapterError> {
180        let user = session.user().clone();
181        let conn_id = session.conn_id().clone();
182        let secret_key = session.secret_key();
183        let uuid = session.uuid();
184        let client_ip = session.client_ip();
185        let application_name = session.application_name().into();
186        let notice_tx = session.retain_notice_transmitter();
187
188        let (tx, rx) = oneshot::channel();
189
190        // ~~SPOOKY ZONE~~
191        //
192        // This guard prevents a race where the startup command finishes, but the Future returned
193        // by this function is concurrently dropped, so we never create a `SessionClient` and thus
194        // never cleanup the initialized Session.
195        let rx = rx.with_guard(|_| {
196            self.send(Command::Terminate {
197                conn_id: conn_id.clone(),
198                tx: None,
199            });
200        });
201
202        self.send(Command::Startup {
203            tx,
204            user,
205            conn_id: conn_id.clone(),
206            secret_key,
207            uuid,
208            client_ip: client_ip.copied(),
209            application_name,
210            notice_tx,
211        });
212
213        // When startup fails, no need to call terminate (handle_startup does this). Delay creating
214        // the client until after startup to sidestep the panic in its `Drop` implementation.
215        let response = rx.await.expect("sender dropped")?;
216
217        // Create the client as soon as startup succeeds (before any await points) so its `Drop` can
218        // handle termination.
219        let mut client = SessionClient {
220            inner: Some(self.clone()),
221            session: Some(session),
222            timeouts: Timeout::new(),
223            environment_id: self.environment_id.clone(),
224            segment_client: self.segment_client.clone(),
225        };
226
227        let StartupResponse {
228            role_id,
229            write_notify,
230            session_defaults,
231            catalog,
232        } = response;
233
234        let session = client.session();
235        session.initialize_role_metadata(role_id);
236        let vars_mut = session.vars_mut();
237        for (name, val) in session_defaults {
238            if let Err(err) = vars_mut.set_default(&name, val.borrow()) {
239                // Note: erroring here is unexpected, but we don't want to panic if somehow our
240                // assumptions are wrong.
241                tracing::error!("failed to set peristed default, {err:?}");
242            }
243        }
244        session
245            .vars_mut()
246            .end_transaction(EndTransactionAction::Commit);
247
248        // Stash the future that notifies us of builtin table writes completing, we'll block on
249        // this future before allowing queries from this session against relevant relations.
250        //
251        // Note: We stash the future as opposed to waiting on it here to prevent blocking session
252        // creation on builtin table updates. This improves the latency for session creation and
253        // reduces scheduling load on any dataflows that read from these builtin relations, since
254        // it allows updates to be batched.
255        session.set_builtin_table_updates(write_notify);
256
257        let catalog = catalog.for_session(session);
258
259        let cluster_active = session.vars().cluster().to_string();
260        if session.vars().welcome_message() {
261            let cluster_info = if catalog.resolve_cluster(Some(&cluster_active)).is_err() {
262                format!("{cluster_active} (does not exist)")
263            } else {
264                cluster_active.to_string()
265            };
266
267            // Emit a welcome message, optimized for readability by humans using
268            // interactive tools. If you change the message, make sure that it
269            // formats nicely in both `psql` and the console's SQL shell.
270            session.add_notice(AdapterNotice::Welcome(format!(
271                "connected to Materialize v{}
272  Org ID: {}
273  Region: {}
274  User: {}
275  Cluster: {}
276  Database: {}
277  {}
278  Session UUID: {}
279
280Issue a SQL query to get started. Need help?
281  View documentation: https://materialize.com/s/docs
282  Join our Slack community: https://materialize.com/s/chat
283    ",
284                session.vars().build_info().semver_version(),
285                self.environment_id.organization_id(),
286                self.environment_id.region(),
287                session.vars().user().name,
288                cluster_info,
289                session.vars().database(),
290                match session.vars().search_path() {
291                    [schema] => format!("Schema: {}", schema),
292                    schemas => format!(
293                        "Search path: {}",
294                        schemas.iter().map(|id| id.to_string()).join(", ")
295                    ),
296                },
297                session.uuid(),
298            )));
299        }
300
301        if session.vars().current_object_missing_warnings() {
302            if catalog.active_database().is_none() {
303                let db = session.vars().database().into();
304                session.add_notice(AdapterNotice::UnknownSessionDatabase(db));
305            }
306        }
307
308        // Users stub their toe on their default cluster not existing, so we provide a notice to
309        // help guide them on what do to.
310        let cluster_var = session
311            .vars()
312            .inspect(CLUSTER.name())
313            .expect("cluster should exist");
314        if session.vars().current_object_missing_warnings()
315            && catalog.resolve_cluster(Some(&cluster_active)).is_err()
316        {
317            let cluster_notice = 'notice: {
318                if cluster_var.inspect_session_value().is_some() {
319                    break 'notice Some(AdapterNotice::DefaultClusterDoesNotExist {
320                        name: cluster_active,
321                        kind: "session",
322                        suggested_action: "Pick an extant cluster with SET CLUSTER = name. Run SHOW CLUSTERS to see available clusters.".into(),
323                    });
324                }
325
326                let role_default = catalog.get_role(catalog.active_role_id());
327                let role_cluster = match role_default.vars().get(CLUSTER.name()) {
328                    Some(OwnedVarInput::Flat(name)) => Some(name),
329                    None => None,
330                    // This is unexpected!
331                    Some(v @ OwnedVarInput::SqlSet(_)) => {
332                        tracing::warn!(?v, "SqlSet found for cluster Role Default");
333                        break 'notice None;
334                    }
335                };
336
337                let alter_role = "with `ALTER ROLE <role> SET cluster TO <cluster>;`";
338                match role_cluster {
339                    // If there is no default, suggest a Role default.
340                    None => Some(AdapterNotice::DefaultClusterDoesNotExist {
341                        name: cluster_active,
342                        kind: "system",
343                        suggested_action: format!(
344                            "Set a default cluster for the current role {alter_role}."
345                        ),
346                    }),
347                    // If the default does not exist, suggest to change it.
348                    Some(_) => Some(AdapterNotice::DefaultClusterDoesNotExist {
349                        name: cluster_active,
350                        kind: "role",
351                        suggested_action: format!(
352                            "Change the default cluster for the current role {alter_role}."
353                        ),
354                    }),
355                }
356            };
357
358            if let Some(notice) = cluster_notice {
359                session.add_notice(notice);
360            }
361        }
362
363        Ok(client)
364    }
365
366    /// Cancels the query currently running on the specified connection.
367    pub fn cancel_request(&self, conn_id: ConnectionIdType, secret_key: u32) {
368        self.send(Command::CancelRequest {
369            conn_id,
370            secret_key,
371        });
372    }
373
374    /// Executes a single SQL statement that returns rows as the
375    /// `mz_support` user.
376    pub async fn support_execute_one(
377        &self,
378        sql: &str,
379    ) -> Result<Box<dyn RowIterator>, anyhow::Error> {
380        // Connect to the coordinator.
381        let conn_id = self.new_conn_id()?;
382        let session = self.new_session(SessionConfig {
383            conn_id,
384            uuid: Uuid::new_v4(),
385            user: SUPPORT_USER.name.clone(),
386            client_ip: None,
387            external_metadata_rx: None,
388            internal_user_metadata: None,
389            helm_chart_version: None,
390        });
391        let mut session_client = self.startup(session).await?;
392
393        // Parse the SQL statement.
394        let stmts = mz_sql::parse::parse(sql)?;
395        if stmts.len() != 1 {
396            bail!("must supply exactly one query");
397        }
398        let StatementParseResult { ast: stmt, sql } = stmts.into_element();
399
400        const EMPTY_PORTAL: &str = "";
401        session_client.start_transaction(Some(1))?;
402        session_client
403            .declare(EMPTY_PORTAL.into(), stmt, sql.to_string())
404            .await?;
405        match session_client
406            .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
407            .await?
408        {
409            (ExecuteResponse::SendingRows { future, .. }, _) => match future.await {
410                PeekResponseUnary::Rows(rows) => Ok(rows),
411                PeekResponseUnary::Canceled => bail!("query canceled"),
412                PeekResponseUnary::Error(e) => bail!(e),
413            },
414            r => bail!("unsupported response type: {r:?}"),
415        }
416    }
417
418    /// Returns the metrics associated with the adapter layer.
419    pub fn metrics(&self) -> &Metrics {
420        &self.metrics
421    }
422
423    /// The current time according to the [`Client`].
424    pub fn now(&self) -> DateTime<Utc> {
425        to_datetime((self.now)())
426    }
427
428    /// Get a metadata and a channel that can be used to append to a webhook source.
429    pub async fn get_webhook_appender(
430        &self,
431        database: String,
432        schema: String,
433        name: String,
434    ) -> Result<AppendWebhookResponse, AppendWebhookError> {
435        let (tx, rx) = oneshot::channel();
436
437        // Send our request.
438        self.send(Command::GetWebhook {
439            database,
440            schema,
441            name,
442            tx,
443        });
444
445        // Using our one shot channel to get the result, returning an error if the sender dropped.
446        let response = rx
447            .await
448            .map_err(|_| anyhow::anyhow!("failed to receive webhook response"))?;
449
450        response
451    }
452
453    /// Gets the current value of all system variables.
454    pub async fn get_system_vars(&self) -> SystemVars {
455        let (tx, rx) = oneshot::channel();
456        self.send(Command::GetSystemVars { tx });
457        rx.await.expect("coordinator unexpectedly gone")
458    }
459
460    #[instrument(level = "debug")]
461    fn send(&self, cmd: Command) {
462        self.inner_cmd_tx
463            .send((OpenTelemetryContext::obtain(), cmd))
464            .expect("coordinator unexpectedly gone");
465    }
466}
467
468/// A coordinator client that is bound to a connection.
469///
470/// See also [`Client`].
471pub struct SessionClient {
472    // Invariant: inner may only be `None` after the session has been terminated.
473    // Once the session is terminated, no communication to the Coordinator
474    // should be attempted.
475    inner: Option<Client>,
476    // Invariant: session may only be `None` during a method call. Every public
477    // method must ensure that `Session` is `Some` before it returns.
478    session: Option<Session>,
479    timeouts: Timeout,
480    segment_client: Option<mz_segment::Client>,
481    environment_id: EnvironmentId,
482}
483
484impl SessionClient {
485    /// Parses a SQL expression, reporting failures as a telemetry event if
486    /// possible.
487    pub fn parse<'a>(
488        &self,
489        sql: &'a str,
490    ) -> Result<Result<Vec<StatementParseResult<'a>>, ParserStatementError>, String> {
491        match mz_sql::parse::parse_with_limit(sql) {
492            Ok(Err(e)) => {
493                self.track_statement_parse_failure(&e);
494                Ok(Err(e))
495            }
496            r => r,
497        }
498    }
499
500    fn track_statement_parse_failure(&self, parse_error: &ParserStatementError) {
501        let session = self.session.as_ref().expect("session invariant violated");
502        let Some(user_id) = session.user().external_metadata.as_ref().map(|m| m.user_id) else {
503            return;
504        };
505        let Some(segment_client) = &self.segment_client else {
506            return;
507        };
508        let Some(statement_kind) = parse_error.statement else {
509            return;
510        };
511        let Some((action, object_type)) = telemetry::analyze_audited_statement(statement_kind)
512        else {
513            return;
514        };
515        let event_type = StatementFailureType::ParseFailure;
516        let event_name = format!(
517            "{} {} {}",
518            object_type.as_title_case(),
519            action.as_title_case(),
520            event_type.as_title_case(),
521        );
522        segment_client.environment_track(
523            &self.environment_id,
524            event_name,
525            json!({
526                "statement_kind": statement_kind,
527                "error": &parse_error.error,
528            }),
529            EventDetails {
530                user_id: Some(user_id),
531                application_name: Some(session.application_name()),
532                ..Default::default()
533            },
534        );
535    }
536
537    // Verify and return the named prepared statement. We need to verify each use
538    // to make sure the prepared statement is still safe to use.
539    pub async fn get_prepared_statement(
540        &mut self,
541        name: &str,
542    ) -> Result<&PreparedStatement, AdapterError> {
543        let catalog = self.catalog_snapshot().await;
544        Coordinator::verify_prepared_statement(&catalog, self.session(), name)?;
545        Ok(self
546            .session()
547            .get_prepared_statement_unverified(name)
548            .expect("must exist"))
549    }
550
551    /// Saves the parsed statement as a prepared statement.
552    ///
553    /// The prepared statement is saved in the connection's [`crate::session::Session`]
554    /// under the specified name.
555    pub async fn prepare(
556        &mut self,
557        name: String,
558        stmt: Option<Statement<Raw>>,
559        sql: String,
560        param_types: Vec<Option<ScalarType>>,
561    ) -> Result<(), AdapterError> {
562        let catalog = self.catalog_snapshot().await;
563
564        // Note: This failpoint is used to simulate a request outliving the external connection
565        // that made it.
566        let mut async_pause = false;
567        (|| {
568            fail::fail_point!("async_prepare", |val| {
569                async_pause = val.map_or(false, |val| val.parse().unwrap_or(false))
570            });
571        })();
572        if async_pause {
573            tokio::time::sleep(Duration::from_secs(1)).await;
574        };
575
576        let desc = Coordinator::describe(&catalog, self.session(), stmt.clone(), param_types)?;
577        let now = self.now();
578        self.session().set_prepared_statement(
579            name,
580            stmt,
581            sql,
582            desc,
583            catalog.transient_revision(),
584            now,
585        );
586        Ok(())
587    }
588
589    /// Binds a statement to a portal.
590    #[mz_ore::instrument(level = "debug")]
591    pub async fn declare(
592        &mut self,
593        name: String,
594        stmt: Statement<Raw>,
595        sql: String,
596    ) -> Result<(), AdapterError> {
597        let catalog = self.catalog_snapshot().await;
598        let param_types = vec![];
599        let desc =
600            Coordinator::describe(&catalog, self.session(), Some(stmt.clone()), param_types)?;
601        let params = vec![];
602        let result_formats = vec![mz_pgwire_common::Format::Text; desc.arity()];
603        let now = self.now();
604        let logging = self.session().mint_logging(sql, Some(&stmt), now);
605        self.session().set_portal(
606            name,
607            desc,
608            Some(stmt),
609            logging,
610            params,
611            result_formats,
612            catalog.transient_revision(),
613        )?;
614        Ok(())
615    }
616
617    /// Executes a previously-bound portal.
618    ///
619    /// Note: the provided `cancel_future` must be cancel-safe as it's polled in a `select!` loop.
620    #[mz_ore::instrument(level = "debug")]
621    pub async fn execute(
622        &mut self,
623        portal_name: String,
624        cancel_future: impl Future<Output = std::io::Error> + Send,
625        outer_ctx_extra: Option<ExecuteContextExtra>,
626    ) -> Result<(ExecuteResponse, Instant), AdapterError> {
627        let execute_started = Instant::now();
628        let response = self
629            .send_with_cancel(
630                |tx, session| Command::Execute {
631                    portal_name,
632                    session,
633                    tx,
634                    outer_ctx_extra,
635                },
636                cancel_future,
637            )
638            .await?;
639        Ok((response, execute_started))
640    }
641
642    fn now(&self) -> EpochMillis {
643        (self.inner().now)()
644    }
645
646    fn now_datetime(&self) -> DateTime<Utc> {
647        to_datetime(self.now())
648    }
649
650    /// Starts a transaction based on implicit:
651    /// - `None`: InTransaction
652    /// - `Some(1)`: Started
653    /// - `Some(n > 1)`: InTransactionImplicit
654    /// - `Some(0)`: no change
655    pub fn start_transaction(&mut self, implicit: Option<usize>) -> Result<(), AdapterError> {
656        let now = self.now_datetime();
657        let session = self.session.as_mut().expect("session invariant violated");
658        let result = match implicit {
659            None => session.start_transaction(now, None, None),
660            Some(stmts) => {
661                session.start_transaction_implicit(now, stmts);
662                Ok(())
663            }
664        };
665        result
666    }
667
668    /// Ends a transaction. Even if an error is returned, guarantees that the transaction in the
669    /// session and Coordinator has cleared its state.
670    #[instrument(level = "debug")]
671    pub async fn end_transaction(
672        &mut self,
673        action: EndTransactionAction,
674    ) -> Result<ExecuteResponse, AdapterError> {
675        let res = self
676            .send(|tx, session| Command::Commit {
677                action,
678                session,
679                tx,
680            })
681            .await;
682        // Commit isn't guaranteed to set the session's state to anything specific, so clear it
683        // here. It's safe to ignore the returned `TransactionStatus` because that doesn't contain
684        // any data that the Coordinator must act on for correctness.
685        let _ = self.session().clear_transaction();
686        res
687    }
688
689    /// Fails a transaction.
690    pub fn fail_transaction(&mut self) {
691        let session = self.session.take().expect("session invariant violated");
692        let session = session.fail_transaction();
693        self.session = Some(session);
694    }
695
696    /// Fetches the catalog.
697    #[instrument(level = "debug")]
698    pub async fn catalog_snapshot(&self) -> Arc<Catalog> {
699        let CatalogSnapshot { catalog } = self
700            .send_without_session(|tx| Command::CatalogSnapshot { tx })
701            .await;
702        catalog
703    }
704
705    /// Dumps the catalog to a JSON string.
706    ///
707    /// No authorization is performed, so access to this function must be limited to internal
708    /// servers or superusers.
709    pub async fn dump_catalog(&self) -> Result<CatalogDump, AdapterError> {
710        let catalog = self.catalog_snapshot().await;
711        catalog.dump().map_err(AdapterError::from)
712    }
713
714    /// Checks the catalog for internal consistency, returning a JSON object describing the
715    /// inconsistencies, if there are any.
716    ///
717    /// No authorization is performed, so access to this function must be limited to internal
718    /// servers or superusers.
719    pub async fn check_catalog(&self) -> Result<(), serde_json::Value> {
720        let catalog = self.catalog_snapshot().await;
721        catalog.check_consistency()
722    }
723
724    /// Checks the coordinator for internal consistency, returning a JSON object describing the
725    /// inconsistencies, if there are any. This is a superset of checks that check_catalog performs,
726    ///
727    /// No authorization is performed, so access to this function must be limited to internal
728    /// servers or superusers.
729    pub async fn check_coordinator(&self) -> Result<(), serde_json::Value> {
730        self.send_without_session(|tx| Command::CheckConsistency { tx })
731            .await
732            .map_err(|inconsistencies| {
733                serde_json::to_value(inconsistencies).unwrap_or_else(|_| {
734                    serde_json::Value::String("failed to serialize inconsistencies".to_string())
735                })
736            })
737    }
738
739    pub async fn dump_coordinator_state(&self) -> Result<serde_json::Value, anyhow::Error> {
740        self.send_without_session(|tx| Command::Dump { tx }).await
741    }
742
743    /// Tells the coordinator a statement has finished execution, in the cases
744    /// where we have no other reason to communicate with the coordinator.
745    pub fn retire_execute(&self, data: ExecuteContextExtra, reason: StatementEndedExecutionReason) {
746        if !data.is_trivial() {
747            let cmd = Command::RetireExecute { data, reason };
748            self.inner().send(cmd);
749        }
750    }
751
752    /// Inserts a set of rows into the given table.
753    ///
754    /// The rows only contain the columns positions in `columns`, so they
755    /// must be re-encoded for adding the default values for the remaining
756    /// ones.
757    pub async fn insert_rows(
758        &mut self,
759        id: CatalogItemId,
760        columns: Vec<ColumnIndex>,
761        rows: Vec<Row>,
762        ctx_extra: ExecuteContextExtra,
763    ) -> Result<ExecuteResponse, AdapterError> {
764        // TODO: Remove this clone once we always have the session. It's currently needed because
765        // self.session returns a mut ref, so we can't call it twice.
766        let pcx = self.session().pcx().clone();
767
768        let catalog = self.catalog_snapshot().await;
769        let conn_catalog = catalog.for_session(self.session());
770
771        // Collect optimizer parameters.
772        let optimizer_config = optimize::OptimizerConfig::from(conn_catalog.system_vars());
773        let mut optimizer = optimize::view::Optimizer::new(optimizer_config, None);
774
775        let result: Result<_, AdapterError> =
776            mz_sql::plan::plan_copy_from(&pcx, &conn_catalog, id, columns, rows)
777                .err_into()
778                .and_then(|values| optimizer.optimize(values).err_into())
779                .and_then(|values| {
780                    // Copied rows must always be constants.
781                    Coordinator::insert_constant(&catalog, self.session(), id, values.into_inner())
782                });
783        self.retire_execute(ctx_extra, (&result).into());
784        result
785    }
786
787    /// Gets the current value of all system variables.
788    pub async fn get_system_vars(&self) -> SystemVars {
789        self.inner().get_system_vars().await
790    }
791
792    /// Updates the specified system variables to the specified values.
793    pub async fn set_system_vars(
794        &mut self,
795        vars: BTreeMap<String, String>,
796    ) -> Result<(), AdapterError> {
797        let conn_id = self.session().conn_id().clone();
798        self.send_without_session(|tx| Command::SetSystemVars { vars, conn_id, tx })
799            .await
800    }
801
802    /// Terminates the client session.
803    pub async fn terminate(&mut self) {
804        let conn_id = self.session().conn_id().clone();
805        let res = self
806            .send_without_session(|tx| Command::Terminate {
807                conn_id,
808                tx: Some(tx),
809            })
810            .await;
811        if let Err(e) = res {
812            // Nothing we can do to handle a failed terminate so we just log and ignore it.
813            error!("Unable to terminate session: {e:?}");
814        }
815        // Prevent any communication with Coordinator after session is terminated.
816        self.inner = None;
817    }
818
819    /// Returns a mutable reference to the session bound to this client.
820    pub fn session(&mut self) -> &mut Session {
821        self.session.as_mut().expect("session invariant violated")
822    }
823
824    /// Returns a reference to the inner client.
825    pub fn inner(&self) -> &Client {
826        self.inner.as_ref().expect("inner invariant violated")
827    }
828
829    async fn send_without_session<T, F>(&self, f: F) -> T
830    where
831        F: FnOnce(oneshot::Sender<T>) -> Command,
832    {
833        let (tx, rx) = oneshot::channel();
834        self.inner().send(f(tx));
835        rx.await.expect("sender dropped")
836    }
837
838    #[instrument(level = "debug")]
839    async fn send<T, F>(&mut self, f: F) -> Result<T, AdapterError>
840    where
841        F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
842    {
843        self.send_with_cancel(f, futures::future::pending()).await
844    }
845
846    /// Send a [`Command`] to the Coordinator, with the ability to cancel the command.
847    ///
848    /// Note: the provided `cancel_future` must be cancel-safe as it's polled in a `select!` loop.
849    #[instrument(level = "debug")]
850    async fn send_with_cancel<T, F>(
851        &mut self,
852        f: F,
853        cancel_future: impl Future<Output = std::io::Error> + Send,
854    ) -> Result<T, AdapterError>
855    where
856        F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
857    {
858        let session = self.session.take().expect("session invariant violated");
859        let mut typ = None;
860        let application_name = session.application_name();
861        let name_hint = ApplicationNameHint::from_str(application_name);
862        let conn_id = session.conn_id().clone();
863        let (tx, rx) = oneshot::channel();
864
865        // Destructure self so we can hold a mutable reference to the inner client and session at
866        // the same time.
867        let Self {
868            inner: inner_client,
869            session: client_session,
870            ..
871        } = self;
872
873        // TODO(parkmycar): Leaking this invariant here doesn't feel great, but calling
874        // `self.client()` doesn't work because then Rust takes a borrow on the entirity of self.
875        let inner_client = inner_client.as_ref().expect("inner invariant violated");
876
877        // ~~SPOOKY ZONE~~
878        //
879        // This guard prevents a race where a `Session` is returned on `rx` but never placed
880        // back in `self` because the Future returned by this function is concurrently dropped
881        // with the Coordinator sending a response.
882        let mut guarded_rx = rx.with_guard(|response: Response<_>| {
883            *client_session = Some(response.session);
884        });
885
886        inner_client.send({
887            let cmd = f(tx, session);
888            // Measure the success and error rate of certain commands:
889            // - declare reports success of SQL statement planning
890            // - execute reports success of dataflow execution
891            match cmd {
892                Command::Execute { .. } => typ = Some("execute"),
893                Command::GetWebhook { .. } => typ = Some("webhook"),
894                Command::Startup { .. }
895                | Command::AuthenticatePassword { .. }
896                | Command::CatalogSnapshot { .. }
897                | Command::Commit { .. }
898                | Command::CancelRequest { .. }
899                | Command::PrivilegedCancelRequest { .. }
900                | Command::GetSystemVars { .. }
901                | Command::SetSystemVars { .. }
902                | Command::Terminate { .. }
903                | Command::RetireExecute { .. }
904                | Command::CheckConsistency { .. }
905                | Command::Dump { .. } => {}
906            };
907            cmd
908        });
909
910        let mut cancel_future = pin::pin!(cancel_future);
911        let mut cancelled = false;
912        loop {
913            tokio::select! {
914                res = &mut guarded_rx => {
915                    // We received a result, so drop our guard to drop our borrows.
916                    drop(guarded_rx);
917
918                    let res = res.expect("sender dropped");
919                    let status = res.result.is_ok().then_some("success").unwrap_or("error");
920                    if let Err(err) = res.result.as_ref() {
921                        if name_hint.should_trace_errors() {
922                            tracing::warn!(?err, ?name_hint, "adapter response error");
923                        }
924                    }
925
926                    if let Some(typ) = typ {
927                        inner_client
928                            .metrics
929                            .commands
930                            .with_label_values(&[typ, status, name_hint.as_str()])
931                            .inc();
932                    }
933                    *client_session = Some(res.session);
934                    return res.result;
935                },
936                _err = &mut cancel_future, if !cancelled => {
937                    cancelled = true;
938                    inner_client.send(Command::PrivilegedCancelRequest {
939                        conn_id: conn_id.clone(),
940                    });
941                }
942            };
943        }
944    }
945
946    pub fn add_idle_in_transaction_session_timeout(&mut self) {
947        let session = self.session();
948        let timeout_dur = session.vars().idle_in_transaction_session_timeout();
949        if !timeout_dur.is_zero() {
950            let timeout_dur = timeout_dur.clone();
951            if let Some(txn) = session.transaction().inner() {
952                let txn_id = txn.id.clone();
953                let timeout = TimeoutType::IdleInTransactionSession(txn_id);
954                self.timeouts.add_timeout(timeout, timeout_dur);
955            }
956        }
957    }
958
959    pub fn remove_idle_in_transaction_session_timeout(&mut self) {
960        let session = self.session();
961        if let Some(txn) = session.transaction().inner() {
962            let txn_id = txn.id.clone();
963            self.timeouts
964                .remove_timeout(&TimeoutType::IdleInTransactionSession(txn_id));
965        }
966    }
967
968    /// # Cancel safety
969    ///
970    /// This method is cancel safe. If `recv` is used as the event in a
971    /// `tokio::select!` statement and some other branch
972    /// completes first, it is guaranteed that no messages were received on this
973    /// channel.
974    pub async fn recv_timeout(&mut self) -> Option<TimeoutType> {
975        self.timeouts.recv().await
976    }
977}
978
979impl Drop for SessionClient {
980    fn drop(&mut self) {
981        // We may not have a session if this client was dropped while awaiting
982        // a response. In this case, it is the coordinator's responsibility to
983        // terminate the session.
984        if let Some(session) = self.session.take() {
985            // We may not have a connection to the Coordinator if the session was
986            // prematurely terminated, for example due to a timeout.
987            if let Some(inner) = &self.inner {
988                inner.send(Command::Terminate {
989                    conn_id: session.conn_id().clone(),
990                    tx: None,
991                })
992            }
993        }
994    }
995}
996
997#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
998pub enum TimeoutType {
999    IdleInTransactionSession(TransactionId),
1000}
1001
1002impl Display for TimeoutType {
1003    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1004        match self {
1005            TimeoutType::IdleInTransactionSession(txn_id) => {
1006                writeln!(f, "Idle in transaction session for transaction '{txn_id}'")
1007            }
1008        }
1009    }
1010}
1011
1012impl From<TimeoutType> for AdapterError {
1013    fn from(timeout: TimeoutType) -> Self {
1014        match timeout {
1015            TimeoutType::IdleInTransactionSession(_) => {
1016                AdapterError::IdleInTransactionSessionTimeout
1017            }
1018        }
1019    }
1020}
1021
1022struct Timeout {
1023    tx: mpsc::UnboundedSender<TimeoutType>,
1024    rx: mpsc::UnboundedReceiver<TimeoutType>,
1025    active_timeouts: BTreeMap<TimeoutType, AbortOnDropHandle<()>>,
1026}
1027
1028impl Timeout {
1029    fn new() -> Self {
1030        let (tx, rx) = mpsc::unbounded_channel();
1031        Timeout {
1032            tx,
1033            rx,
1034            active_timeouts: BTreeMap::new(),
1035        }
1036    }
1037
1038    /// # Cancel safety
1039    ///
1040    /// This method is cancel safe. If `recv` is used as the event in a
1041    /// `tokio::select!` statement and some other branch
1042    /// completes first, it is guaranteed that no messages were received on this
1043    /// channel.
1044    ///
1045    /// <https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.UnboundedReceiver.html#cancel-safety>
1046    async fn recv(&mut self) -> Option<TimeoutType> {
1047        self.rx.recv().await
1048    }
1049
1050    fn add_timeout(&mut self, timeout: TimeoutType, duration: Duration) {
1051        let tx = self.tx.clone();
1052        let timeout_key = timeout.clone();
1053        let handle = mz_ore::task::spawn(|| format!("{timeout_key}"), async move {
1054            tokio::time::sleep(duration).await;
1055            let _ = tx.send(timeout);
1056        })
1057        .abort_on_drop();
1058        self.active_timeouts.insert(timeout_key, handle);
1059    }
1060
1061    fn remove_timeout(&mut self, timeout: &TimeoutType) {
1062        self.active_timeouts.remove(timeout);
1063
1064        // Remove the timeout from the rx queue if it exists.
1065        let mut timeouts = Vec::new();
1066        while let Ok(pending_timeout) = self.rx.try_recv() {
1067            if timeout != &pending_timeout {
1068                timeouts.push(pending_timeout);
1069            }
1070        }
1071        for pending_timeout in timeouts {
1072            self.tx.send(pending_timeout).expect("rx is in this struct");
1073        }
1074    }
1075}
1076
1077/// A wrapper around an UnboundedReceiver of PeekResponseUnary that records when it sees the
1078/// first row data in the given histogram
1079#[derive(Derivative)]
1080#[derivative(Debug)]
1081pub struct RecordFirstRowStream {
1082    #[derivative(Debug = "ignore")]
1083    pub rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1084    pub execute_started: Instant,
1085    pub time_to_first_row_seconds: Histogram,
1086    saw_rows: bool,
1087}
1088
1089impl RecordFirstRowStream {
1090    /// Create a new [`RecordFirstRowStream`]
1091    pub fn new(
1092        rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1093        execute_started: Instant,
1094        client: &SessionClient,
1095        instance_id: Option<ComputeInstanceId>,
1096        strategy: Option<StatementExecutionStrategy>,
1097    ) -> Self {
1098        let histogram = Self::histogram(client, instance_id, strategy);
1099        Self {
1100            rows,
1101            execute_started,
1102            time_to_first_row_seconds: histogram,
1103            saw_rows: false,
1104        }
1105    }
1106
1107    fn histogram(
1108        client: &SessionClient,
1109        instance_id: Option<ComputeInstanceId>,
1110        strategy: Option<StatementExecutionStrategy>,
1111    ) -> Histogram {
1112        let isolation_level = *client
1113            .session
1114            .as_ref()
1115            .expect("session invariant")
1116            .vars()
1117            .transaction_isolation();
1118        let instance = match instance_id {
1119            Some(i) => Cow::Owned(i.to_string()),
1120            None => Cow::Borrowed("none"),
1121        };
1122        let strategy = match strategy {
1123            Some(s) => s.name(),
1124            None => "none",
1125        };
1126
1127        client
1128            .inner()
1129            .metrics()
1130            .time_to_first_row_seconds
1131            .with_label_values(&[&instance, isolation_level.as_str(), strategy])
1132    }
1133
1134    /// If you want to match [`RecordFirstRowStream`]'s logic but don't need
1135    /// a UnboundedReceiver, you can tell it when to record an observation.
1136    pub fn record(
1137        execute_started: Instant,
1138        client: &SessionClient,
1139        instance_id: Option<ComputeInstanceId>,
1140        strategy: Option<StatementExecutionStrategy>,
1141    ) {
1142        Self::histogram(client, instance_id, strategy)
1143            .observe(execute_started.elapsed().as_secs_f64());
1144    }
1145
1146    pub async fn recv(&mut self) -> Option<PeekResponseUnary> {
1147        let msg = self.rows.next().await;
1148        if !self.saw_rows && matches!(msg, Some(PeekResponseUnary::Rows(_))) {
1149            self.saw_rows = true;
1150            self.time_to_first_row_seconds
1151                .observe(self.execute_started.elapsed().as_secs_f64());
1152        }
1153        msg
1154    }
1155}