mz_environmentd/
http.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Embedded HTTP server.
11//!
12//! environmentd embeds an HTTP server for introspection into the running
13//! process. At the moment, its primary exports are Prometheus metrics, heap
14//! profiles, and catalog dumps.
15
16// Axum handlers must use async, but often don't actually use `await`.
17#![allow(clippy::unused_async)]
18
19use std::borrow::Cow;
20use std::collections::BTreeMap;
21use std::fmt::Debug;
22use std::net::{IpAddr, SocketAddr};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::time::Duration;
26
27use anyhow::Context;
28use async_trait::async_trait;
29use axum::error_handling::HandleErrorLayer;
30use axum::extract::ws::{Message, WebSocket};
31use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
32use axum::middleware::{self, Next};
33use axum::response::{IntoResponse, Redirect, Response};
34use axum::{Extension, Json, Router, routing};
35use futures::future::{FutureExt, Shared, TryFutureExt};
36use headers::authorization::{Authorization, Basic, Bearer};
37use headers::{HeaderMapExt, HeaderName};
38use http::header::{AUTHORIZATION, CONTENT_TYPE};
39use http::{Method, StatusCode};
40use hyper_openssl::SslStream;
41use hyper_openssl::client::legacy::MaybeHttpsStream;
42use hyper_util::rt::TokioIo;
43use mz_adapter::session::{Session, SessionConfig};
44use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
45use mz_frontegg_auth::{Authenticator as FronteggAuthentication, Error as FronteggError};
46use mz_http_util::DynamicFilterTarget;
47use mz_ore::cast::u64_to_usize;
48use mz_ore::metrics::MetricsRegistry;
49use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
50use mz_ore::str::StrExt;
51use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
52use mz_repr::user::ExternalUserMetadata;
53use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
54use mz_sql::session::metadata::SessionMetadata;
55use mz_sql::session::user::{HTTP_DEFAULT_USER, SUPPORT_USER_NAME, SYSTEM_USER_NAME};
56use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
57use openssl::ssl::Ssl;
58use prometheus::{
59    COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
60};
61use serde::Deserialize;
62use serde_json::json;
63use thiserror::Error;
64use tokio::io::AsyncWriteExt;
65use tokio::sync::{oneshot, watch};
66use tower::limit::GlobalConcurrencyLimitLayer;
67use tower::{Service, ServiceBuilder};
68use tower_http::cors::{AllowOrigin, Any, CorsLayer};
69use tracing::{error, warn};
70
71use crate::BUILD_INFO;
72use crate::deployment::state::DeploymentStateHandle;
73use crate::http::sql::SqlError;
74
75mod catalog;
76mod console;
77mod memory;
78mod metrics;
79mod probe;
80mod prometheus;
81mod root;
82mod sql;
83mod webhook;
84
85pub use metrics::Metrics;
86pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
87
88/// Maximum allowed size for a request.
89pub const MAX_REQUEST_SIZE: usize = u64_to_usize(2 * bytesize::MB);
90
91#[derive(Debug, Clone)]
92pub struct HttpConfig {
93    pub source: &'static str,
94    pub tls: Option<ReloadingTlsConfig>,
95    pub frontegg: Option<FronteggAuthentication>,
96    pub adapter_client: mz_adapter::Client,
97    pub allowed_origin: AllowOrigin,
98    pub active_connection_counter: ConnectionCounter,
99    pub helm_chart_version: Option<String>,
100    pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
101    pub metrics: Metrics,
102}
103
104#[derive(Debug, Clone)]
105pub struct ReloadingTlsConfig {
106    pub context: ReloadingSslContext,
107    pub mode: TlsMode,
108}
109
110#[derive(Debug, Clone, Copy)]
111pub enum TlsMode {
112    Disable,
113    Require,
114}
115
116#[derive(Clone)]
117pub struct WsState {
118    frontegg: Arc<Option<FronteggAuthentication>>,
119    adapter_client_rx: Delayed<mz_adapter::Client>,
120    active_connection_counter: ConnectionCounter,
121    helm_chart_version: Option<String>,
122}
123
124#[derive(Clone)]
125pub struct WebhookState {
126    adapter_client: mz_adapter::Client,
127    webhook_cache: WebhookAppenderCache,
128}
129
130#[derive(Debug)]
131pub struct HttpServer {
132    tls: Option<ReloadingTlsConfig>,
133    router: Router,
134}
135
136impl HttpServer {
137    pub fn new(
138        HttpConfig {
139            source,
140            tls,
141            frontegg,
142            adapter_client,
143            allowed_origin,
144            active_connection_counter,
145            helm_chart_version,
146            concurrent_webhook_req,
147            metrics,
148        }: HttpConfig,
149    ) -> HttpServer {
150        let tls_mode = tls.as_ref().map(|tls| tls.mode).unwrap_or(TlsMode::Disable);
151        let frontegg = Arc::new(frontegg);
152        let base_frontegg = Arc::clone(&frontegg);
153        let (adapter_client_tx, adapter_client_rx) = oneshot::channel();
154        adapter_client_tx
155            .send(adapter_client.clone())
156            .expect("rx known to be live");
157        let adapter_client_rx = adapter_client_rx.shared();
158        let webhook_cache = WebhookAppenderCache::new();
159
160        let base_router = base_router(BaseRouterConfig { profiling: false })
161            .layer(middleware::from_fn(move |req, next| {
162                let base_frontegg = Arc::clone(&base_frontegg);
163                async move { http_auth(req, next, tls_mode, base_frontegg.as_ref().as_ref()).await }
164            }))
165            .layer(Extension(adapter_client_rx.clone()))
166            .layer(Extension(active_connection_counter.clone()))
167            .layer(
168                CorsLayer::new()
169                    .allow_credentials(false)
170                    .allow_headers([
171                        AUTHORIZATION,
172                        CONTENT_TYPE,
173                        HeaderName::from_static("x-materialize-version"),
174                    ])
175                    .allow_methods(Any)
176                    .allow_origin(allowed_origin)
177                    .expose_headers(Any)
178                    .max_age(Duration::from_secs(60) * 60),
179            );
180        let ws_router = Router::new()
181            .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
182            .with_state(WsState {
183                frontegg,
184                adapter_client_rx,
185                active_connection_counter,
186                helm_chart_version,
187            });
188
189        let webhook_router = Router::new()
190            .route(
191                "/api/webhook/{:database}/{:schema}/{:id}",
192                routing::post(webhook::handle_webhook),
193            )
194            .with_state(WebhookState {
195                adapter_client,
196                webhook_cache,
197            })
198            .layer(
199                CorsLayer::new()
200                    .allow_methods(Method::POST)
201                    .allow_origin(AllowOrigin::mirror_request())
202                    .allow_headers(Any),
203            )
204            .layer(
205                ServiceBuilder::new()
206                    .layer(HandleErrorLayer::new(handle_load_error))
207                    .load_shed()
208                    .layer(GlobalConcurrencyLimitLayer::with_semaphore(
209                        concurrent_webhook_req,
210                    )),
211            );
212
213        let router = Router::new()
214            .merge(base_router)
215            .merge(ws_router)
216            .merge(webhook_router)
217            .apply_default_layers(source, metrics);
218
219        HttpServer { tls, router }
220    }
221}
222
223impl Server for HttpServer {
224    const NAME: &'static str = "http";
225
226    fn handle_connection(&self, conn: Connection) -> ConnectionHandler {
227        let router = self.router.clone();
228        let tls_config = self.tls.clone();
229        let mut conn = TokioIo::new(conn);
230        Box::pin(async {
231            let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
232            let peer_addr = conn
233                .inner_mut()
234                .take_proxy_header_address()
235                .await
236                .map(|a| a.source)
237                .unwrap_or(direct_peer_addr);
238
239            let (conn, conn_protocol) = match tls_config {
240                Some(tls_config) => {
241                    let mut ssl_stream =
242                        SslStream::new(Ssl::new(&tls_config.context.get())?, conn)?;
243                    if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
244                        let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
245                        return Err(e.into());
246                    }
247                    (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
248                }
249                _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
250            };
251            let mut make_tower_svc = router
252                .layer(Extension(conn_protocol))
253                .into_make_service_with_connect_info::<SocketAddr>();
254            let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
255            let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
256            let http = hyper::server::conn::http1::Builder::new();
257            http.serve_connection(conn, hyper_svc)
258                .with_upgrades()
259                .err_into()
260                .await
261        })
262    }
263}
264
265pub struct InternalHttpConfig {
266    pub metrics_registry: MetricsRegistry,
267    pub adapter_client_rx: oneshot::Receiver<mz_adapter::Client>,
268    pub active_connection_counter: ConnectionCounter,
269    pub helm_chart_version: Option<String>,
270    pub deployment_state_handle: DeploymentStateHandle,
271    pub internal_console_redirect_url: Option<String>,
272}
273
274pub struct InternalHttpServer {
275    router: Router,
276}
277
278pub async fn handle_leader_status(
279    State(deployment_state_handle): State<DeploymentStateHandle>,
280) -> impl IntoResponse {
281    let status = deployment_state_handle.status();
282    (StatusCode::OK, Json(json!({ "status": status })))
283}
284
285pub async fn handle_leader_promote(
286    State(deployment_state_handle): State<DeploymentStateHandle>,
287) -> impl IntoResponse {
288    match deployment_state_handle.try_promote() {
289        Ok(()) => {
290            // TODO(benesch): the body here is redundant. Should just return
291            // 204.
292            let status = StatusCode::OK;
293            let body = Json(json!({
294                "result": "Success",
295            }));
296            (status, body)
297        }
298        Err(()) => {
299            // TODO(benesch): the nesting here is redundant given the error
300            // code. Should just return the `{"message": "..."}` object.
301            let status = StatusCode::BAD_REQUEST;
302            let body = Json(json!({
303                "result": {"Failure": {"message": "cannot promote leader while initializing"}},
304            }));
305            (status, body)
306        }
307    }
308}
309
310pub async fn handle_leader_skip_catchup(
311    State(deployment_state_handle): State<DeploymentStateHandle>,
312) -> impl IntoResponse {
313    match deployment_state_handle.try_skip_catchup() {
314        Ok(()) => StatusCode::NO_CONTENT.into_response(),
315        Err(()) => {
316            let status = StatusCode::BAD_REQUEST;
317            let body = Json(json!({
318                "message": "cannot skip catchup in this phase of initialization; try again later",
319            }));
320            (status, body).into_response()
321        }
322    }
323}
324
325impl InternalHttpServer {
326    pub fn new(
327        InternalHttpConfig {
328            metrics_registry,
329            adapter_client_rx,
330            active_connection_counter,
331            helm_chart_version,
332            deployment_state_handle,
333            internal_console_redirect_url,
334        }: InternalHttpConfig,
335    ) -> InternalHttpServer {
336        let metrics = Metrics::register_into(&metrics_registry, "mz_internal_http");
337        let console_config = Arc::new(console::ConsoleProxyConfig::new(
338            internal_console_redirect_url,
339            "/internal-console".to_string(),
340        ));
341
342        let adapter_client_rx = adapter_client_rx.shared();
343        let router = base_router(BaseRouterConfig { profiling: true })
344            .route(
345                "/metrics",
346                routing::get(move || async move {
347                    mz_http_util::handle_prometheus(&metrics_registry).await
348                }),
349            )
350            .route(
351                "/metrics/mz_usage",
352                routing::get(|client: AuthedClient| async move {
353                    let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
354                    mz_http_util::handle_prometheus(&registry).await
355                }),
356            )
357            .route(
358                "/metrics/mz_frontier",
359                routing::get(|client: AuthedClient| async move {
360                    let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
361                    mz_http_util::handle_prometheus(&registry).await
362                }),
363            )
364            .route(
365                "/metrics/mz_compute",
366                routing::get(|client: AuthedClient| async move {
367                    let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
368                    mz_http_util::handle_prometheus(&registry).await
369                }),
370            )
371            .route(
372                "/metrics/mz_storage",
373                routing::get(|client: AuthedClient| async move {
374                    let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
375                    mz_http_util::handle_prometheus(&registry).await
376                }),
377            )
378            .route(
379                "/api/livez",
380                routing::get(mz_http_util::handle_liveness_check),
381            )
382            .route("/api/readyz", routing::get(probe::handle_ready))
383            .route(
384                "/api/opentelemetry/config",
385                routing::put({
386                    move |_: axum::Json<DynamicFilterTarget>| async {
387                        (
388                            StatusCode::BAD_REQUEST,
389                            "This endpoint has been replaced. \
390                            Use the `opentelemetry_filter` system variable."
391                                .to_string(),
392                        )
393                    }
394                }),
395            )
396            .route(
397                "/api/stderr/config",
398                routing::put({
399                    move |_: axum::Json<DynamicFilterTarget>| async {
400                        (
401                            StatusCode::BAD_REQUEST,
402                            "This endpoint has been replaced. \
403                            Use the `log_filter` system variable."
404                                .to_string(),
405                        )
406                    }
407                }),
408            )
409            .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
410            .route(
411                "/api/catalog/dump",
412                routing::get(catalog::handle_catalog_dump),
413            )
414            .route(
415                "/api/catalog/check",
416                routing::get(catalog::handle_catalog_check),
417            )
418            .route(
419                "/api/coordinator/check",
420                routing::get(catalog::handle_coordinator_check),
421            )
422            .route(
423                "/api/coordinator/dump",
424                routing::get(catalog::handle_coordinator_dump),
425            )
426            .route(
427                "/internal-console",
428                routing::get(|| async { Redirect::temporary("/internal-console/") }),
429            )
430            .route(
431                "/internal-console/{*path}",
432                routing::get(console::handle_internal_console),
433            )
434            .route(
435                "/internal-console/",
436                routing::get(console::handle_internal_console),
437            )
438            .layer(middleware::from_fn(internal_http_auth))
439            .layer(Extension(adapter_client_rx.clone()))
440            .layer(Extension(console_config))
441            .layer(Extension(active_connection_counter.clone()));
442
443        let ws_router = Router::new()
444            .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
445            // This middleware extracts the MZ user from the x-materialize-user http header.
446            // Normally, browser-initiated websocket requests do not support headers, however for the
447            // Internal HTTP Server the browser would be connecting through teleport, which should
448            // attach the x-materialize-user header to all requests it proxies to this api.
449            .layer(middleware::from_fn(internal_http_auth))
450            .with_state(WsState {
451                frontegg: Arc::new(None),
452                adapter_client_rx,
453                active_connection_counter,
454                helm_chart_version: helm_chart_version.clone(),
455            });
456
457        let leader_router = Router::new()
458            .route("/api/leader/status", routing::get(handle_leader_status))
459            .route("/api/leader/promote", routing::post(handle_leader_promote))
460            .route(
461                "/api/leader/skip-catchup",
462                routing::post(handle_leader_skip_catchup),
463            )
464            .with_state(deployment_state_handle);
465
466        let router = router
467            .merge(ws_router)
468            .merge(leader_router)
469            .apply_default_layers("internal", metrics);
470
471        InternalHttpServer { router }
472    }
473}
474
475async fn internal_http_auth(mut req: Request, next: Next) -> impl IntoResponse {
476    let user_name = req
477        .headers()
478        .get("x-materialize-user")
479        .map(|h| h.to_str())
480        .unwrap_or_else(|| Ok(SYSTEM_USER_NAME));
481    let user_name = match user_name {
482        Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
483        _ => {
484            return Err(AuthError::MismatchedUser(format!(
485                "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
486            )));
487        }
488    };
489    req.extensions_mut().insert(AuthedUser {
490        name: user_name,
491        external_metadata_rx: None,
492    });
493    Ok(next.run(req).await)
494}
495
496#[async_trait]
497impl Server for InternalHttpServer {
498    const NAME: &'static str = "internal_http";
499
500    fn handle_connection(&self, mut conn: Connection) -> ConnectionHandler {
501        let router = self.router.clone();
502
503        Box::pin(async {
504            let mut make_tower_svc = router.into_make_service_with_connect_info::<SocketAddr>();
505            let direct_peer_addr = conn.peer_addr().context("fetching peer addr")?;
506            let peer_addr = conn
507                .take_proxy_header_address()
508                .await
509                .map(|a| a.source)
510                .unwrap_or(direct_peer_addr);
511            let tower_svc = make_tower_svc.call(peer_addr).await?;
512            let service = hyper::service::service_fn(|req| tower_svc.clone().call(req));
513            let http = hyper::server::conn::http1::Builder::new();
514            http.serve_connection(TokioIo::new(conn), service)
515                .with_upgrades()
516                .err_into()
517                .await
518        })
519    }
520}
521
522type Delayed<T> = Shared<oneshot::Receiver<T>>;
523
524#[derive(Clone)]
525enum ConnProtocol {
526    Http,
527    Https,
528}
529
530#[derive(Clone, Debug)]
531pub struct AuthedUser {
532    name: String,
533    external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
534}
535
536pub struct AuthedClient {
537    pub client: SessionClient,
538    pub connection_guard: Option<ConnectionHandle>,
539}
540
541impl AuthedClient {
542    async fn new<F>(
543        adapter_client: &Client,
544        user: AuthedUser,
545        peer_addr: IpAddr,
546        active_connection_counter: ConnectionCounter,
547        helm_chart_version: Option<String>,
548        session_config: F,
549        options: BTreeMap<String, String>,
550        now: NowFn,
551    ) -> Result<Self, AdapterError>
552    where
553        F: FnOnce(&mut Session),
554    {
555        let conn_id = adapter_client.new_conn_id()?;
556        let mut session = adapter_client.new_session(SessionConfig {
557            conn_id,
558            uuid: epoch_to_uuid_v7(&(now)()),
559            user: user.name,
560            client_ip: Some(peer_addr),
561            external_metadata_rx: user.external_metadata_rx,
562            //TODO(dov): Add support for internal user metadata when we support auth here
563            internal_user_metadata: None,
564            helm_chart_version,
565        });
566        let connection_guard = active_connection_counter.allocate_connection(session.user())?;
567
568        session_config(&mut session);
569        let system_vars = adapter_client.get_system_vars().await;
570        for (key, val) in options {
571            const LOCAL: bool = false;
572            if let Err(err) =
573                session
574                    .vars_mut()
575                    .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
576            {
577                session.add_notice(AdapterNotice::BadStartupSetting {
578                    name: key.to_string(),
579                    reason: err.to_string(),
580                })
581            }
582        }
583        let adapter_client = adapter_client.startup(session).await?;
584        Ok(AuthedClient {
585            client: adapter_client,
586            connection_guard,
587        })
588    }
589}
590
591impl<S> FromRequestParts<S> for AuthedClient
592where
593    S: Send + Sync,
594{
595    type Rejection = Response;
596
597    fn from_request_parts(
598        req: &mut http::request::Parts,
599        state: &S,
600    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
601        Box::pin(async move {
602            #[derive(Debug, Default, Deserialize)]
603            struct Params {
604                #[serde(default)]
605                options: String,
606            }
607            let params: Query<Params> = Query::from_request_parts(req, state)
608                .await
609                .unwrap_or_default();
610
611            let peer_addr = req
612                .extensions
613                .get::<ConnectInfo<SocketAddr>>()
614                .expect("ConnectInfo extension guaranteed to exist")
615                .0
616                .ip();
617
618            let user = req.extensions.get::<AuthedUser>().unwrap();
619            let adapter_client = req
620                .extensions
621                .get::<Delayed<mz_adapter::Client>>()
622                .unwrap()
623                .clone();
624            let adapter_client = adapter_client.await.map_err(|_| {
625                (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
626            })?;
627            let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
628            let helm_chart_version = None;
629
630            let options = if params.options.is_empty() {
631                // It's possible 'options' simply wasn't provided, we don't want that to
632                // count as a failure to deserialize
633                BTreeMap::<String, String>::default()
634            } else {
635                match serde_json::from_str(&params.options) {
636                    Ok(options) => options,
637                    Err(_e) => {
638                        // If we fail to deserialize options, fail the request.
639                        let code = StatusCode::BAD_REQUEST;
640                        let msg = format!("Failed to deserialize {} map", "options".quoted());
641                        return Err((code, msg).into_response());
642                    }
643                }
644            };
645
646            let client = AuthedClient::new(
647                &adapter_client,
648                user.clone(),
649                peer_addr,
650                active_connection_counter.clone(),
651                helm_chart_version,
652                |session| {
653                    session
654                        .vars_mut()
655                        .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
656                        .expect("known to exist")
657                },
658                options,
659                SYSTEM_TIME.clone(),
660            )
661            .await
662            .map_err(|e| {
663                let status = match e {
664                    AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
665                        StatusCode::FORBIDDEN
666                    }
667                    _ => StatusCode::INTERNAL_SERVER_ERROR,
668                };
669                (status, Json(SqlError::from(e))).into_response()
670            })?;
671
672            Ok(client)
673        })
674    }
675}
676
677#[derive(Debug, Error)]
678enum AuthError {
679    #[error("HTTPS is required")]
680    HttpsRequired,
681    #[error("invalid username in client certificate")]
682    InvalidLogin(String),
683    #[error("{0}")]
684    Frontegg(#[from] FronteggError),
685    #[error("missing authorization header")]
686    MissingHttpAuthentication,
687    #[error("{0}")]
688    MismatchedUser(String),
689    #[error("unexpected credentials")]
690    UnexpectedCredentials,
691}
692
693impl IntoResponse for AuthError {
694    fn into_response(self) -> Response {
695        warn!("HTTP request failed authentication: {}", self);
696        // We omit most detail from the error message we send to the client, to
697        // avoid giving attackers unnecessary information.
698        let message = match self {
699            AuthError::HttpsRequired => self.to_string(),
700            _ => "unauthorized".into(),
701        };
702        (
703            StatusCode::UNAUTHORIZED,
704            [(http::header::WWW_AUTHENTICATE, "Basic realm=Materialize")],
705            message,
706        )
707            .into_response()
708    }
709}
710
711async fn http_auth(
712    mut req: Request,
713    next: Next,
714    tls_mode: TlsMode,
715    frontegg: Option<&FronteggAuthentication>,
716) -> impl IntoResponse + use<> {
717    // First, extract the username from the certificate, validating that the
718    // connection matches the TLS configuration along the way.
719    let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
720    match (tls_mode, &conn_protocol) {
721        (TlsMode::Disable, ConnProtocol::Http) => {}
722        (TlsMode::Disable, ConnProtocol::Https { .. }) => unreachable!(),
723        (TlsMode::Require, ConnProtocol::Http) => return Err(AuthError::HttpsRequired),
724        (TlsMode::Require, ConnProtocol::Https { .. }) => {}
725    }
726    let creds = match frontegg {
727        None => {
728            // If no Frontegg authentication, use whatever is in the HTTP auth
729            // header (without checking the password), or fall back to the
730            // default user.
731            if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
732                Credentials::User(basic.username().to_string())
733            } else {
734                Credentials::DefaultUser
735            }
736        }
737        Some(_) => {
738            if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
739                Credentials::Password {
740                    username: basic.username().to_string(),
741                    password: basic.password().to_string(),
742                }
743            } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
744                Credentials::Token {
745                    token: bearer.token().to_string(),
746                }
747            } else {
748                return Err(AuthError::MissingHttpAuthentication);
749            }
750        }
751    };
752
753    let user = auth(frontegg, creds).await?;
754
755    // Add the authenticated user as an extension so downstream handlers can
756    // inspect it if necessary.
757    req.extensions_mut().insert(user);
758
759    // Run the request.
760    Ok(next.run(req).await)
761}
762
763async fn init_ws(
764    WsState {
765        frontegg,
766        adapter_client_rx,
767        active_connection_counter,
768        helm_chart_version,
769    }: &WsState,
770    existing_user: Option<AuthedUser>,
771    peer_addr: IpAddr,
772    ws: &mut WebSocket,
773) -> Result<AuthedClient, anyhow::Error> {
774    // TODO: Add a timeout here to prevent resource leaks by clients that
775    // connect then never send a message.
776    let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
777    let ws_auth: WebSocketAuth = loop {
778        match init_msg {
779            Message::Text(data) => break serde_json::from_str(&data)?,
780            Message::Binary(data) => break serde_json::from_slice(&data)?,
781            // Handled automatically by the server.
782            Message::Ping(_) => {
783                continue;
784            }
785            Message::Pong(_) => {
786                continue;
787            }
788            Message::Close(_) => {
789                anyhow::bail!("closed");
790            }
791        }
792    };
793    let (user, options) = match (frontegg.as_ref(), existing_user, ws_auth) {
794        (Some(frontegg), None, ws_auth) => {
795            let (creds, options) = match ws_auth {
796                WebSocketAuth::Basic {
797                    user,
798                    password,
799                    options,
800                } => {
801                    let creds = Credentials::Password {
802                        username: user,
803                        password,
804                    };
805                    (creds, options)
806                }
807                WebSocketAuth::Bearer { token, options } => {
808                    let creds = Credentials::Token { token };
809                    (creds, options)
810                }
811                WebSocketAuth::OptionsOnly { options: _ } => {
812                    anyhow::bail!("expected auth information");
813                }
814            };
815            (auth(Some(frontegg), creds).await?, options)
816        }
817        (
818            None,
819            None,
820            WebSocketAuth::Basic {
821                user,
822                password: _,
823                options,
824            },
825        ) => (auth(None, Credentials::User(user)).await?, options),
826        // No frontegg, specified existing user, we only accept options only.
827        (None, Some(existing_user), WebSocketAuth::OptionsOnly { options }) => {
828            (existing_user, options)
829        }
830        // No frontegg, specified existing user, we do not expect basic or bearer auth.
831        (None, Some(_), WebSocketAuth::Basic { .. } | WebSocketAuth::Bearer { .. }) => {
832            warn!("Unexpected bearer or basic auth provided when using user header");
833            anyhow::bail!("unexpected")
834        }
835        // Specifying both frontegg and an existing user should not be possible.
836        (Some(_), Some(_), _) => anyhow::bail!("unexpected"),
837        // No frontegg, no existing user, and no passed username.
838        (None, None, WebSocketAuth::Bearer { .. } | WebSocketAuth::OptionsOnly { .. }) => {
839            warn!("Unexpected auth type when not using frontegg or user header");
840            anyhow::bail!("unexpected")
841        }
842    };
843
844    let client = AuthedClient::new(
845        &adapter_client_rx.clone().await?,
846        user,
847        peer_addr,
848        active_connection_counter.clone(),
849        helm_chart_version.clone(),
850        |_session| (),
851        options,
852        SYSTEM_TIME.clone(),
853    )
854    .await?;
855
856    Ok(client)
857}
858
859enum Credentials {
860    User(String),
861    DefaultUser,
862    Password { username: String, password: String },
863    Token { token: String },
864}
865
866async fn auth(
867    frontegg: Option<&FronteggAuthentication>,
868    creds: Credentials,
869) -> Result<AuthedUser, AuthError> {
870    // There are three places a username may be specified:
871    //
872    //   - certificate common name
873    //   - HTTP Basic authentication
874    //   - JWT email address
875    //
876    // We verify that if any of these are present, they must match any other
877    // that is also present.
878
879    // Then, handle Frontegg authentication if required.
880    let (name, external_metadata_rx) = match (frontegg, creds) {
881        // If no Frontegg authentication, allow the default user.
882        (None, Credentials::DefaultUser) => (HTTP_DEFAULT_USER.name.to_string(), None),
883        // If no Frontegg authentication, allow a protocol-specified user.
884        (None, Credentials::User(name)) => (name, None),
885        // With frontegg disabled, specifying credentials is an error.
886        (None, _) => return Err(AuthError::UnexpectedCredentials),
887        // If we require Frontegg auth, fetch credentials from the HTTP auth
888        // header. Basic auth comes with a username/password, where the password
889        // is the client+secret pair. Bearer auth is an existing JWT that must
890        // be validated. In either case, if a username was specified in the
891        // client cert, it must match that of the JWT.
892        (Some(frontegg), creds) => match creds {
893            Credentials::Password { username, password } => {
894                let auth_session = frontegg.authenticate(&username, &password).await?;
895                let user = auth_session.user().into();
896                let external_metadata_rx = Some(auth_session.external_metadata_rx());
897                (user, external_metadata_rx)
898            }
899            Credentials::Token { token } => {
900                let claims = frontegg.validate_access_token(&token, None)?;
901                let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
902                    user_id: claims.user_id,
903                    admin: claims.is_admin,
904                });
905                (claims.user, Some(external_metadata_rx))
906            }
907            Credentials::DefaultUser | Credentials::User(_) => {
908                return Err(AuthError::MissingHttpAuthentication);
909            }
910        },
911    };
912
913    if mz_adapter::catalog::is_reserved_role_name(name.as_str()) {
914        return Err(AuthError::InvalidLogin(name));
915    }
916    Ok(AuthedUser {
917        name,
918        external_metadata_rx,
919    })
920}
921
922/// Configuration for [`base_router`].
923struct BaseRouterConfig {
924    /// Whether to enable the profiling routes.
925    profiling: bool,
926}
927
928/// Returns the router for routes that are shared between the internal and
929/// external HTTP servers.
930fn base_router(BaseRouterConfig { profiling }: BaseRouterConfig) -> Router {
931    // Adding a layer with in this function will only apply to the routes defined in this function.
932    // https://docs.rs/axum/0.6.1/axum/routing/struct.Router.html#method.layer
933    let mut router = Router::new()
934        .route(
935            "/",
936            routing::get(move || async move { root::handle_home(profiling).await }),
937        )
938        .route("/api/sql", routing::post(sql::handle_sql))
939        .route("/memory", routing::get(memory::handle_memory))
940        .route(
941            "/hierarchical-memory",
942            routing::get(memory::handle_hierarchical_memory),
943        )
944        .route("/static/{*path}", routing::get(root::handle_static));
945
946    if profiling {
947        router = router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
948    }
949
950    router
951}
952
953/// Default layers that should be applied to all routes, and should get applied to both the
954/// internal http and external http routers.
955trait DefaultLayers {
956    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
957}
958
959impl DefaultLayers for Router {
960    fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
961        self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
962            .layer(metrics::PrometheusLayer::new(source, metrics))
963    }
964}
965
966/// Glue code to make [`tower`] work with [`axum`].
967///
968/// `axum` requires `Layer`s not return Errors, i.e. they must be `Result<_, Infallible>`,
969/// instead you must return a type that can be converted into a response. `tower` on the other
970/// hand does return Errors, so to make the two work together we need to convert our `tower` errors
971/// into responses.
972async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
973    if error.is::<tower::load_shed::error::Overloaded>() {
974        return (
975            StatusCode::TOO_MANY_REQUESTS,
976            Cow::from("too many requests, try again later"),
977        );
978    }
979
980    // Note: This should be unreachable because at the time of writing our only use case is a
981    // layer that emits `tower::load_shed::error::Overloaded`, which is handled above.
982    (
983        StatusCode::INTERNAL_SERVER_ERROR,
984        Cow::from(format!("Unhandled internal error: {}", error)),
985    )
986}