mz_adapter/
client.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::borrow::Cow;
11use std::collections::BTreeMap;
12use std::fmt::{Debug, Display, Formatter};
13use std::future::Future;
14use std::pin::{self, Pin};
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18use anyhow::bail;
19use chrono::{DateTime, Utc};
20use derivative::Derivative;
21use futures::{Stream, StreamExt};
22use itertools::Itertools;
23use mz_adapter_types::connection::{ConnectionId, ConnectionIdType};
24use mz_auth::password::Password;
25use mz_build_info::BuildInfo;
26use mz_compute_types::ComputeInstanceId;
27use mz_ore::channel::OneshotReceiverExt;
28use mz_ore::collections::CollectionExt;
29use mz_ore::id_gen::{IdAllocator, IdAllocatorInnerBitSet, MAX_ORG_ID, org_id_conn_bits};
30use mz_ore::instrument;
31use mz_ore::now::{EpochMillis, NowFn, to_datetime};
32use mz_ore::result::ResultExt;
33use mz_ore::task::AbortOnDropHandle;
34use mz_ore::thread::JoinOnDropHandle;
35use mz_ore::tracing::OpenTelemetryContext;
36use mz_repr::{CatalogItemId, ColumnIndex, Row, SqlScalarType};
37use mz_sql::ast::{Raw, Statement};
38use mz_sql::catalog::{EnvironmentId, SessionCatalog};
39use mz_sql::session::hint::ApplicationNameHint;
40use mz_sql::session::metadata::SessionMetadata;
41use mz_sql::session::user::SUPPORT_USER;
42use mz_sql::session::vars::{CLUSTER, OwnedVarInput, SystemVars, Var};
43use mz_sql_parser::parser::{ParserStatementError, StatementParseResult};
44use prometheus::Histogram;
45use serde_json::json;
46use tokio::sync::{mpsc, oneshot};
47use tracing::error;
48use uuid::Uuid;
49
50use crate::catalog::Catalog;
51use crate::command::{
52    AuthResponse, CatalogDump, CatalogSnapshot, Command, ExecuteResponse, Response,
53    SASLChallengeResponse, SASLVerifyProofResponse,
54};
55use crate::coord::{Coordinator, ExecuteContextExtra};
56use crate::error::AdapterError;
57use crate::metrics::Metrics;
58use crate::optimize::dataflows::{EvalTime, ExprPrepStyle};
59use crate::optimize::{self, Optimize};
60use crate::session::{
61    EndTransactionAction, PreparedStatement, Session, SessionConfig, StateRevision, TransactionId,
62};
63use crate::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
64use crate::telemetry::{self, EventDetails, SegmentClientExt, StatementFailureType};
65use crate::webhook::AppendWebhookResponse;
66use crate::{AdapterNotice, AppendWebhookError, PeekResponseUnary, StartupResponse};
67
68/// A handle to a running coordinator.
69///
70/// The coordinator runs on its own thread. Dropping the handle will wait for
71/// the coordinator's thread to exit, which will only occur after all
72/// outstanding [`Client`]s for the coordinator have dropped.
73pub struct Handle {
74    pub(crate) session_id: Uuid,
75    pub(crate) start_instant: Instant,
76    pub(crate) _thread: JoinOnDropHandle<()>,
77}
78
79impl Handle {
80    /// Returns the session ID associated with this coordinator.
81    ///
82    /// The session ID is generated on coordinator boot. It lasts for the
83    /// lifetime of the coordinator. Restarting the coordinator will result
84    /// in a new session ID.
85    pub fn session_id(&self) -> Uuid {
86        self.session_id
87    }
88
89    /// Returns the instant at which the coordinator booted.
90    pub fn start_instant(&self) -> Instant {
91        self.start_instant
92    }
93}
94
95/// A coordinator client.
96///
97/// A coordinator client is a simple handle to a communication channel with the
98/// coordinator. It can be cheaply cloned.
99///
100/// Clients keep the coordinator alive. The coordinator will not exit until all
101/// outstanding clients have dropped.
102#[derive(Debug, Clone)]
103pub struct Client {
104    build_info: &'static BuildInfo,
105    inner_cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
106    id_alloc: IdAllocator<IdAllocatorInnerBitSet>,
107    now: NowFn,
108    metrics: Metrics,
109    environment_id: EnvironmentId,
110    segment_client: Option<mz_segment::Client>,
111}
112
113impl Client {
114    pub(crate) fn new(
115        build_info: &'static BuildInfo,
116        cmd_tx: mpsc::UnboundedSender<(OpenTelemetryContext, Command)>,
117        metrics: Metrics,
118        now: NowFn,
119        environment_id: EnvironmentId,
120        segment_client: Option<mz_segment::Client>,
121    ) -> Client {
122        // Connection ids are 32 bits and have 3 parts.
123        // 1. MSB bit is always 0 because these are interpreted as an i32, and it is possible some
124        //    driver will not handle a negative id since postgres has never produced one because it
125        //    uses process ids.
126        // 2. Next 12 bits are the lower 12 bits of the org id. This allows balancerd to route
127        //    incoming cancel messages to a subset of the environments.
128        // 3. Last 19 bits are random.
129        let env_lower = org_id_conn_bits(&environment_id.organization_id());
130        Client {
131            build_info,
132            inner_cmd_tx: cmd_tx,
133            id_alloc: IdAllocator::new(1, MAX_ORG_ID, env_lower),
134            now,
135            metrics,
136            environment_id,
137            segment_client,
138        }
139    }
140
141    /// Allocates a client for an incoming connection.
142    pub fn new_conn_id(&self) -> Result<ConnectionId, AdapterError> {
143        self.id_alloc.alloc().ok_or(AdapterError::IdExhaustionError)
144    }
145
146    /// Creates a new session associated with this client for the given user.
147    ///
148    /// It is the caller's responsibility to have authenticated the user.
149    pub fn new_session(&self, config: SessionConfig) -> Session {
150        // We use the system clock to determine when a session connected to Materialize. This is not
151        // intended to be 100% accurate and correct, so we don't burden the timestamp oracle with
152        // generating a more correct timestamp.
153        Session::new(self.build_info, config, self.metrics().session_metrics())
154    }
155
156    /// Preforms an authentication check for the given user.
157    pub async fn authenticate(
158        &self,
159        user: &String,
160        password: &Password,
161    ) -> Result<AuthResponse, AdapterError> {
162        let (tx, rx) = oneshot::channel();
163        self.send(Command::AuthenticatePassword {
164            role_name: user.to_string(),
165            password: Some(password.clone()),
166            tx,
167        });
168        let response = rx.await.expect("sender dropped")?;
169        Ok(response)
170    }
171
172    pub async fn generate_sasl_challenge(
173        &self,
174        user: &String,
175        client_nonce: &String,
176    ) -> Result<SASLChallengeResponse, AdapterError> {
177        let (tx, rx) = oneshot::channel();
178        self.send(Command::AuthenticateGetSASLChallenge {
179            role_name: user.to_string(),
180            nonce: client_nonce.to_string(),
181            tx,
182        });
183        let response = rx.await.expect("sender dropped")?;
184        Ok(response)
185    }
186
187    pub async fn verify_sasl_proof(
188        &self,
189        user: &String,
190        proof: &String,
191        nonce: &String,
192        mock_hash: &String,
193    ) -> Result<SASLVerifyProofResponse, AdapterError> {
194        let (tx, rx) = oneshot::channel();
195        self.send(Command::AuthenticateVerifySASLProof {
196            role_name: user.to_string(),
197            proof: proof.to_string(),
198            auth_message: nonce.to_string(),
199            mock_hash: mock_hash.to_string(),
200            tx,
201        });
202        let response = rx.await.expect("sender dropped")?;
203        Ok(response)
204    }
205
206    /// Upgrades this client to a session client.
207    ///
208    /// A session is a connection that has successfully negotiated parameters,
209    /// like the user. Most coordinator operations are available only after
210    /// upgrading a connection to a session.
211    ///
212    /// Returns a new client that is bound to the session and a response
213    /// containing various details about the startup.
214    #[mz_ore::instrument(level = "debug")]
215    pub async fn startup(&self, session: Session) -> Result<SessionClient, AdapterError> {
216        let user = session.user().clone();
217        let conn_id = session.conn_id().clone();
218        let secret_key = session.secret_key();
219        let uuid = session.uuid();
220        let client_ip = session.client_ip();
221        let application_name = session.application_name().into();
222        let notice_tx = session.retain_notice_transmitter();
223
224        let (tx, rx) = oneshot::channel();
225
226        // ~~SPOOKY ZONE~~
227        //
228        // This guard prevents a race where the startup command finishes, but the Future returned
229        // by this function is concurrently dropped, so we never create a `SessionClient` and thus
230        // never cleanup the initialized Session.
231        let rx = rx.with_guard(|_| {
232            self.send(Command::Terminate {
233                conn_id: conn_id.clone(),
234                tx: None,
235            });
236        });
237
238        self.send(Command::Startup {
239            tx,
240            user,
241            conn_id: conn_id.clone(),
242            secret_key,
243            uuid,
244            client_ip: client_ip.copied(),
245            application_name,
246            notice_tx,
247        });
248
249        // When startup fails, no need to call terminate (handle_startup does this). Delay creating
250        // the client until after startup to sidestep the panic in its `Drop` implementation.
251        let response = rx.await.expect("sender dropped")?;
252
253        // Create the client as soon as startup succeeds (before any await points) so its `Drop` can
254        // handle termination.
255        let mut client = SessionClient {
256            inner: Some(self.clone()),
257            session: Some(session),
258            timeouts: Timeout::new(),
259            environment_id: self.environment_id.clone(),
260            segment_client: self.segment_client.clone(),
261        };
262
263        let StartupResponse {
264            role_id,
265            write_notify,
266            session_defaults,
267            catalog,
268        } = response;
269
270        let session = client.session();
271        session.initialize_role_metadata(role_id);
272        let vars_mut = session.vars_mut();
273        for (name, val) in session_defaults {
274            if let Err(err) = vars_mut.set_default(&name, val.borrow()) {
275                // Note: erroring here is unexpected, but we don't want to panic if somehow our
276                // assumptions are wrong.
277                tracing::error!("failed to set peristed default, {err:?}");
278            }
279        }
280        session
281            .vars_mut()
282            .end_transaction(EndTransactionAction::Commit);
283
284        // Stash the future that notifies us of builtin table writes completing, we'll block on
285        // this future before allowing queries from this session against relevant relations.
286        //
287        // Note: We stash the future as opposed to waiting on it here to prevent blocking session
288        // creation on builtin table updates. This improves the latency for session creation and
289        // reduces scheduling load on any dataflows that read from these builtin relations, since
290        // it allows updates to be batched.
291        session.set_builtin_table_updates(write_notify);
292
293        let catalog = catalog.for_session(session);
294
295        let cluster_active = session.vars().cluster().to_string();
296        if session.vars().welcome_message() {
297            let cluster_info = if catalog.resolve_cluster(Some(&cluster_active)).is_err() {
298                format!("{cluster_active} (does not exist)")
299            } else {
300                cluster_active.to_string()
301            };
302
303            // Emit a welcome message, optimized for readability by humans using
304            // interactive tools. If you change the message, make sure that it
305            // formats nicely in both `psql` and the console's SQL shell.
306            session.add_notice(AdapterNotice::Welcome(format!(
307                "connected to Materialize v{}
308  Org ID: {}
309  Region: {}
310  User: {}
311  Cluster: {}
312  Database: {}
313  {}
314  Session UUID: {}
315
316Issue a SQL query to get started. Need help?
317  View documentation: https://materialize.com/s/docs
318  Join our Slack community: https://materialize.com/s/chat
319    ",
320                session.vars().build_info().semver_version(),
321                self.environment_id.organization_id(),
322                self.environment_id.region(),
323                session.vars().user().name,
324                cluster_info,
325                session.vars().database(),
326                match session.vars().search_path() {
327                    [schema] => format!("Schema: {}", schema),
328                    schemas => format!(
329                        "Search path: {}",
330                        schemas.iter().map(|id| id.to_string()).join(", ")
331                    ),
332                },
333                session.uuid(),
334            )));
335        }
336
337        if session.vars().current_object_missing_warnings() {
338            if catalog.active_database().is_none() {
339                let db = session.vars().database().into();
340                session.add_notice(AdapterNotice::UnknownSessionDatabase(db));
341            }
342        }
343
344        // Users stub their toe on their default cluster not existing, so we provide a notice to
345        // help guide them on what do to.
346        let cluster_var = session
347            .vars()
348            .inspect(CLUSTER.name())
349            .expect("cluster should exist");
350        if session.vars().current_object_missing_warnings()
351            && catalog.resolve_cluster(Some(&cluster_active)).is_err()
352        {
353            let cluster_notice = 'notice: {
354                if cluster_var.inspect_session_value().is_some() {
355                    break 'notice Some(AdapterNotice::DefaultClusterDoesNotExist {
356                        name: cluster_active,
357                        kind: "session",
358                        suggested_action: "Pick an extant cluster with SET CLUSTER = name. Run SHOW CLUSTERS to see available clusters.".into(),
359                    });
360                }
361
362                let role_default = catalog.get_role(catalog.active_role_id());
363                let role_cluster = match role_default.vars().get(CLUSTER.name()) {
364                    Some(OwnedVarInput::Flat(name)) => Some(name),
365                    None => None,
366                    // This is unexpected!
367                    Some(v @ OwnedVarInput::SqlSet(_)) => {
368                        tracing::warn!(?v, "SqlSet found for cluster Role Default");
369                        break 'notice None;
370                    }
371                };
372
373                let alter_role = "with `ALTER ROLE <role> SET cluster TO <cluster>;`";
374                match role_cluster {
375                    // If there is no default, suggest a Role default.
376                    None => Some(AdapterNotice::DefaultClusterDoesNotExist {
377                        name: cluster_active,
378                        kind: "system",
379                        suggested_action: format!(
380                            "Set a default cluster for the current role {alter_role}."
381                        ),
382                    }),
383                    // If the default does not exist, suggest to change it.
384                    Some(_) => Some(AdapterNotice::DefaultClusterDoesNotExist {
385                        name: cluster_active,
386                        kind: "role",
387                        suggested_action: format!(
388                            "Change the default cluster for the current role {alter_role}."
389                        ),
390                    }),
391                }
392            };
393
394            if let Some(notice) = cluster_notice {
395                session.add_notice(notice);
396            }
397        }
398
399        Ok(client)
400    }
401
402    /// Cancels the query currently running on the specified connection.
403    pub fn cancel_request(&self, conn_id: ConnectionIdType, secret_key: u32) {
404        self.send(Command::CancelRequest {
405            conn_id,
406            secret_key,
407        });
408    }
409
410    /// Executes a single SQL statement that returns rows as the
411    /// `mz_support` user.
412    pub async fn support_execute_one(
413        &self,
414        sql: &str,
415    ) -> Result<Pin<Box<dyn Stream<Item = PeekResponseUnary> + Send + Sync>>, anyhow::Error> {
416        // Connect to the coordinator.
417        let conn_id = self.new_conn_id()?;
418        let session = self.new_session(SessionConfig {
419            conn_id,
420            uuid: Uuid::new_v4(),
421            user: SUPPORT_USER.name.clone(),
422            client_ip: None,
423            external_metadata_rx: None,
424            internal_user_metadata: None,
425            helm_chart_version: None,
426        });
427        let mut session_client = self.startup(session).await?;
428
429        // Parse the SQL statement.
430        let stmts = mz_sql::parse::parse(sql)?;
431        if stmts.len() != 1 {
432            bail!("must supply exactly one query");
433        }
434        let StatementParseResult { ast: stmt, sql } = stmts.into_element();
435
436        const EMPTY_PORTAL: &str = "";
437        session_client.start_transaction(Some(1))?;
438        session_client
439            .declare(EMPTY_PORTAL.into(), stmt, sql.to_string())
440            .await?;
441
442        match session_client
443            .execute(EMPTY_PORTAL.into(), futures::future::pending(), None)
444            .await?
445        {
446            (ExecuteResponse::SendingRowsStreaming { mut rows, .. }, _) => {
447                // We have to only drop the session client _after_ we read the
448                // result. Otherwise the peek will get cancelled right when we
449                // drop the session client. So we wrap it up in an extra stream
450                // like this, which owns the client and can return it.
451                let owning_response_stream = async_stream::stream! {
452                    while let Some(rows) = rows.next().await {
453                        yield rows;
454                    }
455                    drop(session_client);
456                };
457                Ok(Box::pin(owning_response_stream))
458            }
459            r => bail!("unsupported response type: {r:?}"),
460        }
461    }
462
463    /// Returns the metrics associated with the adapter layer.
464    pub fn metrics(&self) -> &Metrics {
465        &self.metrics
466    }
467
468    /// The current time according to the [`Client`].
469    pub fn now(&self) -> DateTime<Utc> {
470        to_datetime((self.now)())
471    }
472
473    /// Get a metadata and a channel that can be used to append to a webhook source.
474    pub async fn get_webhook_appender(
475        &self,
476        database: String,
477        schema: String,
478        name: String,
479    ) -> Result<AppendWebhookResponse, AppendWebhookError> {
480        let (tx, rx) = oneshot::channel();
481
482        // Send our request.
483        self.send(Command::GetWebhook {
484            database,
485            schema,
486            name,
487            tx,
488        });
489
490        // Using our one shot channel to get the result, returning an error if the sender dropped.
491        let response = rx
492            .await
493            .map_err(|_| anyhow::anyhow!("failed to receive webhook response"))?;
494
495        response
496    }
497
498    /// Gets the current value of all system variables.
499    pub async fn get_system_vars(&self) -> SystemVars {
500        let (tx, rx) = oneshot::channel();
501        self.send(Command::GetSystemVars { tx });
502        rx.await.expect("coordinator unexpectedly gone")
503    }
504
505    #[instrument(level = "debug")]
506    fn send(&self, cmd: Command) {
507        self.inner_cmd_tx
508            .send((OpenTelemetryContext::obtain(), cmd))
509            .expect("coordinator unexpectedly gone");
510    }
511}
512
513/// A coordinator client that is bound to a connection.
514///
515/// See also [`Client`].
516pub struct SessionClient {
517    // Invariant: inner may only be `None` after the session has been terminated.
518    // Once the session is terminated, no communication to the Coordinator
519    // should be attempted.
520    inner: Option<Client>,
521    // Invariant: session may only be `None` during a method call. Every public
522    // method must ensure that `Session` is `Some` before it returns.
523    session: Option<Session>,
524    timeouts: Timeout,
525    segment_client: Option<mz_segment::Client>,
526    environment_id: EnvironmentId,
527}
528
529impl SessionClient {
530    /// Parses a SQL expression, reporting failures as a telemetry event if
531    /// possible.
532    pub fn parse<'a>(
533        &self,
534        sql: &'a str,
535    ) -> Result<Result<Vec<StatementParseResult<'a>>, ParserStatementError>, String> {
536        match mz_sql::parse::parse_with_limit(sql) {
537            Ok(Err(e)) => {
538                self.track_statement_parse_failure(&e);
539                Ok(Err(e))
540            }
541            r => r,
542        }
543    }
544
545    fn track_statement_parse_failure(&self, parse_error: &ParserStatementError) {
546        let session = self.session.as_ref().expect("session invariant violated");
547        let Some(user_id) = session.user().external_metadata.as_ref().map(|m| m.user_id) else {
548            return;
549        };
550        let Some(segment_client) = &self.segment_client else {
551            return;
552        };
553        let Some(statement_kind) = parse_error.statement else {
554            return;
555        };
556        let Some((action, object_type)) = telemetry::analyze_audited_statement(statement_kind)
557        else {
558            return;
559        };
560        let event_type = StatementFailureType::ParseFailure;
561        let event_name = format!(
562            "{} {} {}",
563            object_type.as_title_case(),
564            action.as_title_case(),
565            event_type.as_title_case(),
566        );
567        segment_client.environment_track(
568            &self.environment_id,
569            event_name,
570            json!({
571                "statement_kind": statement_kind,
572                "error": &parse_error.error,
573            }),
574            EventDetails {
575                user_id: Some(user_id),
576                application_name: Some(session.application_name()),
577                ..Default::default()
578            },
579        );
580    }
581
582    // Verify and return the named prepared statement. We need to verify each use
583    // to make sure the prepared statement is still safe to use.
584    pub async fn get_prepared_statement(
585        &mut self,
586        name: &str,
587    ) -> Result<&PreparedStatement, AdapterError> {
588        let catalog = self.catalog_snapshot("get_prepared_statement").await;
589        Coordinator::verify_prepared_statement(&catalog, self.session(), name)?;
590        Ok(self
591            .session()
592            .get_prepared_statement_unverified(name)
593            .expect("must exist"))
594    }
595
596    /// Saves the parsed statement as a prepared statement.
597    ///
598    /// The prepared statement is saved in the connection's [`crate::session::Session`]
599    /// under the specified name.
600    pub async fn prepare(
601        &mut self,
602        name: String,
603        stmt: Option<Statement<Raw>>,
604        sql: String,
605        param_types: Vec<Option<SqlScalarType>>,
606    ) -> Result<(), AdapterError> {
607        let catalog = self.catalog_snapshot("prepare").await;
608
609        // Note: This failpoint is used to simulate a request outliving the external connection
610        // that made it.
611        let mut async_pause = false;
612        (|| {
613            fail::fail_point!("async_prepare", |val| {
614                async_pause = val.map_or(false, |val| val.parse().unwrap_or(false))
615            });
616        })();
617        if async_pause {
618            tokio::time::sleep(Duration::from_secs(1)).await;
619        };
620
621        let desc = Coordinator::describe(&catalog, self.session(), stmt.clone(), param_types)?;
622        let now = self.now();
623        let state_revision = StateRevision {
624            catalog_revision: catalog.transient_revision(),
625            session_state_revision: self.session().state_revision(),
626        };
627        self.session()
628            .set_prepared_statement(name, stmt, sql, desc, state_revision, now);
629        Ok(())
630    }
631
632    /// Binds a statement to a portal.
633    #[mz_ore::instrument(level = "debug")]
634    pub async fn declare(
635        &mut self,
636        name: String,
637        stmt: Statement<Raw>,
638        sql: String,
639    ) -> Result<(), AdapterError> {
640        let catalog = self.catalog_snapshot("declare").await;
641        let param_types = vec![];
642        let desc =
643            Coordinator::describe(&catalog, self.session(), Some(stmt.clone()), param_types)?;
644        let params = vec![];
645        let result_formats = vec![mz_pgwire_common::Format::Text; desc.arity()];
646        let now = self.now();
647        let logging = self.session().mint_logging(sql, Some(&stmt), now);
648        let state_revision = StateRevision {
649            catalog_revision: catalog.transient_revision(),
650            session_state_revision: self.session().state_revision(),
651        };
652        self.session().set_portal(
653            name,
654            desc,
655            Some(stmt),
656            logging,
657            params,
658            result_formats,
659            state_revision,
660        )?;
661        Ok(())
662    }
663
664    /// Executes a previously-bound portal.
665    ///
666    /// Note: the provided `cancel_future` must be cancel-safe as it's polled in a `select!` loop.
667    #[mz_ore::instrument(level = "debug")]
668    pub async fn execute(
669        &mut self,
670        portal_name: String,
671        cancel_future: impl Future<Output = std::io::Error> + Send,
672        outer_ctx_extra: Option<ExecuteContextExtra>,
673    ) -> Result<(ExecuteResponse, Instant), AdapterError> {
674        let execute_started = Instant::now();
675        let response = self
676            .send_with_cancel(
677                |tx, session| Command::Execute {
678                    portal_name,
679                    session,
680                    tx,
681                    outer_ctx_extra,
682                },
683                cancel_future,
684            )
685            .await?;
686        Ok((response, execute_started))
687    }
688
689    fn now(&self) -> EpochMillis {
690        (self.inner().now)()
691    }
692
693    fn now_datetime(&self) -> DateTime<Utc> {
694        to_datetime(self.now())
695    }
696
697    /// Starts a transaction based on implicit:
698    /// - `None`: InTransaction
699    /// - `Some(1)`: Started
700    /// - `Some(n > 1)`: InTransactionImplicit
701    /// - `Some(0)`: no change
702    pub fn start_transaction(&mut self, implicit: Option<usize>) -> Result<(), AdapterError> {
703        let now = self.now_datetime();
704        let session = self.session.as_mut().expect("session invariant violated");
705        let result = match implicit {
706            None => session.start_transaction(now, None, None),
707            Some(stmts) => {
708                session.start_transaction_implicit(now, stmts);
709                Ok(())
710            }
711        };
712        result
713    }
714
715    /// Ends a transaction. Even if an error is returned, guarantees that the transaction in the
716    /// session and Coordinator has cleared its state.
717    #[instrument(level = "debug")]
718    pub async fn end_transaction(
719        &mut self,
720        action: EndTransactionAction,
721    ) -> Result<ExecuteResponse, AdapterError> {
722        let res = self
723            .send(|tx, session| Command::Commit {
724                action,
725                session,
726                tx,
727            })
728            .await;
729        // Commit isn't guaranteed to set the session's state to anything specific, so clear it
730        // here. It's safe to ignore the returned `TransactionStatus` because that doesn't contain
731        // any data that the Coordinator must act on for correctness.
732        let _ = self.session().clear_transaction();
733        res
734    }
735
736    /// Fails a transaction.
737    pub fn fail_transaction(&mut self) {
738        let session = self.session.take().expect("session invariant violated");
739        let session = session.fail_transaction();
740        self.session = Some(session);
741    }
742
743    /// Fetches the catalog.
744    #[instrument(level = "debug")]
745    pub async fn catalog_snapshot(&self, context: &str) -> Arc<Catalog> {
746        let start = std::time::Instant::now();
747        let CatalogSnapshot { catalog } = self
748            .send_without_session(|tx| Command::CatalogSnapshot { tx })
749            .await;
750        self.inner()
751            .metrics()
752            .catalog_snapshot_seconds
753            .with_label_values(&[context])
754            .observe(start.elapsed().as_secs_f64());
755        catalog
756    }
757
758    /// Dumps the catalog to a JSON string.
759    ///
760    /// No authorization is performed, so access to this function must be limited to internal
761    /// servers or superusers.
762    pub async fn dump_catalog(&self) -> Result<CatalogDump, AdapterError> {
763        let catalog = self.catalog_snapshot("dump_catalog").await;
764        catalog.dump().map_err(AdapterError::from)
765    }
766
767    /// Checks the catalog for internal consistency, returning a JSON object describing the
768    /// inconsistencies, if there are any.
769    ///
770    /// No authorization is performed, so access to this function must be limited to internal
771    /// servers or superusers.
772    pub async fn check_catalog(&self) -> Result<(), serde_json::Value> {
773        let catalog = self.catalog_snapshot("check_catalog").await;
774        catalog.check_consistency()
775    }
776
777    /// Checks the coordinator for internal consistency, returning a JSON object describing the
778    /// inconsistencies, if there are any. This is a superset of checks that check_catalog performs,
779    ///
780    /// No authorization is performed, so access to this function must be limited to internal
781    /// servers or superusers.
782    pub async fn check_coordinator(&self) -> Result<(), serde_json::Value> {
783        self.send_without_session(|tx| Command::CheckConsistency { tx })
784            .await
785            .map_err(|inconsistencies| {
786                serde_json::to_value(inconsistencies).unwrap_or_else(|_| {
787                    serde_json::Value::String("failed to serialize inconsistencies".to_string())
788                })
789            })
790    }
791
792    pub async fn dump_coordinator_state(&self) -> Result<serde_json::Value, anyhow::Error> {
793        self.send_without_session(|tx| Command::Dump { tx }).await
794    }
795
796    /// Tells the coordinator a statement has finished execution, in the cases
797    /// where we have no other reason to communicate with the coordinator.
798    pub fn retire_execute(&self, data: ExecuteContextExtra, reason: StatementEndedExecutionReason) {
799        if !data.is_trivial() {
800            let cmd = Command::RetireExecute { data, reason };
801            self.inner().send(cmd);
802        }
803    }
804
805    /// Inserts a set of rows into the given table.
806    ///
807    /// The rows only contain the columns positions in `columns`, so they
808    /// must be re-encoded for adding the default values for the remaining
809    /// ones.
810    pub async fn insert_rows(
811        &mut self,
812        target_id: CatalogItemId,
813        target_name: String,
814        columns: Vec<ColumnIndex>,
815        rows: Vec<Row>,
816        ctx_extra: ExecuteContextExtra,
817    ) -> Result<ExecuteResponse, AdapterError> {
818        // TODO: Remove this clone once we always have the session. It's currently needed because
819        // self.session returns a mut ref, so we can't call it twice.
820        let pcx = self.session().pcx().clone();
821
822        let session_meta = self.session().meta();
823
824        let catalog = self.catalog_snapshot("insert_rows").await;
825        let conn_catalog = catalog.for_session(self.session());
826        let catalog_state = conn_catalog.state();
827
828        // Collect optimizer parameters.
829        let optimizer_config = optimize::OptimizerConfig::from(conn_catalog.system_vars());
830        let prep = ExprPrepStyle::OneShot {
831            logical_time: EvalTime::NotAvailable,
832            session: &session_meta,
833            catalog_state,
834        };
835        let mut optimizer =
836            optimize::view::Optimizer::new_with_prep_no_limit(optimizer_config.clone(), None, prep);
837
838        let result: Result<_, AdapterError> = mz_sql::plan::plan_copy_from(
839            &pcx,
840            &conn_catalog,
841            target_id,
842            target_name,
843            columns,
844            rows,
845        )
846        .err_into()
847        .and_then(|values| optimizer.optimize(values).err_into())
848        .and_then(|values| {
849            // Copied rows must always be constants.
850            Coordinator::insert_constant(&catalog, self.session(), target_id, values.into_inner())
851        });
852        self.retire_execute(ctx_extra, (&result).into());
853        result
854    }
855
856    /// Gets the current value of all system variables.
857    pub async fn get_system_vars(&self) -> SystemVars {
858        self.inner().get_system_vars().await
859    }
860
861    /// Updates the specified system variables to the specified values.
862    pub async fn set_system_vars(
863        &mut self,
864        vars: BTreeMap<String, String>,
865    ) -> Result<(), AdapterError> {
866        let conn_id = self.session().conn_id().clone();
867        self.send_without_session(|tx| Command::SetSystemVars { vars, conn_id, tx })
868            .await
869    }
870
871    /// Terminates the client session.
872    pub async fn terminate(&mut self) {
873        let conn_id = self.session().conn_id().clone();
874        let res = self
875            .send_without_session(|tx| Command::Terminate {
876                conn_id,
877                tx: Some(tx),
878            })
879            .await;
880        if let Err(e) = res {
881            // Nothing we can do to handle a failed terminate so we just log and ignore it.
882            error!("Unable to terminate session: {e:?}");
883        }
884        // Prevent any communication with Coordinator after session is terminated.
885        self.inner = None;
886    }
887
888    /// Returns a mutable reference to the session bound to this client.
889    pub fn session(&mut self) -> &mut Session {
890        self.session.as_mut().expect("session invariant violated")
891    }
892
893    /// Returns a reference to the inner client.
894    pub fn inner(&self) -> &Client {
895        self.inner.as_ref().expect("inner invariant violated")
896    }
897
898    async fn send_without_session<T, F>(&self, f: F) -> T
899    where
900        F: FnOnce(oneshot::Sender<T>) -> Command,
901    {
902        let (tx, rx) = oneshot::channel();
903        self.inner().send(f(tx));
904        rx.await.expect("sender dropped")
905    }
906
907    #[instrument(level = "debug")]
908    async fn send<T, F>(&mut self, f: F) -> Result<T, AdapterError>
909    where
910        F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
911    {
912        self.send_with_cancel(f, futures::future::pending()).await
913    }
914
915    /// Send a [`Command`] to the Coordinator, with the ability to cancel the command.
916    ///
917    /// Note: the provided `cancel_future` must be cancel-safe as it's polled in a `select!` loop.
918    #[instrument(level = "debug")]
919    async fn send_with_cancel<T, F>(
920        &mut self,
921        f: F,
922        cancel_future: impl Future<Output = std::io::Error> + Send,
923    ) -> Result<T, AdapterError>
924    where
925        F: FnOnce(oneshot::Sender<Response<T>>, Session) -> Command,
926    {
927        let session = self.session.take().expect("session invariant violated");
928        let mut typ = None;
929        let application_name = session.application_name();
930        let name_hint = ApplicationNameHint::from_str(application_name);
931        let conn_id = session.conn_id().clone();
932        let (tx, rx) = oneshot::channel();
933
934        // Destructure self so we can hold a mutable reference to the inner client and session at
935        // the same time.
936        let Self {
937            inner: inner_client,
938            session: client_session,
939            ..
940        } = self;
941
942        // TODO(parkmycar): Leaking this invariant here doesn't feel great, but calling
943        // `self.client()` doesn't work because then Rust takes a borrow on the entirity of self.
944        let inner_client = inner_client.as_ref().expect("inner invariant violated");
945
946        // ~~SPOOKY ZONE~~
947        //
948        // This guard prevents a race where a `Session` is returned on `rx` but never placed
949        // back in `self` because the Future returned by this function is concurrently dropped
950        // with the Coordinator sending a response.
951        let mut guarded_rx = rx.with_guard(|response: Response<_>| {
952            *client_session = Some(response.session);
953        });
954
955        inner_client.send({
956            let cmd = f(tx, session);
957            // Measure the success and error rate of certain commands:
958            // - declare reports success of SQL statement planning
959            // - execute reports success of dataflow execution
960            match cmd {
961                Command::Execute { .. } => typ = Some("execute"),
962                Command::GetWebhook { .. } => typ = Some("webhook"),
963                Command::Startup { .. }
964                | Command::AuthenticatePassword { .. }
965                | Command::AuthenticateGetSASLChallenge { .. }
966                | Command::AuthenticateVerifySASLProof { .. }
967                | Command::CatalogSnapshot { .. }
968                | Command::Commit { .. }
969                | Command::CancelRequest { .. }
970                | Command::PrivilegedCancelRequest { .. }
971                | Command::GetSystemVars { .. }
972                | Command::SetSystemVars { .. }
973                | Command::Terminate { .. }
974                | Command::RetireExecute { .. }
975                | Command::CheckConsistency { .. }
976                | Command::Dump { .. } => {}
977            };
978            cmd
979        });
980
981        let mut cancel_future = pin::pin!(cancel_future);
982        let mut cancelled = false;
983        loop {
984            tokio::select! {
985                res = &mut guarded_rx => {
986                    // We received a result, so drop our guard to drop our borrows.
987                    drop(guarded_rx);
988
989                    let res = res.expect("sender dropped");
990                    let status = res.result.is_ok().then_some("success").unwrap_or("error");
991                    if let Err(err) = res.result.as_ref() {
992                        if name_hint.should_trace_errors() {
993                            tracing::warn!(?err, ?name_hint, "adapter response error");
994                        }
995                    }
996
997                    if let Some(typ) = typ {
998                        inner_client
999                            .metrics
1000                            .commands
1001                            .with_label_values(&[typ, status, name_hint.as_str()])
1002                            .inc();
1003                    }
1004                    *client_session = Some(res.session);
1005                    return res.result;
1006                },
1007                _err = &mut cancel_future, if !cancelled => {
1008                    cancelled = true;
1009                    inner_client.send(Command::PrivilegedCancelRequest {
1010                        conn_id: conn_id.clone(),
1011                    });
1012                }
1013            };
1014        }
1015    }
1016
1017    pub fn add_idle_in_transaction_session_timeout(&mut self) {
1018        let session = self.session();
1019        let timeout_dur = session.vars().idle_in_transaction_session_timeout();
1020        if !timeout_dur.is_zero() {
1021            let timeout_dur = timeout_dur.clone();
1022            if let Some(txn) = session.transaction().inner() {
1023                let txn_id = txn.id.clone();
1024                let timeout = TimeoutType::IdleInTransactionSession(txn_id);
1025                self.timeouts.add_timeout(timeout, timeout_dur);
1026            }
1027        }
1028    }
1029
1030    pub fn remove_idle_in_transaction_session_timeout(&mut self) {
1031        let session = self.session();
1032        if let Some(txn) = session.transaction().inner() {
1033            let txn_id = txn.id.clone();
1034            self.timeouts
1035                .remove_timeout(&TimeoutType::IdleInTransactionSession(txn_id));
1036        }
1037    }
1038
1039    /// # Cancel safety
1040    ///
1041    /// This method is cancel safe. If `recv` is used as the event in a
1042    /// `tokio::select!` statement and some other branch
1043    /// completes first, it is guaranteed that no messages were received on this
1044    /// channel.
1045    pub async fn recv_timeout(&mut self) -> Option<TimeoutType> {
1046        self.timeouts.recv().await
1047    }
1048}
1049
1050impl Drop for SessionClient {
1051    fn drop(&mut self) {
1052        // We may not have a session if this client was dropped while awaiting
1053        // a response. In this case, it is the coordinator's responsibility to
1054        // terminate the session.
1055        if let Some(session) = self.session.take() {
1056            // We may not have a connection to the Coordinator if the session was
1057            // prematurely terminated, for example due to a timeout.
1058            if let Some(inner) = &self.inner {
1059                inner.send(Command::Terminate {
1060                    conn_id: session.conn_id().clone(),
1061                    tx: None,
1062                })
1063            }
1064        }
1065    }
1066}
1067
1068#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
1069pub enum TimeoutType {
1070    IdleInTransactionSession(TransactionId),
1071}
1072
1073impl Display for TimeoutType {
1074    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1075        match self {
1076            TimeoutType::IdleInTransactionSession(txn_id) => {
1077                writeln!(f, "Idle in transaction session for transaction '{txn_id}'")
1078            }
1079        }
1080    }
1081}
1082
1083impl From<TimeoutType> for AdapterError {
1084    fn from(timeout: TimeoutType) -> Self {
1085        match timeout {
1086            TimeoutType::IdleInTransactionSession(_) => {
1087                AdapterError::IdleInTransactionSessionTimeout
1088            }
1089        }
1090    }
1091}
1092
1093struct Timeout {
1094    tx: mpsc::UnboundedSender<TimeoutType>,
1095    rx: mpsc::UnboundedReceiver<TimeoutType>,
1096    active_timeouts: BTreeMap<TimeoutType, AbortOnDropHandle<()>>,
1097}
1098
1099impl Timeout {
1100    fn new() -> Self {
1101        let (tx, rx) = mpsc::unbounded_channel();
1102        Timeout {
1103            tx,
1104            rx,
1105            active_timeouts: BTreeMap::new(),
1106        }
1107    }
1108
1109    /// # Cancel safety
1110    ///
1111    /// This method is cancel safe. If `recv` is used as the event in a
1112    /// `tokio::select!` statement and some other branch
1113    /// completes first, it is guaranteed that no messages were received on this
1114    /// channel.
1115    ///
1116    /// <https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.UnboundedReceiver.html#cancel-safety>
1117    async fn recv(&mut self) -> Option<TimeoutType> {
1118        self.rx.recv().await
1119    }
1120
1121    fn add_timeout(&mut self, timeout: TimeoutType, duration: Duration) {
1122        let tx = self.tx.clone();
1123        let timeout_key = timeout.clone();
1124        let handle = mz_ore::task::spawn(|| format!("{timeout_key}"), async move {
1125            tokio::time::sleep(duration).await;
1126            let _ = tx.send(timeout);
1127        })
1128        .abort_on_drop();
1129        self.active_timeouts.insert(timeout_key, handle);
1130    }
1131
1132    fn remove_timeout(&mut self, timeout: &TimeoutType) {
1133        self.active_timeouts.remove(timeout);
1134
1135        // Remove the timeout from the rx queue if it exists.
1136        let mut timeouts = Vec::new();
1137        while let Ok(pending_timeout) = self.rx.try_recv() {
1138            if timeout != &pending_timeout {
1139                timeouts.push(pending_timeout);
1140            }
1141        }
1142        for pending_timeout in timeouts {
1143            self.tx.send(pending_timeout).expect("rx is in this struct");
1144        }
1145    }
1146}
1147
1148/// A wrapper around a Stream of PeekResponseUnary that records when it sees the
1149/// first row data in the given histogram. It also keeps track of whether we have already observed
1150/// the end of the underlying stream.
1151#[derive(Derivative)]
1152#[derivative(Debug)]
1153pub struct RecordFirstRowStream {
1154    /// The underlying stream of rows.
1155    #[derivative(Debug = "ignore")]
1156    pub rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1157    /// The Instant when execution started.
1158    pub execute_started: Instant,
1159    /// The histogram where the time since `execute_started` will be recorded when we see the first
1160    /// row.
1161    pub time_to_first_row_seconds: Histogram,
1162    /// Whether we've seen any rows.
1163    pub saw_rows: bool,
1164    /// The Instant when we saw the first row.
1165    pub recorded_first_row_instant: Option<Instant>,
1166    /// Whether we have already observed the end of the underlying stream.
1167    pub no_more_rows: bool,
1168}
1169
1170impl RecordFirstRowStream {
1171    /// Create a new [`RecordFirstRowStream`]
1172    pub fn new(
1173        rows: Box<dyn Stream<Item = PeekResponseUnary> + Unpin + Send + Sync>,
1174        execute_started: Instant,
1175        client: &SessionClient,
1176        instance_id: Option<ComputeInstanceId>,
1177        strategy: Option<StatementExecutionStrategy>,
1178    ) -> Self {
1179        let histogram = Self::histogram(client, instance_id, strategy);
1180        Self {
1181            rows,
1182            execute_started,
1183            time_to_first_row_seconds: histogram,
1184            saw_rows: false,
1185            recorded_first_row_instant: None,
1186            no_more_rows: false,
1187        }
1188    }
1189
1190    fn histogram(
1191        client: &SessionClient,
1192        instance_id: Option<ComputeInstanceId>,
1193        strategy: Option<StatementExecutionStrategy>,
1194    ) -> Histogram {
1195        let isolation_level = *client
1196            .session
1197            .as_ref()
1198            .expect("session invariant")
1199            .vars()
1200            .transaction_isolation();
1201        let instance = match instance_id {
1202            Some(i) => Cow::Owned(i.to_string()),
1203            None => Cow::Borrowed("none"),
1204        };
1205        let strategy = match strategy {
1206            Some(s) => s.name(),
1207            None => "none",
1208        };
1209
1210        client
1211            .inner()
1212            .metrics()
1213            .time_to_first_row_seconds
1214            .with_label_values(&[&instance, isolation_level.as_str(), strategy])
1215    }
1216
1217    /// If you want to match [`RecordFirstRowStream`]'s logic but don't need
1218    /// a UnboundedReceiver, you can tell it when to record an observation.
1219    pub fn record(
1220        execute_started: Instant,
1221        client: &SessionClient,
1222        instance_id: Option<ComputeInstanceId>,
1223        strategy: Option<StatementExecutionStrategy>,
1224    ) {
1225        Self::histogram(client, instance_id, strategy)
1226            .observe(execute_started.elapsed().as_secs_f64());
1227    }
1228
1229    pub async fn recv(&mut self) -> Option<PeekResponseUnary> {
1230        let msg = self.rows.next().await;
1231        if !self.saw_rows && matches!(msg, Some(PeekResponseUnary::Rows(_))) {
1232            self.saw_rows = true;
1233            self.time_to_first_row_seconds
1234                .observe(self.execute_started.elapsed().as_secs_f64());
1235            self.recorded_first_row_instant = Some(Instant::now());
1236        }
1237        if msg.is_none() {
1238            self.no_more_rows = true;
1239        }
1240        msg
1241    }
1242}