Skip to main content

mz_environmentd/
http.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Embedded HTTP server.
11//!
12//! environmentd embeds an HTTP server for introspection into the running
13//! process. At the moment, its primary exports are Prometheus metrics, heap
14//! profiles, and catalog dumps.
15//!
16//! ## Authentication flow
17//!
18//! The server supports several authentication modes, controlled by the
19//! configured [`listeners::AuthenticatorKind`]. The general flow is:
20//!
21//! 1. **Identity resolution.** An authentication middleware runs on every
22//!    protected request and resolves the caller's identity via one of:
23//!    - **Credentials in headers.** The caller supplies a username/password or
24//!      token in the request headers. Supported by all [`listeners::AuthenticatorKind`]s.
25//!    - **Session reuse.** If the caller has an active authenticated session
26//!      (established via `POST /api/login`) and has not supplied credentials
27//!      in the request headers, the session is reused. Only available for
28//!      [`listeners::AuthenticatorKind::Password`] and [`listeners::AuthenticatorKind::Oidc`].
29//!    - **Trusted header injection.** A trusted upstream proxy (e.g. Teleport)
30//!      may inject the caller's identity into the request headers. Only available
31//!      for [`listeners::AuthenticatorKind::None`].
32//!
33//! 2. **Session initialization.** Once the caller's identity is known, an
34//!    adapter session is opened on their behalf. This happens as part of
35//!    request processing, after all middleware has run.
36//!
37//! 3. **Request handling.** The handler executes the request (e.g. runs SQL)
38//!    using the initialized adapter session.
39//!
40//! ### WebSocket
41//!
42//! The WebSocket flow is identical to the HTTP flow with two differences:
43//!
44//! - Credentials are not read from request headers. Instead, the first
45//!   message sent by the client is treated as the authentication message.
46//! - Session initialization (step 2) happens inside the WebSocket handler
47//!   itself, rather than as a separate middleware step.
48
49// Axum handlers must use async, but often don't actually use `await`.
50#![allow(clippy::unused_async)]
51
52use std::borrow::Cow;
53use std::collections::BTreeMap;
54use std::fmt::Debug;
55use std::net::{IpAddr, SocketAddr};
56use std::pin::Pin;
57use std::sync::Arc;
58use std::time::{Duration, SystemTime};
59
60use anyhow::Context;
61use axum::error_handling::HandleErrorLayer;
62use axum::extract::ws::{Message, WebSocket};
63use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
64use axum::middleware::{self, Next};
65use axum::response::{IntoResponse, Redirect, Response};
66use axum::{Extension, Json, Router, routing};
67use futures::future::{Shared, TryFutureExt};
68use headers::authorization::{Authorization, Basic, Bearer};
69use headers::{HeaderMapExt, HeaderName};
70use http::header::{AUTHORIZATION, CONTENT_TYPE};
71use http::uri::Scheme;
72use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
73use hyper_openssl::SslStream;
74use hyper_openssl::client::legacy::MaybeHttpsStream;
75use hyper_util::rt::TokioIo;
76use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
77use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
78use mz_auth::Authenticated;
79use mz_auth::password::Password;
80use mz_authenticator::Authenticator;
81use mz_controller::ReplicaHttpLocator;
82use mz_frontegg_auth::Error as FronteggError;
83use mz_http_util::DynamicFilterTarget;
84use mz_ore::cast::u64_to_usize;
85use mz_ore::metrics::MetricsRegistry;
86use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
87use mz_ore::str::StrExt;
88use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
89use mz_repr::user::ExternalUserMetadata;
90use mz_server_core::listeners::{self, AllowedRoles, HttpRoutesEnabled};
91use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
92use mz_sql::session::metadata::SessionMetadata;
93use mz_sql::session::user::{
94    HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
95};
96use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
97use openssl::ssl::Ssl;
98use prometheus::{
99    COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
100};
101use serde::{Deserialize, Serialize};
102use serde_json::json;
103use thiserror::Error;
104use tokio::io::AsyncWriteExt;
105use tokio::sync::oneshot::Receiver;
106use tokio::sync::{oneshot, watch};
107use tokio_metrics::TaskMetrics;
108use tower::limit::GlobalConcurrencyLimitLayer;
109use tower::{Service, ServiceBuilder};
110use tower_http::cors::{AllowOrigin, Any, CorsLayer};
111use tower_sessions::{
112    MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
113    SessionManagerLayer as TowerSessionManagerLayer,
114};
115use tracing::warn;
116
117use crate::BUILD_INFO;
118use crate::deployment::state::DeploymentStateHandle;
119use crate::http::sql::{ExistingUser, SqlError};
120
121mod catalog;
122mod cluster;
123mod console;
124mod mcp;
125pub mod mcp_metrics;
126mod memory;
127mod metrics;
128mod metrics_public;
129mod metrics_viz;
130mod probe;
131mod prometheus;
132mod root;
133mod sql;
134mod webhook;
135
136pub use metrics::Metrics;
137pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
138
139/// Maximum allowed size for a request.
140pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
141
142const SESSION_DURATION: Duration = Duration::from_secs(8 * 3600); // 8 hours
143
144const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
145
146#[derive(Debug)]
147pub struct HttpConfig {
148    pub source: &'static str,
149    pub tls: Option<ReloadingSslContext>,
150    pub authenticator_kind: listeners::AuthenticatorKind,
151    pub frontegg: Option<mz_frontegg_auth::Authenticator>,
152    pub oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
153    pub adapter_client_rx: Shared<Receiver<Client>>,
154    pub allowed_origin: AllowOrigin,
155    /// Raw list of allowed CORS origins, used by the MCP endpoints for
156    /// server-side Origin validation to defend against DNS rebinding.
157    pub allowed_origin_list: Vec<HeaderValue>,
158    pub active_connection_counter: ConnectionCounter,
159    pub helm_chart_version: Option<String>,
160    pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
161    pub metrics: Metrics,
162    pub metrics_registry: MetricsRegistry,
163    pub mcp_metrics: mcp_metrics::McpMetrics,
164    pub allowed_roles: AllowedRoles,
165    pub internal_route_config: Arc<InternalRouteConfig>,
166    pub routes_enabled: HttpRoutesEnabled,
167    /// Locator for cluster replica HTTP addresses, used for proxying requests.
168    pub replica_http_locator: Arc<ReplicaHttpLocator>,
169}
170
171#[derive(Debug, Clone)]
172pub struct InternalRouteConfig {
173    pub deployment_state_handle: DeploymentStateHandle,
174    pub internal_console_redirect_url: Option<String>,
175}
176
177#[derive(Clone)]
178pub struct WsState {
179    frontegg: Option<mz_frontegg_auth::Authenticator>,
180    oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
181    authenticator_kind: listeners::AuthenticatorKind,
182    adapter_client_rx: Delayed<mz_adapter::Client>,
183    active_connection_counter: ConnectionCounter,
184    helm_chart_version: Option<String>,
185    allowed_roles: AllowedRoles,
186}
187
188#[derive(Clone)]
189pub struct WebhookState {
190    adapter_client_rx: Delayed<mz_adapter::Client>,
191    webhook_cache: WebhookAppenderCache,
192}
193
194#[derive(Clone, Debug)]
195struct HelmChartVersion(Option<String>);
196
197#[derive(Debug)]
198pub struct HttpServer {
199    tls: Option<ReloadingSslContext>,
200    router: Router,
201}
202
203impl HttpServer {
204    pub fn new(
205        HttpConfig {
206            source,
207            tls,
208            authenticator_kind,
209            frontegg,
210            oidc_rx,
211            adapter_client_rx,
212            allowed_origin,
213            allowed_origin_list,
214            active_connection_counter,
215            helm_chart_version,
216            concurrent_webhook_req,
217            metrics,
218            metrics_registry,
219            mcp_metrics,
220            allowed_roles,
221            internal_route_config,
222            routes_enabled,
223            replica_http_locator,
224        }: HttpConfig,
225    ) -> HttpServer {
226        let tls_enabled = tls.is_some();
227        let webhook_cache = WebhookAppenderCache::new();
228
229        // Create secure session store and manager
230        let session_store = TowerSessionMemoryStore::default();
231        let session_layer = TowerSessionManagerLayer::new(session_store)
232            .with_secure(tls_enabled) // Enforce HTTPS
233            .with_same_site(tower_sessions::cookie::SameSite::Strict) // Prevent CSRF
234            .with_http_only(true) // Prevent XSS
235            .with_name("mz_session") // Custom cookie name
236            .with_path("/"); // Set cookie path
237
238        let frontegg_middleware = frontegg.clone();
239        let oidc_middleware_rx = oidc_rx.clone();
240        let adapter_client_middleware_rx = adapter_client_rx.clone();
241        let auth_middleware = middleware::from_fn(move |req, next| {
242            let frontegg = frontegg_middleware.clone();
243            let oidc_rx = oidc_middleware_rx.clone();
244            let adapter_client_rx = adapter_client_middleware_rx.clone();
245            async move {
246                http_auth(
247                    req,
248                    next,
249                    tls_enabled,
250                    authenticator_kind,
251                    frontegg,
252                    oidc_rx,
253                    adapter_client_rx,
254                    allowed_roles,
255                )
256                .await
257            }
258        });
259
260        let mut router = Router::new();
261        let mut base_router = Router::new();
262        let cluster_proxy_config = Arc::new(cluster::ClusterProxyConfig::new(Arc::clone(
263            &replica_http_locator,
264        )));
265        if routes_enabled.base {
266            base_router = base_router
267                .route(
268                    "/",
269                    routing::get(move || async move { root::handle_home(routes_enabled).await }),
270                )
271                .route("/api/sql", routing::post(sql::handle_sql))
272                .route("/memory", routing::get(memory::handle_memory))
273                .route(
274                    "/hierarchical-memory",
275                    routing::get(memory::handle_hierarchical_memory),
276                )
277                .route(
278                    "/metrics-viz",
279                    routing::get(metrics_viz::handle_metrics_viz),
280                )
281                .route("/static/{*path}", routing::get(root::handle_static))
282                .route(
283                    "/metrics/public",
284                    routing::get(metrics_public::handle_public_metrics),
285                )
286                .layer(Extension(metrics_registry.clone()))
287                .layer(Extension(Arc::clone(&cluster_proxy_config)));
288
289            let mut ws_router = Router::new()
290                .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
291                .with_state(WsState {
292                    frontegg,
293                    oidc_rx: oidc_rx.clone(),
294                    authenticator_kind,
295                    adapter_client_rx: adapter_client_rx.clone(),
296                    active_connection_counter: active_connection_counter.clone(),
297                    helm_chart_version: helm_chart_version.clone(),
298                    allowed_roles,
299                });
300            if let listeners::AuthenticatorKind::None = authenticator_kind {
301                ws_router = ws_router.layer(middleware::from_fn_with_state(
302                    allowed_roles,
303                    x_materialize_user_header_auth,
304                ));
305            }
306            router = router.merge(ws_router);
307        }
308        if routes_enabled.profiling {
309            base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
310        }
311
312        if routes_enabled.webhook {
313            let webhook_router = Router::new()
314                .route(
315                    "/api/webhook/{:database}/{:schema}/{:id}",
316                    routing::post(webhook::handle_webhook),
317                )
318                .with_state(WebhookState {
319                    adapter_client_rx: adapter_client_rx.clone(),
320                    webhook_cache,
321                })
322                .layer(
323                    tower_http::decompression::RequestDecompressionLayer::new()
324                        .gzip(true)
325                        .deflate(true)
326                        .br(true)
327                        .zstd(true),
328                )
329                .layer(
330                    CorsLayer::new()
331                        .allow_methods(Method::POST)
332                        .allow_origin(AllowOrigin::mirror_request())
333                        .allow_headers(Any),
334                )
335                .layer(
336                    ServiceBuilder::new()
337                        .layer(HandleErrorLayer::new(handle_load_error))
338                        .load_shed()
339                        .layer(GlobalConcurrencyLimitLayer::with_semaphore(
340                            concurrent_webhook_req,
341                        )),
342                );
343            router = router.merge(webhook_router);
344        }
345
346        if routes_enabled.internal {
347            let console_config = Arc::new(console::ConsoleProxyConfig::new(
348                internal_route_config.internal_console_redirect_url.clone(),
349                "/internal-console".to_string(),
350            ));
351            base_router = base_router
352                .route(
353                    "/api/opentelemetry/config",
354                    routing::put({
355                        move |_: axum::Json<DynamicFilterTarget>| async {
356                            (
357                                StatusCode::BAD_REQUEST,
358                                "This endpoint has been replaced. \
359                            Use the `opentelemetry_filter` system variable."
360                                    .to_string(),
361                            )
362                        }
363                    }),
364                )
365                .route(
366                    "/api/stderr/config",
367                    routing::put({
368                        move |_: axum::Json<DynamicFilterTarget>| async {
369                            (
370                                StatusCode::BAD_REQUEST,
371                                "This endpoint has been replaced. \
372                            Use the `log_filter` system variable."
373                                    .to_string(),
374                            )
375                        }
376                    }),
377                )
378                .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
379                .route(
380                    "/api/catalog/dump",
381                    routing::get(catalog::handle_catalog_dump),
382                )
383                .route(
384                    "/api/catalog/check",
385                    routing::get(catalog::handle_catalog_check),
386                )
387                .route(
388                    "/api/catalog/inject-audit-events",
389                    routing::post(catalog::handle_inject_audit_events),
390                )
391                .route(
392                    "/api/coordinator/check",
393                    routing::get(catalog::handle_coordinator_check),
394                )
395                .route(
396                    "/api/coordinator/dump",
397                    routing::get(catalog::handle_coordinator_dump),
398                )
399                .route(
400                    "/internal-console",
401                    routing::get(|| async { Redirect::temporary("/internal-console/") }),
402                )
403                .route(
404                    "/internal-console/{*path}",
405                    routing::get(console::handle_internal_console),
406                )
407                .route(
408                    "/internal-console/",
409                    routing::get(console::handle_internal_console),
410                )
411                .layer(Extension(console_config));
412
413            // Cluster HTTP proxy routes.
414            base_router = base_router
415                .route("/clusters", routing::get(cluster::handle_clusters))
416                .route(
417                    "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/",
418                    routing::any(cluster::handle_cluster_proxy_root),
419                )
420                .route(
421                    "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/{*path}",
422                    routing::any(cluster::handle_cluster_proxy),
423                )
424                .layer(Extension(Arc::clone(&cluster_proxy_config)));
425
426            let leader_router = Router::new()
427                .route("/api/leader/status", routing::get(handle_leader_status))
428                .route("/api/leader/promote", routing::post(handle_leader_promote))
429                .route(
430                    "/api/leader/skip-catchup",
431                    routing::post(handle_leader_skip_catchup),
432                )
433                .layer(auth_middleware.clone())
434                .with_state(internal_route_config.deployment_state_handle.clone());
435            router = router.merge(leader_router);
436        }
437
438        if routes_enabled.metrics {
439            // Clone into the closure so the outer `metrics_registry` binding
440            // stays available for other route blocks below (e.g. MCP metric
441            // registration).
442            let metrics_registry_for_handler = metrics_registry.clone();
443            let metrics_router = Router::new()
444                .route(
445                    "/metrics",
446                    routing::get(move |headers: HeaderMap| async move {
447                        mz_http_util::handle_prometheus(&metrics_registry_for_handler, headers)
448                            .await
449                    }),
450                )
451                .route(
452                    "/metrics/mz_usage",
453                    routing::get(|client: AuthedClient, headers: HeaderMap| async move {
454                        let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
455                        mz_http_util::handle_prometheus(&registry, headers).await
456                    }),
457                )
458                .route(
459                    "/metrics/mz_frontier",
460                    routing::get(|client: AuthedClient, headers: HeaderMap| async move {
461                        let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
462                        mz_http_util::handle_prometheus(&registry, headers).await
463                    }),
464                )
465                .route(
466                    "/metrics/mz_compute",
467                    routing::get(|client: AuthedClient, headers: HeaderMap| async move {
468                        let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
469                        mz_http_util::handle_prometheus(&registry, headers).await
470                    }),
471                )
472                .route(
473                    "/metrics/mz_storage",
474                    routing::get(|client: AuthedClient, headers: HeaderMap| async move {
475                        let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
476                        mz_http_util::handle_prometheus(&registry, headers).await
477                    }),
478                )
479                .route(
480                    "/api/livez",
481                    routing::get(mz_http_util::handle_liveness_check),
482                )
483                .route("/api/readyz", routing::get(probe::handle_ready))
484                .layer(auth_middleware.clone())
485                .layer(Extension(adapter_client_rx.clone()))
486                .layer(Extension(active_connection_counter.clone()))
487                .layer(Extension(HelmChartVersion(helm_chart_version.clone())));
488            router = router.merge(metrics_router);
489        }
490
491        if routes_enabled.console_config {
492            let console_config_router = Router::new()
493                .route(
494                    "/api/console/config",
495                    routing::get(console::handle_console_config),
496                )
497                .layer(Extension(adapter_client_rx.clone()))
498                .layer(Extension(active_connection_counter.clone()));
499            router = router.merge(console_config_router);
500        }
501
502        // MCP (Model Context Protocol) endpoints
503        // Enabled via runtime `routes_enabled.mcp_agent` and `routes_enabled.mcp_developer` configuration
504        if routes_enabled.mcp_agent || routes_enabled.mcp_developer {
505            use tracing::info;
506
507            let mut mcp_router = Router::new();
508
509            if routes_enabled.mcp_agent {
510                info!("Enabling MCP agent endpoint: /api/mcp/agent");
511                mcp_router = mcp_router.route(
512                    "/api/mcp/agent",
513                    routing::post(mcp::handle_mcp_agent).get(mcp::handle_mcp_method_not_allowed),
514                );
515            }
516
517            if routes_enabled.mcp_developer {
518                info!("Enabling MCP developer endpoint: /api/mcp/developer");
519                mcp_router = mcp_router.route(
520                    "/api/mcp/developer",
521                    routing::post(mcp::handle_mcp_developer)
522                        .get(mcp::handle_mcp_method_not_allowed),
523                );
524            }
525
526            // The MCP handlers perform a server-side Origin check against this
527            // allowlist to defend against DNS rebinding attacks (see
528            // database-issues#11311). The CorsLayer alone is not enough: in a
529            // DNS rebinding attack the browser considers the request
530            // same-origin, so no preflight fires and CORS enforcement is
531            // bypassed.
532            let mcp_allowed_origins = Arc::new(allowed_origin_list.clone());
533            mcp_router = mcp_router
534                .layer(auth_middleware.clone())
535                .layer(Extension(adapter_client_rx.clone()))
536                .layer(Extension(active_connection_counter.clone()))
537                .layer(Extension(HelmChartVersion(helm_chart_version.clone())))
538                .layer(Extension(mcp_allowed_origins))
539                .layer(Extension(mcp_metrics))
540                .layer(
541                    CorsLayer::new()
542                        .allow_methods(Method::POST)
543                        .allow_origin(allowed_origin.clone())
544                        .allow_headers([AUTHORIZATION, CONTENT_TYPE]),
545                );
546            router = router.merge(mcp_router);
547        }
548
549        base_router = base_router
550            .layer(auth_middleware.clone())
551            .layer(Extension(adapter_client_rx.clone()))
552            .layer(Extension(active_connection_counter.clone()))
553            .layer(Extension(HelmChartVersion(helm_chart_version)))
554            .layer(
555                CorsLayer::new()
556                    .allow_credentials(false)
557                    .allow_headers([
558                        AUTHORIZATION,
559                        CONTENT_TYPE,
560                        HeaderName::from_static("x-materialize-version"),
561                    ])
562                    .allow_methods(Any)
563                    .allow_origin(allowed_origin)
564                    .expose_headers(Any)
565                    .max_age(Duration::from_secs(60) * 60),
566            );
567
568        match authenticator_kind {
569            listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Oidc => {
570                base_router = base_router.layer(session_layer.clone());
571
572                let login_router = Router::new()
573                    .route("/api/login", routing::post(handle_login))
574                    .route("/api/logout", routing::post(handle_logout))
575                    .layer(Extension(adapter_client_rx))
576                    .layer(Extension(allowed_roles));
577                router = router.merge(login_router).layer(session_layer);
578            }
579            listeners::AuthenticatorKind::None => {
580                base_router = base_router.layer(middleware::from_fn_with_state(
581                    allowed_roles,
582                    x_materialize_user_header_auth,
583                ));
584            }
585            _ => {}
586        }
587
588        router = router
589            .merge(base_router)
590            .apply_default_layers(source, metrics);
591
592        HttpServer { tls, router }
593    }
594}
595
596impl Server for HttpServer {
597    const NAME: &'static str = "http";
598
599    fn handle_connection(
600        &self,
601        conn: Connection,
602        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
603    ) -> ConnectionHandler {
604        let router = self.router.clone();
605        let tls_context = self.tls.clone();
606        let mut conn = TokioIo::new(conn);
607
608        Box::pin(async {
609            let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
610            let peer_addr = conn
611                .inner_mut()
612                .take_proxy_header_address()
613                .await
614                .map(|a| a.source)
615                .unwrap_or(direct_peer_addr);
616
617            let (conn, conn_protocol) = match tls_context {
618                Some(tls_context) => {
619                    let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
620                    if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
621                        let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
622                        return Err(e.into());
623                    }
624                    (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
625                }
626                _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
627            };
628            let mut make_tower_svc = router
629                .layer(Extension(conn_protocol))
630                .into_make_service_with_connect_info::<SocketAddr>();
631            let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
632            let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
633            let http = hyper::server::conn::http1::Builder::new();
634            http.serve_connection(conn, hyper_svc)
635                .with_upgrades()
636                .err_into()
637                .await
638        })
639    }
640}
641
642pub async fn handle_leader_status(
643    State(deployment_state_handle): State<DeploymentStateHandle>,
644) -> impl IntoResponse {
645    let status = deployment_state_handle.status();
646    (StatusCode::OK, Json(json!({ "status": status })))
647}
648
649pub async fn handle_leader_promote(
650    State(deployment_state_handle): State<DeploymentStateHandle>,
651) -> impl IntoResponse {
652    match deployment_state_handle.try_promote() {
653        Ok(()) => {
654            // TODO(benesch): the body here is redundant. Should just return
655            // 204.
656            let status = StatusCode::OK;
657            let body = Json(json!({
658                "result": "Success",
659            }));
660            (status, body)
661        }
662        Err(()) => {
663            // TODO(benesch): the nesting here is redundant given the error
664            // code. Should just return the `{"message": "..."}` object.
665            let status = StatusCode::BAD_REQUEST;
666            let body = Json(json!({
667                "result": {"Failure": {"message": "cannot promote leader while initializing"}},
668            }));
669            (status, body)
670        }
671    }
672}
673
674pub async fn handle_leader_skip_catchup(
675    State(deployment_state_handle): State<DeploymentStateHandle>,
676) -> impl IntoResponse {
677    match deployment_state_handle.try_skip_catchup() {
678        Ok(()) => StatusCode::NO_CONTENT.into_response(),
679        Err(()) => {
680            let status = StatusCode::BAD_REQUEST;
681            let body = Json(json!({
682                "message": "cannot skip catchup in this phase of initialization; try again later",
683            }));
684            (status, body).into_response()
685        }
686    }
687}
688
689async fn x_materialize_user_header_auth(
690    State(allowed_roles): State<AllowedRoles>,
691    mut req: Request,
692    next: Next,
693) -> impl IntoResponse {
694    // TODO migrate teleport to basic auth and remove this.
695    if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
696        let username = match username {
697            Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
698            _ => {
699                return Err(AuthError::MismatchedUser(format!(
700                    "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
701                )));
702            }
703        };
704        // Enforce the listener's `allowed_roles` policy here. Without this,
705        // a listener with `authenticator_kind=None` and `allowed_roles=Normal`
706        // would let any caller assert `x-materialize-user: mz_system` and
707        // bypass the role restriction.
708        check_role_allowed(&username, allowed_roles)?;
709        req.extensions_mut().insert(AuthedUser {
710            name: username,
711            external_metadata_rx: None,
712            authenticated: Authenticated,
713            authenticator_kind: mz_auth::AuthenticatorKind::None,
714            groups: None,
715        });
716    }
717    Ok(next.run(req).await)
718}
719
720type Delayed<T> = Shared<oneshot::Receiver<T>>;
721
722#[derive(Clone)]
723enum ConnProtocol {
724    Http,
725    Https,
726}
727
728#[derive(Clone, Debug)]
729pub struct AuthedUser {
730    name: String,
731    external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
732    authenticated: Authenticated,
733    authenticator_kind: mz_auth::AuthenticatorKind,
734    /// Groups from JWT claims for OIDC group-to-role sync.
735    groups: Option<Vec<String>>,
736}
737
738pub struct AuthedClient {
739    pub client: SessionClient,
740    pub connection_guard: Option<ConnectionHandle>,
741}
742
743impl AuthedClient {
744    async fn new<F>(
745        adapter_client: &Client,
746        user: AuthedUser,
747        peer_addr: IpAddr,
748        active_connection_counter: ConnectionCounter,
749        helm_chart_version: Option<String>,
750        session_config: F,
751        options: BTreeMap<String, String>,
752        now: NowFn,
753    ) -> Result<Self, AdapterError>
754    where
755        F: FnOnce(&mut AdapterSession),
756    {
757        let conn_id = adapter_client.new_conn_id()?;
758        let mut session = adapter_client.new_session(
759            AdapterSessionConfig {
760                conn_id,
761                uuid: epoch_to_uuid_v7(&(now)()),
762                user: user.name,
763                client_ip: Some(peer_addr),
764                external_metadata_rx: user.external_metadata_rx,
765                helm_chart_version,
766                authenticator_kind: user.authenticator_kind,
767                groups: user.groups,
768            },
769            user.authenticated,
770        );
771        let connection_guard = active_connection_counter.allocate_connection(session.user())?;
772
773        session_config(&mut session);
774        let system_vars = adapter_client.get_system_vars().await;
775        for (key, val) in options {
776            const LOCAL: bool = false;
777            if let Err(err) =
778                session
779                    .vars_mut()
780                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
781            {
782                session.add_notice(AdapterNotice::BadStartupSetting {
783                    name: key.to_string(),
784                    reason: err.to_string(),
785                })
786            }
787        }
788        let adapter_client = adapter_client.startup(session).await?;
789        Ok(AuthedClient {
790            client: adapter_client,
791            connection_guard,
792        })
793    }
794}
795
796impl<S> FromRequestParts<S> for AuthedClient
797where
798    S: Send + Sync,
799{
800    type Rejection = Response;
801
802    async fn from_request_parts(
803        req: &mut http::request::Parts,
804        state: &S,
805    ) -> Result<Self, Self::Rejection> {
806        #[derive(Debug, Default, Deserialize)]
807        struct Params {
808            #[serde(default)]
809            options: String,
810        }
811        let params: Query<Params> = Query::from_request_parts(req, state)
812            .await
813            .unwrap_or_default();
814
815        let peer_addr = req
816            .extensions
817            .get::<ConnectInfo<SocketAddr>>()
818            .expect("ConnectInfo extension guaranteed to exist")
819            .0
820            .ip();
821
822        let user = req.extensions.get::<AuthedUser>().unwrap();
823        let adapter_client = req
824            .extensions
825            .get::<Delayed<mz_adapter::Client>>()
826            .unwrap()
827            .clone();
828        let adapter_client = adapter_client.await.map_err(|_| {
829            (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
830        })?;
831        let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
832        let helm_chart_version = req
833            .extensions
834            .get::<HelmChartVersion>()
835            .map(|h| h.0.clone())
836            .unwrap_or(None);
837
838        let options = if params.options.is_empty() {
839            // It's possible 'options' simply wasn't provided, we don't want that to
840            // count as a failure to deserialize
841            BTreeMap::<String, String>::default()
842        } else {
843            match serde_json::from_str(&params.options) {
844                Ok(options) => options,
845                Err(_e) => {
846                    // If we fail to deserialize options, fail the request.
847                    let code = StatusCode::BAD_REQUEST;
848                    let msg = format!("Failed to deserialize {} map", "options".quoted());
849                    return Err((code, msg).into_response());
850                }
851            }
852        };
853
854        let client = AuthedClient::new(
855            &adapter_client,
856            user.clone(),
857            peer_addr,
858            active_connection_counter.clone(),
859            helm_chart_version,
860            |session| {
861                session
862                    .vars_mut()
863                    .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
864                    .expect("known to exist")
865            },
866            options,
867            SYSTEM_TIME.clone(),
868        )
869        .await
870        .map_err(|e| {
871            let status = match e {
872                AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
873                    StatusCode::FORBIDDEN
874                }
875                _ => StatusCode::INTERNAL_SERVER_ERROR,
876            };
877            (status, Json(SqlError::from(e))).into_response()
878        })?;
879
880        Ok(client)
881    }
882}
883
884#[derive(Debug, Error)]
885pub(crate) enum AuthError {
886    #[error("role dissallowed")]
887    RoleDisallowed(String),
888    #[error("{0}")]
889    Frontegg(#[from] FronteggError),
890    #[error("missing authorization header")]
891    MissingHttpAuthentication {
892        include_www_authenticate_header: bool,
893    },
894    #[error("{0}")]
895    MismatchedUser(String),
896    #[error("session expired")]
897    SessionExpired,
898    #[error("failed to update session")]
899    FailedToUpdateSession,
900    #[error("invalid credentials")]
901    InvalidCredentials,
902    /// Payload is `OidcError`'s sanitized `Display` (no expected-values leaks).
903    #[error("{0}")]
904    OidcFailed(String),
905}
906
907impl IntoResponse for AuthError {
908    fn into_response(self) -> Response {
909        warn!("HTTP request failed authentication: {}", self);
910        let mut headers = HeaderMap::new();
911        // We omit most detail from the error message we send to the client, to
912        // avoid giving attackers unnecessary information. `OidcFailed` is the
913        // exception — its payload is a sanitized `OidcError::Display` that the
914        // console embeds in the login-page error.
915        let body = match &self {
916            AuthError::MissingHttpAuthentication {
917                include_www_authenticate_header,
918            } if *include_www_authenticate_header => {
919                headers.insert(
920                    http::header::WWW_AUTHENTICATE,
921                    HeaderValue::from_static("Basic realm=Materialize"),
922                );
923                "unauthorized".to_string()
924            }
925            AuthError::OidcFailed(message) => message.clone(),
926            _ => "unauthorized".to_string(),
927        };
928        (StatusCode::UNAUTHORIZED, headers, body).into_response()
929    }
930}
931
932// Simplified login handler
933pub async fn handle_login(
934    session: Option<Extension<TowerSession>>,
935    Extension(adapter_client_rx): Extension<Delayed<Client>>,
936    Extension(allowed_roles): Extension<AllowedRoles>,
937    Json(LoginCredentials { username, password }): Json<LoginCredentials>,
938) -> impl IntoResponse {
939    // Enforce the listener's allowed_roles policy before doing any
940    // authentication work, mirroring the check performed in `auth` for
941    // header-based credentials. Without this, a session-based caller could
942    // log in as a role that header-based callers are forbidden to use.
943    if let Err(err) = check_role_allowed(&username, allowed_roles) {
944        warn!(
945            ?err,
946            "HTTP login rejected: role not allowed on this listener"
947        );
948        return StatusCode::UNAUTHORIZED;
949    }
950    let Ok(adapter_client) = adapter_client_rx.clone().await else {
951        return StatusCode::INTERNAL_SERVER_ERROR;
952    };
953    let authenticated = match adapter_client.authenticate(&username, &password).await {
954        Ok(authenticated) => authenticated,
955        Err(err) => {
956            warn!(?err, "HTTP login failed authentication");
957            return StatusCode::UNAUTHORIZED;
958        }
959    };
960    // Create session data
961    let session_data = TowerSessionData {
962        username,
963        created_at: SystemTime::now(),
964        last_activity: SystemTime::now(),
965        authenticated,
966        authenticator_kind: mz_auth::AuthenticatorKind::Password,
967    };
968    // Store session data
969    let session = session.and_then(|Extension(session)| Some(session));
970    let Some(session) = session else {
971        return StatusCode::INTERNAL_SERVER_ERROR;
972    };
973    match session.insert("data", &session_data).await {
974        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
975        Ok(_) => StatusCode::OK,
976    }
977}
978
979// Simplified logout handler
980pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
981    let session = session.and_then(|Extension(session)| Some(session));
982    let Some(session) = session else {
983        return StatusCode::INTERNAL_SERVER_ERROR;
984    };
985    // Delete session
986    match session.delete().await {
987        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
988        Ok(_) => StatusCode::OK,
989    }
990}
991
992async fn http_auth(
993    mut req: Request,
994    next: Next,
995    tls_enabled: bool,
996    authenticator_kind: listeners::AuthenticatorKind,
997    frontegg: Option<mz_frontegg_auth::Authenticator>,
998    oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
999    adapter_client_rx: Delayed<Client>,
1000    allowed_roles: AllowedRoles,
1001) -> Result<impl IntoResponse, AuthError> {
1002    let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
1003        Some(Credentials::Password {
1004            username: basic.username().to_owned(),
1005            password: Password(basic.password().to_owned()),
1006        })
1007    } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
1008        Some(Credentials::Token {
1009            token: bearer.token().to_owned(),
1010        })
1011    } else {
1012        None
1013    };
1014
1015    // Reuses an authenticated session if one already exists.
1016    // If credentials are provided, we perform a new authentication,
1017    // separate from the existing session.
1018    if creds.is_none()
1019        && let Some((session, session_data)) =
1020            maybe_get_authenticated_session(req.extensions().get::<TowerSession>()).await
1021    {
1022        let user = ensure_session_unexpired(session, session_data).await?;
1023        // Defense-in-depth: re-check the listener's `allowed_roles` policy on
1024        // every session-authenticated request. The same check runs at
1025        // `/api/login`, but enforcing it here too prevents a session minted
1026        // under a more permissive configuration (or a future bug in the login
1027        // path) from bypassing role restrictions.
1028        check_role_allowed(&user.name, allowed_roles)?;
1029        // Need this to set the user of the Adapter client.
1030        req.extensions_mut().insert(user);
1031        return Ok(next.run(req).await);
1032    }
1033
1034    // First, extract the username from the certificate, validating that the
1035    // connection matches the TLS configuration along the way.
1036    // Fall back to existing authentication methods.
1037    let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
1038    match (tls_enabled, &conn_protocol) {
1039        (false, ConnProtocol::Http) => {}
1040        (false, ConnProtocol::Https { .. }) => unreachable!(),
1041        (true, ConnProtocol::Http) => {
1042            let mut parts = req.uri().clone().into_parts();
1043            parts.scheme = Some(Scheme::HTTPS);
1044            return Ok(Redirect::permanent(
1045                &Uri::from_parts(parts)
1046                    .expect("it was already a URI, just changed the scheme")
1047                    .to_string(),
1048            )
1049            .into_response());
1050        }
1051        (true, ConnProtocol::Https { .. }) => {}
1052    }
1053    // If we've already passed some other auth, just use that.
1054    if req.extensions().get::<AuthedUser>().is_some() {
1055        return Ok(next.run(req).await);
1056    }
1057
1058    let path = req.uri().path();
1059    let include_www_authenticate_header = path == "/"
1060        || PROFILING_API_ENDPOINTS
1061            .iter()
1062            .any(|prefix| path.starts_with(prefix));
1063    let authenticator = get_authenticator(
1064        authenticator_kind,
1065        creds.as_ref(),
1066        frontegg,
1067        &oidc_rx,
1068        &adapter_client_rx,
1069    )
1070    .await;
1071
1072    let user = auth(
1073        &authenticator,
1074        creds,
1075        allowed_roles,
1076        include_www_authenticate_header,
1077    )
1078    .await?;
1079
1080    // Add the authenticated user as an extension so downstream handlers can
1081    // inspect it if necessary.
1082    req.extensions_mut().insert(user);
1083
1084    // Run the request.
1085    Ok(next.run(req).await)
1086}
1087
1088async fn init_ws(
1089    WsState {
1090        frontegg,
1091        oidc_rx,
1092        authenticator_kind,
1093        adapter_client_rx,
1094        active_connection_counter,
1095        helm_chart_version,
1096        allowed_roles,
1097    }: WsState,
1098    existing_user: Option<ExistingUser>,
1099    peer_addr: IpAddr,
1100    ws: &mut WebSocket,
1101) -> Result<AuthedClient, anyhow::Error> {
1102    // TODO: Add a timeout here to prevent resource leaks by clients that
1103    // connect then never send a message.
1104    let ws_auth: WebSocketAuth = loop {
1105        let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
1106        match init_msg {
1107            Message::Text(data) => break serde_json::from_str(&data)?,
1108            Message::Binary(data) => break serde_json::from_slice(&data)?,
1109            // Handled automatically by the server.
1110            Message::Ping(_) => {
1111                continue;
1112            }
1113            Message::Pong(_) => {
1114                continue;
1115            }
1116            Message::Close(_) => {
1117                anyhow::bail!("closed");
1118            }
1119        }
1120    };
1121
1122    // If credentials are provided, we perform a new authentication,
1123    // separate from the existing session.
1124    let (creds, options) = match ws_auth {
1125        WebSocketAuth::Basic {
1126            user,
1127            password,
1128            options,
1129        } => {
1130            let creds = Credentials::Password {
1131                username: user,
1132                password,
1133            };
1134            (Some(creds), options)
1135        }
1136        WebSocketAuth::Bearer { token, options } => {
1137            let creds = Credentials::Token { token };
1138            (Some(creds), options)
1139        }
1140        WebSocketAuth::OptionsOnly { options } => (None, options),
1141    };
1142
1143    let user = match (existing_user, creds) {
1144        (Some(ExistingUser::XMaterializeUserHeader(_)), Some(_creds)) => {
1145            warn!("Unexpected bearer or basic auth provided when using user header");
1146            anyhow::bail!("unexpected")
1147        }
1148        (Some(ExistingUser::Session(user)), None) => {
1149            // Defense-in-depth: re-enforce the listener's `allowed_roles`
1150            // policy on session-authenticated WebSocket connections. The same
1151            // check runs at `/api/login`, but enforcing it here too prevents a
1152            // session minted under a more permissive configuration (or a
1153            // future bug in the login path) from bypassing role restrictions.
1154            check_role_allowed(&user.name, allowed_roles)?;
1155            user
1156        }
1157        (Some(ExistingUser::XMaterializeUserHeader(user)), None) => user,
1158        (_, Some(creds)) => {
1159            let authenticator = get_authenticator(
1160                authenticator_kind,
1161                Some(&creds),
1162                frontegg,
1163                &oidc_rx,
1164                &adapter_client_rx,
1165            )
1166            .await;
1167            let user = auth(&authenticator, Some(creds), allowed_roles, false).await?;
1168            user
1169        }
1170        (None, None) => anyhow::bail!("expected auth information"),
1171    };
1172
1173    let client = AuthedClient::new(
1174        &adapter_client_rx.clone().await?,
1175        user,
1176        peer_addr,
1177        active_connection_counter.clone(),
1178        helm_chart_version.clone(),
1179        |_session| (),
1180        options,
1181        SYSTEM_TIME.clone(),
1182    )
1183    .await?;
1184
1185    Ok(client)
1186}
1187
1188enum Credentials {
1189    Password {
1190        username: String,
1191        password: Password,
1192    },
1193    Token {
1194        token: String,
1195    },
1196}
1197
1198async fn get_authenticator(
1199    kind: listeners::AuthenticatorKind,
1200    creds: Option<&Credentials>,
1201    frontegg: Option<mz_frontegg_auth::Authenticator>,
1202    oidc_rx: &Delayed<mz_authenticator::GenericOidcAuthenticator>,
1203    adapter_client_rx: &Delayed<Client>,
1204) -> Authenticator {
1205    match kind {
1206        listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
1207            "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
1208        )),
1209        listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Sasl => {
1210            let client = adapter_client_rx.clone().await.expect("sender not dropped");
1211            Authenticator::Password(client)
1212        }
1213        listeners::AuthenticatorKind::Oidc => match creds {
1214            // Use the password authenticator if the credentials are password-based
1215            Some(Credentials::Password { .. }) => {
1216                let client = adapter_client_rx.clone().await.expect("sender not dropped");
1217                Authenticator::Password(client)
1218            }
1219            _ => Authenticator::Oidc(oidc_rx.clone().await.expect("sender not dropped")),
1220        },
1221        listeners::AuthenticatorKind::None => Authenticator::None,
1222    }
1223}
1224
1225/// Attempts to retrieve session data from a [`TowerSession`], if available.
1226/// Session data is present only if an authenticated session has been
1227/// established via [`handle_login`].
1228pub(crate) async fn maybe_get_authenticated_session(
1229    session: Option<&TowerSession>,
1230) -> Option<(&TowerSession, TowerSessionData)> {
1231    if let Some(session) = session {
1232        if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
1233            return Some((session, session_data));
1234        }
1235    }
1236    None
1237}
1238
1239/// Ensures the session is still valid by checking for expiration,
1240/// and returns the associated user if the session remains active.
1241pub(crate) async fn ensure_session_unexpired(
1242    session: &TowerSession,
1243    session_data: TowerSessionData,
1244) -> Result<AuthedUser, AuthError> {
1245    if session_data
1246        .last_activity
1247        .elapsed()
1248        .unwrap_or(Duration::MAX)
1249        > SESSION_DURATION
1250    {
1251        let _ = session.delete().await;
1252        return Err(AuthError::SessionExpired);
1253    }
1254    let mut updated_data = session_data.clone();
1255    updated_data.last_activity = SystemTime::now();
1256    session
1257        .insert("data", &updated_data)
1258        .await
1259        .map_err(|_| AuthError::FailedToUpdateSession)?;
1260
1261    Ok(AuthedUser {
1262        name: session_data.username,
1263        external_metadata_rx: None,
1264        authenticated: session_data.authenticated,
1265        authenticator_kind: session_data.authenticator_kind,
1266        groups: None,
1267    })
1268}
1269
1270async fn auth(
1271    authenticator: &Authenticator,
1272    creds: Option<Credentials>,
1273    allowed_roles: AllowedRoles,
1274    include_www_authenticate_header: bool,
1275) -> Result<AuthedUser, AuthError> {
1276    let (name, external_metadata_rx, authenticated, groups) = match authenticator {
1277        Authenticator::Frontegg(frontegg) => match creds {
1278            Some(Credentials::Password { username, password }) => {
1279                let (auth_session, authenticated) =
1280                    frontegg.authenticate(&username, password.as_str()).await?;
1281                let name = auth_session.user().into();
1282                let external_metadata_rx = Some(auth_session.external_metadata_rx());
1283                (name, external_metadata_rx, authenticated, None)
1284            }
1285            Some(Credentials::Token { token }) => {
1286                let (claims, authenticated) = frontegg.validate_access_token(&token, None)?;
1287                let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
1288                    user_id: claims.user_id,
1289                    admin: claims.is_admin,
1290                });
1291                (claims.user, Some(external_metadata_rx), authenticated, None)
1292            }
1293            None => {
1294                return Err(AuthError::MissingHttpAuthentication {
1295                    include_www_authenticate_header,
1296                });
1297            }
1298        },
1299        Authenticator::Password(adapter_client) => match creds {
1300            Some(Credentials::Password { username, password }) => {
1301                let authenticated = adapter_client
1302                    .authenticate(&username, &password)
1303                    .await
1304                    .map_err(|_| AuthError::InvalidCredentials)?;
1305                (username, None, authenticated, None)
1306            }
1307            _ => {
1308                return Err(AuthError::MissingHttpAuthentication {
1309                    include_www_authenticate_header,
1310                });
1311            }
1312        },
1313        Authenticator::Sasl(_) => {
1314            // We shouldn't ever end up here as the configuration is validated at startup.
1315            // If we do, it's a server misconfiguration.
1316            // Just in case, we return a 401 rather than panic.
1317            return Err(AuthError::MissingHttpAuthentication {
1318                include_www_authenticate_header,
1319            });
1320        }
1321        Authenticator::Oidc(oidc) => match creds {
1322            Some(Credentials::Token { token }) => {
1323                let (mut claims, authenticated) = oidc
1324                    .authenticate(&token, None)
1325                    .await
1326                    .map_err(|e| AuthError::OidcFailed(e.to_string()))?;
1327                let name = std::mem::take(&mut claims.user);
1328                let groups = claims.groups.take();
1329                (name, None, authenticated, groups)
1330            }
1331            _ => {
1332                return Err(AuthError::MissingHttpAuthentication {
1333                    include_www_authenticate_header,
1334                });
1335            }
1336        },
1337        Authenticator::None => {
1338            // If no authentication, use whatever is in the HTTP auth
1339            // header (without checking the password), or fall back to the
1340            // default user.
1341            let name = match creds {
1342                Some(Credentials::Password { username, .. }) => username,
1343                _ => HTTP_DEFAULT_USER.name.to_owned(),
1344            };
1345            (name, None, Authenticated, None)
1346        }
1347    };
1348
1349    check_role_allowed(&name, allowed_roles)?;
1350
1351    Ok(AuthedUser {
1352        name,
1353        external_metadata_rx,
1354        authenticated,
1355        authenticator_kind: authenticator.kind(),
1356        groups,
1357    })
1358}
1359
1360// TODO move this somewhere it can be shared with PGWIRE
1361fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1362    let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1363    // this is a superset of internal users
1364    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1365    let role_allowed = match allowed_roles {
1366        AllowedRoles::Normal => !is_reserved_user,
1367        AllowedRoles::Internal => is_internal_user,
1368        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1369    };
1370    if role_allowed {
1371        Ok(())
1372    } else {
1373        Err(AuthError::RoleDisallowed(name.to_owned()))
1374    }
1375}
1376
1377/// Default layers that should be applied to all routes, and should get applied to both the
1378/// internal http and external http routers.
1379trait DefaultLayers {
1380    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1381}
1382
1383impl DefaultLayers for Router {
1384    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1385        self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1386            .layer(metrics::PrometheusLayer::new(source, metrics))
1387    }
1388}
1389
1390/// Glue code to make [`tower`] work with [`axum`].
1391///
1392/// `axum` requires `Layer`s not return Errors, i.e. they must be `Result<_, Infallible>`,
1393/// instead you must return a type that can be converted into a response. `tower` on the other
1394/// hand does return Errors, so to make the two work together we need to convert our `tower` errors
1395/// into responses.
1396async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1397    if error.is::<tower::load_shed::error::Overloaded>() {
1398        return (
1399            StatusCode::TOO_MANY_REQUESTS,
1400            Cow::from("too many requests, try again later"),
1401        );
1402    }
1403
1404    // Note: This should be unreachable because at the time of writing our only use case is a
1405    // layer that emits `tower::load_shed::error::Overloaded`, which is handled above.
1406    (
1407        StatusCode::INTERNAL_SERVER_ERROR,
1408        Cow::from(format!("Unhandled internal error: {}", error)),
1409    )
1410}
1411
1412#[derive(Debug, Deserialize, Serialize, PartialEq)]
1413pub struct LoginCredentials {
1414    username: String,
1415    password: Password,
1416}
1417
1418#[derive(Debug, Clone, Serialize, Deserialize)]
1419pub struct TowerSessionData {
1420    username: String,
1421    created_at: SystemTime,
1422    last_activity: SystemTime,
1423    authenticated: Authenticated,
1424    authenticator_kind: mz_auth::AuthenticatorKind,
1425}
1426
1427#[cfg(test)]
1428mod tests {
1429    use super::{AllowedRoles, check_role_allowed};
1430
1431    #[mz_ore::test]
1432    fn test_check_role_allowed() {
1433        // Internal user
1434        assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1435        assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1436        assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1437
1438        // Internal user
1439        assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1440        assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1441        assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1442
1443        // Internal user
1444        assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1445        assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1446        assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1447
1448        // Normal user
1449        assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1450        assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1451        assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1452
1453        // Normal user
1454        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1455        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1456        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1457
1458        // Normal user
1459        assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1460        assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1461        assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1462
1463        // Denied by reserved role prefix
1464        assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1465        assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1466        assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1467
1468        // Denied by reserved role prefix
1469        assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1470        assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1471        assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1472
1473        // Denied by reserved role prefix
1474        assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1475        assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1476        assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1477
1478        // Denied by literal PUBLIC
1479        assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1480        assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1481        assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1482    }
1483}