mz_environmentd/
http.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Embedded HTTP server.
11//!
12//! environmentd embeds an HTTP server for introspection into the running
13//! process. At the moment, its primary exports are Prometheus metrics, heap
14//! profiles, and catalog dumps.
15
16// Axum handlers must use async, but often don't actually use `await`.
17#![allow(clippy::unused_async)]
18
19use std::borrow::Cow;
20use std::collections::BTreeMap;
21use std::fmt::Debug;
22use std::net::{IpAddr, SocketAddr};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::time::{Duration, SystemTime};
26
27use anyhow::Context;
28use async_trait::async_trait;
29use axum::error_handling::HandleErrorLayer;
30use axum::extract::ws::{Message, WebSocket};
31use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
32use axum::middleware::{self, Next};
33use axum::response::{IntoResponse, Redirect, Response};
34use axum::{Extension, Json, Router, routing};
35use futures::future::{Shared, TryFutureExt};
36use headers::authorization::{Authorization, Basic, Bearer};
37use headers::{HeaderMapExt, HeaderName};
38use http::header::{AUTHORIZATION, CONTENT_TYPE};
39use http::{Method, StatusCode};
40use hyper_openssl::SslStream;
41use hyper_openssl::client::legacy::MaybeHttpsStream;
42use hyper_util::rt::TokioIo;
43use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
44use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
45use mz_auth::password::Password;
46use mz_authenticator::Authenticator;
47use mz_frontegg_auth::Error as FronteggError;
48use mz_http_util::DynamicFilterTarget;
49use mz_ore::cast::u64_to_usize;
50use mz_ore::metrics::MetricsRegistry;
51use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
52use mz_ore::str::StrExt;
53use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
54use mz_repr::user::ExternalUserMetadata;
55use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind, HttpRoutesEnabled};
56use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
57use mz_sql::session::metadata::SessionMetadata;
58use mz_sql::session::user::{
59    HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
60};
61use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
62use openssl::ssl::Ssl;
63use prometheus::{
64    COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
65};
66use serde::{Deserialize, Serialize};
67use serde_json::json;
68use thiserror::Error;
69use tokio::io::AsyncWriteExt;
70use tokio::sync::oneshot::Receiver;
71use tokio::sync::{oneshot, watch};
72use tower::limit::GlobalConcurrencyLimitLayer;
73use tower::{Service, ServiceBuilder};
74use tower_http::cors::{AllowOrigin, Any, CorsLayer};
75use tower_sessions::{
76    MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
77    SessionManagerLayer as TowerSessionManagerLayer,
78};
79use tracing::{error, warn};
80
81use crate::BUILD_INFO;
82use crate::deployment::state::DeploymentStateHandle;
83use crate::http::sql::SqlError;
84
85mod catalog;
86mod console;
87mod memory;
88mod metrics;
89mod probe;
90mod prometheus;
91mod root;
92mod sql;
93mod webhook;
94
95pub use metrics::Metrics;
96pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
97
98/// Maximum allowed size for a request.
99pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
100
101const SESSION_DURATION: Duration = Duration::from_secs(3600); // 1 hour
102
103#[derive(Debug)]
104pub struct HttpConfig {
105    pub source: &'static str,
106    pub tls: Option<ReloadingSslContext>,
107    pub authenticator_kind: AuthenticatorKind,
108    pub authenticator_rx: Shared<Receiver<Arc<Authenticator>>>,
109    pub adapter_client_rx: Shared<Receiver<Client>>,
110    pub allowed_origin: AllowOrigin,
111    pub active_connection_counter: ConnectionCounter,
112    pub helm_chart_version: Option<String>,
113    pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
114    pub metrics: Metrics,
115    pub metrics_registry: MetricsRegistry,
116    pub allowed_roles: AllowedRoles,
117    pub internal_route_config: Arc<InternalRouteConfig>,
118    pub routes_enabled: HttpRoutesEnabled,
119}
120
121#[derive(Debug, Clone)]
122pub struct InternalRouteConfig {
123    pub deployment_state_handle: DeploymentStateHandle,
124    pub internal_console_redirect_url: Option<String>,
125}
126
127#[derive(Clone)]
128pub struct WsState {
129    authenticator_rx: Delayed<Arc<Authenticator>>,
130    adapter_client_rx: Delayed<mz_adapter::Client>,
131    active_connection_counter: ConnectionCounter,
132    helm_chart_version: Option<String>,
133    allowed_roles: AllowedRoles,
134}
135
136#[derive(Clone)]
137pub struct WebhookState {
138    adapter_client_rx: Delayed<mz_adapter::Client>,
139    webhook_cache: WebhookAppenderCache,
140}
141
142#[derive(Debug)]
143pub struct HttpServer {
144    tls: Option<ReloadingSslContext>,
145    router: Router,
146}
147
148impl HttpServer {
149    pub fn new(
150        HttpConfig {
151            source,
152            tls,
153            authenticator_kind,
154            authenticator_rx,
155            adapter_client_rx,
156            allowed_origin,
157            active_connection_counter,
158            helm_chart_version,
159            concurrent_webhook_req,
160            metrics,
161            metrics_registry,
162            allowed_roles,
163            internal_route_config,
164            routes_enabled,
165        }: HttpConfig,
166    ) -> HttpServer {
167        let tls_enabled = tls.is_some();
168        let webhook_cache = WebhookAppenderCache::new();
169
170        // Create secure session store and manager
171        let session_store = TowerSessionMemoryStore::default();
172        let session_layer = TowerSessionManagerLayer::new(session_store)
173            .with_secure(tls_enabled) // Enforce HTTPS
174            .with_same_site(tower_sessions::cookie::SameSite::Strict) // Prevent CSRF
175            .with_http_only(true) // Prevent XSS
176            .with_name("mz_session") // Custom cookie name
177            .with_path("/"); // Set cookie path
178
179        let auth_middleware_authenticator_rx = authenticator_rx.clone();
180        let auth_middleware = middleware::from_fn(move |req, next| {
181            let authenticator_rx = auth_middleware_authenticator_rx.clone();
182            async move {
183                let authenticator = authenticator_rx
184                    .await
185                    .expect("sender not dropped before sending once");
186                http_auth(req, next, tls_enabled, authenticator, allowed_roles).await
187            }
188        });
189
190        let mut router = Router::new();
191        let mut base_router = Router::new();
192        if routes_enabled.base {
193            base_router = base_router
194                .route(
195                    "/",
196                    routing::get(move || async move {
197                        root::handle_home(routes_enabled.profiling).await
198                    }),
199                )
200                .route("/api/sql", routing::post(sql::handle_sql))
201                .route("/memory", routing::get(memory::handle_memory))
202                .route(
203                    "/hierarchical-memory",
204                    routing::get(memory::handle_hierarchical_memory),
205                )
206                .route("/static/*path", routing::get(root::handle_static));
207
208            let mut ws_router = Router::new()
209                .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
210                .with_state(WsState {
211                    authenticator_rx: authenticator_rx.clone(),
212                    adapter_client_rx: adapter_client_rx.clone(),
213                    active_connection_counter: active_connection_counter.clone(),
214                    helm_chart_version,
215                    allowed_roles,
216                });
217            if let AuthenticatorKind::None = authenticator_kind {
218                ws_router = ws_router.layer(middleware::from_fn(x_materialize_user_header_auth));
219            }
220            router = router.merge(ws_router);
221        }
222        if routes_enabled.profiling {
223            base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
224        }
225
226        if routes_enabled.webhook {
227            let webhook_router = Router::new()
228                .route(
229                    "/api/webhook/:database/:schema/:id",
230                    routing::post(webhook::handle_webhook),
231                )
232                .with_state(WebhookState {
233                    adapter_client_rx: adapter_client_rx.clone(),
234                    webhook_cache,
235                })
236                .layer(
237                    tower_http::decompression::RequestDecompressionLayer::new()
238                        .gzip(true)
239                        .deflate(true)
240                        .br(true)
241                        .zstd(true),
242                )
243                .layer(
244                    CorsLayer::new()
245                        .allow_methods(Method::POST)
246                        .allow_origin(AllowOrigin::mirror_request())
247                        .allow_headers(Any),
248                )
249                .layer(
250                    ServiceBuilder::new()
251                        .layer(HandleErrorLayer::new(handle_load_error))
252                        .load_shed()
253                        .layer(GlobalConcurrencyLimitLayer::with_semaphore(
254                            concurrent_webhook_req,
255                        )),
256                );
257            router = router.merge(webhook_router);
258        }
259
260        if routes_enabled.internal {
261            let console_config = Arc::new(console::ConsoleProxyConfig::new(
262                internal_route_config.internal_console_redirect_url.clone(),
263                "/internal-console".to_string(),
264            ));
265            base_router = base_router
266                .route(
267                    "/api/opentelemetry/config",
268                    routing::put({
269                        move |_: axum::Json<DynamicFilterTarget>| async {
270                            (
271                                StatusCode::BAD_REQUEST,
272                                "This endpoint has been replaced. \
273                            Use the `opentelemetry_filter` system variable."
274                                    .to_string(),
275                            )
276                        }
277                    }),
278                )
279                .route(
280                    "/api/stderr/config",
281                    routing::put({
282                        move |_: axum::Json<DynamicFilterTarget>| async {
283                            (
284                                StatusCode::BAD_REQUEST,
285                                "This endpoint has been replaced. \
286                            Use the `log_filter` system variable."
287                                    .to_string(),
288                            )
289                        }
290                    }),
291                )
292                .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
293                .route(
294                    "/api/catalog/dump",
295                    routing::get(catalog::handle_catalog_dump),
296                )
297                .route(
298                    "/api/catalog/check",
299                    routing::get(catalog::handle_catalog_check),
300                )
301                .route(
302                    "/api/coordinator/check",
303                    routing::get(catalog::handle_coordinator_check),
304                )
305                .route(
306                    "/api/coordinator/dump",
307                    routing::get(catalog::handle_coordinator_dump),
308                )
309                .route(
310                    "/internal-console",
311                    routing::get(|| async { Redirect::temporary("/internal-console/") }),
312                )
313                .route(
314                    "/internal-console/*path",
315                    routing::get(console::handle_internal_console),
316                )
317                .route(
318                    "/internal-console/",
319                    routing::get(console::handle_internal_console),
320                )
321                .layer(Extension(console_config));
322            let leader_router = Router::new()
323                .route("/api/leader/status", routing::get(handle_leader_status))
324                .route("/api/leader/promote", routing::post(handle_leader_promote))
325                .route(
326                    "/api/leader/skip-catchup",
327                    routing::post(handle_leader_skip_catchup),
328                )
329                .layer(auth_middleware.clone())
330                .with_state(internal_route_config.deployment_state_handle.clone());
331            router = router.merge(leader_router);
332        }
333
334        if routes_enabled.metrics {
335            let metrics_router = Router::new()
336                .route(
337                    "/metrics",
338                    routing::get(move || async move {
339                        mz_http_util::handle_prometheus(&metrics_registry).await
340                    }),
341                )
342                .route(
343                    "/metrics/mz_usage",
344                    routing::get(|client: AuthedClient| async move {
345                        let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
346                        mz_http_util::handle_prometheus(&registry).await
347                    }),
348                )
349                .route(
350                    "/metrics/mz_frontier",
351                    routing::get(|client: AuthedClient| async move {
352                        let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
353                        mz_http_util::handle_prometheus(&registry).await
354                    }),
355                )
356                .route(
357                    "/metrics/mz_compute",
358                    routing::get(|client: AuthedClient| async move {
359                        let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
360                        mz_http_util::handle_prometheus(&registry).await
361                    }),
362                )
363                .route(
364                    "/metrics/mz_storage",
365                    routing::get(|client: AuthedClient| async move {
366                        let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
367                        mz_http_util::handle_prometheus(&registry).await
368                    }),
369                )
370                .route(
371                    "/api/livez",
372                    routing::get(mz_http_util::handle_liveness_check),
373                )
374                .route("/api/readyz", routing::get(probe::handle_ready))
375                .layer(auth_middleware.clone())
376                .layer(Extension(adapter_client_rx.clone()))
377                .layer(Extension(active_connection_counter.clone()));
378            router = router.merge(metrics_router);
379        }
380
381        base_router = base_router
382            .layer(auth_middleware.clone())
383            .layer(Extension(adapter_client_rx.clone()))
384            .layer(Extension(active_connection_counter.clone()))
385            .layer(
386                CorsLayer::new()
387                    .allow_credentials(false)
388                    .allow_headers([
389                        AUTHORIZATION,
390                        CONTENT_TYPE,
391                        HeaderName::from_static("x-materialize-version"),
392                    ])
393                    .allow_methods(Any)
394                    .allow_origin(allowed_origin)
395                    .expose_headers(Any)
396                    .max_age(Duration::from_secs(60) * 60),
397            );
398
399        match authenticator_kind {
400            AuthenticatorKind::Password => {
401                base_router = base_router.layer(session_layer.clone());
402
403                let login_router = Router::new()
404                    .route("/api/login", routing::post(handle_login))
405                    .route("/api/logout", routing::post(handle_logout))
406                    .layer(Extension(adapter_client_rx));
407                router = router.merge(login_router).layer(session_layer);
408            }
409            AuthenticatorKind::None => {
410                base_router =
411                    base_router.layer(middleware::from_fn(x_materialize_user_header_auth));
412            }
413            _ => {}
414        }
415
416        router = router
417            .merge(base_router)
418            .apply_default_layers(source, metrics);
419
420        HttpServer { tls, router }
421    }
422}
423
424impl Server for HttpServer {
425    const NAME: &'static str = "http";
426
427    fn handle_connection(&self, conn: Connection) -> ConnectionHandler {
428        let router = self.router.clone();
429        let tls_context = self.tls.clone();
430        let mut conn = TokioIo::new(conn);
431
432        Box::pin(async {
433            let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
434            let peer_addr = conn
435                .inner_mut()
436                .take_proxy_header_address()
437                .await
438                .map(|a| a.source)
439                .unwrap_or(direct_peer_addr);
440
441            let (conn, conn_protocol) = match tls_context {
442                Some(tls_context) => {
443                    let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
444                    if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
445                        let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
446                        return Err(e.into());
447                    }
448                    (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
449                }
450                _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
451            };
452            let mut make_tower_svc = router
453                .layer(Extension(conn_protocol))
454                .into_make_service_with_connect_info::<SocketAddr>();
455            let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
456            let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
457            let http = hyper::server::conn::http1::Builder::new();
458            http.serve_connection(conn, hyper_svc)
459                .with_upgrades()
460                .err_into()
461                .await
462        })
463    }
464}
465
466pub async fn handle_leader_status(
467    State(deployment_state_handle): State<DeploymentStateHandle>,
468) -> impl IntoResponse {
469    let status = deployment_state_handle.status();
470    (StatusCode::OK, Json(json!({ "status": status })))
471}
472
473pub async fn handle_leader_promote(
474    State(deployment_state_handle): State<DeploymentStateHandle>,
475) -> impl IntoResponse {
476    match deployment_state_handle.try_promote() {
477        Ok(()) => {
478            // TODO(benesch): the body here is redundant. Should just return
479            // 204.
480            let status = StatusCode::OK;
481            let body = Json(json!({
482                "result": "Success",
483            }));
484            (status, body)
485        }
486        Err(()) => {
487            // TODO(benesch): the nesting here is redundant given the error
488            // code. Should just return the `{"message": "..."}` object.
489            let status = StatusCode::BAD_REQUEST;
490            let body = Json(json!({
491                "result": {"Failure": {"message": "cannot promote leader while initializing"}},
492            }));
493            (status, body)
494        }
495    }
496}
497
498pub async fn handle_leader_skip_catchup(
499    State(deployment_state_handle): State<DeploymentStateHandle>,
500) -> impl IntoResponse {
501    match deployment_state_handle.try_skip_catchup() {
502        Ok(()) => StatusCode::NO_CONTENT.into_response(),
503        Err(()) => {
504            let status = StatusCode::BAD_REQUEST;
505            let body = Json(json!({
506                "message": "cannot skip catchup in this phase of initialization; try again later",
507            }));
508            (status, body).into_response()
509        }
510    }
511}
512
513async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl IntoResponse {
514    // TODO migrate teleport to basic auth and remove this.
515    if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
516        let username = match username {
517            Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
518            _ => {
519                return Err(AuthError::MismatchedUser(format!(
520                    "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
521                )));
522            }
523        };
524        req.extensions_mut().insert(AuthedUser {
525            name: username,
526            external_metadata_rx: None,
527        });
528    }
529    Ok(next.run(req).await)
530}
531
532type Delayed<T> = Shared<oneshot::Receiver<T>>;
533
534#[derive(Clone)]
535enum ConnProtocol {
536    Http,
537    Https,
538}
539
540#[derive(Clone, Debug)]
541pub struct AuthedUser {
542    name: String,
543    external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
544}
545
546pub struct AuthedClient {
547    pub client: SessionClient,
548    pub connection_guard: Option<ConnectionHandle>,
549}
550
551impl AuthedClient {
552    async fn new<F>(
553        adapter_client: &Client,
554        user: AuthedUser,
555        peer_addr: IpAddr,
556        active_connection_counter: ConnectionCounter,
557        helm_chart_version: Option<String>,
558        session_config: F,
559        options: BTreeMap<String, String>,
560        now: NowFn,
561    ) -> Result<Self, AdapterError>
562    where
563        F: FnOnce(&mut AdapterSession),
564    {
565        let conn_id = adapter_client.new_conn_id()?;
566        let mut session = adapter_client.new_session(AdapterSessionConfig {
567            conn_id,
568            uuid: epoch_to_uuid_v7(&(now)()),
569            user: user.name,
570            client_ip: Some(peer_addr),
571            external_metadata_rx: user.external_metadata_rx,
572            //TODO(dov): Add support for internal user metadata when we support auth here
573            internal_user_metadata: None,
574            helm_chart_version,
575        });
576        let connection_guard = active_connection_counter.allocate_connection(session.user())?;
577
578        session_config(&mut session);
579        let system_vars = adapter_client.get_system_vars().await;
580        for (key, val) in options {
581            const LOCAL: bool = false;
582            if let Err(err) =
583                session
584                    .vars_mut()
585                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
586            {
587                session.add_notice(AdapterNotice::BadStartupSetting {
588                    name: key.to_string(),
589                    reason: err.to_string(),
590                })
591            }
592        }
593        let adapter_client = adapter_client.startup(session).await?;
594        Ok(AuthedClient {
595            client: adapter_client,
596            connection_guard,
597        })
598    }
599}
600
601#[async_trait]
602impl<S> FromRequestParts<S> for AuthedClient
603where
604    S: Send + Sync,
605{
606    type Rejection = Response;
607
608    async fn from_request_parts(
609        req: &mut http::request::Parts,
610        state: &S,
611    ) -> Result<Self, Self::Rejection> {
612        #[derive(Debug, Default, Deserialize)]
613        struct Params {
614            #[serde(default)]
615            options: String,
616        }
617        let params: Query<Params> = Query::from_request_parts(req, state)
618            .await
619            .unwrap_or_default();
620
621        let peer_addr = req
622            .extensions
623            .get::<ConnectInfo<SocketAddr>>()
624            .expect("ConnectInfo extension guaranteed to exist")
625            .0
626            .ip();
627
628        let user = req.extensions.get::<AuthedUser>().unwrap();
629        let adapter_client = req
630            .extensions
631            .get::<Delayed<mz_adapter::Client>>()
632            .unwrap()
633            .clone();
634        let adapter_client = adapter_client.await.map_err(|_| {
635            (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
636        })?;
637        let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
638        let helm_chart_version = None;
639
640        let options = if params.options.is_empty() {
641            // It's possible 'options' simply wasn't provided, we don't want that to
642            // count as a failure to deserialize
643            BTreeMap::<String, String>::default()
644        } else {
645            match serde_json::from_str(&params.options) {
646                Ok(options) => options,
647                Err(_e) => {
648                    // If we fail to deserialize options, fail the request.
649                    let code = StatusCode::BAD_REQUEST;
650                    let msg = format!("Failed to deserialize {} map", "options".quoted());
651                    return Err((code, msg).into_response());
652                }
653            }
654        };
655
656        let client = AuthedClient::new(
657            &adapter_client,
658            user.clone(),
659            peer_addr,
660            active_connection_counter.clone(),
661            helm_chart_version,
662            |session| {
663                session
664                    .vars_mut()
665                    .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
666                    .expect("known to exist")
667            },
668            options,
669            SYSTEM_TIME.clone(),
670        )
671        .await
672        .map_err(|e| {
673            let status = match e {
674                AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
675                    StatusCode::FORBIDDEN
676                }
677                _ => StatusCode::INTERNAL_SERVER_ERROR,
678            };
679            (status, Json(SqlError::from(e))).into_response()
680        })?;
681
682        Ok(client)
683    }
684}
685
686#[derive(Debug, Error)]
687enum AuthError {
688    #[error("HTTPS is required")]
689    HttpsRequired,
690    #[error("invalid username in client certificate")]
691    InvalidLogin(String),
692    #[error("{0}")]
693    Frontegg(#[from] FronteggError),
694    #[error("missing authorization header")]
695    MissingHttpAuthentication,
696    #[error("{0}")]
697    MismatchedUser(String),
698    #[error("session expired")]
699    SessionExpired,
700    #[error("failed to update session")]
701    FailedToUpdateSession,
702}
703
704impl IntoResponse for AuthError {
705    fn into_response(self) -> Response {
706        warn!("HTTP request failed authentication: {}", self);
707        // We omit most detail from the error message we send to the client, to
708        // avoid giving attackers unnecessary information.
709        let message = match self {
710            AuthError::HttpsRequired => self.to_string(),
711            _ => "unauthorized".into(),
712        };
713        (
714            StatusCode::UNAUTHORIZED,
715            [(http::header::WWW_AUTHENTICATE, "Basic realm=Materialize")],
716            message,
717        )
718            .into_response()
719    }
720}
721
722// Simplified login handler
723pub async fn handle_login(
724    session: Option<Extension<TowerSession>>,
725    Extension(adapter_client_rx): Extension<Delayed<Client>>,
726    Json(LoginCredentials { username, password }): Json<LoginCredentials>,
727) -> impl IntoResponse {
728    let Ok(adapter_client) = adapter_client_rx.clone().await else {
729        return StatusCode::INTERNAL_SERVER_ERROR;
730    };
731    if let Err(err) = adapter_client.authenticate(&username, &password).await {
732        warn!(?err, "HTTP login failed authentication");
733        return StatusCode::UNAUTHORIZED;
734    };
735
736    // Create session data
737    let session_data = TowerSessionData {
738        username,
739        created_at: SystemTime::now(),
740        last_activity: SystemTime::now(),
741    };
742    // Store session data
743    let session = session.and_then(|Extension(session)| Some(session));
744    let Some(session) = session else {
745        return StatusCode::INTERNAL_SERVER_ERROR;
746    };
747    match session.insert("data", &session_data).await {
748        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
749        Ok(_) => StatusCode::OK,
750    }
751}
752
753// Simplified logout handler
754pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
755    let session = session.and_then(|Extension(session)| Some(session));
756    let Some(session) = session else {
757        return StatusCode::INTERNAL_SERVER_ERROR;
758    };
759    // Delete session
760    match session.delete().await {
761        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
762        Ok(_) => StatusCode::OK,
763    }
764}
765
766async fn http_auth(
767    mut req: Request,
768    next: Next,
769    tls_enabled: bool,
770    authenticator: Arc<Authenticator>,
771    allowed_roles: AllowedRoles,
772) -> impl IntoResponse + use<> {
773    // First check for session authentication
774    if let Some(session) = req.extensions().get::<TowerSession>() {
775        if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
776            // Check session expiration
777            if session_data
778                .last_activity
779                .elapsed()
780                .unwrap_or(Duration::MAX)
781                > SESSION_DURATION
782            {
783                let _ = session.delete().await;
784                return Err(AuthError::SessionExpired);
785            }
786            // Update last activity
787            let mut updated_data = session_data.clone();
788            updated_data.last_activity = SystemTime::now();
789            session
790                .insert("data", &updated_data)
791                .await
792                .map_err(|_| AuthError::FailedToUpdateSession)?;
793            // User is authenticated via session
794            req.extensions_mut().insert(AuthedUser {
795                name: session_data.username,
796                external_metadata_rx: None,
797            });
798            return Ok(next.run(req).await);
799        }
800    }
801
802    // First, extract the username from the certificate, validating that the
803    // connection matches the TLS configuration along the way.
804    // Fall back to existing authentication methods.
805    let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
806    match (tls_enabled, &conn_protocol) {
807        (false, ConnProtocol::Http) => {}
808        (false, ConnProtocol::Https { .. }) => unreachable!(),
809        (true, ConnProtocol::Http) => return Err(AuthError::HttpsRequired),
810        (true, ConnProtocol::Https { .. }) => {}
811    }
812    // If we've already passed some other auth, just use that.
813    if req.extensions().get::<AuthedUser>().is_some() {
814        return Ok(next.run(req).await);
815    }
816    let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
817        Some(Credentials::Password {
818            username: basic.username().to_owned(),
819            password: Password(basic.password().to_owned()),
820        })
821    } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
822        Some(Credentials::Token {
823            token: bearer.token().to_owned(),
824        })
825    } else {
826        None
827    };
828
829    let user = auth(&authenticator, creds, allowed_roles).await?;
830
831    // Add the authenticated user as an extension so downstream handlers can
832    // inspect it if necessary.
833    req.extensions_mut().insert(user);
834
835    // Run the request.
836    Ok(next.run(req).await)
837}
838
839async fn init_ws(
840    WsState {
841        authenticator_rx,
842        adapter_client_rx,
843        active_connection_counter,
844        helm_chart_version,
845        allowed_roles,
846    }: &WsState,
847    existing_user: Option<AuthedUser>,
848    peer_addr: IpAddr,
849    ws: &mut WebSocket,
850) -> Result<AuthedClient, anyhow::Error> {
851    let authenticator = authenticator_rx.clone().await.expect("sender not dropped");
852    // TODO: Add a timeout here to prevent resource leaks by clients that
853    // connect then never send a message.
854    let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
855    let ws_auth: WebSocketAuth = loop {
856        match init_msg {
857            Message::Text(data) => break serde_json::from_str(&data)?,
858            Message::Binary(data) => break serde_json::from_slice(&data)?,
859            // Handled automatically by the server.
860            Message::Ping(_) => {
861                continue;
862            }
863            Message::Pong(_) => {
864                continue;
865            }
866            Message::Close(_) => {
867                anyhow::bail!("closed");
868            }
869        }
870    };
871
872    let (user, options) = if let Some(existing_user) = existing_user {
873        match ws_auth {
874            WebSocketAuth::OptionsOnly { options } => (existing_user, options),
875            _ => {
876                warn!("Unexpected bearer or basic auth provided when using user header");
877                anyhow::bail!("unexpected")
878            }
879        }
880    } else {
881        let (creds, options) = match ws_auth {
882            WebSocketAuth::Basic {
883                user,
884                password,
885                options,
886            } => {
887                let creds = Credentials::Password {
888                    username: user,
889                    password,
890                };
891                (creds, options)
892            }
893            WebSocketAuth::Bearer { token, options } => {
894                let creds = Credentials::Token { token };
895                (creds, options)
896            }
897            WebSocketAuth::OptionsOnly { .. } => {
898                anyhow::bail!("expected auth information");
899            }
900        };
901        let user = auth(&authenticator, Some(creds), *allowed_roles).await?;
902        (user, options)
903    };
904
905    let client = AuthedClient::new(
906        &adapter_client_rx.clone().await?,
907        user,
908        peer_addr,
909        active_connection_counter.clone(),
910        helm_chart_version.clone(),
911        |_session| (),
912        options,
913        SYSTEM_TIME.clone(),
914    )
915    .await?;
916
917    Ok(client)
918}
919
920enum Credentials {
921    Password {
922        username: String,
923        password: Password,
924    },
925    Token {
926        token: String,
927    },
928}
929
930async fn auth(
931    authenticator: &Authenticator,
932    creds: Option<Credentials>,
933    allowed_roles: AllowedRoles,
934) -> Result<AuthedUser, AuthError> {
935    // TODO pass session data here?
936    let (name, external_metadata_rx) = match authenticator {
937        Authenticator::Frontegg(frontegg) => match creds {
938            Some(Credentials::Password { username, password }) => {
939                let auth_session = frontegg.authenticate(&username, &password.0).await?;
940                let name = auth_session.user().into();
941                let external_metadata_rx = Some(auth_session.external_metadata_rx());
942                (name, external_metadata_rx)
943            }
944            Some(Credentials::Token { token }) => {
945                let claims = frontegg.validate_access_token(&token, None)?;
946                let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
947                    user_id: claims.user_id,
948                    admin: claims.is_admin,
949                });
950                (claims.user, Some(external_metadata_rx))
951            }
952            None => return Err(AuthError::MissingHttpAuthentication),
953        },
954        Authenticator::Password(_) => {
955            warn!("self hosted auth only, but somehow missing session data");
956            // TODO tell the user to login here?
957            return Err(AuthError::MissingHttpAuthentication);
958        }
959        Authenticator::None => {
960            // If no authentication, use whatever is in the HTTP auth
961            // header (without checking the password), or fall back to the
962            // default user.
963            let name = match creds {
964                Some(Credentials::Password { username, .. }) => username,
965                _ => HTTP_DEFAULT_USER.name.to_owned(),
966            };
967            (name, None)
968        }
969    };
970
971    check_role_allowed(&name, allowed_roles)?;
972
973    Ok(AuthedUser {
974        name,
975        external_metadata_rx,
976    })
977}
978
979// TODO move this somewhere it can be shared with PGWIRE
980fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
981    let is_internal_user = INTERNAL_USER_NAMES.contains(name);
982    // this is a superset of internal users
983    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
984    let role_allowed = match allowed_roles {
985        AllowedRoles::Normal => !is_reserved_user,
986        AllowedRoles::Internal => is_internal_user,
987        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
988    };
989    if role_allowed {
990        Ok(())
991    } else {
992        Err(AuthError::InvalidLogin(name.to_owned()))
993    }
994}
995
996/// Default layers that should be applied to all routes, and should get applied to both the
997/// internal http and external http routers.
998trait DefaultLayers {
999    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1000}
1001
1002impl DefaultLayers for Router {
1003    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1004        self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1005            .layer(metrics::PrometheusLayer::new(source, metrics))
1006    }
1007}
1008
1009/// Glue code to make [`tower`] work with [`axum`].
1010///
1011/// `axum` requires `Layer`s not return Errors, i.e. they must be `Result<_, Infallible>`,
1012/// instead you must return a type that can be converted into a response. `tower` on the other
1013/// hand does return Errors, so to make the two work together we need to convert our `tower` errors
1014/// into responses.
1015async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1016    if error.is::<tower::load_shed::error::Overloaded>() {
1017        return (
1018            StatusCode::TOO_MANY_REQUESTS,
1019            Cow::from("too many requests, try again later"),
1020        );
1021    }
1022
1023    // Note: This should be unreachable because at the time of writing our only use case is a
1024    // layer that emits `tower::load_shed::error::Overloaded`, which is handled above.
1025    (
1026        StatusCode::INTERNAL_SERVER_ERROR,
1027        Cow::from(format!("Unhandled internal error: {}", error)),
1028    )
1029}
1030
1031#[derive(Debug, Deserialize, Serialize, PartialEq)]
1032pub struct LoginCredentials {
1033    username: String,
1034    password: Password,
1035}
1036
1037#[derive(Debug, Clone, Serialize, Deserialize)]
1038pub struct TowerSessionData {
1039    username: String,
1040    created_at: SystemTime,
1041    last_activity: SystemTime,
1042}
1043
1044#[cfg(test)]
1045mod tests {
1046    use super::{AllowedRoles, check_role_allowed};
1047
1048    #[mz_ore::test]
1049    fn test_check_role_allowed() {
1050        // Internal user
1051        assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1052        assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1053        assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1054
1055        // Internal user
1056        assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1057        assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1058        assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1059
1060        // Internal user
1061        assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1062        assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1063        assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1064
1065        // Normal user
1066        assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1067        assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1068        assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1069
1070        // Normal user
1071        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1072        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1073        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1074
1075        // Normal user
1076        assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1077        assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1078        assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1079
1080        // Denied by reserved role prefix
1081        assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1082        assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1083        assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1084
1085        // Denied by reserved role prefix
1086        assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1087        assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1088        assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1089
1090        // Denied by reserved role prefix
1091        assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1092        assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1093        assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1094
1095        // Denied by literal PUBLIC
1096        assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1097        assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1098        assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1099    }
1100}