Skip to main content

mz_environmentd/
http.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Embedded HTTP server.
11//!
12//! environmentd embeds an HTTP server for introspection into the running
13//! process. At the moment, its primary exports are Prometheus metrics, heap
14//! profiles, and catalog dumps.
15//!
16//! ## Authentication flow
17//!
18//! The server supports several authentication modes, controlled by the
19//! configured [`listeners::AuthenticatorKind`]. The general flow is:
20//!
21//! 1. **Identity resolution.** An authentication middleware runs on every
22//!    protected request and resolves the caller's identity via one of:
23//!    - **Credentials in headers.** The caller supplies a username/password or
24//!      token in the request headers. Supported by all [`listeners::AuthenticatorKind`]s.
25//!    - **Session reuse.** If the caller has an active authenticated session
26//!      (established via `POST /api/login`) and has not supplied credentials
27//!      in the request headers, the session is reused. Only available for
28//!      [`listeners::AuthenticatorKind::Password`] and [`listeners::AuthenticatorKind::Oidc`].
29//!    - **Trusted header injection.** A trusted upstream proxy (e.g. Teleport)
30//!      may inject the caller's identity into the request headers. Only available
31//!      for [`listeners::AuthenticatorKind::None`].
32//!
33//! 2. **Session initialization.** Once the caller's identity is known, an
34//!    adapter session is opened on their behalf. This happens as part of
35//!    request processing, after all middleware has run.
36//!
37//! 3. **Request handling.** The handler executes the request (e.g. runs SQL)
38//!    using the initialized adapter session.
39//!
40//! ### WebSocket
41//!
42//! The WebSocket flow is identical to the HTTP flow with two differences:
43//!
44//! - Credentials are not read from request headers. Instead, the first
45//!   message sent by the client is treated as the authentication message.
46//! - Session initialization (step 2) happens inside the WebSocket handler
47//!   itself, rather than as a separate middleware step.
48
49// Axum handlers must use async, but often don't actually use `await`.
50#![allow(clippy::unused_async)]
51
52use std::borrow::Cow;
53use std::collections::BTreeMap;
54use std::fmt::Debug;
55use std::net::{IpAddr, SocketAddr};
56use std::pin::Pin;
57use std::sync::Arc;
58use std::time::{Duration, SystemTime};
59
60use anyhow::Context;
61use axum::error_handling::HandleErrorLayer;
62use axum::extract::ws::{Message, WebSocket};
63use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
64use axum::middleware::{self, Next};
65use axum::response::{IntoResponse, Redirect, Response};
66use axum::{Extension, Json, Router, routing};
67use futures::future::{Shared, TryFutureExt};
68use headers::authorization::{Authorization, Basic, Bearer};
69use headers::{HeaderMapExt, HeaderName};
70use http::header::{AUTHORIZATION, CONTENT_TYPE};
71use http::uri::Scheme;
72use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
73use hyper_openssl::SslStream;
74use hyper_openssl::client::legacy::MaybeHttpsStream;
75use hyper_util::rt::TokioIo;
76use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
77use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
78use mz_adapter_types::dyncfgs::OIDC_GROUP_CLAIM;
79use mz_auth::Authenticated;
80use mz_auth::password::Password;
81use mz_authenticator::Authenticator;
82use mz_controller::ReplicaHttpLocator;
83use mz_frontegg_auth::Error as FronteggError;
84use mz_http_util::DynamicFilterTarget;
85use mz_ore::cast::u64_to_usize;
86use mz_ore::metrics::MetricsRegistry;
87use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
88use mz_ore::str::StrExt;
89use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
90use mz_repr::user::ExternalUserMetadata;
91use mz_server_core::listeners::{self, AllowedRoles, HttpRoutesEnabled};
92use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
93use mz_sql::session::metadata::SessionMetadata;
94use mz_sql::session::user::{
95    HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
96};
97use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
98use openssl::ssl::Ssl;
99use prometheus::{
100    COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
101};
102use serde::{Deserialize, Serialize};
103use serde_json::json;
104use thiserror::Error;
105use tokio::io::AsyncWriteExt;
106use tokio::sync::oneshot::Receiver;
107use tokio::sync::{oneshot, watch};
108use tokio_metrics::TaskMetrics;
109use tower::limit::GlobalConcurrencyLimitLayer;
110use tower::{Service, ServiceBuilder};
111use tower_http::cors::{AllowOrigin, Any, CorsLayer};
112use tower_sessions::{
113    MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
114    SessionManagerLayer as TowerSessionManagerLayer,
115};
116use tracing::warn;
117
118use crate::BUILD_INFO;
119use crate::deployment::state::DeploymentStateHandle;
120use crate::http::sql::{ExistingUser, SqlError};
121
122mod catalog;
123mod cluster;
124mod console;
125mod mcp;
126pub mod mcp_metrics;
127mod memory;
128mod metrics;
129mod metrics_public;
130mod metrics_viz;
131pub(crate) mod oauth_metadata;
132mod probe;
133mod prometheus;
134mod root;
135mod sql;
136mod webhook;
137
138pub use metrics::Metrics;
139pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
140
141/// Maximum allowed size for a request.
142pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
143
144const SESSION_DURATION: Duration = Duration::from_secs(8 * 3600); // 8 hours
145
146const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
147
148#[derive(Debug)]
149pub struct HttpConfig {
150    pub source: &'static str,
151    pub tls: Option<ReloadingSslContext>,
152    pub authenticator_kind: listeners::AuthenticatorKind,
153    pub frontegg: Option<mz_frontegg_auth::Authenticator>,
154    pub oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
155    pub adapter_client_rx: Shared<Receiver<Client>>,
156    pub allowed_origin: AllowOrigin,
157    /// Raw list of allowed CORS origins, used by the MCP endpoints for
158    /// server-side Origin validation to defend against DNS rebinding.
159    pub allowed_origin_list: Vec<HeaderValue>,
160    pub active_connection_counter: ConnectionCounter,
161    pub helm_chart_version: Option<String>,
162    /// Externally-visible host name for this environment (without scheme).
163    ///
164    /// Used as the canonical host when constructing absolute URLs that the
165    /// server needs to publish (e.g. the OAuth Protected Resource Metadata
166    /// `resource` field, RFC 9728). When `None`, callers fall back to the
167    /// request's `Host` header, which is correct for unproxied dev setups
168    /// but loses fidelity behind a load balancer that rewrites Host.
169    ///
170    /// We deliberately do NOT consult `X-Forwarded-Host` or
171    /// `X-Forwarded-Proto`: there is no proxy-trust model in environmentd
172    /// today, and an attacker reaching the server directly can otherwise
173    /// poison the published metadata URLs.
174    pub http_host_name: Option<String>,
175    pub frontegg_oauth_issuer_url: Option<String>,
176    pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
177    pub metrics: Metrics,
178    pub metrics_registry: MetricsRegistry,
179    pub mcp_metrics: mcp_metrics::McpMetrics,
180    pub oauth_metadata_metrics: oauth_metadata::OauthMetadataMetrics,
181    pub allowed_roles: AllowedRoles,
182    pub internal_route_config: Arc<InternalRouteConfig>,
183    pub routes_enabled: HttpRoutesEnabled,
184    /// Locator for cluster replica HTTP addresses, used for proxying requests.
185    pub replica_http_locator: Arc<ReplicaHttpLocator>,
186}
187
188#[derive(Debug, Clone)]
189pub struct InternalRouteConfig {
190    pub deployment_state_handle: DeploymentStateHandle,
191    pub internal_console_redirect_url: Option<String>,
192}
193
194#[derive(Clone)]
195pub struct WsState {
196    frontegg: Option<mz_frontegg_auth::Authenticator>,
197    oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
198    authenticator_kind: listeners::AuthenticatorKind,
199    adapter_client_rx: Delayed<mz_adapter::Client>,
200    active_connection_counter: ConnectionCounter,
201    helm_chart_version: Option<String>,
202    allowed_roles: AllowedRoles,
203}
204
205#[derive(Clone)]
206pub struct WebhookState {
207    adapter_client_rx: Delayed<mz_adapter::Client>,
208    webhook_cache: WebhookAppenderCache,
209}
210
211#[derive(Clone, Debug)]
212struct HelmChartVersion(Option<String>);
213
214#[derive(Debug)]
215pub struct HttpServer {
216    tls: Option<ReloadingSslContext>,
217    router: Router,
218}
219
220impl HttpServer {
221    pub fn new(
222        HttpConfig {
223            source,
224            tls,
225            authenticator_kind,
226            frontegg,
227            oidc_rx,
228            adapter_client_rx,
229            allowed_origin,
230            allowed_origin_list,
231            active_connection_counter,
232            helm_chart_version,
233            http_host_name,
234            frontegg_oauth_issuer_url,
235            concurrent_webhook_req,
236            metrics,
237            metrics_registry,
238            mcp_metrics,
239            oauth_metadata_metrics,
240            allowed_roles,
241            internal_route_config,
242            routes_enabled,
243            replica_http_locator,
244        }: HttpConfig,
245    ) -> HttpServer {
246        let tls_enabled = tls.is_some();
247        let webhook_cache = WebhookAppenderCache::new();
248
249        // Compute OAuth discovery once per listener so the Bearer challenge
250        // and the discovery handler always agree, and the middleware doesn't
251        // re-derive (and re-allocate the Frontegg issuer) on each request.
252        let oauth_discovery = Arc::new(oauth_metadata::McpOAuthDiscovery::for_authenticator(
253            authenticator_kind,
254            frontegg_oauth_issuer_url.as_deref(),
255        ));
256
257        // Create secure session store and manager
258        let session_store = TowerSessionMemoryStore::default();
259        let session_layer = TowerSessionManagerLayer::new(session_store)
260            .with_secure(tls_enabled) // Enforce HTTPS
261            .with_same_site(tower_sessions::cookie::SameSite::Strict) // Prevent CSRF
262            .with_http_only(true) // Prevent XSS
263            .with_name("mz_session") // Custom cookie name
264            .with_path("/"); // Set cookie path
265
266        let frontegg_middleware = frontegg.clone();
267        let oidc_middleware_rx = oidc_rx.clone();
268        let adapter_client_middleware_rx = adapter_client_rx.clone();
269        let auth_middleware = middleware::from_fn(move |req, next| {
270            let frontegg = frontegg_middleware.clone();
271            let oidc_rx = oidc_middleware_rx.clone();
272            let adapter_client_rx = adapter_client_middleware_rx.clone();
273            async move {
274                http_auth(
275                    req,
276                    next,
277                    tls_enabled,
278                    authenticator_kind,
279                    frontegg,
280                    oidc_rx,
281                    adapter_client_rx,
282                    allowed_roles,
283                )
284                .await
285            }
286        });
287
288        let mut router = Router::new();
289        let mut base_router = Router::new();
290        let cluster_proxy_config = Arc::new(cluster::ClusterProxyConfig::new(Arc::clone(
291            &replica_http_locator,
292        )));
293        if routes_enabled.base {
294            base_router = base_router
295                .route(
296                    "/",
297                    routing::get(move || async move { root::handle_home(routes_enabled).await }),
298                )
299                .route("/api/sql", routing::post(sql::handle_sql))
300                .route("/memory", routing::get(memory::handle_memory))
301                .route(
302                    "/hierarchical-memory",
303                    routing::get(memory::handle_hierarchical_memory),
304                )
305                .route(
306                    "/metrics-viz",
307                    routing::get(metrics_viz::handle_metrics_viz),
308                )
309                .route("/static/{*path}", routing::get(root::handle_static))
310                .route(
311                    "/metrics/public",
312                    routing::get(metrics_public::handle_public_metrics),
313                )
314                .layer(Extension(metrics_registry.clone()))
315                .layer(Extension(Arc::clone(&cluster_proxy_config)));
316
317            let mut ws_router = Router::new()
318                .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
319                .with_state(WsState {
320                    frontegg,
321                    oidc_rx: oidc_rx.clone(),
322                    authenticator_kind,
323                    adapter_client_rx: adapter_client_rx.clone(),
324                    active_connection_counter: active_connection_counter.clone(),
325                    helm_chart_version: helm_chart_version.clone(),
326                    allowed_roles,
327                });
328            if let listeners::AuthenticatorKind::None = authenticator_kind {
329                ws_router = ws_router.layer(middleware::from_fn_with_state(
330                    allowed_roles,
331                    x_materialize_user_header_auth,
332                ));
333            }
334            router = router.merge(ws_router);
335        }
336        if routes_enabled.profiling {
337            base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
338        }
339
340        if routes_enabled.webhook {
341            let webhook_router = Router::new()
342                .route(
343                    "/api/webhook/{:database}/{:schema}/{:id}",
344                    routing::post(webhook::handle_webhook),
345                )
346                .with_state(WebhookState {
347                    adapter_client_rx: adapter_client_rx.clone(),
348                    webhook_cache,
349                })
350                .layer(
351                    tower_http::decompression::RequestDecompressionLayer::new()
352                        .gzip(true)
353                        .deflate(true)
354                        .br(true)
355                        .zstd(true),
356                )
357                .layer(
358                    CorsLayer::new()
359                        .allow_methods(Method::POST)
360                        .allow_origin(AllowOrigin::mirror_request())
361                        .allow_headers(Any),
362                )
363                .layer(
364                    ServiceBuilder::new()
365                        .layer(HandleErrorLayer::new(handle_load_error))
366                        .load_shed()
367                        .layer(GlobalConcurrencyLimitLayer::with_semaphore(
368                            concurrent_webhook_req,
369                        )),
370                );
371            router = router.merge(webhook_router);
372        }
373
374        if routes_enabled.internal {
375            let console_config = Arc::new(console::ConsoleProxyConfig::new(
376                internal_route_config.internal_console_redirect_url.clone(),
377                "/internal-console".to_string(),
378            ));
379            base_router = base_router
380                .route(
381                    "/api/opentelemetry/config",
382                    routing::put({
383                        move |_: axum::Json<DynamicFilterTarget>| async {
384                            (
385                                StatusCode::BAD_REQUEST,
386                                "This endpoint has been replaced. \
387                            Use the `opentelemetry_filter` system variable."
388                                    .to_string(),
389                            )
390                        }
391                    }),
392                )
393                .route(
394                    "/api/stderr/config",
395                    routing::put({
396                        move |_: axum::Json<DynamicFilterTarget>| async {
397                            (
398                                StatusCode::BAD_REQUEST,
399                                "This endpoint has been replaced. \
400                            Use the `log_filter` system variable."
401                                    .to_string(),
402                            )
403                        }
404                    }),
405                )
406                .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
407                .route(
408                    "/api/catalog/dump",
409                    routing::get(catalog::handle_catalog_dump),
410                )
411                .route(
412                    "/api/catalog/check",
413                    routing::get(catalog::handle_catalog_check),
414                )
415                .route(
416                    "/api/catalog/inject-audit-events",
417                    routing::post(catalog::handle_inject_audit_events),
418                )
419                .route(
420                    "/api/coordinator/check",
421                    routing::get(catalog::handle_coordinator_check),
422                )
423                .route(
424                    "/api/coordinator/dump",
425                    routing::get(catalog::handle_coordinator_dump),
426                )
427                .route(
428                    "/internal-console",
429                    routing::get(|| async { Redirect::temporary("/internal-console/") }),
430                )
431                .route(
432                    "/internal-console/{*path}",
433                    routing::get(console::handle_internal_console),
434                )
435                .route(
436                    "/internal-console/",
437                    routing::get(console::handle_internal_console),
438                )
439                .layer(Extension(console_config));
440
441            // Cluster HTTP proxy routes.
442            base_router = base_router
443                .route("/clusters", routing::get(cluster::handle_clusters))
444                .route(
445                    "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/",
446                    routing::any(cluster::handle_cluster_proxy_root),
447                )
448                .route(
449                    "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/{*path}",
450                    routing::any(cluster::handle_cluster_proxy),
451                )
452                .layer(Extension(Arc::clone(&cluster_proxy_config)));
453
454            let leader_router = Router::new()
455                .route("/api/leader/status", routing::get(handle_leader_status))
456                .route("/api/leader/promote", routing::post(handle_leader_promote))
457                .route(
458                    "/api/leader/skip-catchup",
459                    routing::post(handle_leader_skip_catchup),
460                )
461                .layer(auth_middleware.clone())
462                .with_state(internal_route_config.deployment_state_handle.clone());
463            router = router.merge(leader_router);
464        }
465
466        if routes_enabled.metrics {
467            // Clone into the closure so the outer `metrics_registry` binding
468            // stays available for other route blocks below (e.g. MCP metric
469            // registration).
470            let metrics_registry_for_handler = metrics_registry.clone();
471            let metrics_router = Router::new()
472                .route(
473                    "/metrics",
474                    routing::get(move |headers: HeaderMap| async move {
475                        mz_http_util::handle_prometheus(&metrics_registry_for_handler, headers)
476                            .await
477                    }),
478                )
479                .route(
480                    "/metrics/mz_usage",
481                    routing::get(|client: AuthedClient, headers: HeaderMap| async move {
482                        let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
483                        mz_http_util::handle_prometheus(&registry, headers).await
484                    }),
485                )
486                .route(
487                    "/metrics/mz_frontier",
488                    routing::get(|client: AuthedClient, headers: HeaderMap| async move {
489                        let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
490                        mz_http_util::handle_prometheus(&registry, headers).await
491                    }),
492                )
493                .route(
494                    "/metrics/mz_compute",
495                    routing::get(|client: AuthedClient, headers: HeaderMap| async move {
496                        let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
497                        mz_http_util::handle_prometheus(&registry, headers).await
498                    }),
499                )
500                .route(
501                    "/metrics/mz_storage",
502                    routing::get(|client: AuthedClient, headers: HeaderMap| async move {
503                        let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
504                        mz_http_util::handle_prometheus(&registry, headers).await
505                    }),
506                )
507                .route(
508                    "/api/livez",
509                    routing::get(mz_http_util::handle_liveness_check),
510                )
511                .route("/api/readyz", routing::get(probe::handle_ready))
512                .layer(auth_middleware.clone())
513                .layer(Extension(adapter_client_rx.clone()))
514                .layer(Extension(active_connection_counter.clone()))
515                .layer(Extension(HelmChartVersion(helm_chart_version.clone())));
516            router = router.merge(metrics_router);
517        }
518
519        if routes_enabled.console_config {
520            let console_config_router = Router::new()
521                .route(
522                    "/api/console/config",
523                    routing::get(console::handle_console_config),
524                )
525                .layer(Extension(adapter_client_rx.clone()))
526                .layer(Extension(active_connection_counter.clone()));
527            router = router.merge(console_config_router);
528        }
529
530        // MCP (Model Context Protocol) endpoints
531        // Enabled via runtime `routes_enabled.mcp_agent` and `routes_enabled.mcp_developer` configuration
532        if routes_enabled.mcp_agent || routes_enabled.mcp_developer {
533            use tracing::info;
534
535            // RFC 9728 Protected Resource Metadata. Public route: MCP
536            // clients fetch it before they have a token. Sits on its own
537            // router so the auth middleware never runs on it. The handler
538            // 404s when the listener does not advertise OAuth (see
539            // `McpOAuthDiscovery`) or `oidc_issuer` is unset, so it is safe
540            // to enable unconditionally whenever MCP is enabled.
541            // RFC 9728 ยง3.1 lets clients look up per-resource metadata
542            // via a path-suffixed well-known URI before falling back to
543            // the bare one. The MCP endpoints share an identical
544            // metadata view today, so we serve the same handler at all
545            // three paths.
546            let oauth_metadata_router = Router::new()
547                .route(
548                    oauth_metadata::PROTECTED_RESOURCE_METADATA_PATH,
549                    routing::get(oauth_metadata::handle_protected_resource_metadata),
550                )
551                .route(
552                    oauth_metadata::PROTECTED_RESOURCE_METADATA_PATH_AGENT,
553                    routing::get(oauth_metadata::handle_protected_resource_metadata),
554                )
555                .route(
556                    oauth_metadata::PROTECTED_RESOURCE_METADATA_PATH_DEVELOPER,
557                    routing::get(oauth_metadata::handle_protected_resource_metadata),
558                )
559                .layer(Extension(adapter_client_rx.clone()))
560                .layer(Extension(oauth_metadata::McpOAuthConfig {
561                    http_host_name: http_host_name.clone(),
562                    discovery: Arc::clone(&oauth_discovery),
563                }))
564                .layer(Extension(oauth_metadata_metrics.clone()));
565            router = router.merge(oauth_metadata_router);
566
567            let mut mcp_router = Router::new();
568
569            if routes_enabled.mcp_agent {
570                info!("Enabling MCP agent endpoint: /api/mcp/agent");
571                mcp_router = mcp_router.route(
572                    "/api/mcp/agent",
573                    routing::post(mcp::handle_mcp_agent).get(mcp::handle_mcp_method_not_allowed),
574                );
575            }
576
577            if routes_enabled.mcp_developer {
578                info!("Enabling MCP developer endpoint: /api/mcp/developer");
579                mcp_router = mcp_router.route(
580                    "/api/mcp/developer",
581                    routing::post(mcp::handle_mcp_developer)
582                        .get(mcp::handle_mcp_method_not_allowed),
583                );
584            }
585
586            // The MCP handlers perform a server-side Origin check against this
587            // allowlist to defend against DNS rebinding attacks (see
588            // database-issues#11311). The CorsLayer alone is not enough: in a
589            // DNS rebinding attack the browser considers the request
590            // same-origin, so no preflight fires and CORS enforcement is
591            // bypassed.
592            let mcp_allowed_origins = Arc::new(allowed_origin_list.clone());
593            mcp_router = mcp_router
594                .layer(auth_middleware.clone())
595                .layer(Extension(oauth_metadata::McpOAuthConfig {
596                    http_host_name: http_host_name.clone(),
597                    discovery: Arc::clone(&oauth_discovery),
598                }))
599                .layer(Extension(adapter_client_rx.clone()))
600                .layer(Extension(active_connection_counter.clone()))
601                .layer(Extension(HelmChartVersion(helm_chart_version.clone())))
602                .layer(Extension(mcp_allowed_origins))
603                .layer(Extension(mcp_metrics))
604                .layer(
605                    CorsLayer::new()
606                        .allow_methods(Method::POST)
607                        .allow_origin(allowed_origin.clone())
608                        .allow_headers([AUTHORIZATION, CONTENT_TYPE]),
609                );
610            router = router.merge(mcp_router);
611        }
612
613        base_router = base_router
614            .layer(auth_middleware.clone())
615            .layer(Extension(adapter_client_rx.clone()))
616            .layer(Extension(active_connection_counter.clone()))
617            .layer(Extension(HelmChartVersion(helm_chart_version)))
618            .layer(
619                CorsLayer::new()
620                    .allow_credentials(false)
621                    .allow_headers([
622                        AUTHORIZATION,
623                        CONTENT_TYPE,
624                        HeaderName::from_static("x-materialize-version"),
625                    ])
626                    .allow_methods(Any)
627                    .allow_origin(allowed_origin)
628                    .expose_headers(Any)
629                    .max_age(Duration::from_secs(60) * 60),
630            );
631
632        match authenticator_kind {
633            listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Oidc => {
634                base_router = base_router.layer(session_layer.clone());
635
636                let login_router = Router::new()
637                    .route("/api/login", routing::post(handle_login))
638                    .route("/api/logout", routing::post(handle_logout))
639                    .layer(Extension(adapter_client_rx))
640                    .layer(Extension(allowed_roles));
641                router = router.merge(login_router).layer(session_layer);
642            }
643            listeners::AuthenticatorKind::None => {
644                base_router = base_router.layer(middleware::from_fn_with_state(
645                    allowed_roles,
646                    x_materialize_user_header_auth,
647                ));
648            }
649            _ => {}
650        }
651
652        router = router
653            .merge(base_router)
654            .apply_default_layers(source, metrics);
655
656        HttpServer { tls, router }
657    }
658}
659
660impl Server for HttpServer {
661    const NAME: &'static str = "http";
662
663    fn handle_connection(
664        &self,
665        conn: Connection,
666        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
667    ) -> ConnectionHandler {
668        let router = self.router.clone();
669        let tls_context = self.tls.clone();
670        let mut conn = TokioIo::new(conn);
671
672        Box::pin(async {
673            let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
674            let peer_addr = conn
675                .inner_mut()
676                .take_proxy_header_address()
677                .await
678                .map(|a| a.source)
679                .unwrap_or(direct_peer_addr);
680
681            let (conn, conn_protocol) = match tls_context {
682                Some(tls_context) => {
683                    let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
684                    if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
685                        let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
686                        return Err(e.into());
687                    }
688                    (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
689                }
690                _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
691            };
692            let mut make_tower_svc = router
693                .layer(Extension(conn_protocol))
694                .into_make_service_with_connect_info::<SocketAddr>();
695            let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
696            let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
697            let http = hyper::server::conn::http1::Builder::new();
698            http.serve_connection(conn, hyper_svc)
699                .with_upgrades()
700                .err_into()
701                .await
702        })
703    }
704}
705
706pub async fn handle_leader_status(
707    State(deployment_state_handle): State<DeploymentStateHandle>,
708) -> impl IntoResponse {
709    let status = deployment_state_handle.status();
710    (StatusCode::OK, Json(json!({ "status": status })))
711}
712
713pub async fn handle_leader_promote(
714    State(deployment_state_handle): State<DeploymentStateHandle>,
715) -> impl IntoResponse {
716    match deployment_state_handle.try_promote() {
717        Ok(()) => {
718            // TODO(benesch): the body here is redundant. Should just return
719            // 204.
720            let status = StatusCode::OK;
721            let body = Json(json!({
722                "result": "Success",
723            }));
724            (status, body)
725        }
726        Err(()) => {
727            // TODO(benesch): the nesting here is redundant given the error
728            // code. Should just return the `{"message": "..."}` object.
729            let status = StatusCode::BAD_REQUEST;
730            let body = Json(json!({
731                "result": {"Failure": {"message": "cannot promote leader while initializing"}},
732            }));
733            (status, body)
734        }
735    }
736}
737
738pub async fn handle_leader_skip_catchup(
739    State(deployment_state_handle): State<DeploymentStateHandle>,
740) -> impl IntoResponse {
741    match deployment_state_handle.try_skip_catchup() {
742        Ok(()) => StatusCode::NO_CONTENT.into_response(),
743        Err(()) => {
744            let status = StatusCode::BAD_REQUEST;
745            let body = Json(json!({
746                "message": "cannot skip catchup in this phase of initialization; try again later",
747            }));
748            (status, body).into_response()
749        }
750    }
751}
752
753async fn x_materialize_user_header_auth(
754    State(allowed_roles): State<AllowedRoles>,
755    mut req: Request,
756    next: Next,
757) -> impl IntoResponse {
758    // TODO migrate teleport to basic auth and remove this.
759    if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
760        let username = match username {
761            Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
762            _ => {
763                return Err(AuthError::MismatchedUser(format!(
764                    "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
765                )));
766            }
767        };
768        // Enforce the listener's `allowed_roles` policy here. Without this,
769        // a listener with `authenticator_kind=None` and `allowed_roles=Normal`
770        // would let any caller assert `x-materialize-user: mz_system` and
771        // bypass the role restriction.
772        check_role_allowed(&username, allowed_roles)?;
773        req.extensions_mut().insert(AuthedUser {
774            name: username,
775            external_metadata_rx: None,
776            authenticated: Authenticated,
777            authenticator_kind: mz_auth::AuthenticatorKind::None,
778            groups: None,
779        });
780    }
781    Ok(next.run(req).await)
782}
783
784pub(crate) type Delayed<T> = Shared<oneshot::Receiver<T>>;
785
786/// Resolve the dyncfg-configured group claim path from a delayed adapter
787/// client. Callers must already have driven `adapter_client_rx` to readiness
788/// (e.g. via `get_authenticator`), so the await here is non-blocking.
789async fn group_claim_for(adapter_client_rx: &Delayed<Client>) -> String {
790    let client = adapter_client_rx
791        .clone()
792        .await
793        .expect("adapter client receiver dropped");
794    OIDC_GROUP_CLAIM.get(client.get_system_vars().await.dyncfgs())
795}
796
797#[derive(Clone)]
798enum ConnProtocol {
799    Http,
800    Https,
801}
802
803#[derive(Clone, Debug)]
804pub struct AuthedUser {
805    name: String,
806    external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
807    authenticated: Authenticated,
808    authenticator_kind: mz_auth::AuthenticatorKind,
809    /// Groups from JWT claims for OIDC group-to-role sync.
810    groups: Option<Vec<String>>,
811}
812
813pub struct AuthedClient {
814    pub client: SessionClient,
815    pub connection_guard: Option<ConnectionHandle>,
816}
817
818impl AuthedClient {
819    async fn new<F>(
820        adapter_client: &Client,
821        user: AuthedUser,
822        peer_addr: IpAddr,
823        active_connection_counter: ConnectionCounter,
824        helm_chart_version: Option<String>,
825        session_config: F,
826        options: BTreeMap<String, String>,
827        now: NowFn,
828    ) -> Result<Self, AdapterError>
829    where
830        F: FnOnce(&mut AdapterSession),
831    {
832        let conn_id = adapter_client.new_conn_id()?;
833        let mut session = adapter_client.new_session(
834            AdapterSessionConfig {
835                conn_id,
836                uuid: epoch_to_uuid_v7(&(now)()),
837                user: user.name,
838                client_ip: Some(peer_addr),
839                external_metadata_rx: user.external_metadata_rx,
840                helm_chart_version,
841                authenticator_kind: user.authenticator_kind,
842                groups: user.groups,
843            },
844            user.authenticated,
845        );
846        let connection_guard = active_connection_counter.allocate_connection(session.user())?;
847
848        session_config(&mut session);
849        let system_vars = adapter_client.get_system_vars().await;
850        for (key, val) in options {
851            const LOCAL: bool = false;
852            if let Err(err) =
853                session
854                    .vars_mut()
855                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
856            {
857                session.add_notice(AdapterNotice::BadStartupSetting {
858                    name: key.to_string(),
859                    reason: err.to_string(),
860                })
861            }
862        }
863        let adapter_client = adapter_client.startup(session).await?;
864        Ok(AuthedClient {
865            client: adapter_client,
866            connection_guard,
867        })
868    }
869}
870
871impl<S> FromRequestParts<S> for AuthedClient
872where
873    S: Send + Sync,
874{
875    type Rejection = Response;
876
877    async fn from_request_parts(
878        req: &mut http::request::Parts,
879        state: &S,
880    ) -> Result<Self, Self::Rejection> {
881        #[derive(Debug, Default, Deserialize)]
882        struct Params {
883            #[serde(default)]
884            options: String,
885        }
886        let params: Query<Params> = Query::from_request_parts(req, state)
887            .await
888            .unwrap_or_default();
889
890        let peer_addr = req
891            .extensions
892            .get::<ConnectInfo<SocketAddr>>()
893            .expect("ConnectInfo extension guaranteed to exist")
894            .0
895            .ip();
896
897        let user = req.extensions.get::<AuthedUser>().unwrap();
898        let adapter_client = req
899            .extensions
900            .get::<Delayed<mz_adapter::Client>>()
901            .unwrap()
902            .clone();
903        let adapter_client = adapter_client.await.map_err(|_| {
904            (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
905        })?;
906        let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
907        let helm_chart_version = req
908            .extensions
909            .get::<HelmChartVersion>()
910            .map(|h| h.0.clone())
911            .unwrap_or(None);
912
913        let options = if params.options.is_empty() {
914            // It's possible 'options' simply wasn't provided, we don't want that to
915            // count as a failure to deserialize
916            BTreeMap::<String, String>::default()
917        } else {
918            match serde_json::from_str(&params.options) {
919                Ok(options) => options,
920                Err(_e) => {
921                    // If we fail to deserialize options, fail the request.
922                    let code = StatusCode::BAD_REQUEST;
923                    let msg = format!("Failed to deserialize {} map", "options".quoted());
924                    return Err((code, msg).into_response());
925                }
926            }
927        };
928
929        let client = AuthedClient::new(
930            &adapter_client,
931            user.clone(),
932            peer_addr,
933            active_connection_counter.clone(),
934            helm_chart_version,
935            |session| {
936                session
937                    .vars_mut()
938                    .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
939                    .expect("known to exist")
940            },
941            options,
942            SYSTEM_TIME.clone(),
943        )
944        .await
945        .map_err(|e| {
946            let status = match e {
947                AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
948                    StatusCode::FORBIDDEN
949                }
950                _ => StatusCode::INTERNAL_SERVER_ERROR,
951            };
952            (status, Json(SqlError::from(e))).into_response()
953        })?;
954
955        Ok(client)
956    }
957}
958
959/// Per-request decision about which `WWW-Authenticate` challenges to emit
960/// on a 401, computed by the auth middleware.
961///
962/// Carries both the `Basic` toggle (today's behavior, kept for the SQL HTTP
963/// layer and friends) and an optional `Bearer` challenge with a
964/// `resource_metadata` URL per RFC 9728. The Bearer challenge is only set on
965/// routes that attach an [`oauth_metadata::McpOAuthConfig`] extension; other
966/// routes emit only `Basic` so their behavior is unchanged.
967#[derive(Debug, Clone, Default)]
968pub(crate) struct WwwAuthenticateChallenges {
969    /// Whether to emit `WWW-Authenticate: Basic realm=Materialize`.
970    pub include_basic: bool,
971    /// If `Some`, also emit `WWW-Authenticate: Bearer
972    /// resource_metadata="<url>"`. The URL points at this server's RFC 9728
973    /// Protected Resource Metadata document, which advertises the
974    /// authorization server the client should use.
975    pub bearer_resource_metadata: Option<String>,
976    /// If `Some`, also emit `scope="<scope>"` inside the Bearer challenge.
977    /// Tells clients which OAuth scope to request a token with for this
978    /// resource. Only set in conjunction with `bearer_resource_metadata`
979    /// (a scope challenge with no resource hint would be confusing).
980    pub bearer_scope: Option<&'static str>,
981}
982
983#[derive(Debug, Error)]
984pub(crate) enum AuthError {
985    #[error("role dissallowed")]
986    RoleDisallowed(String),
987    #[error("{0}")]
988    Frontegg(#[from] FronteggError),
989    #[error("missing authorization header")]
990    MissingHttpAuthentication {
991        challenges: WwwAuthenticateChallenges,
992    },
993    #[error("{0}")]
994    MismatchedUser(String),
995    #[error("session expired")]
996    SessionExpired,
997    #[error("failed to update session")]
998    FailedToUpdateSession,
999    #[error("invalid credentials")]
1000    InvalidCredentials,
1001    /// Payload is `OidcError`'s sanitized `Display` (no expected-values leaks).
1002    #[error("{0}")]
1003    OidcFailed(String),
1004}
1005
1006impl IntoResponse for AuthError {
1007    fn into_response(self) -> Response {
1008        warn!("HTTP request failed authentication: {}", self);
1009        let mut headers = HeaderMap::new();
1010        // We omit most detail from the error message we send to the client, to
1011        // avoid giving attackers unnecessary information. `OidcFailed` is the
1012        // exception: its payload is a sanitized `OidcError::Display` that the
1013        // console embeds in the login-page error.
1014        let body = match &self {
1015            // Bearer goes first so OAuth-aware clients see it before the
1016            // Basic fallback. RFC 7235 allows emitting multiple
1017            // `WWW-Authenticate` headers; we use one per scheme so each
1018            // challenge is unambiguously framed; some parsers struggle
1019            // with multiple schemes on a single header value.
1020            AuthError::MissingHttpAuthentication { challenges } => {
1021                if let Some(resource_metadata) = &challenges.bearer_resource_metadata {
1022                    // `scope` is hard-coded to a vetted constant
1023                    // (`MCP_SCOPE`); only `resource_metadata` is derived
1024                    // from a header value, and `resolve_host` has already
1025                    // round-tripped it through the URI grammar. The quoted
1026                    // form follows RFC 6749 ยง3.3 / RFC 6750 ยง3.
1027                    let value = match &challenges.bearer_scope {
1028                        Some(scope) => format!(
1029                            "Bearer scope=\"{scope}\", resource_metadata=\"{resource_metadata}\"",
1030                        ),
1031                        None => format!("Bearer resource_metadata=\"{resource_metadata}\""),
1032                    };
1033                    match HeaderValue::from_str(&value) {
1034                        Ok(v) => {
1035                            headers.append(http::header::WWW_AUTHENTICATE, v);
1036                        }
1037                        Err(e) => {
1038                            warn!(
1039                                "skipping Bearer WWW-Authenticate challenge: invalid header \
1040                                 value derived from resource_metadata={resource_metadata:?}: {e}",
1041                            );
1042                        }
1043                    }
1044                }
1045                if challenges.include_basic {
1046                    headers.append(
1047                        http::header::WWW_AUTHENTICATE,
1048                        HeaderValue::from_static("Basic realm=Materialize"),
1049                    );
1050                }
1051                "unauthorized".to_string()
1052            }
1053            AuthError::OidcFailed(message) => message.clone(),
1054            _ => "unauthorized".to_string(),
1055        };
1056        (StatusCode::UNAUTHORIZED, headers, body).into_response()
1057    }
1058}
1059
1060// Simplified login handler
1061pub async fn handle_login(
1062    session: Option<Extension<TowerSession>>,
1063    Extension(adapter_client_rx): Extension<Delayed<Client>>,
1064    Extension(allowed_roles): Extension<AllowedRoles>,
1065    Json(LoginCredentials { username, password }): Json<LoginCredentials>,
1066) -> impl IntoResponse {
1067    // Enforce the listener's allowed_roles policy before doing any
1068    // authentication work, mirroring the check performed in `auth` for
1069    // header-based credentials. Without this, a session-based caller could
1070    // log in as a role that header-based callers are forbidden to use.
1071    if let Err(err) = check_role_allowed(&username, allowed_roles) {
1072        warn!(
1073            ?err,
1074            "HTTP login rejected: role not allowed on this listener"
1075        );
1076        return StatusCode::UNAUTHORIZED;
1077    }
1078    let Ok(adapter_client) = adapter_client_rx.clone().await else {
1079        return StatusCode::INTERNAL_SERVER_ERROR;
1080    };
1081    let authenticated = match adapter_client.authenticate(&username, &password).await {
1082        Ok(authenticated) => authenticated,
1083        Err(err) => {
1084            warn!(?err, "HTTP login failed authentication");
1085            return StatusCode::UNAUTHORIZED;
1086        }
1087    };
1088    // Create session data
1089    let session_data = TowerSessionData {
1090        username,
1091        created_at: SystemTime::now(),
1092        last_activity: SystemTime::now(),
1093        authenticated,
1094        authenticator_kind: mz_auth::AuthenticatorKind::Password,
1095    };
1096    // Store session data
1097    let session = session.and_then(|Extension(session)| Some(session));
1098    let Some(session) = session else {
1099        return StatusCode::INTERNAL_SERVER_ERROR;
1100    };
1101    match session.insert("data", &session_data).await {
1102        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
1103        Ok(_) => StatusCode::OK,
1104    }
1105}
1106
1107// Simplified logout handler
1108pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
1109    let session = session.and_then(|Extension(session)| Some(session));
1110    let Some(session) = session else {
1111        return StatusCode::INTERNAL_SERVER_ERROR;
1112    };
1113    // Delete session
1114    match session.delete().await {
1115        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
1116        Ok(_) => StatusCode::OK,
1117    }
1118}
1119
1120async fn http_auth(
1121    mut req: Request,
1122    next: Next,
1123    tls_enabled: bool,
1124    authenticator_kind: listeners::AuthenticatorKind,
1125    frontegg: Option<mz_frontegg_auth::Authenticator>,
1126    oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
1127    adapter_client_rx: Delayed<Client>,
1128    allowed_roles: AllowedRoles,
1129) -> Result<impl IntoResponse, AuthError> {
1130    let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
1131        Some(Credentials::Password {
1132            username: basic.username().to_owned(),
1133            password: Password(basic.password().to_owned()),
1134        })
1135    } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
1136        Some(Credentials::Token {
1137            token: bearer.token().to_owned(),
1138        })
1139    } else {
1140        None
1141    };
1142
1143    // Reuses an authenticated session if one already exists.
1144    // If credentials are provided, we perform a new authentication,
1145    // separate from the existing session.
1146    if creds.is_none()
1147        && let Some((session, session_data)) =
1148            maybe_get_authenticated_session(req.extensions().get::<TowerSession>()).await
1149    {
1150        let user = ensure_session_unexpired(session, session_data).await?;
1151        // Defense-in-depth: re-check the listener's `allowed_roles` policy on
1152        // every session-authenticated request. The same check runs at
1153        // `/api/login`, but enforcing it here too prevents a session minted
1154        // under a more permissive configuration (or a future bug in the login
1155        // path) from bypassing role restrictions.
1156        check_role_allowed(&user.name, allowed_roles)?;
1157        // Need this to set the user of the Adapter client.
1158        req.extensions_mut().insert(user);
1159        return Ok(next.run(req).await);
1160    }
1161
1162    // First, extract the username from the certificate, validating that the
1163    // connection matches the TLS configuration along the way.
1164    // Fall back to existing authentication methods.
1165    let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
1166    match (tls_enabled, &conn_protocol) {
1167        (false, ConnProtocol::Http) => {}
1168        (false, ConnProtocol::Https { .. }) => unreachable!(),
1169        (true, ConnProtocol::Http) => {
1170            let mut parts = req.uri().clone().into_parts();
1171            parts.scheme = Some(Scheme::HTTPS);
1172            return Ok(Redirect::permanent(
1173                &Uri::from_parts(parts)
1174                    .expect("it was already a URI, just changed the scheme")
1175                    .to_string(),
1176            )
1177            .into_response());
1178        }
1179        (true, ConnProtocol::Https { .. }) => {}
1180    }
1181    // If we've already passed some other auth, just use that.
1182    if req.extensions().get::<AuthedUser>().is_some() {
1183        return Ok(next.run(req).await);
1184    }
1185
1186    let path = req.uri().path();
1187    // Routes that advertise OAuth opt in by attaching an `McpOAuthConfig`
1188    // extension; the middleware stays path-agnostic. Routes that opt in also
1189    // get a `Basic` challenge so existing curl/Bearer-already users still see
1190    // a usable challenge. See `crate::http::oauth_metadata` for the discovery
1191    // document; the challenge and the discovery handler read the same
1192    // `McpOAuthConfig`, so they emit the same authorization server and the
1193    // same host for a given listener.
1194    let oauth_config = req
1195        .extensions()
1196        .get::<oauth_metadata::McpOAuthConfig>()
1197        .cloned();
1198    let include_basic = path == "/"
1199        || PROFILING_API_ENDPOINTS
1200            .iter()
1201            .any(|prefix| path.starts_with(prefix))
1202        || oauth_config.is_some();
1203    let (bearer_resource_metadata, bearer_scope) = if let Some(config) = &oauth_config
1204        && config.discovery.is_enabled()
1205    {
1206        (
1207            oauth_metadata::metadata_url(&req, config.http_host_name.as_deref()),
1208            Some(config.scope()),
1209        )
1210    } else {
1211        (None, None)
1212    };
1213    let challenges = WwwAuthenticateChallenges {
1214        include_basic,
1215        bearer_resource_metadata,
1216        bearer_scope,
1217    };
1218    let authenticator = get_authenticator(
1219        authenticator_kind,
1220        creds.as_ref(),
1221        frontegg,
1222        &oidc_rx,
1223        &adapter_client_rx,
1224    )
1225    .await;
1226
1227    // Only the Frontegg arm consumes `group_claim`; resolving it requires a
1228    // coordinator round-trip (`Command::GetSystemVars`), so skip it for
1229    // None/Password/OIDC paths. Otherwise unauth'd probes like `/api/livez`,
1230    // `/api/readyz`, and `/metrics` on a `None`-auth listener would couple
1231    // liveness to coordinator health. Mirrors the pgwire path.
1232    let group_claim = if matches!(authenticator, Authenticator::Frontegg(_)) {
1233        Some(group_claim_for(&adapter_client_rx).await)
1234    } else {
1235        None
1236    };
1237    let user = auth(
1238        &authenticator,
1239        creds,
1240        allowed_roles,
1241        &challenges,
1242        group_claim.as_deref(),
1243    )
1244    .await?;
1245
1246    // Add the authenticated user as an extension so downstream handlers can
1247    // inspect it if necessary.
1248    req.extensions_mut().insert(user);
1249
1250    // Run the request.
1251    Ok(next.run(req).await)
1252}
1253
1254async fn init_ws(
1255    WsState {
1256        frontegg,
1257        oidc_rx,
1258        authenticator_kind,
1259        adapter_client_rx,
1260        active_connection_counter,
1261        helm_chart_version,
1262        allowed_roles,
1263    }: WsState,
1264    existing_user: Option<ExistingUser>,
1265    peer_addr: IpAddr,
1266    ws: &mut WebSocket,
1267) -> Result<AuthedClient, anyhow::Error> {
1268    // TODO: Add a timeout here to prevent resource leaks by clients that
1269    // connect then never send a message.
1270    let ws_auth: WebSocketAuth = loop {
1271        let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
1272        match init_msg {
1273            Message::Text(data) => break serde_json::from_str(&data)?,
1274            Message::Binary(data) => break serde_json::from_slice(&data)?,
1275            // Handled automatically by the server.
1276            Message::Ping(_) => {
1277                continue;
1278            }
1279            Message::Pong(_) => {
1280                continue;
1281            }
1282            Message::Close(_) => {
1283                anyhow::bail!("closed");
1284            }
1285        }
1286    };
1287
1288    // If credentials are provided, we perform a new authentication,
1289    // separate from the existing session.
1290    let (creds, options) = match ws_auth {
1291        WebSocketAuth::Basic {
1292            user,
1293            password,
1294            options,
1295        } => {
1296            let creds = Credentials::Password {
1297                username: user,
1298                password,
1299            };
1300            (Some(creds), options)
1301        }
1302        WebSocketAuth::Bearer { token, options } => {
1303            let creds = Credentials::Token { token };
1304            (Some(creds), options)
1305        }
1306        WebSocketAuth::OptionsOnly { options } => (None, options),
1307    };
1308
1309    let user = match (existing_user, creds) {
1310        (Some(ExistingUser::XMaterializeUserHeader(_)), Some(_creds)) => {
1311            warn!("Unexpected bearer or basic auth provided when using user header");
1312            anyhow::bail!("unexpected")
1313        }
1314        (Some(ExistingUser::Session(user)), None) => {
1315            // Defense-in-depth: re-enforce the listener's `allowed_roles`
1316            // policy on session-authenticated WebSocket connections. The same
1317            // check runs at `/api/login`, but enforcing it here too prevents a
1318            // session minted under a more permissive configuration (or a
1319            // future bug in the login path) from bypassing role restrictions.
1320            check_role_allowed(&user.name, allowed_roles)?;
1321            user
1322        }
1323        (Some(ExistingUser::XMaterializeUserHeader(user)), None) => user,
1324        (_, Some(creds)) => {
1325            let authenticator = get_authenticator(
1326                authenticator_kind,
1327                Some(&creds),
1328                frontegg,
1329                &oidc_rx,
1330                &adapter_client_rx,
1331            )
1332            .await;
1333            // WebSocket init: no 401-with-challenge contract, the
1334            // client is reading WS frames, not parsing HTTP headers, so
1335            // we just suppress challenge emission entirely.
1336            let no_challenges = WwwAuthenticateChallenges::default();
1337            // See `http_auth`: only Frontegg uses `group_claim`, and the
1338            // fetch costs a coordinator round-trip.
1339            let group_claim = if matches!(authenticator, Authenticator::Frontegg(_)) {
1340                Some(group_claim_for(&adapter_client_rx).await)
1341            } else {
1342                None
1343            };
1344            let user = auth(
1345                &authenticator,
1346                Some(creds),
1347                allowed_roles,
1348                &no_challenges,
1349                group_claim.as_deref(),
1350            )
1351            .await?;
1352            user
1353        }
1354        (None, None) => anyhow::bail!("expected auth information"),
1355    };
1356
1357    let client = AuthedClient::new(
1358        &adapter_client_rx.clone().await?,
1359        user,
1360        peer_addr,
1361        active_connection_counter.clone(),
1362        helm_chart_version.clone(),
1363        |_session| (),
1364        options,
1365        SYSTEM_TIME.clone(),
1366    )
1367    .await?;
1368
1369    Ok(client)
1370}
1371
1372enum Credentials {
1373    Password {
1374        username: String,
1375        password: Password,
1376    },
1377    Token {
1378        token: String,
1379    },
1380}
1381
1382async fn get_authenticator(
1383    kind: listeners::AuthenticatorKind,
1384    creds: Option<&Credentials>,
1385    frontegg: Option<mz_frontegg_auth::Authenticator>,
1386    oidc_rx: &Delayed<mz_authenticator::GenericOidcAuthenticator>,
1387    adapter_client_rx: &Delayed<Client>,
1388) -> Authenticator {
1389    match kind {
1390        listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
1391            "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
1392        )),
1393        listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Sasl => {
1394            let client = adapter_client_rx.clone().await.expect("sender not dropped");
1395            Authenticator::Password(client)
1396        }
1397        listeners::AuthenticatorKind::Oidc => match creds {
1398            // Use the password authenticator if the credentials are password-based
1399            Some(Credentials::Password { .. }) => {
1400                let client = adapter_client_rx.clone().await.expect("sender not dropped");
1401                Authenticator::Password(client)
1402            }
1403            _ => Authenticator::Oidc(oidc_rx.clone().await.expect("sender not dropped")),
1404        },
1405        listeners::AuthenticatorKind::None => Authenticator::None,
1406    }
1407}
1408
1409/// Attempts to retrieve session data from a [`TowerSession`], if available.
1410/// Session data is present only if an authenticated session has been
1411/// established via [`handle_login`].
1412pub(crate) async fn maybe_get_authenticated_session(
1413    session: Option<&TowerSession>,
1414) -> Option<(&TowerSession, TowerSessionData)> {
1415    if let Some(session) = session {
1416        if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
1417            return Some((session, session_data));
1418        }
1419    }
1420    None
1421}
1422
1423/// Ensures the session is still valid by checking for expiration,
1424/// and returns the associated user if the session remains active.
1425pub(crate) async fn ensure_session_unexpired(
1426    session: &TowerSession,
1427    session_data: TowerSessionData,
1428) -> Result<AuthedUser, AuthError> {
1429    if session_data
1430        .last_activity
1431        .elapsed()
1432        .unwrap_or(Duration::MAX)
1433        > SESSION_DURATION
1434    {
1435        let _ = session.delete().await;
1436        return Err(AuthError::SessionExpired);
1437    }
1438    let mut updated_data = session_data.clone();
1439    updated_data.last_activity = SystemTime::now();
1440    session
1441        .insert("data", &updated_data)
1442        .await
1443        .map_err(|_| AuthError::FailedToUpdateSession)?;
1444
1445    Ok(AuthedUser {
1446        name: session_data.username,
1447        external_metadata_rx: None,
1448        authenticated: session_data.authenticated,
1449        authenticator_kind: session_data.authenticator_kind,
1450        groups: None,
1451    })
1452}
1453
1454async fn auth(
1455    authenticator: &Authenticator,
1456    creds: Option<Credentials>,
1457    allowed_roles: AllowedRoles,
1458    challenges: &WwwAuthenticateChallenges,
1459    group_claim: Option<&str>,
1460) -> Result<AuthedUser, AuthError> {
1461    let (name, external_metadata_rx, authenticated, groups) = match authenticator {
1462        Authenticator::Frontegg(frontegg) => match creds {
1463            Some(Credentials::Password { username, password }) => {
1464                let (auth_session, authenticated) = frontegg
1465                    .authenticate(&username, password.as_str(), group_claim)
1466                    .await?;
1467                let name = auth_session.user().into();
1468                let groups = auth_session.groups();
1469                let external_metadata_rx = Some(auth_session.external_metadata_rx());
1470                (name, external_metadata_rx, authenticated, groups)
1471            }
1472            Some(Credentials::Token { token }) => {
1473                let (claims, authenticated) =
1474                    frontegg.validate_access_token(&token, None, group_claim)?;
1475                let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
1476                    user_id: claims.user_id,
1477                    admin: claims.is_admin,
1478                });
1479                (
1480                    claims.user,
1481                    Some(external_metadata_rx),
1482                    authenticated,
1483                    claims.groups,
1484                )
1485            }
1486            None => {
1487                return Err(AuthError::MissingHttpAuthentication {
1488                    challenges: challenges.clone(),
1489                });
1490            }
1491        },
1492        Authenticator::Password(adapter_client) => match creds {
1493            Some(Credentials::Password { username, password }) => {
1494                let authenticated = adapter_client
1495                    .authenticate(&username, &password)
1496                    .await
1497                    .map_err(|_| AuthError::InvalidCredentials)?;
1498                (username, None, authenticated, None)
1499            }
1500            _ => {
1501                return Err(AuthError::MissingHttpAuthentication {
1502                    challenges: challenges.clone(),
1503                });
1504            }
1505        },
1506        Authenticator::Sasl(_) => {
1507            // We shouldn't ever end up here as the configuration is validated at startup.
1508            // If we do, it's a server misconfiguration.
1509            // Just in case, we return a 401 rather than panic.
1510            return Err(AuthError::MissingHttpAuthentication {
1511                challenges: challenges.clone(),
1512            });
1513        }
1514        Authenticator::Oidc(oidc) => match creds {
1515            Some(Credentials::Token { token }) => {
1516                let (mut claims, authenticated) = oidc
1517                    .authenticate(&token, None)
1518                    .await
1519                    .map_err(|e| AuthError::OidcFailed(e.to_string()))?;
1520                let name = std::mem::take(&mut claims.user);
1521                let groups = claims.groups.take();
1522                (name, None, authenticated, groups)
1523            }
1524            _ => {
1525                return Err(AuthError::MissingHttpAuthentication {
1526                    challenges: challenges.clone(),
1527                });
1528            }
1529        },
1530        Authenticator::None => {
1531            // If no authentication, use whatever is in the HTTP auth
1532            // header (without checking the password), or fall back to the
1533            // default user.
1534            let name = match creds {
1535                Some(Credentials::Password { username, .. }) => username,
1536                _ => HTTP_DEFAULT_USER.name.to_owned(),
1537            };
1538            (name, None, Authenticated, None)
1539        }
1540    };
1541
1542    check_role_allowed(&name, allowed_roles)?;
1543
1544    Ok(AuthedUser {
1545        name,
1546        external_metadata_rx,
1547        authenticated,
1548        authenticator_kind: authenticator.kind(),
1549        groups,
1550    })
1551}
1552
1553// TODO move this somewhere it can be shared with PGWIRE
1554fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1555    let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1556    // this is a superset of internal users
1557    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1558    let role_allowed = match allowed_roles {
1559        AllowedRoles::Normal => !is_reserved_user,
1560        AllowedRoles::Internal => is_internal_user,
1561        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1562    };
1563    if role_allowed {
1564        Ok(())
1565    } else {
1566        Err(AuthError::RoleDisallowed(name.to_owned()))
1567    }
1568}
1569
1570/// Default layers that should be applied to all routes, and should get applied to both the
1571/// internal http and external http routers.
1572trait DefaultLayers {
1573    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1574}
1575
1576impl DefaultLayers for Router {
1577    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1578        self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1579            .layer(metrics::PrometheusLayer::new(source, metrics))
1580    }
1581}
1582
1583/// Glue code to make [`tower`] work with [`axum`].
1584///
1585/// `axum` requires `Layer`s not return Errors, i.e. they must be `Result<_, Infallible>`,
1586/// instead you must return a type that can be converted into a response. `tower` on the other
1587/// hand does return Errors, so to make the two work together we need to convert our `tower` errors
1588/// into responses.
1589async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1590    if error.is::<tower::load_shed::error::Overloaded>() {
1591        return (
1592            StatusCode::TOO_MANY_REQUESTS,
1593            Cow::from("too many requests, try again later"),
1594        );
1595    }
1596
1597    // Note: This should be unreachable because at the time of writing our only use case is a
1598    // layer that emits `tower::load_shed::error::Overloaded`, which is handled above.
1599    (
1600        StatusCode::INTERNAL_SERVER_ERROR,
1601        Cow::from(format!("Unhandled internal error: {}", error)),
1602    )
1603}
1604
1605#[derive(Debug, Deserialize, Serialize, PartialEq)]
1606pub struct LoginCredentials {
1607    username: String,
1608    password: Password,
1609}
1610
1611#[derive(Debug, Clone, Serialize, Deserialize)]
1612pub struct TowerSessionData {
1613    username: String,
1614    created_at: SystemTime,
1615    last_activity: SystemTime,
1616    authenticated: Authenticated,
1617    authenticator_kind: mz_auth::AuthenticatorKind,
1618}
1619
1620#[cfg(test)]
1621mod tests {
1622    use super::{AllowedRoles, check_role_allowed};
1623
1624    #[mz_ore::test]
1625    fn test_check_role_allowed() {
1626        // Internal user
1627        assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1628        assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1629        assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1630
1631        // Internal user
1632        assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1633        assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1634        assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1635
1636        // Internal user
1637        assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1638        assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1639        assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1640
1641        // Normal user
1642        assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1643        assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1644        assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1645
1646        // Normal user
1647        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1648        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1649        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1650
1651        // Normal user
1652        assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1653        assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1654        assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1655
1656        // Denied by reserved role prefix
1657        assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1658        assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1659        assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1660
1661        // Denied by reserved role prefix
1662        assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1663        assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1664        assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1665
1666        // Denied by reserved role prefix
1667        assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1668        assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1669        assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1670
1671        // Denied by literal PUBLIC
1672        assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1673        assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1674        assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1675    }
1676}