1#![allow(clippy::unused_async)]
51
52use std::borrow::Cow;
53use std::collections::BTreeMap;
54use std::fmt::Debug;
55use std::net::{IpAddr, SocketAddr};
56use std::pin::Pin;
57use std::sync::Arc;
58use std::time::{Duration, SystemTime};
59
60use anyhow::Context;
61use axum::error_handling::HandleErrorLayer;
62use axum::extract::ws::{Message, WebSocket};
63use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
64use axum::middleware::{self, Next};
65use axum::response::{IntoResponse, Redirect, Response};
66use axum::{Extension, Json, Router, routing};
67use futures::future::{Shared, TryFutureExt};
68use headers::authorization::{Authorization, Basic, Bearer};
69use headers::{HeaderMapExt, HeaderName};
70use http::header::{AUTHORIZATION, CONTENT_TYPE};
71use http::uri::Scheme;
72use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
73use hyper_openssl::SslStream;
74use hyper_openssl::client::legacy::MaybeHttpsStream;
75use hyper_util::rt::TokioIo;
76use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
77use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
78use mz_auth::Authenticated;
79use mz_auth::password::Password;
80use mz_authenticator::Authenticator;
81use mz_controller::ReplicaHttpLocator;
82use mz_frontegg_auth::Error as FronteggError;
83use mz_http_util::DynamicFilterTarget;
84use mz_ore::cast::u64_to_usize;
85use mz_ore::metrics::MetricsRegistry;
86use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
87use mz_ore::str::StrExt;
88use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
89use mz_repr::user::ExternalUserMetadata;
90use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind, HttpRoutesEnabled};
91use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
92use mz_sql::session::metadata::SessionMetadata;
93use mz_sql::session::user::{
94 HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
95};
96use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
97use openssl::ssl::Ssl;
98use prometheus::{
99 COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
100};
101use serde::{Deserialize, Serialize};
102use serde_json::json;
103use thiserror::Error;
104use tokio::io::AsyncWriteExt;
105use tokio::sync::oneshot::Receiver;
106use tokio::sync::{oneshot, watch};
107use tokio_metrics::TaskMetrics;
108use tower::limit::GlobalConcurrencyLimitLayer;
109use tower::{Service, ServiceBuilder};
110use tower_http::cors::{AllowOrigin, Any, CorsLayer};
111use tower_sessions::{
112 MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
113 SessionManagerLayer as TowerSessionManagerLayer,
114};
115use tracing::warn;
116
117use crate::BUILD_INFO;
118use crate::deployment::state::DeploymentStateHandle;
119use crate::http::sql::{ExistingUser, SqlError};
120
121mod catalog;
122mod cluster;
123mod console;
124mod mcp;
125mod memory;
126mod metrics;
127mod metrics_viz;
128mod probe;
129mod prometheus;
130mod root;
131mod sql;
132mod webhook;
133
134pub use metrics::Metrics;
135pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
136
137pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
139
140const SESSION_DURATION: Duration = Duration::from_secs(8 * 3600); const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
143
144#[derive(Debug)]
145pub struct HttpConfig {
146 pub source: &'static str,
147 pub tls: Option<ReloadingSslContext>,
148 pub authenticator_kind: AuthenticatorKind,
149 pub frontegg: Option<mz_frontegg_auth::Authenticator>,
150 pub oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
151 pub adapter_client_rx: Shared<Receiver<Client>>,
152 pub allowed_origin: AllowOrigin,
153 pub active_connection_counter: ConnectionCounter,
154 pub helm_chart_version: Option<String>,
155 pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
156 pub metrics: Metrics,
157 pub metrics_registry: MetricsRegistry,
158 pub allowed_roles: AllowedRoles,
159 pub internal_route_config: Arc<InternalRouteConfig>,
160 pub routes_enabled: HttpRoutesEnabled,
161 pub replica_http_locator: Arc<ReplicaHttpLocator>,
163}
164
165#[derive(Debug, Clone)]
166pub struct InternalRouteConfig {
167 pub deployment_state_handle: DeploymentStateHandle,
168 pub internal_console_redirect_url: Option<String>,
169}
170
171#[derive(Clone)]
172pub struct WsState {
173 frontegg: Option<mz_frontegg_auth::Authenticator>,
174 oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
175 authenticator_kind: AuthenticatorKind,
176 adapter_client_rx: Delayed<mz_adapter::Client>,
177 active_connection_counter: ConnectionCounter,
178 helm_chart_version: Option<String>,
179 allowed_roles: AllowedRoles,
180}
181
182#[derive(Clone)]
183pub struct WebhookState {
184 adapter_client_rx: Delayed<mz_adapter::Client>,
185 webhook_cache: WebhookAppenderCache,
186}
187
188#[derive(Debug)]
189pub struct HttpServer {
190 tls: Option<ReloadingSslContext>,
191 router: Router,
192}
193
194impl HttpServer {
195 pub fn new(
196 HttpConfig {
197 source,
198 tls,
199 authenticator_kind,
200 frontegg,
201 oidc_rx,
202 adapter_client_rx,
203 allowed_origin,
204 active_connection_counter,
205 helm_chart_version,
206 concurrent_webhook_req,
207 metrics,
208 metrics_registry,
209 allowed_roles,
210 internal_route_config,
211 routes_enabled,
212 replica_http_locator,
213 }: HttpConfig,
214 ) -> HttpServer {
215 let tls_enabled = tls.is_some();
216 let webhook_cache = WebhookAppenderCache::new();
217
218 let session_store = TowerSessionMemoryStore::default();
220 let session_layer = TowerSessionManagerLayer::new(session_store)
221 .with_secure(tls_enabled) .with_same_site(tower_sessions::cookie::SameSite::Strict) .with_http_only(true) .with_name("mz_session") .with_path("/"); let frontegg_middleware = frontegg.clone();
228 let oidc_middleware_rx = oidc_rx.clone();
229 let adapter_client_middleware_rx = adapter_client_rx.clone();
230 let auth_middleware = middleware::from_fn(move |req, next| {
231 let frontegg = frontegg_middleware.clone();
232 let oidc_rx = oidc_middleware_rx.clone();
233 let adapter_client_rx = adapter_client_middleware_rx.clone();
234 async move {
235 http_auth(
236 req,
237 next,
238 tls_enabled,
239 authenticator_kind,
240 frontegg,
241 oidc_rx,
242 adapter_client_rx,
243 allowed_roles,
244 )
245 .await
246 }
247 });
248
249 let mut router = Router::new();
250 let mut base_router = Router::new();
251 if routes_enabled.base {
252 base_router = base_router
253 .route(
254 "/",
255 routing::get(move || async move { root::handle_home(routes_enabled).await }),
256 )
257 .route("/api/sql", routing::post(sql::handle_sql))
258 .route("/memory", routing::get(memory::handle_memory))
259 .route(
260 "/hierarchical-memory",
261 routing::get(memory::handle_hierarchical_memory),
262 )
263 .route(
264 "/metrics-viz",
265 routing::get(metrics_viz::handle_metrics_viz),
266 )
267 .route("/static/{*path}", routing::get(root::handle_static));
268
269 let mut ws_router = Router::new()
270 .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
271 .with_state(WsState {
272 frontegg,
273 oidc_rx: oidc_rx.clone(),
274 authenticator_kind,
275 adapter_client_rx: adapter_client_rx.clone(),
276 active_connection_counter: active_connection_counter.clone(),
277 helm_chart_version,
278 allowed_roles,
279 });
280 if let AuthenticatorKind::None = authenticator_kind {
281 ws_router = ws_router.layer(middleware::from_fn(x_materialize_user_header_auth));
282 }
283 router = router.merge(ws_router);
284 }
285 if routes_enabled.profiling {
286 base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
287 }
288
289 if routes_enabled.webhook {
290 let webhook_router = Router::new()
291 .route(
292 "/api/webhook/{:database}/{:schema}/{:id}",
293 routing::post(webhook::handle_webhook),
294 )
295 .with_state(WebhookState {
296 adapter_client_rx: adapter_client_rx.clone(),
297 webhook_cache,
298 })
299 .layer(
300 tower_http::decompression::RequestDecompressionLayer::new()
301 .gzip(true)
302 .deflate(true)
303 .br(true)
304 .zstd(true),
305 )
306 .layer(
307 CorsLayer::new()
308 .allow_methods(Method::POST)
309 .allow_origin(AllowOrigin::mirror_request())
310 .allow_headers(Any),
311 )
312 .layer(
313 ServiceBuilder::new()
314 .layer(HandleErrorLayer::new(handle_load_error))
315 .load_shed()
316 .layer(GlobalConcurrencyLimitLayer::with_semaphore(
317 concurrent_webhook_req,
318 )),
319 );
320 router = router.merge(webhook_router);
321 }
322
323 if routes_enabled.internal {
324 let console_config = Arc::new(console::ConsoleProxyConfig::new(
325 internal_route_config.internal_console_redirect_url.clone(),
326 "/internal-console".to_string(),
327 ));
328 base_router = base_router
329 .route(
330 "/api/opentelemetry/config",
331 routing::put({
332 move |_: axum::Json<DynamicFilterTarget>| async {
333 (
334 StatusCode::BAD_REQUEST,
335 "This endpoint has been replaced. \
336 Use the `opentelemetry_filter` system variable."
337 .to_string(),
338 )
339 }
340 }),
341 )
342 .route(
343 "/api/stderr/config",
344 routing::put({
345 move |_: axum::Json<DynamicFilterTarget>| async {
346 (
347 StatusCode::BAD_REQUEST,
348 "This endpoint has been replaced. \
349 Use the `log_filter` system variable."
350 .to_string(),
351 )
352 }
353 }),
354 )
355 .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
356 .route(
357 "/api/catalog/dump",
358 routing::get(catalog::handle_catalog_dump),
359 )
360 .route(
361 "/api/catalog/check",
362 routing::get(catalog::handle_catalog_check),
363 )
364 .route(
365 "/api/coordinator/check",
366 routing::get(catalog::handle_coordinator_check),
367 )
368 .route(
369 "/api/coordinator/dump",
370 routing::get(catalog::handle_coordinator_dump),
371 )
372 .route(
373 "/internal-console",
374 routing::get(|| async { Redirect::temporary("/internal-console/") }),
375 )
376 .route(
377 "/internal-console/{*path}",
378 routing::get(console::handle_internal_console),
379 )
380 .route(
381 "/internal-console/",
382 routing::get(console::handle_internal_console),
383 )
384 .layer(Extension(console_config));
385
386 let cluster_proxy_config = Arc::new(cluster::ClusterProxyConfig::new(Arc::clone(
388 &replica_http_locator,
389 )));
390 base_router = base_router
391 .route("/clusters", routing::get(cluster::handle_clusters))
392 .route(
393 "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/",
394 routing::any(cluster::handle_cluster_proxy_root),
395 )
396 .route(
397 "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/{*path}",
398 routing::any(cluster::handle_cluster_proxy),
399 )
400 .layer(Extension(cluster_proxy_config));
401
402 let leader_router = Router::new()
403 .route("/api/leader/status", routing::get(handle_leader_status))
404 .route("/api/leader/promote", routing::post(handle_leader_promote))
405 .route(
406 "/api/leader/skip-catchup",
407 routing::post(handle_leader_skip_catchup),
408 )
409 .layer(auth_middleware.clone())
410 .with_state(internal_route_config.deployment_state_handle.clone());
411 router = router.merge(leader_router);
412 }
413
414 if routes_enabled.metrics {
415 let metrics_router = Router::new()
416 .route(
417 "/metrics",
418 routing::get(move || async move {
419 mz_http_util::handle_prometheus(&metrics_registry).await
420 }),
421 )
422 .route(
423 "/metrics/mz_usage",
424 routing::get(|client: AuthedClient| async move {
425 let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
426 mz_http_util::handle_prometheus(®istry).await
427 }),
428 )
429 .route(
430 "/metrics/mz_frontier",
431 routing::get(|client: AuthedClient| async move {
432 let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
433 mz_http_util::handle_prometheus(®istry).await
434 }),
435 )
436 .route(
437 "/metrics/mz_compute",
438 routing::get(|client: AuthedClient| async move {
439 let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
440 mz_http_util::handle_prometheus(®istry).await
441 }),
442 )
443 .route(
444 "/metrics/mz_storage",
445 routing::get(|client: AuthedClient| async move {
446 let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
447 mz_http_util::handle_prometheus(®istry).await
448 }),
449 )
450 .route(
451 "/api/livez",
452 routing::get(mz_http_util::handle_liveness_check),
453 )
454 .route("/api/readyz", routing::get(probe::handle_ready))
455 .layer(auth_middleware.clone())
456 .layer(Extension(adapter_client_rx.clone()))
457 .layer(Extension(active_connection_counter.clone()));
458 router = router.merge(metrics_router);
459 }
460
461 if routes_enabled.mcp_agents || routes_enabled.mcp_observatory {
464 use tracing::info;
465
466 let mut mcp_router = Router::new();
467
468 if routes_enabled.mcp_agents {
469 info!("Enabling MCP agents endpoint: /api/mcp/agents");
470 mcp_router =
471 mcp_router.route("/api/mcp/agents", routing::post(mcp::handle_mcp_agents));
472 }
473
474 if routes_enabled.mcp_observatory {
475 info!("Enabling MCP observatory endpoint: /api/mcp/observatory");
476 mcp_router = mcp_router.route(
477 "/api/mcp/observatory",
478 routing::post(mcp::handle_mcp_observatory),
479 );
480 }
481
482 mcp_router = mcp_router
483 .layer(auth_middleware.clone())
484 .layer(Extension(adapter_client_rx.clone()))
485 .layer(Extension(active_connection_counter.clone()))
486 .layer(
487 CorsLayer::new()
488 .allow_methods(Method::POST)
489 .allow_origin(AllowOrigin::mirror_request())
490 .allow_headers(Any),
491 );
492 router = router.merge(mcp_router);
493 }
494
495 base_router = base_router
496 .layer(auth_middleware.clone())
497 .layer(Extension(adapter_client_rx.clone()))
498 .layer(Extension(active_connection_counter.clone()))
499 .layer(
500 CorsLayer::new()
501 .allow_credentials(false)
502 .allow_headers([
503 AUTHORIZATION,
504 CONTENT_TYPE,
505 HeaderName::from_static("x-materialize-version"),
506 ])
507 .allow_methods(Any)
508 .allow_origin(allowed_origin)
509 .expose_headers(Any)
510 .max_age(Duration::from_secs(60) * 60),
511 );
512
513 match authenticator_kind {
514 AuthenticatorKind::Password | AuthenticatorKind::Oidc => {
515 base_router = base_router.layer(session_layer.clone());
516
517 let login_router = Router::new()
518 .route("/api/login", routing::post(handle_login))
519 .route("/api/logout", routing::post(handle_logout))
520 .layer(Extension(adapter_client_rx));
521 router = router.merge(login_router).layer(session_layer);
522 }
523 AuthenticatorKind::None => {
524 base_router =
525 base_router.layer(middleware::from_fn(x_materialize_user_header_auth));
526 }
527 _ => {}
528 }
529
530 router = router
531 .merge(base_router)
532 .apply_default_layers(source, metrics);
533
534 HttpServer { tls, router }
535 }
536}
537
538impl Server for HttpServer {
539 const NAME: &'static str = "http";
540
541 fn handle_connection(
542 &self,
543 conn: Connection,
544 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
545 ) -> ConnectionHandler {
546 let router = self.router.clone();
547 let tls_context = self.tls.clone();
548 let mut conn = TokioIo::new(conn);
549
550 Box::pin(async {
551 let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
552 let peer_addr = conn
553 .inner_mut()
554 .take_proxy_header_address()
555 .await
556 .map(|a| a.source)
557 .unwrap_or(direct_peer_addr);
558
559 let (conn, conn_protocol) = match tls_context {
560 Some(tls_context) => {
561 let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
562 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
563 let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
564 return Err(e.into());
565 }
566 (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
567 }
568 _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
569 };
570 let mut make_tower_svc = router
571 .layer(Extension(conn_protocol))
572 .into_make_service_with_connect_info::<SocketAddr>();
573 let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
574 let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
575 let http = hyper::server::conn::http1::Builder::new();
576 http.serve_connection(conn, hyper_svc)
577 .with_upgrades()
578 .err_into()
579 .await
580 })
581 }
582}
583
584pub async fn handle_leader_status(
585 State(deployment_state_handle): State<DeploymentStateHandle>,
586) -> impl IntoResponse {
587 let status = deployment_state_handle.status();
588 (StatusCode::OK, Json(json!({ "status": status })))
589}
590
591pub async fn handle_leader_promote(
592 State(deployment_state_handle): State<DeploymentStateHandle>,
593) -> impl IntoResponse {
594 match deployment_state_handle.try_promote() {
595 Ok(()) => {
596 let status = StatusCode::OK;
599 let body = Json(json!({
600 "result": "Success",
601 }));
602 (status, body)
603 }
604 Err(()) => {
605 let status = StatusCode::BAD_REQUEST;
608 let body = Json(json!({
609 "result": {"Failure": {"message": "cannot promote leader while initializing"}},
610 }));
611 (status, body)
612 }
613 }
614}
615
616pub async fn handle_leader_skip_catchup(
617 State(deployment_state_handle): State<DeploymentStateHandle>,
618) -> impl IntoResponse {
619 match deployment_state_handle.try_skip_catchup() {
620 Ok(()) => StatusCode::NO_CONTENT.into_response(),
621 Err(()) => {
622 let status = StatusCode::BAD_REQUEST;
623 let body = Json(json!({
624 "message": "cannot skip catchup in this phase of initialization; try again later",
625 }));
626 (status, body).into_response()
627 }
628 }
629}
630
631async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl IntoResponse {
632 if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
634 let username = match username {
635 Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
636 _ => {
637 return Err(AuthError::MismatchedUser(format!(
638 "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
639 )));
640 }
641 };
642 req.extensions_mut().insert(AuthedUser {
643 name: username,
644 external_metadata_rx: None,
645 authenticated: Authenticated,
646 });
647 }
648 Ok(next.run(req).await)
649}
650
651type Delayed<T> = Shared<oneshot::Receiver<T>>;
652
653#[derive(Clone)]
654enum ConnProtocol {
655 Http,
656 Https,
657}
658
659#[derive(Clone, Debug)]
660pub struct AuthedUser {
661 name: String,
662 external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
663 authenticated: Authenticated,
664}
665
666pub struct AuthedClient {
667 pub client: SessionClient,
668 pub connection_guard: Option<ConnectionHandle>,
669}
670
671impl AuthedClient {
672 async fn new<F>(
673 adapter_client: &Client,
674 user: AuthedUser,
675 peer_addr: IpAddr,
676 active_connection_counter: ConnectionCounter,
677 helm_chart_version: Option<String>,
678 session_config: F,
679 options: BTreeMap<String, String>,
680 now: NowFn,
681 ) -> Result<Self, AdapterError>
682 where
683 F: FnOnce(&mut AdapterSession),
684 {
685 let conn_id = adapter_client.new_conn_id()?;
686 let mut session = adapter_client.new_session(
687 AdapterSessionConfig {
688 conn_id,
689 uuid: epoch_to_uuid_v7(&(now)()),
690 user: user.name,
691 client_ip: Some(peer_addr),
692 external_metadata_rx: user.external_metadata_rx,
693 helm_chart_version,
694 },
695 user.authenticated,
696 );
697 let connection_guard = active_connection_counter.allocate_connection(session.user())?;
698
699 session_config(&mut session);
700 let system_vars = adapter_client.get_system_vars().await;
701 for (key, val) in options {
702 const LOCAL: bool = false;
703 if let Err(err) =
704 session
705 .vars_mut()
706 .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
707 {
708 session.add_notice(AdapterNotice::BadStartupSetting {
709 name: key.to_string(),
710 reason: err.to_string(),
711 })
712 }
713 }
714 let adapter_client = adapter_client.startup(session).await?;
715 Ok(AuthedClient {
716 client: adapter_client,
717 connection_guard,
718 })
719 }
720}
721
722impl<S> FromRequestParts<S> for AuthedClient
723where
724 S: Send + Sync,
725{
726 type Rejection = Response;
727
728 async fn from_request_parts(
729 req: &mut http::request::Parts,
730 state: &S,
731 ) -> Result<Self, Self::Rejection> {
732 #[derive(Debug, Default, Deserialize)]
733 struct Params {
734 #[serde(default)]
735 options: String,
736 }
737 let params: Query<Params> = Query::from_request_parts(req, state)
738 .await
739 .unwrap_or_default();
740
741 let peer_addr = req
742 .extensions
743 .get::<ConnectInfo<SocketAddr>>()
744 .expect("ConnectInfo extension guaranteed to exist")
745 .0
746 .ip();
747
748 let user = req.extensions.get::<AuthedUser>().unwrap();
749 let adapter_client = req
750 .extensions
751 .get::<Delayed<mz_adapter::Client>>()
752 .unwrap()
753 .clone();
754 let adapter_client = adapter_client.await.map_err(|_| {
755 (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
756 })?;
757 let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
758 let helm_chart_version = None;
759
760 let options = if params.options.is_empty() {
761 BTreeMap::<String, String>::default()
764 } else {
765 match serde_json::from_str(¶ms.options) {
766 Ok(options) => options,
767 Err(_e) => {
768 let code = StatusCode::BAD_REQUEST;
770 let msg = format!("Failed to deserialize {} map", "options".quoted());
771 return Err((code, msg).into_response());
772 }
773 }
774 };
775
776 let client = AuthedClient::new(
777 &adapter_client,
778 user.clone(),
779 peer_addr,
780 active_connection_counter.clone(),
781 helm_chart_version,
782 |session| {
783 session
784 .vars_mut()
785 .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
786 .expect("known to exist")
787 },
788 options,
789 SYSTEM_TIME.clone(),
790 )
791 .await
792 .map_err(|e| {
793 let status = match e {
794 AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
795 StatusCode::FORBIDDEN
796 }
797 _ => StatusCode::INTERNAL_SERVER_ERROR,
798 };
799 (status, Json(SqlError::from(e))).into_response()
800 })?;
801
802 Ok(client)
803 }
804}
805
806#[derive(Debug, Error)]
807pub(crate) enum AuthError {
808 #[error("role dissallowed")]
809 RoleDisallowed(String),
810 #[error("{0}")]
811 Frontegg(#[from] FronteggError),
812 #[error("missing authorization header")]
813 MissingHttpAuthentication {
814 include_www_authenticate_header: bool,
815 },
816 #[error("{0}")]
817 MismatchedUser(String),
818 #[error("session expired")]
819 SessionExpired,
820 #[error("failed to update session")]
821 FailedToUpdateSession,
822 #[error("invalid credentials")]
823 InvalidCredentials,
824}
825
826impl IntoResponse for AuthError {
827 fn into_response(self) -> Response {
828 warn!("HTTP request failed authentication: {}", self);
829 let mut headers = HeaderMap::new();
830 match self {
831 AuthError::MissingHttpAuthentication {
832 include_www_authenticate_header,
833 } if include_www_authenticate_header => {
834 headers.insert(
835 http::header::WWW_AUTHENTICATE,
836 HeaderValue::from_static("Basic realm=Materialize"),
837 );
838 }
839 _ => {}
840 };
841 (StatusCode::UNAUTHORIZED, headers, "unauthorized").into_response()
844 }
845}
846
847pub async fn handle_login(
849 session: Option<Extension<TowerSession>>,
850 Extension(adapter_client_rx): Extension<Delayed<Client>>,
851 Json(LoginCredentials { username, password }): Json<LoginCredentials>,
852) -> impl IntoResponse {
853 let Ok(adapter_client) = adapter_client_rx.clone().await else {
854 return StatusCode::INTERNAL_SERVER_ERROR;
855 };
856 let authenticated = match adapter_client.authenticate(&username, &password).await {
857 Ok(authenticated) => authenticated,
858 Err(err) => {
859 warn!(?err, "HTTP login failed authentication");
860 return StatusCode::UNAUTHORIZED;
861 }
862 };
863 let session_data = TowerSessionData {
865 username,
866 created_at: SystemTime::now(),
867 last_activity: SystemTime::now(),
868 authenticated,
869 };
870 let session = session.and_then(|Extension(session)| Some(session));
872 let Some(session) = session else {
873 return StatusCode::INTERNAL_SERVER_ERROR;
874 };
875 match session.insert("data", &session_data).await {
876 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
877 Ok(_) => StatusCode::OK,
878 }
879}
880
881pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
883 let session = session.and_then(|Extension(session)| Some(session));
884 let Some(session) = session else {
885 return StatusCode::INTERNAL_SERVER_ERROR;
886 };
887 match session.delete().await {
889 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
890 Ok(_) => StatusCode::OK,
891 }
892}
893
894async fn http_auth(
895 mut req: Request,
896 next: Next,
897 tls_enabled: bool,
898 authenticator_kind: AuthenticatorKind,
899 frontegg: Option<mz_frontegg_auth::Authenticator>,
900 oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
901 adapter_client_rx: Delayed<Client>,
902 allowed_roles: AllowedRoles,
903) -> Result<impl IntoResponse, AuthError> {
904 let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
905 Some(Credentials::Password {
906 username: basic.username().to_owned(),
907 password: Password(basic.password().to_owned()),
908 })
909 } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
910 Some(Credentials::Token {
911 token: bearer.token().to_owned(),
912 })
913 } else {
914 None
915 };
916
917 if creds.is_none()
921 && let Some((session, session_data)) =
922 maybe_get_authenticated_session(req.extensions().get::<TowerSession>()).await
923 {
924 let user = ensure_session_unexpired(session, session_data).await?;
925 req.extensions_mut().insert(user);
927 return Ok(next.run(req).await);
928 }
929
930 let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
934 match (tls_enabled, &conn_protocol) {
935 (false, ConnProtocol::Http) => {}
936 (false, ConnProtocol::Https { .. }) => unreachable!(),
937 (true, ConnProtocol::Http) => {
938 let mut parts = req.uri().clone().into_parts();
939 parts.scheme = Some(Scheme::HTTPS);
940 return Ok(Redirect::permanent(
941 &Uri::from_parts(parts)
942 .expect("it was already a URI, just changed the scheme")
943 .to_string(),
944 )
945 .into_response());
946 }
947 (true, ConnProtocol::Https { .. }) => {}
948 }
949 if req.extensions().get::<AuthedUser>().is_some() {
951 return Ok(next.run(req).await);
952 }
953
954 let path = req.uri().path();
955 let include_www_authenticate_header = path == "/"
956 || PROFILING_API_ENDPOINTS
957 .iter()
958 .any(|prefix| path.starts_with(prefix));
959 let authenticator = get_authenticator(
960 authenticator_kind,
961 creds.as_ref(),
962 frontegg,
963 &oidc_rx,
964 &adapter_client_rx,
965 )
966 .await;
967
968 let user = auth(
969 &authenticator,
970 creds,
971 allowed_roles,
972 include_www_authenticate_header,
973 )
974 .await?;
975
976 req.extensions_mut().insert(user);
979
980 Ok(next.run(req).await)
982}
983
984async fn init_ws(
985 WsState {
986 frontegg,
987 oidc_rx,
988 authenticator_kind,
989 adapter_client_rx,
990 active_connection_counter,
991 helm_chart_version,
992 allowed_roles,
993 }: WsState,
994 existing_user: Option<ExistingUser>,
995 peer_addr: IpAddr,
996 ws: &mut WebSocket,
997) -> Result<AuthedClient, anyhow::Error> {
998 let ws_auth: WebSocketAuth = loop {
1001 let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
1002 match init_msg {
1003 Message::Text(data) => break serde_json::from_str(&data)?,
1004 Message::Binary(data) => break serde_json::from_slice(&data)?,
1005 Message::Ping(_) => {
1007 continue;
1008 }
1009 Message::Pong(_) => {
1010 continue;
1011 }
1012 Message::Close(_) => {
1013 anyhow::bail!("closed");
1014 }
1015 }
1016 };
1017
1018 let (creds, options) = match ws_auth {
1021 WebSocketAuth::Basic {
1022 user,
1023 password,
1024 options,
1025 } => {
1026 let creds = Credentials::Password {
1027 username: user,
1028 password,
1029 };
1030 (Some(creds), options)
1031 }
1032 WebSocketAuth::Bearer { token, options } => {
1033 let creds = Credentials::Token { token };
1034 (Some(creds), options)
1035 }
1036 WebSocketAuth::OptionsOnly { options } => (None, options),
1037 };
1038
1039 let user = match (existing_user, creds) {
1040 (Some(ExistingUser::XMaterializeUserHeader(_)), Some(_creds)) => {
1041 warn!("Unexpected bearer or basic auth provided when using user header");
1042 anyhow::bail!("unexpected")
1043 }
1044 (Some(ExistingUser::Session(user)), None)
1045 | (Some(ExistingUser::XMaterializeUserHeader(user)), None) => user,
1046 (_, Some(creds)) => {
1047 let authenticator = get_authenticator(
1048 authenticator_kind,
1049 Some(&creds),
1050 frontegg,
1051 &oidc_rx,
1052 &adapter_client_rx,
1053 )
1054 .await;
1055 let user = auth(&authenticator, Some(creds), allowed_roles, false).await?;
1056 user
1057 }
1058 (None, None) => anyhow::bail!("expected auth information"),
1059 };
1060
1061 let client = AuthedClient::new(
1062 &adapter_client_rx.clone().await?,
1063 user,
1064 peer_addr,
1065 active_connection_counter.clone(),
1066 helm_chart_version.clone(),
1067 |_session| (),
1068 options,
1069 SYSTEM_TIME.clone(),
1070 )
1071 .await?;
1072
1073 Ok(client)
1074}
1075
1076enum Credentials {
1077 Password {
1078 username: String,
1079 password: Password,
1080 },
1081 Token {
1082 token: String,
1083 },
1084}
1085
1086async fn get_authenticator(
1087 kind: AuthenticatorKind,
1088 creds: Option<&Credentials>,
1089 frontegg: Option<mz_frontegg_auth::Authenticator>,
1090 oidc_rx: &Delayed<mz_authenticator::GenericOidcAuthenticator>,
1091 adapter_client_rx: &Delayed<Client>,
1092) -> Authenticator {
1093 match kind {
1094 AuthenticatorKind::Frontegg => Authenticator::Frontegg(
1095 frontegg.expect("Frontegg authenticator should exist with AuthenticatorKind::Frontegg"),
1096 ),
1097 AuthenticatorKind::Password | AuthenticatorKind::Sasl => {
1098 let client = adapter_client_rx.clone().await.expect("sender not dropped");
1099 Authenticator::Password(client)
1100 }
1101 AuthenticatorKind::Oidc => match creds {
1102 Some(Credentials::Password { .. }) => {
1104 let client = adapter_client_rx.clone().await.expect("sender not dropped");
1105 Authenticator::Password(client)
1106 }
1107 _ => Authenticator::Oidc(oidc_rx.clone().await.expect("sender not dropped")),
1108 },
1109 AuthenticatorKind::None => Authenticator::None,
1110 }
1111}
1112
1113pub(crate) async fn maybe_get_authenticated_session(
1117 session: Option<&TowerSession>,
1118) -> Option<(&TowerSession, TowerSessionData)> {
1119 if let Some(session) = session {
1120 if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
1121 return Some((session, session_data));
1122 }
1123 }
1124 None
1125}
1126
1127pub(crate) async fn ensure_session_unexpired(
1130 session: &TowerSession,
1131 session_data: TowerSessionData,
1132) -> Result<AuthedUser, AuthError> {
1133 if session_data
1134 .last_activity
1135 .elapsed()
1136 .unwrap_or(Duration::MAX)
1137 > SESSION_DURATION
1138 {
1139 let _ = session.delete().await;
1140 return Err(AuthError::SessionExpired);
1141 }
1142 let mut updated_data = session_data.clone();
1143 updated_data.last_activity = SystemTime::now();
1144 session
1145 .insert("data", &updated_data)
1146 .await
1147 .map_err(|_| AuthError::FailedToUpdateSession)?;
1148
1149 Ok(AuthedUser {
1150 name: session_data.username,
1151 external_metadata_rx: None,
1152 authenticated: session_data.authenticated,
1153 })
1154}
1155
1156async fn auth(
1157 authenticator: &Authenticator,
1158 creds: Option<Credentials>,
1159 allowed_roles: AllowedRoles,
1160 include_www_authenticate_header: bool,
1161) -> Result<AuthedUser, AuthError> {
1162 let (name, external_metadata_rx, authenticated) = match authenticator {
1163 Authenticator::Frontegg(frontegg) => match creds {
1164 Some(Credentials::Password { username, password }) => {
1165 let (auth_session, authenticated) =
1166 frontegg.authenticate(&username, &password.0).await?;
1167 let name = auth_session.user().into();
1168 let external_metadata_rx = Some(auth_session.external_metadata_rx());
1169 (name, external_metadata_rx, authenticated)
1170 }
1171 Some(Credentials::Token { token }) => {
1172 let (claims, authenticated) = frontegg.validate_access_token(&token, None)?;
1173 let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
1174 user_id: claims.user_id,
1175 admin: claims.is_admin,
1176 });
1177 (claims.user, Some(external_metadata_rx), authenticated)
1178 }
1179 None => {
1180 return Err(AuthError::MissingHttpAuthentication {
1181 include_www_authenticate_header,
1182 });
1183 }
1184 },
1185 Authenticator::Password(adapter_client) => match creds {
1186 Some(Credentials::Password { username, password }) => {
1187 let authenticated = adapter_client
1188 .authenticate(&username, &password)
1189 .await
1190 .map_err(|_| AuthError::InvalidCredentials)?;
1191 (username, None, authenticated)
1192 }
1193 _ => {
1194 return Err(AuthError::MissingHttpAuthentication {
1195 include_www_authenticate_header,
1196 });
1197 }
1198 },
1199 Authenticator::Sasl(_) => {
1200 return Err(AuthError::MissingHttpAuthentication {
1204 include_www_authenticate_header,
1205 });
1206 }
1207 Authenticator::Oidc(oidc) => match creds {
1208 Some(Credentials::Token { token }) => {
1209 let (claims, authenticated) = oidc
1211 .authenticate(&token, None)
1212 .await
1213 .map_err(|_| AuthError::InvalidCredentials)?;
1214 let name = claims.user;
1215 (name, None, authenticated)
1216 }
1217 _ => {
1218 return Err(AuthError::MissingHttpAuthentication {
1219 include_www_authenticate_header,
1220 });
1221 }
1222 },
1223 Authenticator::None => {
1224 let name = match creds {
1228 Some(Credentials::Password { username, .. }) => username,
1229 _ => HTTP_DEFAULT_USER.name.to_owned(),
1230 };
1231 (name, None, Authenticated)
1232 }
1233 };
1234
1235 check_role_allowed(&name, allowed_roles)?;
1236
1237 Ok(AuthedUser {
1238 name,
1239 external_metadata_rx,
1240 authenticated,
1241 })
1242}
1243
1244fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1246 let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1247 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1249 let role_allowed = match allowed_roles {
1250 AllowedRoles::Normal => !is_reserved_user,
1251 AllowedRoles::Internal => is_internal_user,
1252 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1253 };
1254 if role_allowed {
1255 Ok(())
1256 } else {
1257 Err(AuthError::RoleDisallowed(name.to_owned()))
1258 }
1259}
1260
1261trait DefaultLayers {
1264 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1265}
1266
1267impl DefaultLayers for Router {
1268 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1269 self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1270 .layer(metrics::PrometheusLayer::new(source, metrics))
1271 }
1272}
1273
1274async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1281 if error.is::<tower::load_shed::error::Overloaded>() {
1282 return (
1283 StatusCode::TOO_MANY_REQUESTS,
1284 Cow::from("too many requests, try again later"),
1285 );
1286 }
1287
1288 (
1291 StatusCode::INTERNAL_SERVER_ERROR,
1292 Cow::from(format!("Unhandled internal error: {}", error)),
1293 )
1294}
1295
1296#[derive(Debug, Deserialize, Serialize, PartialEq)]
1297pub struct LoginCredentials {
1298 username: String,
1299 password: Password,
1300}
1301
1302#[derive(Debug, Clone, Serialize, Deserialize)]
1303pub struct TowerSessionData {
1304 username: String,
1305 created_at: SystemTime,
1306 last_activity: SystemTime,
1307 authenticated: Authenticated,
1308}
1309
1310#[cfg(test)]
1311mod tests {
1312 use super::{AllowedRoles, check_role_allowed};
1313
1314 #[mz_ore::test]
1315 fn test_check_role_allowed() {
1316 assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1318 assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1319 assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1320
1321 assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1323 assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1324 assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1325
1326 assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1328 assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1329 assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1330
1331 assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1333 assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1334 assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1335
1336 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1338 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1339 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1340
1341 assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1343 assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1344 assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1345
1346 assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1348 assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1349 assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1350
1351 assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1353 assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1354 assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1355
1356 assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1358 assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1359 assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1360
1361 assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1363 assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1364 assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1365 }
1366}