mz_environmentd/
http.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Embedded HTTP server.
11//!
12//! environmentd embeds an HTTP server for introspection into the running
13//! process. At the moment, its primary exports are Prometheus metrics, heap
14//! profiles, and catalog dumps.
15
16// Axum handlers must use async, but often don't actually use `await`.
17#![allow(clippy::unused_async)]
18
19use std::borrow::Cow;
20use std::collections::BTreeMap;
21use std::fmt::Debug;
22use std::net::{IpAddr, SocketAddr};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::time::{Duration, SystemTime};
26
27use anyhow::Context;
28use async_trait::async_trait;
29use axum::error_handling::HandleErrorLayer;
30use axum::extract::ws::{Message, WebSocket};
31use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
32use axum::middleware::{self, Next};
33use axum::response::{IntoResponse, Redirect, Response};
34use axum::{Extension, Json, Router, routing};
35use futures::future::{Shared, TryFutureExt};
36use headers::authorization::{Authorization, Basic, Bearer};
37use headers::{HeaderMapExt, HeaderName};
38use http::header::{AUTHORIZATION, CONTENT_TYPE};
39use http::uri::Scheme;
40use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
41use hyper_openssl::SslStream;
42use hyper_openssl::client::legacy::MaybeHttpsStream;
43use hyper_util::rt::TokioIo;
44use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
45use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
46use mz_auth::password::Password;
47use mz_authenticator::Authenticator;
48use mz_frontegg_auth::Error as FronteggError;
49use mz_http_util::DynamicFilterTarget;
50use mz_ore::cast::u64_to_usize;
51use mz_ore::metrics::MetricsRegistry;
52use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
53use mz_ore::str::StrExt;
54use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
55use mz_repr::user::ExternalUserMetadata;
56use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind, HttpRoutesEnabled};
57use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
58use mz_sql::session::metadata::SessionMetadata;
59use mz_sql::session::user::{
60    HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
61};
62use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
63use openssl::ssl::Ssl;
64use prometheus::{
65    COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
66};
67use serde::{Deserialize, Serialize};
68use serde_json::json;
69use thiserror::Error;
70use tokio::io::AsyncWriteExt;
71use tokio::sync::oneshot::Receiver;
72use tokio::sync::{oneshot, watch};
73use tokio_metrics::TaskMetrics;
74use tower::limit::GlobalConcurrencyLimitLayer;
75use tower::{Service, ServiceBuilder};
76use tower_http::cors::{AllowOrigin, Any, CorsLayer};
77use tower_sessions::{
78    MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
79    SessionManagerLayer as TowerSessionManagerLayer,
80};
81use tracing::{error, warn};
82
83use crate::BUILD_INFO;
84use crate::deployment::state::DeploymentStateHandle;
85use crate::http::sql::SqlError;
86
87mod catalog;
88mod console;
89mod memory;
90mod metrics;
91mod probe;
92mod prometheus;
93mod root;
94mod sql;
95mod webhook;
96
97pub use metrics::Metrics;
98pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
99
100/// Maximum allowed size for a request.
101pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
102
103const SESSION_DURATION: Duration = Duration::from_secs(3600); // 1 hour
104
105const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
106
107#[derive(Debug)]
108pub struct HttpConfig {
109    pub source: &'static str,
110    pub tls: Option<ReloadingSslContext>,
111    pub authenticator_kind: AuthenticatorKind,
112    pub authenticator_rx: Shared<Receiver<Arc<Authenticator>>>,
113    pub adapter_client_rx: Shared<Receiver<Client>>,
114    pub allowed_origin: AllowOrigin,
115    pub active_connection_counter: ConnectionCounter,
116    pub helm_chart_version: Option<String>,
117    pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
118    pub metrics: Metrics,
119    pub metrics_registry: MetricsRegistry,
120    pub allowed_roles: AllowedRoles,
121    pub internal_route_config: Arc<InternalRouteConfig>,
122    pub routes_enabled: HttpRoutesEnabled,
123}
124
125#[derive(Debug, Clone)]
126pub struct InternalRouteConfig {
127    pub deployment_state_handle: DeploymentStateHandle,
128    pub internal_console_redirect_url: Option<String>,
129}
130
131#[derive(Clone)]
132pub struct WsState {
133    authenticator_rx: Delayed<Arc<Authenticator>>,
134    adapter_client_rx: Delayed<mz_adapter::Client>,
135    active_connection_counter: ConnectionCounter,
136    helm_chart_version: Option<String>,
137    allowed_roles: AllowedRoles,
138}
139
140#[derive(Clone)]
141pub struct WebhookState {
142    adapter_client_rx: Delayed<mz_adapter::Client>,
143    webhook_cache: WebhookAppenderCache,
144}
145
146#[derive(Debug)]
147pub struct HttpServer {
148    tls: Option<ReloadingSslContext>,
149    router: Router,
150}
151
152impl HttpServer {
153    pub fn new(
154        HttpConfig {
155            source,
156            tls,
157            authenticator_kind,
158            authenticator_rx,
159            adapter_client_rx,
160            allowed_origin,
161            active_connection_counter,
162            helm_chart_version,
163            concurrent_webhook_req,
164            metrics,
165            metrics_registry,
166            allowed_roles,
167            internal_route_config,
168            routes_enabled,
169        }: HttpConfig,
170    ) -> HttpServer {
171        let tls_enabled = tls.is_some();
172        let webhook_cache = WebhookAppenderCache::new();
173
174        // Create secure session store and manager
175        let session_store = TowerSessionMemoryStore::default();
176        let session_layer = TowerSessionManagerLayer::new(session_store)
177            .with_secure(tls_enabled) // Enforce HTTPS
178            .with_same_site(tower_sessions::cookie::SameSite::Strict) // Prevent CSRF
179            .with_http_only(true) // Prevent XSS
180            .with_name("mz_session") // Custom cookie name
181            .with_path("/"); // Set cookie path
182
183        let auth_middleware_authenticator_rx = authenticator_rx.clone();
184        let auth_middleware = middleware::from_fn(move |req, next| {
185            let authenticator_rx = auth_middleware_authenticator_rx.clone();
186            async move {
187                let authenticator = authenticator_rx
188                    .await
189                    .expect("sender not dropped before sending once");
190                http_auth(req, next, tls_enabled, authenticator, allowed_roles).await
191            }
192        });
193
194        let mut router = Router::new();
195        let mut base_router = Router::new();
196        if routes_enabled.base {
197            base_router = base_router
198                .route(
199                    "/",
200                    routing::get(move || async move {
201                        root::handle_home(routes_enabled.profiling).await
202                    }),
203                )
204                .route("/api/sql", routing::post(sql::handle_sql))
205                .route("/memory", routing::get(memory::handle_memory))
206                .route(
207                    "/hierarchical-memory",
208                    routing::get(memory::handle_hierarchical_memory),
209                )
210                .route("/static/*path", routing::get(root::handle_static));
211
212            let mut ws_router = Router::new()
213                .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
214                .with_state(WsState {
215                    authenticator_rx: authenticator_rx.clone(),
216                    adapter_client_rx: adapter_client_rx.clone(),
217                    active_connection_counter: active_connection_counter.clone(),
218                    helm_chart_version,
219                    allowed_roles,
220                });
221            if let AuthenticatorKind::None = authenticator_kind {
222                ws_router = ws_router.layer(middleware::from_fn(x_materialize_user_header_auth));
223            }
224            router = router.merge(ws_router);
225        }
226        if routes_enabled.profiling {
227            base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
228        }
229
230        if routes_enabled.webhook {
231            let webhook_router = Router::new()
232                .route(
233                    "/api/webhook/:database/:schema/:id",
234                    routing::post(webhook::handle_webhook),
235                )
236                .with_state(WebhookState {
237                    adapter_client_rx: adapter_client_rx.clone(),
238                    webhook_cache,
239                })
240                .layer(
241                    tower_http::decompression::RequestDecompressionLayer::new()
242                        .gzip(true)
243                        .deflate(true)
244                        .br(true)
245                        .zstd(true),
246                )
247                .layer(
248                    CorsLayer::new()
249                        .allow_methods(Method::POST)
250                        .allow_origin(AllowOrigin::mirror_request())
251                        .allow_headers(Any),
252                )
253                .layer(
254                    ServiceBuilder::new()
255                        .layer(HandleErrorLayer::new(handle_load_error))
256                        .load_shed()
257                        .layer(GlobalConcurrencyLimitLayer::with_semaphore(
258                            concurrent_webhook_req,
259                        )),
260                );
261            router = router.merge(webhook_router);
262        }
263
264        if routes_enabled.internal {
265            let console_config = Arc::new(console::ConsoleProxyConfig::new(
266                internal_route_config.internal_console_redirect_url.clone(),
267                "/internal-console".to_string(),
268            ));
269            base_router = base_router
270                .route(
271                    "/api/opentelemetry/config",
272                    routing::put({
273                        move |_: axum::Json<DynamicFilterTarget>| async {
274                            (
275                                StatusCode::BAD_REQUEST,
276                                "This endpoint has been replaced. \
277                            Use the `opentelemetry_filter` system variable."
278                                    .to_string(),
279                            )
280                        }
281                    }),
282                )
283                .route(
284                    "/api/stderr/config",
285                    routing::put({
286                        move |_: axum::Json<DynamicFilterTarget>| async {
287                            (
288                                StatusCode::BAD_REQUEST,
289                                "This endpoint has been replaced. \
290                            Use the `log_filter` system variable."
291                                    .to_string(),
292                            )
293                        }
294                    }),
295                )
296                .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
297                .route(
298                    "/api/catalog/dump",
299                    routing::get(catalog::handle_catalog_dump),
300                )
301                .route(
302                    "/api/catalog/check",
303                    routing::get(catalog::handle_catalog_check),
304                )
305                .route(
306                    "/api/coordinator/check",
307                    routing::get(catalog::handle_coordinator_check),
308                )
309                .route(
310                    "/api/coordinator/dump",
311                    routing::get(catalog::handle_coordinator_dump),
312                )
313                .route(
314                    "/internal-console",
315                    routing::get(|| async { Redirect::temporary("/internal-console/") }),
316                )
317                .route(
318                    "/internal-console/*path",
319                    routing::get(console::handle_internal_console),
320                )
321                .route(
322                    "/internal-console/",
323                    routing::get(console::handle_internal_console),
324                )
325                .layer(Extension(console_config));
326            let leader_router = Router::new()
327                .route("/api/leader/status", routing::get(handle_leader_status))
328                .route("/api/leader/promote", routing::post(handle_leader_promote))
329                .route(
330                    "/api/leader/skip-catchup",
331                    routing::post(handle_leader_skip_catchup),
332                )
333                .layer(auth_middleware.clone())
334                .with_state(internal_route_config.deployment_state_handle.clone());
335            router = router.merge(leader_router);
336        }
337
338        if routes_enabled.metrics {
339            let metrics_router = Router::new()
340                .route(
341                    "/metrics",
342                    routing::get(move || async move {
343                        mz_http_util::handle_prometheus(&metrics_registry).await
344                    }),
345                )
346                .route(
347                    "/metrics/mz_usage",
348                    routing::get(|client: AuthedClient| async move {
349                        let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
350                        mz_http_util::handle_prometheus(&registry).await
351                    }),
352                )
353                .route(
354                    "/metrics/mz_frontier",
355                    routing::get(|client: AuthedClient| async move {
356                        let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
357                        mz_http_util::handle_prometheus(&registry).await
358                    }),
359                )
360                .route(
361                    "/metrics/mz_compute",
362                    routing::get(|client: AuthedClient| async move {
363                        let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
364                        mz_http_util::handle_prometheus(&registry).await
365                    }),
366                )
367                .route(
368                    "/metrics/mz_storage",
369                    routing::get(|client: AuthedClient| async move {
370                        let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
371                        mz_http_util::handle_prometheus(&registry).await
372                    }),
373                )
374                .route(
375                    "/api/livez",
376                    routing::get(mz_http_util::handle_liveness_check),
377                )
378                .route("/api/readyz", routing::get(probe::handle_ready))
379                .layer(auth_middleware.clone())
380                .layer(Extension(adapter_client_rx.clone()))
381                .layer(Extension(active_connection_counter.clone()));
382            router = router.merge(metrics_router);
383        }
384
385        base_router = base_router
386            .layer(auth_middleware.clone())
387            .layer(Extension(adapter_client_rx.clone()))
388            .layer(Extension(active_connection_counter.clone()))
389            .layer(
390                CorsLayer::new()
391                    .allow_credentials(false)
392                    .allow_headers([
393                        AUTHORIZATION,
394                        CONTENT_TYPE,
395                        HeaderName::from_static("x-materialize-version"),
396                    ])
397                    .allow_methods(Any)
398                    .allow_origin(allowed_origin)
399                    .expose_headers(Any)
400                    .max_age(Duration::from_secs(60) * 60),
401            );
402
403        match authenticator_kind {
404            AuthenticatorKind::Password => {
405                base_router = base_router.layer(session_layer.clone());
406
407                let login_router = Router::new()
408                    .route("/api/login", routing::post(handle_login))
409                    .route("/api/logout", routing::post(handle_logout))
410                    .layer(Extension(adapter_client_rx));
411                router = router.merge(login_router).layer(session_layer);
412            }
413            AuthenticatorKind::None => {
414                base_router =
415                    base_router.layer(middleware::from_fn(x_materialize_user_header_auth));
416            }
417            _ => {}
418        }
419
420        router = router
421            .merge(base_router)
422            .apply_default_layers(source, metrics);
423
424        HttpServer { tls, router }
425    }
426}
427
428impl Server for HttpServer {
429    const NAME: &'static str = "http";
430
431    fn handle_connection(
432        &self,
433        conn: Connection,
434        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
435    ) -> ConnectionHandler {
436        let router = self.router.clone();
437        let tls_context = self.tls.clone();
438        let mut conn = TokioIo::new(conn);
439
440        Box::pin(async {
441            let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
442            let peer_addr = conn
443                .inner_mut()
444                .take_proxy_header_address()
445                .await
446                .map(|a| a.source)
447                .unwrap_or(direct_peer_addr);
448
449            let (conn, conn_protocol) = match tls_context {
450                Some(tls_context) => {
451                    let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
452                    if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
453                        let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
454                        return Err(e.into());
455                    }
456                    (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
457                }
458                _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
459            };
460            let mut make_tower_svc = router
461                .layer(Extension(conn_protocol))
462                .into_make_service_with_connect_info::<SocketAddr>();
463            let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
464            let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
465            let http = hyper::server::conn::http1::Builder::new();
466            http.serve_connection(conn, hyper_svc)
467                .with_upgrades()
468                .err_into()
469                .await
470        })
471    }
472}
473
474pub async fn handle_leader_status(
475    State(deployment_state_handle): State<DeploymentStateHandle>,
476) -> impl IntoResponse {
477    let status = deployment_state_handle.status();
478    (StatusCode::OK, Json(json!({ "status": status })))
479}
480
481pub async fn handle_leader_promote(
482    State(deployment_state_handle): State<DeploymentStateHandle>,
483) -> impl IntoResponse {
484    match deployment_state_handle.try_promote() {
485        Ok(()) => {
486            // TODO(benesch): the body here is redundant. Should just return
487            // 204.
488            let status = StatusCode::OK;
489            let body = Json(json!({
490                "result": "Success",
491            }));
492            (status, body)
493        }
494        Err(()) => {
495            // TODO(benesch): the nesting here is redundant given the error
496            // code. Should just return the `{"message": "..."}` object.
497            let status = StatusCode::BAD_REQUEST;
498            let body = Json(json!({
499                "result": {"Failure": {"message": "cannot promote leader while initializing"}},
500            }));
501            (status, body)
502        }
503    }
504}
505
506pub async fn handle_leader_skip_catchup(
507    State(deployment_state_handle): State<DeploymentStateHandle>,
508) -> impl IntoResponse {
509    match deployment_state_handle.try_skip_catchup() {
510        Ok(()) => StatusCode::NO_CONTENT.into_response(),
511        Err(()) => {
512            let status = StatusCode::BAD_REQUEST;
513            let body = Json(json!({
514                "message": "cannot skip catchup in this phase of initialization; try again later",
515            }));
516            (status, body).into_response()
517        }
518    }
519}
520
521async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl IntoResponse {
522    // TODO migrate teleport to basic auth and remove this.
523    if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
524        let username = match username {
525            Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
526            _ => {
527                return Err(AuthError::MismatchedUser(format!(
528                    "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
529                )));
530            }
531        };
532        req.extensions_mut().insert(AuthedUser {
533            name: username,
534            external_metadata_rx: None,
535        });
536    }
537    Ok(next.run(req).await)
538}
539
540type Delayed<T> = Shared<oneshot::Receiver<T>>;
541
542#[derive(Clone)]
543enum ConnProtocol {
544    Http,
545    Https,
546}
547
548#[derive(Clone, Debug)]
549pub struct AuthedUser {
550    name: String,
551    external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
552}
553
554pub struct AuthedClient {
555    pub client: SessionClient,
556    pub connection_guard: Option<ConnectionHandle>,
557}
558
559impl AuthedClient {
560    async fn new<F>(
561        adapter_client: &Client,
562        user: AuthedUser,
563        peer_addr: IpAddr,
564        active_connection_counter: ConnectionCounter,
565        helm_chart_version: Option<String>,
566        session_config: F,
567        options: BTreeMap<String, String>,
568        now: NowFn,
569    ) -> Result<Self, AdapterError>
570    where
571        F: FnOnce(&mut AdapterSession),
572    {
573        let conn_id = adapter_client.new_conn_id()?;
574        let mut session = adapter_client.new_session(AdapterSessionConfig {
575            conn_id,
576            uuid: epoch_to_uuid_v7(&(now)()),
577            user: user.name,
578            client_ip: Some(peer_addr),
579            external_metadata_rx: user.external_metadata_rx,
580            //TODO(dov): Add support for internal user metadata when we support auth here
581            internal_user_metadata: None,
582            helm_chart_version,
583        });
584        let connection_guard = active_connection_counter.allocate_connection(session.user())?;
585
586        session_config(&mut session);
587        let system_vars = adapter_client.get_system_vars().await;
588        for (key, val) in options {
589            const LOCAL: bool = false;
590            if let Err(err) =
591                session
592                    .vars_mut()
593                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
594            {
595                session.add_notice(AdapterNotice::BadStartupSetting {
596                    name: key.to_string(),
597                    reason: err.to_string(),
598                })
599            }
600        }
601        let adapter_client = adapter_client.startup(session).await?;
602        Ok(AuthedClient {
603            client: adapter_client,
604            connection_guard,
605        })
606    }
607}
608
609#[async_trait]
610impl<S> FromRequestParts<S> for AuthedClient
611where
612    S: Send + Sync,
613{
614    type Rejection = Response;
615
616    async fn from_request_parts(
617        req: &mut http::request::Parts,
618        state: &S,
619    ) -> Result<Self, Self::Rejection> {
620        #[derive(Debug, Default, Deserialize)]
621        struct Params {
622            #[serde(default)]
623            options: String,
624        }
625        let params: Query<Params> = Query::from_request_parts(req, state)
626            .await
627            .unwrap_or_default();
628
629        let peer_addr = req
630            .extensions
631            .get::<ConnectInfo<SocketAddr>>()
632            .expect("ConnectInfo extension guaranteed to exist")
633            .0
634            .ip();
635
636        let user = req.extensions.get::<AuthedUser>().unwrap();
637        let adapter_client = req
638            .extensions
639            .get::<Delayed<mz_adapter::Client>>()
640            .unwrap()
641            .clone();
642        let adapter_client = adapter_client.await.map_err(|_| {
643            (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
644        })?;
645        let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
646        let helm_chart_version = None;
647
648        let options = if params.options.is_empty() {
649            // It's possible 'options' simply wasn't provided, we don't want that to
650            // count as a failure to deserialize
651            BTreeMap::<String, String>::default()
652        } else {
653            match serde_json::from_str(&params.options) {
654                Ok(options) => options,
655                Err(_e) => {
656                    // If we fail to deserialize options, fail the request.
657                    let code = StatusCode::BAD_REQUEST;
658                    let msg = format!("Failed to deserialize {} map", "options".quoted());
659                    return Err((code, msg).into_response());
660                }
661            }
662        };
663
664        let client = AuthedClient::new(
665            &adapter_client,
666            user.clone(),
667            peer_addr,
668            active_connection_counter.clone(),
669            helm_chart_version,
670            |session| {
671                session
672                    .vars_mut()
673                    .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
674                    .expect("known to exist")
675            },
676            options,
677            SYSTEM_TIME.clone(),
678        )
679        .await
680        .map_err(|e| {
681            let status = match e {
682                AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
683                    StatusCode::FORBIDDEN
684                }
685                _ => StatusCode::INTERNAL_SERVER_ERROR,
686            };
687            (status, Json(SqlError::from(e))).into_response()
688        })?;
689
690        Ok(client)
691    }
692}
693
694#[derive(Debug, Error)]
695enum AuthError {
696    #[error("role dissallowed")]
697    RoleDisallowed(String),
698    #[error("{0}")]
699    Frontegg(#[from] FronteggError),
700    #[error("missing authorization header")]
701    MissingHttpAuthentication {
702        include_www_authenticate_header: bool,
703    },
704    #[error("{0}")]
705    MismatchedUser(String),
706    #[error("session expired")]
707    SessionExpired,
708    #[error("failed to update session")]
709    FailedToUpdateSession,
710    #[error("invalid credentials")]
711    InvalidCredentials,
712}
713
714impl IntoResponse for AuthError {
715    fn into_response(self) -> Response {
716        warn!("HTTP request failed authentication: {}", self);
717        let mut headers = HeaderMap::new();
718        match self {
719            AuthError::MissingHttpAuthentication {
720                include_www_authenticate_header,
721            } if include_www_authenticate_header => {
722                headers.insert(
723                    http::header::WWW_AUTHENTICATE,
724                    HeaderValue::from_static("Basic realm=Materialize"),
725                );
726            }
727            _ => {}
728        };
729        // We omit most detail from the error message we send to the client, to
730        // avoid giving attackers unnecessary information.
731        (StatusCode::UNAUTHORIZED, headers, "unauthorized").into_response()
732    }
733}
734
735// Simplified login handler
736pub async fn handle_login(
737    session: Option<Extension<TowerSession>>,
738    Extension(adapter_client_rx): Extension<Delayed<Client>>,
739    Json(LoginCredentials { username, password }): Json<LoginCredentials>,
740) -> impl IntoResponse {
741    let Ok(adapter_client) = adapter_client_rx.clone().await else {
742        return StatusCode::INTERNAL_SERVER_ERROR;
743    };
744    if let Err(err) = adapter_client.authenticate(&username, &password).await {
745        warn!(?err, "HTTP login failed authentication");
746        return StatusCode::UNAUTHORIZED;
747    };
748
749    // Create session data
750    let session_data = TowerSessionData {
751        username,
752        created_at: SystemTime::now(),
753        last_activity: SystemTime::now(),
754    };
755    // Store session data
756    let session = session.and_then(|Extension(session)| Some(session));
757    let Some(session) = session else {
758        return StatusCode::INTERNAL_SERVER_ERROR;
759    };
760    match session.insert("data", &session_data).await {
761        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
762        Ok(_) => StatusCode::OK,
763    }
764}
765
766// Simplified logout handler
767pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
768    let session = session.and_then(|Extension(session)| Some(session));
769    let Some(session) = session else {
770        return StatusCode::INTERNAL_SERVER_ERROR;
771    };
772    // Delete session
773    match session.delete().await {
774        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
775        Ok(_) => StatusCode::OK,
776    }
777}
778
779async fn http_auth(
780    mut req: Request,
781    next: Next,
782    tls_enabled: bool,
783    authenticator: Arc<Authenticator>,
784    allowed_roles: AllowedRoles,
785) -> impl IntoResponse + use<> {
786    // First check for session authentication
787    if let Some(session) = req.extensions().get::<TowerSession>() {
788        if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
789            // Check session expiration
790            if session_data
791                .last_activity
792                .elapsed()
793                .unwrap_or(Duration::MAX)
794                > SESSION_DURATION
795            {
796                let _ = session.delete().await;
797                return Err(AuthError::SessionExpired);
798            }
799            // Update last activity
800            let mut updated_data = session_data.clone();
801            updated_data.last_activity = SystemTime::now();
802            session
803                .insert("data", &updated_data)
804                .await
805                .map_err(|_| AuthError::FailedToUpdateSession)?;
806            // User is authenticated via session
807            req.extensions_mut().insert(AuthedUser {
808                name: session_data.username,
809                external_metadata_rx: None,
810            });
811            return Ok(next.run(req).await);
812        }
813    }
814
815    // First, extract the username from the certificate, validating that the
816    // connection matches the TLS configuration along the way.
817    // Fall back to existing authentication methods.
818    let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
819    match (tls_enabled, &conn_protocol) {
820        (false, ConnProtocol::Http) => {}
821        (false, ConnProtocol::Https { .. }) => unreachable!(),
822        (true, ConnProtocol::Http) => {
823            let mut parts = req.uri().clone().into_parts();
824            parts.scheme = Some(Scheme::HTTPS);
825            return Ok(Redirect::permanent(
826                &Uri::from_parts(parts)
827                    .expect("it was already a URI, just changed the scheme")
828                    .to_string(),
829            )
830            .into_response());
831        }
832        (true, ConnProtocol::Https { .. }) => {}
833    }
834    // If we've already passed some other auth, just use that.
835    if req.extensions().get::<AuthedUser>().is_some() {
836        return Ok(next.run(req).await);
837    }
838    let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
839        Some(Credentials::Password {
840            username: basic.username().to_owned(),
841            password: Password(basic.password().to_owned()),
842        })
843    } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
844        Some(Credentials::Token {
845            token: bearer.token().to_owned(),
846        })
847    } else {
848        None
849    };
850
851    let path = req.uri().path();
852    let include_www_authenticate_header = path == "/"
853        || PROFILING_API_ENDPOINTS
854            .iter()
855            .any(|prefix| path.starts_with(prefix));
856    let user = auth(
857        &authenticator,
858        creds,
859        allowed_roles,
860        include_www_authenticate_header,
861    )
862    .await?;
863
864    // Add the authenticated user as an extension so downstream handlers can
865    // inspect it if necessary.
866    req.extensions_mut().insert(user);
867
868    // Run the request.
869    Ok(next.run(req).await)
870}
871
872async fn init_ws(
873    WsState {
874        authenticator_rx,
875        adapter_client_rx,
876        active_connection_counter,
877        helm_chart_version,
878        allowed_roles,
879    }: &WsState,
880    existing_user: Option<AuthedUser>,
881    peer_addr: IpAddr,
882    ws: &mut WebSocket,
883) -> Result<AuthedClient, anyhow::Error> {
884    let authenticator = authenticator_rx.clone().await.expect("sender not dropped");
885    // TODO: Add a timeout here to prevent resource leaks by clients that
886    // connect then never send a message.
887    let ws_auth: WebSocketAuth = loop {
888        let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
889        match init_msg {
890            Message::Text(data) => break serde_json::from_str(&data)?,
891            Message::Binary(data) => break serde_json::from_slice(&data)?,
892            // Handled automatically by the server.
893            Message::Ping(_) => {
894                continue;
895            }
896            Message::Pong(_) => {
897                continue;
898            }
899            Message::Close(_) => {
900                anyhow::bail!("closed");
901            }
902        }
903    };
904
905    let (user, options) = if let Some(existing_user) = existing_user {
906        match ws_auth {
907            WebSocketAuth::OptionsOnly { options } => (existing_user, options),
908            _ => {
909                warn!("Unexpected bearer or basic auth provided when using user header");
910                anyhow::bail!("unexpected")
911            }
912        }
913    } else {
914        let (creds, options) = match ws_auth {
915            WebSocketAuth::Basic {
916                user,
917                password,
918                options,
919            } => {
920                let creds = Credentials::Password {
921                    username: user,
922                    password,
923                };
924                (creds, options)
925            }
926            WebSocketAuth::Bearer { token, options } => {
927                let creds = Credentials::Token { token };
928                (creds, options)
929            }
930            WebSocketAuth::OptionsOnly { .. } => {
931                anyhow::bail!("expected auth information");
932            }
933        };
934        let user = auth(&authenticator, Some(creds), *allowed_roles, false).await?;
935        (user, options)
936    };
937
938    let client = AuthedClient::new(
939        &adapter_client_rx.clone().await?,
940        user,
941        peer_addr,
942        active_connection_counter.clone(),
943        helm_chart_version.clone(),
944        |_session| (),
945        options,
946        SYSTEM_TIME.clone(),
947    )
948    .await?;
949
950    Ok(client)
951}
952
953enum Credentials {
954    Password {
955        username: String,
956        password: Password,
957    },
958    Token {
959        token: String,
960    },
961}
962
963async fn auth(
964    authenticator: &Authenticator,
965    creds: Option<Credentials>,
966    allowed_roles: AllowedRoles,
967    include_www_authenticate_header: bool,
968) -> Result<AuthedUser, AuthError> {
969    // TODO pass session data here?
970    let (name, external_metadata_rx) = match authenticator {
971        Authenticator::Frontegg(frontegg) => match creds {
972            Some(Credentials::Password { username, password }) => {
973                let auth_session = frontegg.authenticate(&username, &password.0).await?;
974                let name = auth_session.user().into();
975                let external_metadata_rx = Some(auth_session.external_metadata_rx());
976                (name, external_metadata_rx)
977            }
978            Some(Credentials::Token { token }) => {
979                let claims = frontegg.validate_access_token(&token, None)?;
980                let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
981                    user_id: claims.user_id,
982                    admin: claims.is_admin,
983                });
984                (claims.user, Some(external_metadata_rx))
985            }
986            None => {
987                return Err(AuthError::MissingHttpAuthentication {
988                    include_www_authenticate_header,
989                });
990            }
991        },
992        Authenticator::Password(adapter_client) => match creds {
993            Some(Credentials::Password { username, password }) => {
994                if let Err(_) = adapter_client.authenticate(&username, &password).await {
995                    return Err(AuthError::InvalidCredentials);
996                }
997                (username, None)
998            }
999            _ => {
1000                return Err(AuthError::MissingHttpAuthentication {
1001                    include_www_authenticate_header,
1002                });
1003            }
1004        },
1005        Authenticator::None => {
1006            // If no authentication, use whatever is in the HTTP auth
1007            // header (without checking the password), or fall back to the
1008            // default user.
1009            let name = match creds {
1010                Some(Credentials::Password { username, .. }) => username,
1011                _ => HTTP_DEFAULT_USER.name.to_owned(),
1012            };
1013            (name, None)
1014        }
1015    };
1016
1017    check_role_allowed(&name, allowed_roles)?;
1018
1019    Ok(AuthedUser {
1020        name,
1021        external_metadata_rx,
1022    })
1023}
1024
1025// TODO move this somewhere it can be shared with PGWIRE
1026fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1027    let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1028    // this is a superset of internal users
1029    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1030    let role_allowed = match allowed_roles {
1031        AllowedRoles::Normal => !is_reserved_user,
1032        AllowedRoles::Internal => is_internal_user,
1033        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1034    };
1035    if role_allowed {
1036        Ok(())
1037    } else {
1038        Err(AuthError::RoleDisallowed(name.to_owned()))
1039    }
1040}
1041
1042/// Default layers that should be applied to all routes, and should get applied to both the
1043/// internal http and external http routers.
1044trait DefaultLayers {
1045    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1046}
1047
1048impl DefaultLayers for Router {
1049    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1050        self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1051            .layer(metrics::PrometheusLayer::new(source, metrics))
1052    }
1053}
1054
1055/// Glue code to make [`tower`] work with [`axum`].
1056///
1057/// `axum` requires `Layer`s not return Errors, i.e. they must be `Result<_, Infallible>`,
1058/// instead you must return a type that can be converted into a response. `tower` on the other
1059/// hand does return Errors, so to make the two work together we need to convert our `tower` errors
1060/// into responses.
1061async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1062    if error.is::<tower::load_shed::error::Overloaded>() {
1063        return (
1064            StatusCode::TOO_MANY_REQUESTS,
1065            Cow::from("too many requests, try again later"),
1066        );
1067    }
1068
1069    // Note: This should be unreachable because at the time of writing our only use case is a
1070    // layer that emits `tower::load_shed::error::Overloaded`, which is handled above.
1071    (
1072        StatusCode::INTERNAL_SERVER_ERROR,
1073        Cow::from(format!("Unhandled internal error: {}", error)),
1074    )
1075}
1076
1077#[derive(Debug, Deserialize, Serialize, PartialEq)]
1078pub struct LoginCredentials {
1079    username: String,
1080    password: Password,
1081}
1082
1083#[derive(Debug, Clone, Serialize, Deserialize)]
1084pub struct TowerSessionData {
1085    username: String,
1086    created_at: SystemTime,
1087    last_activity: SystemTime,
1088}
1089
1090#[cfg(test)]
1091mod tests {
1092    use super::{AllowedRoles, check_role_allowed};
1093
1094    #[mz_ore::test]
1095    fn test_check_role_allowed() {
1096        // Internal user
1097        assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1098        assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1099        assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1100
1101        // Internal user
1102        assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1103        assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1104        assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1105
1106        // Internal user
1107        assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1108        assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1109        assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1110
1111        // Normal user
1112        assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1113        assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1114        assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1115
1116        // Normal user
1117        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1118        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1119        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1120
1121        // Normal user
1122        assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1123        assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1124        assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1125
1126        // Denied by reserved role prefix
1127        assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1128        assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1129        assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1130
1131        // Denied by reserved role prefix
1132        assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1133        assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1134        assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1135
1136        // Denied by reserved role prefix
1137        assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1138        assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1139        assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1140
1141        // Denied by literal PUBLIC
1142        assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1143        assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1144        assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1145    }
1146}