Skip to main content

mz_environmentd/
http.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Embedded HTTP server.
11//!
12//! environmentd embeds an HTTP server for introspection into the running
13//! process. At the moment, its primary exports are Prometheus metrics, heap
14//! profiles, and catalog dumps.
15//!
16//! ## Authentication flow
17//!
18//! The server supports several authentication modes, controlled by the
19//! configured [`listeners::AuthenticatorKind`]. The general flow is:
20//!
21//! 1. **Identity resolution.** An authentication middleware runs on every
22//!    protected request and resolves the caller's identity via one of:
23//!    - **Credentials in headers.** The caller supplies a username/password or
24//!      token in the request headers. Supported by all [`listeners::AuthenticatorKind`]s.
25//!    - **Session reuse.** If the caller has an active authenticated session
26//!      (established via `POST /api/login`) and has not supplied credentials
27//!      in the request headers, the session is reused. Only available for
28//!      [`listeners::AuthenticatorKind::Password`] and [`listeners::AuthenticatorKind::Oidc`].
29//!    - **Trusted header injection.** A trusted upstream proxy (e.g. Teleport)
30//!      may inject the caller's identity into the request headers. Only available
31//!      for [`listeners::AuthenticatorKind::None`].
32//!
33//! 2. **Session initialization.** Once the caller's identity is known, an
34//!    adapter session is opened on their behalf. This happens as part of
35//!    request processing, after all middleware has run.
36//!
37//! 3. **Request handling.** The handler executes the request (e.g. runs SQL)
38//!    using the initialized adapter session.
39//!
40//! ### WebSocket
41//!
42//! The WebSocket flow is identical to the HTTP flow with two differences:
43//!
44//! - Credentials are not read from request headers. Instead, the first
45//!   message sent by the client is treated as the authentication message.
46//! - Session initialization (step 2) happens inside the WebSocket handler
47//!   itself, rather than as a separate middleware step.
48
49// Axum handlers must use async, but often don't actually use `await`.
50#![allow(clippy::unused_async)]
51
52use std::borrow::Cow;
53use std::collections::BTreeMap;
54use std::fmt::Debug;
55use std::net::{IpAddr, SocketAddr};
56use std::pin::Pin;
57use std::sync::Arc;
58use std::time::{Duration, SystemTime};
59
60use anyhow::Context;
61use axum::error_handling::HandleErrorLayer;
62use axum::extract::ws::{Message, WebSocket};
63use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
64use axum::middleware::{self, Next};
65use axum::response::{IntoResponse, Redirect, Response};
66use axum::{Extension, Json, Router, routing};
67use futures::future::{Shared, TryFutureExt};
68use headers::authorization::{Authorization, Basic, Bearer};
69use headers::{HeaderMapExt, HeaderName};
70use http::header::{AUTHORIZATION, CONTENT_TYPE};
71use http::uri::Scheme;
72use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
73use hyper_openssl::SslStream;
74use hyper_openssl::client::legacy::MaybeHttpsStream;
75use hyper_util::rt::TokioIo;
76use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
77use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
78use mz_auth::Authenticated;
79use mz_auth::password::Password;
80use mz_authenticator::Authenticator;
81use mz_controller::ReplicaHttpLocator;
82use mz_frontegg_auth::Error as FronteggError;
83use mz_http_util::DynamicFilterTarget;
84use mz_ore::cast::u64_to_usize;
85use mz_ore::metrics::MetricsRegistry;
86use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
87use mz_ore::str::StrExt;
88use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
89use mz_repr::user::ExternalUserMetadata;
90use mz_server_core::listeners::{self, AllowedRoles, HttpRoutesEnabled};
91use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
92use mz_sql::session::metadata::SessionMetadata;
93use mz_sql::session::user::{
94    HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
95};
96use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
97use openssl::ssl::Ssl;
98use prometheus::{
99    COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
100};
101use serde::{Deserialize, Serialize};
102use serde_json::json;
103use thiserror::Error;
104use tokio::io::AsyncWriteExt;
105use tokio::sync::oneshot::Receiver;
106use tokio::sync::{oneshot, watch};
107use tokio_metrics::TaskMetrics;
108use tower::limit::GlobalConcurrencyLimitLayer;
109use tower::{Service, ServiceBuilder};
110use tower_http::cors::{AllowOrigin, Any, CorsLayer};
111use tower_sessions::{
112    MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
113    SessionManagerLayer as TowerSessionManagerLayer,
114};
115use tracing::warn;
116
117use crate::BUILD_INFO;
118use crate::deployment::state::DeploymentStateHandle;
119use crate::http::sql::{ExistingUser, SqlError};
120
121mod catalog;
122mod cluster;
123mod console;
124mod mcp;
125mod memory;
126mod metrics;
127mod metrics_viz;
128mod probe;
129mod prometheus;
130mod root;
131mod sql;
132mod webhook;
133
134pub use metrics::Metrics;
135pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
136
137/// Maximum allowed size for a request.
138pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
139
140const SESSION_DURATION: Duration = Duration::from_secs(8 * 3600); // 8 hours
141
142const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
143
144#[derive(Debug)]
145pub struct HttpConfig {
146    pub source: &'static str,
147    pub tls: Option<ReloadingSslContext>,
148    pub authenticator_kind: listeners::AuthenticatorKind,
149    pub frontegg: Option<mz_frontegg_auth::Authenticator>,
150    pub oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
151    pub adapter_client_rx: Shared<Receiver<Client>>,
152    pub allowed_origin: AllowOrigin,
153    /// Raw list of allowed CORS origins, used by the MCP endpoints for
154    /// server-side Origin validation to defend against DNS rebinding.
155    pub allowed_origin_list: Vec<HeaderValue>,
156    pub active_connection_counter: ConnectionCounter,
157    pub helm_chart_version: Option<String>,
158    pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
159    pub metrics: Metrics,
160    pub metrics_registry: MetricsRegistry,
161    pub allowed_roles: AllowedRoles,
162    pub internal_route_config: Arc<InternalRouteConfig>,
163    pub routes_enabled: HttpRoutesEnabled,
164    /// Locator for cluster replica HTTP addresses, used for proxying requests.
165    pub replica_http_locator: Arc<ReplicaHttpLocator>,
166}
167
168#[derive(Debug, Clone)]
169pub struct InternalRouteConfig {
170    pub deployment_state_handle: DeploymentStateHandle,
171    pub internal_console_redirect_url: Option<String>,
172}
173
174#[derive(Clone)]
175pub struct WsState {
176    frontegg: Option<mz_frontegg_auth::Authenticator>,
177    oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
178    authenticator_kind: listeners::AuthenticatorKind,
179    adapter_client_rx: Delayed<mz_adapter::Client>,
180    active_connection_counter: ConnectionCounter,
181    helm_chart_version: Option<String>,
182    allowed_roles: AllowedRoles,
183}
184
185#[derive(Clone)]
186pub struct WebhookState {
187    adapter_client_rx: Delayed<mz_adapter::Client>,
188    webhook_cache: WebhookAppenderCache,
189}
190
191#[derive(Clone, Debug)]
192struct HelmChartVersion(Option<String>);
193
194#[derive(Debug)]
195pub struct HttpServer {
196    tls: Option<ReloadingSslContext>,
197    router: Router,
198}
199
200impl HttpServer {
201    pub fn new(
202        HttpConfig {
203            source,
204            tls,
205            authenticator_kind,
206            frontegg,
207            oidc_rx,
208            adapter_client_rx,
209            allowed_origin,
210            allowed_origin_list,
211            active_connection_counter,
212            helm_chart_version,
213            concurrent_webhook_req,
214            metrics,
215            metrics_registry,
216            allowed_roles,
217            internal_route_config,
218            routes_enabled,
219            replica_http_locator,
220        }: HttpConfig,
221    ) -> HttpServer {
222        let tls_enabled = tls.is_some();
223        let webhook_cache = WebhookAppenderCache::new();
224
225        // Create secure session store and manager
226        let session_store = TowerSessionMemoryStore::default();
227        let session_layer = TowerSessionManagerLayer::new(session_store)
228            .with_secure(tls_enabled) // Enforce HTTPS
229            .with_same_site(tower_sessions::cookie::SameSite::Strict) // Prevent CSRF
230            .with_http_only(true) // Prevent XSS
231            .with_name("mz_session") // Custom cookie name
232            .with_path("/"); // Set cookie path
233
234        let frontegg_middleware = frontegg.clone();
235        let oidc_middleware_rx = oidc_rx.clone();
236        let adapter_client_middleware_rx = adapter_client_rx.clone();
237        let auth_middleware = middleware::from_fn(move |req, next| {
238            let frontegg = frontegg_middleware.clone();
239            let oidc_rx = oidc_middleware_rx.clone();
240            let adapter_client_rx = adapter_client_middleware_rx.clone();
241            async move {
242                http_auth(
243                    req,
244                    next,
245                    tls_enabled,
246                    authenticator_kind,
247                    frontegg,
248                    oidc_rx,
249                    adapter_client_rx,
250                    allowed_roles,
251                )
252                .await
253            }
254        });
255
256        let mut router = Router::new();
257        let mut base_router = Router::new();
258        if routes_enabled.base {
259            base_router = base_router
260                .route(
261                    "/",
262                    routing::get(move || async move { root::handle_home(routes_enabled).await }),
263                )
264                .route("/api/sql", routing::post(sql::handle_sql))
265                .route("/memory", routing::get(memory::handle_memory))
266                .route(
267                    "/hierarchical-memory",
268                    routing::get(memory::handle_hierarchical_memory),
269                )
270                .route(
271                    "/metrics-viz",
272                    routing::get(metrics_viz::handle_metrics_viz),
273                )
274                .route("/static/{*path}", routing::get(root::handle_static));
275
276            let mut ws_router = Router::new()
277                .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
278                .with_state(WsState {
279                    frontegg,
280                    oidc_rx: oidc_rx.clone(),
281                    authenticator_kind,
282                    adapter_client_rx: adapter_client_rx.clone(),
283                    active_connection_counter: active_connection_counter.clone(),
284                    helm_chart_version: helm_chart_version.clone(),
285                    allowed_roles,
286                });
287            if let listeners::AuthenticatorKind::None = authenticator_kind {
288                ws_router = ws_router.layer(middleware::from_fn_with_state(
289                    allowed_roles,
290                    x_materialize_user_header_auth,
291                ));
292            }
293            router = router.merge(ws_router);
294        }
295        if routes_enabled.profiling {
296            base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
297        }
298
299        if routes_enabled.webhook {
300            let webhook_router = Router::new()
301                .route(
302                    "/api/webhook/{:database}/{:schema}/{:id}",
303                    routing::post(webhook::handle_webhook),
304                )
305                .with_state(WebhookState {
306                    adapter_client_rx: adapter_client_rx.clone(),
307                    webhook_cache,
308                })
309                .layer(
310                    tower_http::decompression::RequestDecompressionLayer::new()
311                        .gzip(true)
312                        .deflate(true)
313                        .br(true)
314                        .zstd(true),
315                )
316                .layer(
317                    CorsLayer::new()
318                        .allow_methods(Method::POST)
319                        .allow_origin(AllowOrigin::mirror_request())
320                        .allow_headers(Any),
321                )
322                .layer(
323                    ServiceBuilder::new()
324                        .layer(HandleErrorLayer::new(handle_load_error))
325                        .load_shed()
326                        .layer(GlobalConcurrencyLimitLayer::with_semaphore(
327                            concurrent_webhook_req,
328                        )),
329                );
330            router = router.merge(webhook_router);
331        }
332
333        if routes_enabled.internal {
334            let console_config = Arc::new(console::ConsoleProxyConfig::new(
335                internal_route_config.internal_console_redirect_url.clone(),
336                "/internal-console".to_string(),
337            ));
338            base_router = base_router
339                .route(
340                    "/api/opentelemetry/config",
341                    routing::put({
342                        move |_: axum::Json<DynamicFilterTarget>| async {
343                            (
344                                StatusCode::BAD_REQUEST,
345                                "This endpoint has been replaced. \
346                            Use the `opentelemetry_filter` system variable."
347                                    .to_string(),
348                            )
349                        }
350                    }),
351                )
352                .route(
353                    "/api/stderr/config",
354                    routing::put({
355                        move |_: axum::Json<DynamicFilterTarget>| async {
356                            (
357                                StatusCode::BAD_REQUEST,
358                                "This endpoint has been replaced. \
359                            Use the `log_filter` system variable."
360                                    .to_string(),
361                            )
362                        }
363                    }),
364                )
365                .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
366                .route(
367                    "/api/catalog/dump",
368                    routing::get(catalog::handle_catalog_dump),
369                )
370                .route(
371                    "/api/catalog/check",
372                    routing::get(catalog::handle_catalog_check),
373                )
374                .route(
375                    "/api/catalog/inject-audit-events",
376                    routing::post(catalog::handle_inject_audit_events),
377                )
378                .route(
379                    "/api/coordinator/check",
380                    routing::get(catalog::handle_coordinator_check),
381                )
382                .route(
383                    "/api/coordinator/dump",
384                    routing::get(catalog::handle_coordinator_dump),
385                )
386                .route(
387                    "/internal-console",
388                    routing::get(|| async { Redirect::temporary("/internal-console/") }),
389                )
390                .route(
391                    "/internal-console/{*path}",
392                    routing::get(console::handle_internal_console),
393                )
394                .route(
395                    "/internal-console/",
396                    routing::get(console::handle_internal_console),
397                )
398                .layer(Extension(console_config));
399
400            // Cluster HTTP proxy routes.
401            let cluster_proxy_config = Arc::new(cluster::ClusterProxyConfig::new(Arc::clone(
402                &replica_http_locator,
403            )));
404            base_router = base_router
405                .route("/clusters", routing::get(cluster::handle_clusters))
406                .route(
407                    "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/",
408                    routing::any(cluster::handle_cluster_proxy_root),
409                )
410                .route(
411                    "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/{*path}",
412                    routing::any(cluster::handle_cluster_proxy),
413                )
414                .layer(Extension(cluster_proxy_config));
415
416            let leader_router = Router::new()
417                .route("/api/leader/status", routing::get(handle_leader_status))
418                .route("/api/leader/promote", routing::post(handle_leader_promote))
419                .route(
420                    "/api/leader/skip-catchup",
421                    routing::post(handle_leader_skip_catchup),
422                )
423                .layer(auth_middleware.clone())
424                .with_state(internal_route_config.deployment_state_handle.clone());
425            router = router.merge(leader_router);
426        }
427
428        if routes_enabled.metrics {
429            let metrics_router = Router::new()
430                .route(
431                    "/metrics",
432                    routing::get(move || async move {
433                        mz_http_util::handle_prometheus(&metrics_registry).await
434                    }),
435                )
436                .route(
437                    "/metrics/mz_usage",
438                    routing::get(|client: AuthedClient| async move {
439                        let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
440                        mz_http_util::handle_prometheus(&registry).await
441                    }),
442                )
443                .route(
444                    "/metrics/mz_frontier",
445                    routing::get(|client: AuthedClient| async move {
446                        let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
447                        mz_http_util::handle_prometheus(&registry).await
448                    }),
449                )
450                .route(
451                    "/metrics/mz_compute",
452                    routing::get(|client: AuthedClient| async move {
453                        let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
454                        mz_http_util::handle_prometheus(&registry).await
455                    }),
456                )
457                .route(
458                    "/metrics/mz_storage",
459                    routing::get(|client: AuthedClient| async move {
460                        let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
461                        mz_http_util::handle_prometheus(&registry).await
462                    }),
463                )
464                .route(
465                    "/api/livez",
466                    routing::get(mz_http_util::handle_liveness_check),
467                )
468                .route("/api/readyz", routing::get(probe::handle_ready))
469                .layer(auth_middleware.clone())
470                .layer(Extension(adapter_client_rx.clone()))
471                .layer(Extension(active_connection_counter.clone()))
472                .layer(Extension(HelmChartVersion(helm_chart_version.clone())));
473            router = router.merge(metrics_router);
474        }
475
476        if routes_enabled.console_config {
477            let console_config_router = Router::new()
478                .route(
479                    "/api/console/config",
480                    routing::get(console::handle_console_config),
481                )
482                .layer(Extension(adapter_client_rx.clone()))
483                .layer(Extension(active_connection_counter.clone()));
484            router = router.merge(console_config_router);
485        }
486
487        // MCP (Model Context Protocol) endpoints
488        // Enabled via runtime `routes_enabled.mcp_agent` and `routes_enabled.mcp_developer` configuration
489        if routes_enabled.mcp_agent || routes_enabled.mcp_developer {
490            use tracing::info;
491
492            let mut mcp_router = Router::new();
493
494            if routes_enabled.mcp_agent {
495                info!("Enabling MCP agent endpoint: /api/mcp/agent");
496                mcp_router = mcp_router.route(
497                    "/api/mcp/agent",
498                    routing::post(mcp::handle_mcp_agent).get(mcp::handle_mcp_method_not_allowed),
499                );
500            }
501
502            if routes_enabled.mcp_developer {
503                info!("Enabling MCP developer endpoint: /api/mcp/developer");
504                mcp_router = mcp_router.route(
505                    "/api/mcp/developer",
506                    routing::post(mcp::handle_mcp_developer)
507                        .get(mcp::handle_mcp_method_not_allowed),
508                );
509            }
510
511            // The MCP handlers perform a server-side Origin check against this
512            // allowlist to defend against DNS rebinding attacks (see
513            // database-issues#11311). The CorsLayer alone is not enough: in a
514            // DNS rebinding attack the browser considers the request
515            // same-origin, so no preflight fires and CORS enforcement is
516            // bypassed.
517            let mcp_allowed_origins = Arc::new(allowed_origin_list.clone());
518            mcp_router = mcp_router
519                .layer(auth_middleware.clone())
520                .layer(Extension(adapter_client_rx.clone()))
521                .layer(Extension(active_connection_counter.clone()))
522                .layer(Extension(HelmChartVersion(helm_chart_version.clone())))
523                .layer(Extension(mcp_allowed_origins))
524                .layer(
525                    CorsLayer::new()
526                        .allow_methods(Method::POST)
527                        .allow_origin(allowed_origin.clone())
528                        .allow_headers([AUTHORIZATION, CONTENT_TYPE]),
529                );
530            router = router.merge(mcp_router);
531        }
532
533        base_router = base_router
534            .layer(auth_middleware.clone())
535            .layer(Extension(adapter_client_rx.clone()))
536            .layer(Extension(active_connection_counter.clone()))
537            .layer(Extension(HelmChartVersion(helm_chart_version)))
538            .layer(
539                CorsLayer::new()
540                    .allow_credentials(false)
541                    .allow_headers([
542                        AUTHORIZATION,
543                        CONTENT_TYPE,
544                        HeaderName::from_static("x-materialize-version"),
545                    ])
546                    .allow_methods(Any)
547                    .allow_origin(allowed_origin)
548                    .expose_headers(Any)
549                    .max_age(Duration::from_secs(60) * 60),
550            );
551
552        match authenticator_kind {
553            listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Oidc => {
554                base_router = base_router.layer(session_layer.clone());
555
556                let login_router = Router::new()
557                    .route("/api/login", routing::post(handle_login))
558                    .route("/api/logout", routing::post(handle_logout))
559                    .layer(Extension(adapter_client_rx))
560                    .layer(Extension(allowed_roles));
561                router = router.merge(login_router).layer(session_layer);
562            }
563            listeners::AuthenticatorKind::None => {
564                base_router = base_router.layer(middleware::from_fn_with_state(
565                    allowed_roles,
566                    x_materialize_user_header_auth,
567                ));
568            }
569            _ => {}
570        }
571
572        router = router
573            .merge(base_router)
574            .apply_default_layers(source, metrics);
575
576        HttpServer { tls, router }
577    }
578}
579
580impl Server for HttpServer {
581    const NAME: &'static str = "http";
582
583    fn handle_connection(
584        &self,
585        conn: Connection,
586        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
587    ) -> ConnectionHandler {
588        let router = self.router.clone();
589        let tls_context = self.tls.clone();
590        let mut conn = TokioIo::new(conn);
591
592        Box::pin(async {
593            let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
594            let peer_addr = conn
595                .inner_mut()
596                .take_proxy_header_address()
597                .await
598                .map(|a| a.source)
599                .unwrap_or(direct_peer_addr);
600
601            let (conn, conn_protocol) = match tls_context {
602                Some(tls_context) => {
603                    let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
604                    if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
605                        let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
606                        return Err(e.into());
607                    }
608                    (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
609                }
610                _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
611            };
612            let mut make_tower_svc = router
613                .layer(Extension(conn_protocol))
614                .into_make_service_with_connect_info::<SocketAddr>();
615            let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
616            let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
617            let http = hyper::server::conn::http1::Builder::new();
618            http.serve_connection(conn, hyper_svc)
619                .with_upgrades()
620                .err_into()
621                .await
622        })
623    }
624}
625
626pub async fn handle_leader_status(
627    State(deployment_state_handle): State<DeploymentStateHandle>,
628) -> impl IntoResponse {
629    let status = deployment_state_handle.status();
630    (StatusCode::OK, Json(json!({ "status": status })))
631}
632
633pub async fn handle_leader_promote(
634    State(deployment_state_handle): State<DeploymentStateHandle>,
635) -> impl IntoResponse {
636    match deployment_state_handle.try_promote() {
637        Ok(()) => {
638            // TODO(benesch): the body here is redundant. Should just return
639            // 204.
640            let status = StatusCode::OK;
641            let body = Json(json!({
642                "result": "Success",
643            }));
644            (status, body)
645        }
646        Err(()) => {
647            // TODO(benesch): the nesting here is redundant given the error
648            // code. Should just return the `{"message": "..."}` object.
649            let status = StatusCode::BAD_REQUEST;
650            let body = Json(json!({
651                "result": {"Failure": {"message": "cannot promote leader while initializing"}},
652            }));
653            (status, body)
654        }
655    }
656}
657
658pub async fn handle_leader_skip_catchup(
659    State(deployment_state_handle): State<DeploymentStateHandle>,
660) -> impl IntoResponse {
661    match deployment_state_handle.try_skip_catchup() {
662        Ok(()) => StatusCode::NO_CONTENT.into_response(),
663        Err(()) => {
664            let status = StatusCode::BAD_REQUEST;
665            let body = Json(json!({
666                "message": "cannot skip catchup in this phase of initialization; try again later",
667            }));
668            (status, body).into_response()
669        }
670    }
671}
672
673async fn x_materialize_user_header_auth(
674    State(allowed_roles): State<AllowedRoles>,
675    mut req: Request,
676    next: Next,
677) -> impl IntoResponse {
678    // TODO migrate teleport to basic auth and remove this.
679    if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
680        let username = match username {
681            Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
682            _ => {
683                return Err(AuthError::MismatchedUser(format!(
684                    "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
685                )));
686            }
687        };
688        // Enforce the listener's `allowed_roles` policy here. Without this,
689        // a listener with `authenticator_kind=None` and `allowed_roles=Normal`
690        // would let any caller assert `x-materialize-user: mz_system` and
691        // bypass the role restriction.
692        check_role_allowed(&username, allowed_roles)?;
693        req.extensions_mut().insert(AuthedUser {
694            name: username,
695            external_metadata_rx: None,
696            authenticated: Authenticated,
697            authenticator_kind: mz_auth::AuthenticatorKind::None,
698            groups: None,
699        });
700    }
701    Ok(next.run(req).await)
702}
703
704type Delayed<T> = Shared<oneshot::Receiver<T>>;
705
706#[derive(Clone)]
707enum ConnProtocol {
708    Http,
709    Https,
710}
711
712#[derive(Clone, Debug)]
713pub struct AuthedUser {
714    name: String,
715    external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
716    authenticated: Authenticated,
717    authenticator_kind: mz_auth::AuthenticatorKind,
718    /// Groups from JWT claims for OIDC group-to-role sync.
719    groups: Option<Vec<String>>,
720}
721
722pub struct AuthedClient {
723    pub client: SessionClient,
724    pub connection_guard: Option<ConnectionHandle>,
725}
726
727impl AuthedClient {
728    async fn new<F>(
729        adapter_client: &Client,
730        user: AuthedUser,
731        peer_addr: IpAddr,
732        active_connection_counter: ConnectionCounter,
733        helm_chart_version: Option<String>,
734        session_config: F,
735        options: BTreeMap<String, String>,
736        now: NowFn,
737    ) -> Result<Self, AdapterError>
738    where
739        F: FnOnce(&mut AdapterSession),
740    {
741        let conn_id = adapter_client.new_conn_id()?;
742        let mut session = adapter_client.new_session(
743            AdapterSessionConfig {
744                conn_id,
745                uuid: epoch_to_uuid_v7(&(now)()),
746                user: user.name,
747                client_ip: Some(peer_addr),
748                external_metadata_rx: user.external_metadata_rx,
749                helm_chart_version,
750                authenticator_kind: user.authenticator_kind,
751                groups: user.groups,
752            },
753            user.authenticated,
754        );
755        let connection_guard = active_connection_counter.allocate_connection(session.user())?;
756
757        session_config(&mut session);
758        let system_vars = adapter_client.get_system_vars().await;
759        for (key, val) in options {
760            const LOCAL: bool = false;
761            if let Err(err) =
762                session
763                    .vars_mut()
764                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
765            {
766                session.add_notice(AdapterNotice::BadStartupSetting {
767                    name: key.to_string(),
768                    reason: err.to_string(),
769                })
770            }
771        }
772        let adapter_client = adapter_client.startup(session).await?;
773        Ok(AuthedClient {
774            client: adapter_client,
775            connection_guard,
776        })
777    }
778}
779
780impl<S> FromRequestParts<S> for AuthedClient
781where
782    S: Send + Sync,
783{
784    type Rejection = Response;
785
786    async fn from_request_parts(
787        req: &mut http::request::Parts,
788        state: &S,
789    ) -> Result<Self, Self::Rejection> {
790        #[derive(Debug, Default, Deserialize)]
791        struct Params {
792            #[serde(default)]
793            options: String,
794        }
795        let params: Query<Params> = Query::from_request_parts(req, state)
796            .await
797            .unwrap_or_default();
798
799        let peer_addr = req
800            .extensions
801            .get::<ConnectInfo<SocketAddr>>()
802            .expect("ConnectInfo extension guaranteed to exist")
803            .0
804            .ip();
805
806        let user = req.extensions.get::<AuthedUser>().unwrap();
807        let adapter_client = req
808            .extensions
809            .get::<Delayed<mz_adapter::Client>>()
810            .unwrap()
811            .clone();
812        let adapter_client = adapter_client.await.map_err(|_| {
813            (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
814        })?;
815        let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
816        let helm_chart_version = req
817            .extensions
818            .get::<HelmChartVersion>()
819            .map(|h| h.0.clone())
820            .unwrap_or(None);
821
822        let options = if params.options.is_empty() {
823            // It's possible 'options' simply wasn't provided, we don't want that to
824            // count as a failure to deserialize
825            BTreeMap::<String, String>::default()
826        } else {
827            match serde_json::from_str(&params.options) {
828                Ok(options) => options,
829                Err(_e) => {
830                    // If we fail to deserialize options, fail the request.
831                    let code = StatusCode::BAD_REQUEST;
832                    let msg = format!("Failed to deserialize {} map", "options".quoted());
833                    return Err((code, msg).into_response());
834                }
835            }
836        };
837
838        let client = AuthedClient::new(
839            &adapter_client,
840            user.clone(),
841            peer_addr,
842            active_connection_counter.clone(),
843            helm_chart_version,
844            |session| {
845                session
846                    .vars_mut()
847                    .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
848                    .expect("known to exist")
849            },
850            options,
851            SYSTEM_TIME.clone(),
852        )
853        .await
854        .map_err(|e| {
855            let status = match e {
856                AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
857                    StatusCode::FORBIDDEN
858                }
859                _ => StatusCode::INTERNAL_SERVER_ERROR,
860            };
861            (status, Json(SqlError::from(e))).into_response()
862        })?;
863
864        Ok(client)
865    }
866}
867
868#[derive(Debug, Error)]
869pub(crate) enum AuthError {
870    #[error("role dissallowed")]
871    RoleDisallowed(String),
872    #[error("{0}")]
873    Frontegg(#[from] FronteggError),
874    #[error("missing authorization header")]
875    MissingHttpAuthentication {
876        include_www_authenticate_header: bool,
877    },
878    #[error("{0}")]
879    MismatchedUser(String),
880    #[error("session expired")]
881    SessionExpired,
882    #[error("failed to update session")]
883    FailedToUpdateSession,
884    #[error("invalid credentials")]
885    InvalidCredentials,
886    /// Payload is `OidcError`'s sanitized `Display` (no expected-values leaks).
887    #[error("{0}")]
888    OidcFailed(String),
889}
890
891impl IntoResponse for AuthError {
892    fn into_response(self) -> Response {
893        warn!("HTTP request failed authentication: {}", self);
894        let mut headers = HeaderMap::new();
895        // We omit most detail from the error message we send to the client, to
896        // avoid giving attackers unnecessary information. `OidcFailed` is the
897        // exception — its payload is a sanitized `OidcError::Display` that the
898        // console embeds in the login-page error.
899        let body = match &self {
900            AuthError::MissingHttpAuthentication {
901                include_www_authenticate_header,
902            } if *include_www_authenticate_header => {
903                headers.insert(
904                    http::header::WWW_AUTHENTICATE,
905                    HeaderValue::from_static("Basic realm=Materialize"),
906                );
907                "unauthorized".to_string()
908            }
909            AuthError::OidcFailed(message) => message.clone(),
910            _ => "unauthorized".to_string(),
911        };
912        (StatusCode::UNAUTHORIZED, headers, body).into_response()
913    }
914}
915
916// Simplified login handler
917pub async fn handle_login(
918    session: Option<Extension<TowerSession>>,
919    Extension(adapter_client_rx): Extension<Delayed<Client>>,
920    Extension(allowed_roles): Extension<AllowedRoles>,
921    Json(LoginCredentials { username, password }): Json<LoginCredentials>,
922) -> impl IntoResponse {
923    // Enforce the listener's allowed_roles policy before doing any
924    // authentication work, mirroring the check performed in `auth` for
925    // header-based credentials. Without this, a session-based caller could
926    // log in as a role that header-based callers are forbidden to use.
927    if let Err(err) = check_role_allowed(&username, allowed_roles) {
928        warn!(
929            ?err,
930            "HTTP login rejected: role not allowed on this listener"
931        );
932        return StatusCode::UNAUTHORIZED;
933    }
934    let Ok(adapter_client) = adapter_client_rx.clone().await else {
935        return StatusCode::INTERNAL_SERVER_ERROR;
936    };
937    let authenticated = match adapter_client.authenticate(&username, &password).await {
938        Ok(authenticated) => authenticated,
939        Err(err) => {
940            warn!(?err, "HTTP login failed authentication");
941            return StatusCode::UNAUTHORIZED;
942        }
943    };
944    // Create session data
945    let session_data = TowerSessionData {
946        username,
947        created_at: SystemTime::now(),
948        last_activity: SystemTime::now(),
949        authenticated,
950        authenticator_kind: mz_auth::AuthenticatorKind::Password,
951    };
952    // Store session data
953    let session = session.and_then(|Extension(session)| Some(session));
954    let Some(session) = session else {
955        return StatusCode::INTERNAL_SERVER_ERROR;
956    };
957    match session.insert("data", &session_data).await {
958        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
959        Ok(_) => StatusCode::OK,
960    }
961}
962
963// Simplified logout handler
964pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
965    let session = session.and_then(|Extension(session)| Some(session));
966    let Some(session) = session else {
967        return StatusCode::INTERNAL_SERVER_ERROR;
968    };
969    // Delete session
970    match session.delete().await {
971        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
972        Ok(_) => StatusCode::OK,
973    }
974}
975
976async fn http_auth(
977    mut req: Request,
978    next: Next,
979    tls_enabled: bool,
980    authenticator_kind: listeners::AuthenticatorKind,
981    frontegg: Option<mz_frontegg_auth::Authenticator>,
982    oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
983    adapter_client_rx: Delayed<Client>,
984    allowed_roles: AllowedRoles,
985) -> Result<impl IntoResponse, AuthError> {
986    let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
987        Some(Credentials::Password {
988            username: basic.username().to_owned(),
989            password: Password(basic.password().to_owned()),
990        })
991    } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
992        Some(Credentials::Token {
993            token: bearer.token().to_owned(),
994        })
995    } else {
996        None
997    };
998
999    // Reuses an authenticated session if one already exists.
1000    // If credentials are provided, we perform a new authentication,
1001    // separate from the existing session.
1002    if creds.is_none()
1003        && let Some((session, session_data)) =
1004            maybe_get_authenticated_session(req.extensions().get::<TowerSession>()).await
1005    {
1006        let user = ensure_session_unexpired(session, session_data).await?;
1007        // Defense-in-depth: re-check the listener's `allowed_roles` policy on
1008        // every session-authenticated request. The same check runs at
1009        // `/api/login`, but enforcing it here too prevents a session minted
1010        // under a more permissive configuration (or a future bug in the login
1011        // path) from bypassing role restrictions.
1012        check_role_allowed(&user.name, allowed_roles)?;
1013        // Need this to set the user of the Adapter client.
1014        req.extensions_mut().insert(user);
1015        return Ok(next.run(req).await);
1016    }
1017
1018    // First, extract the username from the certificate, validating that the
1019    // connection matches the TLS configuration along the way.
1020    // Fall back to existing authentication methods.
1021    let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
1022    match (tls_enabled, &conn_protocol) {
1023        (false, ConnProtocol::Http) => {}
1024        (false, ConnProtocol::Https { .. }) => unreachable!(),
1025        (true, ConnProtocol::Http) => {
1026            let mut parts = req.uri().clone().into_parts();
1027            parts.scheme = Some(Scheme::HTTPS);
1028            return Ok(Redirect::permanent(
1029                &Uri::from_parts(parts)
1030                    .expect("it was already a URI, just changed the scheme")
1031                    .to_string(),
1032            )
1033            .into_response());
1034        }
1035        (true, ConnProtocol::Https { .. }) => {}
1036    }
1037    // If we've already passed some other auth, just use that.
1038    if req.extensions().get::<AuthedUser>().is_some() {
1039        return Ok(next.run(req).await);
1040    }
1041
1042    let path = req.uri().path();
1043    let include_www_authenticate_header = path == "/"
1044        || PROFILING_API_ENDPOINTS
1045            .iter()
1046            .any(|prefix| path.starts_with(prefix));
1047    let authenticator = get_authenticator(
1048        authenticator_kind,
1049        creds.as_ref(),
1050        frontegg,
1051        &oidc_rx,
1052        &adapter_client_rx,
1053    )
1054    .await;
1055
1056    let user = auth(
1057        &authenticator,
1058        creds,
1059        allowed_roles,
1060        include_www_authenticate_header,
1061    )
1062    .await?;
1063
1064    // Add the authenticated user as an extension so downstream handlers can
1065    // inspect it if necessary.
1066    req.extensions_mut().insert(user);
1067
1068    // Run the request.
1069    Ok(next.run(req).await)
1070}
1071
1072async fn init_ws(
1073    WsState {
1074        frontegg,
1075        oidc_rx,
1076        authenticator_kind,
1077        adapter_client_rx,
1078        active_connection_counter,
1079        helm_chart_version,
1080        allowed_roles,
1081    }: WsState,
1082    existing_user: Option<ExistingUser>,
1083    peer_addr: IpAddr,
1084    ws: &mut WebSocket,
1085) -> Result<AuthedClient, anyhow::Error> {
1086    // TODO: Add a timeout here to prevent resource leaks by clients that
1087    // connect then never send a message.
1088    let ws_auth: WebSocketAuth = loop {
1089        let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
1090        match init_msg {
1091            Message::Text(data) => break serde_json::from_str(&data)?,
1092            Message::Binary(data) => break serde_json::from_slice(&data)?,
1093            // Handled automatically by the server.
1094            Message::Ping(_) => {
1095                continue;
1096            }
1097            Message::Pong(_) => {
1098                continue;
1099            }
1100            Message::Close(_) => {
1101                anyhow::bail!("closed");
1102            }
1103        }
1104    };
1105
1106    // If credentials are provided, we perform a new authentication,
1107    // separate from the existing session.
1108    let (creds, options) = match ws_auth {
1109        WebSocketAuth::Basic {
1110            user,
1111            password,
1112            options,
1113        } => {
1114            let creds = Credentials::Password {
1115                username: user,
1116                password,
1117            };
1118            (Some(creds), options)
1119        }
1120        WebSocketAuth::Bearer { token, options } => {
1121            let creds = Credentials::Token { token };
1122            (Some(creds), options)
1123        }
1124        WebSocketAuth::OptionsOnly { options } => (None, options),
1125    };
1126
1127    let user = match (existing_user, creds) {
1128        (Some(ExistingUser::XMaterializeUserHeader(_)), Some(_creds)) => {
1129            warn!("Unexpected bearer or basic auth provided when using user header");
1130            anyhow::bail!("unexpected")
1131        }
1132        (Some(ExistingUser::Session(user)), None) => {
1133            // Defense-in-depth: re-enforce the listener's `allowed_roles`
1134            // policy on session-authenticated WebSocket connections. The same
1135            // check runs at `/api/login`, but enforcing it here too prevents a
1136            // session minted under a more permissive configuration (or a
1137            // future bug in the login path) from bypassing role restrictions.
1138            check_role_allowed(&user.name, allowed_roles)?;
1139            user
1140        }
1141        (Some(ExistingUser::XMaterializeUserHeader(user)), None) => user,
1142        (_, Some(creds)) => {
1143            let authenticator = get_authenticator(
1144                authenticator_kind,
1145                Some(&creds),
1146                frontegg,
1147                &oidc_rx,
1148                &adapter_client_rx,
1149            )
1150            .await;
1151            let user = auth(&authenticator, Some(creds), allowed_roles, false).await?;
1152            user
1153        }
1154        (None, None) => anyhow::bail!("expected auth information"),
1155    };
1156
1157    let client = AuthedClient::new(
1158        &adapter_client_rx.clone().await?,
1159        user,
1160        peer_addr,
1161        active_connection_counter.clone(),
1162        helm_chart_version.clone(),
1163        |_session| (),
1164        options,
1165        SYSTEM_TIME.clone(),
1166    )
1167    .await?;
1168
1169    Ok(client)
1170}
1171
1172enum Credentials {
1173    Password {
1174        username: String,
1175        password: Password,
1176    },
1177    Token {
1178        token: String,
1179    },
1180}
1181
1182async fn get_authenticator(
1183    kind: listeners::AuthenticatorKind,
1184    creds: Option<&Credentials>,
1185    frontegg: Option<mz_frontegg_auth::Authenticator>,
1186    oidc_rx: &Delayed<mz_authenticator::GenericOidcAuthenticator>,
1187    adapter_client_rx: &Delayed<Client>,
1188) -> Authenticator {
1189    match kind {
1190        listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
1191            "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
1192        )),
1193        listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Sasl => {
1194            let client = adapter_client_rx.clone().await.expect("sender not dropped");
1195            Authenticator::Password(client)
1196        }
1197        listeners::AuthenticatorKind::Oidc => match creds {
1198            // Use the password authenticator if the credentials are password-based
1199            Some(Credentials::Password { .. }) => {
1200                let client = adapter_client_rx.clone().await.expect("sender not dropped");
1201                Authenticator::Password(client)
1202            }
1203            _ => Authenticator::Oidc(oidc_rx.clone().await.expect("sender not dropped")),
1204        },
1205        listeners::AuthenticatorKind::None => Authenticator::None,
1206    }
1207}
1208
1209/// Attempts to retrieve session data from a [`TowerSession`], if available.
1210/// Session data is present only if an authenticated session has been
1211/// established via [`handle_login`].
1212pub(crate) async fn maybe_get_authenticated_session(
1213    session: Option<&TowerSession>,
1214) -> Option<(&TowerSession, TowerSessionData)> {
1215    if let Some(session) = session {
1216        if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
1217            return Some((session, session_data));
1218        }
1219    }
1220    None
1221}
1222
1223/// Ensures the session is still valid by checking for expiration,
1224/// and returns the associated user if the session remains active.
1225pub(crate) async fn ensure_session_unexpired(
1226    session: &TowerSession,
1227    session_data: TowerSessionData,
1228) -> Result<AuthedUser, AuthError> {
1229    if session_data
1230        .last_activity
1231        .elapsed()
1232        .unwrap_or(Duration::MAX)
1233        > SESSION_DURATION
1234    {
1235        let _ = session.delete().await;
1236        return Err(AuthError::SessionExpired);
1237    }
1238    let mut updated_data = session_data.clone();
1239    updated_data.last_activity = SystemTime::now();
1240    session
1241        .insert("data", &updated_data)
1242        .await
1243        .map_err(|_| AuthError::FailedToUpdateSession)?;
1244
1245    Ok(AuthedUser {
1246        name: session_data.username,
1247        external_metadata_rx: None,
1248        authenticated: session_data.authenticated,
1249        authenticator_kind: session_data.authenticator_kind,
1250        groups: None,
1251    })
1252}
1253
1254async fn auth(
1255    authenticator: &Authenticator,
1256    creds: Option<Credentials>,
1257    allowed_roles: AllowedRoles,
1258    include_www_authenticate_header: bool,
1259) -> Result<AuthedUser, AuthError> {
1260    let (name, external_metadata_rx, authenticated, groups) = match authenticator {
1261        Authenticator::Frontegg(frontegg) => match creds {
1262            Some(Credentials::Password { username, password }) => {
1263                let (auth_session, authenticated) =
1264                    frontegg.authenticate(&username, password.as_str()).await?;
1265                let name = auth_session.user().into();
1266                let external_metadata_rx = Some(auth_session.external_metadata_rx());
1267                (name, external_metadata_rx, authenticated, None)
1268            }
1269            Some(Credentials::Token { token }) => {
1270                let (claims, authenticated) = frontegg.validate_access_token(&token, None)?;
1271                let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
1272                    user_id: claims.user_id,
1273                    admin: claims.is_admin,
1274                });
1275                (claims.user, Some(external_metadata_rx), authenticated, None)
1276            }
1277            None => {
1278                return Err(AuthError::MissingHttpAuthentication {
1279                    include_www_authenticate_header,
1280                });
1281            }
1282        },
1283        Authenticator::Password(adapter_client) => match creds {
1284            Some(Credentials::Password { username, password }) => {
1285                let authenticated = adapter_client
1286                    .authenticate(&username, &password)
1287                    .await
1288                    .map_err(|_| AuthError::InvalidCredentials)?;
1289                (username, None, authenticated, None)
1290            }
1291            _ => {
1292                return Err(AuthError::MissingHttpAuthentication {
1293                    include_www_authenticate_header,
1294                });
1295            }
1296        },
1297        Authenticator::Sasl(_) => {
1298            // We shouldn't ever end up here as the configuration is validated at startup.
1299            // If we do, it's a server misconfiguration.
1300            // Just in case, we return a 401 rather than panic.
1301            return Err(AuthError::MissingHttpAuthentication {
1302                include_www_authenticate_header,
1303            });
1304        }
1305        Authenticator::Oidc(oidc) => match creds {
1306            Some(Credentials::Token { token }) => {
1307                let (mut claims, authenticated) = oidc
1308                    .authenticate(&token, None)
1309                    .await
1310                    .map_err(|e| AuthError::OidcFailed(e.to_string()))?;
1311                let name = std::mem::take(&mut claims.user);
1312                let groups = claims.groups.take();
1313                (name, None, authenticated, groups)
1314            }
1315            _ => {
1316                return Err(AuthError::MissingHttpAuthentication {
1317                    include_www_authenticate_header,
1318                });
1319            }
1320        },
1321        Authenticator::None => {
1322            // If no authentication, use whatever is in the HTTP auth
1323            // header (without checking the password), or fall back to the
1324            // default user.
1325            let name = match creds {
1326                Some(Credentials::Password { username, .. }) => username,
1327                _ => HTTP_DEFAULT_USER.name.to_owned(),
1328            };
1329            (name, None, Authenticated, None)
1330        }
1331    };
1332
1333    check_role_allowed(&name, allowed_roles)?;
1334
1335    Ok(AuthedUser {
1336        name,
1337        external_metadata_rx,
1338        authenticated,
1339        authenticator_kind: authenticator.kind(),
1340        groups,
1341    })
1342}
1343
1344// TODO move this somewhere it can be shared with PGWIRE
1345fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1346    let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1347    // this is a superset of internal users
1348    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1349    let role_allowed = match allowed_roles {
1350        AllowedRoles::Normal => !is_reserved_user,
1351        AllowedRoles::Internal => is_internal_user,
1352        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1353    };
1354    if role_allowed {
1355        Ok(())
1356    } else {
1357        Err(AuthError::RoleDisallowed(name.to_owned()))
1358    }
1359}
1360
1361/// Default layers that should be applied to all routes, and should get applied to both the
1362/// internal http and external http routers.
1363trait DefaultLayers {
1364    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1365}
1366
1367impl DefaultLayers for Router {
1368    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1369        self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1370            .layer(metrics::PrometheusLayer::new(source, metrics))
1371    }
1372}
1373
1374/// Glue code to make [`tower`] work with [`axum`].
1375///
1376/// `axum` requires `Layer`s not return Errors, i.e. they must be `Result<_, Infallible>`,
1377/// instead you must return a type that can be converted into a response. `tower` on the other
1378/// hand does return Errors, so to make the two work together we need to convert our `tower` errors
1379/// into responses.
1380async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1381    if error.is::<tower::load_shed::error::Overloaded>() {
1382        return (
1383            StatusCode::TOO_MANY_REQUESTS,
1384            Cow::from("too many requests, try again later"),
1385        );
1386    }
1387
1388    // Note: This should be unreachable because at the time of writing our only use case is a
1389    // layer that emits `tower::load_shed::error::Overloaded`, which is handled above.
1390    (
1391        StatusCode::INTERNAL_SERVER_ERROR,
1392        Cow::from(format!("Unhandled internal error: {}", error)),
1393    )
1394}
1395
1396#[derive(Debug, Deserialize, Serialize, PartialEq)]
1397pub struct LoginCredentials {
1398    username: String,
1399    password: Password,
1400}
1401
1402#[derive(Debug, Clone, Serialize, Deserialize)]
1403pub struct TowerSessionData {
1404    username: String,
1405    created_at: SystemTime,
1406    last_activity: SystemTime,
1407    authenticated: Authenticated,
1408    authenticator_kind: mz_auth::AuthenticatorKind,
1409}
1410
1411#[cfg(test)]
1412mod tests {
1413    use super::{AllowedRoles, check_role_allowed};
1414
1415    #[mz_ore::test]
1416    fn test_check_role_allowed() {
1417        // Internal user
1418        assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1419        assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1420        assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1421
1422        // Internal user
1423        assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1424        assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1425        assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1426
1427        // Internal user
1428        assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1429        assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1430        assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1431
1432        // Normal user
1433        assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1434        assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1435        assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1436
1437        // Normal user
1438        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1439        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1440        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1441
1442        // Normal user
1443        assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1444        assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1445        assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1446
1447        // Denied by reserved role prefix
1448        assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1449        assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1450        assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1451
1452        // Denied by reserved role prefix
1453        assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1454        assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1455        assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1456
1457        // Denied by reserved role prefix
1458        assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1459        assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1460        assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1461
1462        // Denied by literal PUBLIC
1463        assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1464        assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1465        assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1466    }
1467}