mz_environmentd/
http.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Embedded HTTP server.
11//!
12//! environmentd embeds an HTTP server for introspection into the running
13//! process. At the moment, its primary exports are Prometheus metrics, heap
14//! profiles, and catalog dumps.
15
16// Axum handlers must use async, but often don't actually use `await`.
17#![allow(clippy::unused_async)]
18
19use std::borrow::Cow;
20use std::collections::BTreeMap;
21use std::fmt::Debug;
22use std::net::{IpAddr, SocketAddr};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::time::{Duration, SystemTime};
26
27use anyhow::Context;
28use async_trait::async_trait;
29use axum::error_handling::HandleErrorLayer;
30use axum::extract::ws::{Message, WebSocket};
31use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
32use axum::middleware::{self, Next};
33use axum::response::{IntoResponse, Redirect, Response};
34use axum::{Extension, Json, Router, routing};
35use futures::future::{Shared, TryFutureExt};
36use headers::authorization::{Authorization, Basic, Bearer};
37use headers::{HeaderMapExt, HeaderName};
38use http::header::{AUTHORIZATION, CONTENT_TYPE};
39use http::uri::Scheme;
40use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
41use hyper_openssl::SslStream;
42use hyper_openssl::client::legacy::MaybeHttpsStream;
43use hyper_util::rt::TokioIo;
44use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
45use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
46use mz_auth::password::Password;
47use mz_authenticator::Authenticator;
48use mz_frontegg_auth::Error as FronteggError;
49use mz_http_util::DynamicFilterTarget;
50use mz_ore::cast::u64_to_usize;
51use mz_ore::metrics::MetricsRegistry;
52use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
53use mz_ore::str::StrExt;
54use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
55use mz_repr::user::{ExternalUserMetadata, InternalUserMetadata};
56use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind, HttpRoutesEnabled};
57use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
58use mz_sql::session::metadata::SessionMetadata;
59use mz_sql::session::user::{
60    HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
61};
62use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
63use openssl::ssl::Ssl;
64use prometheus::{
65    COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
66};
67use serde::{Deserialize, Serialize};
68use serde_json::json;
69use thiserror::Error;
70use tokio::io::AsyncWriteExt;
71use tokio::sync::oneshot::Receiver;
72use tokio::sync::{oneshot, watch};
73use tokio_metrics::TaskMetrics;
74use tower::limit::GlobalConcurrencyLimitLayer;
75use tower::{Service, ServiceBuilder};
76use tower_http::cors::{AllowOrigin, Any, CorsLayer};
77use tower_sessions::{
78    MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
79    SessionManagerLayer as TowerSessionManagerLayer,
80};
81use tracing::{error, warn};
82
83use crate::BUILD_INFO;
84use crate::deployment::state::DeploymentStateHandle;
85use crate::http::sql::SqlError;
86
87mod catalog;
88mod console;
89mod memory;
90mod metrics;
91mod probe;
92mod prometheus;
93mod root;
94mod sql;
95mod webhook;
96
97pub use metrics::Metrics;
98pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
99
100/// Maximum allowed size for a request.
101pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
102
103const SESSION_DURATION: Duration = Duration::from_secs(8 * 3600); // 8 hours
104
105const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
106
107#[derive(Debug)]
108pub struct HttpConfig {
109    pub source: &'static str,
110    pub tls: Option<ReloadingSslContext>,
111    pub authenticator_kind: AuthenticatorKind,
112    pub authenticator_rx: Shared<Receiver<Arc<Authenticator>>>,
113    pub adapter_client_rx: Shared<Receiver<Client>>,
114    pub allowed_origin: AllowOrigin,
115    pub active_connection_counter: ConnectionCounter,
116    pub helm_chart_version: Option<String>,
117    pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
118    pub metrics: Metrics,
119    pub metrics_registry: MetricsRegistry,
120    pub allowed_roles: AllowedRoles,
121    pub internal_route_config: Arc<InternalRouteConfig>,
122    pub routes_enabled: HttpRoutesEnabled,
123}
124
125#[derive(Debug, Clone)]
126pub struct InternalRouteConfig {
127    pub deployment_state_handle: DeploymentStateHandle,
128    pub internal_console_redirect_url: Option<String>,
129}
130
131#[derive(Clone)]
132pub struct WsState {
133    authenticator_rx: Delayed<Arc<Authenticator>>,
134    adapter_client_rx: Delayed<mz_adapter::Client>,
135    active_connection_counter: ConnectionCounter,
136    helm_chart_version: Option<String>,
137    allowed_roles: AllowedRoles,
138}
139
140#[derive(Clone)]
141pub struct WebhookState {
142    adapter_client_rx: Delayed<mz_adapter::Client>,
143    webhook_cache: WebhookAppenderCache,
144}
145
146#[derive(Debug)]
147pub struct HttpServer {
148    tls: Option<ReloadingSslContext>,
149    router: Router,
150}
151
152impl HttpServer {
153    pub fn new(
154        HttpConfig {
155            source,
156            tls,
157            authenticator_kind,
158            authenticator_rx,
159            adapter_client_rx,
160            allowed_origin,
161            active_connection_counter,
162            helm_chart_version,
163            concurrent_webhook_req,
164            metrics,
165            metrics_registry,
166            allowed_roles,
167            internal_route_config,
168            routes_enabled,
169        }: HttpConfig,
170    ) -> HttpServer {
171        let tls_enabled = tls.is_some();
172        let webhook_cache = WebhookAppenderCache::new();
173
174        // Create secure session store and manager
175        let session_store = TowerSessionMemoryStore::default();
176        let session_layer = TowerSessionManagerLayer::new(session_store)
177            .with_secure(tls_enabled) // Enforce HTTPS
178            .with_same_site(tower_sessions::cookie::SameSite::Strict) // Prevent CSRF
179            .with_http_only(true) // Prevent XSS
180            .with_name("mz_session") // Custom cookie name
181            .with_path("/"); // Set cookie path
182
183        let auth_middleware_authenticator_rx = authenticator_rx.clone();
184        let auth_middleware = middleware::from_fn(move |req, next| {
185            let authenticator_rx = auth_middleware_authenticator_rx.clone();
186            async move {
187                let authenticator = authenticator_rx
188                    .await
189                    .expect("sender not dropped before sending once");
190                http_auth(req, next, tls_enabled, authenticator, allowed_roles).await
191            }
192        });
193
194        let mut router = Router::new();
195        let mut base_router = Router::new();
196        if routes_enabled.base {
197            base_router = base_router
198                .route(
199                    "/",
200                    routing::get(move || async move {
201                        root::handle_home(routes_enabled.profiling).await
202                    }),
203                )
204                .route("/api/sql", routing::post(sql::handle_sql))
205                .route("/memory", routing::get(memory::handle_memory))
206                .route(
207                    "/hierarchical-memory",
208                    routing::get(memory::handle_hierarchical_memory),
209                )
210                .route("/static/*path", routing::get(root::handle_static));
211
212            let mut ws_router = Router::new()
213                .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
214                .with_state(WsState {
215                    authenticator_rx: authenticator_rx.clone(),
216                    adapter_client_rx: adapter_client_rx.clone(),
217                    active_connection_counter: active_connection_counter.clone(),
218                    helm_chart_version,
219                    allowed_roles,
220                });
221            if let AuthenticatorKind::None = authenticator_kind {
222                ws_router = ws_router.layer(middleware::from_fn(x_materialize_user_header_auth));
223            }
224            router = router.merge(ws_router);
225        }
226        if routes_enabled.profiling {
227            base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
228        }
229
230        if routes_enabled.webhook {
231            let webhook_router = Router::new()
232                .route(
233                    "/api/webhook/:database/:schema/:id",
234                    routing::post(webhook::handle_webhook),
235                )
236                .with_state(WebhookState {
237                    adapter_client_rx: adapter_client_rx.clone(),
238                    webhook_cache,
239                })
240                .layer(
241                    tower_http::decompression::RequestDecompressionLayer::new()
242                        .gzip(true)
243                        .deflate(true)
244                        .br(true)
245                        .zstd(true),
246                )
247                .layer(
248                    CorsLayer::new()
249                        .allow_methods(Method::POST)
250                        .allow_origin(AllowOrigin::mirror_request())
251                        .allow_headers(Any),
252                )
253                .layer(
254                    ServiceBuilder::new()
255                        .layer(HandleErrorLayer::new(handle_load_error))
256                        .load_shed()
257                        .layer(GlobalConcurrencyLimitLayer::with_semaphore(
258                            concurrent_webhook_req,
259                        )),
260                );
261            router = router.merge(webhook_router);
262        }
263
264        if routes_enabled.internal {
265            let console_config = Arc::new(console::ConsoleProxyConfig::new(
266                internal_route_config.internal_console_redirect_url.clone(),
267                "/internal-console".to_string(),
268            ));
269            base_router = base_router
270                .route(
271                    "/api/opentelemetry/config",
272                    routing::put({
273                        move |_: axum::Json<DynamicFilterTarget>| async {
274                            (
275                                StatusCode::BAD_REQUEST,
276                                "This endpoint has been replaced. \
277                            Use the `opentelemetry_filter` system variable."
278                                    .to_string(),
279                            )
280                        }
281                    }),
282                )
283                .route(
284                    "/api/stderr/config",
285                    routing::put({
286                        move |_: axum::Json<DynamicFilterTarget>| async {
287                            (
288                                StatusCode::BAD_REQUEST,
289                                "This endpoint has been replaced. \
290                            Use the `log_filter` system variable."
291                                    .to_string(),
292                            )
293                        }
294                    }),
295                )
296                .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
297                .route(
298                    "/api/catalog/dump",
299                    routing::get(catalog::handle_catalog_dump),
300                )
301                .route(
302                    "/api/catalog/check",
303                    routing::get(catalog::handle_catalog_check),
304                )
305                .route(
306                    "/api/coordinator/check",
307                    routing::get(catalog::handle_coordinator_check),
308                )
309                .route(
310                    "/api/coordinator/dump",
311                    routing::get(catalog::handle_coordinator_dump),
312                )
313                .route(
314                    "/internal-console",
315                    routing::get(|| async { Redirect::temporary("/internal-console/") }),
316                )
317                .route(
318                    "/internal-console/*path",
319                    routing::get(console::handle_internal_console),
320                )
321                .route(
322                    "/internal-console/",
323                    routing::get(console::handle_internal_console),
324                )
325                .layer(Extension(console_config));
326            let leader_router = Router::new()
327                .route("/api/leader/status", routing::get(handle_leader_status))
328                .route("/api/leader/promote", routing::post(handle_leader_promote))
329                .route(
330                    "/api/leader/skip-catchup",
331                    routing::post(handle_leader_skip_catchup),
332                )
333                .layer(auth_middleware.clone())
334                .with_state(internal_route_config.deployment_state_handle.clone());
335            router = router.merge(leader_router);
336        }
337
338        if routes_enabled.metrics {
339            let metrics_router = Router::new()
340                .route(
341                    "/metrics",
342                    routing::get(move || async move {
343                        mz_http_util::handle_prometheus(&metrics_registry).await
344                    }),
345                )
346                .route(
347                    "/metrics/mz_usage",
348                    routing::get(|client: AuthedClient| async move {
349                        let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
350                        mz_http_util::handle_prometheus(&registry).await
351                    }),
352                )
353                .route(
354                    "/metrics/mz_frontier",
355                    routing::get(|client: AuthedClient| async move {
356                        let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
357                        mz_http_util::handle_prometheus(&registry).await
358                    }),
359                )
360                .route(
361                    "/metrics/mz_compute",
362                    routing::get(|client: AuthedClient| async move {
363                        let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
364                        mz_http_util::handle_prometheus(&registry).await
365                    }),
366                )
367                .route(
368                    "/metrics/mz_storage",
369                    routing::get(|client: AuthedClient| async move {
370                        let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
371                        mz_http_util::handle_prometheus(&registry).await
372                    }),
373                )
374                .route(
375                    "/api/livez",
376                    routing::get(mz_http_util::handle_liveness_check),
377                )
378                .route("/api/readyz", routing::get(probe::handle_ready))
379                .layer(auth_middleware.clone())
380                .layer(Extension(adapter_client_rx.clone()))
381                .layer(Extension(active_connection_counter.clone()));
382            router = router.merge(metrics_router);
383        }
384
385        base_router = base_router
386            .layer(auth_middleware.clone())
387            .layer(Extension(adapter_client_rx.clone()))
388            .layer(Extension(active_connection_counter.clone()))
389            .layer(
390                CorsLayer::new()
391                    .allow_credentials(false)
392                    .allow_headers([
393                        AUTHORIZATION,
394                        CONTENT_TYPE,
395                        HeaderName::from_static("x-materialize-version"),
396                    ])
397                    .allow_methods(Any)
398                    .allow_origin(allowed_origin)
399                    .expose_headers(Any)
400                    .max_age(Duration::from_secs(60) * 60),
401            );
402
403        match authenticator_kind {
404            AuthenticatorKind::Password => {
405                base_router = base_router.layer(session_layer.clone());
406
407                let login_router = Router::new()
408                    .route("/api/login", routing::post(handle_login))
409                    .route("/api/logout", routing::post(handle_logout))
410                    .layer(Extension(adapter_client_rx));
411                router = router.merge(login_router).layer(session_layer);
412            }
413            AuthenticatorKind::None => {
414                base_router =
415                    base_router.layer(middleware::from_fn(x_materialize_user_header_auth));
416            }
417            _ => {}
418        }
419
420        router = router
421            .merge(base_router)
422            .apply_default_layers(source, metrics);
423
424        HttpServer { tls, router }
425    }
426}
427
428impl Server for HttpServer {
429    const NAME: &'static str = "http";
430
431    fn handle_connection(
432        &self,
433        conn: Connection,
434        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
435    ) -> ConnectionHandler {
436        let router = self.router.clone();
437        let tls_context = self.tls.clone();
438        let mut conn = TokioIo::new(conn);
439
440        Box::pin(async {
441            let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
442            let peer_addr = conn
443                .inner_mut()
444                .take_proxy_header_address()
445                .await
446                .map(|a| a.source)
447                .unwrap_or(direct_peer_addr);
448
449            let (conn, conn_protocol) = match tls_context {
450                Some(tls_context) => {
451                    let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
452                    if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
453                        let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
454                        return Err(e.into());
455                    }
456                    (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
457                }
458                _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
459            };
460            let mut make_tower_svc = router
461                .layer(Extension(conn_protocol))
462                .into_make_service_with_connect_info::<SocketAddr>();
463            let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
464            let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
465            let http = hyper::server::conn::http1::Builder::new();
466            http.serve_connection(conn, hyper_svc)
467                .with_upgrades()
468                .err_into()
469                .await
470        })
471    }
472}
473
474pub async fn handle_leader_status(
475    State(deployment_state_handle): State<DeploymentStateHandle>,
476) -> impl IntoResponse {
477    let status = deployment_state_handle.status();
478    (StatusCode::OK, Json(json!({ "status": status })))
479}
480
481pub async fn handle_leader_promote(
482    State(deployment_state_handle): State<DeploymentStateHandle>,
483) -> impl IntoResponse {
484    match deployment_state_handle.try_promote() {
485        Ok(()) => {
486            // TODO(benesch): the body here is redundant. Should just return
487            // 204.
488            let status = StatusCode::OK;
489            let body = Json(json!({
490                "result": "Success",
491            }));
492            (status, body)
493        }
494        Err(()) => {
495            // TODO(benesch): the nesting here is redundant given the error
496            // code. Should just return the `{"message": "..."}` object.
497            let status = StatusCode::BAD_REQUEST;
498            let body = Json(json!({
499                "result": {"Failure": {"message": "cannot promote leader while initializing"}},
500            }));
501            (status, body)
502        }
503    }
504}
505
506pub async fn handle_leader_skip_catchup(
507    State(deployment_state_handle): State<DeploymentStateHandle>,
508) -> impl IntoResponse {
509    match deployment_state_handle.try_skip_catchup() {
510        Ok(()) => StatusCode::NO_CONTENT.into_response(),
511        Err(()) => {
512            let status = StatusCode::BAD_REQUEST;
513            let body = Json(json!({
514                "message": "cannot skip catchup in this phase of initialization; try again later",
515            }));
516            (status, body).into_response()
517        }
518    }
519}
520
521async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl IntoResponse {
522    // TODO migrate teleport to basic auth and remove this.
523    if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
524        let username = match username {
525            Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
526            _ => {
527                return Err(AuthError::MismatchedUser(format!(
528                    "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
529                )));
530            }
531        };
532        let superuser = matches!(username.as_str(), SYSTEM_USER_NAME);
533        req.extensions_mut().insert(AuthedUser {
534            name: username,
535            external_metadata_rx: None,
536            internal_metadata: Some(InternalUserMetadata { superuser }),
537        });
538    }
539    Ok(next.run(req).await)
540}
541
542type Delayed<T> = Shared<oneshot::Receiver<T>>;
543
544#[derive(Clone)]
545enum ConnProtocol {
546    Http,
547    Https,
548}
549
550#[derive(Clone, Debug)]
551pub struct AuthedUser {
552    name: String,
553    external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
554    internal_metadata: Option<InternalUserMetadata>,
555}
556
557pub struct AuthedClient {
558    pub client: SessionClient,
559    pub connection_guard: Option<ConnectionHandle>,
560}
561
562impl AuthedClient {
563    async fn new<F>(
564        adapter_client: &Client,
565        user: AuthedUser,
566        peer_addr: IpAddr,
567        active_connection_counter: ConnectionCounter,
568        helm_chart_version: Option<String>,
569        session_config: F,
570        options: BTreeMap<String, String>,
571        now: NowFn,
572    ) -> Result<Self, AdapterError>
573    where
574        F: FnOnce(&mut AdapterSession),
575    {
576        let conn_id = adapter_client.new_conn_id()?;
577        let mut session = adapter_client.new_session(AdapterSessionConfig {
578            conn_id,
579            uuid: epoch_to_uuid_v7(&(now)()),
580            user: user.name,
581            client_ip: Some(peer_addr),
582            external_metadata_rx: user.external_metadata_rx,
583            internal_user_metadata: user.internal_metadata,
584            helm_chart_version,
585        });
586        let connection_guard = active_connection_counter.allocate_connection(session.user())?;
587
588        session_config(&mut session);
589        let system_vars = adapter_client.get_system_vars().await;
590        for (key, val) in options {
591            const LOCAL: bool = false;
592            if let Err(err) =
593                session
594                    .vars_mut()
595                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
596            {
597                session.add_notice(AdapterNotice::BadStartupSetting {
598                    name: key.to_string(),
599                    reason: err.to_string(),
600                })
601            }
602        }
603        let adapter_client = adapter_client.startup(session).await?;
604        Ok(AuthedClient {
605            client: adapter_client,
606            connection_guard,
607        })
608    }
609}
610
611#[async_trait]
612impl<S> FromRequestParts<S> for AuthedClient
613where
614    S: Send + Sync,
615{
616    type Rejection = Response;
617
618    async fn from_request_parts(
619        req: &mut http::request::Parts,
620        state: &S,
621    ) -> Result<Self, Self::Rejection> {
622        #[derive(Debug, Default, Deserialize)]
623        struct Params {
624            #[serde(default)]
625            options: String,
626        }
627        let params: Query<Params> = Query::from_request_parts(req, state)
628            .await
629            .unwrap_or_default();
630
631        let peer_addr = req
632            .extensions
633            .get::<ConnectInfo<SocketAddr>>()
634            .expect("ConnectInfo extension guaranteed to exist")
635            .0
636            .ip();
637
638        let user = req.extensions.get::<AuthedUser>().unwrap();
639        let adapter_client = req
640            .extensions
641            .get::<Delayed<mz_adapter::Client>>()
642            .unwrap()
643            .clone();
644        let adapter_client = adapter_client.await.map_err(|_| {
645            (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
646        })?;
647        let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
648        let helm_chart_version = None;
649
650        let options = if params.options.is_empty() {
651            // It's possible 'options' simply wasn't provided, we don't want that to
652            // count as a failure to deserialize
653            BTreeMap::<String, String>::default()
654        } else {
655            match serde_json::from_str(&params.options) {
656                Ok(options) => options,
657                Err(_e) => {
658                    // If we fail to deserialize options, fail the request.
659                    let code = StatusCode::BAD_REQUEST;
660                    let msg = format!("Failed to deserialize {} map", "options".quoted());
661                    return Err((code, msg).into_response());
662                }
663            }
664        };
665
666        let client = AuthedClient::new(
667            &adapter_client,
668            user.clone(),
669            peer_addr,
670            active_connection_counter.clone(),
671            helm_chart_version,
672            |session| {
673                session
674                    .vars_mut()
675                    .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
676                    .expect("known to exist")
677            },
678            options,
679            SYSTEM_TIME.clone(),
680        )
681        .await
682        .map_err(|e| {
683            let status = match e {
684                AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
685                    StatusCode::FORBIDDEN
686                }
687                _ => StatusCode::INTERNAL_SERVER_ERROR,
688            };
689            (status, Json(SqlError::from(e))).into_response()
690        })?;
691
692        Ok(client)
693    }
694}
695
696#[derive(Debug, Error)]
697enum AuthError {
698    #[error("role dissallowed")]
699    RoleDisallowed(String),
700    #[error("{0}")]
701    Frontegg(#[from] FronteggError),
702    #[error("missing authorization header")]
703    MissingHttpAuthentication {
704        include_www_authenticate_header: bool,
705    },
706    #[error("{0}")]
707    MismatchedUser(String),
708    #[error("session expired")]
709    SessionExpired,
710    #[error("failed to update session")]
711    FailedToUpdateSession,
712    #[error("invalid credentials")]
713    InvalidCredentials,
714}
715
716impl IntoResponse for AuthError {
717    fn into_response(self) -> Response {
718        warn!("HTTP request failed authentication: {}", self);
719        let mut headers = HeaderMap::new();
720        match self {
721            AuthError::MissingHttpAuthentication {
722                include_www_authenticate_header,
723            } if include_www_authenticate_header => {
724                headers.insert(
725                    http::header::WWW_AUTHENTICATE,
726                    HeaderValue::from_static("Basic realm=Materialize"),
727                );
728            }
729            _ => {}
730        };
731        // We omit most detail from the error message we send to the client, to
732        // avoid giving attackers unnecessary information.
733        (StatusCode::UNAUTHORIZED, headers, "unauthorized").into_response()
734    }
735}
736
737// Simplified login handler
738pub async fn handle_login(
739    session: Option<Extension<TowerSession>>,
740    Extension(adapter_client_rx): Extension<Delayed<Client>>,
741    Json(LoginCredentials { username, password }): Json<LoginCredentials>,
742) -> impl IntoResponse {
743    let Ok(adapter_client) = adapter_client_rx.clone().await else {
744        return StatusCode::INTERNAL_SERVER_ERROR;
745    };
746    let auth_response = match adapter_client.authenticate(&username, &password).await {
747        Ok(auth_response) => auth_response,
748        Err(err) => {
749            warn!(?err, "HTTP login failed authentication");
750            return StatusCode::UNAUTHORIZED;
751        }
752    };
753
754    // Create session data
755    let session_data = TowerSessionData {
756        username,
757        created_at: SystemTime::now(),
758        last_activity: SystemTime::now(),
759        internal_metadata: InternalUserMetadata {
760            superuser: auth_response.superuser,
761        },
762    };
763    // Store session data
764    let session = session.and_then(|Extension(session)| Some(session));
765    let Some(session) = session else {
766        return StatusCode::INTERNAL_SERVER_ERROR;
767    };
768    match session.insert("data", &session_data).await {
769        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
770        Ok(_) => StatusCode::OK,
771    }
772}
773
774// Simplified logout handler
775pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
776    let session = session.and_then(|Extension(session)| Some(session));
777    let Some(session) = session else {
778        return StatusCode::INTERNAL_SERVER_ERROR;
779    };
780    // Delete session
781    match session.delete().await {
782        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
783        Ok(_) => StatusCode::OK,
784    }
785}
786
787async fn http_auth(
788    mut req: Request,
789    next: Next,
790    tls_enabled: bool,
791    authenticator: Arc<Authenticator>,
792    allowed_roles: AllowedRoles,
793) -> impl IntoResponse + use<> {
794    // First check for session authentication
795    if let Some(session) = req.extensions().get::<TowerSession>() {
796        if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
797            // Check session expiration
798            if session_data
799                .last_activity
800                .elapsed()
801                .unwrap_or(Duration::MAX)
802                > SESSION_DURATION
803            {
804                let _ = session.delete().await;
805                return Err(AuthError::SessionExpired);
806            }
807            // Update last activity
808            let mut updated_data = session_data.clone();
809            updated_data.last_activity = SystemTime::now();
810            session
811                .insert("data", &updated_data)
812                .await
813                .map_err(|_| AuthError::FailedToUpdateSession)?;
814            // User is authenticated via session
815            req.extensions_mut().insert(AuthedUser {
816                name: session_data.username,
817                external_metadata_rx: None,
818                internal_metadata: Some(session_data.internal_metadata),
819            });
820            return Ok(next.run(req).await);
821        }
822    }
823
824    // First, extract the username from the certificate, validating that the
825    // connection matches the TLS configuration along the way.
826    // Fall back to existing authentication methods.
827    let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
828    match (tls_enabled, &conn_protocol) {
829        (false, ConnProtocol::Http) => {}
830        (false, ConnProtocol::Https { .. }) => unreachable!(),
831        (true, ConnProtocol::Http) => {
832            let mut parts = req.uri().clone().into_parts();
833            parts.scheme = Some(Scheme::HTTPS);
834            return Ok(Redirect::permanent(
835                &Uri::from_parts(parts)
836                    .expect("it was already a URI, just changed the scheme")
837                    .to_string(),
838            )
839            .into_response());
840        }
841        (true, ConnProtocol::Https { .. }) => {}
842    }
843    // If we've already passed some other auth, just use that.
844    if req.extensions().get::<AuthedUser>().is_some() {
845        return Ok(next.run(req).await);
846    }
847    let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
848        Some(Credentials::Password {
849            username: basic.username().to_owned(),
850            password: Password(basic.password().to_owned()),
851        })
852    } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
853        Some(Credentials::Token {
854            token: bearer.token().to_owned(),
855        })
856    } else {
857        None
858    };
859
860    let path = req.uri().path();
861    let include_www_authenticate_header = path == "/"
862        || PROFILING_API_ENDPOINTS
863            .iter()
864            .any(|prefix| path.starts_with(prefix));
865    let user = auth(
866        &authenticator,
867        creds,
868        allowed_roles,
869        include_www_authenticate_header,
870    )
871    .await?;
872
873    // Add the authenticated user as an extension so downstream handlers can
874    // inspect it if necessary.
875    req.extensions_mut().insert(user);
876
877    // Run the request.
878    Ok(next.run(req).await)
879}
880
881async fn init_ws(
882    WsState {
883        authenticator_rx,
884        adapter_client_rx,
885        active_connection_counter,
886        helm_chart_version,
887        allowed_roles,
888    }: &WsState,
889    existing_user: Option<AuthedUser>,
890    peer_addr: IpAddr,
891    ws: &mut WebSocket,
892) -> Result<AuthedClient, anyhow::Error> {
893    let authenticator = authenticator_rx.clone().await.expect("sender not dropped");
894    // TODO: Add a timeout here to prevent resource leaks by clients that
895    // connect then never send a message.
896    let ws_auth: WebSocketAuth = loop {
897        let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
898        match init_msg {
899            Message::Text(data) => break serde_json::from_str(&data)?,
900            Message::Binary(data) => break serde_json::from_slice(&data)?,
901            // Handled automatically by the server.
902            Message::Ping(_) => {
903                continue;
904            }
905            Message::Pong(_) => {
906                continue;
907            }
908            Message::Close(_) => {
909                anyhow::bail!("closed");
910            }
911        }
912    };
913
914    let (user, options) = if let Some(existing_user) = existing_user {
915        match ws_auth {
916            WebSocketAuth::OptionsOnly { options } => (existing_user, options),
917            _ => {
918                warn!("Unexpected bearer or basic auth provided when using user header");
919                anyhow::bail!("unexpected")
920            }
921        }
922    } else {
923        let (creds, options) = match ws_auth {
924            WebSocketAuth::Basic {
925                user,
926                password,
927                options,
928            } => {
929                let creds = Credentials::Password {
930                    username: user,
931                    password,
932                };
933                (creds, options)
934            }
935            WebSocketAuth::Bearer { token, options } => {
936                let creds = Credentials::Token { token };
937                (creds, options)
938            }
939            WebSocketAuth::OptionsOnly { .. } => {
940                anyhow::bail!("expected auth information");
941            }
942        };
943        let user = auth(&authenticator, Some(creds), *allowed_roles, false).await?;
944        (user, options)
945    };
946
947    let client = AuthedClient::new(
948        &adapter_client_rx.clone().await?,
949        user,
950        peer_addr,
951        active_connection_counter.clone(),
952        helm_chart_version.clone(),
953        |_session| (),
954        options,
955        SYSTEM_TIME.clone(),
956    )
957    .await?;
958
959    Ok(client)
960}
961
962enum Credentials {
963    Password {
964        username: String,
965        password: Password,
966    },
967    Token {
968        token: String,
969    },
970}
971
972async fn auth(
973    authenticator: &Authenticator,
974    creds: Option<Credentials>,
975    allowed_roles: AllowedRoles,
976    include_www_authenticate_header: bool,
977) -> Result<AuthedUser, AuthError> {
978    let (name, external_metadata_rx, internal_metadata) = match authenticator {
979        Authenticator::Frontegg(frontegg) => match creds {
980            Some(Credentials::Password { username, password }) => {
981                let auth_session = frontegg.authenticate(&username, &password.0).await?;
982                let name = auth_session.user().into();
983                let external_metadata_rx = Some(auth_session.external_metadata_rx());
984                (name, external_metadata_rx, None)
985            }
986            Some(Credentials::Token { token }) => {
987                let claims = frontegg.validate_access_token(&token, None)?;
988                let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
989                    user_id: claims.user_id,
990                    admin: claims.is_admin,
991                });
992                (claims.user, Some(external_metadata_rx), None)
993            }
994            None => {
995                return Err(AuthError::MissingHttpAuthentication {
996                    include_www_authenticate_header,
997                });
998            }
999        },
1000        Authenticator::Password(adapter_client) => match creds {
1001            Some(Credentials::Password { username, password }) => {
1002                let auth_response = adapter_client
1003                    .authenticate(&username, &password)
1004                    .await
1005                    .map_err(|_| AuthError::InvalidCredentials)?;
1006                let internal_metadata = InternalUserMetadata {
1007                    superuser: auth_response.superuser,
1008                };
1009                (username, None, Some(internal_metadata))
1010            }
1011            _ => {
1012                return Err(AuthError::MissingHttpAuthentication {
1013                    include_www_authenticate_header,
1014                });
1015            }
1016        },
1017        Authenticator::Sasl(_) => {
1018            // We shouldn't ever end up here as the configuration is validated at startup.
1019            // If we do, it's a server misconfiguration.
1020            // Just in case, we return a 401 rather than panic.
1021            return Err(AuthError::MissingHttpAuthentication {
1022                include_www_authenticate_header,
1023            });
1024        }
1025        Authenticator::None => {
1026            // If no authentication, use whatever is in the HTTP auth
1027            // header (without checking the password), or fall back to the
1028            // default user.
1029            let name = match creds {
1030                Some(Credentials::Password { username, .. }) => username,
1031                _ => HTTP_DEFAULT_USER.name.to_owned(),
1032            };
1033            (name, None, None)
1034        }
1035    };
1036
1037    check_role_allowed(&name, allowed_roles)?;
1038
1039    Ok(AuthedUser {
1040        name,
1041        external_metadata_rx,
1042        internal_metadata,
1043    })
1044}
1045
1046// TODO move this somewhere it can be shared with PGWIRE
1047fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1048    let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1049    // this is a superset of internal users
1050    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1051    let role_allowed = match allowed_roles {
1052        AllowedRoles::Normal => !is_reserved_user,
1053        AllowedRoles::Internal => is_internal_user,
1054        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1055    };
1056    if role_allowed {
1057        Ok(())
1058    } else {
1059        Err(AuthError::RoleDisallowed(name.to_owned()))
1060    }
1061}
1062
1063/// Default layers that should be applied to all routes, and should get applied to both the
1064/// internal http and external http routers.
1065trait DefaultLayers {
1066    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1067}
1068
1069impl DefaultLayers for Router {
1070    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1071        self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1072            .layer(metrics::PrometheusLayer::new(source, metrics))
1073    }
1074}
1075
1076/// Glue code to make [`tower`] work with [`axum`].
1077///
1078/// `axum` requires `Layer`s not return Errors, i.e. they must be `Result<_, Infallible>`,
1079/// instead you must return a type that can be converted into a response. `tower` on the other
1080/// hand does return Errors, so to make the two work together we need to convert our `tower` errors
1081/// into responses.
1082async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1083    if error.is::<tower::load_shed::error::Overloaded>() {
1084        return (
1085            StatusCode::TOO_MANY_REQUESTS,
1086            Cow::from("too many requests, try again later"),
1087        );
1088    }
1089
1090    // Note: This should be unreachable because at the time of writing our only use case is a
1091    // layer that emits `tower::load_shed::error::Overloaded`, which is handled above.
1092    (
1093        StatusCode::INTERNAL_SERVER_ERROR,
1094        Cow::from(format!("Unhandled internal error: {}", error)),
1095    )
1096}
1097
1098#[derive(Debug, Deserialize, Serialize, PartialEq)]
1099pub struct LoginCredentials {
1100    username: String,
1101    password: Password,
1102}
1103
1104#[derive(Debug, Clone, Serialize, Deserialize)]
1105pub struct TowerSessionData {
1106    username: String,
1107    created_at: SystemTime,
1108    last_activity: SystemTime,
1109    internal_metadata: InternalUserMetadata,
1110}
1111
1112#[cfg(test)]
1113mod tests {
1114    use super::{AllowedRoles, check_role_allowed};
1115
1116    #[mz_ore::test]
1117    fn test_check_role_allowed() {
1118        // Internal user
1119        assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1120        assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1121        assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1122
1123        // Internal user
1124        assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1125        assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1126        assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1127
1128        // Internal user
1129        assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1130        assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1131        assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1132
1133        // Normal user
1134        assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1135        assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1136        assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1137
1138        // Normal user
1139        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1140        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1141        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1142
1143        // Normal user
1144        assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1145        assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1146        assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1147
1148        // Denied by reserved role prefix
1149        assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1150        assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1151        assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1152
1153        // Denied by reserved role prefix
1154        assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1155        assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1156        assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1157
1158        // Denied by reserved role prefix
1159        assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1160        assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1161        assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1162
1163        // Denied by literal PUBLIC
1164        assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1165        assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1166        assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1167    }
1168}