Skip to main content

mz_environmentd/
http.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Embedded HTTP server.
11//!
12//! environmentd embeds an HTTP server for introspection into the running
13//! process. At the moment, its primary exports are Prometheus metrics, heap
14//! profiles, and catalog dumps.
15//!
16//! ## Authentication flow
17//!
18//! The server supports several authentication modes, controlled by the
19//! configured [`listeners::AuthenticatorKind`]. The general flow is:
20//!
21//! 1. **Identity resolution.** An authentication middleware runs on every
22//!    protected request and resolves the caller's identity via one of:
23//!    - **Credentials in headers.** The caller supplies a username/password or
24//!      token in the request headers. Supported by all [`listeners::AuthenticatorKind`]s.
25//!    - **Session reuse.** If the caller has an active authenticated session
26//!      (established via `POST /api/login`) and has not supplied credentials
27//!      in the request headers, the session is reused. Only available for
28//!      [`listeners::AuthenticatorKind::Password`] and [`listeners::AuthenticatorKind::Oidc`].
29//!    - **Trusted header injection.** A trusted upstream proxy (e.g. Teleport)
30//!      may inject the caller's identity into the request headers. Only available
31//!      for [`listeners::AuthenticatorKind::None`].
32//!
33//! 2. **Session initialization.** Once the caller's identity is known, an
34//!    adapter session is opened on their behalf. This happens as part of
35//!    request processing, after all middleware has run.
36//!
37//! 3. **Request handling.** The handler executes the request (e.g. runs SQL)
38//!    using the initialized adapter session.
39//!
40//! ### WebSocket
41//!
42//! The WebSocket flow is identical to the HTTP flow with two differences:
43//!
44//! - Credentials are not read from request headers. Instead, the first
45//!   message sent by the client is treated as the authentication message.
46//! - Session initialization (step 2) happens inside the WebSocket handler
47//!   itself, rather than as a separate middleware step.
48
49// Axum handlers must use async, but often don't actually use `await`.
50#![allow(clippy::unused_async)]
51
52use std::borrow::Cow;
53use std::collections::BTreeMap;
54use std::fmt::Debug;
55use std::net::{IpAddr, SocketAddr};
56use std::pin::Pin;
57use std::sync::Arc;
58use std::time::{Duration, SystemTime};
59
60use anyhow::Context;
61use axum::error_handling::HandleErrorLayer;
62use axum::extract::ws::{Message, WebSocket};
63use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
64use axum::middleware::{self, Next};
65use axum::response::{IntoResponse, Redirect, Response};
66use axum::{Extension, Json, Router, routing};
67use futures::future::{Shared, TryFutureExt};
68use headers::authorization::{Authorization, Basic, Bearer};
69use headers::{HeaderMapExt, HeaderName};
70use http::header::{AUTHORIZATION, CONTENT_TYPE};
71use http::uri::Scheme;
72use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
73use hyper_openssl::SslStream;
74use hyper_openssl::client::legacy::MaybeHttpsStream;
75use hyper_util::rt::TokioIo;
76use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
77use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
78use mz_auth::Authenticated;
79use mz_auth::password::Password;
80use mz_authenticator::Authenticator;
81use mz_controller::ReplicaHttpLocator;
82use mz_frontegg_auth::Error as FronteggError;
83use mz_http_util::DynamicFilterTarget;
84use mz_ore::cast::u64_to_usize;
85use mz_ore::metrics::MetricsRegistry;
86use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
87use mz_ore::str::StrExt;
88use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
89use mz_repr::user::ExternalUserMetadata;
90use mz_server_core::listeners::{self, AllowedRoles, HttpRoutesEnabled};
91use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
92use mz_sql::session::metadata::SessionMetadata;
93use mz_sql::session::user::{
94    HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
95};
96use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
97use openssl::ssl::Ssl;
98use prometheus::{
99    COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
100};
101use serde::{Deserialize, Serialize};
102use serde_json::json;
103use thiserror::Error;
104use tokio::io::AsyncWriteExt;
105use tokio::sync::oneshot::Receiver;
106use tokio::sync::{oneshot, watch};
107use tokio_metrics::TaskMetrics;
108use tower::limit::GlobalConcurrencyLimitLayer;
109use tower::{Service, ServiceBuilder};
110use tower_http::cors::{AllowOrigin, Any, CorsLayer};
111use tower_sessions::{
112    MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
113    SessionManagerLayer as TowerSessionManagerLayer,
114};
115use tracing::warn;
116
117use crate::BUILD_INFO;
118use crate::deployment::state::DeploymentStateHandle;
119use crate::http::sql::{ExistingUser, SqlError};
120
121mod catalog;
122mod cluster;
123mod console;
124mod mcp;
125mod memory;
126mod metrics;
127mod metrics_viz;
128mod probe;
129mod prometheus;
130mod root;
131mod sql;
132mod webhook;
133
134pub use metrics::Metrics;
135pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
136
137/// Maximum allowed size for a request.
138pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
139
140const SESSION_DURATION: Duration = Duration::from_secs(8 * 3600); // 8 hours
141
142const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
143
144#[derive(Debug)]
145pub struct HttpConfig {
146    pub source: &'static str,
147    pub tls: Option<ReloadingSslContext>,
148    pub authenticator_kind: listeners::AuthenticatorKind,
149    pub frontegg: Option<mz_frontegg_auth::Authenticator>,
150    pub oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
151    pub adapter_client_rx: Shared<Receiver<Client>>,
152    pub allowed_origin: AllowOrigin,
153    /// Raw list of allowed CORS origins, used by the MCP endpoints for
154    /// server-side Origin validation to defend against DNS rebinding.
155    pub allowed_origin_list: Vec<HeaderValue>,
156    pub active_connection_counter: ConnectionCounter,
157    pub helm_chart_version: Option<String>,
158    pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
159    pub metrics: Metrics,
160    pub metrics_registry: MetricsRegistry,
161    pub allowed_roles: AllowedRoles,
162    pub internal_route_config: Arc<InternalRouteConfig>,
163    pub routes_enabled: HttpRoutesEnabled,
164    /// Locator for cluster replica HTTP addresses, used for proxying requests.
165    pub replica_http_locator: Arc<ReplicaHttpLocator>,
166}
167
168#[derive(Debug, Clone)]
169pub struct InternalRouteConfig {
170    pub deployment_state_handle: DeploymentStateHandle,
171    pub internal_console_redirect_url: Option<String>,
172}
173
174#[derive(Clone)]
175pub struct WsState {
176    frontegg: Option<mz_frontegg_auth::Authenticator>,
177    oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
178    authenticator_kind: listeners::AuthenticatorKind,
179    adapter_client_rx: Delayed<mz_adapter::Client>,
180    active_connection_counter: ConnectionCounter,
181    helm_chart_version: Option<String>,
182    allowed_roles: AllowedRoles,
183}
184
185#[derive(Clone)]
186pub struct WebhookState {
187    adapter_client_rx: Delayed<mz_adapter::Client>,
188    webhook_cache: WebhookAppenderCache,
189}
190
191#[derive(Debug)]
192pub struct HttpServer {
193    tls: Option<ReloadingSslContext>,
194    router: Router,
195}
196
197impl HttpServer {
198    pub fn new(
199        HttpConfig {
200            source,
201            tls,
202            authenticator_kind,
203            frontegg,
204            oidc_rx,
205            adapter_client_rx,
206            allowed_origin,
207            allowed_origin_list,
208            active_connection_counter,
209            helm_chart_version,
210            concurrent_webhook_req,
211            metrics,
212            metrics_registry,
213            allowed_roles,
214            internal_route_config,
215            routes_enabled,
216            replica_http_locator,
217        }: HttpConfig,
218    ) -> HttpServer {
219        let tls_enabled = tls.is_some();
220        let webhook_cache = WebhookAppenderCache::new();
221
222        // Create secure session store and manager
223        let session_store = TowerSessionMemoryStore::default();
224        let session_layer = TowerSessionManagerLayer::new(session_store)
225            .with_secure(tls_enabled) // Enforce HTTPS
226            .with_same_site(tower_sessions::cookie::SameSite::Strict) // Prevent CSRF
227            .with_http_only(true) // Prevent XSS
228            .with_name("mz_session") // Custom cookie name
229            .with_path("/"); // Set cookie path
230
231        let frontegg_middleware = frontegg.clone();
232        let oidc_middleware_rx = oidc_rx.clone();
233        let adapter_client_middleware_rx = adapter_client_rx.clone();
234        let auth_middleware = middleware::from_fn(move |req, next| {
235            let frontegg = frontegg_middleware.clone();
236            let oidc_rx = oidc_middleware_rx.clone();
237            let adapter_client_rx = adapter_client_middleware_rx.clone();
238            async move {
239                http_auth(
240                    req,
241                    next,
242                    tls_enabled,
243                    authenticator_kind,
244                    frontegg,
245                    oidc_rx,
246                    adapter_client_rx,
247                    allowed_roles,
248                )
249                .await
250            }
251        });
252
253        let mut router = Router::new();
254        let mut base_router = Router::new();
255        if routes_enabled.base {
256            base_router = base_router
257                .route(
258                    "/",
259                    routing::get(move || async move { root::handle_home(routes_enabled).await }),
260                )
261                .route("/api/sql", routing::post(sql::handle_sql))
262                .route("/memory", routing::get(memory::handle_memory))
263                .route(
264                    "/hierarchical-memory",
265                    routing::get(memory::handle_hierarchical_memory),
266                )
267                .route(
268                    "/metrics-viz",
269                    routing::get(metrics_viz::handle_metrics_viz),
270                )
271                .route("/static/{*path}", routing::get(root::handle_static));
272
273            let mut ws_router = Router::new()
274                .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
275                .with_state(WsState {
276                    frontegg,
277                    oidc_rx: oidc_rx.clone(),
278                    authenticator_kind,
279                    adapter_client_rx: adapter_client_rx.clone(),
280                    active_connection_counter: active_connection_counter.clone(),
281                    helm_chart_version,
282                    allowed_roles,
283                });
284            if let listeners::AuthenticatorKind::None = authenticator_kind {
285                ws_router = ws_router.layer(middleware::from_fn_with_state(
286                    allowed_roles,
287                    x_materialize_user_header_auth,
288                ));
289            }
290            router = router.merge(ws_router);
291        }
292        if routes_enabled.profiling {
293            base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
294        }
295
296        if routes_enabled.webhook {
297            let webhook_router = Router::new()
298                .route(
299                    "/api/webhook/{:database}/{:schema}/{:id}",
300                    routing::post(webhook::handle_webhook),
301                )
302                .with_state(WebhookState {
303                    adapter_client_rx: adapter_client_rx.clone(),
304                    webhook_cache,
305                })
306                .layer(
307                    tower_http::decompression::RequestDecompressionLayer::new()
308                        .gzip(true)
309                        .deflate(true)
310                        .br(true)
311                        .zstd(true),
312                )
313                .layer(
314                    CorsLayer::new()
315                        .allow_methods(Method::POST)
316                        .allow_origin(AllowOrigin::mirror_request())
317                        .allow_headers(Any),
318                )
319                .layer(
320                    ServiceBuilder::new()
321                        .layer(HandleErrorLayer::new(handle_load_error))
322                        .load_shed()
323                        .layer(GlobalConcurrencyLimitLayer::with_semaphore(
324                            concurrent_webhook_req,
325                        )),
326                );
327            router = router.merge(webhook_router);
328        }
329
330        if routes_enabled.internal {
331            let console_config = Arc::new(console::ConsoleProxyConfig::new(
332                internal_route_config.internal_console_redirect_url.clone(),
333                "/internal-console".to_string(),
334            ));
335            base_router = base_router
336                .route(
337                    "/api/opentelemetry/config",
338                    routing::put({
339                        move |_: axum::Json<DynamicFilterTarget>| async {
340                            (
341                                StatusCode::BAD_REQUEST,
342                                "This endpoint has been replaced. \
343                            Use the `opentelemetry_filter` system variable."
344                                    .to_string(),
345                            )
346                        }
347                    }),
348                )
349                .route(
350                    "/api/stderr/config",
351                    routing::put({
352                        move |_: axum::Json<DynamicFilterTarget>| async {
353                            (
354                                StatusCode::BAD_REQUEST,
355                                "This endpoint has been replaced. \
356                            Use the `log_filter` system variable."
357                                    .to_string(),
358                            )
359                        }
360                    }),
361                )
362                .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
363                .route(
364                    "/api/catalog/dump",
365                    routing::get(catalog::handle_catalog_dump),
366                )
367                .route(
368                    "/api/catalog/check",
369                    routing::get(catalog::handle_catalog_check),
370                )
371                .route(
372                    "/api/catalog/inject-audit-events",
373                    routing::post(catalog::handle_inject_audit_events),
374                )
375                .route(
376                    "/api/coordinator/check",
377                    routing::get(catalog::handle_coordinator_check),
378                )
379                .route(
380                    "/api/coordinator/dump",
381                    routing::get(catalog::handle_coordinator_dump),
382                )
383                .route(
384                    "/internal-console",
385                    routing::get(|| async { Redirect::temporary("/internal-console/") }),
386                )
387                .route(
388                    "/internal-console/{*path}",
389                    routing::get(console::handle_internal_console),
390                )
391                .route(
392                    "/internal-console/",
393                    routing::get(console::handle_internal_console),
394                )
395                .layer(Extension(console_config));
396
397            // Cluster HTTP proxy routes.
398            let cluster_proxy_config = Arc::new(cluster::ClusterProxyConfig::new(Arc::clone(
399                &replica_http_locator,
400            )));
401            base_router = base_router
402                .route("/clusters", routing::get(cluster::handle_clusters))
403                .route(
404                    "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/",
405                    routing::any(cluster::handle_cluster_proxy_root),
406                )
407                .route(
408                    "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/{*path}",
409                    routing::any(cluster::handle_cluster_proxy),
410                )
411                .layer(Extension(cluster_proxy_config));
412
413            let leader_router = Router::new()
414                .route("/api/leader/status", routing::get(handle_leader_status))
415                .route("/api/leader/promote", routing::post(handle_leader_promote))
416                .route(
417                    "/api/leader/skip-catchup",
418                    routing::post(handle_leader_skip_catchup),
419                )
420                .layer(auth_middleware.clone())
421                .with_state(internal_route_config.deployment_state_handle.clone());
422            router = router.merge(leader_router);
423        }
424
425        if routes_enabled.metrics {
426            let metrics_router = Router::new()
427                .route(
428                    "/metrics",
429                    routing::get(move || async move {
430                        mz_http_util::handle_prometheus(&metrics_registry).await
431                    }),
432                )
433                .route(
434                    "/metrics/mz_usage",
435                    routing::get(|client: AuthedClient| async move {
436                        let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
437                        mz_http_util::handle_prometheus(&registry).await
438                    }),
439                )
440                .route(
441                    "/metrics/mz_frontier",
442                    routing::get(|client: AuthedClient| async move {
443                        let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
444                        mz_http_util::handle_prometheus(&registry).await
445                    }),
446                )
447                .route(
448                    "/metrics/mz_compute",
449                    routing::get(|client: AuthedClient| async move {
450                        let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
451                        mz_http_util::handle_prometheus(&registry).await
452                    }),
453                )
454                .route(
455                    "/metrics/mz_storage",
456                    routing::get(|client: AuthedClient| async move {
457                        let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
458                        mz_http_util::handle_prometheus(&registry).await
459                    }),
460                )
461                .route(
462                    "/api/livez",
463                    routing::get(mz_http_util::handle_liveness_check),
464                )
465                .route("/api/readyz", routing::get(probe::handle_ready))
466                .layer(auth_middleware.clone())
467                .layer(Extension(adapter_client_rx.clone()))
468                .layer(Extension(active_connection_counter.clone()));
469            router = router.merge(metrics_router);
470        }
471
472        if routes_enabled.console_config {
473            let console_config_router = Router::new()
474                .route(
475                    "/api/console/config",
476                    routing::get(console::handle_console_config),
477                )
478                .layer(Extension(adapter_client_rx.clone()))
479                .layer(Extension(active_connection_counter.clone()));
480            router = router.merge(console_config_router);
481        }
482
483        // MCP (Model Context Protocol) endpoints
484        // Enabled via runtime `routes_enabled.mcp_agent` and `routes_enabled.mcp_developer` configuration
485        if routes_enabled.mcp_agent || routes_enabled.mcp_developer {
486            use tracing::info;
487
488            let mut mcp_router = Router::new();
489
490            if routes_enabled.mcp_agent {
491                info!("Enabling MCP agent endpoint: /api/mcp/agent");
492                mcp_router = mcp_router.route(
493                    "/api/mcp/agent",
494                    routing::post(mcp::handle_mcp_agent).get(mcp::handle_mcp_method_not_allowed),
495                );
496            }
497
498            if routes_enabled.mcp_developer {
499                info!("Enabling MCP developer endpoint: /api/mcp/developer");
500                mcp_router = mcp_router.route(
501                    "/api/mcp/developer",
502                    routing::post(mcp::handle_mcp_developer)
503                        .get(mcp::handle_mcp_method_not_allowed),
504                );
505            }
506
507            // The MCP handlers perform a server-side Origin check against this
508            // allowlist to defend against DNS rebinding attacks (see
509            // database-issues#11311). The CorsLayer alone is not enough: in a
510            // DNS rebinding attack the browser considers the request
511            // same-origin, so no preflight fires and CORS enforcement is
512            // bypassed.
513            let mcp_allowed_origins = Arc::new(allowed_origin_list.clone());
514            mcp_router = mcp_router
515                .layer(auth_middleware.clone())
516                .layer(Extension(adapter_client_rx.clone()))
517                .layer(Extension(active_connection_counter.clone()))
518                .layer(Extension(mcp_allowed_origins))
519                .layer(
520                    CorsLayer::new()
521                        .allow_methods(Method::POST)
522                        .allow_origin(allowed_origin.clone())
523                        .allow_headers([AUTHORIZATION, CONTENT_TYPE]),
524                );
525            router = router.merge(mcp_router);
526        }
527
528        base_router = base_router
529            .layer(auth_middleware.clone())
530            .layer(Extension(adapter_client_rx.clone()))
531            .layer(Extension(active_connection_counter.clone()))
532            .layer(
533                CorsLayer::new()
534                    .allow_credentials(false)
535                    .allow_headers([
536                        AUTHORIZATION,
537                        CONTENT_TYPE,
538                        HeaderName::from_static("x-materialize-version"),
539                    ])
540                    .allow_methods(Any)
541                    .allow_origin(allowed_origin)
542                    .expose_headers(Any)
543                    .max_age(Duration::from_secs(60) * 60),
544            );
545
546        match authenticator_kind {
547            listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Oidc => {
548                base_router = base_router.layer(session_layer.clone());
549
550                let login_router = Router::new()
551                    .route("/api/login", routing::post(handle_login))
552                    .route("/api/logout", routing::post(handle_logout))
553                    .layer(Extension(adapter_client_rx))
554                    .layer(Extension(allowed_roles));
555                router = router.merge(login_router).layer(session_layer);
556            }
557            listeners::AuthenticatorKind::None => {
558                base_router = base_router.layer(middleware::from_fn_with_state(
559                    allowed_roles,
560                    x_materialize_user_header_auth,
561                ));
562            }
563            _ => {}
564        }
565
566        router = router
567            .merge(base_router)
568            .apply_default_layers(source, metrics);
569
570        HttpServer { tls, router }
571    }
572}
573
574impl Server for HttpServer {
575    const NAME: &'static str = "http";
576
577    fn handle_connection(
578        &self,
579        conn: Connection,
580        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
581    ) -> ConnectionHandler {
582        let router = self.router.clone();
583        let tls_context = self.tls.clone();
584        let mut conn = TokioIo::new(conn);
585
586        Box::pin(async {
587            let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
588            let peer_addr = conn
589                .inner_mut()
590                .take_proxy_header_address()
591                .await
592                .map(|a| a.source)
593                .unwrap_or(direct_peer_addr);
594
595            let (conn, conn_protocol) = match tls_context {
596                Some(tls_context) => {
597                    let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
598                    if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
599                        let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
600                        return Err(e.into());
601                    }
602                    (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
603                }
604                _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
605            };
606            let mut make_tower_svc = router
607                .layer(Extension(conn_protocol))
608                .into_make_service_with_connect_info::<SocketAddr>();
609            let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
610            let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
611            let http = hyper::server::conn::http1::Builder::new();
612            http.serve_connection(conn, hyper_svc)
613                .with_upgrades()
614                .err_into()
615                .await
616        })
617    }
618}
619
620pub async fn handle_leader_status(
621    State(deployment_state_handle): State<DeploymentStateHandle>,
622) -> impl IntoResponse {
623    let status = deployment_state_handle.status();
624    (StatusCode::OK, Json(json!({ "status": status })))
625}
626
627pub async fn handle_leader_promote(
628    State(deployment_state_handle): State<DeploymentStateHandle>,
629) -> impl IntoResponse {
630    match deployment_state_handle.try_promote() {
631        Ok(()) => {
632            // TODO(benesch): the body here is redundant. Should just return
633            // 204.
634            let status = StatusCode::OK;
635            let body = Json(json!({
636                "result": "Success",
637            }));
638            (status, body)
639        }
640        Err(()) => {
641            // TODO(benesch): the nesting here is redundant given the error
642            // code. Should just return the `{"message": "..."}` object.
643            let status = StatusCode::BAD_REQUEST;
644            let body = Json(json!({
645                "result": {"Failure": {"message": "cannot promote leader while initializing"}},
646            }));
647            (status, body)
648        }
649    }
650}
651
652pub async fn handle_leader_skip_catchup(
653    State(deployment_state_handle): State<DeploymentStateHandle>,
654) -> impl IntoResponse {
655    match deployment_state_handle.try_skip_catchup() {
656        Ok(()) => StatusCode::NO_CONTENT.into_response(),
657        Err(()) => {
658            let status = StatusCode::BAD_REQUEST;
659            let body = Json(json!({
660                "message": "cannot skip catchup in this phase of initialization; try again later",
661            }));
662            (status, body).into_response()
663        }
664    }
665}
666
667async fn x_materialize_user_header_auth(
668    State(allowed_roles): State<AllowedRoles>,
669    mut req: Request,
670    next: Next,
671) -> impl IntoResponse {
672    // TODO migrate teleport to basic auth and remove this.
673    if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
674        let username = match username {
675            Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
676            _ => {
677                return Err(AuthError::MismatchedUser(format!(
678                    "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
679                )));
680            }
681        };
682        // Enforce the listener's `allowed_roles` policy here. Without this,
683        // a listener with `authenticator_kind=None` and `allowed_roles=Normal`
684        // would let any caller assert `x-materialize-user: mz_system` and
685        // bypass the role restriction.
686        check_role_allowed(&username, allowed_roles)?;
687        req.extensions_mut().insert(AuthedUser {
688            name: username,
689            external_metadata_rx: None,
690            authenticated: Authenticated,
691            authenticator_kind: mz_auth::AuthenticatorKind::None,
692        });
693    }
694    Ok(next.run(req).await)
695}
696
697type Delayed<T> = Shared<oneshot::Receiver<T>>;
698
699#[derive(Clone)]
700enum ConnProtocol {
701    Http,
702    Https,
703}
704
705#[derive(Clone, Debug)]
706pub struct AuthedUser {
707    name: String,
708    external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
709    authenticated: Authenticated,
710    authenticator_kind: mz_auth::AuthenticatorKind,
711}
712
713pub struct AuthedClient {
714    pub client: SessionClient,
715    pub connection_guard: Option<ConnectionHandle>,
716}
717
718impl AuthedClient {
719    async fn new<F>(
720        adapter_client: &Client,
721        user: AuthedUser,
722        peer_addr: IpAddr,
723        active_connection_counter: ConnectionCounter,
724        helm_chart_version: Option<String>,
725        session_config: F,
726        options: BTreeMap<String, String>,
727        now: NowFn,
728    ) -> Result<Self, AdapterError>
729    where
730        F: FnOnce(&mut AdapterSession),
731    {
732        let conn_id = adapter_client.new_conn_id()?;
733        let mut session = adapter_client.new_session(
734            AdapterSessionConfig {
735                conn_id,
736                uuid: epoch_to_uuid_v7(&(now)()),
737                user: user.name,
738                client_ip: Some(peer_addr),
739                external_metadata_rx: user.external_metadata_rx,
740                helm_chart_version,
741                authenticator_kind: user.authenticator_kind,
742            },
743            user.authenticated,
744        );
745        let connection_guard = active_connection_counter.allocate_connection(session.user())?;
746
747        session_config(&mut session);
748        let system_vars = adapter_client.get_system_vars().await;
749        for (key, val) in options {
750            const LOCAL: bool = false;
751            if let Err(err) =
752                session
753                    .vars_mut()
754                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
755            {
756                session.add_notice(AdapterNotice::BadStartupSetting {
757                    name: key.to_string(),
758                    reason: err.to_string(),
759                })
760            }
761        }
762        let adapter_client = adapter_client.startup(session).await?;
763        Ok(AuthedClient {
764            client: adapter_client,
765            connection_guard,
766        })
767    }
768}
769
770impl<S> FromRequestParts<S> for AuthedClient
771where
772    S: Send + Sync,
773{
774    type Rejection = Response;
775
776    async fn from_request_parts(
777        req: &mut http::request::Parts,
778        state: &S,
779    ) -> Result<Self, Self::Rejection> {
780        #[derive(Debug, Default, Deserialize)]
781        struct Params {
782            #[serde(default)]
783            options: String,
784        }
785        let params: Query<Params> = Query::from_request_parts(req, state)
786            .await
787            .unwrap_or_default();
788
789        let peer_addr = req
790            .extensions
791            .get::<ConnectInfo<SocketAddr>>()
792            .expect("ConnectInfo extension guaranteed to exist")
793            .0
794            .ip();
795
796        let user = req.extensions.get::<AuthedUser>().unwrap();
797        let adapter_client = req
798            .extensions
799            .get::<Delayed<mz_adapter::Client>>()
800            .unwrap()
801            .clone();
802        let adapter_client = adapter_client.await.map_err(|_| {
803            (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
804        })?;
805        let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
806        let helm_chart_version = None;
807
808        let options = if params.options.is_empty() {
809            // It's possible 'options' simply wasn't provided, we don't want that to
810            // count as a failure to deserialize
811            BTreeMap::<String, String>::default()
812        } else {
813            match serde_json::from_str(&params.options) {
814                Ok(options) => options,
815                Err(_e) => {
816                    // If we fail to deserialize options, fail the request.
817                    let code = StatusCode::BAD_REQUEST;
818                    let msg = format!("Failed to deserialize {} map", "options".quoted());
819                    return Err((code, msg).into_response());
820                }
821            }
822        };
823
824        let client = AuthedClient::new(
825            &adapter_client,
826            user.clone(),
827            peer_addr,
828            active_connection_counter.clone(),
829            helm_chart_version,
830            |session| {
831                session
832                    .vars_mut()
833                    .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
834                    .expect("known to exist")
835            },
836            options,
837            SYSTEM_TIME.clone(),
838        )
839        .await
840        .map_err(|e| {
841            let status = match e {
842                AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
843                    StatusCode::FORBIDDEN
844                }
845                _ => StatusCode::INTERNAL_SERVER_ERROR,
846            };
847            (status, Json(SqlError::from(e))).into_response()
848        })?;
849
850        Ok(client)
851    }
852}
853
854#[derive(Debug, Error)]
855pub(crate) enum AuthError {
856    #[error("role dissallowed")]
857    RoleDisallowed(String),
858    #[error("{0}")]
859    Frontegg(#[from] FronteggError),
860    #[error("missing authorization header")]
861    MissingHttpAuthentication {
862        include_www_authenticate_header: bool,
863    },
864    #[error("{0}")]
865    MismatchedUser(String),
866    #[error("session expired")]
867    SessionExpired,
868    #[error("failed to update session")]
869    FailedToUpdateSession,
870    #[error("invalid credentials")]
871    InvalidCredentials,
872}
873
874impl IntoResponse for AuthError {
875    fn into_response(self) -> Response {
876        warn!("HTTP request failed authentication: {}", self);
877        let mut headers = HeaderMap::new();
878        match self {
879            AuthError::MissingHttpAuthentication {
880                include_www_authenticate_header,
881            } if include_www_authenticate_header => {
882                headers.insert(
883                    http::header::WWW_AUTHENTICATE,
884                    HeaderValue::from_static("Basic realm=Materialize"),
885                );
886            }
887            _ => {}
888        };
889        // We omit most detail from the error message we send to the client, to
890        // avoid giving attackers unnecessary information.
891        (StatusCode::UNAUTHORIZED, headers, "unauthorized").into_response()
892    }
893}
894
895// Simplified login handler
896pub async fn handle_login(
897    session: Option<Extension<TowerSession>>,
898    Extension(adapter_client_rx): Extension<Delayed<Client>>,
899    Extension(allowed_roles): Extension<AllowedRoles>,
900    Json(LoginCredentials { username, password }): Json<LoginCredentials>,
901) -> impl IntoResponse {
902    // Enforce the listener's allowed_roles policy before doing any
903    // authentication work, mirroring the check performed in `auth` for
904    // header-based credentials. Without this, a session-based caller could
905    // log in as a role that header-based callers are forbidden to use.
906    if let Err(err) = check_role_allowed(&username, allowed_roles) {
907        warn!(
908            ?err,
909            "HTTP login rejected: role not allowed on this listener"
910        );
911        return StatusCode::UNAUTHORIZED;
912    }
913    let Ok(adapter_client) = adapter_client_rx.clone().await else {
914        return StatusCode::INTERNAL_SERVER_ERROR;
915    };
916    let authenticated = match adapter_client.authenticate(&username, &password).await {
917        Ok(authenticated) => authenticated,
918        Err(err) => {
919            warn!(?err, "HTTP login failed authentication");
920            return StatusCode::UNAUTHORIZED;
921        }
922    };
923    // Create session data
924    let session_data = TowerSessionData {
925        username,
926        created_at: SystemTime::now(),
927        last_activity: SystemTime::now(),
928        authenticated,
929        authenticator_kind: mz_auth::AuthenticatorKind::Password,
930    };
931    // Store session data
932    let session = session.and_then(|Extension(session)| Some(session));
933    let Some(session) = session else {
934        return StatusCode::INTERNAL_SERVER_ERROR;
935    };
936    match session.insert("data", &session_data).await {
937        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
938        Ok(_) => StatusCode::OK,
939    }
940}
941
942// Simplified logout handler
943pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
944    let session = session.and_then(|Extension(session)| Some(session));
945    let Some(session) = session else {
946        return StatusCode::INTERNAL_SERVER_ERROR;
947    };
948    // Delete session
949    match session.delete().await {
950        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
951        Ok(_) => StatusCode::OK,
952    }
953}
954
955async fn http_auth(
956    mut req: Request,
957    next: Next,
958    tls_enabled: bool,
959    authenticator_kind: listeners::AuthenticatorKind,
960    frontegg: Option<mz_frontegg_auth::Authenticator>,
961    oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
962    adapter_client_rx: Delayed<Client>,
963    allowed_roles: AllowedRoles,
964) -> Result<impl IntoResponse, AuthError> {
965    let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
966        Some(Credentials::Password {
967            username: basic.username().to_owned(),
968            password: Password(basic.password().to_owned()),
969        })
970    } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
971        Some(Credentials::Token {
972            token: bearer.token().to_owned(),
973        })
974    } else {
975        None
976    };
977
978    // Reuses an authenticated session if one already exists.
979    // If credentials are provided, we perform a new authentication,
980    // separate from the existing session.
981    if creds.is_none()
982        && let Some((session, session_data)) =
983            maybe_get_authenticated_session(req.extensions().get::<TowerSession>()).await
984    {
985        let user = ensure_session_unexpired(session, session_data).await?;
986        // Defense-in-depth: re-check the listener's `allowed_roles` policy on
987        // every session-authenticated request. The same check runs at
988        // `/api/login`, but enforcing it here too prevents a session minted
989        // under a more permissive configuration (or a future bug in the login
990        // path) from bypassing role restrictions.
991        check_role_allowed(&user.name, allowed_roles)?;
992        // Need this to set the user of the Adapter client.
993        req.extensions_mut().insert(user);
994        return Ok(next.run(req).await);
995    }
996
997    // First, extract the username from the certificate, validating that the
998    // connection matches the TLS configuration along the way.
999    // Fall back to existing authentication methods.
1000    let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
1001    match (tls_enabled, &conn_protocol) {
1002        (false, ConnProtocol::Http) => {}
1003        (false, ConnProtocol::Https { .. }) => unreachable!(),
1004        (true, ConnProtocol::Http) => {
1005            let mut parts = req.uri().clone().into_parts();
1006            parts.scheme = Some(Scheme::HTTPS);
1007            return Ok(Redirect::permanent(
1008                &Uri::from_parts(parts)
1009                    .expect("it was already a URI, just changed the scheme")
1010                    .to_string(),
1011            )
1012            .into_response());
1013        }
1014        (true, ConnProtocol::Https { .. }) => {}
1015    }
1016    // If we've already passed some other auth, just use that.
1017    if req.extensions().get::<AuthedUser>().is_some() {
1018        return Ok(next.run(req).await);
1019    }
1020
1021    let path = req.uri().path();
1022    let include_www_authenticate_header = path == "/"
1023        || PROFILING_API_ENDPOINTS
1024            .iter()
1025            .any(|prefix| path.starts_with(prefix));
1026    let authenticator = get_authenticator(
1027        authenticator_kind,
1028        creds.as_ref(),
1029        frontegg,
1030        &oidc_rx,
1031        &adapter_client_rx,
1032    )
1033    .await;
1034
1035    let user = auth(
1036        &authenticator,
1037        creds,
1038        allowed_roles,
1039        include_www_authenticate_header,
1040    )
1041    .await?;
1042
1043    // Add the authenticated user as an extension so downstream handlers can
1044    // inspect it if necessary.
1045    req.extensions_mut().insert(user);
1046
1047    // Run the request.
1048    Ok(next.run(req).await)
1049}
1050
1051async fn init_ws(
1052    WsState {
1053        frontegg,
1054        oidc_rx,
1055        authenticator_kind,
1056        adapter_client_rx,
1057        active_connection_counter,
1058        helm_chart_version,
1059        allowed_roles,
1060    }: WsState,
1061    existing_user: Option<ExistingUser>,
1062    peer_addr: IpAddr,
1063    ws: &mut WebSocket,
1064) -> Result<AuthedClient, anyhow::Error> {
1065    // TODO: Add a timeout here to prevent resource leaks by clients that
1066    // connect then never send a message.
1067    let ws_auth: WebSocketAuth = loop {
1068        let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
1069        match init_msg {
1070            Message::Text(data) => break serde_json::from_str(&data)?,
1071            Message::Binary(data) => break serde_json::from_slice(&data)?,
1072            // Handled automatically by the server.
1073            Message::Ping(_) => {
1074                continue;
1075            }
1076            Message::Pong(_) => {
1077                continue;
1078            }
1079            Message::Close(_) => {
1080                anyhow::bail!("closed");
1081            }
1082        }
1083    };
1084
1085    // If credentials are provided, we perform a new authentication,
1086    // separate from the existing session.
1087    let (creds, options) = match ws_auth {
1088        WebSocketAuth::Basic {
1089            user,
1090            password,
1091            options,
1092        } => {
1093            let creds = Credentials::Password {
1094                username: user,
1095                password,
1096            };
1097            (Some(creds), options)
1098        }
1099        WebSocketAuth::Bearer { token, options } => {
1100            let creds = Credentials::Token { token };
1101            (Some(creds), options)
1102        }
1103        WebSocketAuth::OptionsOnly { options } => (None, options),
1104    };
1105
1106    let user = match (existing_user, creds) {
1107        (Some(ExistingUser::XMaterializeUserHeader(_)), Some(_creds)) => {
1108            warn!("Unexpected bearer or basic auth provided when using user header");
1109            anyhow::bail!("unexpected")
1110        }
1111        (Some(ExistingUser::Session(user)), None) => {
1112            // Defense-in-depth: re-enforce the listener's `allowed_roles`
1113            // policy on session-authenticated WebSocket connections. The same
1114            // check runs at `/api/login`, but enforcing it here too prevents a
1115            // session minted under a more permissive configuration (or a
1116            // future bug in the login path) from bypassing role restrictions.
1117            check_role_allowed(&user.name, allowed_roles)?;
1118            user
1119        }
1120        (Some(ExistingUser::XMaterializeUserHeader(user)), None) => user,
1121        (_, Some(creds)) => {
1122            let authenticator = get_authenticator(
1123                authenticator_kind,
1124                Some(&creds),
1125                frontegg,
1126                &oidc_rx,
1127                &adapter_client_rx,
1128            )
1129            .await;
1130            let user = auth(&authenticator, Some(creds), allowed_roles, false).await?;
1131            user
1132        }
1133        (None, None) => anyhow::bail!("expected auth information"),
1134    };
1135
1136    let client = AuthedClient::new(
1137        &adapter_client_rx.clone().await?,
1138        user,
1139        peer_addr,
1140        active_connection_counter.clone(),
1141        helm_chart_version.clone(),
1142        |_session| (),
1143        options,
1144        SYSTEM_TIME.clone(),
1145    )
1146    .await?;
1147
1148    Ok(client)
1149}
1150
1151enum Credentials {
1152    Password {
1153        username: String,
1154        password: Password,
1155    },
1156    Token {
1157        token: String,
1158    },
1159}
1160
1161async fn get_authenticator(
1162    kind: listeners::AuthenticatorKind,
1163    creds: Option<&Credentials>,
1164    frontegg: Option<mz_frontegg_auth::Authenticator>,
1165    oidc_rx: &Delayed<mz_authenticator::GenericOidcAuthenticator>,
1166    adapter_client_rx: &Delayed<Client>,
1167) -> Authenticator {
1168    match kind {
1169        listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
1170            "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
1171        )),
1172        listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Sasl => {
1173            let client = adapter_client_rx.clone().await.expect("sender not dropped");
1174            Authenticator::Password(client)
1175        }
1176        listeners::AuthenticatorKind::Oidc => match creds {
1177            // Use the password authenticator if the credentials are password-based
1178            Some(Credentials::Password { .. }) => {
1179                let client = adapter_client_rx.clone().await.expect("sender not dropped");
1180                Authenticator::Password(client)
1181            }
1182            _ => Authenticator::Oidc(oidc_rx.clone().await.expect("sender not dropped")),
1183        },
1184        listeners::AuthenticatorKind::None => Authenticator::None,
1185    }
1186}
1187
1188/// Attempts to retrieve session data from a [`TowerSession`], if available.
1189/// Session data is present only if an authenticated session has been
1190/// established via [`handle_login`].
1191pub(crate) async fn maybe_get_authenticated_session(
1192    session: Option<&TowerSession>,
1193) -> Option<(&TowerSession, TowerSessionData)> {
1194    if let Some(session) = session {
1195        if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
1196            return Some((session, session_data));
1197        }
1198    }
1199    None
1200}
1201
1202/// Ensures the session is still valid by checking for expiration,
1203/// and returns the associated user if the session remains active.
1204pub(crate) async fn ensure_session_unexpired(
1205    session: &TowerSession,
1206    session_data: TowerSessionData,
1207) -> Result<AuthedUser, AuthError> {
1208    if session_data
1209        .last_activity
1210        .elapsed()
1211        .unwrap_or(Duration::MAX)
1212        > SESSION_DURATION
1213    {
1214        let _ = session.delete().await;
1215        return Err(AuthError::SessionExpired);
1216    }
1217    let mut updated_data = session_data.clone();
1218    updated_data.last_activity = SystemTime::now();
1219    session
1220        .insert("data", &updated_data)
1221        .await
1222        .map_err(|_| AuthError::FailedToUpdateSession)?;
1223
1224    Ok(AuthedUser {
1225        name: session_data.username,
1226        external_metadata_rx: None,
1227        authenticated: session_data.authenticated,
1228        authenticator_kind: session_data.authenticator_kind,
1229    })
1230}
1231
1232async fn auth(
1233    authenticator: &Authenticator,
1234    creds: Option<Credentials>,
1235    allowed_roles: AllowedRoles,
1236    include_www_authenticate_header: bool,
1237) -> Result<AuthedUser, AuthError> {
1238    let (name, external_metadata_rx, authenticated) = match authenticator {
1239        Authenticator::Frontegg(frontegg) => match creds {
1240            Some(Credentials::Password { username, password }) => {
1241                let (auth_session, authenticated) =
1242                    frontegg.authenticate(&username, password.as_str()).await?;
1243                let name = auth_session.user().into();
1244                let external_metadata_rx = Some(auth_session.external_metadata_rx());
1245                (name, external_metadata_rx, authenticated)
1246            }
1247            Some(Credentials::Token { token }) => {
1248                let (claims, authenticated) = frontegg.validate_access_token(&token, None)?;
1249                let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
1250                    user_id: claims.user_id,
1251                    admin: claims.is_admin,
1252                });
1253                (claims.user, Some(external_metadata_rx), authenticated)
1254            }
1255            None => {
1256                return Err(AuthError::MissingHttpAuthentication {
1257                    include_www_authenticate_header,
1258                });
1259            }
1260        },
1261        Authenticator::Password(adapter_client) => match creds {
1262            Some(Credentials::Password { username, password }) => {
1263                let authenticated = adapter_client
1264                    .authenticate(&username, &password)
1265                    .await
1266                    .map_err(|_| AuthError::InvalidCredentials)?;
1267                (username, None, authenticated)
1268            }
1269            _ => {
1270                return Err(AuthError::MissingHttpAuthentication {
1271                    include_www_authenticate_header,
1272                });
1273            }
1274        },
1275        Authenticator::Sasl(_) => {
1276            // We shouldn't ever end up here as the configuration is validated at startup.
1277            // If we do, it's a server misconfiguration.
1278            // Just in case, we return a 401 rather than panic.
1279            return Err(AuthError::MissingHttpAuthentication {
1280                include_www_authenticate_header,
1281            });
1282        }
1283        Authenticator::Oidc(oidc) => match creds {
1284            Some(Credentials::Token { token }) => {
1285                // Validate JWT token
1286                let (mut claims, authenticated) = oidc
1287                    .authenticate(&token, None)
1288                    .await
1289                    .map_err(|_| AuthError::InvalidCredentials)?;
1290                let name = std::mem::take(&mut claims.user);
1291                (name, None, authenticated)
1292            }
1293            _ => {
1294                return Err(AuthError::MissingHttpAuthentication {
1295                    include_www_authenticate_header,
1296                });
1297            }
1298        },
1299        Authenticator::None => {
1300            // If no authentication, use whatever is in the HTTP auth
1301            // header (without checking the password), or fall back to the
1302            // default user.
1303            let name = match creds {
1304                Some(Credentials::Password { username, .. }) => username,
1305                _ => HTTP_DEFAULT_USER.name.to_owned(),
1306            };
1307            (name, None, Authenticated)
1308        }
1309    };
1310
1311    check_role_allowed(&name, allowed_roles)?;
1312
1313    Ok(AuthedUser {
1314        name,
1315        external_metadata_rx,
1316        authenticated,
1317        authenticator_kind: authenticator.kind(),
1318    })
1319}
1320
1321// TODO move this somewhere it can be shared with PGWIRE
1322fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1323    let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1324    // this is a superset of internal users
1325    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1326    let role_allowed = match allowed_roles {
1327        AllowedRoles::Normal => !is_reserved_user,
1328        AllowedRoles::Internal => is_internal_user,
1329        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1330    };
1331    if role_allowed {
1332        Ok(())
1333    } else {
1334        Err(AuthError::RoleDisallowed(name.to_owned()))
1335    }
1336}
1337
1338/// Default layers that should be applied to all routes, and should get applied to both the
1339/// internal http and external http routers.
1340trait DefaultLayers {
1341    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1342}
1343
1344impl DefaultLayers for Router {
1345    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1346        self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1347            .layer(metrics::PrometheusLayer::new(source, metrics))
1348    }
1349}
1350
1351/// Glue code to make [`tower`] work with [`axum`].
1352///
1353/// `axum` requires `Layer`s not return Errors, i.e. they must be `Result<_, Infallible>`,
1354/// instead you must return a type that can be converted into a response. `tower` on the other
1355/// hand does return Errors, so to make the two work together we need to convert our `tower` errors
1356/// into responses.
1357async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1358    if error.is::<tower::load_shed::error::Overloaded>() {
1359        return (
1360            StatusCode::TOO_MANY_REQUESTS,
1361            Cow::from("too many requests, try again later"),
1362        );
1363    }
1364
1365    // Note: This should be unreachable because at the time of writing our only use case is a
1366    // layer that emits `tower::load_shed::error::Overloaded`, which is handled above.
1367    (
1368        StatusCode::INTERNAL_SERVER_ERROR,
1369        Cow::from(format!("Unhandled internal error: {}", error)),
1370    )
1371}
1372
1373#[derive(Debug, Deserialize, Serialize, PartialEq)]
1374pub struct LoginCredentials {
1375    username: String,
1376    password: Password,
1377}
1378
1379#[derive(Debug, Clone, Serialize, Deserialize)]
1380pub struct TowerSessionData {
1381    username: String,
1382    created_at: SystemTime,
1383    last_activity: SystemTime,
1384    authenticated: Authenticated,
1385    authenticator_kind: mz_auth::AuthenticatorKind,
1386}
1387
1388#[cfg(test)]
1389mod tests {
1390    use super::{AllowedRoles, check_role_allowed};
1391
1392    #[mz_ore::test]
1393    fn test_check_role_allowed() {
1394        // Internal user
1395        assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1396        assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1397        assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1398
1399        // Internal user
1400        assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1401        assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1402        assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1403
1404        // Internal user
1405        assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1406        assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1407        assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1408
1409        // Normal user
1410        assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1411        assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1412        assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1413
1414        // Normal user
1415        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1416        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1417        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1418
1419        // Normal user
1420        assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1421        assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1422        assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1423
1424        // Denied by reserved role prefix
1425        assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1426        assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1427        assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1428
1429        // Denied by reserved role prefix
1430        assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1431        assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1432        assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1433
1434        // Denied by reserved role prefix
1435        assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1436        assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1437        assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1438
1439        // Denied by literal PUBLIC
1440        assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1441        assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1442        assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1443    }
1444}