1#![allow(clippy::unused_async)]
51
52use std::borrow::Cow;
53use std::collections::BTreeMap;
54use std::fmt::Debug;
55use std::net::{IpAddr, SocketAddr};
56use std::pin::Pin;
57use std::sync::Arc;
58use std::time::{Duration, SystemTime};
59
60use anyhow::Context;
61use axum::error_handling::HandleErrorLayer;
62use axum::extract::ws::{Message, WebSocket};
63use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
64use axum::middleware::{self, Next};
65use axum::response::{IntoResponse, Redirect, Response};
66use axum::{Extension, Json, Router, routing};
67use futures::future::{Shared, TryFutureExt};
68use headers::authorization::{Authorization, Basic, Bearer};
69use headers::{HeaderMapExt, HeaderName};
70use http::header::{AUTHORIZATION, CONTENT_TYPE};
71use http::uri::Scheme;
72use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
73use hyper_openssl::SslStream;
74use hyper_openssl::client::legacy::MaybeHttpsStream;
75use hyper_util::rt::TokioIo;
76use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
77use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
78use mz_auth::Authenticated;
79use mz_auth::password::Password;
80use mz_authenticator::Authenticator;
81use mz_controller::ReplicaHttpLocator;
82use mz_frontegg_auth::Error as FronteggError;
83use mz_http_util::DynamicFilterTarget;
84use mz_ore::cast::u64_to_usize;
85use mz_ore::metrics::MetricsRegistry;
86use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
87use mz_ore::str::StrExt;
88use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
89use mz_repr::user::ExternalUserMetadata;
90use mz_server_core::listeners::{self, AllowedRoles, HttpRoutesEnabled};
91use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
92use mz_sql::session::metadata::SessionMetadata;
93use mz_sql::session::user::{
94 HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
95};
96use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
97use openssl::ssl::Ssl;
98use prometheus::{
99 COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
100};
101use serde::{Deserialize, Serialize};
102use serde_json::json;
103use thiserror::Error;
104use tokio::io::AsyncWriteExt;
105use tokio::sync::oneshot::Receiver;
106use tokio::sync::{oneshot, watch};
107use tokio_metrics::TaskMetrics;
108use tower::limit::GlobalConcurrencyLimitLayer;
109use tower::{Service, ServiceBuilder};
110use tower_http::cors::{AllowOrigin, Any, CorsLayer};
111use tower_sessions::{
112 MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
113 SessionManagerLayer as TowerSessionManagerLayer,
114};
115use tracing::warn;
116
117use crate::BUILD_INFO;
118use crate::deployment::state::DeploymentStateHandle;
119use crate::http::sql::{ExistingUser, SqlError};
120
121mod catalog;
122mod cluster;
123mod console;
124mod mcp;
125mod memory;
126mod metrics;
127mod metrics_viz;
128mod probe;
129mod prometheus;
130mod root;
131mod sql;
132mod webhook;
133
134pub use metrics::Metrics;
135pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
136
137pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
139
140const SESSION_DURATION: Duration = Duration::from_secs(8 * 3600); const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
143
144#[derive(Debug)]
145pub struct HttpConfig {
146 pub source: &'static str,
147 pub tls: Option<ReloadingSslContext>,
148 pub authenticator_kind: listeners::AuthenticatorKind,
149 pub frontegg: Option<mz_frontegg_auth::Authenticator>,
150 pub oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
151 pub adapter_client_rx: Shared<Receiver<Client>>,
152 pub allowed_origin: AllowOrigin,
153 pub allowed_origin_list: Vec<HeaderValue>,
156 pub active_connection_counter: ConnectionCounter,
157 pub helm_chart_version: Option<String>,
158 pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
159 pub metrics: Metrics,
160 pub metrics_registry: MetricsRegistry,
161 pub allowed_roles: AllowedRoles,
162 pub internal_route_config: Arc<InternalRouteConfig>,
163 pub routes_enabled: HttpRoutesEnabled,
164 pub replica_http_locator: Arc<ReplicaHttpLocator>,
166}
167
168#[derive(Debug, Clone)]
169pub struct InternalRouteConfig {
170 pub deployment_state_handle: DeploymentStateHandle,
171 pub internal_console_redirect_url: Option<String>,
172}
173
174#[derive(Clone)]
175pub struct WsState {
176 frontegg: Option<mz_frontegg_auth::Authenticator>,
177 oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
178 authenticator_kind: listeners::AuthenticatorKind,
179 adapter_client_rx: Delayed<mz_adapter::Client>,
180 active_connection_counter: ConnectionCounter,
181 helm_chart_version: Option<String>,
182 allowed_roles: AllowedRoles,
183}
184
185#[derive(Clone)]
186pub struct WebhookState {
187 adapter_client_rx: Delayed<mz_adapter::Client>,
188 webhook_cache: WebhookAppenderCache,
189}
190
191#[derive(Debug)]
192pub struct HttpServer {
193 tls: Option<ReloadingSslContext>,
194 router: Router,
195}
196
197impl HttpServer {
198 pub fn new(
199 HttpConfig {
200 source,
201 tls,
202 authenticator_kind,
203 frontegg,
204 oidc_rx,
205 adapter_client_rx,
206 allowed_origin,
207 allowed_origin_list,
208 active_connection_counter,
209 helm_chart_version,
210 concurrent_webhook_req,
211 metrics,
212 metrics_registry,
213 allowed_roles,
214 internal_route_config,
215 routes_enabled,
216 replica_http_locator,
217 }: HttpConfig,
218 ) -> HttpServer {
219 let tls_enabled = tls.is_some();
220 let webhook_cache = WebhookAppenderCache::new();
221
222 let session_store = TowerSessionMemoryStore::default();
224 let session_layer = TowerSessionManagerLayer::new(session_store)
225 .with_secure(tls_enabled) .with_same_site(tower_sessions::cookie::SameSite::Strict) .with_http_only(true) .with_name("mz_session") .with_path("/"); let frontegg_middleware = frontegg.clone();
232 let oidc_middleware_rx = oidc_rx.clone();
233 let adapter_client_middleware_rx = adapter_client_rx.clone();
234 let auth_middleware = middleware::from_fn(move |req, next| {
235 let frontegg = frontegg_middleware.clone();
236 let oidc_rx = oidc_middleware_rx.clone();
237 let adapter_client_rx = adapter_client_middleware_rx.clone();
238 async move {
239 http_auth(
240 req,
241 next,
242 tls_enabled,
243 authenticator_kind,
244 frontegg,
245 oidc_rx,
246 adapter_client_rx,
247 allowed_roles,
248 )
249 .await
250 }
251 });
252
253 let mut router = Router::new();
254 let mut base_router = Router::new();
255 if routes_enabled.base {
256 base_router = base_router
257 .route(
258 "/",
259 routing::get(move || async move { root::handle_home(routes_enabled).await }),
260 )
261 .route("/api/sql", routing::post(sql::handle_sql))
262 .route("/memory", routing::get(memory::handle_memory))
263 .route(
264 "/hierarchical-memory",
265 routing::get(memory::handle_hierarchical_memory),
266 )
267 .route(
268 "/metrics-viz",
269 routing::get(metrics_viz::handle_metrics_viz),
270 )
271 .route("/static/{*path}", routing::get(root::handle_static));
272
273 let mut ws_router = Router::new()
274 .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
275 .with_state(WsState {
276 frontegg,
277 oidc_rx: oidc_rx.clone(),
278 authenticator_kind,
279 adapter_client_rx: adapter_client_rx.clone(),
280 active_connection_counter: active_connection_counter.clone(),
281 helm_chart_version,
282 allowed_roles,
283 });
284 if let listeners::AuthenticatorKind::None = authenticator_kind {
285 ws_router = ws_router.layer(middleware::from_fn(x_materialize_user_header_auth));
286 }
287 router = router.merge(ws_router);
288 }
289 if routes_enabled.profiling {
290 base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
291 }
292
293 if routes_enabled.webhook {
294 let webhook_router = Router::new()
295 .route(
296 "/api/webhook/{:database}/{:schema}/{:id}",
297 routing::post(webhook::handle_webhook),
298 )
299 .with_state(WebhookState {
300 adapter_client_rx: adapter_client_rx.clone(),
301 webhook_cache,
302 })
303 .layer(
304 tower_http::decompression::RequestDecompressionLayer::new()
305 .gzip(true)
306 .deflate(true)
307 .br(true)
308 .zstd(true),
309 )
310 .layer(
311 CorsLayer::new()
312 .allow_methods(Method::POST)
313 .allow_origin(AllowOrigin::mirror_request())
314 .allow_headers(Any),
315 )
316 .layer(
317 ServiceBuilder::new()
318 .layer(HandleErrorLayer::new(handle_load_error))
319 .load_shed()
320 .layer(GlobalConcurrencyLimitLayer::with_semaphore(
321 concurrent_webhook_req,
322 )),
323 );
324 router = router.merge(webhook_router);
325 }
326
327 if routes_enabled.internal {
328 let console_config = Arc::new(console::ConsoleProxyConfig::new(
329 internal_route_config.internal_console_redirect_url.clone(),
330 "/internal-console".to_string(),
331 ));
332 base_router = base_router
333 .route(
334 "/api/opentelemetry/config",
335 routing::put({
336 move |_: axum::Json<DynamicFilterTarget>| async {
337 (
338 StatusCode::BAD_REQUEST,
339 "This endpoint has been replaced. \
340 Use the `opentelemetry_filter` system variable."
341 .to_string(),
342 )
343 }
344 }),
345 )
346 .route(
347 "/api/stderr/config",
348 routing::put({
349 move |_: axum::Json<DynamicFilterTarget>| async {
350 (
351 StatusCode::BAD_REQUEST,
352 "This endpoint has been replaced. \
353 Use the `log_filter` system variable."
354 .to_string(),
355 )
356 }
357 }),
358 )
359 .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
360 .route(
361 "/api/catalog/dump",
362 routing::get(catalog::handle_catalog_dump),
363 )
364 .route(
365 "/api/catalog/check",
366 routing::get(catalog::handle_catalog_check),
367 )
368 .route(
369 "/api/catalog/inject-audit-events",
370 routing::post(catalog::handle_inject_audit_events),
371 )
372 .route(
373 "/api/coordinator/check",
374 routing::get(catalog::handle_coordinator_check),
375 )
376 .route(
377 "/api/coordinator/dump",
378 routing::get(catalog::handle_coordinator_dump),
379 )
380 .route(
381 "/internal-console",
382 routing::get(|| async { Redirect::temporary("/internal-console/") }),
383 )
384 .route(
385 "/internal-console/{*path}",
386 routing::get(console::handle_internal_console),
387 )
388 .route(
389 "/internal-console/",
390 routing::get(console::handle_internal_console),
391 )
392 .layer(Extension(console_config));
393
394 let cluster_proxy_config = Arc::new(cluster::ClusterProxyConfig::new(Arc::clone(
396 &replica_http_locator,
397 )));
398 base_router = base_router
399 .route("/clusters", routing::get(cluster::handle_clusters))
400 .route(
401 "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/",
402 routing::any(cluster::handle_cluster_proxy_root),
403 )
404 .route(
405 "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/{*path}",
406 routing::any(cluster::handle_cluster_proxy),
407 )
408 .layer(Extension(cluster_proxy_config));
409
410 let leader_router = Router::new()
411 .route("/api/leader/status", routing::get(handle_leader_status))
412 .route("/api/leader/promote", routing::post(handle_leader_promote))
413 .route(
414 "/api/leader/skip-catchup",
415 routing::post(handle_leader_skip_catchup),
416 )
417 .layer(auth_middleware.clone())
418 .with_state(internal_route_config.deployment_state_handle.clone());
419 router = router.merge(leader_router);
420 }
421
422 if routes_enabled.metrics {
423 let metrics_router = Router::new()
424 .route(
425 "/metrics",
426 routing::get(move || async move {
427 mz_http_util::handle_prometheus(&metrics_registry).await
428 }),
429 )
430 .route(
431 "/metrics/mz_usage",
432 routing::get(|client: AuthedClient| async move {
433 let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
434 mz_http_util::handle_prometheus(®istry).await
435 }),
436 )
437 .route(
438 "/metrics/mz_frontier",
439 routing::get(|client: AuthedClient| async move {
440 let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
441 mz_http_util::handle_prometheus(®istry).await
442 }),
443 )
444 .route(
445 "/metrics/mz_compute",
446 routing::get(|client: AuthedClient| async move {
447 let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
448 mz_http_util::handle_prometheus(®istry).await
449 }),
450 )
451 .route(
452 "/metrics/mz_storage",
453 routing::get(|client: AuthedClient| async move {
454 let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
455 mz_http_util::handle_prometheus(®istry).await
456 }),
457 )
458 .route(
459 "/api/livez",
460 routing::get(mz_http_util::handle_liveness_check),
461 )
462 .route("/api/readyz", routing::get(probe::handle_ready))
463 .layer(auth_middleware.clone())
464 .layer(Extension(adapter_client_rx.clone()))
465 .layer(Extension(active_connection_counter.clone()));
466 router = router.merge(metrics_router);
467 }
468
469 if routes_enabled.console_config {
470 let console_config_router = Router::new()
471 .route(
472 "/api/console/config",
473 routing::get(console::handle_console_config),
474 )
475 .layer(Extension(adapter_client_rx.clone()))
476 .layer(Extension(active_connection_counter.clone()));
477 router = router.merge(console_config_router);
478 }
479
480 if routes_enabled.mcp_agent || routes_enabled.mcp_developer {
483 use tracing::info;
484
485 let mut mcp_router = Router::new();
486
487 if routes_enabled.mcp_agent {
488 info!("Enabling MCP agent endpoint: /api/mcp/agent");
489 mcp_router = mcp_router.route(
490 "/api/mcp/agent",
491 routing::post(mcp::handle_mcp_agent).get(mcp::handle_mcp_method_not_allowed),
492 );
493 }
494
495 if routes_enabled.mcp_developer {
496 info!("Enabling MCP developer endpoint: /api/mcp/developer");
497 mcp_router = mcp_router.route(
498 "/api/mcp/developer",
499 routing::post(mcp::handle_mcp_developer)
500 .get(mcp::handle_mcp_method_not_allowed),
501 );
502 }
503
504 let mcp_allowed_origins = Arc::new(allowed_origin_list.clone());
511 mcp_router = mcp_router
512 .layer(auth_middleware.clone())
513 .layer(Extension(adapter_client_rx.clone()))
514 .layer(Extension(active_connection_counter.clone()))
515 .layer(Extension(mcp_allowed_origins))
516 .layer(
517 CorsLayer::new()
518 .allow_methods(Method::POST)
519 .allow_origin(allowed_origin.clone())
520 .allow_headers([AUTHORIZATION, CONTENT_TYPE]),
521 );
522 router = router.merge(mcp_router);
523 }
524
525 base_router = base_router
526 .layer(auth_middleware.clone())
527 .layer(Extension(adapter_client_rx.clone()))
528 .layer(Extension(active_connection_counter.clone()))
529 .layer(
530 CorsLayer::new()
531 .allow_credentials(false)
532 .allow_headers([
533 AUTHORIZATION,
534 CONTENT_TYPE,
535 HeaderName::from_static("x-materialize-version"),
536 ])
537 .allow_methods(Any)
538 .allow_origin(allowed_origin)
539 .expose_headers(Any)
540 .max_age(Duration::from_secs(60) * 60),
541 );
542
543 match authenticator_kind {
544 listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Oidc => {
545 base_router = base_router.layer(session_layer.clone());
546
547 let login_router = Router::new()
548 .route("/api/login", routing::post(handle_login))
549 .route("/api/logout", routing::post(handle_logout))
550 .layer(Extension(adapter_client_rx));
551 router = router.merge(login_router).layer(session_layer);
552 }
553 listeners::AuthenticatorKind::None => {
554 base_router =
555 base_router.layer(middleware::from_fn(x_materialize_user_header_auth));
556 }
557 _ => {}
558 }
559
560 router = router
561 .merge(base_router)
562 .apply_default_layers(source, metrics);
563
564 HttpServer { tls, router }
565 }
566}
567
568impl Server for HttpServer {
569 const NAME: &'static str = "http";
570
571 fn handle_connection(
572 &self,
573 conn: Connection,
574 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
575 ) -> ConnectionHandler {
576 let router = self.router.clone();
577 let tls_context = self.tls.clone();
578 let mut conn = TokioIo::new(conn);
579
580 Box::pin(async {
581 let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
582 let peer_addr = conn
583 .inner_mut()
584 .take_proxy_header_address()
585 .await
586 .map(|a| a.source)
587 .unwrap_or(direct_peer_addr);
588
589 let (conn, conn_protocol) = match tls_context {
590 Some(tls_context) => {
591 let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
592 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
593 let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
594 return Err(e.into());
595 }
596 (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
597 }
598 _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
599 };
600 let mut make_tower_svc = router
601 .layer(Extension(conn_protocol))
602 .into_make_service_with_connect_info::<SocketAddr>();
603 let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
604 let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
605 let http = hyper::server::conn::http1::Builder::new();
606 http.serve_connection(conn, hyper_svc)
607 .with_upgrades()
608 .err_into()
609 .await
610 })
611 }
612}
613
614pub async fn handle_leader_status(
615 State(deployment_state_handle): State<DeploymentStateHandle>,
616) -> impl IntoResponse {
617 let status = deployment_state_handle.status();
618 (StatusCode::OK, Json(json!({ "status": status })))
619}
620
621pub async fn handle_leader_promote(
622 State(deployment_state_handle): State<DeploymentStateHandle>,
623) -> impl IntoResponse {
624 match deployment_state_handle.try_promote() {
625 Ok(()) => {
626 let status = StatusCode::OK;
629 let body = Json(json!({
630 "result": "Success",
631 }));
632 (status, body)
633 }
634 Err(()) => {
635 let status = StatusCode::BAD_REQUEST;
638 let body = Json(json!({
639 "result": {"Failure": {"message": "cannot promote leader while initializing"}},
640 }));
641 (status, body)
642 }
643 }
644}
645
646pub async fn handle_leader_skip_catchup(
647 State(deployment_state_handle): State<DeploymentStateHandle>,
648) -> impl IntoResponse {
649 match deployment_state_handle.try_skip_catchup() {
650 Ok(()) => StatusCode::NO_CONTENT.into_response(),
651 Err(()) => {
652 let status = StatusCode::BAD_REQUEST;
653 let body = Json(json!({
654 "message": "cannot skip catchup in this phase of initialization; try again later",
655 }));
656 (status, body).into_response()
657 }
658 }
659}
660
661async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl IntoResponse {
662 if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
664 let username = match username {
665 Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
666 _ => {
667 return Err(AuthError::MismatchedUser(format!(
668 "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
669 )));
670 }
671 };
672 req.extensions_mut().insert(AuthedUser {
673 name: username,
674 external_metadata_rx: None,
675 authenticated: Authenticated,
676 authenticator_kind: mz_auth::AuthenticatorKind::None,
677 });
678 }
679 Ok(next.run(req).await)
680}
681
682type Delayed<T> = Shared<oneshot::Receiver<T>>;
683
684#[derive(Clone)]
685enum ConnProtocol {
686 Http,
687 Https,
688}
689
690#[derive(Clone, Debug)]
691pub struct AuthedUser {
692 name: String,
693 external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
694 authenticated: Authenticated,
695 authenticator_kind: mz_auth::AuthenticatorKind,
696}
697
698pub struct AuthedClient {
699 pub client: SessionClient,
700 pub connection_guard: Option<ConnectionHandle>,
701}
702
703impl AuthedClient {
704 async fn new<F>(
705 adapter_client: &Client,
706 user: AuthedUser,
707 peer_addr: IpAddr,
708 active_connection_counter: ConnectionCounter,
709 helm_chart_version: Option<String>,
710 session_config: F,
711 options: BTreeMap<String, String>,
712 now: NowFn,
713 ) -> Result<Self, AdapterError>
714 where
715 F: FnOnce(&mut AdapterSession),
716 {
717 let conn_id = adapter_client.new_conn_id()?;
718 let mut session = adapter_client.new_session(
719 AdapterSessionConfig {
720 conn_id,
721 uuid: epoch_to_uuid_v7(&(now)()),
722 user: user.name,
723 client_ip: Some(peer_addr),
724 external_metadata_rx: user.external_metadata_rx,
725 helm_chart_version,
726 authenticator_kind: user.authenticator_kind,
727 },
728 user.authenticated,
729 );
730 let connection_guard = active_connection_counter.allocate_connection(session.user())?;
731
732 session_config(&mut session);
733 let system_vars = adapter_client.get_system_vars().await;
734 for (key, val) in options {
735 const LOCAL: bool = false;
736 if let Err(err) =
737 session
738 .vars_mut()
739 .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
740 {
741 session.add_notice(AdapterNotice::BadStartupSetting {
742 name: key.to_string(),
743 reason: err.to_string(),
744 })
745 }
746 }
747 let adapter_client = adapter_client.startup(session).await?;
748 Ok(AuthedClient {
749 client: adapter_client,
750 connection_guard,
751 })
752 }
753}
754
755impl<S> FromRequestParts<S> for AuthedClient
756where
757 S: Send + Sync,
758{
759 type Rejection = Response;
760
761 async fn from_request_parts(
762 req: &mut http::request::Parts,
763 state: &S,
764 ) -> Result<Self, Self::Rejection> {
765 #[derive(Debug, Default, Deserialize)]
766 struct Params {
767 #[serde(default)]
768 options: String,
769 }
770 let params: Query<Params> = Query::from_request_parts(req, state)
771 .await
772 .unwrap_or_default();
773
774 let peer_addr = req
775 .extensions
776 .get::<ConnectInfo<SocketAddr>>()
777 .expect("ConnectInfo extension guaranteed to exist")
778 .0
779 .ip();
780
781 let user = req.extensions.get::<AuthedUser>().unwrap();
782 let adapter_client = req
783 .extensions
784 .get::<Delayed<mz_adapter::Client>>()
785 .unwrap()
786 .clone();
787 let adapter_client = adapter_client.await.map_err(|_| {
788 (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
789 })?;
790 let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
791 let helm_chart_version = None;
792
793 let options = if params.options.is_empty() {
794 BTreeMap::<String, String>::default()
797 } else {
798 match serde_json::from_str(¶ms.options) {
799 Ok(options) => options,
800 Err(_e) => {
801 let code = StatusCode::BAD_REQUEST;
803 let msg = format!("Failed to deserialize {} map", "options".quoted());
804 return Err((code, msg).into_response());
805 }
806 }
807 };
808
809 let client = AuthedClient::new(
810 &adapter_client,
811 user.clone(),
812 peer_addr,
813 active_connection_counter.clone(),
814 helm_chart_version,
815 |session| {
816 session
817 .vars_mut()
818 .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
819 .expect("known to exist")
820 },
821 options,
822 SYSTEM_TIME.clone(),
823 )
824 .await
825 .map_err(|e| {
826 let status = match e {
827 AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
828 StatusCode::FORBIDDEN
829 }
830 _ => StatusCode::INTERNAL_SERVER_ERROR,
831 };
832 (status, Json(SqlError::from(e))).into_response()
833 })?;
834
835 Ok(client)
836 }
837}
838
839#[derive(Debug, Error)]
840pub(crate) enum AuthError {
841 #[error("role dissallowed")]
842 RoleDisallowed(String),
843 #[error("{0}")]
844 Frontegg(#[from] FronteggError),
845 #[error("missing authorization header")]
846 MissingHttpAuthentication {
847 include_www_authenticate_header: bool,
848 },
849 #[error("{0}")]
850 MismatchedUser(String),
851 #[error("session expired")]
852 SessionExpired,
853 #[error("failed to update session")]
854 FailedToUpdateSession,
855 #[error("invalid credentials")]
856 InvalidCredentials,
857}
858
859impl IntoResponse for AuthError {
860 fn into_response(self) -> Response {
861 warn!("HTTP request failed authentication: {}", self);
862 let mut headers = HeaderMap::new();
863 match self {
864 AuthError::MissingHttpAuthentication {
865 include_www_authenticate_header,
866 } if include_www_authenticate_header => {
867 headers.insert(
868 http::header::WWW_AUTHENTICATE,
869 HeaderValue::from_static("Basic realm=Materialize"),
870 );
871 }
872 _ => {}
873 };
874 (StatusCode::UNAUTHORIZED, headers, "unauthorized").into_response()
877 }
878}
879
880pub async fn handle_login(
882 session: Option<Extension<TowerSession>>,
883 Extension(adapter_client_rx): Extension<Delayed<Client>>,
884 Json(LoginCredentials { username, password }): Json<LoginCredentials>,
885) -> impl IntoResponse {
886 let Ok(adapter_client) = adapter_client_rx.clone().await else {
887 return StatusCode::INTERNAL_SERVER_ERROR;
888 };
889 let authenticated = match adapter_client.authenticate(&username, &password).await {
890 Ok(authenticated) => authenticated,
891 Err(err) => {
892 warn!(?err, "HTTP login failed authentication");
893 return StatusCode::UNAUTHORIZED;
894 }
895 };
896 let session_data = TowerSessionData {
898 username,
899 created_at: SystemTime::now(),
900 last_activity: SystemTime::now(),
901 authenticated,
902 authenticator_kind: mz_auth::AuthenticatorKind::Password,
903 };
904 let session = session.and_then(|Extension(session)| Some(session));
906 let Some(session) = session else {
907 return StatusCode::INTERNAL_SERVER_ERROR;
908 };
909 match session.insert("data", &session_data).await {
910 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
911 Ok(_) => StatusCode::OK,
912 }
913}
914
915pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
917 let session = session.and_then(|Extension(session)| Some(session));
918 let Some(session) = session else {
919 return StatusCode::INTERNAL_SERVER_ERROR;
920 };
921 match session.delete().await {
923 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
924 Ok(_) => StatusCode::OK,
925 }
926}
927
928async fn http_auth(
929 mut req: Request,
930 next: Next,
931 tls_enabled: bool,
932 authenticator_kind: listeners::AuthenticatorKind,
933 frontegg: Option<mz_frontegg_auth::Authenticator>,
934 oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
935 adapter_client_rx: Delayed<Client>,
936 allowed_roles: AllowedRoles,
937) -> Result<impl IntoResponse, AuthError> {
938 let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
939 Some(Credentials::Password {
940 username: basic.username().to_owned(),
941 password: Password(basic.password().to_owned()),
942 })
943 } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
944 Some(Credentials::Token {
945 token: bearer.token().to_owned(),
946 })
947 } else {
948 None
949 };
950
951 if creds.is_none()
955 && let Some((session, session_data)) =
956 maybe_get_authenticated_session(req.extensions().get::<TowerSession>()).await
957 {
958 let user = ensure_session_unexpired(session, session_data).await?;
959 req.extensions_mut().insert(user);
961 return Ok(next.run(req).await);
962 }
963
964 let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
968 match (tls_enabled, &conn_protocol) {
969 (false, ConnProtocol::Http) => {}
970 (false, ConnProtocol::Https { .. }) => unreachable!(),
971 (true, ConnProtocol::Http) => {
972 let mut parts = req.uri().clone().into_parts();
973 parts.scheme = Some(Scheme::HTTPS);
974 return Ok(Redirect::permanent(
975 &Uri::from_parts(parts)
976 .expect("it was already a URI, just changed the scheme")
977 .to_string(),
978 )
979 .into_response());
980 }
981 (true, ConnProtocol::Https { .. }) => {}
982 }
983 if req.extensions().get::<AuthedUser>().is_some() {
985 return Ok(next.run(req).await);
986 }
987
988 let path = req.uri().path();
989 let include_www_authenticate_header = path == "/"
990 || PROFILING_API_ENDPOINTS
991 .iter()
992 .any(|prefix| path.starts_with(prefix));
993 let authenticator = get_authenticator(
994 authenticator_kind,
995 creds.as_ref(),
996 frontegg,
997 &oidc_rx,
998 &adapter_client_rx,
999 )
1000 .await;
1001
1002 let user = auth(
1003 &authenticator,
1004 creds,
1005 allowed_roles,
1006 include_www_authenticate_header,
1007 )
1008 .await?;
1009
1010 req.extensions_mut().insert(user);
1013
1014 Ok(next.run(req).await)
1016}
1017
1018async fn init_ws(
1019 WsState {
1020 frontegg,
1021 oidc_rx,
1022 authenticator_kind,
1023 adapter_client_rx,
1024 active_connection_counter,
1025 helm_chart_version,
1026 allowed_roles,
1027 }: WsState,
1028 existing_user: Option<ExistingUser>,
1029 peer_addr: IpAddr,
1030 ws: &mut WebSocket,
1031) -> Result<AuthedClient, anyhow::Error> {
1032 let ws_auth: WebSocketAuth = loop {
1035 let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
1036 match init_msg {
1037 Message::Text(data) => break serde_json::from_str(&data)?,
1038 Message::Binary(data) => break serde_json::from_slice(&data)?,
1039 Message::Ping(_) => {
1041 continue;
1042 }
1043 Message::Pong(_) => {
1044 continue;
1045 }
1046 Message::Close(_) => {
1047 anyhow::bail!("closed");
1048 }
1049 }
1050 };
1051
1052 let (creds, options) = match ws_auth {
1055 WebSocketAuth::Basic {
1056 user,
1057 password,
1058 options,
1059 } => {
1060 let creds = Credentials::Password {
1061 username: user,
1062 password,
1063 };
1064 (Some(creds), options)
1065 }
1066 WebSocketAuth::Bearer { token, options } => {
1067 let creds = Credentials::Token { token };
1068 (Some(creds), options)
1069 }
1070 WebSocketAuth::OptionsOnly { options } => (None, options),
1071 };
1072
1073 let user = match (existing_user, creds) {
1074 (Some(ExistingUser::XMaterializeUserHeader(_)), Some(_creds)) => {
1075 warn!("Unexpected bearer or basic auth provided when using user header");
1076 anyhow::bail!("unexpected")
1077 }
1078 (Some(ExistingUser::Session(user)), None)
1079 | (Some(ExistingUser::XMaterializeUserHeader(user)), None) => user,
1080 (_, Some(creds)) => {
1081 let authenticator = get_authenticator(
1082 authenticator_kind,
1083 Some(&creds),
1084 frontegg,
1085 &oidc_rx,
1086 &adapter_client_rx,
1087 )
1088 .await;
1089 let user = auth(&authenticator, Some(creds), allowed_roles, false).await?;
1090 user
1091 }
1092 (None, None) => anyhow::bail!("expected auth information"),
1093 };
1094
1095 let client = AuthedClient::new(
1096 &adapter_client_rx.clone().await?,
1097 user,
1098 peer_addr,
1099 active_connection_counter.clone(),
1100 helm_chart_version.clone(),
1101 |_session| (),
1102 options,
1103 SYSTEM_TIME.clone(),
1104 )
1105 .await?;
1106
1107 Ok(client)
1108}
1109
1110enum Credentials {
1111 Password {
1112 username: String,
1113 password: Password,
1114 },
1115 Token {
1116 token: String,
1117 },
1118}
1119
1120async fn get_authenticator(
1121 kind: listeners::AuthenticatorKind,
1122 creds: Option<&Credentials>,
1123 frontegg: Option<mz_frontegg_auth::Authenticator>,
1124 oidc_rx: &Delayed<mz_authenticator::GenericOidcAuthenticator>,
1125 adapter_client_rx: &Delayed<Client>,
1126) -> Authenticator {
1127 match kind {
1128 listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
1129 "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
1130 )),
1131 listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Sasl => {
1132 let client = adapter_client_rx.clone().await.expect("sender not dropped");
1133 Authenticator::Password(client)
1134 }
1135 listeners::AuthenticatorKind::Oidc => match creds {
1136 Some(Credentials::Password { .. }) => {
1138 let client = adapter_client_rx.clone().await.expect("sender not dropped");
1139 Authenticator::Password(client)
1140 }
1141 _ => Authenticator::Oidc(oidc_rx.clone().await.expect("sender not dropped")),
1142 },
1143 listeners::AuthenticatorKind::None => Authenticator::None,
1144 }
1145}
1146
1147pub(crate) async fn maybe_get_authenticated_session(
1151 session: Option<&TowerSession>,
1152) -> Option<(&TowerSession, TowerSessionData)> {
1153 if let Some(session) = session {
1154 if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
1155 return Some((session, session_data));
1156 }
1157 }
1158 None
1159}
1160
1161pub(crate) async fn ensure_session_unexpired(
1164 session: &TowerSession,
1165 session_data: TowerSessionData,
1166) -> Result<AuthedUser, AuthError> {
1167 if session_data
1168 .last_activity
1169 .elapsed()
1170 .unwrap_or(Duration::MAX)
1171 > SESSION_DURATION
1172 {
1173 let _ = session.delete().await;
1174 return Err(AuthError::SessionExpired);
1175 }
1176 let mut updated_data = session_data.clone();
1177 updated_data.last_activity = SystemTime::now();
1178 session
1179 .insert("data", &updated_data)
1180 .await
1181 .map_err(|_| AuthError::FailedToUpdateSession)?;
1182
1183 Ok(AuthedUser {
1184 name: session_data.username,
1185 external_metadata_rx: None,
1186 authenticated: session_data.authenticated,
1187 authenticator_kind: session_data.authenticator_kind,
1188 })
1189}
1190
1191async fn auth(
1192 authenticator: &Authenticator,
1193 creds: Option<Credentials>,
1194 allowed_roles: AllowedRoles,
1195 include_www_authenticate_header: bool,
1196) -> Result<AuthedUser, AuthError> {
1197 let (name, external_metadata_rx, authenticated) = match authenticator {
1198 Authenticator::Frontegg(frontegg) => match creds {
1199 Some(Credentials::Password { username, password }) => {
1200 let (auth_session, authenticated) =
1201 frontegg.authenticate(&username, password.as_str()).await?;
1202 let name = auth_session.user().into();
1203 let external_metadata_rx = Some(auth_session.external_metadata_rx());
1204 (name, external_metadata_rx, authenticated)
1205 }
1206 Some(Credentials::Token { token }) => {
1207 let (claims, authenticated) = frontegg.validate_access_token(&token, None)?;
1208 let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
1209 user_id: claims.user_id,
1210 admin: claims.is_admin,
1211 });
1212 (claims.user, Some(external_metadata_rx), authenticated)
1213 }
1214 None => {
1215 return Err(AuthError::MissingHttpAuthentication {
1216 include_www_authenticate_header,
1217 });
1218 }
1219 },
1220 Authenticator::Password(adapter_client) => match creds {
1221 Some(Credentials::Password { username, password }) => {
1222 let authenticated = adapter_client
1223 .authenticate(&username, &password)
1224 .await
1225 .map_err(|_| AuthError::InvalidCredentials)?;
1226 (username, None, authenticated)
1227 }
1228 _ => {
1229 return Err(AuthError::MissingHttpAuthentication {
1230 include_www_authenticate_header,
1231 });
1232 }
1233 },
1234 Authenticator::Sasl(_) => {
1235 return Err(AuthError::MissingHttpAuthentication {
1239 include_www_authenticate_header,
1240 });
1241 }
1242 Authenticator::Oidc(oidc) => match creds {
1243 Some(Credentials::Token { token }) => {
1244 let (mut claims, authenticated) = oidc
1246 .authenticate(&token, None)
1247 .await
1248 .map_err(|_| AuthError::InvalidCredentials)?;
1249 let name = std::mem::take(&mut claims.user);
1250 (name, None, authenticated)
1251 }
1252 _ => {
1253 return Err(AuthError::MissingHttpAuthentication {
1254 include_www_authenticate_header,
1255 });
1256 }
1257 },
1258 Authenticator::None => {
1259 let name = match creds {
1263 Some(Credentials::Password { username, .. }) => username,
1264 _ => HTTP_DEFAULT_USER.name.to_owned(),
1265 };
1266 (name, None, Authenticated)
1267 }
1268 };
1269
1270 check_role_allowed(&name, allowed_roles)?;
1271
1272 Ok(AuthedUser {
1273 name,
1274 external_metadata_rx,
1275 authenticated,
1276 authenticator_kind: authenticator.kind(),
1277 })
1278}
1279
1280fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1282 let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1283 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1285 let role_allowed = match allowed_roles {
1286 AllowedRoles::Normal => !is_reserved_user,
1287 AllowedRoles::Internal => is_internal_user,
1288 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1289 };
1290 if role_allowed {
1291 Ok(())
1292 } else {
1293 Err(AuthError::RoleDisallowed(name.to_owned()))
1294 }
1295}
1296
1297trait DefaultLayers {
1300 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1301}
1302
1303impl DefaultLayers for Router {
1304 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1305 self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1306 .layer(metrics::PrometheusLayer::new(source, metrics))
1307 }
1308}
1309
1310async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1317 if error.is::<tower::load_shed::error::Overloaded>() {
1318 return (
1319 StatusCode::TOO_MANY_REQUESTS,
1320 Cow::from("too many requests, try again later"),
1321 );
1322 }
1323
1324 (
1327 StatusCode::INTERNAL_SERVER_ERROR,
1328 Cow::from(format!("Unhandled internal error: {}", error)),
1329 )
1330}
1331
1332#[derive(Debug, Deserialize, Serialize, PartialEq)]
1333pub struct LoginCredentials {
1334 username: String,
1335 password: Password,
1336}
1337
1338#[derive(Debug, Clone, Serialize, Deserialize)]
1339pub struct TowerSessionData {
1340 username: String,
1341 created_at: SystemTime,
1342 last_activity: SystemTime,
1343 authenticated: Authenticated,
1344 authenticator_kind: mz_auth::AuthenticatorKind,
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349 use super::{AllowedRoles, check_role_allowed};
1350
1351 #[mz_ore::test]
1352 fn test_check_role_allowed() {
1353 assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1355 assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1356 assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1357
1358 assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1360 assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1361 assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1362
1363 assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1365 assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1366 assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1367
1368 assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1370 assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1371 assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1372
1373 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1375 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1376 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1377
1378 assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1380 assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1381 assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1382
1383 assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1385 assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1386 assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1387
1388 assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1390 assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1391 assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1392
1393 assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1395 assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1396 assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1397
1398 assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1400 assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1401 assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1402 }
1403}