Skip to main content

mz_environmentd/
http.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Embedded HTTP server.
11//!
12//! environmentd embeds an HTTP server for introspection into the running
13//! process. At the moment, its primary exports are Prometheus metrics, heap
14//! profiles, and catalog dumps.
15
16// Axum handlers must use async, but often don't actually use `await`.
17#![allow(clippy::unused_async)]
18
19use std::borrow::Cow;
20use std::collections::BTreeMap;
21use std::fmt::Debug;
22use std::net::{IpAddr, SocketAddr};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::time::{Duration, SystemTime};
26
27use anyhow::Context;
28use axum::error_handling::HandleErrorLayer;
29use axum::extract::ws::{Message, WebSocket};
30use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
31use axum::middleware::{self, Next};
32use axum::response::{IntoResponse, Redirect, Response};
33use axum::{Extension, Json, Router, routing};
34use futures::future::{Shared, TryFutureExt};
35use headers::authorization::{Authorization, Basic, Bearer};
36use headers::{HeaderMapExt, HeaderName};
37use http::header::{AUTHORIZATION, CONTENT_TYPE};
38use http::uri::Scheme;
39use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
40use hyper_openssl::SslStream;
41use hyper_openssl::client::legacy::MaybeHttpsStream;
42use hyper_util::rt::TokioIo;
43use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
44use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
45use mz_auth::password::Password;
46use mz_authenticator::Authenticator;
47use mz_controller::ReplicaHttpLocator;
48use mz_frontegg_auth::Error as FronteggError;
49use mz_http_util::DynamicFilterTarget;
50use mz_ore::cast::u64_to_usize;
51use mz_ore::metrics::MetricsRegistry;
52use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
53use mz_ore::str::StrExt;
54use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
55use mz_repr::user::{ExternalUserMetadata, InternalUserMetadata};
56use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind, HttpRoutesEnabled};
57use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
58use mz_sql::session::metadata::SessionMetadata;
59use mz_sql::session::user::{
60    HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
61};
62use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
63use openssl::ssl::Ssl;
64use prometheus::{
65    COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
66};
67use serde::{Deserialize, Serialize};
68use serde_json::json;
69use thiserror::Error;
70use tokio::io::AsyncWriteExt;
71use tokio::sync::oneshot::Receiver;
72use tokio::sync::{oneshot, watch};
73use tokio_metrics::TaskMetrics;
74use tower::limit::GlobalConcurrencyLimitLayer;
75use tower::{Service, ServiceBuilder};
76use tower_http::cors::{AllowOrigin, Any, CorsLayer};
77use tower_sessions::{
78    MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
79    SessionManagerLayer as TowerSessionManagerLayer,
80};
81use tracing::warn;
82
83use crate::BUILD_INFO;
84use crate::deployment::state::DeploymentStateHandle;
85use crate::http::sql::SqlError;
86
87mod catalog;
88mod cluster;
89mod console;
90mod memory;
91mod metrics;
92mod probe;
93mod prometheus;
94mod root;
95mod sql;
96mod webhook;
97
98pub use metrics::Metrics;
99pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
100
101/// Maximum allowed size for a request.
102pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
103
104const SESSION_DURATION: Duration = Duration::from_secs(8 * 3600); // 8 hours
105
106const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
107
108#[derive(Debug)]
109pub struct HttpConfig {
110    pub source: &'static str,
111    pub tls: Option<ReloadingSslContext>,
112    pub authenticator_kind: AuthenticatorKind,
113    pub authenticator_rx: Shared<Receiver<Arc<Authenticator>>>,
114    pub adapter_client_rx: Shared<Receiver<Client>>,
115    pub allowed_origin: AllowOrigin,
116    pub active_connection_counter: ConnectionCounter,
117    pub helm_chart_version: Option<String>,
118    pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
119    pub metrics: Metrics,
120    pub metrics_registry: MetricsRegistry,
121    pub allowed_roles: AllowedRoles,
122    pub internal_route_config: Arc<InternalRouteConfig>,
123    pub routes_enabled: HttpRoutesEnabled,
124    /// Locator for cluster replica HTTP addresses, used for proxying requests.
125    pub replica_http_locator: Arc<ReplicaHttpLocator>,
126}
127
128#[derive(Debug, Clone)]
129pub struct InternalRouteConfig {
130    pub deployment_state_handle: DeploymentStateHandle,
131    pub internal_console_redirect_url: Option<String>,
132}
133
134#[derive(Clone)]
135pub struct WsState {
136    authenticator_rx: Delayed<Arc<Authenticator>>,
137    adapter_client_rx: Delayed<mz_adapter::Client>,
138    active_connection_counter: ConnectionCounter,
139    helm_chart_version: Option<String>,
140    allowed_roles: AllowedRoles,
141}
142
143#[derive(Clone)]
144pub struct WebhookState {
145    adapter_client_rx: Delayed<mz_adapter::Client>,
146    webhook_cache: WebhookAppenderCache,
147}
148
149#[derive(Debug)]
150pub struct HttpServer {
151    tls: Option<ReloadingSslContext>,
152    router: Router,
153}
154
155impl HttpServer {
156    pub fn new(
157        HttpConfig {
158            source,
159            tls,
160            authenticator_kind,
161            authenticator_rx,
162            adapter_client_rx,
163            allowed_origin,
164            active_connection_counter,
165            helm_chart_version,
166            concurrent_webhook_req,
167            metrics,
168            metrics_registry,
169            allowed_roles,
170            internal_route_config,
171            routes_enabled,
172            replica_http_locator,
173        }: HttpConfig,
174    ) -> HttpServer {
175        let tls_enabled = tls.is_some();
176        let webhook_cache = WebhookAppenderCache::new();
177
178        // Create secure session store and manager
179        let session_store = TowerSessionMemoryStore::default();
180        let session_layer = TowerSessionManagerLayer::new(session_store)
181            .with_secure(tls_enabled) // Enforce HTTPS
182            .with_same_site(tower_sessions::cookie::SameSite::Strict) // Prevent CSRF
183            .with_http_only(true) // Prevent XSS
184            .with_name("mz_session") // Custom cookie name
185            .with_path("/"); // Set cookie path
186
187        let auth_middleware_authenticator_rx = authenticator_rx.clone();
188        let auth_middleware = middleware::from_fn(move |req, next| {
189            let authenticator_rx = auth_middleware_authenticator_rx.clone();
190            async move {
191                let authenticator = authenticator_rx
192                    .await
193                    .expect("sender not dropped before sending once");
194                http_auth(req, next, tls_enabled, authenticator, allowed_roles).await
195            }
196        });
197
198        let mut router = Router::new();
199        let mut base_router = Router::new();
200        if routes_enabled.base {
201            base_router = base_router
202                .route(
203                    "/",
204                    routing::get(move || async move { root::handle_home(routes_enabled).await }),
205                )
206                .route("/api/sql", routing::post(sql::handle_sql))
207                .route("/memory", routing::get(memory::handle_memory))
208                .route(
209                    "/hierarchical-memory",
210                    routing::get(memory::handle_hierarchical_memory),
211                )
212                .route("/static/{*path}", routing::get(root::handle_static));
213
214            let mut ws_router = Router::new()
215                .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
216                .with_state(WsState {
217                    authenticator_rx: authenticator_rx.clone(),
218                    adapter_client_rx: adapter_client_rx.clone(),
219                    active_connection_counter: active_connection_counter.clone(),
220                    helm_chart_version,
221                    allowed_roles,
222                });
223            if let AuthenticatorKind::None = authenticator_kind {
224                ws_router = ws_router.layer(middleware::from_fn(x_materialize_user_header_auth));
225            }
226            router = router.merge(ws_router);
227        }
228        if routes_enabled.profiling {
229            base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
230        }
231
232        if routes_enabled.webhook {
233            let webhook_router = Router::new()
234                .route(
235                    "/api/webhook/{:database}/{:schema}/{:id}",
236                    routing::post(webhook::handle_webhook),
237                )
238                .with_state(WebhookState {
239                    adapter_client_rx: adapter_client_rx.clone(),
240                    webhook_cache,
241                })
242                .layer(
243                    tower_http::decompression::RequestDecompressionLayer::new()
244                        .gzip(true)
245                        .deflate(true)
246                        .br(true)
247                        .zstd(true),
248                )
249                .layer(
250                    CorsLayer::new()
251                        .allow_methods(Method::POST)
252                        .allow_origin(AllowOrigin::mirror_request())
253                        .allow_headers(Any),
254                )
255                .layer(
256                    ServiceBuilder::new()
257                        .layer(HandleErrorLayer::new(handle_load_error))
258                        .load_shed()
259                        .layer(GlobalConcurrencyLimitLayer::with_semaphore(
260                            concurrent_webhook_req,
261                        )),
262                );
263            router = router.merge(webhook_router);
264        }
265
266        if routes_enabled.internal {
267            let console_config = Arc::new(console::ConsoleProxyConfig::new(
268                internal_route_config.internal_console_redirect_url.clone(),
269                "/internal-console".to_string(),
270            ));
271            base_router = base_router
272                .route(
273                    "/api/opentelemetry/config",
274                    routing::put({
275                        move |_: axum::Json<DynamicFilterTarget>| async {
276                            (
277                                StatusCode::BAD_REQUEST,
278                                "This endpoint has been replaced. \
279                            Use the `opentelemetry_filter` system variable."
280                                    .to_string(),
281                            )
282                        }
283                    }),
284                )
285                .route(
286                    "/api/stderr/config",
287                    routing::put({
288                        move |_: axum::Json<DynamicFilterTarget>| async {
289                            (
290                                StatusCode::BAD_REQUEST,
291                                "This endpoint has been replaced. \
292                            Use the `log_filter` system variable."
293                                    .to_string(),
294                            )
295                        }
296                    }),
297                )
298                .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
299                .route(
300                    "/api/catalog/dump",
301                    routing::get(catalog::handle_catalog_dump),
302                )
303                .route(
304                    "/api/catalog/check",
305                    routing::get(catalog::handle_catalog_check),
306                )
307                .route(
308                    "/api/coordinator/check",
309                    routing::get(catalog::handle_coordinator_check),
310                )
311                .route(
312                    "/api/coordinator/dump",
313                    routing::get(catalog::handle_coordinator_dump),
314                )
315                .route(
316                    "/internal-console",
317                    routing::get(|| async { Redirect::temporary("/internal-console/") }),
318                )
319                .route(
320                    "/internal-console/{*path}",
321                    routing::get(console::handle_internal_console),
322                )
323                .route(
324                    "/internal-console/",
325                    routing::get(console::handle_internal_console),
326                )
327                .layer(Extension(console_config));
328
329            // Cluster HTTP proxy routes.
330            let cluster_proxy_config = Arc::new(cluster::ClusterProxyConfig::new(Arc::clone(
331                &replica_http_locator,
332            )));
333            base_router = base_router
334                .route("/clusters", routing::get(cluster::handle_clusters))
335                .route(
336                    "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/",
337                    routing::any(cluster::handle_cluster_proxy_root),
338                )
339                .route(
340                    "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/{*path}",
341                    routing::any(cluster::handle_cluster_proxy),
342                )
343                .layer(Extension(cluster_proxy_config));
344
345            let leader_router = Router::new()
346                .route("/api/leader/status", routing::get(handle_leader_status))
347                .route("/api/leader/promote", routing::post(handle_leader_promote))
348                .route(
349                    "/api/leader/skip-catchup",
350                    routing::post(handle_leader_skip_catchup),
351                )
352                .layer(auth_middleware.clone())
353                .with_state(internal_route_config.deployment_state_handle.clone());
354            router = router.merge(leader_router);
355        }
356
357        if routes_enabled.metrics {
358            let metrics_router = Router::new()
359                .route(
360                    "/metrics",
361                    routing::get(move || async move {
362                        mz_http_util::handle_prometheus(&metrics_registry).await
363                    }),
364                )
365                .route(
366                    "/metrics/mz_usage",
367                    routing::get(|client: AuthedClient| async move {
368                        let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
369                        mz_http_util::handle_prometheus(&registry).await
370                    }),
371                )
372                .route(
373                    "/metrics/mz_frontier",
374                    routing::get(|client: AuthedClient| async move {
375                        let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
376                        mz_http_util::handle_prometheus(&registry).await
377                    }),
378                )
379                .route(
380                    "/metrics/mz_compute",
381                    routing::get(|client: AuthedClient| async move {
382                        let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
383                        mz_http_util::handle_prometheus(&registry).await
384                    }),
385                )
386                .route(
387                    "/metrics/mz_storage",
388                    routing::get(|client: AuthedClient| async move {
389                        let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
390                        mz_http_util::handle_prometheus(&registry).await
391                    }),
392                )
393                .route(
394                    "/api/livez",
395                    routing::get(mz_http_util::handle_liveness_check),
396                )
397                .route("/api/readyz", routing::get(probe::handle_ready))
398                .layer(auth_middleware.clone())
399                .layer(Extension(adapter_client_rx.clone()))
400                .layer(Extension(active_connection_counter.clone()));
401            router = router.merge(metrics_router);
402        }
403
404        base_router = base_router
405            .layer(auth_middleware.clone())
406            .layer(Extension(adapter_client_rx.clone()))
407            .layer(Extension(active_connection_counter.clone()))
408            .layer(
409                CorsLayer::new()
410                    .allow_credentials(false)
411                    .allow_headers([
412                        AUTHORIZATION,
413                        CONTENT_TYPE,
414                        HeaderName::from_static("x-materialize-version"),
415                    ])
416                    .allow_methods(Any)
417                    .allow_origin(allowed_origin)
418                    .expose_headers(Any)
419                    .max_age(Duration::from_secs(60) * 60),
420            );
421
422        match authenticator_kind {
423            AuthenticatorKind::Password => {
424                base_router = base_router.layer(session_layer.clone());
425
426                let login_router = Router::new()
427                    .route("/api/login", routing::post(handle_login))
428                    .route("/api/logout", routing::post(handle_logout))
429                    .layer(Extension(adapter_client_rx));
430                router = router.merge(login_router).layer(session_layer);
431            }
432            AuthenticatorKind::None => {
433                base_router =
434                    base_router.layer(middleware::from_fn(x_materialize_user_header_auth));
435            }
436            _ => {}
437        }
438
439        router = router
440            .merge(base_router)
441            .apply_default_layers(source, metrics);
442
443        HttpServer { tls, router }
444    }
445}
446
447impl Server for HttpServer {
448    const NAME: &'static str = "http";
449
450    fn handle_connection(
451        &self,
452        conn: Connection,
453        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
454    ) -> ConnectionHandler {
455        let router = self.router.clone();
456        let tls_context = self.tls.clone();
457        let mut conn = TokioIo::new(conn);
458
459        Box::pin(async {
460            let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
461            let peer_addr = conn
462                .inner_mut()
463                .take_proxy_header_address()
464                .await
465                .map(|a| a.source)
466                .unwrap_or(direct_peer_addr);
467
468            let (conn, conn_protocol) = match tls_context {
469                Some(tls_context) => {
470                    let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
471                    if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
472                        let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
473                        return Err(e.into());
474                    }
475                    (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
476                }
477                _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
478            };
479            let mut make_tower_svc = router
480                .layer(Extension(conn_protocol))
481                .into_make_service_with_connect_info::<SocketAddr>();
482            let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
483            let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
484            let http = hyper::server::conn::http1::Builder::new();
485            http.serve_connection(conn, hyper_svc)
486                .with_upgrades()
487                .err_into()
488                .await
489        })
490    }
491}
492
493pub async fn handle_leader_status(
494    State(deployment_state_handle): State<DeploymentStateHandle>,
495) -> impl IntoResponse {
496    let status = deployment_state_handle.status();
497    (StatusCode::OK, Json(json!({ "status": status })))
498}
499
500pub async fn handle_leader_promote(
501    State(deployment_state_handle): State<DeploymentStateHandle>,
502) -> impl IntoResponse {
503    match deployment_state_handle.try_promote() {
504        Ok(()) => {
505            // TODO(benesch): the body here is redundant. Should just return
506            // 204.
507            let status = StatusCode::OK;
508            let body = Json(json!({
509                "result": "Success",
510            }));
511            (status, body)
512        }
513        Err(()) => {
514            // TODO(benesch): the nesting here is redundant given the error
515            // code. Should just return the `{"message": "..."}` object.
516            let status = StatusCode::BAD_REQUEST;
517            let body = Json(json!({
518                "result": {"Failure": {"message": "cannot promote leader while initializing"}},
519            }));
520            (status, body)
521        }
522    }
523}
524
525pub async fn handle_leader_skip_catchup(
526    State(deployment_state_handle): State<DeploymentStateHandle>,
527) -> impl IntoResponse {
528    match deployment_state_handle.try_skip_catchup() {
529        Ok(()) => StatusCode::NO_CONTENT.into_response(),
530        Err(()) => {
531            let status = StatusCode::BAD_REQUEST;
532            let body = Json(json!({
533                "message": "cannot skip catchup in this phase of initialization; try again later",
534            }));
535            (status, body).into_response()
536        }
537    }
538}
539
540async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl IntoResponse {
541    // TODO migrate teleport to basic auth and remove this.
542    if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
543        let username = match username {
544            Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
545            _ => {
546                return Err(AuthError::MismatchedUser(format!(
547                    "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
548                )));
549            }
550        };
551        let superuser = matches!(username.as_str(), SYSTEM_USER_NAME);
552        req.extensions_mut().insert(AuthedUser {
553            name: username,
554            external_metadata_rx: None,
555            internal_metadata: Some(InternalUserMetadata { superuser }),
556        });
557    }
558    Ok(next.run(req).await)
559}
560
561type Delayed<T> = Shared<oneshot::Receiver<T>>;
562
563#[derive(Clone)]
564enum ConnProtocol {
565    Http,
566    Https,
567}
568
569#[derive(Clone, Debug)]
570pub struct AuthedUser {
571    name: String,
572    external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
573    internal_metadata: Option<InternalUserMetadata>,
574}
575
576pub struct AuthedClient {
577    pub client: SessionClient,
578    pub connection_guard: Option<ConnectionHandle>,
579}
580
581impl AuthedClient {
582    async fn new<F>(
583        adapter_client: &Client,
584        user: AuthedUser,
585        peer_addr: IpAddr,
586        active_connection_counter: ConnectionCounter,
587        helm_chart_version: Option<String>,
588        session_config: F,
589        options: BTreeMap<String, String>,
590        now: NowFn,
591    ) -> Result<Self, AdapterError>
592    where
593        F: FnOnce(&mut AdapterSession),
594    {
595        let conn_id = adapter_client.new_conn_id()?;
596        let mut session = adapter_client.new_session(AdapterSessionConfig {
597            conn_id,
598            uuid: epoch_to_uuid_v7(&(now)()),
599            user: user.name,
600            client_ip: Some(peer_addr),
601            external_metadata_rx: user.external_metadata_rx,
602            internal_user_metadata: user.internal_metadata,
603            helm_chart_version,
604        });
605        let connection_guard = active_connection_counter.allocate_connection(session.user())?;
606
607        session_config(&mut session);
608        let system_vars = adapter_client.get_system_vars().await;
609        for (key, val) in options {
610            const LOCAL: bool = false;
611            if let Err(err) =
612                session
613                    .vars_mut()
614                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
615            {
616                session.add_notice(AdapterNotice::BadStartupSetting {
617                    name: key.to_string(),
618                    reason: err.to_string(),
619                })
620            }
621        }
622        let adapter_client = adapter_client.startup(session).await?;
623        Ok(AuthedClient {
624            client: adapter_client,
625            connection_guard,
626        })
627    }
628}
629
630impl<S> FromRequestParts<S> for AuthedClient
631where
632    S: Send + Sync,
633{
634    type Rejection = Response;
635
636    async fn from_request_parts(
637        req: &mut http::request::Parts,
638        state: &S,
639    ) -> Result<Self, Self::Rejection> {
640        #[derive(Debug, Default, Deserialize)]
641        struct Params {
642            #[serde(default)]
643            options: String,
644        }
645        let params: Query<Params> = Query::from_request_parts(req, state)
646            .await
647            .unwrap_or_default();
648
649        let peer_addr = req
650            .extensions
651            .get::<ConnectInfo<SocketAddr>>()
652            .expect("ConnectInfo extension guaranteed to exist")
653            .0
654            .ip();
655
656        let user = req.extensions.get::<AuthedUser>().unwrap();
657        let adapter_client = req
658            .extensions
659            .get::<Delayed<mz_adapter::Client>>()
660            .unwrap()
661            .clone();
662        let adapter_client = adapter_client.await.map_err(|_| {
663            (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
664        })?;
665        let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
666        let helm_chart_version = None;
667
668        let options = if params.options.is_empty() {
669            // It's possible 'options' simply wasn't provided, we don't want that to
670            // count as a failure to deserialize
671            BTreeMap::<String, String>::default()
672        } else {
673            match serde_json::from_str(&params.options) {
674                Ok(options) => options,
675                Err(_e) => {
676                    // If we fail to deserialize options, fail the request.
677                    let code = StatusCode::BAD_REQUEST;
678                    let msg = format!("Failed to deserialize {} map", "options".quoted());
679                    return Err((code, msg).into_response());
680                }
681            }
682        };
683
684        let client = AuthedClient::new(
685            &adapter_client,
686            user.clone(),
687            peer_addr,
688            active_connection_counter.clone(),
689            helm_chart_version,
690            |session| {
691                session
692                    .vars_mut()
693                    .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
694                    .expect("known to exist")
695            },
696            options,
697            SYSTEM_TIME.clone(),
698        )
699        .await
700        .map_err(|e| {
701            let status = match e {
702                AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
703                    StatusCode::FORBIDDEN
704                }
705                _ => StatusCode::INTERNAL_SERVER_ERROR,
706            };
707            (status, Json(SqlError::from(e))).into_response()
708        })?;
709
710        Ok(client)
711    }
712}
713
714#[derive(Debug, Error)]
715enum AuthError {
716    #[error("role dissallowed")]
717    RoleDisallowed(String),
718    #[error("{0}")]
719    Frontegg(#[from] FronteggError),
720    #[error("missing authorization header")]
721    MissingHttpAuthentication {
722        include_www_authenticate_header: bool,
723    },
724    #[error("{0}")]
725    MismatchedUser(String),
726    #[error("session expired")]
727    SessionExpired,
728    #[error("failed to update session")]
729    FailedToUpdateSession,
730    #[error("invalid credentials")]
731    InvalidCredentials,
732}
733
734impl IntoResponse for AuthError {
735    fn into_response(self) -> Response {
736        warn!("HTTP request failed authentication: {}", self);
737        let mut headers = HeaderMap::new();
738        match self {
739            AuthError::MissingHttpAuthentication {
740                include_www_authenticate_header,
741            } if include_www_authenticate_header => {
742                headers.insert(
743                    http::header::WWW_AUTHENTICATE,
744                    HeaderValue::from_static("Basic realm=Materialize"),
745                );
746            }
747            _ => {}
748        };
749        // We omit most detail from the error message we send to the client, to
750        // avoid giving attackers unnecessary information.
751        (StatusCode::UNAUTHORIZED, headers, "unauthorized").into_response()
752    }
753}
754
755// Simplified login handler
756pub async fn handle_login(
757    session: Option<Extension<TowerSession>>,
758    Extension(adapter_client_rx): Extension<Delayed<Client>>,
759    Json(LoginCredentials { username, password }): Json<LoginCredentials>,
760) -> impl IntoResponse {
761    let Ok(adapter_client) = adapter_client_rx.clone().await else {
762        return StatusCode::INTERNAL_SERVER_ERROR;
763    };
764    let auth_response = match adapter_client.authenticate(&username, &password).await {
765        Ok(auth_response) => auth_response,
766        Err(err) => {
767            warn!(?err, "HTTP login failed authentication");
768            return StatusCode::UNAUTHORIZED;
769        }
770    };
771
772    // Create session data
773    let session_data = TowerSessionData {
774        username,
775        created_at: SystemTime::now(),
776        last_activity: SystemTime::now(),
777        internal_metadata: InternalUserMetadata {
778            superuser: auth_response.superuser,
779        },
780    };
781    // Store session data
782    let session = session.and_then(|Extension(session)| Some(session));
783    let Some(session) = session else {
784        return StatusCode::INTERNAL_SERVER_ERROR;
785    };
786    match session.insert("data", &session_data).await {
787        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
788        Ok(_) => StatusCode::OK,
789    }
790}
791
792// Simplified logout handler
793pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
794    let session = session.and_then(|Extension(session)| Some(session));
795    let Some(session) = session else {
796        return StatusCode::INTERNAL_SERVER_ERROR;
797    };
798    // Delete session
799    match session.delete().await {
800        Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
801        Ok(_) => StatusCode::OK,
802    }
803}
804
805async fn http_auth(
806    mut req: Request,
807    next: Next,
808    tls_enabled: bool,
809    authenticator: Arc<Authenticator>,
810    allowed_roles: AllowedRoles,
811) -> impl IntoResponse + use<> {
812    // First check for session authentication
813    if let Some(session) = req.extensions().get::<TowerSession>() {
814        if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
815            // Check session expiration
816            if session_data
817                .last_activity
818                .elapsed()
819                .unwrap_or(Duration::MAX)
820                > SESSION_DURATION
821            {
822                let _ = session.delete().await;
823                return Err(AuthError::SessionExpired);
824            }
825            // Update last activity
826            let mut updated_data = session_data.clone();
827            updated_data.last_activity = SystemTime::now();
828            session
829                .insert("data", &updated_data)
830                .await
831                .map_err(|_| AuthError::FailedToUpdateSession)?;
832            // User is authenticated via session
833            req.extensions_mut().insert(AuthedUser {
834                name: session_data.username,
835                external_metadata_rx: None,
836                internal_metadata: Some(session_data.internal_metadata),
837            });
838            return Ok(next.run(req).await);
839        }
840    }
841
842    // First, extract the username from the certificate, validating that the
843    // connection matches the TLS configuration along the way.
844    // Fall back to existing authentication methods.
845    let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
846    match (tls_enabled, &conn_protocol) {
847        (false, ConnProtocol::Http) => {}
848        (false, ConnProtocol::Https { .. }) => unreachable!(),
849        (true, ConnProtocol::Http) => {
850            let mut parts = req.uri().clone().into_parts();
851            parts.scheme = Some(Scheme::HTTPS);
852            return Ok(Redirect::permanent(
853                &Uri::from_parts(parts)
854                    .expect("it was already a URI, just changed the scheme")
855                    .to_string(),
856            )
857            .into_response());
858        }
859        (true, ConnProtocol::Https { .. }) => {}
860    }
861    // If we've already passed some other auth, just use that.
862    if req.extensions().get::<AuthedUser>().is_some() {
863        return Ok(next.run(req).await);
864    }
865    let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
866        Some(Credentials::Password {
867            username: basic.username().to_owned(),
868            password: Password(basic.password().to_owned()),
869        })
870    } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
871        Some(Credentials::Token {
872            token: bearer.token().to_owned(),
873        })
874    } else {
875        None
876    };
877
878    let path = req.uri().path();
879    let include_www_authenticate_header = path == "/"
880        || PROFILING_API_ENDPOINTS
881            .iter()
882            .any(|prefix| path.starts_with(prefix));
883    let user = auth(
884        &authenticator,
885        creds,
886        allowed_roles,
887        include_www_authenticate_header,
888    )
889    .await?;
890
891    // Add the authenticated user as an extension so downstream handlers can
892    // inspect it if necessary.
893    req.extensions_mut().insert(user);
894
895    // Run the request.
896    Ok(next.run(req).await)
897}
898
899async fn init_ws(
900    WsState {
901        authenticator_rx,
902        adapter_client_rx,
903        active_connection_counter,
904        helm_chart_version,
905        allowed_roles,
906    }: &WsState,
907    existing_user: Option<AuthedUser>,
908    peer_addr: IpAddr,
909    ws: &mut WebSocket,
910) -> Result<AuthedClient, anyhow::Error> {
911    let authenticator = authenticator_rx.clone().await.expect("sender not dropped");
912    // TODO: Add a timeout here to prevent resource leaks by clients that
913    // connect then never send a message.
914    let ws_auth: WebSocketAuth = loop {
915        let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
916        match init_msg {
917            Message::Text(data) => break serde_json::from_str(&data)?,
918            Message::Binary(data) => break serde_json::from_slice(&data)?,
919            // Handled automatically by the server.
920            Message::Ping(_) => {
921                continue;
922            }
923            Message::Pong(_) => {
924                continue;
925            }
926            Message::Close(_) => {
927                anyhow::bail!("closed");
928            }
929        }
930    };
931
932    let (user, options) = if let Some(existing_user) = existing_user {
933        match ws_auth {
934            WebSocketAuth::OptionsOnly { options } => (existing_user, options),
935            _ => {
936                warn!("Unexpected bearer or basic auth provided when using user header");
937                anyhow::bail!("unexpected")
938            }
939        }
940    } else {
941        let (creds, options) = match ws_auth {
942            WebSocketAuth::Basic {
943                user,
944                password,
945                options,
946            } => {
947                let creds = Credentials::Password {
948                    username: user,
949                    password,
950                };
951                (creds, options)
952            }
953            WebSocketAuth::Bearer { token, options } => {
954                let creds = Credentials::Token { token };
955                (creds, options)
956            }
957            WebSocketAuth::OptionsOnly { .. } => {
958                anyhow::bail!("expected auth information");
959            }
960        };
961        let user = auth(&authenticator, Some(creds), *allowed_roles, false).await?;
962        (user, options)
963    };
964
965    let client = AuthedClient::new(
966        &adapter_client_rx.clone().await?,
967        user,
968        peer_addr,
969        active_connection_counter.clone(),
970        helm_chart_version.clone(),
971        |_session| (),
972        options,
973        SYSTEM_TIME.clone(),
974    )
975    .await?;
976
977    Ok(client)
978}
979
980enum Credentials {
981    Password {
982        username: String,
983        password: Password,
984    },
985    Token {
986        token: String,
987    },
988}
989
990async fn auth(
991    authenticator: &Authenticator,
992    creds: Option<Credentials>,
993    allowed_roles: AllowedRoles,
994    include_www_authenticate_header: bool,
995) -> Result<AuthedUser, AuthError> {
996    let (name, external_metadata_rx, internal_metadata) = match authenticator {
997        Authenticator::Frontegg(frontegg) => match creds {
998            Some(Credentials::Password { username, password }) => {
999                let auth_session = frontegg.authenticate(&username, &password.0).await?;
1000                let name = auth_session.user().into();
1001                let external_metadata_rx = Some(auth_session.external_metadata_rx());
1002                (name, external_metadata_rx, None)
1003            }
1004            Some(Credentials::Token { token }) => {
1005                let claims = frontegg.validate_access_token(&token, None)?;
1006                let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
1007                    user_id: claims.user_id,
1008                    admin: claims.is_admin,
1009                });
1010                (claims.user, Some(external_metadata_rx), None)
1011            }
1012            None => {
1013                return Err(AuthError::MissingHttpAuthentication {
1014                    include_www_authenticate_header,
1015                });
1016            }
1017        },
1018        Authenticator::Password(adapter_client) => match creds {
1019            Some(Credentials::Password { username, password }) => {
1020                let auth_response = adapter_client
1021                    .authenticate(&username, &password)
1022                    .await
1023                    .map_err(|_| AuthError::InvalidCredentials)?;
1024                let internal_metadata = InternalUserMetadata {
1025                    superuser: auth_response.superuser,
1026                };
1027                (username, None, Some(internal_metadata))
1028            }
1029            _ => {
1030                return Err(AuthError::MissingHttpAuthentication {
1031                    include_www_authenticate_header,
1032                });
1033            }
1034        },
1035        Authenticator::Sasl(_) => {
1036            // We shouldn't ever end up here as the configuration is validated at startup.
1037            // If we do, it's a server misconfiguration.
1038            // Just in case, we return a 401 rather than panic.
1039            return Err(AuthError::MissingHttpAuthentication {
1040                include_www_authenticate_header,
1041            });
1042        }
1043        Authenticator::None => {
1044            // If no authentication, use whatever is in the HTTP auth
1045            // header (without checking the password), or fall back to the
1046            // default user.
1047            let name = match creds {
1048                Some(Credentials::Password { username, .. }) => username,
1049                _ => HTTP_DEFAULT_USER.name.to_owned(),
1050            };
1051            (name, None, None)
1052        }
1053    };
1054
1055    check_role_allowed(&name, allowed_roles)?;
1056
1057    Ok(AuthedUser {
1058        name,
1059        external_metadata_rx,
1060        internal_metadata,
1061    })
1062}
1063
1064// TODO move this somewhere it can be shared with PGWIRE
1065fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1066    let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1067    // this is a superset of internal users
1068    let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1069    let role_allowed = match allowed_roles {
1070        AllowedRoles::Normal => !is_reserved_user,
1071        AllowedRoles::Internal => is_internal_user,
1072        AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1073    };
1074    if role_allowed {
1075        Ok(())
1076    } else {
1077        Err(AuthError::RoleDisallowed(name.to_owned()))
1078    }
1079}
1080
1081/// Default layers that should be applied to all routes, and should get applied to both the
1082/// internal http and external http routers.
1083trait DefaultLayers {
1084    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1085}
1086
1087impl DefaultLayers for Router {
1088    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1089        self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1090            .layer(metrics::PrometheusLayer::new(source, metrics))
1091    }
1092}
1093
1094/// Glue code to make [`tower`] work with [`axum`].
1095///
1096/// `axum` requires `Layer`s not return Errors, i.e. they must be `Result<_, Infallible>`,
1097/// instead you must return a type that can be converted into a response. `tower` on the other
1098/// hand does return Errors, so to make the two work together we need to convert our `tower` errors
1099/// into responses.
1100async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1101    if error.is::<tower::load_shed::error::Overloaded>() {
1102        return (
1103            StatusCode::TOO_MANY_REQUESTS,
1104            Cow::from("too many requests, try again later"),
1105        );
1106    }
1107
1108    // Note: This should be unreachable because at the time of writing our only use case is a
1109    // layer that emits `tower::load_shed::error::Overloaded`, which is handled above.
1110    (
1111        StatusCode::INTERNAL_SERVER_ERROR,
1112        Cow::from(format!("Unhandled internal error: {}", error)),
1113    )
1114}
1115
1116#[derive(Debug, Deserialize, Serialize, PartialEq)]
1117pub struct LoginCredentials {
1118    username: String,
1119    password: Password,
1120}
1121
1122#[derive(Debug, Clone, Serialize, Deserialize)]
1123pub struct TowerSessionData {
1124    username: String,
1125    created_at: SystemTime,
1126    last_activity: SystemTime,
1127    internal_metadata: InternalUserMetadata,
1128}
1129
1130#[cfg(test)]
1131mod tests {
1132    use super::{AllowedRoles, check_role_allowed};
1133
1134    #[mz_ore::test]
1135    fn test_check_role_allowed() {
1136        // Internal user
1137        assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1138        assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1139        assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1140
1141        // Internal user
1142        assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1143        assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1144        assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1145
1146        // Internal user
1147        assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1148        assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1149        assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1150
1151        // Normal user
1152        assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1153        assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1154        assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1155
1156        // Normal user
1157        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1158        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1159        assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1160
1161        // Normal user
1162        assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1163        assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1164        assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1165
1166        // Denied by reserved role prefix
1167        assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1168        assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1169        assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1170
1171        // Denied by reserved role prefix
1172        assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1173        assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1174        assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1175
1176        // Denied by reserved role prefix
1177        assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1178        assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1179        assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1180
1181        // Denied by literal PUBLIC
1182        assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1183        assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1184        assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1185    }
1186}