Skip to main content

mz_environmentd/
http.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Embedded HTTP server.
11//!
12//! environmentd embeds an HTTP server for introspection into the running
13//! process. At the moment, its primary exports are Prometheus metrics, heap
14//! profiles, and catalog dumps.
15
16// Axum handlers must use async, but often don't actually use `await`.
17#![allow(clippy::unused_async)]
18
19use std::borrow::Cow;
20use std::collections::BTreeMap;
21use std::fmt::Debug;
22use std::net::{IpAddr, SocketAddr};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::time::{Duration, SystemTime};
26
27use anyhow::Context;
28use axum::error_handling::HandleErrorLayer;
29use axum::extract::ws::{Message, WebSocket};
30use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
31use axum::middleware::{self, Next};
32use axum::response::{IntoResponse, Redirect, Response};
33use axum::{Extension, Json, Router, routing};
34use futures::future::{Shared, TryFutureExt};
35use headers::authorization::{Authorization, Basic, Bearer};
36use headers::{HeaderMapExt, HeaderName};
37use http::header::{AUTHORIZATION, CONTENT_TYPE};
38use http::uri::Scheme;
39use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
40use hyper_openssl::SslStream;
41use hyper_openssl::client::legacy::MaybeHttpsStream;
42use hyper_util::rt::TokioIo;
43use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
44use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
45use mz_auth::Authenticated;
46use mz_auth::password::Password;
47use mz_authenticator::Authenticator;
48use mz_controller::ReplicaHttpLocator;
49use mz_frontegg_auth::Error as FronteggError;
50use mz_http_util::DynamicFilterTarget;
51use mz_ore::cast::u64_to_usize;
52use mz_ore::metrics::MetricsRegistry;
53use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
54use mz_ore::str::StrExt;
55use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
56use mz_repr::user::ExternalUserMetadata;
57use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind, HttpRoutesEnabled};
58use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
59use mz_sql::session::metadata::SessionMetadata;
60use mz_sql::session::user::{
61    HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
62};
63use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
64use openssl::ssl::Ssl;
65use prometheus::{
66    COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
67};
68use serde::{Deserialize, Serialize};
69use serde_json::json;
70use thiserror::Error;
71use tokio::io::AsyncWriteExt;
72use tokio::sync::oneshot::Receiver;
73use tokio::sync::{oneshot, watch};
74use tokio_metrics::TaskMetrics;
75use tower::limit::GlobalConcurrencyLimitLayer;
76use tower::{Service, ServiceBuilder};
77use tower_http::cors::{AllowOrigin, Any, CorsLayer};
78use tower_sessions::{
79    MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
80    SessionManagerLayer as TowerSessionManagerLayer,
81};
82use tracing::warn;
83
84use crate::BUILD_INFO;
85use crate::deployment::state::DeploymentStateHandle;
86use crate::http::sql::SqlError;
87
88mod catalog;
89mod cluster;
90mod console;
91mod memory;
92mod metrics;
93mod metrics_viz;
94mod probe;
95mod prometheus;
96mod root;
97mod sql;
98mod webhook;
99
100pub use metrics::Metrics;
101pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
102
103/// Maximum allowed size for a request.
104pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
105
106const SESSION_DURATION: Duration = Duration::from_secs(8 * 3600); // 8 hours
107
108const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
109
110#[derive(Debug)]
111pub struct HttpConfig {
112    pub source: &'static str,
113    pub tls: Option<ReloadingSslContext>,
114    pub authenticator_kind: AuthenticatorKind,
115    pub authenticator_rx: Shared<Receiver<Arc<Authenticator>>>,
116    pub adapter_client_rx: Shared<Receiver<Client>>,
117    pub allowed_origin: AllowOrigin,
118    pub active_connection_counter: ConnectionCounter,
119    pub helm_chart_version: Option<String>,
120    pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
121    pub metrics: Metrics,
122    pub metrics_registry: MetricsRegistry,
123    pub allowed_roles: AllowedRoles,
124    pub internal_route_config: Arc<InternalRouteConfig>,
125    pub routes_enabled: HttpRoutesEnabled,
126    /// Locator for cluster replica HTTP addresses, used for proxying requests.
127    pub replica_http_locator: Arc<ReplicaHttpLocator>,
128}
129
130#[derive(Debug, Clone)]
131pub struct InternalRouteConfig {
132    pub deployment_state_handle: DeploymentStateHandle,
133    pub internal_console_redirect_url: Option<String>,
134}
135
136#[derive(Clone)]
137pub struct WsState {
138    authenticator_rx: Delayed<Arc<Authenticator>>,
139    adapter_client_rx: Delayed<mz_adapter::Client>,
140    active_connection_counter: ConnectionCounter,
141    helm_chart_version: Option<String>,
142    allowed_roles: AllowedRoles,
143}
144
145#[derive(Clone)]
146pub struct WebhookState {
147    adapter_client_rx: Delayed<mz_adapter::Client>,
148    webhook_cache: WebhookAppenderCache,
149}
150
151#[derive(Debug)]
152pub struct HttpServer {
153    tls: Option<ReloadingSslContext>,
154    router: Router,
155}
156
157impl HttpServer {
158    pub fn new(
159        HttpConfig {
160            source,
161            tls,
162            authenticator_kind,
163            authenticator_rx,
164            adapter_client_rx,
165            allowed_origin,
166            active_connection_counter,
167            helm_chart_version,
168            concurrent_webhook_req,
169            metrics,
170            metrics_registry,
171            allowed_roles,
172            internal_route_config,
173            routes_enabled,
174            replica_http_locator,
175        }: HttpConfig,
176    ) -> HttpServer {
177        let tls_enabled = tls.is_some();
178        let webhook_cache = WebhookAppenderCache::new();
179
180        // Create secure session store and manager
181        let session_store = TowerSessionMemoryStore::default();
182        let session_layer = TowerSessionManagerLayer::new(session_store)
183            .with_secure(tls_enabled) // Enforce HTTPS
184            .with_same_site(tower_sessions::cookie::SameSite::Strict) // Prevent CSRF
185            .with_http_only(true) // Prevent XSS
186            .with_name("mz_session") // Custom cookie name
187            .with_path("/"); // Set cookie path
188
189        let auth_middleware_authenticator_rx = authenticator_rx.clone();
190        let auth_middleware = middleware::from_fn(move |req, next| {
191            let authenticator_rx = auth_middleware_authenticator_rx.clone();
192            async move {
193                let authenticator = authenticator_rx
194                    .await
195                    .expect("sender not dropped before sending once");
196                http_auth(req, next, tls_enabled, authenticator, allowed_roles).await
197            }
198        });
199
200        let mut router = Router::new();
201        let mut base_router = Router::new();
202        if routes_enabled.base {
203            base_router = base_router
204                .route(
205                    "/",
206                    routing::get(move || async move { root::handle_home(routes_enabled).await }),
207                )
208                .route("/api/sql", routing::post(sql::handle_sql))
209                .route("/memory", routing::get(memory::handle_memory))
210                .route(
211                    "/hierarchical-memory",
212                    routing::get(memory::handle_hierarchical_memory),
213                )
214                .route(
215                    "/metrics-viz",
216                    routing::get(metrics_viz::handle_metrics_viz),
217                )
218                .route("/static/{*path}", routing::get(root::handle_static));
219
220            let mut ws_router = Router::new()
221                .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
222                .with_state(WsState {
223                    authenticator_rx: authenticator_rx.clone(),
224                    adapter_client_rx: adapter_client_rx.clone(),
225                    active_connection_counter: active_connection_counter.clone(),
226                    helm_chart_version,
227                    allowed_roles,
228                });
229            if let AuthenticatorKind::None = authenticator_kind {
230                ws_router = ws_router.layer(middleware::from_fn(x_materialize_user_header_auth));
231            }
232            router = router.merge(ws_router);
233        }
234        if routes_enabled.profiling {
235            base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
236        }
237
238        if routes_enabled.webhook {
239            let webhook_router = Router::new()
240                .route(
241                    "/api/webhook/{:database}/{:schema}/{:id}",
242                    routing::post(webhook::handle_webhook),
243                )
244                .with_state(WebhookState {
245                    adapter_client_rx: adapter_client_rx.clone(),
246                    webhook_cache,
247                })
248                .layer(
249                    tower_http::decompression::RequestDecompressionLayer::new()
250                        .gzip(true)
251                        .deflate(true)
252                        .br(true)
253                        .zstd(true),
254                )
255                .layer(
256                    CorsLayer::new()
257                        .allow_methods(Method::POST)
258                        .allow_origin(AllowOrigin::mirror_request())
259                        .allow_headers(Any),
260                )
261                .layer(
262                    ServiceBuilder::new()
263                        .layer(HandleErrorLayer::new(handle_load_error))
264                        .load_shed()
265                        .layer(GlobalConcurrencyLimitLayer::with_semaphore(
266                            concurrent_webhook_req,
267                        )),
268                );
269            router = router.merge(webhook_router);
270        }
271
272        if routes_enabled.internal {
273            let console_config = Arc::new(console::ConsoleProxyConfig::new(
274                internal_route_config.internal_console_redirect_url.clone(),
275                "/internal-console".to_string(),
276            ));
277            base_router = base_router
278                .route(
279                    "/api/opentelemetry/config",
280                    routing::put({
281                        move |_: axum::Json<DynamicFilterTarget>| async {
282                            (
283                                StatusCode::BAD_REQUEST,
284                                "This endpoint has been replaced. \
285                            Use the `opentelemetry_filter` system variable."
286                                    .to_string(),
287                            )
288                        }
289                    }),
290                )
291                .route(
292                    "/api/stderr/config",
293                    routing::put({
294                        move |_: axum::Json<DynamicFilterTarget>| async {
295                            (
296                                StatusCode::BAD_REQUEST,
297                                "This endpoint has been replaced. \
298                            Use the `log_filter` system variable."
299                                    .to_string(),
300                            )
301                        }
302                    }),
303                )
304                .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
305                .route(
306                    "/api/catalog/dump",
307                    routing::get(catalog::handle_catalog_dump),
308                )
309                .route(
310                    "/api/catalog/check",
311                    routing::get(catalog::handle_catalog_check),
312                )
313                .route(
314                    "/api/coordinator/check",
315                    routing::get(catalog::handle_coordinator_check),
316                )
317                .route(
318                    "/api/coordinator/dump",
319                    routing::get(catalog::handle_coordinator_dump),
320                )
321                .route(
322                    "/internal-console",
323                    routing::get(|| async { Redirect::temporary("/internal-console/") }),
324                )
325                .route(
326                    "/internal-console/{*path}",
327                    routing::get(console::handle_internal_console),
328                )
329                .route(
330                    "/internal-console/",
331                    routing::get(console::handle_internal_console),
332                )
333                .layer(Extension(console_config));
334
335            // Cluster HTTP proxy routes.
336            let cluster_proxy_config = Arc::new(cluster::ClusterProxyConfig::new(Arc::clone(
337                &replica_http_locator,
338            )));
339            base_router = base_router
340                .route("/clusters", routing::get(cluster::handle_clusters))
341                .route(
342                    "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/",
343                    routing::any(cluster::handle_cluster_proxy_root),
344                )
345                .route(
346                    "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/{*path}",
347                    routing::any(cluster::handle_cluster_proxy),
348                )
349                .layer(Extension(cluster_proxy_config));
350
351            let leader_router = Router::new()
352                .route("/api/leader/status", routing::get(handle_leader_status))
353                .route("/api/leader/promote", routing::post(handle_leader_promote))
354                .route(
355                    "/api/leader/skip-catchup",
356                    routing::post(handle_leader_skip_catchup),
357                )
358                .layer(auth_middleware.clone())
359                .with_state(internal_route_config.deployment_state_handle.clone());
360            router = router.merge(leader_router);
361        }
362
363        if routes_enabled.metrics {
364            let metrics_router = Router::new()
365                .route(
366                    "/metrics",
367                    routing::get(move || async move {
368                        mz_http_util::handle_prometheus(&metrics_registry).await
369                    }),
370                )
371                .route(
372                    "/metrics/mz_usage",
373                    routing::get(|client: AuthedClient| async move {
374                        let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
375                        mz_http_util::handle_prometheus(&registry).await
376                    }),
377                )
378                .route(
379                    "/metrics/mz_frontier",
380                    routing::get(|client: AuthedClient| async move {
381                        let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
382                        mz_http_util::handle_prometheus(&registry).await
383                    }),
384                )
385                .route(
386                    "/metrics/mz_compute",
387                    routing::get(|client: AuthedClient| async move {
388                        let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
389                        mz_http_util::handle_prometheus(&registry).await
390                    }),
391                )
392                .route(
393                    "/metrics/mz_storage",
394                    routing::get(|client: AuthedClient| async move {
395                        let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
396                        mz_http_util::handle_prometheus(&registry).await
397                    }),
398                )
399                .route(
400                    "/api/livez",
401                    routing::get(mz_http_util::handle_liveness_check),
402                )
403                .route("/api/readyz", routing::get(probe::handle_ready))
404                .layer(auth_middleware.clone())
405                .layer(Extension(adapter_client_rx.clone()))
406                .layer(Extension(active_connection_counter.clone()));
407            router = router.merge(metrics_router);
408        }
409
410        base_router = base_router
411            .layer(auth_middleware.clone())
412            .layer(Extension(adapter_client_rx.clone()))
413            .layer(Extension(active_connection_counter.clone()))
414            .layer(
415                CorsLayer::new()
416                    .allow_credentials(false)
417                    .allow_headers([
418                        AUTHORIZATION,
419                        CONTENT_TYPE,
420                        HeaderName::from_static("x-materialize-version"),
421                    ])
422                    .allow_methods(Any)
423                    .allow_origin(allowed_origin)
424                    .expose_headers(Any)
425                    .max_age(Duration::from_secs(60) * 60),
426            );
427
428        match authenticator_kind {
429            AuthenticatorKind::Password => {
430                base_router = base_router.layer(session_layer.clone());
431
432                let login_router = Router::new()
433                    .route("/api/login", routing::post(handle_login))
434                    .route("/api/logout", routing::post(handle_logout))
435                    .layer(Extension(adapter_client_rx));
436                router = router.merge(login_router).layer(session_layer);
437            }
438            AuthenticatorKind::None => {
439                base_router =
440                    base_router.layer(middleware::from_fn(x_materialize_user_header_auth));
441            }
442            _ => {}
443        }
444
445        router = router
446            .merge(base_router)
447            .apply_default_layers(source, metrics);
448
449        HttpServer { tls, router }
450    }
451}
452
453impl Server for HttpServer {
454    const NAME: &'static str = "http";
455
456    fn handle_connection(
457        &self,
458        conn: Connection,
459        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
460    ) -> ConnectionHandler {
461        let router = self.router.clone();
462        let tls_context = self.tls.clone();
463        let mut conn = TokioIo::new(conn);
464
465        Box::pin(async {
466            let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
467            let peer_addr = conn
468                .inner_mut()
469                .take_proxy_header_address()
470                .await
471                .map(|a| a.source)
472                .unwrap_or(direct_peer_addr);
473
474            let (conn, conn_protocol) = match tls_context {
475                Some(tls_context) => {
476                    let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
477                    if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
478                        let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
479                        return Err(e.into());
480                    }
481                    (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
482                }
483                _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
484            };
485            let mut make_tower_svc = router
486                .layer(Extension(conn_protocol))
487                .into_make_service_with_connect_info::<SocketAddr>();
488            let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
489            let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
490            let http = hyper::server::conn::http1::Builder::new();
491            http.serve_connection(conn, hyper_svc)
492                .with_upgrades()
493                .err_into()
494                .await
495        })
496    }
497}
498
499pub async fn handle_leader_status(
500    State(deployment_state_handle): State<DeploymentStateHandle>,
501) -> impl IntoResponse {
502    let status = deployment_state_handle.status();
503    (StatusCode::OK, Json(json!({ "status": status })))
504}
505
506pub async fn handle_leader_promote(
507    State(deployment_state_handle): State<DeploymentStateHandle>,
508) -> impl IntoResponse {
509    match deployment_state_handle.try_promote() {
510        Ok(()) => {
511            // TODO(benesch): the body here is redundant. Should just return
512            // 204.
513            let status = StatusCode::OK;
514            let body = Json(json!({
515                "result": "Success",
516            }));
517            (status, body)
518        }
519        Err(()) => {
520            // TODO(benesch): the nesting here is redundant given the error
521            // code. Should just return the `{"message": "..."}` object.
522            let status = StatusCode::BAD_REQUEST;
523            let body = Json(json!({
524                "result": {"Failure": {"message": "cannot promote leader while initializing"}},
525            }));
526            (status, body)
527        }
528    }
529}
530
531pub async fn handle_leader_skip_catchup(
532    State(deployment_state_handle): State<DeploymentStateHandle>,
533) -> impl IntoResponse {
534    match deployment_state_handle.try_skip_catchup() {
535        Ok(()) => StatusCode::NO_CONTENT.into_response(),
536        Err(()) => {
537            let status = StatusCode::BAD_REQUEST;
538            let body = Json(json!({
539                "message": "cannot skip catchup in this phase of initialization; try again later",
540            }));
541            (status, body).into_response()
542        }
543    }
544}
545
546async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl IntoResponse {
547    // TODO migrate teleport to basic auth and remove this.
548    if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
549        let username = match username {
550            Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
551            _ => {
552                return Err(AuthError::MismatchedUser(format!(
553                    "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
554                )));
555            }
556        };
557        req.extensions_mut().insert(AuthedUser {
558            name: username,
559            external_metadata_rx: None,
560            authenticated: Authenticated,
561        });
562    }
563    Ok(next.run(req).await)
564}
565
566type Delayed<T> = Shared<oneshot::Receiver<T>>;
567
568#[derive(Clone)]
569enum ConnProtocol {
570    Http,
571    Https,
572}
573
574#[derive(Clone, Debug)]
575pub struct AuthedUser {
576    name: String,
577    external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
578    authenticated: Authenticated,
579}
580
581pub struct AuthedClient {
582    pub client: SessionClient,
583    pub connection_guard: Option<ConnectionHandle>,
584}
585
586impl AuthedClient {
587    async fn new<F>(
588        adapter_client: &Client,
589        user: AuthedUser,
590        peer_addr: IpAddr,
591        active_connection_counter: ConnectionCounter,
592        helm_chart_version: Option<String>,
593        session_config: F,
594        options: BTreeMap<String, String>,
595        now: NowFn,
596    ) -> Result<Self, AdapterError>
597    where
598        F: FnOnce(&mut AdapterSession),
599    {
600        let conn_id = adapter_client.new_conn_id()?;
601        let mut session = adapter_client.new_session(
602            AdapterSessionConfig {
603                conn_id,
604                uuid: epoch_to_uuid_v7(&(now)()),
605                user: user.name,
606                client_ip: Some(peer_addr),
607                external_metadata_rx: user.external_metadata_rx,
608                helm_chart_version,
609            },
610            user.authenticated,
611        );
612        let connection_guard = active_connection_counter.allocate_connection(session.user())?;
613
614        session_config(&mut session);
615        let system_vars = adapter_client.get_system_vars().await;
616        for (key, val) in options {
617            const LOCAL: bool = false;
618            if let Err(err) =
619                session
620                    .vars_mut()
621                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
622            {
623                session.add_notice(AdapterNotice::BadStartupSetting {
624                    name: key.to_string(),
625                    reason: err.to_string(),
626                })
627            }
628        }
629        let adapter_client = adapter_client.startup(session).await?;
630        Ok(AuthedClient {
631            client: adapter_client,
632            connection_guard,
633        })
634    }
635}
636
637impl<S> FromRequestParts<S> for AuthedClient
638where
639    S: Send + Sync,
640{
641    type Rejection = Response;
642
643    async fn from_request_parts(
644        req: &mut http::request::Parts,
645        state: &S,
646    ) -> Result<Self, Self::Rejection> {
647        #[derive(Debug, Default, Deserialize)]
648        struct Params {
649            #[serde(default)]
650            options: String,
651        }
652        let params: Query<Params> = Query::from_request_parts(req, state)
653            .await
654            .unwrap_or_default();
655
656        let peer_addr = req
657            .extensions
658            .get::<ConnectInfo<SocketAddr>>()
659            .expect("ConnectInfo extension guaranteed to exist")
660            .0
661            .ip();
662
663        let user = req.extensions.get::<AuthedUser>().unwrap();
664        let adapter_client = req
665            .extensions
666            .get::<Delayed<mz_adapter::Client>>()
667            .unwrap()
668            .clone();
669        let adapter_client = adapter_client.await.map_err(|_| {
670            (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
671        })?;
672        let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
673        let helm_chart_version = None;
674
675        let options = if params.options.is_empty() {
676            // It's possible 'options' simply wasn't provided, we don't want that to
677            // count as a failure to deserialize
678            BTreeMap::<String, String>::default()
679        } else {
680            match serde_json::from_str(&params.options) {
681                Ok(options) => options,
682                Err(_e) => {
683                    // If we fail to deserialize options, fail the request.
684                    let code = StatusCode::BAD_REQUEST;
685                    let msg = format!("Failed to deserialize {} map", "options".quoted());
686                    return Err((code, msg).into_response());
687                }
688            }
689        };
690
691        let client = AuthedClient::new(
692            &adapter_client,
693            user.clone(),
694            peer_addr,
695            active_connection_counter.clone(),
696            helm_chart_version,
697            |session| {
698                session
699                    .vars_mut()
700                    .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
701                    .expect("known to exist")
702            },
703            options,
704            SYSTEM_TIME.clone(),
705        )
706        .await
707        .map_err(|e| {
708            let status = match e {
709                AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
710                    StatusCode::FORBIDDEN
711                }
712                _ => StatusCode::INTERNAL_SERVER_ERROR,
713            };
714            (status, Json(SqlError::from(e))).into_response()
715        })?;
716
717        Ok(client)
718    }
719}
720
721#[derive(Debug, Error)]
722enum AuthError {
723    #[error("role dissallowed")]
724    RoleDisallowed(String),
725    #[error("{0}")]
726    Frontegg(#[from] FronteggError),
727    #[error("missing authorization header")]
728    MissingHttpAuthentication {
729        include_www_authenticate_header: bool,
730    },
731    #[error("{0}")]
732    MismatchedUser(String),
733    #[error("session expired")]
734    SessionExpired,
735    #[error("failed to update session")]
736    FailedToUpdateSession,
737    #[error("invalid credentials")]
738    InvalidCredentials,
739}
740
741impl IntoResponse for AuthError {
742    fn into_response(self) -> Response {
743        warn!("HTTP request failed authentication: {}", self);
744        let mut headers = HeaderMap::new();
745        match self {
746            AuthError::MissingHttpAuthentication {
747                include_www_authenticate_header,
748            } if include_www_authenticate_header => {
749                headers.insert(
750                    http::header::WWW_AUTHENTICATE,
751                    HeaderValue::from_static("Basic realm=Materialize"),
752                );
753            }
754            _ => {}
755        };
756        // We omit most detail from the error message we send to the client, to
757        // avoid giving attackers unnecessary information.
758        (StatusCode::UNAUTHORIZED, headers, "unauthorized").into_response()
759    }
760}
761
762// Simplified login handler
763pub async fn handle_login(
764    session: Option<Extension<TowerSession>>,
765    Extension(adapter_client_rx): Extension<Delayed<Client>>,
766    Json(LoginCredentials { username, password }): Json<LoginCredentials>,
767) -> impl IntoResponse {
768    let Ok(adapter_client) = adapter_client_rx.clone().await else {
769        return StatusCode::INTERNAL_SERVER_ERROR;
770    };
771    let authenticated = match adapter_client.authenticate(&username, &password).await {
772        Ok(authenticated) => authenticated,
773        Err(err) => {
774            warn!(?err, "HTTP login failed authentication");
775            return StatusCode::UNAUTHORIZED;
776        }
777    };
778    // Create session data
779    let session_data = TowerSessionData {
780        username,
781        created_at: SystemTime::now(),
782        last_activity: SystemTime::now(),
783        authenticated,
784    };
785    // Store session data
786    let session = session.and_then(|Extension(session)| Some(session));
787    let Some(session) = session else {
788        return StatusCode::INTERNAL_SERVER_ERROR;
789    };
790    match session.insert("data", &session_data).await {
791        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
792        Ok(_) => StatusCode::OK,
793    }
794}
795
796// Simplified logout handler
797pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
798    let session = session.and_then(|Extension(session)| Some(session));
799    let Some(session) = session else {
800        return StatusCode::INTERNAL_SERVER_ERROR;
801    };
802    // Delete session
803    match session.delete().await {
804        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
805        Ok(_) => StatusCode::OK,
806    }
807}
808
809async fn http_auth(
810    mut req: Request,
811    next: Next,
812    tls_enabled: bool,
813    authenticator: Arc<Authenticator>,
814    allowed_roles: AllowedRoles,
815) -> impl IntoResponse + use<> {
816    // First check for session authentication
817    if let Some(session) = req.extensions().get::<TowerSession>() {
818        if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
819            // Check session expiration
820            if session_data
821                .last_activity
822                .elapsed()
823                .unwrap_or(Duration::MAX)
824                > SESSION_DURATION
825            {
826                let _ = session.delete().await;
827                return Err(AuthError::SessionExpired);
828            }
829            // Update last activity
830            let mut updated_data = session_data.clone();
831            updated_data.last_activity = SystemTime::now();
832            session
833                .insert("data", &updated_data)
834                .await
835                .map_err(|_| AuthError::FailedToUpdateSession)?;
836            // User is authenticated via session
837            req.extensions_mut().insert(AuthedUser {
838                name: session_data.username,
839                external_metadata_rx: None,
840                authenticated: session_data.authenticated,
841            });
842            return Ok(next.run(req).await);
843        }
844    }
845
846    // First, extract the username from the certificate, validating that the
847    // connection matches the TLS configuration along the way.
848    // Fall back to existing authentication methods.
849    let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
850    match (tls_enabled, &conn_protocol) {
851        (false, ConnProtocol::Http) => {}
852        (false, ConnProtocol::Https { .. }) => unreachable!(),
853        (true, ConnProtocol::Http) => {
854            let mut parts = req.uri().clone().into_parts();
855            parts.scheme = Some(Scheme::HTTPS);
856            return Ok(Redirect::permanent(
857                &Uri::from_parts(parts)
858                    .expect("it was already a URI, just changed the scheme")
859                    .to_string(),
860            )
861            .into_response());
862        }
863        (true, ConnProtocol::Https { .. }) => {}
864    }
865    // If we've already passed some other auth, just use that.
866    if req.extensions().get::<AuthedUser>().is_some() {
867        return Ok(next.run(req).await);
868    }
869    let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
870        Some(Credentials::Password {
871            username: basic.username().to_owned(),
872            password: Password(basic.password().to_owned()),
873        })
874    } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
875        Some(Credentials::Token {
876            token: bearer.token().to_owned(),
877        })
878    } else {
879        None
880    };
881
882    let path = req.uri().path();
883    let include_www_authenticate_header = path == "/"
884        || PROFILING_API_ENDPOINTS
885            .iter()
886            .any(|prefix| path.starts_with(prefix));
887    let user = auth(
888        &authenticator,
889        creds,
890        allowed_roles,
891        include_www_authenticate_header,
892    )
893    .await?;
894
895    // Add the authenticated user as an extension so downstream handlers can
896    // inspect it if necessary.
897    req.extensions_mut().insert(user);
898
899    // Run the request.
900    Ok(next.run(req).await)
901}
902
903async fn init_ws(
904    WsState {
905        authenticator_rx,
906        adapter_client_rx,
907        active_connection_counter,
908        helm_chart_version,
909        allowed_roles,
910    }: &WsState,
911    existing_user: Option<AuthedUser>,
912    peer_addr: IpAddr,
913    ws: &mut WebSocket,
914) -> Result<AuthedClient, anyhow::Error> {
915    let authenticator = authenticator_rx.clone().await.expect("sender not dropped");
916    // TODO: Add a timeout here to prevent resource leaks by clients that
917    // connect then never send a message.
918    let ws_auth: WebSocketAuth = loop {
919        let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
920        match init_msg {
921            Message::Text(data) => break serde_json::from_str(&data)?,
922            Message::Binary(data) => break serde_json::from_slice(&data)?,
923            // Handled automatically by the server.
924            Message::Ping(_) => {
925                continue;
926            }
927            Message::Pong(_) => {
928                continue;
929            }
930            Message::Close(_) => {
931                anyhow::bail!("closed");
932            }
933        }
934    };
935
936    let (user, options) = if let Some(existing_user) = existing_user {
937        match ws_auth {
938            WebSocketAuth::OptionsOnly { options } => (existing_user, options),
939            _ => {
940                warn!("Unexpected bearer or basic auth provided when using user header");
941                anyhow::bail!("unexpected")
942            }
943        }
944    } else {
945        let (creds, options) = match ws_auth {
946            WebSocketAuth::Basic {
947                user,
948                password,
949                options,
950            } => {
951                let creds = Credentials::Password {
952                    username: user,
953                    password,
954                };
955                (creds, options)
956            }
957            WebSocketAuth::Bearer { token, options } => {
958                let creds = Credentials::Token { token };
959                (creds, options)
960            }
961            WebSocketAuth::OptionsOnly { .. } => {
962                anyhow::bail!("expected auth information");
963            }
964        };
965        let user = auth(&authenticator, Some(creds), *allowed_roles, false).await?;
966        (user, options)
967    };
968
969    let client = AuthedClient::new(
970        &adapter_client_rx.clone().await?,
971        user,
972        peer_addr,
973        active_connection_counter.clone(),
974        helm_chart_version.clone(),
975        |_session| (),
976        options,
977        SYSTEM_TIME.clone(),
978    )
979    .await?;
980
981    Ok(client)
982}
983
984enum Credentials {
985    Password {
986        username: String,
987        password: Password,
988    },
989    Token {
990        token: String,
991    },
992}
993
994async fn auth(
995    authenticator: &Authenticator,
996    creds: Option<Credentials>,
997    allowed_roles: AllowedRoles,
998    include_www_authenticate_header: bool,
999) -> Result<AuthedUser, AuthError> {
1000    let (name, external_metadata_rx, authenticated) = match authenticator {
1001        Authenticator::Frontegg(frontegg) => match creds {
1002            Some(Credentials::Password { username, password }) => {
1003                let (auth_session, authenticated) =
1004                    frontegg.authenticate(&username, &password.0).await?;
1005                let name = auth_session.user().into();
1006                let external_metadata_rx = Some(auth_session.external_metadata_rx());
1007                (name, external_metadata_rx, authenticated)
1008            }
1009            Some(Credentials::Token { token }) => {
1010                let (claims, authenticated) = frontegg.validate_access_token(&token, None)?;
1011                let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
1012                    user_id: claims.user_id,
1013                    admin: claims.is_admin,
1014                });
1015                (claims.user, Some(external_metadata_rx), authenticated)
1016            }
1017            None => {
1018                return Err(AuthError::MissingHttpAuthentication {
1019                    include_www_authenticate_header,
1020                });
1021            }
1022        },
1023        Authenticator::Password(adapter_client) => match creds {
1024            Some(Credentials::Password { username, password }) => {
1025                let authenticated = adapter_client
1026                    .authenticate(&username, &password)
1027                    .await
1028                    .map_err(|_| AuthError::InvalidCredentials)?;
1029                (username, None, authenticated)
1030            }
1031            _ => {
1032                return Err(AuthError::MissingHttpAuthentication {
1033                    include_www_authenticate_header,
1034                });
1035            }
1036        },
1037        Authenticator::Sasl(_) => {
1038            // We shouldn't ever end up here as the configuration is validated at startup.
1039            // If we do, it's a server misconfiguration.
1040            // Just in case, we return a 401 rather than panic.
1041            return Err(AuthError::MissingHttpAuthentication {
1042                include_www_authenticate_header,
1043            });
1044        }
1045        // TODO (Oidc): Implement password auth flow
1046        // for this authenticator variant.
1047        Authenticator::Oidc(oidc) => match creds {
1048            Some(Credentials::Token { token }) => {
1049                // Validate JWT token
1050                let (claims, authenticated) = oidc
1051                    .authenticate(&token, None)
1052                    .await
1053                    .map_err(|_| AuthError::InvalidCredentials)?;
1054                let name = claims.username().to_string();
1055                (name, None, authenticated)
1056            }
1057            Some(Credentials::Password { username, password }) => {
1058                // Allow JWT to be passed as password
1059                let (claims, authenticated) = oidc
1060                    .authenticate(&password.0, Some(&username))
1061                    .await
1062                    .map_err(|_| AuthError::InvalidCredentials)?;
1063                let name = claims.username().to_string();
1064                (name, None, authenticated)
1065            }
1066            None => {
1067                return Err(AuthError::MissingHttpAuthentication {
1068                    include_www_authenticate_header,
1069                });
1070            }
1071        },
1072        Authenticator::None => {
1073            // If no authentication, use whatever is in the HTTP auth
1074            // header (without checking the password), or fall back to the
1075            // default user.
1076            let name = match creds {
1077                Some(Credentials::Password { username, .. }) => username,
1078                _ => HTTP_DEFAULT_USER.name.to_owned(),
1079            };
1080            (name, None, Authenticated)
1081        }
1082    };
1083
1084    check_role_allowed(&name, allowed_roles)?;
1085
1086    Ok(AuthedUser {
1087        name,
1088        external_metadata_rx,
1089        authenticated,
1090    })
1091}
1092
1093// TODO move this somewhere it can be shared with PGWIRE
1094fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1095    let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1096    // this is a superset of internal users
1097    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1098    let role_allowed = match allowed_roles {
1099        AllowedRoles::Normal => !is_reserved_user,
1100        AllowedRoles::Internal => is_internal_user,
1101        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1102    };
1103    if role_allowed {
1104        Ok(())
1105    } else {
1106        Err(AuthError::RoleDisallowed(name.to_owned()))
1107    }
1108}
1109
1110/// Default layers that should be applied to all routes, and should get applied to both the
1111/// internal http and external http routers.
1112trait DefaultLayers {
1113    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1114}
1115
1116impl DefaultLayers for Router {
1117    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1118        self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1119            .layer(metrics::PrometheusLayer::new(source, metrics))
1120    }
1121}
1122
1123/// Glue code to make [`tower`] work with [`axum`].
1124///
1125/// `axum` requires `Layer`s not return Errors, i.e. they must be `Result<_, Infallible>`,
1126/// instead you must return a type that can be converted into a response. `tower` on the other
1127/// hand does return Errors, so to make the two work together we need to convert our `tower` errors
1128/// into responses.
1129async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1130    if error.is::<tower::load_shed::error::Overloaded>() {
1131        return (
1132            StatusCode::TOO_MANY_REQUESTS,
1133            Cow::from("too many requests, try again later"),
1134        );
1135    }
1136
1137    // Note: This should be unreachable because at the time of writing our only use case is a
1138    // layer that emits `tower::load_shed::error::Overloaded`, which is handled above.
1139    (
1140        StatusCode::INTERNAL_SERVER_ERROR,
1141        Cow::from(format!("Unhandled internal error: {}", error)),
1142    )
1143}
1144
1145#[derive(Debug, Deserialize, Serialize, PartialEq)]
1146pub struct LoginCredentials {
1147    username: String,
1148    password: Password,
1149}
1150
1151#[derive(Debug, Clone, Serialize, Deserialize)]
1152pub struct TowerSessionData {
1153    username: String,
1154    created_at: SystemTime,
1155    last_activity: SystemTime,
1156    authenticated: Authenticated,
1157}
1158
1159#[cfg(test)]
1160mod tests {
1161    use super::{AllowedRoles, check_role_allowed};
1162
1163    #[mz_ore::test]
1164    fn test_check_role_allowed() {
1165        // Internal user
1166        assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1167        assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1168        assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1169
1170        // Internal user
1171        assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1172        assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1173        assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1174
1175        // Internal user
1176        assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1177        assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1178        assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1179
1180        // Normal user
1181        assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1182        assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1183        assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1184
1185        // Normal user
1186        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1187        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1188        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1189
1190        // Normal user
1191        assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1192        assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1193        assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1194
1195        // Denied by reserved role prefix
1196        assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1197        assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1198        assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1199
1200        // Denied by reserved role prefix
1201        assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1202        assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1203        assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1204
1205        // Denied by reserved role prefix
1206        assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1207        assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1208        assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1209
1210        // Denied by literal PUBLIC
1211        assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1212        assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1213        assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1214    }
1215}