1#![allow(clippy::unused_async)]
51
52use std::borrow::Cow;
53use std::collections::BTreeMap;
54use std::fmt::Debug;
55use std::net::{IpAddr, SocketAddr};
56use std::pin::Pin;
57use std::sync::Arc;
58use std::time::{Duration, SystemTime};
59
60use anyhow::Context;
61use axum::error_handling::HandleErrorLayer;
62use axum::extract::ws::{Message, WebSocket};
63use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
64use axum::middleware::{self, Next};
65use axum::response::{IntoResponse, Redirect, Response};
66use axum::{Extension, Json, Router, routing};
67use futures::future::{Shared, TryFutureExt};
68use headers::authorization::{Authorization, Basic, Bearer};
69use headers::{HeaderMapExt, HeaderName};
70use http::header::{AUTHORIZATION, CONTENT_TYPE};
71use http::uri::Scheme;
72use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
73use hyper_openssl::SslStream;
74use hyper_openssl::client::legacy::MaybeHttpsStream;
75use hyper_util::rt::TokioIo;
76use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
77use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
78use mz_auth::Authenticated;
79use mz_auth::password::Password;
80use mz_authenticator::Authenticator;
81use mz_controller::ReplicaHttpLocator;
82use mz_frontegg_auth::Error as FronteggError;
83use mz_http_util::DynamicFilterTarget;
84use mz_ore::cast::u64_to_usize;
85use mz_ore::metrics::MetricsRegistry;
86use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
87use mz_ore::str::StrExt;
88use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
89use mz_repr::user::ExternalUserMetadata;
90use mz_server_core::listeners::{self, AllowedRoles, HttpRoutesEnabled};
91use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
92use mz_sql::session::metadata::SessionMetadata;
93use mz_sql::session::user::{
94 HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
95};
96use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
97use openssl::ssl::Ssl;
98use prometheus::{
99 COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
100};
101use serde::{Deserialize, Serialize};
102use serde_json::json;
103use thiserror::Error;
104use tokio::io::AsyncWriteExt;
105use tokio::sync::oneshot::Receiver;
106use tokio::sync::{oneshot, watch};
107use tokio_metrics::TaskMetrics;
108use tower::limit::GlobalConcurrencyLimitLayer;
109use tower::{Service, ServiceBuilder};
110use tower_http::cors::{AllowOrigin, Any, CorsLayer};
111use tower_sessions::{
112 MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
113 SessionManagerLayer as TowerSessionManagerLayer,
114};
115use tracing::warn;
116
117use crate::BUILD_INFO;
118use crate::deployment::state::DeploymentStateHandle;
119use crate::http::sql::{ExistingUser, SqlError};
120
121mod catalog;
122mod cluster;
123mod console;
124mod mcp;
125mod memory;
126mod metrics;
127mod metrics_viz;
128mod probe;
129mod prometheus;
130mod root;
131mod sql;
132mod webhook;
133
134pub use metrics::Metrics;
135pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
136
137pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
139
140const SESSION_DURATION: Duration = Duration::from_secs(8 * 3600); const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
143
144#[derive(Debug)]
145pub struct HttpConfig {
146 pub source: &'static str,
147 pub tls: Option<ReloadingSslContext>,
148 pub authenticator_kind: listeners::AuthenticatorKind,
149 pub frontegg: Option<mz_frontegg_auth::Authenticator>,
150 pub oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
151 pub adapter_client_rx: Shared<Receiver<Client>>,
152 pub allowed_origin: AllowOrigin,
153 pub allowed_origin_list: Vec<HeaderValue>,
156 pub active_connection_counter: ConnectionCounter,
157 pub helm_chart_version: Option<String>,
158 pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
159 pub metrics: Metrics,
160 pub metrics_registry: MetricsRegistry,
161 pub allowed_roles: AllowedRoles,
162 pub internal_route_config: Arc<InternalRouteConfig>,
163 pub routes_enabled: HttpRoutesEnabled,
164 pub replica_http_locator: Arc<ReplicaHttpLocator>,
166}
167
168#[derive(Debug, Clone)]
169pub struct InternalRouteConfig {
170 pub deployment_state_handle: DeploymentStateHandle,
171 pub internal_console_redirect_url: Option<String>,
172}
173
174#[derive(Clone)]
175pub struct WsState {
176 frontegg: Option<mz_frontegg_auth::Authenticator>,
177 oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
178 authenticator_kind: listeners::AuthenticatorKind,
179 adapter_client_rx: Delayed<mz_adapter::Client>,
180 active_connection_counter: ConnectionCounter,
181 helm_chart_version: Option<String>,
182 allowed_roles: AllowedRoles,
183}
184
185#[derive(Clone)]
186pub struct WebhookState {
187 adapter_client_rx: Delayed<mz_adapter::Client>,
188 webhook_cache: WebhookAppenderCache,
189}
190
191#[derive(Clone, Debug)]
192struct HelmChartVersion(Option<String>);
193
194#[derive(Debug)]
195pub struct HttpServer {
196 tls: Option<ReloadingSslContext>,
197 router: Router,
198}
199
200impl HttpServer {
201 pub fn new(
202 HttpConfig {
203 source,
204 tls,
205 authenticator_kind,
206 frontegg,
207 oidc_rx,
208 adapter_client_rx,
209 allowed_origin,
210 allowed_origin_list,
211 active_connection_counter,
212 helm_chart_version,
213 concurrent_webhook_req,
214 metrics,
215 metrics_registry,
216 allowed_roles,
217 internal_route_config,
218 routes_enabled,
219 replica_http_locator,
220 }: HttpConfig,
221 ) -> HttpServer {
222 let tls_enabled = tls.is_some();
223 let webhook_cache = WebhookAppenderCache::new();
224
225 let session_store = TowerSessionMemoryStore::default();
227 let session_layer = TowerSessionManagerLayer::new(session_store)
228 .with_secure(tls_enabled) .with_same_site(tower_sessions::cookie::SameSite::Strict) .with_http_only(true) .with_name("mz_session") .with_path("/"); let frontegg_middleware = frontegg.clone();
235 let oidc_middleware_rx = oidc_rx.clone();
236 let adapter_client_middleware_rx = adapter_client_rx.clone();
237 let auth_middleware = middleware::from_fn(move |req, next| {
238 let frontegg = frontegg_middleware.clone();
239 let oidc_rx = oidc_middleware_rx.clone();
240 let adapter_client_rx = adapter_client_middleware_rx.clone();
241 async move {
242 http_auth(
243 req,
244 next,
245 tls_enabled,
246 authenticator_kind,
247 frontegg,
248 oidc_rx,
249 adapter_client_rx,
250 allowed_roles,
251 )
252 .await
253 }
254 });
255
256 let mut router = Router::new();
257 let mut base_router = Router::new();
258 if routes_enabled.base {
259 base_router = base_router
260 .route(
261 "/",
262 routing::get(move || async move { root::handle_home(routes_enabled).await }),
263 )
264 .route("/api/sql", routing::post(sql::handle_sql))
265 .route("/memory", routing::get(memory::handle_memory))
266 .route(
267 "/hierarchical-memory",
268 routing::get(memory::handle_hierarchical_memory),
269 )
270 .route(
271 "/metrics-viz",
272 routing::get(metrics_viz::handle_metrics_viz),
273 )
274 .route("/static/{*path}", routing::get(root::handle_static));
275
276 let mut ws_router = Router::new()
277 .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
278 .with_state(WsState {
279 frontegg,
280 oidc_rx: oidc_rx.clone(),
281 authenticator_kind,
282 adapter_client_rx: adapter_client_rx.clone(),
283 active_connection_counter: active_connection_counter.clone(),
284 helm_chart_version: helm_chart_version.clone(),
285 allowed_roles,
286 });
287 if let listeners::AuthenticatorKind::None = authenticator_kind {
288 ws_router = ws_router.layer(middleware::from_fn_with_state(
289 allowed_roles,
290 x_materialize_user_header_auth,
291 ));
292 }
293 router = router.merge(ws_router);
294 }
295 if routes_enabled.profiling {
296 base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
297 }
298
299 if routes_enabled.webhook {
300 let webhook_router = Router::new()
301 .route(
302 "/api/webhook/{:database}/{:schema}/{:id}",
303 routing::post(webhook::handle_webhook),
304 )
305 .with_state(WebhookState {
306 adapter_client_rx: adapter_client_rx.clone(),
307 webhook_cache,
308 })
309 .layer(
310 tower_http::decompression::RequestDecompressionLayer::new()
311 .gzip(true)
312 .deflate(true)
313 .br(true)
314 .zstd(true),
315 )
316 .layer(
317 CorsLayer::new()
318 .allow_methods(Method::POST)
319 .allow_origin(AllowOrigin::mirror_request())
320 .allow_headers(Any),
321 )
322 .layer(
323 ServiceBuilder::new()
324 .layer(HandleErrorLayer::new(handle_load_error))
325 .load_shed()
326 .layer(GlobalConcurrencyLimitLayer::with_semaphore(
327 concurrent_webhook_req,
328 )),
329 );
330 router = router.merge(webhook_router);
331 }
332
333 if routes_enabled.internal {
334 let console_config = Arc::new(console::ConsoleProxyConfig::new(
335 internal_route_config.internal_console_redirect_url.clone(),
336 "/internal-console".to_string(),
337 ));
338 base_router = base_router
339 .route(
340 "/api/opentelemetry/config",
341 routing::put({
342 move |_: axum::Json<DynamicFilterTarget>| async {
343 (
344 StatusCode::BAD_REQUEST,
345 "This endpoint has been replaced. \
346 Use the `opentelemetry_filter` system variable."
347 .to_string(),
348 )
349 }
350 }),
351 )
352 .route(
353 "/api/stderr/config",
354 routing::put({
355 move |_: axum::Json<DynamicFilterTarget>| async {
356 (
357 StatusCode::BAD_REQUEST,
358 "This endpoint has been replaced. \
359 Use the `log_filter` system variable."
360 .to_string(),
361 )
362 }
363 }),
364 )
365 .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
366 .route(
367 "/api/catalog/dump",
368 routing::get(catalog::handle_catalog_dump),
369 )
370 .route(
371 "/api/catalog/check",
372 routing::get(catalog::handle_catalog_check),
373 )
374 .route(
375 "/api/catalog/inject-audit-events",
376 routing::post(catalog::handle_inject_audit_events),
377 )
378 .route(
379 "/api/coordinator/check",
380 routing::get(catalog::handle_coordinator_check),
381 )
382 .route(
383 "/api/coordinator/dump",
384 routing::get(catalog::handle_coordinator_dump),
385 )
386 .route(
387 "/internal-console",
388 routing::get(|| async { Redirect::temporary("/internal-console/") }),
389 )
390 .route(
391 "/internal-console/{*path}",
392 routing::get(console::handle_internal_console),
393 )
394 .route(
395 "/internal-console/",
396 routing::get(console::handle_internal_console),
397 )
398 .layer(Extension(console_config));
399
400 let cluster_proxy_config = Arc::new(cluster::ClusterProxyConfig::new(Arc::clone(
402 &replica_http_locator,
403 )));
404 base_router = base_router
405 .route("/clusters", routing::get(cluster::handle_clusters))
406 .route(
407 "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/",
408 routing::any(cluster::handle_cluster_proxy_root),
409 )
410 .route(
411 "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/{*path}",
412 routing::any(cluster::handle_cluster_proxy),
413 )
414 .layer(Extension(cluster_proxy_config));
415
416 let leader_router = Router::new()
417 .route("/api/leader/status", routing::get(handle_leader_status))
418 .route("/api/leader/promote", routing::post(handle_leader_promote))
419 .route(
420 "/api/leader/skip-catchup",
421 routing::post(handle_leader_skip_catchup),
422 )
423 .layer(auth_middleware.clone())
424 .with_state(internal_route_config.deployment_state_handle.clone());
425 router = router.merge(leader_router);
426 }
427
428 if routes_enabled.metrics {
429 let metrics_router = Router::new()
430 .route(
431 "/metrics",
432 routing::get(move || async move {
433 mz_http_util::handle_prometheus(&metrics_registry).await
434 }),
435 )
436 .route(
437 "/metrics/mz_usage",
438 routing::get(|client: AuthedClient| async move {
439 let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
440 mz_http_util::handle_prometheus(®istry).await
441 }),
442 )
443 .route(
444 "/metrics/mz_frontier",
445 routing::get(|client: AuthedClient| async move {
446 let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
447 mz_http_util::handle_prometheus(®istry).await
448 }),
449 )
450 .route(
451 "/metrics/mz_compute",
452 routing::get(|client: AuthedClient| async move {
453 let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
454 mz_http_util::handle_prometheus(®istry).await
455 }),
456 )
457 .route(
458 "/metrics/mz_storage",
459 routing::get(|client: AuthedClient| async move {
460 let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
461 mz_http_util::handle_prometheus(®istry).await
462 }),
463 )
464 .route(
465 "/api/livez",
466 routing::get(mz_http_util::handle_liveness_check),
467 )
468 .route("/api/readyz", routing::get(probe::handle_ready))
469 .layer(auth_middleware.clone())
470 .layer(Extension(adapter_client_rx.clone()))
471 .layer(Extension(active_connection_counter.clone()))
472 .layer(Extension(HelmChartVersion(helm_chart_version.clone())));
473 router = router.merge(metrics_router);
474 }
475
476 if routes_enabled.console_config {
477 let console_config_router = Router::new()
478 .route(
479 "/api/console/config",
480 routing::get(console::handle_console_config),
481 )
482 .layer(Extension(adapter_client_rx.clone()))
483 .layer(Extension(active_connection_counter.clone()));
484 router = router.merge(console_config_router);
485 }
486
487 if routes_enabled.mcp_agent || routes_enabled.mcp_developer {
490 use tracing::info;
491
492 let mut mcp_router = Router::new();
493
494 if routes_enabled.mcp_agent {
495 info!("Enabling MCP agent endpoint: /api/mcp/agent");
496 mcp_router = mcp_router.route(
497 "/api/mcp/agent",
498 routing::post(mcp::handle_mcp_agent).get(mcp::handle_mcp_method_not_allowed),
499 );
500 }
501
502 if routes_enabled.mcp_developer {
503 info!("Enabling MCP developer endpoint: /api/mcp/developer");
504 mcp_router = mcp_router.route(
505 "/api/mcp/developer",
506 routing::post(mcp::handle_mcp_developer)
507 .get(mcp::handle_mcp_method_not_allowed),
508 );
509 }
510
511 let mcp_allowed_origins = Arc::new(allowed_origin_list.clone());
518 mcp_router = mcp_router
519 .layer(auth_middleware.clone())
520 .layer(Extension(adapter_client_rx.clone()))
521 .layer(Extension(active_connection_counter.clone()))
522 .layer(Extension(HelmChartVersion(helm_chart_version.clone())))
523 .layer(Extension(mcp_allowed_origins))
524 .layer(
525 CorsLayer::new()
526 .allow_methods(Method::POST)
527 .allow_origin(allowed_origin.clone())
528 .allow_headers([AUTHORIZATION, CONTENT_TYPE]),
529 );
530 router = router.merge(mcp_router);
531 }
532
533 base_router = base_router
534 .layer(auth_middleware.clone())
535 .layer(Extension(adapter_client_rx.clone()))
536 .layer(Extension(active_connection_counter.clone()))
537 .layer(Extension(HelmChartVersion(helm_chart_version)))
538 .layer(
539 CorsLayer::new()
540 .allow_credentials(false)
541 .allow_headers([
542 AUTHORIZATION,
543 CONTENT_TYPE,
544 HeaderName::from_static("x-materialize-version"),
545 ])
546 .allow_methods(Any)
547 .allow_origin(allowed_origin)
548 .expose_headers(Any)
549 .max_age(Duration::from_secs(60) * 60),
550 );
551
552 match authenticator_kind {
553 listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Oidc => {
554 base_router = base_router.layer(session_layer.clone());
555
556 let login_router = Router::new()
557 .route("/api/login", routing::post(handle_login))
558 .route("/api/logout", routing::post(handle_logout))
559 .layer(Extension(adapter_client_rx))
560 .layer(Extension(allowed_roles));
561 router = router.merge(login_router).layer(session_layer);
562 }
563 listeners::AuthenticatorKind::None => {
564 base_router = base_router.layer(middleware::from_fn_with_state(
565 allowed_roles,
566 x_materialize_user_header_auth,
567 ));
568 }
569 _ => {}
570 }
571
572 router = router
573 .merge(base_router)
574 .apply_default_layers(source, metrics);
575
576 HttpServer { tls, router }
577 }
578}
579
580impl Server for HttpServer {
581 const NAME: &'static str = "http";
582
583 fn handle_connection(
584 &self,
585 conn: Connection,
586 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
587 ) -> ConnectionHandler {
588 let router = self.router.clone();
589 let tls_context = self.tls.clone();
590 let mut conn = TokioIo::new(conn);
591
592 Box::pin(async {
593 let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
594 let peer_addr = conn
595 .inner_mut()
596 .take_proxy_header_address()
597 .await
598 .map(|a| a.source)
599 .unwrap_or(direct_peer_addr);
600
601 let (conn, conn_protocol) = match tls_context {
602 Some(tls_context) => {
603 let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
604 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
605 let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
606 return Err(e.into());
607 }
608 (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
609 }
610 _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
611 };
612 let mut make_tower_svc = router
613 .layer(Extension(conn_protocol))
614 .into_make_service_with_connect_info::<SocketAddr>();
615 let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
616 let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
617 let http = hyper::server::conn::http1::Builder::new();
618 http.serve_connection(conn, hyper_svc)
619 .with_upgrades()
620 .err_into()
621 .await
622 })
623 }
624}
625
626pub async fn handle_leader_status(
627 State(deployment_state_handle): State<DeploymentStateHandle>,
628) -> impl IntoResponse {
629 let status = deployment_state_handle.status();
630 (StatusCode::OK, Json(json!({ "status": status })))
631}
632
633pub async fn handle_leader_promote(
634 State(deployment_state_handle): State<DeploymentStateHandle>,
635) -> impl IntoResponse {
636 match deployment_state_handle.try_promote() {
637 Ok(()) => {
638 let status = StatusCode::OK;
641 let body = Json(json!({
642 "result": "Success",
643 }));
644 (status, body)
645 }
646 Err(()) => {
647 let status = StatusCode::BAD_REQUEST;
650 let body = Json(json!({
651 "result": {"Failure": {"message": "cannot promote leader while initializing"}},
652 }));
653 (status, body)
654 }
655 }
656}
657
658pub async fn handle_leader_skip_catchup(
659 State(deployment_state_handle): State<DeploymentStateHandle>,
660) -> impl IntoResponse {
661 match deployment_state_handle.try_skip_catchup() {
662 Ok(()) => StatusCode::NO_CONTENT.into_response(),
663 Err(()) => {
664 let status = StatusCode::BAD_REQUEST;
665 let body = Json(json!({
666 "message": "cannot skip catchup in this phase of initialization; try again later",
667 }));
668 (status, body).into_response()
669 }
670 }
671}
672
673async fn x_materialize_user_header_auth(
674 State(allowed_roles): State<AllowedRoles>,
675 mut req: Request,
676 next: Next,
677) -> impl IntoResponse {
678 if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
680 let username = match username {
681 Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
682 _ => {
683 return Err(AuthError::MismatchedUser(format!(
684 "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
685 )));
686 }
687 };
688 check_role_allowed(&username, allowed_roles)?;
693 req.extensions_mut().insert(AuthedUser {
694 name: username,
695 external_metadata_rx: None,
696 authenticated: Authenticated,
697 authenticator_kind: mz_auth::AuthenticatorKind::None,
698 groups: None,
699 });
700 }
701 Ok(next.run(req).await)
702}
703
704type Delayed<T> = Shared<oneshot::Receiver<T>>;
705
706#[derive(Clone)]
707enum ConnProtocol {
708 Http,
709 Https,
710}
711
712#[derive(Clone, Debug)]
713pub struct AuthedUser {
714 name: String,
715 external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
716 authenticated: Authenticated,
717 authenticator_kind: mz_auth::AuthenticatorKind,
718 groups: Option<Vec<String>>,
720}
721
722pub struct AuthedClient {
723 pub client: SessionClient,
724 pub connection_guard: Option<ConnectionHandle>,
725}
726
727impl AuthedClient {
728 async fn new<F>(
729 adapter_client: &Client,
730 user: AuthedUser,
731 peer_addr: IpAddr,
732 active_connection_counter: ConnectionCounter,
733 helm_chart_version: Option<String>,
734 session_config: F,
735 options: BTreeMap<String, String>,
736 now: NowFn,
737 ) -> Result<Self, AdapterError>
738 where
739 F: FnOnce(&mut AdapterSession),
740 {
741 let conn_id = adapter_client.new_conn_id()?;
742 let mut session = adapter_client.new_session(
743 AdapterSessionConfig {
744 conn_id,
745 uuid: epoch_to_uuid_v7(&(now)()),
746 user: user.name,
747 client_ip: Some(peer_addr),
748 external_metadata_rx: user.external_metadata_rx,
749 helm_chart_version,
750 authenticator_kind: user.authenticator_kind,
751 groups: user.groups,
752 },
753 user.authenticated,
754 );
755 let connection_guard = active_connection_counter.allocate_connection(session.user())?;
756
757 session_config(&mut session);
758 let system_vars = adapter_client.get_system_vars().await;
759 for (key, val) in options {
760 const LOCAL: bool = false;
761 if let Err(err) =
762 session
763 .vars_mut()
764 .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
765 {
766 session.add_notice(AdapterNotice::BadStartupSetting {
767 name: key.to_string(),
768 reason: err.to_string(),
769 })
770 }
771 }
772 let adapter_client = adapter_client.startup(session).await?;
773 Ok(AuthedClient {
774 client: adapter_client,
775 connection_guard,
776 })
777 }
778}
779
780impl<S> FromRequestParts<S> for AuthedClient
781where
782 S: Send + Sync,
783{
784 type Rejection = Response;
785
786 async fn from_request_parts(
787 req: &mut http::request::Parts,
788 state: &S,
789 ) -> Result<Self, Self::Rejection> {
790 #[derive(Debug, Default, Deserialize)]
791 struct Params {
792 #[serde(default)]
793 options: String,
794 }
795 let params: Query<Params> = Query::from_request_parts(req, state)
796 .await
797 .unwrap_or_default();
798
799 let peer_addr = req
800 .extensions
801 .get::<ConnectInfo<SocketAddr>>()
802 .expect("ConnectInfo extension guaranteed to exist")
803 .0
804 .ip();
805
806 let user = req.extensions.get::<AuthedUser>().unwrap();
807 let adapter_client = req
808 .extensions
809 .get::<Delayed<mz_adapter::Client>>()
810 .unwrap()
811 .clone();
812 let adapter_client = adapter_client.await.map_err(|_| {
813 (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
814 })?;
815 let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
816 let helm_chart_version = req
817 .extensions
818 .get::<HelmChartVersion>()
819 .map(|h| h.0.clone())
820 .unwrap_or(None);
821
822 let options = if params.options.is_empty() {
823 BTreeMap::<String, String>::default()
826 } else {
827 match serde_json::from_str(¶ms.options) {
828 Ok(options) => options,
829 Err(_e) => {
830 let code = StatusCode::BAD_REQUEST;
832 let msg = format!("Failed to deserialize {} map", "options".quoted());
833 return Err((code, msg).into_response());
834 }
835 }
836 };
837
838 let client = AuthedClient::new(
839 &adapter_client,
840 user.clone(),
841 peer_addr,
842 active_connection_counter.clone(),
843 helm_chart_version,
844 |session| {
845 session
846 .vars_mut()
847 .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
848 .expect("known to exist")
849 },
850 options,
851 SYSTEM_TIME.clone(),
852 )
853 .await
854 .map_err(|e| {
855 let status = match e {
856 AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
857 StatusCode::FORBIDDEN
858 }
859 _ => StatusCode::INTERNAL_SERVER_ERROR,
860 };
861 (status, Json(SqlError::from(e))).into_response()
862 })?;
863
864 Ok(client)
865 }
866}
867
868#[derive(Debug, Error)]
869pub(crate) enum AuthError {
870 #[error("role dissallowed")]
871 RoleDisallowed(String),
872 #[error("{0}")]
873 Frontegg(#[from] FronteggError),
874 #[error("missing authorization header")]
875 MissingHttpAuthentication {
876 include_www_authenticate_header: bool,
877 },
878 #[error("{0}")]
879 MismatchedUser(String),
880 #[error("session expired")]
881 SessionExpired,
882 #[error("failed to update session")]
883 FailedToUpdateSession,
884 #[error("invalid credentials")]
885 InvalidCredentials,
886 #[error("{0}")]
888 OidcFailed(String),
889}
890
891impl IntoResponse for AuthError {
892 fn into_response(self) -> Response {
893 warn!("HTTP request failed authentication: {}", self);
894 let mut headers = HeaderMap::new();
895 let body = match &self {
900 AuthError::MissingHttpAuthentication {
901 include_www_authenticate_header,
902 } if *include_www_authenticate_header => {
903 headers.insert(
904 http::header::WWW_AUTHENTICATE,
905 HeaderValue::from_static("Basic realm=Materialize"),
906 );
907 "unauthorized".to_string()
908 }
909 AuthError::OidcFailed(message) => message.clone(),
910 _ => "unauthorized".to_string(),
911 };
912 (StatusCode::UNAUTHORIZED, headers, body).into_response()
913 }
914}
915
916pub async fn handle_login(
918 session: Option<Extension<TowerSession>>,
919 Extension(adapter_client_rx): Extension<Delayed<Client>>,
920 Extension(allowed_roles): Extension<AllowedRoles>,
921 Json(LoginCredentials { username, password }): Json<LoginCredentials>,
922) -> impl IntoResponse {
923 if let Err(err) = check_role_allowed(&username, allowed_roles) {
928 warn!(
929 ?err,
930 "HTTP login rejected: role not allowed on this listener"
931 );
932 return StatusCode::UNAUTHORIZED;
933 }
934 let Ok(adapter_client) = adapter_client_rx.clone().await else {
935 return StatusCode::INTERNAL_SERVER_ERROR;
936 };
937 let authenticated = match adapter_client.authenticate(&username, &password).await {
938 Ok(authenticated) => authenticated,
939 Err(err) => {
940 warn!(?err, "HTTP login failed authentication");
941 return StatusCode::UNAUTHORIZED;
942 }
943 };
944 let session_data = TowerSessionData {
946 username,
947 created_at: SystemTime::now(),
948 last_activity: SystemTime::now(),
949 authenticated,
950 authenticator_kind: mz_auth::AuthenticatorKind::Password,
951 };
952 let session = session.and_then(|Extension(session)| Some(session));
954 let Some(session) = session else {
955 return StatusCode::INTERNAL_SERVER_ERROR;
956 };
957 match session.insert("data", &session_data).await {
958 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
959 Ok(_) => StatusCode::OK,
960 }
961}
962
963pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
965 let session = session.and_then(|Extension(session)| Some(session));
966 let Some(session) = session else {
967 return StatusCode::INTERNAL_SERVER_ERROR;
968 };
969 match session.delete().await {
971 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
972 Ok(_) => StatusCode::OK,
973 }
974}
975
976async fn http_auth(
977 mut req: Request,
978 next: Next,
979 tls_enabled: bool,
980 authenticator_kind: listeners::AuthenticatorKind,
981 frontegg: Option<mz_frontegg_auth::Authenticator>,
982 oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
983 adapter_client_rx: Delayed<Client>,
984 allowed_roles: AllowedRoles,
985) -> Result<impl IntoResponse, AuthError> {
986 let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
987 Some(Credentials::Password {
988 username: basic.username().to_owned(),
989 password: Password(basic.password().to_owned()),
990 })
991 } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
992 Some(Credentials::Token {
993 token: bearer.token().to_owned(),
994 })
995 } else {
996 None
997 };
998
999 if creds.is_none()
1003 && let Some((session, session_data)) =
1004 maybe_get_authenticated_session(req.extensions().get::<TowerSession>()).await
1005 {
1006 let user = ensure_session_unexpired(session, session_data).await?;
1007 check_role_allowed(&user.name, allowed_roles)?;
1013 req.extensions_mut().insert(user);
1015 return Ok(next.run(req).await);
1016 }
1017
1018 let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
1022 match (tls_enabled, &conn_protocol) {
1023 (false, ConnProtocol::Http) => {}
1024 (false, ConnProtocol::Https { .. }) => unreachable!(),
1025 (true, ConnProtocol::Http) => {
1026 let mut parts = req.uri().clone().into_parts();
1027 parts.scheme = Some(Scheme::HTTPS);
1028 return Ok(Redirect::permanent(
1029 &Uri::from_parts(parts)
1030 .expect("it was already a URI, just changed the scheme")
1031 .to_string(),
1032 )
1033 .into_response());
1034 }
1035 (true, ConnProtocol::Https { .. }) => {}
1036 }
1037 if req.extensions().get::<AuthedUser>().is_some() {
1039 return Ok(next.run(req).await);
1040 }
1041
1042 let path = req.uri().path();
1043 let include_www_authenticate_header = path == "/"
1044 || PROFILING_API_ENDPOINTS
1045 .iter()
1046 .any(|prefix| path.starts_with(prefix));
1047 let authenticator = get_authenticator(
1048 authenticator_kind,
1049 creds.as_ref(),
1050 frontegg,
1051 &oidc_rx,
1052 &adapter_client_rx,
1053 )
1054 .await;
1055
1056 let user = auth(
1057 &authenticator,
1058 creds,
1059 allowed_roles,
1060 include_www_authenticate_header,
1061 )
1062 .await?;
1063
1064 req.extensions_mut().insert(user);
1067
1068 Ok(next.run(req).await)
1070}
1071
1072async fn init_ws(
1073 WsState {
1074 frontegg,
1075 oidc_rx,
1076 authenticator_kind,
1077 adapter_client_rx,
1078 active_connection_counter,
1079 helm_chart_version,
1080 allowed_roles,
1081 }: WsState,
1082 existing_user: Option<ExistingUser>,
1083 peer_addr: IpAddr,
1084 ws: &mut WebSocket,
1085) -> Result<AuthedClient, anyhow::Error> {
1086 let ws_auth: WebSocketAuth = loop {
1089 let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
1090 match init_msg {
1091 Message::Text(data) => break serde_json::from_str(&data)?,
1092 Message::Binary(data) => break serde_json::from_slice(&data)?,
1093 Message::Ping(_) => {
1095 continue;
1096 }
1097 Message::Pong(_) => {
1098 continue;
1099 }
1100 Message::Close(_) => {
1101 anyhow::bail!("closed");
1102 }
1103 }
1104 };
1105
1106 let (creds, options) = match ws_auth {
1109 WebSocketAuth::Basic {
1110 user,
1111 password,
1112 options,
1113 } => {
1114 let creds = Credentials::Password {
1115 username: user,
1116 password,
1117 };
1118 (Some(creds), options)
1119 }
1120 WebSocketAuth::Bearer { token, options } => {
1121 let creds = Credentials::Token { token };
1122 (Some(creds), options)
1123 }
1124 WebSocketAuth::OptionsOnly { options } => (None, options),
1125 };
1126
1127 let user = match (existing_user, creds) {
1128 (Some(ExistingUser::XMaterializeUserHeader(_)), Some(_creds)) => {
1129 warn!("Unexpected bearer or basic auth provided when using user header");
1130 anyhow::bail!("unexpected")
1131 }
1132 (Some(ExistingUser::Session(user)), None) => {
1133 check_role_allowed(&user.name, allowed_roles)?;
1139 user
1140 }
1141 (Some(ExistingUser::XMaterializeUserHeader(user)), None) => user,
1142 (_, Some(creds)) => {
1143 let authenticator = get_authenticator(
1144 authenticator_kind,
1145 Some(&creds),
1146 frontegg,
1147 &oidc_rx,
1148 &adapter_client_rx,
1149 )
1150 .await;
1151 let user = auth(&authenticator, Some(creds), allowed_roles, false).await?;
1152 user
1153 }
1154 (None, None) => anyhow::bail!("expected auth information"),
1155 };
1156
1157 let client = AuthedClient::new(
1158 &adapter_client_rx.clone().await?,
1159 user,
1160 peer_addr,
1161 active_connection_counter.clone(),
1162 helm_chart_version.clone(),
1163 |_session| (),
1164 options,
1165 SYSTEM_TIME.clone(),
1166 )
1167 .await?;
1168
1169 Ok(client)
1170}
1171
1172enum Credentials {
1173 Password {
1174 username: String,
1175 password: Password,
1176 },
1177 Token {
1178 token: String,
1179 },
1180}
1181
1182async fn get_authenticator(
1183 kind: listeners::AuthenticatorKind,
1184 creds: Option<&Credentials>,
1185 frontegg: Option<mz_frontegg_auth::Authenticator>,
1186 oidc_rx: &Delayed<mz_authenticator::GenericOidcAuthenticator>,
1187 adapter_client_rx: &Delayed<Client>,
1188) -> Authenticator {
1189 match kind {
1190 listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
1191 "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
1192 )),
1193 listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Sasl => {
1194 let client = adapter_client_rx.clone().await.expect("sender not dropped");
1195 Authenticator::Password(client)
1196 }
1197 listeners::AuthenticatorKind::Oidc => match creds {
1198 Some(Credentials::Password { .. }) => {
1200 let client = adapter_client_rx.clone().await.expect("sender not dropped");
1201 Authenticator::Password(client)
1202 }
1203 _ => Authenticator::Oidc(oidc_rx.clone().await.expect("sender not dropped")),
1204 },
1205 listeners::AuthenticatorKind::None => Authenticator::None,
1206 }
1207}
1208
1209pub(crate) async fn maybe_get_authenticated_session(
1213 session: Option<&TowerSession>,
1214) -> Option<(&TowerSession, TowerSessionData)> {
1215 if let Some(session) = session {
1216 if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
1217 return Some((session, session_data));
1218 }
1219 }
1220 None
1221}
1222
1223pub(crate) async fn ensure_session_unexpired(
1226 session: &TowerSession,
1227 session_data: TowerSessionData,
1228) -> Result<AuthedUser, AuthError> {
1229 if session_data
1230 .last_activity
1231 .elapsed()
1232 .unwrap_or(Duration::MAX)
1233 > SESSION_DURATION
1234 {
1235 let _ = session.delete().await;
1236 return Err(AuthError::SessionExpired);
1237 }
1238 let mut updated_data = session_data.clone();
1239 updated_data.last_activity = SystemTime::now();
1240 session
1241 .insert("data", &updated_data)
1242 .await
1243 .map_err(|_| AuthError::FailedToUpdateSession)?;
1244
1245 Ok(AuthedUser {
1246 name: session_data.username,
1247 external_metadata_rx: None,
1248 authenticated: session_data.authenticated,
1249 authenticator_kind: session_data.authenticator_kind,
1250 groups: None,
1251 })
1252}
1253
1254async fn auth(
1255 authenticator: &Authenticator,
1256 creds: Option<Credentials>,
1257 allowed_roles: AllowedRoles,
1258 include_www_authenticate_header: bool,
1259) -> Result<AuthedUser, AuthError> {
1260 let (name, external_metadata_rx, authenticated, groups) = match authenticator {
1261 Authenticator::Frontegg(frontegg) => match creds {
1262 Some(Credentials::Password { username, password }) => {
1263 let (auth_session, authenticated) =
1264 frontegg.authenticate(&username, password.as_str()).await?;
1265 let name = auth_session.user().into();
1266 let external_metadata_rx = Some(auth_session.external_metadata_rx());
1267 (name, external_metadata_rx, authenticated, None)
1268 }
1269 Some(Credentials::Token { token }) => {
1270 let (claims, authenticated) = frontegg.validate_access_token(&token, None)?;
1271 let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
1272 user_id: claims.user_id,
1273 admin: claims.is_admin,
1274 });
1275 (claims.user, Some(external_metadata_rx), authenticated, None)
1276 }
1277 None => {
1278 return Err(AuthError::MissingHttpAuthentication {
1279 include_www_authenticate_header,
1280 });
1281 }
1282 },
1283 Authenticator::Password(adapter_client) => match creds {
1284 Some(Credentials::Password { username, password }) => {
1285 let authenticated = adapter_client
1286 .authenticate(&username, &password)
1287 .await
1288 .map_err(|_| AuthError::InvalidCredentials)?;
1289 (username, None, authenticated, None)
1290 }
1291 _ => {
1292 return Err(AuthError::MissingHttpAuthentication {
1293 include_www_authenticate_header,
1294 });
1295 }
1296 },
1297 Authenticator::Sasl(_) => {
1298 return Err(AuthError::MissingHttpAuthentication {
1302 include_www_authenticate_header,
1303 });
1304 }
1305 Authenticator::Oidc(oidc) => match creds {
1306 Some(Credentials::Token { token }) => {
1307 let (mut claims, authenticated) = oidc
1308 .authenticate(&token, None)
1309 .await
1310 .map_err(|e| AuthError::OidcFailed(e.to_string()))?;
1311 let name = std::mem::take(&mut claims.user);
1312 let groups = claims.groups.take();
1313 (name, None, authenticated, groups)
1314 }
1315 _ => {
1316 return Err(AuthError::MissingHttpAuthentication {
1317 include_www_authenticate_header,
1318 });
1319 }
1320 },
1321 Authenticator::None => {
1322 let name = match creds {
1326 Some(Credentials::Password { username, .. }) => username,
1327 _ => HTTP_DEFAULT_USER.name.to_owned(),
1328 };
1329 (name, None, Authenticated, None)
1330 }
1331 };
1332
1333 check_role_allowed(&name, allowed_roles)?;
1334
1335 Ok(AuthedUser {
1336 name,
1337 external_metadata_rx,
1338 authenticated,
1339 authenticator_kind: authenticator.kind(),
1340 groups,
1341 })
1342}
1343
1344fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1346 let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1347 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1349 let role_allowed = match allowed_roles {
1350 AllowedRoles::Normal => !is_reserved_user,
1351 AllowedRoles::Internal => is_internal_user,
1352 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1353 };
1354 if role_allowed {
1355 Ok(())
1356 } else {
1357 Err(AuthError::RoleDisallowed(name.to_owned()))
1358 }
1359}
1360
1361trait DefaultLayers {
1364 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1365}
1366
1367impl DefaultLayers for Router {
1368 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1369 self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1370 .layer(metrics::PrometheusLayer::new(source, metrics))
1371 }
1372}
1373
1374async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1381 if error.is::<tower::load_shed::error::Overloaded>() {
1382 return (
1383 StatusCode::TOO_MANY_REQUESTS,
1384 Cow::from("too many requests, try again later"),
1385 );
1386 }
1387
1388 (
1391 StatusCode::INTERNAL_SERVER_ERROR,
1392 Cow::from(format!("Unhandled internal error: {}", error)),
1393 )
1394}
1395
1396#[derive(Debug, Deserialize, Serialize, PartialEq)]
1397pub struct LoginCredentials {
1398 username: String,
1399 password: Password,
1400}
1401
1402#[derive(Debug, Clone, Serialize, Deserialize)]
1403pub struct TowerSessionData {
1404 username: String,
1405 created_at: SystemTime,
1406 last_activity: SystemTime,
1407 authenticated: Authenticated,
1408 authenticator_kind: mz_auth::AuthenticatorKind,
1409}
1410
1411#[cfg(test)]
1412mod tests {
1413 use super::{AllowedRoles, check_role_allowed};
1414
1415 #[mz_ore::test]
1416 fn test_check_role_allowed() {
1417 assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1419 assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1420 assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1421
1422 assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1424 assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1425 assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1426
1427 assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1429 assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1430 assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1431
1432 assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1434 assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1435 assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1436
1437 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1439 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1440 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1441
1442 assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1444 assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1445 assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1446
1447 assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1449 assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1450 assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1451
1452 assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1454 assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1455 assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1456
1457 assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1459 assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1460 assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1461
1462 assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1464 assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1465 assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1466 }
1467}