1#![allow(clippy::unused_async)]
51
52use std::borrow::Cow;
53use std::collections::BTreeMap;
54use std::fmt::Debug;
55use std::net::{IpAddr, SocketAddr};
56use std::pin::Pin;
57use std::sync::Arc;
58use std::time::{Duration, SystemTime};
59
60use anyhow::Context;
61use axum::error_handling::HandleErrorLayer;
62use axum::extract::ws::{Message, WebSocket};
63use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
64use axum::middleware::{self, Next};
65use axum::response::{IntoResponse, Redirect, Response};
66use axum::{Extension, Json, Router, routing};
67use futures::future::{Shared, TryFutureExt};
68use headers::authorization::{Authorization, Basic, Bearer};
69use headers::{HeaderMapExt, HeaderName};
70use http::header::{AUTHORIZATION, CONTENT_TYPE};
71use http::uri::Scheme;
72use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
73use hyper_openssl::SslStream;
74use hyper_openssl::client::legacy::MaybeHttpsStream;
75use hyper_util::rt::TokioIo;
76use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
77use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
78use mz_auth::Authenticated;
79use mz_auth::password::Password;
80use mz_authenticator::Authenticator;
81use mz_controller::ReplicaHttpLocator;
82use mz_frontegg_auth::Error as FronteggError;
83use mz_http_util::DynamicFilterTarget;
84use mz_ore::cast::u64_to_usize;
85use mz_ore::metrics::MetricsRegistry;
86use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
87use mz_ore::str::StrExt;
88use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
89use mz_repr::user::ExternalUserMetadata;
90use mz_server_core::listeners::{self, AllowedRoles, HttpRoutesEnabled};
91use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
92use mz_sql::session::metadata::SessionMetadata;
93use mz_sql::session::user::{
94 HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
95};
96use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
97use openssl::ssl::Ssl;
98use prometheus::{
99 COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
100};
101use serde::{Deserialize, Serialize};
102use serde_json::json;
103use thiserror::Error;
104use tokio::io::AsyncWriteExt;
105use tokio::sync::oneshot::Receiver;
106use tokio::sync::{oneshot, watch};
107use tokio_metrics::TaskMetrics;
108use tower::limit::GlobalConcurrencyLimitLayer;
109use tower::{Service, ServiceBuilder};
110use tower_http::cors::{AllowOrigin, Any, CorsLayer};
111use tower_sessions::{
112 MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
113 SessionManagerLayer as TowerSessionManagerLayer,
114};
115use tracing::warn;
116
117use crate::BUILD_INFO;
118use crate::deployment::state::DeploymentStateHandle;
119use crate::http::sql::{ExistingUser, SqlError};
120
121mod catalog;
122mod cluster;
123mod console;
124mod mcp;
125mod memory;
126mod metrics;
127mod metrics_viz;
128mod probe;
129mod prometheus;
130mod root;
131mod sql;
132mod webhook;
133
134pub use metrics::Metrics;
135pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
136
137pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
139
140const SESSION_DURATION: Duration = Duration::from_secs(8 * 3600); const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
143
144#[derive(Debug)]
145pub struct HttpConfig {
146 pub source: &'static str,
147 pub tls: Option<ReloadingSslContext>,
148 pub authenticator_kind: listeners::AuthenticatorKind,
149 pub frontegg: Option<mz_frontegg_auth::Authenticator>,
150 pub oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
151 pub adapter_client_rx: Shared<Receiver<Client>>,
152 pub allowed_origin: AllowOrigin,
153 pub allowed_origin_list: Vec<HeaderValue>,
156 pub active_connection_counter: ConnectionCounter,
157 pub helm_chart_version: Option<String>,
158 pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
159 pub metrics: Metrics,
160 pub metrics_registry: MetricsRegistry,
161 pub allowed_roles: AllowedRoles,
162 pub internal_route_config: Arc<InternalRouteConfig>,
163 pub routes_enabled: HttpRoutesEnabled,
164 pub replica_http_locator: Arc<ReplicaHttpLocator>,
166}
167
168#[derive(Debug, Clone)]
169pub struct InternalRouteConfig {
170 pub deployment_state_handle: DeploymentStateHandle,
171 pub internal_console_redirect_url: Option<String>,
172}
173
174#[derive(Clone)]
175pub struct WsState {
176 frontegg: Option<mz_frontegg_auth::Authenticator>,
177 oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
178 authenticator_kind: listeners::AuthenticatorKind,
179 adapter_client_rx: Delayed<mz_adapter::Client>,
180 active_connection_counter: ConnectionCounter,
181 helm_chart_version: Option<String>,
182 allowed_roles: AllowedRoles,
183}
184
185#[derive(Clone)]
186pub struct WebhookState {
187 adapter_client_rx: Delayed<mz_adapter::Client>,
188 webhook_cache: WebhookAppenderCache,
189}
190
191#[derive(Debug)]
192pub struct HttpServer {
193 tls: Option<ReloadingSslContext>,
194 router: Router,
195}
196
197impl HttpServer {
198 pub fn new(
199 HttpConfig {
200 source,
201 tls,
202 authenticator_kind,
203 frontegg,
204 oidc_rx,
205 adapter_client_rx,
206 allowed_origin,
207 allowed_origin_list,
208 active_connection_counter,
209 helm_chart_version,
210 concurrent_webhook_req,
211 metrics,
212 metrics_registry,
213 allowed_roles,
214 internal_route_config,
215 routes_enabled,
216 replica_http_locator,
217 }: HttpConfig,
218 ) -> HttpServer {
219 let tls_enabled = tls.is_some();
220 let webhook_cache = WebhookAppenderCache::new();
221
222 let session_store = TowerSessionMemoryStore::default();
224 let session_layer = TowerSessionManagerLayer::new(session_store)
225 .with_secure(tls_enabled) .with_same_site(tower_sessions::cookie::SameSite::Strict) .with_http_only(true) .with_name("mz_session") .with_path("/"); let frontegg_middleware = frontegg.clone();
232 let oidc_middleware_rx = oidc_rx.clone();
233 let adapter_client_middleware_rx = adapter_client_rx.clone();
234 let auth_middleware = middleware::from_fn(move |req, next| {
235 let frontegg = frontegg_middleware.clone();
236 let oidc_rx = oidc_middleware_rx.clone();
237 let adapter_client_rx = adapter_client_middleware_rx.clone();
238 async move {
239 http_auth(
240 req,
241 next,
242 tls_enabled,
243 authenticator_kind,
244 frontegg,
245 oidc_rx,
246 adapter_client_rx,
247 allowed_roles,
248 )
249 .await
250 }
251 });
252
253 let mut router = Router::new();
254 let mut base_router = Router::new();
255 if routes_enabled.base {
256 base_router = base_router
257 .route(
258 "/",
259 routing::get(move || async move { root::handle_home(routes_enabled).await }),
260 )
261 .route("/api/sql", routing::post(sql::handle_sql))
262 .route("/memory", routing::get(memory::handle_memory))
263 .route(
264 "/hierarchical-memory",
265 routing::get(memory::handle_hierarchical_memory),
266 )
267 .route(
268 "/metrics-viz",
269 routing::get(metrics_viz::handle_metrics_viz),
270 )
271 .route("/static/{*path}", routing::get(root::handle_static));
272
273 let mut ws_router = Router::new()
274 .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
275 .with_state(WsState {
276 frontegg,
277 oidc_rx: oidc_rx.clone(),
278 authenticator_kind,
279 adapter_client_rx: adapter_client_rx.clone(),
280 active_connection_counter: active_connection_counter.clone(),
281 helm_chart_version,
282 allowed_roles,
283 });
284 if let listeners::AuthenticatorKind::None = authenticator_kind {
285 ws_router = ws_router.layer(middleware::from_fn_with_state(
286 allowed_roles,
287 x_materialize_user_header_auth,
288 ));
289 }
290 router = router.merge(ws_router);
291 }
292 if routes_enabled.profiling {
293 base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
294 }
295
296 if routes_enabled.webhook {
297 let webhook_router = Router::new()
298 .route(
299 "/api/webhook/{:database}/{:schema}/{:id}",
300 routing::post(webhook::handle_webhook),
301 )
302 .with_state(WebhookState {
303 adapter_client_rx: adapter_client_rx.clone(),
304 webhook_cache,
305 })
306 .layer(
307 tower_http::decompression::RequestDecompressionLayer::new()
308 .gzip(true)
309 .deflate(true)
310 .br(true)
311 .zstd(true),
312 )
313 .layer(
314 CorsLayer::new()
315 .allow_methods(Method::POST)
316 .allow_origin(AllowOrigin::mirror_request())
317 .allow_headers(Any),
318 )
319 .layer(
320 ServiceBuilder::new()
321 .layer(HandleErrorLayer::new(handle_load_error))
322 .load_shed()
323 .layer(GlobalConcurrencyLimitLayer::with_semaphore(
324 concurrent_webhook_req,
325 )),
326 );
327 router = router.merge(webhook_router);
328 }
329
330 if routes_enabled.internal {
331 let console_config = Arc::new(console::ConsoleProxyConfig::new(
332 internal_route_config.internal_console_redirect_url.clone(),
333 "/internal-console".to_string(),
334 ));
335 base_router = base_router
336 .route(
337 "/api/opentelemetry/config",
338 routing::put({
339 move |_: axum::Json<DynamicFilterTarget>| async {
340 (
341 StatusCode::BAD_REQUEST,
342 "This endpoint has been replaced. \
343 Use the `opentelemetry_filter` system variable."
344 .to_string(),
345 )
346 }
347 }),
348 )
349 .route(
350 "/api/stderr/config",
351 routing::put({
352 move |_: axum::Json<DynamicFilterTarget>| async {
353 (
354 StatusCode::BAD_REQUEST,
355 "This endpoint has been replaced. \
356 Use the `log_filter` system variable."
357 .to_string(),
358 )
359 }
360 }),
361 )
362 .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
363 .route(
364 "/api/catalog/dump",
365 routing::get(catalog::handle_catalog_dump),
366 )
367 .route(
368 "/api/catalog/check",
369 routing::get(catalog::handle_catalog_check),
370 )
371 .route(
372 "/api/catalog/inject-audit-events",
373 routing::post(catalog::handle_inject_audit_events),
374 )
375 .route(
376 "/api/coordinator/check",
377 routing::get(catalog::handle_coordinator_check),
378 )
379 .route(
380 "/api/coordinator/dump",
381 routing::get(catalog::handle_coordinator_dump),
382 )
383 .route(
384 "/internal-console",
385 routing::get(|| async { Redirect::temporary("/internal-console/") }),
386 )
387 .route(
388 "/internal-console/{*path}",
389 routing::get(console::handle_internal_console),
390 )
391 .route(
392 "/internal-console/",
393 routing::get(console::handle_internal_console),
394 )
395 .layer(Extension(console_config));
396
397 let cluster_proxy_config = Arc::new(cluster::ClusterProxyConfig::new(Arc::clone(
399 &replica_http_locator,
400 )));
401 base_router = base_router
402 .route("/clusters", routing::get(cluster::handle_clusters))
403 .route(
404 "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/",
405 routing::any(cluster::handle_cluster_proxy_root),
406 )
407 .route(
408 "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/{*path}",
409 routing::any(cluster::handle_cluster_proxy),
410 )
411 .layer(Extension(cluster_proxy_config));
412
413 let leader_router = Router::new()
414 .route("/api/leader/status", routing::get(handle_leader_status))
415 .route("/api/leader/promote", routing::post(handle_leader_promote))
416 .route(
417 "/api/leader/skip-catchup",
418 routing::post(handle_leader_skip_catchup),
419 )
420 .layer(auth_middleware.clone())
421 .with_state(internal_route_config.deployment_state_handle.clone());
422 router = router.merge(leader_router);
423 }
424
425 if routes_enabled.metrics {
426 let metrics_router = Router::new()
427 .route(
428 "/metrics",
429 routing::get(move || async move {
430 mz_http_util::handle_prometheus(&metrics_registry).await
431 }),
432 )
433 .route(
434 "/metrics/mz_usage",
435 routing::get(|client: AuthedClient| async move {
436 let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
437 mz_http_util::handle_prometheus(®istry).await
438 }),
439 )
440 .route(
441 "/metrics/mz_frontier",
442 routing::get(|client: AuthedClient| async move {
443 let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
444 mz_http_util::handle_prometheus(®istry).await
445 }),
446 )
447 .route(
448 "/metrics/mz_compute",
449 routing::get(|client: AuthedClient| async move {
450 let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
451 mz_http_util::handle_prometheus(®istry).await
452 }),
453 )
454 .route(
455 "/metrics/mz_storage",
456 routing::get(|client: AuthedClient| async move {
457 let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
458 mz_http_util::handle_prometheus(®istry).await
459 }),
460 )
461 .route(
462 "/api/livez",
463 routing::get(mz_http_util::handle_liveness_check),
464 )
465 .route("/api/readyz", routing::get(probe::handle_ready))
466 .layer(auth_middleware.clone())
467 .layer(Extension(adapter_client_rx.clone()))
468 .layer(Extension(active_connection_counter.clone()));
469 router = router.merge(metrics_router);
470 }
471
472 if routes_enabled.console_config {
473 let console_config_router = Router::new()
474 .route(
475 "/api/console/config",
476 routing::get(console::handle_console_config),
477 )
478 .layer(Extension(adapter_client_rx.clone()))
479 .layer(Extension(active_connection_counter.clone()));
480 router = router.merge(console_config_router);
481 }
482
483 if routes_enabled.mcp_agent || routes_enabled.mcp_developer {
486 use tracing::info;
487
488 let mut mcp_router = Router::new();
489
490 if routes_enabled.mcp_agent {
491 info!("Enabling MCP agent endpoint: /api/mcp/agent");
492 mcp_router = mcp_router.route(
493 "/api/mcp/agent",
494 routing::post(mcp::handle_mcp_agent).get(mcp::handle_mcp_method_not_allowed),
495 );
496 }
497
498 if routes_enabled.mcp_developer {
499 info!("Enabling MCP developer endpoint: /api/mcp/developer");
500 mcp_router = mcp_router.route(
501 "/api/mcp/developer",
502 routing::post(mcp::handle_mcp_developer)
503 .get(mcp::handle_mcp_method_not_allowed),
504 );
505 }
506
507 let mcp_allowed_origins = Arc::new(allowed_origin_list.clone());
514 mcp_router = mcp_router
515 .layer(auth_middleware.clone())
516 .layer(Extension(adapter_client_rx.clone()))
517 .layer(Extension(active_connection_counter.clone()))
518 .layer(Extension(mcp_allowed_origins))
519 .layer(
520 CorsLayer::new()
521 .allow_methods(Method::POST)
522 .allow_origin(allowed_origin.clone())
523 .allow_headers([AUTHORIZATION, CONTENT_TYPE]),
524 );
525 router = router.merge(mcp_router);
526 }
527
528 base_router = base_router
529 .layer(auth_middleware.clone())
530 .layer(Extension(adapter_client_rx.clone()))
531 .layer(Extension(active_connection_counter.clone()))
532 .layer(
533 CorsLayer::new()
534 .allow_credentials(false)
535 .allow_headers([
536 AUTHORIZATION,
537 CONTENT_TYPE,
538 HeaderName::from_static("x-materialize-version"),
539 ])
540 .allow_methods(Any)
541 .allow_origin(allowed_origin)
542 .expose_headers(Any)
543 .max_age(Duration::from_secs(60) * 60),
544 );
545
546 match authenticator_kind {
547 listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Oidc => {
548 base_router = base_router.layer(session_layer.clone());
549
550 let login_router = Router::new()
551 .route("/api/login", routing::post(handle_login))
552 .route("/api/logout", routing::post(handle_logout))
553 .layer(Extension(adapter_client_rx))
554 .layer(Extension(allowed_roles));
555 router = router.merge(login_router).layer(session_layer);
556 }
557 listeners::AuthenticatorKind::None => {
558 base_router = base_router.layer(middleware::from_fn_with_state(
559 allowed_roles,
560 x_materialize_user_header_auth,
561 ));
562 }
563 _ => {}
564 }
565
566 router = router
567 .merge(base_router)
568 .apply_default_layers(source, metrics);
569
570 HttpServer { tls, router }
571 }
572}
573
574impl Server for HttpServer {
575 const NAME: &'static str = "http";
576
577 fn handle_connection(
578 &self,
579 conn: Connection,
580 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
581 ) -> ConnectionHandler {
582 let router = self.router.clone();
583 let tls_context = self.tls.clone();
584 let mut conn = TokioIo::new(conn);
585
586 Box::pin(async {
587 let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
588 let peer_addr = conn
589 .inner_mut()
590 .take_proxy_header_address()
591 .await
592 .map(|a| a.source)
593 .unwrap_or(direct_peer_addr);
594
595 let (conn, conn_protocol) = match tls_context {
596 Some(tls_context) => {
597 let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
598 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
599 let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
600 return Err(e.into());
601 }
602 (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
603 }
604 _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
605 };
606 let mut make_tower_svc = router
607 .layer(Extension(conn_protocol))
608 .into_make_service_with_connect_info::<SocketAddr>();
609 let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
610 let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
611 let http = hyper::server::conn::http1::Builder::new();
612 http.serve_connection(conn, hyper_svc)
613 .with_upgrades()
614 .err_into()
615 .await
616 })
617 }
618}
619
620pub async fn handle_leader_status(
621 State(deployment_state_handle): State<DeploymentStateHandle>,
622) -> impl IntoResponse {
623 let status = deployment_state_handle.status();
624 (StatusCode::OK, Json(json!({ "status": status })))
625}
626
627pub async fn handle_leader_promote(
628 State(deployment_state_handle): State<DeploymentStateHandle>,
629) -> impl IntoResponse {
630 match deployment_state_handle.try_promote() {
631 Ok(()) => {
632 let status = StatusCode::OK;
635 let body = Json(json!({
636 "result": "Success",
637 }));
638 (status, body)
639 }
640 Err(()) => {
641 let status = StatusCode::BAD_REQUEST;
644 let body = Json(json!({
645 "result": {"Failure": {"message": "cannot promote leader while initializing"}},
646 }));
647 (status, body)
648 }
649 }
650}
651
652pub async fn handle_leader_skip_catchup(
653 State(deployment_state_handle): State<DeploymentStateHandle>,
654) -> impl IntoResponse {
655 match deployment_state_handle.try_skip_catchup() {
656 Ok(()) => StatusCode::NO_CONTENT.into_response(),
657 Err(()) => {
658 let status = StatusCode::BAD_REQUEST;
659 let body = Json(json!({
660 "message": "cannot skip catchup in this phase of initialization; try again later",
661 }));
662 (status, body).into_response()
663 }
664 }
665}
666
667async fn x_materialize_user_header_auth(
668 State(allowed_roles): State<AllowedRoles>,
669 mut req: Request,
670 next: Next,
671) -> impl IntoResponse {
672 if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
674 let username = match username {
675 Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
676 _ => {
677 return Err(AuthError::MismatchedUser(format!(
678 "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
679 )));
680 }
681 };
682 check_role_allowed(&username, allowed_roles)?;
687 req.extensions_mut().insert(AuthedUser {
688 name: username,
689 external_metadata_rx: None,
690 authenticated: Authenticated,
691 authenticator_kind: mz_auth::AuthenticatorKind::None,
692 });
693 }
694 Ok(next.run(req).await)
695}
696
697type Delayed<T> = Shared<oneshot::Receiver<T>>;
698
699#[derive(Clone)]
700enum ConnProtocol {
701 Http,
702 Https,
703}
704
705#[derive(Clone, Debug)]
706pub struct AuthedUser {
707 name: String,
708 external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
709 authenticated: Authenticated,
710 authenticator_kind: mz_auth::AuthenticatorKind,
711}
712
713pub struct AuthedClient {
714 pub client: SessionClient,
715 pub connection_guard: Option<ConnectionHandle>,
716}
717
718impl AuthedClient {
719 async fn new<F>(
720 adapter_client: &Client,
721 user: AuthedUser,
722 peer_addr: IpAddr,
723 active_connection_counter: ConnectionCounter,
724 helm_chart_version: Option<String>,
725 session_config: F,
726 options: BTreeMap<String, String>,
727 now: NowFn,
728 ) -> Result<Self, AdapterError>
729 where
730 F: FnOnce(&mut AdapterSession),
731 {
732 let conn_id = adapter_client.new_conn_id()?;
733 let mut session = adapter_client.new_session(
734 AdapterSessionConfig {
735 conn_id,
736 uuid: epoch_to_uuid_v7(&(now)()),
737 user: user.name,
738 client_ip: Some(peer_addr),
739 external_metadata_rx: user.external_metadata_rx,
740 helm_chart_version,
741 authenticator_kind: user.authenticator_kind,
742 },
743 user.authenticated,
744 );
745 let connection_guard = active_connection_counter.allocate_connection(session.user())?;
746
747 session_config(&mut session);
748 let system_vars = adapter_client.get_system_vars().await;
749 for (key, val) in options {
750 const LOCAL: bool = false;
751 if let Err(err) =
752 session
753 .vars_mut()
754 .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
755 {
756 session.add_notice(AdapterNotice::BadStartupSetting {
757 name: key.to_string(),
758 reason: err.to_string(),
759 })
760 }
761 }
762 let adapter_client = adapter_client.startup(session).await?;
763 Ok(AuthedClient {
764 client: adapter_client,
765 connection_guard,
766 })
767 }
768}
769
770impl<S> FromRequestParts<S> for AuthedClient
771where
772 S: Send + Sync,
773{
774 type Rejection = Response;
775
776 async fn from_request_parts(
777 req: &mut http::request::Parts,
778 state: &S,
779 ) -> Result<Self, Self::Rejection> {
780 #[derive(Debug, Default, Deserialize)]
781 struct Params {
782 #[serde(default)]
783 options: String,
784 }
785 let params: Query<Params> = Query::from_request_parts(req, state)
786 .await
787 .unwrap_or_default();
788
789 let peer_addr = req
790 .extensions
791 .get::<ConnectInfo<SocketAddr>>()
792 .expect("ConnectInfo extension guaranteed to exist")
793 .0
794 .ip();
795
796 let user = req.extensions.get::<AuthedUser>().unwrap();
797 let adapter_client = req
798 .extensions
799 .get::<Delayed<mz_adapter::Client>>()
800 .unwrap()
801 .clone();
802 let adapter_client = adapter_client.await.map_err(|_| {
803 (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
804 })?;
805 let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
806 let helm_chart_version = None;
807
808 let options = if params.options.is_empty() {
809 BTreeMap::<String, String>::default()
812 } else {
813 match serde_json::from_str(¶ms.options) {
814 Ok(options) => options,
815 Err(_e) => {
816 let code = StatusCode::BAD_REQUEST;
818 let msg = format!("Failed to deserialize {} map", "options".quoted());
819 return Err((code, msg).into_response());
820 }
821 }
822 };
823
824 let client = AuthedClient::new(
825 &adapter_client,
826 user.clone(),
827 peer_addr,
828 active_connection_counter.clone(),
829 helm_chart_version,
830 |session| {
831 session
832 .vars_mut()
833 .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
834 .expect("known to exist")
835 },
836 options,
837 SYSTEM_TIME.clone(),
838 )
839 .await
840 .map_err(|e| {
841 let status = match e {
842 AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
843 StatusCode::FORBIDDEN
844 }
845 _ => StatusCode::INTERNAL_SERVER_ERROR,
846 };
847 (status, Json(SqlError::from(e))).into_response()
848 })?;
849
850 Ok(client)
851 }
852}
853
854#[derive(Debug, Error)]
855pub(crate) enum AuthError {
856 #[error("role dissallowed")]
857 RoleDisallowed(String),
858 #[error("{0}")]
859 Frontegg(#[from] FronteggError),
860 #[error("missing authorization header")]
861 MissingHttpAuthentication {
862 include_www_authenticate_header: bool,
863 },
864 #[error("{0}")]
865 MismatchedUser(String),
866 #[error("session expired")]
867 SessionExpired,
868 #[error("failed to update session")]
869 FailedToUpdateSession,
870 #[error("invalid credentials")]
871 InvalidCredentials,
872}
873
874impl IntoResponse for AuthError {
875 fn into_response(self) -> Response {
876 warn!("HTTP request failed authentication: {}", self);
877 let mut headers = HeaderMap::new();
878 match self {
879 AuthError::MissingHttpAuthentication {
880 include_www_authenticate_header,
881 } if include_www_authenticate_header => {
882 headers.insert(
883 http::header::WWW_AUTHENTICATE,
884 HeaderValue::from_static("Basic realm=Materialize"),
885 );
886 }
887 _ => {}
888 };
889 (StatusCode::UNAUTHORIZED, headers, "unauthorized").into_response()
892 }
893}
894
895pub async fn handle_login(
897 session: Option<Extension<TowerSession>>,
898 Extension(adapter_client_rx): Extension<Delayed<Client>>,
899 Extension(allowed_roles): Extension<AllowedRoles>,
900 Json(LoginCredentials { username, password }): Json<LoginCredentials>,
901) -> impl IntoResponse {
902 if let Err(err) = check_role_allowed(&username, allowed_roles) {
907 warn!(
908 ?err,
909 "HTTP login rejected: role not allowed on this listener"
910 );
911 return StatusCode::UNAUTHORIZED;
912 }
913 let Ok(adapter_client) = adapter_client_rx.clone().await else {
914 return StatusCode::INTERNAL_SERVER_ERROR;
915 };
916 let authenticated = match adapter_client.authenticate(&username, &password).await {
917 Ok(authenticated) => authenticated,
918 Err(err) => {
919 warn!(?err, "HTTP login failed authentication");
920 return StatusCode::UNAUTHORIZED;
921 }
922 };
923 let session_data = TowerSessionData {
925 username,
926 created_at: SystemTime::now(),
927 last_activity: SystemTime::now(),
928 authenticated,
929 authenticator_kind: mz_auth::AuthenticatorKind::Password,
930 };
931 let session = session.and_then(|Extension(session)| Some(session));
933 let Some(session) = session else {
934 return StatusCode::INTERNAL_SERVER_ERROR;
935 };
936 match session.insert("data", &session_data).await {
937 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
938 Ok(_) => StatusCode::OK,
939 }
940}
941
942pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
944 let session = session.and_then(|Extension(session)| Some(session));
945 let Some(session) = session else {
946 return StatusCode::INTERNAL_SERVER_ERROR;
947 };
948 match session.delete().await {
950 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
951 Ok(_) => StatusCode::OK,
952 }
953}
954
955async fn http_auth(
956 mut req: Request,
957 next: Next,
958 tls_enabled: bool,
959 authenticator_kind: listeners::AuthenticatorKind,
960 frontegg: Option<mz_frontegg_auth::Authenticator>,
961 oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
962 adapter_client_rx: Delayed<Client>,
963 allowed_roles: AllowedRoles,
964) -> Result<impl IntoResponse, AuthError> {
965 let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
966 Some(Credentials::Password {
967 username: basic.username().to_owned(),
968 password: Password(basic.password().to_owned()),
969 })
970 } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
971 Some(Credentials::Token {
972 token: bearer.token().to_owned(),
973 })
974 } else {
975 None
976 };
977
978 if creds.is_none()
982 && let Some((session, session_data)) =
983 maybe_get_authenticated_session(req.extensions().get::<TowerSession>()).await
984 {
985 let user = ensure_session_unexpired(session, session_data).await?;
986 check_role_allowed(&user.name, allowed_roles)?;
992 req.extensions_mut().insert(user);
994 return Ok(next.run(req).await);
995 }
996
997 let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
1001 match (tls_enabled, &conn_protocol) {
1002 (false, ConnProtocol::Http) => {}
1003 (false, ConnProtocol::Https { .. }) => unreachable!(),
1004 (true, ConnProtocol::Http) => {
1005 let mut parts = req.uri().clone().into_parts();
1006 parts.scheme = Some(Scheme::HTTPS);
1007 return Ok(Redirect::permanent(
1008 &Uri::from_parts(parts)
1009 .expect("it was already a URI, just changed the scheme")
1010 .to_string(),
1011 )
1012 .into_response());
1013 }
1014 (true, ConnProtocol::Https { .. }) => {}
1015 }
1016 if req.extensions().get::<AuthedUser>().is_some() {
1018 return Ok(next.run(req).await);
1019 }
1020
1021 let path = req.uri().path();
1022 let include_www_authenticate_header = path == "/"
1023 || PROFILING_API_ENDPOINTS
1024 .iter()
1025 .any(|prefix| path.starts_with(prefix));
1026 let authenticator = get_authenticator(
1027 authenticator_kind,
1028 creds.as_ref(),
1029 frontegg,
1030 &oidc_rx,
1031 &adapter_client_rx,
1032 )
1033 .await;
1034
1035 let user = auth(
1036 &authenticator,
1037 creds,
1038 allowed_roles,
1039 include_www_authenticate_header,
1040 )
1041 .await?;
1042
1043 req.extensions_mut().insert(user);
1046
1047 Ok(next.run(req).await)
1049}
1050
1051async fn init_ws(
1052 WsState {
1053 frontegg,
1054 oidc_rx,
1055 authenticator_kind,
1056 adapter_client_rx,
1057 active_connection_counter,
1058 helm_chart_version,
1059 allowed_roles,
1060 }: WsState,
1061 existing_user: Option<ExistingUser>,
1062 peer_addr: IpAddr,
1063 ws: &mut WebSocket,
1064) -> Result<AuthedClient, anyhow::Error> {
1065 let ws_auth: WebSocketAuth = loop {
1068 let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
1069 match init_msg {
1070 Message::Text(data) => break serde_json::from_str(&data)?,
1071 Message::Binary(data) => break serde_json::from_slice(&data)?,
1072 Message::Ping(_) => {
1074 continue;
1075 }
1076 Message::Pong(_) => {
1077 continue;
1078 }
1079 Message::Close(_) => {
1080 anyhow::bail!("closed");
1081 }
1082 }
1083 };
1084
1085 let (creds, options) = match ws_auth {
1088 WebSocketAuth::Basic {
1089 user,
1090 password,
1091 options,
1092 } => {
1093 let creds = Credentials::Password {
1094 username: user,
1095 password,
1096 };
1097 (Some(creds), options)
1098 }
1099 WebSocketAuth::Bearer { token, options } => {
1100 let creds = Credentials::Token { token };
1101 (Some(creds), options)
1102 }
1103 WebSocketAuth::OptionsOnly { options } => (None, options),
1104 };
1105
1106 let user = match (existing_user, creds) {
1107 (Some(ExistingUser::XMaterializeUserHeader(_)), Some(_creds)) => {
1108 warn!("Unexpected bearer or basic auth provided when using user header");
1109 anyhow::bail!("unexpected")
1110 }
1111 (Some(ExistingUser::Session(user)), None) => {
1112 check_role_allowed(&user.name, allowed_roles)?;
1118 user
1119 }
1120 (Some(ExistingUser::XMaterializeUserHeader(user)), None) => user,
1121 (_, Some(creds)) => {
1122 let authenticator = get_authenticator(
1123 authenticator_kind,
1124 Some(&creds),
1125 frontegg,
1126 &oidc_rx,
1127 &adapter_client_rx,
1128 )
1129 .await;
1130 let user = auth(&authenticator, Some(creds), allowed_roles, false).await?;
1131 user
1132 }
1133 (None, None) => anyhow::bail!("expected auth information"),
1134 };
1135
1136 let client = AuthedClient::new(
1137 &adapter_client_rx.clone().await?,
1138 user,
1139 peer_addr,
1140 active_connection_counter.clone(),
1141 helm_chart_version.clone(),
1142 |_session| (),
1143 options,
1144 SYSTEM_TIME.clone(),
1145 )
1146 .await?;
1147
1148 Ok(client)
1149}
1150
1151enum Credentials {
1152 Password {
1153 username: String,
1154 password: Password,
1155 },
1156 Token {
1157 token: String,
1158 },
1159}
1160
1161async fn get_authenticator(
1162 kind: listeners::AuthenticatorKind,
1163 creds: Option<&Credentials>,
1164 frontegg: Option<mz_frontegg_auth::Authenticator>,
1165 oidc_rx: &Delayed<mz_authenticator::GenericOidcAuthenticator>,
1166 adapter_client_rx: &Delayed<Client>,
1167) -> Authenticator {
1168 match kind {
1169 listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
1170 "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
1171 )),
1172 listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Sasl => {
1173 let client = adapter_client_rx.clone().await.expect("sender not dropped");
1174 Authenticator::Password(client)
1175 }
1176 listeners::AuthenticatorKind::Oidc => match creds {
1177 Some(Credentials::Password { .. }) => {
1179 let client = adapter_client_rx.clone().await.expect("sender not dropped");
1180 Authenticator::Password(client)
1181 }
1182 _ => Authenticator::Oidc(oidc_rx.clone().await.expect("sender not dropped")),
1183 },
1184 listeners::AuthenticatorKind::None => Authenticator::None,
1185 }
1186}
1187
1188pub(crate) async fn maybe_get_authenticated_session(
1192 session: Option<&TowerSession>,
1193) -> Option<(&TowerSession, TowerSessionData)> {
1194 if let Some(session) = session {
1195 if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
1196 return Some((session, session_data));
1197 }
1198 }
1199 None
1200}
1201
1202pub(crate) async fn ensure_session_unexpired(
1205 session: &TowerSession,
1206 session_data: TowerSessionData,
1207) -> Result<AuthedUser, AuthError> {
1208 if session_data
1209 .last_activity
1210 .elapsed()
1211 .unwrap_or(Duration::MAX)
1212 > SESSION_DURATION
1213 {
1214 let _ = session.delete().await;
1215 return Err(AuthError::SessionExpired);
1216 }
1217 let mut updated_data = session_data.clone();
1218 updated_data.last_activity = SystemTime::now();
1219 session
1220 .insert("data", &updated_data)
1221 .await
1222 .map_err(|_| AuthError::FailedToUpdateSession)?;
1223
1224 Ok(AuthedUser {
1225 name: session_data.username,
1226 external_metadata_rx: None,
1227 authenticated: session_data.authenticated,
1228 authenticator_kind: session_data.authenticator_kind,
1229 })
1230}
1231
1232async fn auth(
1233 authenticator: &Authenticator,
1234 creds: Option<Credentials>,
1235 allowed_roles: AllowedRoles,
1236 include_www_authenticate_header: bool,
1237) -> Result<AuthedUser, AuthError> {
1238 let (name, external_metadata_rx, authenticated) = match authenticator {
1239 Authenticator::Frontegg(frontegg) => match creds {
1240 Some(Credentials::Password { username, password }) => {
1241 let (auth_session, authenticated) =
1242 frontegg.authenticate(&username, password.as_str()).await?;
1243 let name = auth_session.user().into();
1244 let external_metadata_rx = Some(auth_session.external_metadata_rx());
1245 (name, external_metadata_rx, authenticated)
1246 }
1247 Some(Credentials::Token { token }) => {
1248 let (claims, authenticated) = frontegg.validate_access_token(&token, None)?;
1249 let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
1250 user_id: claims.user_id,
1251 admin: claims.is_admin,
1252 });
1253 (claims.user, Some(external_metadata_rx), authenticated)
1254 }
1255 None => {
1256 return Err(AuthError::MissingHttpAuthentication {
1257 include_www_authenticate_header,
1258 });
1259 }
1260 },
1261 Authenticator::Password(adapter_client) => match creds {
1262 Some(Credentials::Password { username, password }) => {
1263 let authenticated = adapter_client
1264 .authenticate(&username, &password)
1265 .await
1266 .map_err(|_| AuthError::InvalidCredentials)?;
1267 (username, None, authenticated)
1268 }
1269 _ => {
1270 return Err(AuthError::MissingHttpAuthentication {
1271 include_www_authenticate_header,
1272 });
1273 }
1274 },
1275 Authenticator::Sasl(_) => {
1276 return Err(AuthError::MissingHttpAuthentication {
1280 include_www_authenticate_header,
1281 });
1282 }
1283 Authenticator::Oidc(oidc) => match creds {
1284 Some(Credentials::Token { token }) => {
1285 let (mut claims, authenticated) = oidc
1287 .authenticate(&token, None)
1288 .await
1289 .map_err(|_| AuthError::InvalidCredentials)?;
1290 let name = std::mem::take(&mut claims.user);
1291 (name, None, authenticated)
1292 }
1293 _ => {
1294 return Err(AuthError::MissingHttpAuthentication {
1295 include_www_authenticate_header,
1296 });
1297 }
1298 },
1299 Authenticator::None => {
1300 let name = match creds {
1304 Some(Credentials::Password { username, .. }) => username,
1305 _ => HTTP_DEFAULT_USER.name.to_owned(),
1306 };
1307 (name, None, Authenticated)
1308 }
1309 };
1310
1311 check_role_allowed(&name, allowed_roles)?;
1312
1313 Ok(AuthedUser {
1314 name,
1315 external_metadata_rx,
1316 authenticated,
1317 authenticator_kind: authenticator.kind(),
1318 })
1319}
1320
1321fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1323 let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1324 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1326 let role_allowed = match allowed_roles {
1327 AllowedRoles::Normal => !is_reserved_user,
1328 AllowedRoles::Internal => is_internal_user,
1329 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1330 };
1331 if role_allowed {
1332 Ok(())
1333 } else {
1334 Err(AuthError::RoleDisallowed(name.to_owned()))
1335 }
1336}
1337
1338trait DefaultLayers {
1341 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1342}
1343
1344impl DefaultLayers for Router {
1345 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1346 self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1347 .layer(metrics::PrometheusLayer::new(source, metrics))
1348 }
1349}
1350
1351async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1358 if error.is::<tower::load_shed::error::Overloaded>() {
1359 return (
1360 StatusCode::TOO_MANY_REQUESTS,
1361 Cow::from("too many requests, try again later"),
1362 );
1363 }
1364
1365 (
1368 StatusCode::INTERNAL_SERVER_ERROR,
1369 Cow::from(format!("Unhandled internal error: {}", error)),
1370 )
1371}
1372
1373#[derive(Debug, Deserialize, Serialize, PartialEq)]
1374pub struct LoginCredentials {
1375 username: String,
1376 password: Password,
1377}
1378
1379#[derive(Debug, Clone, Serialize, Deserialize)]
1380pub struct TowerSessionData {
1381 username: String,
1382 created_at: SystemTime,
1383 last_activity: SystemTime,
1384 authenticated: Authenticated,
1385 authenticator_kind: mz_auth::AuthenticatorKind,
1386}
1387
1388#[cfg(test)]
1389mod tests {
1390 use super::{AllowedRoles, check_role_allowed};
1391
1392 #[mz_ore::test]
1393 fn test_check_role_allowed() {
1394 assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1396 assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1397 assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1398
1399 assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1401 assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1402 assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1403
1404 assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1406 assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1407 assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1408
1409 assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1411 assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1412 assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1413
1414 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1416 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1417 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1418
1419 assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1421 assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1422 assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1423
1424 assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1426 assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1427 assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1428
1429 assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1431 assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1432 assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1433
1434 assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1436 assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1437 assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1438
1439 assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1441 assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1442 assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1443 }
1444}