Skip to main content

mz_environmentd/
http.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Embedded HTTP server.
11//!
12//! environmentd embeds an HTTP server for introspection into the running
13//! process. At the moment, its primary exports are Prometheus metrics, heap
14//! profiles, and catalog dumps.
15//!
16//! ## Authentication flow
17//!
18//! The server supports several authentication modes, controlled by the
19//! configured [`listeners::AuthenticatorKind`]. The general flow is:
20//!
21//! 1. **Identity resolution.** An authentication middleware runs on every
22//!    protected request and resolves the caller's identity via one of:
23//!    - **Credentials in headers.** The caller supplies a username/password or
24//!      token in the request headers. Supported by all [`listeners::AuthenticatorKind`]s.
25//!    - **Session reuse.** If the caller has an active authenticated session
26//!      (established via `POST /api/login`) and has not supplied credentials
27//!      in the request headers, the session is reused. Only available for
28//!      [`listeners::AuthenticatorKind::Password`] and [`listeners::AuthenticatorKind::Oidc`].
29//!    - **Trusted header injection.** A trusted upstream proxy (e.g. Teleport)
30//!      may inject the caller's identity into the request headers. Only available
31//!      for [`listeners::AuthenticatorKind::None`].
32//!
33//! 2. **Session initialization.** Once the caller's identity is known, an
34//!    adapter session is opened on their behalf. This happens as part of
35//!    request processing, after all middleware has run.
36//!
37//! 3. **Request handling.** The handler executes the request (e.g. runs SQL)
38//!    using the initialized adapter session.
39//!
40//! ### WebSocket
41//!
42//! The WebSocket flow is identical to the HTTP flow with two differences:
43//!
44//! - Credentials are not read from request headers. Instead, the first
45//!   message sent by the client is treated as the authentication message.
46//! - Session initialization (step 2) happens inside the WebSocket handler
47//!   itself, rather than as a separate middleware step.
48
49// Axum handlers must use async, but often don't actually use `await`.
50#![allow(clippy::unused_async)]
51
52use std::borrow::Cow;
53use std::collections::BTreeMap;
54use std::fmt::Debug;
55use std::net::{IpAddr, SocketAddr};
56use std::pin::Pin;
57use std::sync::Arc;
58use std::time::{Duration, SystemTime};
59
60use anyhow::Context;
61use axum::error_handling::HandleErrorLayer;
62use axum::extract::ws::{Message, WebSocket};
63use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
64use axum::middleware::{self, Next};
65use axum::response::{IntoResponse, Redirect, Response};
66use axum::{Extension, Json, Router, routing};
67use futures::future::{Shared, TryFutureExt};
68use headers::authorization::{Authorization, Basic, Bearer};
69use headers::{HeaderMapExt, HeaderName};
70use http::header::{AUTHORIZATION, CONTENT_TYPE};
71use http::uri::Scheme;
72use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
73use hyper_openssl::SslStream;
74use hyper_openssl::client::legacy::MaybeHttpsStream;
75use hyper_util::rt::TokioIo;
76use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
77use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
78use mz_auth::Authenticated;
79use mz_auth::password::Password;
80use mz_authenticator::Authenticator;
81use mz_controller::ReplicaHttpLocator;
82use mz_frontegg_auth::Error as FronteggError;
83use mz_http_util::DynamicFilterTarget;
84use mz_ore::cast::u64_to_usize;
85use mz_ore::metrics::MetricsRegistry;
86use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
87use mz_ore::str::StrExt;
88use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
89use mz_repr::user::ExternalUserMetadata;
90use mz_server_core::listeners::{self, AllowedRoles, HttpRoutesEnabled};
91use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
92use mz_sql::session::metadata::SessionMetadata;
93use mz_sql::session::user::{
94    HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
95};
96use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
97use openssl::ssl::Ssl;
98use prometheus::{
99    COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
100};
101use serde::{Deserialize, Serialize};
102use serde_json::json;
103use thiserror::Error;
104use tokio::io::AsyncWriteExt;
105use tokio::sync::oneshot::Receiver;
106use tokio::sync::{oneshot, watch};
107use tokio_metrics::TaskMetrics;
108use tower::limit::GlobalConcurrencyLimitLayer;
109use tower::{Service, ServiceBuilder};
110use tower_http::cors::{AllowOrigin, Any, CorsLayer};
111use tower_sessions::{
112    MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
113    SessionManagerLayer as TowerSessionManagerLayer,
114};
115use tracing::warn;
116
117use crate::BUILD_INFO;
118use crate::deployment::state::DeploymentStateHandle;
119use crate::http::sql::{ExistingUser, SqlError};
120
121mod catalog;
122mod cluster;
123mod console;
124mod mcp;
125mod memory;
126mod metrics;
127mod metrics_viz;
128mod probe;
129mod prometheus;
130mod root;
131mod sql;
132mod webhook;
133
134pub use metrics::Metrics;
135pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
136
137/// Maximum allowed size for a request.
138pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
139
140const SESSION_DURATION: Duration = Duration::from_secs(8 * 3600); // 8 hours
141
142const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
143
144#[derive(Debug)]
145pub struct HttpConfig {
146    pub source: &'static str,
147    pub tls: Option<ReloadingSslContext>,
148    pub authenticator_kind: listeners::AuthenticatorKind,
149    pub frontegg: Option<mz_frontegg_auth::Authenticator>,
150    pub oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
151    pub adapter_client_rx: Shared<Receiver<Client>>,
152    pub allowed_origin: AllowOrigin,
153    /// Raw list of allowed CORS origins, used by the MCP endpoints for
154    /// server-side Origin validation to defend against DNS rebinding.
155    pub allowed_origin_list: Vec<HeaderValue>,
156    pub active_connection_counter: ConnectionCounter,
157    pub helm_chart_version: Option<String>,
158    pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
159    pub metrics: Metrics,
160    pub metrics_registry: MetricsRegistry,
161    pub allowed_roles: AllowedRoles,
162    pub internal_route_config: Arc<InternalRouteConfig>,
163    pub routes_enabled: HttpRoutesEnabled,
164    /// Locator for cluster replica HTTP addresses, used for proxying requests.
165    pub replica_http_locator: Arc<ReplicaHttpLocator>,
166}
167
168#[derive(Debug, Clone)]
169pub struct InternalRouteConfig {
170    pub deployment_state_handle: DeploymentStateHandle,
171    pub internal_console_redirect_url: Option<String>,
172}
173
174#[derive(Clone)]
175pub struct WsState {
176    frontegg: Option<mz_frontegg_auth::Authenticator>,
177    oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
178    authenticator_kind: listeners::AuthenticatorKind,
179    adapter_client_rx: Delayed<mz_adapter::Client>,
180    active_connection_counter: ConnectionCounter,
181    helm_chart_version: Option<String>,
182    allowed_roles: AllowedRoles,
183}
184
185#[derive(Clone)]
186pub struct WebhookState {
187    adapter_client_rx: Delayed<mz_adapter::Client>,
188    webhook_cache: WebhookAppenderCache,
189}
190
191#[derive(Debug)]
192pub struct HttpServer {
193    tls: Option<ReloadingSslContext>,
194    router: Router,
195}
196
197impl HttpServer {
198    pub fn new(
199        HttpConfig {
200            source,
201            tls,
202            authenticator_kind,
203            frontegg,
204            oidc_rx,
205            adapter_client_rx,
206            allowed_origin,
207            allowed_origin_list,
208            active_connection_counter,
209            helm_chart_version,
210            concurrent_webhook_req,
211            metrics,
212            metrics_registry,
213            allowed_roles,
214            internal_route_config,
215            routes_enabled,
216            replica_http_locator,
217        }: HttpConfig,
218    ) -> HttpServer {
219        let tls_enabled = tls.is_some();
220        let webhook_cache = WebhookAppenderCache::new();
221
222        // Create secure session store and manager
223        let session_store = TowerSessionMemoryStore::default();
224        let session_layer = TowerSessionManagerLayer::new(session_store)
225            .with_secure(tls_enabled) // Enforce HTTPS
226            .with_same_site(tower_sessions::cookie::SameSite::Strict) // Prevent CSRF
227            .with_http_only(true) // Prevent XSS
228            .with_name("mz_session") // Custom cookie name
229            .with_path("/"); // Set cookie path
230
231        let frontegg_middleware = frontegg.clone();
232        let oidc_middleware_rx = oidc_rx.clone();
233        let adapter_client_middleware_rx = adapter_client_rx.clone();
234        let auth_middleware = middleware::from_fn(move |req, next| {
235            let frontegg = frontegg_middleware.clone();
236            let oidc_rx = oidc_middleware_rx.clone();
237            let adapter_client_rx = adapter_client_middleware_rx.clone();
238            async move {
239                http_auth(
240                    req,
241                    next,
242                    tls_enabled,
243                    authenticator_kind,
244                    frontegg,
245                    oidc_rx,
246                    adapter_client_rx,
247                    allowed_roles,
248                )
249                .await
250            }
251        });
252
253        let mut router = Router::new();
254        let mut base_router = Router::new();
255        if routes_enabled.base {
256            base_router = base_router
257                .route(
258                    "/",
259                    routing::get(move || async move { root::handle_home(routes_enabled).await }),
260                )
261                .route("/api/sql", routing::post(sql::handle_sql))
262                .route("/memory", routing::get(memory::handle_memory))
263                .route(
264                    "/hierarchical-memory",
265                    routing::get(memory::handle_hierarchical_memory),
266                )
267                .route(
268                    "/metrics-viz",
269                    routing::get(metrics_viz::handle_metrics_viz),
270                )
271                .route("/static/{*path}", routing::get(root::handle_static));
272
273            let mut ws_router = Router::new()
274                .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
275                .with_state(WsState {
276                    frontegg,
277                    oidc_rx: oidc_rx.clone(),
278                    authenticator_kind,
279                    adapter_client_rx: adapter_client_rx.clone(),
280                    active_connection_counter: active_connection_counter.clone(),
281                    helm_chart_version,
282                    allowed_roles,
283                });
284            if let listeners::AuthenticatorKind::None = authenticator_kind {
285                ws_router = ws_router.layer(middleware::from_fn(x_materialize_user_header_auth));
286            }
287            router = router.merge(ws_router);
288        }
289        if routes_enabled.profiling {
290            base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
291        }
292
293        if routes_enabled.webhook {
294            let webhook_router = Router::new()
295                .route(
296                    "/api/webhook/{:database}/{:schema}/{:id}",
297                    routing::post(webhook::handle_webhook),
298                )
299                .with_state(WebhookState {
300                    adapter_client_rx: adapter_client_rx.clone(),
301                    webhook_cache,
302                })
303                .layer(
304                    tower_http::decompression::RequestDecompressionLayer::new()
305                        .gzip(true)
306                        .deflate(true)
307                        .br(true)
308                        .zstd(true),
309                )
310                .layer(
311                    CorsLayer::new()
312                        .allow_methods(Method::POST)
313                        .allow_origin(AllowOrigin::mirror_request())
314                        .allow_headers(Any),
315                )
316                .layer(
317                    ServiceBuilder::new()
318                        .layer(HandleErrorLayer::new(handle_load_error))
319                        .load_shed()
320                        .layer(GlobalConcurrencyLimitLayer::with_semaphore(
321                            concurrent_webhook_req,
322                        )),
323                );
324            router = router.merge(webhook_router);
325        }
326
327        if routes_enabled.internal {
328            let console_config = Arc::new(console::ConsoleProxyConfig::new(
329                internal_route_config.internal_console_redirect_url.clone(),
330                "/internal-console".to_string(),
331            ));
332            base_router = base_router
333                .route(
334                    "/api/opentelemetry/config",
335                    routing::put({
336                        move |_: axum::Json<DynamicFilterTarget>| async {
337                            (
338                                StatusCode::BAD_REQUEST,
339                                "This endpoint has been replaced. \
340                            Use the `opentelemetry_filter` system variable."
341                                    .to_string(),
342                            )
343                        }
344                    }),
345                )
346                .route(
347                    "/api/stderr/config",
348                    routing::put({
349                        move |_: axum::Json<DynamicFilterTarget>| async {
350                            (
351                                StatusCode::BAD_REQUEST,
352                                "This endpoint has been replaced. \
353                            Use the `log_filter` system variable."
354                                    .to_string(),
355                            )
356                        }
357                    }),
358                )
359                .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
360                .route(
361                    "/api/catalog/dump",
362                    routing::get(catalog::handle_catalog_dump),
363                )
364                .route(
365                    "/api/catalog/check",
366                    routing::get(catalog::handle_catalog_check),
367                )
368                .route(
369                    "/api/catalog/inject-audit-events",
370                    routing::post(catalog::handle_inject_audit_events),
371                )
372                .route(
373                    "/api/coordinator/check",
374                    routing::get(catalog::handle_coordinator_check),
375                )
376                .route(
377                    "/api/coordinator/dump",
378                    routing::get(catalog::handle_coordinator_dump),
379                )
380                .route(
381                    "/internal-console",
382                    routing::get(|| async { Redirect::temporary("/internal-console/") }),
383                )
384                .route(
385                    "/internal-console/{*path}",
386                    routing::get(console::handle_internal_console),
387                )
388                .route(
389                    "/internal-console/",
390                    routing::get(console::handle_internal_console),
391                )
392                .layer(Extension(console_config));
393
394            // Cluster HTTP proxy routes.
395            let cluster_proxy_config = Arc::new(cluster::ClusterProxyConfig::new(Arc::clone(
396                &replica_http_locator,
397            )));
398            base_router = base_router
399                .route("/clusters", routing::get(cluster::handle_clusters))
400                .route(
401                    "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/",
402                    routing::any(cluster::handle_cluster_proxy_root),
403                )
404                .route(
405                    "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/{*path}",
406                    routing::any(cluster::handle_cluster_proxy),
407                )
408                .layer(Extension(cluster_proxy_config));
409
410            let leader_router = Router::new()
411                .route("/api/leader/status", routing::get(handle_leader_status))
412                .route("/api/leader/promote", routing::post(handle_leader_promote))
413                .route(
414                    "/api/leader/skip-catchup",
415                    routing::post(handle_leader_skip_catchup),
416                )
417                .layer(auth_middleware.clone())
418                .with_state(internal_route_config.deployment_state_handle.clone());
419            router = router.merge(leader_router);
420        }
421
422        if routes_enabled.metrics {
423            let metrics_router = Router::new()
424                .route(
425                    "/metrics",
426                    routing::get(move || async move {
427                        mz_http_util::handle_prometheus(&metrics_registry).await
428                    }),
429                )
430                .route(
431                    "/metrics/mz_usage",
432                    routing::get(|client: AuthedClient| async move {
433                        let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
434                        mz_http_util::handle_prometheus(&registry).await
435                    }),
436                )
437                .route(
438                    "/metrics/mz_frontier",
439                    routing::get(|client: AuthedClient| async move {
440                        let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
441                        mz_http_util::handle_prometheus(&registry).await
442                    }),
443                )
444                .route(
445                    "/metrics/mz_compute",
446                    routing::get(|client: AuthedClient| async move {
447                        let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
448                        mz_http_util::handle_prometheus(&registry).await
449                    }),
450                )
451                .route(
452                    "/metrics/mz_storage",
453                    routing::get(|client: AuthedClient| async move {
454                        let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
455                        mz_http_util::handle_prometheus(&registry).await
456                    }),
457                )
458                .route(
459                    "/api/livez",
460                    routing::get(mz_http_util::handle_liveness_check),
461                )
462                .route("/api/readyz", routing::get(probe::handle_ready))
463                .layer(auth_middleware.clone())
464                .layer(Extension(adapter_client_rx.clone()))
465                .layer(Extension(active_connection_counter.clone()));
466            router = router.merge(metrics_router);
467        }
468
469        if routes_enabled.console_config {
470            let console_config_router = Router::new()
471                .route(
472                    "/api/console/config",
473                    routing::get(console::handle_console_config),
474                )
475                .layer(Extension(adapter_client_rx.clone()))
476                .layer(Extension(active_connection_counter.clone()));
477            router = router.merge(console_config_router);
478        }
479
480        // MCP (Model Context Protocol) endpoints
481        // Enabled via runtime `routes_enabled.mcp_agent` and `routes_enabled.mcp_developer` configuration
482        if routes_enabled.mcp_agent || routes_enabled.mcp_developer {
483            use tracing::info;
484
485            let mut mcp_router = Router::new();
486
487            if routes_enabled.mcp_agent {
488                info!("Enabling MCP agent endpoint: /api/mcp/agent");
489                mcp_router = mcp_router.route(
490                    "/api/mcp/agent",
491                    routing::post(mcp::handle_mcp_agent).get(mcp::handle_mcp_method_not_allowed),
492                );
493            }
494
495            if routes_enabled.mcp_developer {
496                info!("Enabling MCP developer endpoint: /api/mcp/developer");
497                mcp_router = mcp_router.route(
498                    "/api/mcp/developer",
499                    routing::post(mcp::handle_mcp_developer)
500                        .get(mcp::handle_mcp_method_not_allowed),
501                );
502            }
503
504            // The MCP handlers perform a server-side Origin check against this
505            // allowlist to defend against DNS rebinding attacks (see
506            // database-issues#11311). The CorsLayer alone is not enough: in a
507            // DNS rebinding attack the browser considers the request
508            // same-origin, so no preflight fires and CORS enforcement is
509            // bypassed.
510            let mcp_allowed_origins = Arc::new(allowed_origin_list.clone());
511            mcp_router = mcp_router
512                .layer(auth_middleware.clone())
513                .layer(Extension(adapter_client_rx.clone()))
514                .layer(Extension(active_connection_counter.clone()))
515                .layer(Extension(mcp_allowed_origins))
516                .layer(
517                    CorsLayer::new()
518                        .allow_methods(Method::POST)
519                        .allow_origin(allowed_origin.clone())
520                        .allow_headers([AUTHORIZATION, CONTENT_TYPE]),
521                );
522            router = router.merge(mcp_router);
523        }
524
525        base_router = base_router
526            .layer(auth_middleware.clone())
527            .layer(Extension(adapter_client_rx.clone()))
528            .layer(Extension(active_connection_counter.clone()))
529            .layer(
530                CorsLayer::new()
531                    .allow_credentials(false)
532                    .allow_headers([
533                        AUTHORIZATION,
534                        CONTENT_TYPE,
535                        HeaderName::from_static("x-materialize-version"),
536                    ])
537                    .allow_methods(Any)
538                    .allow_origin(allowed_origin)
539                    .expose_headers(Any)
540                    .max_age(Duration::from_secs(60) * 60),
541            );
542
543        match authenticator_kind {
544            listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Oidc => {
545                base_router = base_router.layer(session_layer.clone());
546
547                let login_router = Router::new()
548                    .route("/api/login", routing::post(handle_login))
549                    .route("/api/logout", routing::post(handle_logout))
550                    .layer(Extension(adapter_client_rx));
551                router = router.merge(login_router).layer(session_layer);
552            }
553            listeners::AuthenticatorKind::None => {
554                base_router =
555                    base_router.layer(middleware::from_fn(x_materialize_user_header_auth));
556            }
557            _ => {}
558        }
559
560        router = router
561            .merge(base_router)
562            .apply_default_layers(source, metrics);
563
564        HttpServer { tls, router }
565    }
566}
567
568impl Server for HttpServer {
569    const NAME: &'static str = "http";
570
571    fn handle_connection(
572        &self,
573        conn: Connection,
574        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
575    ) -> ConnectionHandler {
576        let router = self.router.clone();
577        let tls_context = self.tls.clone();
578        let mut conn = TokioIo::new(conn);
579
580        Box::pin(async {
581            let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
582            let peer_addr = conn
583                .inner_mut()
584                .take_proxy_header_address()
585                .await
586                .map(|a| a.source)
587                .unwrap_or(direct_peer_addr);
588
589            let (conn, conn_protocol) = match tls_context {
590                Some(tls_context) => {
591                    let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
592                    if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
593                        let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
594                        return Err(e.into());
595                    }
596                    (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
597                }
598                _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
599            };
600            let mut make_tower_svc = router
601                .layer(Extension(conn_protocol))
602                .into_make_service_with_connect_info::<SocketAddr>();
603            let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
604            let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
605            let http = hyper::server::conn::http1::Builder::new();
606            http.serve_connection(conn, hyper_svc)
607                .with_upgrades()
608                .err_into()
609                .await
610        })
611    }
612}
613
614pub async fn handle_leader_status(
615    State(deployment_state_handle): State<DeploymentStateHandle>,
616) -> impl IntoResponse {
617    let status = deployment_state_handle.status();
618    (StatusCode::OK, Json(json!({ "status": status })))
619}
620
621pub async fn handle_leader_promote(
622    State(deployment_state_handle): State<DeploymentStateHandle>,
623) -> impl IntoResponse {
624    match deployment_state_handle.try_promote() {
625        Ok(()) => {
626            // TODO(benesch): the body here is redundant. Should just return
627            // 204.
628            let status = StatusCode::OK;
629            let body = Json(json!({
630                "result": "Success",
631            }));
632            (status, body)
633        }
634        Err(()) => {
635            // TODO(benesch): the nesting here is redundant given the error
636            // code. Should just return the `{"message": "..."}` object.
637            let status = StatusCode::BAD_REQUEST;
638            let body = Json(json!({
639                "result": {"Failure": {"message": "cannot promote leader while initializing"}},
640            }));
641            (status, body)
642        }
643    }
644}
645
646pub async fn handle_leader_skip_catchup(
647    State(deployment_state_handle): State<DeploymentStateHandle>,
648) -> impl IntoResponse {
649    match deployment_state_handle.try_skip_catchup() {
650        Ok(()) => StatusCode::NO_CONTENT.into_response(),
651        Err(()) => {
652            let status = StatusCode::BAD_REQUEST;
653            let body = Json(json!({
654                "message": "cannot skip catchup in this phase of initialization; try again later",
655            }));
656            (status, body).into_response()
657        }
658    }
659}
660
661async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl IntoResponse {
662    // TODO migrate teleport to basic auth and remove this.
663    if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
664        let username = match username {
665            Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
666            _ => {
667                return Err(AuthError::MismatchedUser(format!(
668                    "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
669                )));
670            }
671        };
672        req.extensions_mut().insert(AuthedUser {
673            name: username,
674            external_metadata_rx: None,
675            authenticated: Authenticated,
676            authenticator_kind: mz_auth::AuthenticatorKind::None,
677        });
678    }
679    Ok(next.run(req).await)
680}
681
682type Delayed<T> = Shared<oneshot::Receiver<T>>;
683
684#[derive(Clone)]
685enum ConnProtocol {
686    Http,
687    Https,
688}
689
690#[derive(Clone, Debug)]
691pub struct AuthedUser {
692    name: String,
693    external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
694    authenticated: Authenticated,
695    authenticator_kind: mz_auth::AuthenticatorKind,
696}
697
698pub struct AuthedClient {
699    pub client: SessionClient,
700    pub connection_guard: Option<ConnectionHandle>,
701}
702
703impl AuthedClient {
704    async fn new<F>(
705        adapter_client: &Client,
706        user: AuthedUser,
707        peer_addr: IpAddr,
708        active_connection_counter: ConnectionCounter,
709        helm_chart_version: Option<String>,
710        session_config: F,
711        options: BTreeMap<String, String>,
712        now: NowFn,
713    ) -> Result<Self, AdapterError>
714    where
715        F: FnOnce(&mut AdapterSession),
716    {
717        let conn_id = adapter_client.new_conn_id()?;
718        let mut session = adapter_client.new_session(
719            AdapterSessionConfig {
720                conn_id,
721                uuid: epoch_to_uuid_v7(&(now)()),
722                user: user.name,
723                client_ip: Some(peer_addr),
724                external_metadata_rx: user.external_metadata_rx,
725                helm_chart_version,
726                authenticator_kind: user.authenticator_kind,
727            },
728            user.authenticated,
729        );
730        let connection_guard = active_connection_counter.allocate_connection(session.user())?;
731
732        session_config(&mut session);
733        let system_vars = adapter_client.get_system_vars().await;
734        for (key, val) in options {
735            const LOCAL: bool = false;
736            if let Err(err) =
737                session
738                    .vars_mut()
739                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
740            {
741                session.add_notice(AdapterNotice::BadStartupSetting {
742                    name: key.to_string(),
743                    reason: err.to_string(),
744                })
745            }
746        }
747        let adapter_client = adapter_client.startup(session).await?;
748        Ok(AuthedClient {
749            client: adapter_client,
750            connection_guard,
751        })
752    }
753}
754
755impl<S> FromRequestParts<S> for AuthedClient
756where
757    S: Send + Sync,
758{
759    type Rejection = Response;
760
761    async fn from_request_parts(
762        req: &mut http::request::Parts,
763        state: &S,
764    ) -> Result<Self, Self::Rejection> {
765        #[derive(Debug, Default, Deserialize)]
766        struct Params {
767            #[serde(default)]
768            options: String,
769        }
770        let params: Query<Params> = Query::from_request_parts(req, state)
771            .await
772            .unwrap_or_default();
773
774        let peer_addr = req
775            .extensions
776            .get::<ConnectInfo<SocketAddr>>()
777            .expect("ConnectInfo extension guaranteed to exist")
778            .0
779            .ip();
780
781        let user = req.extensions.get::<AuthedUser>().unwrap();
782        let adapter_client = req
783            .extensions
784            .get::<Delayed<mz_adapter::Client>>()
785            .unwrap()
786            .clone();
787        let adapter_client = adapter_client.await.map_err(|_| {
788            (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
789        })?;
790        let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
791        let helm_chart_version = None;
792
793        let options = if params.options.is_empty() {
794            // It's possible 'options' simply wasn't provided, we don't want that to
795            // count as a failure to deserialize
796            BTreeMap::<String, String>::default()
797        } else {
798            match serde_json::from_str(&params.options) {
799                Ok(options) => options,
800                Err(_e) => {
801                    // If we fail to deserialize options, fail the request.
802                    let code = StatusCode::BAD_REQUEST;
803                    let msg = format!("Failed to deserialize {} map", "options".quoted());
804                    return Err((code, msg).into_response());
805                }
806            }
807        };
808
809        let client = AuthedClient::new(
810            &adapter_client,
811            user.clone(),
812            peer_addr,
813            active_connection_counter.clone(),
814            helm_chart_version,
815            |session| {
816                session
817                    .vars_mut()
818                    .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
819                    .expect("known to exist")
820            },
821            options,
822            SYSTEM_TIME.clone(),
823        )
824        .await
825        .map_err(|e| {
826            let status = match e {
827                AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
828                    StatusCode::FORBIDDEN
829                }
830                _ => StatusCode::INTERNAL_SERVER_ERROR,
831            };
832            (status, Json(SqlError::from(e))).into_response()
833        })?;
834
835        Ok(client)
836    }
837}
838
839#[derive(Debug, Error)]
840pub(crate) enum AuthError {
841    #[error("role dissallowed")]
842    RoleDisallowed(String),
843    #[error("{0}")]
844    Frontegg(#[from] FronteggError),
845    #[error("missing authorization header")]
846    MissingHttpAuthentication {
847        include_www_authenticate_header: bool,
848    },
849    #[error("{0}")]
850    MismatchedUser(String),
851    #[error("session expired")]
852    SessionExpired,
853    #[error("failed to update session")]
854    FailedToUpdateSession,
855    #[error("invalid credentials")]
856    InvalidCredentials,
857}
858
859impl IntoResponse for AuthError {
860    fn into_response(self) -> Response {
861        warn!("HTTP request failed authentication: {}", self);
862        let mut headers = HeaderMap::new();
863        match self {
864            AuthError::MissingHttpAuthentication {
865                include_www_authenticate_header,
866            } if include_www_authenticate_header => {
867                headers.insert(
868                    http::header::WWW_AUTHENTICATE,
869                    HeaderValue::from_static("Basic realm=Materialize"),
870                );
871            }
872            _ => {}
873        };
874        // We omit most detail from the error message we send to the client, to
875        // avoid giving attackers unnecessary information.
876        (StatusCode::UNAUTHORIZED, headers, "unauthorized").into_response()
877    }
878}
879
880// Simplified login handler
881pub async fn handle_login(
882    session: Option<Extension<TowerSession>>,
883    Extension(adapter_client_rx): Extension<Delayed<Client>>,
884    Json(LoginCredentials { username, password }): Json<LoginCredentials>,
885) -> impl IntoResponse {
886    let Ok(adapter_client) = adapter_client_rx.clone().await else {
887        return StatusCode::INTERNAL_SERVER_ERROR;
888    };
889    let authenticated = match adapter_client.authenticate(&username, &password).await {
890        Ok(authenticated) => authenticated,
891        Err(err) => {
892            warn!(?err, "HTTP login failed authentication");
893            return StatusCode::UNAUTHORIZED;
894        }
895    };
896    // Create session data
897    let session_data = TowerSessionData {
898        username,
899        created_at: SystemTime::now(),
900        last_activity: SystemTime::now(),
901        authenticated,
902        authenticator_kind: mz_auth::AuthenticatorKind::Password,
903    };
904    // Store session data
905    let session = session.and_then(|Extension(session)| Some(session));
906    let Some(session) = session else {
907        return StatusCode::INTERNAL_SERVER_ERROR;
908    };
909    match session.insert("data", &session_data).await {
910        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
911        Ok(_) => StatusCode::OK,
912    }
913}
914
915// Simplified logout handler
916pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
917    let session = session.and_then(|Extension(session)| Some(session));
918    let Some(session) = session else {
919        return StatusCode::INTERNAL_SERVER_ERROR;
920    };
921    // Delete session
922    match session.delete().await {
923        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
924        Ok(_) => StatusCode::OK,
925    }
926}
927
928async fn http_auth(
929    mut req: Request,
930    next: Next,
931    tls_enabled: bool,
932    authenticator_kind: listeners::AuthenticatorKind,
933    frontegg: Option<mz_frontegg_auth::Authenticator>,
934    oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
935    adapter_client_rx: Delayed<Client>,
936    allowed_roles: AllowedRoles,
937) -> Result<impl IntoResponse, AuthError> {
938    let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
939        Some(Credentials::Password {
940            username: basic.username().to_owned(),
941            password: Password(basic.password().to_owned()),
942        })
943    } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
944        Some(Credentials::Token {
945            token: bearer.token().to_owned(),
946        })
947    } else {
948        None
949    };
950
951    // Reuses an authenticated session if one already exists.
952    // If credentials are provided, we perform a new authentication,
953    // separate from the existing session.
954    if creds.is_none()
955        && let Some((session, session_data)) =
956            maybe_get_authenticated_session(req.extensions().get::<TowerSession>()).await
957    {
958        let user = ensure_session_unexpired(session, session_data).await?;
959        // Need this to set the user of the Adapter client.
960        req.extensions_mut().insert(user);
961        return Ok(next.run(req).await);
962    }
963
964    // First, extract the username from the certificate, validating that the
965    // connection matches the TLS configuration along the way.
966    // Fall back to existing authentication methods.
967    let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
968    match (tls_enabled, &conn_protocol) {
969        (false, ConnProtocol::Http) => {}
970        (false, ConnProtocol::Https { .. }) => unreachable!(),
971        (true, ConnProtocol::Http) => {
972            let mut parts = req.uri().clone().into_parts();
973            parts.scheme = Some(Scheme::HTTPS);
974            return Ok(Redirect::permanent(
975                &Uri::from_parts(parts)
976                    .expect("it was already a URI, just changed the scheme")
977                    .to_string(),
978            )
979            .into_response());
980        }
981        (true, ConnProtocol::Https { .. }) => {}
982    }
983    // If we've already passed some other auth, just use that.
984    if req.extensions().get::<AuthedUser>().is_some() {
985        return Ok(next.run(req).await);
986    }
987
988    let path = req.uri().path();
989    let include_www_authenticate_header = path == "/"
990        || PROFILING_API_ENDPOINTS
991            .iter()
992            .any(|prefix| path.starts_with(prefix));
993    let authenticator = get_authenticator(
994        authenticator_kind,
995        creds.as_ref(),
996        frontegg,
997        &oidc_rx,
998        &adapter_client_rx,
999    )
1000    .await;
1001
1002    let user = auth(
1003        &authenticator,
1004        creds,
1005        allowed_roles,
1006        include_www_authenticate_header,
1007    )
1008    .await?;
1009
1010    // Add the authenticated user as an extension so downstream handlers can
1011    // inspect it if necessary.
1012    req.extensions_mut().insert(user);
1013
1014    // Run the request.
1015    Ok(next.run(req).await)
1016}
1017
1018async fn init_ws(
1019    WsState {
1020        frontegg,
1021        oidc_rx,
1022        authenticator_kind,
1023        adapter_client_rx,
1024        active_connection_counter,
1025        helm_chart_version,
1026        allowed_roles,
1027    }: WsState,
1028    existing_user: Option<ExistingUser>,
1029    peer_addr: IpAddr,
1030    ws: &mut WebSocket,
1031) -> Result<AuthedClient, anyhow::Error> {
1032    // TODO: Add a timeout here to prevent resource leaks by clients that
1033    // connect then never send a message.
1034    let ws_auth: WebSocketAuth = loop {
1035        let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
1036        match init_msg {
1037            Message::Text(data) => break serde_json::from_str(&data)?,
1038            Message::Binary(data) => break serde_json::from_slice(&data)?,
1039            // Handled automatically by the server.
1040            Message::Ping(_) => {
1041                continue;
1042            }
1043            Message::Pong(_) => {
1044                continue;
1045            }
1046            Message::Close(_) => {
1047                anyhow::bail!("closed");
1048            }
1049        }
1050    };
1051
1052    // If credentials are provided, we perform a new authentication,
1053    // separate from the existing session.
1054    let (creds, options) = match ws_auth {
1055        WebSocketAuth::Basic {
1056            user,
1057            password,
1058            options,
1059        } => {
1060            let creds = Credentials::Password {
1061                username: user,
1062                password,
1063            };
1064            (Some(creds), options)
1065        }
1066        WebSocketAuth::Bearer { token, options } => {
1067            let creds = Credentials::Token { token };
1068            (Some(creds), options)
1069        }
1070        WebSocketAuth::OptionsOnly { options } => (None, options),
1071    };
1072
1073    let user = match (existing_user, creds) {
1074        (Some(ExistingUser::XMaterializeUserHeader(_)), Some(_creds)) => {
1075            warn!("Unexpected bearer or basic auth provided when using user header");
1076            anyhow::bail!("unexpected")
1077        }
1078        (Some(ExistingUser::Session(user)), None)
1079        | (Some(ExistingUser::XMaterializeUserHeader(user)), None) => user,
1080        (_, Some(creds)) => {
1081            let authenticator = get_authenticator(
1082                authenticator_kind,
1083                Some(&creds),
1084                frontegg,
1085                &oidc_rx,
1086                &adapter_client_rx,
1087            )
1088            .await;
1089            let user = auth(&authenticator, Some(creds), allowed_roles, false).await?;
1090            user
1091        }
1092        (None, None) => anyhow::bail!("expected auth information"),
1093    };
1094
1095    let client = AuthedClient::new(
1096        &adapter_client_rx.clone().await?,
1097        user,
1098        peer_addr,
1099        active_connection_counter.clone(),
1100        helm_chart_version.clone(),
1101        |_session| (),
1102        options,
1103        SYSTEM_TIME.clone(),
1104    )
1105    .await?;
1106
1107    Ok(client)
1108}
1109
1110enum Credentials {
1111    Password {
1112        username: String,
1113        password: Password,
1114    },
1115    Token {
1116        token: String,
1117    },
1118}
1119
1120async fn get_authenticator(
1121    kind: listeners::AuthenticatorKind,
1122    creds: Option<&Credentials>,
1123    frontegg: Option<mz_frontegg_auth::Authenticator>,
1124    oidc_rx: &Delayed<mz_authenticator::GenericOidcAuthenticator>,
1125    adapter_client_rx: &Delayed<Client>,
1126) -> Authenticator {
1127    match kind {
1128        listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
1129            "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
1130        )),
1131        listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Sasl => {
1132            let client = adapter_client_rx.clone().await.expect("sender not dropped");
1133            Authenticator::Password(client)
1134        }
1135        listeners::AuthenticatorKind::Oidc => match creds {
1136            // Use the password authenticator if the credentials are password-based
1137            Some(Credentials::Password { .. }) => {
1138                let client = adapter_client_rx.clone().await.expect("sender not dropped");
1139                Authenticator::Password(client)
1140            }
1141            _ => Authenticator::Oidc(oidc_rx.clone().await.expect("sender not dropped")),
1142        },
1143        listeners::AuthenticatorKind::None => Authenticator::None,
1144    }
1145}
1146
1147/// Attempts to retrieve session data from a [`TowerSession`], if available.
1148/// Session data is present only if an authenticated session has been
1149/// established via [`handle_login`].
1150pub(crate) async fn maybe_get_authenticated_session(
1151    session: Option<&TowerSession>,
1152) -> Option<(&TowerSession, TowerSessionData)> {
1153    if let Some(session) = session {
1154        if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
1155            return Some((session, session_data));
1156        }
1157    }
1158    None
1159}
1160
1161/// Ensures the session is still valid by checking for expiration,
1162/// and returns the associated user if the session remains active.
1163pub(crate) async fn ensure_session_unexpired(
1164    session: &TowerSession,
1165    session_data: TowerSessionData,
1166) -> Result<AuthedUser, AuthError> {
1167    if session_data
1168        .last_activity
1169        .elapsed()
1170        .unwrap_or(Duration::MAX)
1171        > SESSION_DURATION
1172    {
1173        let _ = session.delete().await;
1174        return Err(AuthError::SessionExpired);
1175    }
1176    let mut updated_data = session_data.clone();
1177    updated_data.last_activity = SystemTime::now();
1178    session
1179        .insert("data", &updated_data)
1180        .await
1181        .map_err(|_| AuthError::FailedToUpdateSession)?;
1182
1183    Ok(AuthedUser {
1184        name: session_data.username,
1185        external_metadata_rx: None,
1186        authenticated: session_data.authenticated,
1187        authenticator_kind: session_data.authenticator_kind,
1188    })
1189}
1190
1191async fn auth(
1192    authenticator: &Authenticator,
1193    creds: Option<Credentials>,
1194    allowed_roles: AllowedRoles,
1195    include_www_authenticate_header: bool,
1196) -> Result<AuthedUser, AuthError> {
1197    let (name, external_metadata_rx, authenticated) = match authenticator {
1198        Authenticator::Frontegg(frontegg) => match creds {
1199            Some(Credentials::Password { username, password }) => {
1200                let (auth_session, authenticated) =
1201                    frontegg.authenticate(&username, password.as_str()).await?;
1202                let name = auth_session.user().into();
1203                let external_metadata_rx = Some(auth_session.external_metadata_rx());
1204                (name, external_metadata_rx, authenticated)
1205            }
1206            Some(Credentials::Token { token }) => {
1207                let (claims, authenticated) = frontegg.validate_access_token(&token, None)?;
1208                let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
1209                    user_id: claims.user_id,
1210                    admin: claims.is_admin,
1211                });
1212                (claims.user, Some(external_metadata_rx), authenticated)
1213            }
1214            None => {
1215                return Err(AuthError::MissingHttpAuthentication {
1216                    include_www_authenticate_header,
1217                });
1218            }
1219        },
1220        Authenticator::Password(adapter_client) => match creds {
1221            Some(Credentials::Password { username, password }) => {
1222                let authenticated = adapter_client
1223                    .authenticate(&username, &password)
1224                    .await
1225                    .map_err(|_| AuthError::InvalidCredentials)?;
1226                (username, None, authenticated)
1227            }
1228            _ => {
1229                return Err(AuthError::MissingHttpAuthentication {
1230                    include_www_authenticate_header,
1231                });
1232            }
1233        },
1234        Authenticator::Sasl(_) => {
1235            // We shouldn't ever end up here as the configuration is validated at startup.
1236            // If we do, it's a server misconfiguration.
1237            // Just in case, we return a 401 rather than panic.
1238            return Err(AuthError::MissingHttpAuthentication {
1239                include_www_authenticate_header,
1240            });
1241        }
1242        Authenticator::Oidc(oidc) => match creds {
1243            Some(Credentials::Token { token }) => {
1244                // Validate JWT token
1245                let (mut claims, authenticated) = oidc
1246                    .authenticate(&token, None)
1247                    .await
1248                    .map_err(|_| AuthError::InvalidCredentials)?;
1249                let name = std::mem::take(&mut claims.user);
1250                (name, None, authenticated)
1251            }
1252            _ => {
1253                return Err(AuthError::MissingHttpAuthentication {
1254                    include_www_authenticate_header,
1255                });
1256            }
1257        },
1258        Authenticator::None => {
1259            // If no authentication, use whatever is in the HTTP auth
1260            // header (without checking the password), or fall back to the
1261            // default user.
1262            let name = match creds {
1263                Some(Credentials::Password { username, .. }) => username,
1264                _ => HTTP_DEFAULT_USER.name.to_owned(),
1265            };
1266            (name, None, Authenticated)
1267        }
1268    };
1269
1270    check_role_allowed(&name, allowed_roles)?;
1271
1272    Ok(AuthedUser {
1273        name,
1274        external_metadata_rx,
1275        authenticated,
1276        authenticator_kind: authenticator.kind(),
1277    })
1278}
1279
1280// TODO move this somewhere it can be shared with PGWIRE
1281fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1282    let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1283    // this is a superset of internal users
1284    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1285    let role_allowed = match allowed_roles {
1286        AllowedRoles::Normal => !is_reserved_user,
1287        AllowedRoles::Internal => is_internal_user,
1288        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1289    };
1290    if role_allowed {
1291        Ok(())
1292    } else {
1293        Err(AuthError::RoleDisallowed(name.to_owned()))
1294    }
1295}
1296
1297/// Default layers that should be applied to all routes, and should get applied to both the
1298/// internal http and external http routers.
1299trait DefaultLayers {
1300    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1301}
1302
1303impl DefaultLayers for Router {
1304    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1305        self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1306            .layer(metrics::PrometheusLayer::new(source, metrics))
1307    }
1308}
1309
1310/// Glue code to make [`tower`] work with [`axum`].
1311///
1312/// `axum` requires `Layer`s not return Errors, i.e. they must be `Result<_, Infallible>`,
1313/// instead you must return a type that can be converted into a response. `tower` on the other
1314/// hand does return Errors, so to make the two work together we need to convert our `tower` errors
1315/// into responses.
1316async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1317    if error.is::<tower::load_shed::error::Overloaded>() {
1318        return (
1319            StatusCode::TOO_MANY_REQUESTS,
1320            Cow::from("too many requests, try again later"),
1321        );
1322    }
1323
1324    // Note: This should be unreachable because at the time of writing our only use case is a
1325    // layer that emits `tower::load_shed::error::Overloaded`, which is handled above.
1326    (
1327        StatusCode::INTERNAL_SERVER_ERROR,
1328        Cow::from(format!("Unhandled internal error: {}", error)),
1329    )
1330}
1331
1332#[derive(Debug, Deserialize, Serialize, PartialEq)]
1333pub struct LoginCredentials {
1334    username: String,
1335    password: Password,
1336}
1337
1338#[derive(Debug, Clone, Serialize, Deserialize)]
1339pub struct TowerSessionData {
1340    username: String,
1341    created_at: SystemTime,
1342    last_activity: SystemTime,
1343    authenticated: Authenticated,
1344    authenticator_kind: mz_auth::AuthenticatorKind,
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349    use super::{AllowedRoles, check_role_allowed};
1350
1351    #[mz_ore::test]
1352    fn test_check_role_allowed() {
1353        // Internal user
1354        assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1355        assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1356        assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1357
1358        // Internal user
1359        assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1360        assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1361        assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1362
1363        // Internal user
1364        assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1365        assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1366        assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1367
1368        // Normal user
1369        assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1370        assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1371        assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1372
1373        // Normal user
1374        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1375        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1376        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1377
1378        // Normal user
1379        assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1380        assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1381        assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1382
1383        // Denied by reserved role prefix
1384        assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1385        assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1386        assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1387
1388        // Denied by reserved role prefix
1389        assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1390        assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1391        assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1392
1393        // Denied by reserved role prefix
1394        assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1395        assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1396        assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1397
1398        // Denied by literal PUBLIC
1399        assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1400        assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1401        assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1402    }
1403}