Skip to main content

mz_adapter/
client.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::borrow::Cow;
11use std::collections::BTreeMap;
12use std::fmt::{Debug, Display, Formatter};
13use std::future::Future;
14use std::pin::{self, Pin};
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18use anyhow::bail;
19use chrono::{DateTime, Utc};
20use derivative::Derivative;
21use futures::{Stream, StreamExt};
22use itertools::Itertools;
23use mz_adapter_types::connection::{ConnectionId, ConnectionIdType};
24use mz_auth::Authenticated;
25use mz_auth::password::Password;
26use mz_build_info::BuildInfo;
27use mz_compute_types::ComputeInstanceId;
28use mz_ore::channel::OneshotReceiverExt;
29use mz_ore::collections::CollectionExt;
30use mz_ore::id_gen::{IdAllocator, IdAllocatorInnerBitSet, MAX_ORG_ID, org_id_conn_bits};
31use mz_ore::instrument;
32use mz_ore::now::{EpochMillis, NowFn, to_datetime};
33use mz_ore::task::AbortOnDropHandle;
34use mz_ore::thread::JoinOnDropHandle;
35use mz_ore::tracing::OpenTelemetryContext;
36use mz_repr::user::InternalUserMetadata;
37use mz_repr::{CatalogItemId, ColumnIndex, SqlScalarType};
38use mz_sql::ast::{Raw, Statement};
39use mz_sql::catalog::{EnvironmentId, SessionCatalog};
40use mz_sql::session::hint::ApplicationNameHint;
41use mz_sql::session::metadata::SessionMetadata;
42use mz_sql::session::user::SUPPORT_USER;
43use mz_sql::session::vars::{
44    CLUSTER, ENABLE_FRONTEND_PEEK_SEQUENCING, OwnedVarInput, SystemVars, Var,
45};
46use mz_sql_parser::parser::{ParserStatementError, StatementParseResult};
47use prometheus::Histogram;
48use serde_json::json;
49use tokio::sync::{mpsc, oneshot};
50use tracing::{debug, error};
51use uuid::Uuid;
52
53use crate::catalog::Catalog;
54use crate::command::{
55    CatalogDump, CatalogSnapshot, Command, CopyFromStdinWriter, ExecuteResponse, Response,
56    SASLChallengeResponse, SASLVerifyProofResponse, SuperuserAttribute,
57};
58use crate::coord::{Coordinator, ExecuteContextGuard};
59use crate::error::AdapterError;
60use crate::metrics::Metrics;
61use crate::session::{
62    EndTransactionAction, PreparedStatement, Session, SessionConfig, StateRevision, TransactionId,
63};
64use crate::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
65use crate::telemetry::{self, EventDetails, SegmentClientExt, StatementFailureType};
66use crate::webhook::AppendWebhookResponse;
67use crate::{AdapterNotice, AppendWebhookError, PeekClient, PeekResponseUnary, StartupResponse};
68
69/// A handle to a running coordinator.
70///
71/// The coordinator runs on its own thread. Dropping the handle will wait for
72/// the coordinator's thread to exit, which will only occur after all
73/// outstanding [`Client`]s for the coordinator have dropped.
74pub struct Handle {
75    pub(crate) session_id: Uuid,
76    pub(crate) start_instant: Instant,
77    pub(crate) _thread: JoinOnDropHandle<()>,
78}
79
80impl Handle {
81    /// Returns the session ID associated with this coordinator.
82    ///
83    /// The session ID is generated on coordinator boot. It lasts for the
84    /// lifetime of the coordinator. Restarting the coordinator will result
85    /// in a new session ID.
86    pub fn session_id(&self) -> Uuid {
87        self.session_id
88    }
89
90    /// Returns the instant at which the coordinator booted.
91    pub fn start_instant(&self) -> Instant {
92        self.start_instant
93    }
94}
95
96/// A coordinator client.
97///
98/// A coordinator client is a simple handle to a communication channel with the
99/// coordinator. It can be cheaply cloned.
100///
101/// Clients keep the coordinator alive. The coordinator will not exit until all
102/// outstanding clients have dropped.
103#[derive(Debug, Clone)]
104pub struct Client {
105    build_info: &'static BuildInfo,
106    inner_cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
107    id_alloc: IdAllocator<IdAllocatorInnerBitSet>,
108    now: NowFn,
109    metrics: Metrics,
110    environment_id: EnvironmentId,
111    segment_client: Option<mz_segment::Client>,
112}
113
114impl Client {
115    pub(crate) fn new(
116        build_info: &'static BuildInfo,
117        cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
118        metrics: Metrics,
119        now: NowFn,
120        environment_id: EnvironmentId,
121        segment_client: Option<mz_segment::Client>,
122    ) -> Client {
123        // Connection ids are 32 bits and have 3 parts.
124        // 1. MSB bit is always 0 because these are interpreted as an i32, and it is possible some
125        //    driver will not handle a negative id since postgres has never produced one because it
126        //    uses process ids.
127        // 2. Next 12 bits are the lower 12 bits of the org id. This allows balancerd to route
128        //    incoming cancel messages to a subset of the environments.
129        // 3. Last 19 bits are random.
130        let env_lower = org_id_conn_bits(&environment_id.organization_id());
131        Client {
132            build_info,
133            inner_cmd_tx: cmd_tx,
134            id_alloc: IdAllocator::new(1, MAX_ORG_ID, env_lower),
135            now,
136            metrics,
137            environment_id,
138            segment_client,
139        }
140    }
141
142    /// Allocates a client for an incoming connection.
143    pub fn new_conn_id(&self) -> Result<ConnectionId, AdapterError> {
144        self.id_alloc.alloc().ok_or(AdapterError::IdExhaustionError)
145    }
146
147    /// Creates a new session associated with this client for the given user.
148    ///
149    /// It is the caller's responsibility to have authenticated the user.
150    /// We pass in an Authenticated marker as a guardrail to ensure the
151    /// user has authenticated with an authenticator before creating a session.
152    pub fn new_session(&self, config: SessionConfig, _authenticated: Authenticated) -> Session {
153        // We use the system clock to determine when a session connected to Materialize. This is not
154        // intended to be 100% accurate and correct, so we don't burden the timestamp oracle with
155        // generating a more correct timestamp.
156        Session::new(self.build_info, config, self.metrics().session_metrics())
157    }
158
159    /// Verifies the provided user's password against the
160    /// stored credentials in the catalog.
161    pub async fn authenticate(
162        &self,
163        user: &String,
164        password: &Password,
165    ) -> Result<Authenticated, AdapterError> {
166        let (tx, rx) = oneshot::channel();
167        self.send(Command::AuthenticatePassword {
168            role_name: user.to_string(),
169            password: Some(password.clone()),
170            tx,
171        });
172        rx.await.expect("sender dropped")?;
173        Ok(Authenticated)
174    }
175
176    pub async fn generate_sasl_challenge(
177        &self,
178        user: &String,
179        client_nonce: &String,
180    ) -> Result<SASLChallengeResponse, AdapterError> {
181        let (tx, rx) = oneshot::channel();
182        self.send(Command::AuthenticateGetSASLChallenge {
183            role_name: user.to_string(),
184            nonce: client_nonce.to_string(),
185            tx,
186        });
187        let response = rx.await.expect("sender dropped")?;
188        Ok(response)
189    }
190
191    pub async fn verify_sasl_proof(
192        &self,
193        user: &String,
194        proof: &String,
195        nonce: &String,
196        mock_hash: &String,
197    ) -> Result<(SASLVerifyProofResponse, Authenticated), AdapterError> {
198        let (tx, rx) = oneshot::channel();
199        self.send(Command::AuthenticateVerifySASLProof {
200            role_name: user.to_string(),
201            proof: proof.to_string(),
202            auth_message: nonce.to_string(),
203            mock_hash: mock_hash.to_string(),
204            tx,
205        });
206        let response = rx.await.expect("sender dropped")?;
207        Ok((response, Authenticated))
208    }
209
210    /// Upgrades this client to a session client.
211    ///
212    /// A session is a connection that has successfully negotiated parameters,
213    /// like the user. Most coordinator operations are available only after
214    /// upgrading a connection to a session.
215    ///
216    /// Returns a new client that is bound to the session and a response
217    /// containing various details about the startup.
218    #[mz_ore::instrument(level = "debug")]
219    pub async fn startup(&self, session: Session) -> Result<SessionClient, AdapterError> {
220        let user = session.user().clone();
221        let conn_id = session.conn_id().clone();
222        let secret_key = session.secret_key();
223        let uuid = session.uuid();
224        let client_ip = session.client_ip();
225        let application_name = session.application_name().into();
226        let notice_tx = session.retain_notice_transmitter();
227
228        let (tx, rx) = oneshot::channel();
229
230        // ~~SPOOKY ZONE~~
231        //
232        // This guard prevents a race where the startup command finishes, but the Future returned
233        // by this function is concurrently dropped, so we never create a `SessionClient` and thus
234        // never cleanup the initialized Session.
235        let rx = rx.with_guard(|_| {
236            self.send(Command::Terminate {
237                conn_id: conn_id.clone(),
238                tx: None,
239            });
240        });
241
242        self.send(Command::Startup {
243            tx,
244            user,
245            conn_id: conn_id.clone(),
246            secret_key,
247            uuid,
248            client_ip: client_ip.copied(),
249            application_name,
250            notice_tx,
251        });
252
253        // When startup fails, no need to call terminate (handle_startup does this). Delay creating
254        // the client until after startup to sidestep the panic in its `Drop` implementation.
255        let response = rx.await.expect("sender dropped")?;
256
257        // Create the client as soon as startup succeeds (before any await points) so its `Drop` can
258        // handle termination.
259        // Build the PeekClient with controller handles returned from startup.
260        let StartupResponse {
261            role_id,
262            write_notify,
263            session_defaults,
264            catalog,
265            storage_collections,
266            transient_id_gen,
267            optimizer_metrics,
268            persist_client,
269            statement_logging_frontend,
270            superuser_attribute,
271        } = response;
272
273        let peek_client = PeekClient::new(
274            self.clone(),
275            storage_collections,
276            transient_id_gen,
277            optimizer_metrics,
278            persist_client,
279            statement_logging_frontend,
280        );
281
282        let mut client = SessionClient {
283            inner: Some(self.clone()),
284            session: Some(session),
285            timeouts: Timeout::new(),
286            environment_id: self.environment_id.clone(),
287            segment_client: self.segment_client.clone(),
288            peek_client,
289            enable_frontend_peek_sequencing: false, // initialized below, once we have a ConnCatalog
290        };
291
292        let session = client.session();
293
294        // Apply the superuser attribute to the session's user if
295        // it exists.
296        if let SuperuserAttribute(Some(superuser)) = superuser_attribute {
297            session.apply_internal_user_metadata(InternalUserMetadata { superuser });
298        }
299
300        session.initialize_role_metadata(role_id);
301        let vars_mut = session.vars_mut();
302        for (name, val) in session_defaults {
303            if let Err(err) = vars_mut.set_default(&name, val.borrow()) {
304                // Note: erroring here is unexpected, but we don't want to panic if somehow our
305                // assumptions are wrong.
306                tracing::error!("failed to set peristed default, {err:?}");
307            }
308        }
309        session
310            .vars_mut()
311            .end_transaction(EndTransactionAction::Commit);
312
313        // Stash the future that notifies us of builtin table writes completing, we'll block on
314        // this future before allowing queries from this session against relevant relations.
315        //
316        // Note: We stash the future as opposed to waiting on it here to prevent blocking session
317        // creation on builtin table updates. This improves the latency for session creation and
318        // reduces scheduling load on any dataflows that read from these builtin relations, since
319        // it allows updates to be batched.
320        session.set_builtin_table_updates(write_notify);
321
322        let catalog = catalog.for_session(session);
323
324        let cluster_active = session.vars().cluster().to_string();
325        if session.vars().welcome_message() {
326            let cluster_info = if catalog.resolve_cluster(Some(&cluster_active)).is_err() {
327                format!("{cluster_active} (does not exist)")
328            } else {
329                cluster_active.to_string()
330            };
331
332            // Emit a welcome message, optimized for readability by humans using
333            // interactive tools. If you change the message, make sure that it
334            // formats nicely in both `psql` and the console's SQL shell.
335            session.add_notice(AdapterNotice::Welcome(format!(
336                "connected to Materialize v{}
337  Environment ID: {}
338  Region: {}
339  User: {}
340  Cluster: {}
341  Database: {}
342  {}
343  Session UUID: {}
344
345Issue a SQL query to get started. Need help?
346  View documentation: https://materialize.com/s/docs
347  Join our Slack community: https://materialize.com/s/chat
348    ",
349                session.vars().build_info().semver_version(),
350                self.environment_id,
351                self.environment_id.region(),
352                session.vars().user().name,
353                cluster_info,
354                session.vars().database(),
355                match session.vars().search_path() {
356                    [schema] => format!("Schema: {}", schema),
357                    schemas => format!(
358                        "Search path: {}",
359                        schemas.iter().map(|id| id.to_string()).join(", ")
360                    ),
361                },
362                session.uuid(),
363            )));
364        }
365
366        if session.vars().current_object_missing_warnings() {
367            if catalog.active_database().is_none() {
368                let db = session.vars().database().into();
369                session.add_notice(AdapterNotice::UnknownSessionDatabase(db));
370            }
371        }
372
373        // Users stub their toe on their default cluster not existing, so we provide a notice to
374        // help guide them on what do to.
375        let cluster_var = session
376            .vars()
377            .inspect(CLUSTER.name())
378            .expect("cluster should exist");
379        if session.vars().current_object_missing_warnings()
380            && catalog.resolve_cluster(Some(&cluster_active)).is_err()
381        {
382            let cluster_notice = 'notice: {
383                if cluster_var.inspect_session_value().is_some() {
384                    break 'notice Some(AdapterNotice::DefaultClusterDoesNotExist {
385                        name: cluster_active,
386                        kind: "session",
387                        suggested_action: "Pick an extant cluster with SET CLUSTER = name. Run SHOW CLUSTERS to see available clusters.".into(),
388                    });
389                }
390
391                let role_default = catalog.get_role(catalog.active_role_id());
392                let role_cluster = match role_default.vars().get(CLUSTER.name()) {
393                    Some(OwnedVarInput::Flat(name)) => Some(name),
394                    None => None,
395                    // This is unexpected!
396                    Some(v @ OwnedVarInput::SqlSet(_)) => {
397                        tracing::warn!(?v, "SqlSet found for cluster Role Default");
398                        break 'notice None;
399                    }
400                };
401
402                let alter_role = "with `ALTER ROLE <role> SET cluster TO <cluster>;`";
403                match role_cluster {
404                    // If there is no default, suggest a Role default.
405                    None => Some(AdapterNotice::DefaultClusterDoesNotExist {
406                        name: cluster_active,
407                        kind: "system",
408                        suggested_action: format!(
409                            "Set a default cluster for the current role {alter_role}."
410                        ),
411                    }),
412                    // If the default does not exist, suggest to change it.
413                    Some(_) => Some(AdapterNotice::DefaultClusterDoesNotExist {
414                        name: cluster_active,
415                        kind: "role",
416                        suggested_action: format!(
417                            "Change the default cluster for the current role {alter_role}."
418                        ),
419                    }),
420                }
421            };
422
423            if let Some(notice) = cluster_notice {
424                session.add_notice(notice);
425            }
426        }
427
428        client.enable_frontend_peek_sequencing = ENABLE_FRONTEND_PEEK_SEQUENCING
429            .require(catalog.system_vars())
430            .is_ok();
431
432        Ok(client)
433    }
434
435    /// Cancels the query currently running on the specified connection.
436    pub fn cancel_request(&self, conn_id: ConnectionIdType, secret_key: u32) {
437        self.send(Command::CancelRequest {
438            conn_id,
439            secret_key,
440        });
441    }
442
443    /// Executes a single SQL statement that returns rows as the
444    /// `mz_support` user.
445    pub async fn support_execute_one(
446        &self,
447        sql: &str,
448    ) -> Result<Pin<Box<dyn Stream<Item = PeekResponseUnary> + Send>>, anyhow::Error> {
449        // Connect to the coordinator.
450        let conn_id = self.new_conn_id()?;
451        let session = self.new_session(
452            SessionConfig {
453                conn_id,
454                uuid: Uuid::new_v4(),
455                user: SUPPORT_USER.name.clone(),
456                client_ip: None,
457                external_metadata_rx: None,
458                helm_chart_version: None,
459            },
460            Authenticated,
461        );
462        let mut session_client = self.startup(session).await?;
463
464        // Parse the SQL statement.
465        let stmts = mz_sql::parse::parse(sql)?;
466        if stmts.len() != 1 {
467            bail!("must supply exactly one query");
468        }
469        let StatementParseResult { ast: stmt, sql } = stmts.into_element();
470
471        const EMPTY_PORTAL: &str = "";
472        session_client.start_transaction(Some(1))?;
473        session_client
474            .declare(EMPTY_PORTAL.into(), stmt, sql.to_string())
475            .await?;
476
477        match session_client
478            .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
479            .await?
480        {
481            (ExecuteResponse::SendingRowsStreaming { mut rows, .. }, _) => {
482                // We have to only drop the session client _after_ we read the
483                // result. Otherwise the peek will get cancelled right when we
484                // drop the session client. So we wrap it up in an extra stream
485                // like this, which owns the client and can return it.
486                let owning_response_stream = async_stream::stream! {
487                    while let Some(rows) = rows.next().await {
488                        yield rows;
489                    }
490                    drop(session_client);
491                };
492                Ok(Box::pin(owning_response_stream))
493            }
494            r => bail!("unsupported response type: {r:?}"),
495        }
496    }
497
498    /// Returns the metrics associated with the adapter layer.
499    pub fn metrics(&self) -> &Metrics {
500        &self.metrics
501    }
502
503    /// The current time according to the [`Client`].
504    pub fn now(&self) -> DateTime<Utc> {
505        to_datetime((self.now)())
506    }
507
508    /// Get a metadata and a channel that can be used to append to a webhook source.
509    pub async fn get_webhook_appender(
510        &self,
511        database: String,
512        schema: String,
513        name: String,
514    ) -> Result<AppendWebhookResponse, AppendWebhookError> {
515        let (tx, rx) = oneshot::channel();
516
517        // Send our request.
518        self.send(Command::GetWebhook {
519            database,
520            schema,
521            name,
522            tx,
523        });
524
525        // Using our one shot channel to get the result, returning an error if the sender dropped.
526        let response = rx
527            .await
528            .map_err(|_| anyhow::anyhow!("failed to receive webhook response"))?;
529
530        response
531    }
532
533    /// Gets the current value of all system variables.
534    pub async fn get_system_vars(&self) -> SystemVars {
535        let (tx, rx) = oneshot::channel();
536        self.send(Command::GetSystemVars { tx });
537        rx.await.expect("coordinator unexpectedly gone")
538    }
539
540    #[instrument(level = "debug")]
541    pub(crate) fn send(&self, cmd: Command) {
542        self.inner_cmd_tx
543            .send((OpenTelemetryContext::obtain(), cmd))
544            .expect("coordinator unexpectedly gone");
545    }
546}
547
548/// A coordinator client that is bound to a connection.
549///
550/// See also [`Client`].
551pub struct SessionClient {
552    // Invariant: inner may only be `None` after the session has been terminated.
553    // Once the session is terminated, no communication to the Coordinator
554    // should be attempted.
555    inner: Option<Client>,
556    // Invariant: session may only be `None` during a method call. Every public
557    // method must ensure that `Session` is `Some` before it returns.
558    session: Option<Session>,
559    timeouts: Timeout,
560    segment_client: Option<mz_segment::Client>,
561    environment_id: EnvironmentId,
562    /// Client for frontend peek sequencing; populated at connection startup.
563    peek_client: PeekClient,
564    /// Whether frontend peek sequencing is enabled; initialized at connection startup.
565    // TODO(peek-seq): Currently, this is initialized only at session startup. We'll be able to
566    // check the actual feature flag value at every peek (without a Coordinator call) once we'll
567    // always have a catalog snapshot at hand.
568    pub enable_frontend_peek_sequencing: bool,
569}
570
571impl SessionClient {
572    /// Parses a SQL expression, reporting failures as a telemetry event if
573    /// possible.
574    pub fn parse<'a>(
575        &self,
576        sql: &'a str,
577    ) -> Result<Result<Vec<StatementParseResult<'a>>, ParserStatementError>, String> {
578        match mz_sql::parse::parse_with_limit(sql) {
579            Ok(Err(e)) => {
580                self.track_statement_parse_failure(&e);
581                Ok(Err(e))
582            }
583            r => r,
584        }
585    }
586
587    fn track_statement_parse_failure(&self, parse_error: &ParserStatementError) {
588        let session = self.session.as_ref().expect("session invariant violated");
589        let Some(user_id) = session.user().external_metadata.as_ref().map(|m| m.user_id) else {
590            return;
591        };
592        let Some(segment_client) = &self.segment_client else {
593            return;
594        };
595        let Some(statement_kind) = parse_error.statement else {
596            return;
597        };
598        let Some((action, object_type)) = telemetry::analyze_audited_statement(statement_kind)
599        else {
600            return;
601        };
602        let event_type = StatementFailureType::ParseFailure;
603        let event_name = format!(
604            "{} {} {}",
605            object_type.as_title_case(),
606            action.as_title_case(),
607            event_type.as_title_case(),
608        );
609        segment_client.environment_track(
610            &self.environment_id,
611            event_name,
612            json!({
613                "statement_kind": statement_kind,
614                "error": &parse_error.error,
615            }),
616            EventDetails {
617                user_id: Some(user_id),
618                application_name: Some(session.application_name()),
619                ..Default::default()
620            },
621        );
622    }
623
624    // Verify and return the named prepared statement. We need to verify each use
625    // to make sure the prepared statement is still safe to use.
626    pub async fn get_prepared_statement(
627        &mut self,
628        name: &str,
629    ) -> Result<&PreparedStatement, AdapterError> {
630        let catalog = self.catalog_snapshot("get_prepared_statement").await;
631        Coordinator::verify_prepared_statement(&catalog, self.session(), name)?;
632        Ok(self
633            .session()
634            .get_prepared_statement_unverified(name)
635            .expect("must exist"))
636    }
637
638    /// Saves the parsed statement as a prepared statement.
639    ///
640    /// The prepared statement is saved in the connection's [`crate::session::Session`]
641    /// under the specified name.
642    pub async fn prepare(
643        &mut self,
644        name: String,
645        stmt: Option<Statement<Raw>>,
646        sql: String,
647        param_types: Vec<Option<SqlScalarType>>,
648    ) -> Result<(), AdapterError> {
649        let catalog = self.catalog_snapshot("prepare").await;
650
651        // Note: This failpoint is used to simulate a request outliving the external connection
652        // that made it.
653        let mut async_pause = false;
654        (|| {
655            fail::fail_point!("async_prepare", |val| {
656                async_pause = val.map_or(false, |val| val.parse().unwrap_or(false))
657            });
658        })();
659        if async_pause {
660            tokio::time::sleep(Duration::from_secs(1)).await;
661        };
662
663        let desc = Coordinator::describe(&catalog, self.session(), stmt.clone(), param_types)?;
664        let now = self.now();
665        let state_revision = StateRevision {
666            catalog_revision: catalog.transient_revision(),
667            session_state_revision: self.session().state_revision(),
668        };
669        self.session()
670            .set_prepared_statement(name, stmt, sql, desc, state_revision, now);
671        Ok(())
672    }
673
674    /// Binds a statement to a portal.
675    #[mz_ore::instrument(level = "debug")]
676    pub async fn declare(
677        &mut self,
678        name: String,
679        stmt: Statement<Raw>,
680        sql: String,
681    ) -> Result<(), AdapterError> {
682        let catalog = self.catalog_snapshot("declare").await;
683        let param_types = vec![];
684        let desc =
685            Coordinator::describe(&catalog, self.session(), Some(stmt.clone()), param_types)?;
686        let params = vec![];
687        let result_formats = vec![mz_pgwire_common::Format::Text; desc.arity()];
688        let now = self.now();
689        let logging = self.session().mint_logging(sql, Some(&stmt), now);
690        let state_revision = StateRevision {
691            catalog_revision: catalog.transient_revision(),
692            session_state_revision: self.session().state_revision(),
693        };
694        self.session().set_portal(
695            name,
696            desc,
697            Some(stmt),
698            logging,
699            params,
700            result_formats,
701            state_revision,
702        )?;
703        Ok(())
704    }
705
706    /// Executes a previously-bound portal.
707    ///
708    /// Note: the provided `cancel_future` must be cancel-safe as it's polled in a `select!` loop.
709    ///
710    /// `outer_ctx_extra` is Some when we are executing as part of an outer statement, e.g., a FETCH
711    /// triggering the execution of the underlying query.
712    #[mz_ore::instrument(level = "debug")]
713    pub async fn execute(
714        &mut self,
715        portal_name: String,
716        cancel_future: impl Future<Output = std::io::Error> + Send,
717        outer_ctx_extra: Option<ExecuteContextGuard>,
718    ) -> Result<(ExecuteResponse, Instant), AdapterError> {
719        let execute_started = Instant::now();
720
721        // Attempt peek sequencing in the session task.
722        // If unsupported, fall back to the Coordinator path.
723        // TODO(peek-seq): wire up cancel_future
724        let mut outer_ctx_extra = outer_ctx_extra;
725        if let Some(resp) = self
726            .try_frontend_peek(&portal_name, &mut outer_ctx_extra)
727            .await?
728        {
729            debug!("frontend peek succeeded");
730            // Frontend peek handled the execution and retired outer_ctx_extra if it existed.
731            // No additional work needed here.
732            return Ok((resp, execute_started));
733        } else {
734            debug!("frontend peek did not happen, falling back to `Command::Execute`");
735            // If we bailed out, outer_ctx_extra is still present (if it was originally).
736            // `Command::Execute` will handle it.
737            // (This is not true if we bailed out _after_ the frontend peek sequencing has already
738            // begun its own statement logging. That case would be a bug.)
739        }
740
741        let response = self
742            .send_with_cancel(
743                |tx, session| Command::Execute {
744                    portal_name,
745                    session,
746                    tx,
747                    outer_ctx_extra,
748                },
749                cancel_future,
750            )
751            .await?;
752        Ok((response, execute_started))
753    }
754
755    fn now(&self) -> EpochMillis {
756        (self.inner().now)()
757    }
758
759    fn now_datetime(&self) -> DateTime<Utc> {
760        to_datetime(self.now())
761    }
762
763    /// Starts a transaction based on implicit:
764    /// - `None`: InTransaction
765    /// - `Some(1)`: Started
766    /// - `Some(n > 1)`: InTransactionImplicit
767    /// - `Some(0)`: no change
768    pub fn start_transaction(&mut self, implicit: Option<usize>) -> Result<(), AdapterError> {
769        let now = self.now_datetime();
770        let session = self.session.as_mut().expect("session invariant violated");
771        let result = match implicit {
772            None => session.start_transaction(now, None, None),
773            Some(stmts) => {
774                session.start_transaction_implicit(now, stmts);
775                Ok(())
776            }
777        };
778        result
779    }
780
781    /// Ends a transaction. Even if an error is returned, guarantees that the transaction in the
782    /// session and Coordinator has cleared its state.
783    #[instrument(level = "debug")]
784    pub async fn end_transaction(
785        &mut self,
786        action: EndTransactionAction,
787    ) -> Result<ExecuteResponse, AdapterError> {
788        let res = self
789            .send(|tx, session| Command::Commit {
790                action,
791                session,
792                tx,
793            })
794            .await;
795        // Commit isn't guaranteed to set the session's state to anything specific, so clear it
796        // here. It's safe to ignore the returned `TransactionStatus` because that doesn't contain
797        // any data that the Coordinator must act on for correctness.
798        let _ = self.session().clear_transaction();
799        res
800    }
801
802    /// Fails a transaction.
803    pub fn fail_transaction(&mut self) {
804        let session = self.session.take().expect("session invariant violated");
805        let session = session.fail_transaction();
806        self.session = Some(session);
807    }
808
809    /// Fetches the catalog.
810    #[instrument(level = "debug")]
811    pub async fn catalog_snapshot(&self, context: &str) -> Arc<Catalog> {
812        let start = std::time::Instant::now();
813        let CatalogSnapshot { catalog } = self
814            .send_without_session(|tx| Command::CatalogSnapshot { tx })
815            .await;
816        self.inner()
817            .metrics()
818            .catalog_snapshot_seconds
819            .with_label_values(&[context])
820            .observe(start.elapsed().as_secs_f64());
821        catalog
822    }
823
824    /// Dumps the catalog to a JSON string.
825    ///
826    /// No authorization is performed, so access to this function must be limited to internal
827    /// servers or superusers.
828    pub async fn dump_catalog(&self) -> Result<CatalogDump, AdapterError> {
829        let catalog = self.catalog_snapshot("dump_catalog").await;
830        catalog.dump().map_err(AdapterError::from)
831    }
832
833    /// Checks the catalog for internal consistency, returning a JSON object describing the
834    /// inconsistencies, if there are any.
835    ///
836    /// No authorization is performed, so access to this function must be limited to internal
837    /// servers or superusers.
838    pub async fn check_catalog(&self) -> Result<(), serde_json::Value> {
839        let catalog = self.catalog_snapshot("check_catalog").await;
840        catalog.check_consistency()
841    }
842
843    /// Checks the coordinator for internal consistency, returning a JSON object describing the
844    /// inconsistencies, if there are any. This is a superset of checks that check_catalog performs,
845    ///
846    /// No authorization is performed, so access to this function must be limited to internal
847    /// servers or superusers.
848    pub async fn check_coordinator(&self) -> Result<(), serde_json::Value> {
849        self.send_without_session(|tx| Command::CheckConsistency { tx })
850            .await
851            .map_err(|inconsistencies| {
852                serde_json::to_value(inconsistencies).unwrap_or_else(|_| {
853                    serde_json::Value::String("failed to serialize inconsistencies".to_string())
854                })
855            })
856    }
857
858    pub async fn dump_coordinator_state(&self) -> Result<serde_json::Value, anyhow::Error> {
859        self.send_without_session(|tx| Command::Dump { tx }).await
860    }
861
862    /// Tells the coordinator a statement has finished execution, in the cases
863    /// where we have no other reason to communicate with the coordinator.
864    pub fn retire_execute(
865        &self,
866        guard: ExecuteContextGuard,
867        reason: StatementEndedExecutionReason,
868    ) {
869        if !guard.is_trivial() {
870            let data = guard.defuse();
871            let cmd = Command::RetireExecute { data, reason };
872            self.inner().send(cmd);
873        }
874    }
875
876    /// Sets up a streaming COPY FROM STDIN operation.
877    ///
878    /// Sends a command to the coordinator to create a background batch
879    /// builder task. Returns a [`CopyFromStdinWriter`] that pgwire uses
880    /// to stream decoded rows.
881    pub async fn start_copy_from_stdin(
882        &mut self,
883        target_id: CatalogItemId,
884        target_name: String,
885        columns: Vec<ColumnIndex>,
886        row_desc: mz_repr::RelationDesc,
887        params: mz_pgcopy::CopyFormatParams<'static>,
888    ) -> Result<CopyFromStdinWriter, AdapterError> {
889        self.send(|tx, session| Command::StartCopyFromStdin {
890            target_id,
891            target_name,
892            columns,
893            row_desc,
894            params,
895            session,
896            tx,
897        })
898        .await
899    }
900
901    /// Commits staged COPY FROM STDIN batches to a table.
902    ///
903    /// Adds the pre-built persist batches to the session's transaction
904    /// operations. The actual commit happens when the transaction ends.
905    pub fn stage_copy_from_stdin_batches(
906        &mut self,
907        target_id: CatalogItemId,
908        batches: Vec<mz_persist_client::batch::ProtoBatch>,
909    ) -> Result<(), AdapterError> {
910        use crate::session::{TransactionOps, WriteOp};
911        use mz_storage_client::client::TableData;
912
913        self.session()
914            .add_transaction_ops(TransactionOps::Writes(vec![WriteOp {
915                id: target_id,
916                rows: TableData::Batches(batches.into()),
917            }]))?;
918        Ok(())
919    }
920
921    /// Gets the current value of all system variables.
922    pub async fn get_system_vars(&self) -> SystemVars {
923        self.inner().get_system_vars().await
924    }
925
926    /// Updates the specified system variables to the specified values.
927    pub async fn set_system_vars(
928        &mut self,
929        vars: BTreeMap<String, String>,
930    ) -> Result<(), AdapterError> {
931        let conn_id = self.session().conn_id().clone();
932        self.send_without_session(|tx| Command::SetSystemVars { vars, conn_id, tx })
933            .await
934    }
935
936    /// Terminates the client session.
937    pub async fn terminate(&mut self) {
938        let conn_id = self.session().conn_id().clone();
939        let res = self
940            .send_without_session(|tx| Command::Terminate {
941                conn_id,
942                tx: Some(tx),
943            })
944            .await;
945        if let Err(e) = res {
946            // Nothing we can do to handle a failed terminate so we just log and ignore it.
947            error!("Unable to terminate session: {e:?}");
948        }
949        // Prevent any communication with Coordinator after session is terminated.
950        self.inner = None;
951    }
952
953    /// Returns a mutable reference to the session bound to this client.
954    pub fn session(&mut self) -> &mut Session {
955        self.session.as_mut().expect("session invariant violated")
956    }
957
958    /// Returns a reference to the inner client.
959    pub fn inner(&self) -> &Client {
960        self.inner.as_ref().expect("inner invariant violated")
961    }
962
963    async fn send_without_session<T, F>(&self, f: F) -> T
964    where
965        F: FnOnce(oneshot::Sender<T>) -> Command,
966    {
967        let (tx, rx) = oneshot::channel();
968        self.inner().send(f(tx));
969        rx.await.expect("sender dropped")
970    }
971
972    #[instrument(level = "debug")]
973    async fn send<T, F>(&mut self, f: F) -> Result<T, AdapterError>
974    where
975        F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
976    {
977        self.send_with_cancel(f, futures::future::pending()).await
978    }
979
980    /// Send a [`Command`] to the Coordinator, with the ability to cancel the command.
981    ///
982    /// Note: the provided `cancel_future` must be cancel-safe as it's polled in a `select!` loop.
983    #[instrument(level = "debug")]
984    async fn send_with_cancel<T, F>(
985        &mut self,
986        f: F,
987        cancel_future: impl Future<Output = std::io::Error> + Send,
988    ) -> Result<T, AdapterError>
989    where
990        F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
991    {
992        let session = self.session.take().expect("session invariant violated");
993        let mut typ = None;
994        let application_name = session.application_name();
995        let name_hint = ApplicationNameHint::from_str(application_name);
996        let conn_id = session.conn_id().clone();
997        let (tx, rx) = oneshot::channel();
998
999        // Destructure self so we can hold a mutable reference to the inner client and session at
1000        // the same time.
1001        let Self {
1002            inner: inner_client,
1003            session: client_session,
1004            ..
1005        } = self;
1006
1007        // TODO(parkmycar): Leaking this invariant here doesn't feel great, but calling
1008        // `self.client()` doesn't work because then Rust takes a borrow on the entirity of self.
1009        let inner_client = inner_client.as_ref().expect("inner invariant violated");
1010
1011        // ~~SPOOKY ZONE~~
1012        //
1013        // This guard prevents a race where a `Session` is returned on `rx` but never placed
1014        // back in `self` because the Future returned by this function is concurrently dropped
1015        // with the Coordinator sending a response.
1016        let mut guarded_rx = rx.with_guard(|response: Response<_>| {
1017            *client_session = Some(response.session);
1018        });
1019
1020        inner_client.send({
1021            let cmd = f(tx, session);
1022            // Measure the success and error rate of certain commands:
1023            // - declare reports success of SQL statement planning
1024            // - execute reports success of dataflow execution
1025            match cmd {
1026                Command::Execute { .. } => typ = Some("execute"),
1027                Command::GetWebhook { .. } => typ = Some("webhook"),
1028                Command::StartCopyFromStdin { .. }
1029                | Command::Startup { .. }
1030                | Command::AuthenticatePassword { .. }
1031                | Command::AuthenticateGetSASLChallenge { .. }
1032                | Command::AuthenticateVerifySASLProof { .. }
1033                | Command::CatalogSnapshot { .. }
1034                | Command::Commit { .. }
1035                | Command::CancelRequest { .. }
1036                | Command::PrivilegedCancelRequest { .. }
1037                | Command::GetSystemVars { .. }
1038                | Command::SetSystemVars { .. }
1039                | Command::Terminate { .. }
1040                | Command::RetireExecute { .. }
1041                | Command::CheckConsistency { .. }
1042                | Command::Dump { .. }
1043                | Command::GetComputeInstanceClient { .. }
1044                | Command::GetOracle { .. }
1045                | Command::DetermineRealTimeRecentTimestamp { .. }
1046                | Command::GetTransactionReadHoldsBundle { .. }
1047                | Command::StoreTransactionReadHolds { .. }
1048                | Command::ExecuteSlowPathPeek { .. }
1049                | Command::CopyToPreflight { .. }
1050                | Command::ExecuteCopyTo { .. }
1051                | Command::ExecuteSideEffectingFunc { .. }
1052                | Command::RegisterFrontendPeek { .. }
1053                | Command::UnregisterFrontendPeek { .. }
1054                | Command::ExplainTimestamp { .. }
1055                | Command::FrontendStatementLogging(..) => {}
1056            };
1057            cmd
1058        });
1059
1060        let mut cancel_future = pin::pin!(cancel_future);
1061        let mut cancelled = false;
1062        loop {
1063            tokio::select! {
1064                res = &mut guarded_rx => {
1065                    // We received a result, so drop our guard to drop our borrows.
1066                    drop(guarded_rx);
1067
1068                    let res = res.expect("sender dropped");
1069                    let status = res.result.is_ok().then_some("success").unwrap_or("error");
1070                    if let Err(err) = res.result.as_ref() {
1071                        if name_hint.should_trace_errors() {
1072                            tracing::warn!(?err, ?name_hint, "adapter response error");
1073                        }
1074                    }
1075
1076                    if let Some(typ) = typ {
1077                        inner_client
1078                            .metrics
1079                            .commands
1080                            .with_label_values(&[typ, status, name_hint.as_str()])
1081                            .inc();
1082                    }
1083                    *client_session = Some(res.session);
1084                    return res.result;
1085                },
1086                _err = &mut cancel_future, if !cancelled => {
1087                    cancelled = true;
1088                    inner_client.send(Command::PrivilegedCancelRequest {
1089                        conn_id: conn_id.clone(),
1090                    });
1091                }
1092            };
1093        }
1094    }
1095
1096    pub fn add_idle_in_transaction_session_timeout(&mut self) {
1097        let session = self.session();
1098        let timeout_dur = session.vars().idle_in_transaction_session_timeout();
1099        if !timeout_dur.is_zero() {
1100            let timeout_dur = timeout_dur.clone();
1101            if let Some(txn) = session.transaction().inner() {
1102                let txn_id = txn.id.clone();
1103                let timeout = TimeoutType::IdleInTransactionSession(txn_id);
1104                self.timeouts.add_timeout(timeout, timeout_dur);
1105            }
1106        }
1107    }
1108
1109    pub fn remove_idle_in_transaction_session_timeout(&mut self) {
1110        let session = self.session();
1111        if let Some(txn) = session.transaction().inner() {
1112            let txn_id = txn.id.clone();
1113            self.timeouts
1114                .remove_timeout(&TimeoutType::IdleInTransactionSession(txn_id));
1115        }
1116    }
1117
1118    /// # Cancel safety
1119    ///
1120    /// This method is cancel safe. If `recv` is used as the event in a
1121    /// `tokio::select!` statement and some other branch
1122    /// completes first, it is guaranteed that no messages were received on this
1123    /// channel.
1124    pub async fn recv_timeout(&mut self) -> Option<TimeoutType> {
1125        self.timeouts.recv().await
1126    }
1127
1128    /// Returns a reference to the PeekClient used for frontend peek sequencing.
1129    pub fn peek_client(&self) -> &PeekClient {
1130        &self.peek_client
1131    }
1132
1133    /// Returns a reference to the PeekClient used for frontend peek sequencing.
1134    pub fn peek_client_mut(&mut self) -> &mut PeekClient {
1135        &mut self.peek_client
1136    }
1137
1138    /// Attempt to sequence a peek from the session task.
1139    ///
1140    /// Returns `Ok(Some(response))` if we handled the peek, or `Ok(None)` to fall back to the
1141    /// Coordinator's sequencing. If it returns an error, it should be returned to the user.
1142    ///
1143    /// `outer_ctx_extra` is Some when we are executing as part of an outer statement, e.g., a FETCH
1144    /// triggering the execution of the underlying query.
1145    pub(crate) async fn try_frontend_peek(
1146        &mut self,
1147        portal_name: &str,
1148        outer_ctx_extra: &mut Option<ExecuteContextGuard>,
1149    ) -> Result<Option<ExecuteResponse>, AdapterError> {
1150        if self.enable_frontend_peek_sequencing {
1151            let session = self.session.as_mut().expect("SessionClient invariant");
1152            self.peek_client
1153                .try_frontend_peek(portal_name, session, outer_ctx_extra)
1154                .await
1155        } else {
1156            Ok(None)
1157        }
1158    }
1159}
1160
1161impl Drop for SessionClient {
1162    fn drop(&mut self) {
1163        // We may not have a session if this client was dropped while awaiting
1164        // a response. In this case, it is the coordinator's responsibility to
1165        // terminate the session.
1166        if let Some(session) = self.session.take() {
1167            // We may not have a connection to the Coordinator if the session was
1168            // prematurely terminated, for example due to a timeout.
1169            if let Some(inner) = &self.inner {
1170                inner.send(Command::Terminate {
1171                    conn_id: session.conn_id().clone(),
1172                    tx: None,
1173                })
1174            }
1175        }
1176    }
1177}
1178
1179#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
1180pub enum TimeoutType {
1181    IdleInTransactionSession(TransactionId),
1182}
1183
1184impl Display for TimeoutType {
1185    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1186        match self {
1187            TimeoutType::IdleInTransactionSession(txn_id) => {
1188                writeln!(f, "Idle in transaction session for transaction '{txn_id}'")
1189            }
1190        }
1191    }
1192}
1193
1194impl From<TimeoutType> for AdapterError {
1195    fn from(timeout: TimeoutType) -> Self {
1196        match timeout {
1197            TimeoutType::IdleInTransactionSession(_) => {
1198                AdapterError::IdleInTransactionSessionTimeout
1199            }
1200        }
1201    }
1202}
1203
1204struct Timeout {
1205    tx: mpsc::UnboundedSender<TimeoutType>,
1206    rx: mpsc::UnboundedReceiver<TimeoutType>,
1207    active_timeouts: BTreeMap<TimeoutType, AbortOnDropHandle<()>>,
1208}
1209
1210impl Timeout {
1211    fn new() -> Self {
1212        let (tx, rx) = mpsc::unbounded_channel();
1213        Timeout {
1214            tx,
1215            rx,
1216            active_timeouts: BTreeMap::new(),
1217        }
1218    }
1219
1220    /// # Cancel safety
1221    ///
1222    /// This method is cancel safe. If `recv` is used as the event in a
1223    /// `tokio::select!` statement and some other branch
1224    /// completes first, it is guaranteed that no messages were received on this
1225    /// channel.
1226    ///
1227    /// <https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.UnboundedReceiver.html#cancel-safety>
1228    async fn recv(&mut self) -> Option<TimeoutType> {
1229        self.rx.recv().await
1230    }
1231
1232    fn add_timeout(&mut self, timeout: TimeoutType, duration: Duration) {
1233        let tx = self.tx.clone();
1234        let timeout_key = timeout.clone();
1235        let handle = mz_ore::task::spawn(|| format!("{timeout_key}"), async move {
1236            tokio::time::sleep(duration).await;
1237            let _ = tx.send(timeout);
1238        })
1239        .abort_on_drop();
1240        self.active_timeouts.insert(timeout_key, handle);
1241    }
1242
1243    fn remove_timeout(&mut self, timeout: &TimeoutType) {
1244        self.active_timeouts.remove(timeout);
1245
1246        // Remove the timeout from the rx queue if it exists.
1247        let mut timeouts = Vec::new();
1248        while let Ok(pending_timeout) = self.rx.try_recv() {
1249            if timeout != &pending_timeout {
1250                timeouts.push(pending_timeout);
1251            }
1252        }
1253        for pending_timeout in timeouts {
1254            self.tx.send(pending_timeout).expect("rx is in this struct");
1255        }
1256    }
1257}
1258
1259/// A wrapper around a Stream of PeekResponseUnary that records when it sees the
1260/// first row data in the given histogram. It also keeps track of whether we have already observed
1261/// the end of the underlying stream.
1262#[derive(Derivative)]
1263#[derivative(Debug)]
1264pub struct RecordFirstRowStream {
1265    /// The underlying stream of rows.
1266    #[derivative(Debug = "ignore")]
1267    pub rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1268    /// The Instant when execution started.
1269    pub execute_started: Instant,
1270    /// The histogram where the time since `execute_started` will be recorded when we see the first
1271    /// row.
1272    pub time_to_first_row_seconds: Histogram,
1273    /// Whether we've seen any rows.
1274    pub saw_rows: bool,
1275    /// The Instant when we saw the first row.
1276    pub recorded_first_row_instant: Option<Instant>,
1277    /// Whether we have already observed the end of the underlying stream.
1278    pub no_more_rows: bool,
1279}
1280
1281impl RecordFirstRowStream {
1282    /// Create a new [`RecordFirstRowStream`]
1283    pub fn new(
1284        rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1285        execute_started: Instant,
1286        client: &SessionClient,
1287        instance_id: Option<ComputeInstanceId>,
1288        strategy: Option<StatementExecutionStrategy>,
1289    ) -> Self {
1290        let histogram = Self::histogram(client, instance_id, strategy);
1291        Self {
1292            rows,
1293            execute_started,
1294            time_to_first_row_seconds: histogram,
1295            saw_rows: false,
1296            recorded_first_row_instant: None,
1297            no_more_rows: false,
1298        }
1299    }
1300
1301    fn histogram(
1302        client: &SessionClient,
1303        instance_id: Option<ComputeInstanceId>,
1304        strategy: Option<StatementExecutionStrategy>,
1305    ) -> Histogram {
1306        let isolation_level = *client
1307            .session
1308            .as_ref()
1309            .expect("session invariant")
1310            .vars()
1311            .transaction_isolation();
1312        let instance = match instance_id {
1313            Some(i) => Cow::Owned(i.to_string()),
1314            None => Cow::Borrowed("none"),
1315        };
1316        let strategy = match strategy {
1317            Some(s) => s.name(),
1318            None => "none",
1319        };
1320
1321        client
1322            .inner()
1323            .metrics()
1324            .time_to_first_row_seconds
1325            .with_label_values(&[instance.as_ref(), isolation_level.as_str(), strategy])
1326    }
1327
1328    /// If you want to match [`RecordFirstRowStream`]'s logic but don't need
1329    /// a UnboundedReceiver, you can tell it when to record an observation.
1330    pub fn record(
1331        execute_started: Instant,
1332        client: &SessionClient,
1333        instance_id: Option<ComputeInstanceId>,
1334        strategy: Option<StatementExecutionStrategy>,
1335    ) {
1336        Self::histogram(client, instance_id, strategy)
1337            .observe(execute_started.elapsed().as_secs_f64());
1338    }
1339
1340    pub async fn recv(&mut self) -> Option<PeekResponseUnary> {
1341        let msg = self.rows.next().await;
1342        if !self.saw_rows && matches!(msg, Some(PeekResponseUnary::Rows(_))) {
1343            self.saw_rows = true;
1344            self.time_to_first_row_seconds
1345                .observe(self.execute_started.elapsed().as_secs_f64());
1346            self.recorded_first_row_instant = Some(Instant::now());
1347        }
1348        if msg.is_none() {
1349            self.no_more_rows = true;
1350        }
1351        msg
1352    }
1353}