1#![allow(clippy::unused_async)]
51
52use std::borrow::Cow;
53use std::collections::BTreeMap;
54use std::fmt::Debug;
55use std::net::{IpAddr, SocketAddr};
56use std::pin::Pin;
57use std::sync::Arc;
58use std::time::{Duration, SystemTime};
59
60use anyhow::Context;
61use axum::error_handling::HandleErrorLayer;
62use axum::extract::ws::{Message, WebSocket};
63use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
64use axum::middleware::{self, Next};
65use axum::response::{IntoResponse, Redirect, Response};
66use axum::{Extension, Json, Router, routing};
67use futures::future::{Shared, TryFutureExt};
68use headers::authorization::{Authorization, Basic, Bearer};
69use headers::{HeaderMapExt, HeaderName};
70use http::header::{AUTHORIZATION, CONTENT_TYPE};
71use http::uri::Scheme;
72use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
73use hyper_openssl::SslStream;
74use hyper_openssl::client::legacy::MaybeHttpsStream;
75use hyper_util::rt::TokioIo;
76use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
77use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
78use mz_auth::Authenticated;
79use mz_auth::password::Password;
80use mz_authenticator::Authenticator;
81use mz_controller::ReplicaHttpLocator;
82use mz_frontegg_auth::Error as FronteggError;
83use mz_http_util::DynamicFilterTarget;
84use mz_ore::cast::u64_to_usize;
85use mz_ore::metrics::MetricsRegistry;
86use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
87use mz_ore::str::StrExt;
88use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
89use mz_repr::user::ExternalUserMetadata;
90use mz_server_core::listeners::{self, AllowedRoles, HttpRoutesEnabled};
91use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
92use mz_sql::session::metadata::SessionMetadata;
93use mz_sql::session::user::{
94 HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
95};
96use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
97use openssl::ssl::Ssl;
98use prometheus::{
99 COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
100};
101use serde::{Deserialize, Serialize};
102use serde_json::json;
103use thiserror::Error;
104use tokio::io::AsyncWriteExt;
105use tokio::sync::oneshot::Receiver;
106use tokio::sync::{oneshot, watch};
107use tokio_metrics::TaskMetrics;
108use tower::limit::GlobalConcurrencyLimitLayer;
109use tower::{Service, ServiceBuilder};
110use tower_http::cors::{AllowOrigin, Any, CorsLayer};
111use tower_sessions::{
112 MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
113 SessionManagerLayer as TowerSessionManagerLayer,
114};
115use tracing::warn;
116
117use crate::BUILD_INFO;
118use crate::deployment::state::DeploymentStateHandle;
119use crate::http::sql::{ExistingUser, SqlError};
120
121mod catalog;
122mod cluster;
123mod console;
124mod mcp;
125mod memory;
126mod metrics;
127mod metrics_viz;
128mod probe;
129mod prometheus;
130mod root;
131mod sql;
132mod webhook;
133
134pub use metrics::Metrics;
135pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
136
137pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
139
140const SESSION_DURATION: Duration = Duration::from_secs(8 * 3600); const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
143
144#[derive(Debug)]
145pub struct HttpConfig {
146 pub source: &'static str,
147 pub tls: Option<ReloadingSslContext>,
148 pub authenticator_kind: listeners::AuthenticatorKind,
149 pub frontegg: Option<mz_frontegg_auth::Authenticator>,
150 pub oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
151 pub adapter_client_rx: Shared<Receiver<Client>>,
152 pub allowed_origin: AllowOrigin,
153 pub active_connection_counter: ConnectionCounter,
154 pub helm_chart_version: Option<String>,
155 pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
156 pub metrics: Metrics,
157 pub metrics_registry: MetricsRegistry,
158 pub allowed_roles: AllowedRoles,
159 pub internal_route_config: Arc<InternalRouteConfig>,
160 pub routes_enabled: HttpRoutesEnabled,
161 pub replica_http_locator: Arc<ReplicaHttpLocator>,
163}
164
165#[derive(Debug, Clone)]
166pub struct InternalRouteConfig {
167 pub deployment_state_handle: DeploymentStateHandle,
168 pub internal_console_redirect_url: Option<String>,
169}
170
171#[derive(Clone)]
172pub struct WsState {
173 frontegg: Option<mz_frontegg_auth::Authenticator>,
174 oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
175 authenticator_kind: listeners::AuthenticatorKind,
176 adapter_client_rx: Delayed<mz_adapter::Client>,
177 active_connection_counter: ConnectionCounter,
178 helm_chart_version: Option<String>,
179 allowed_roles: AllowedRoles,
180}
181
182#[derive(Clone)]
183pub struct WebhookState {
184 adapter_client_rx: Delayed<mz_adapter::Client>,
185 webhook_cache: WebhookAppenderCache,
186}
187
188#[derive(Debug)]
189pub struct HttpServer {
190 tls: Option<ReloadingSslContext>,
191 router: Router,
192}
193
194impl HttpServer {
195 pub fn new(
196 HttpConfig {
197 source,
198 tls,
199 authenticator_kind,
200 frontegg,
201 oidc_rx,
202 adapter_client_rx,
203 allowed_origin,
204 active_connection_counter,
205 helm_chart_version,
206 concurrent_webhook_req,
207 metrics,
208 metrics_registry,
209 allowed_roles,
210 internal_route_config,
211 routes_enabled,
212 replica_http_locator,
213 }: HttpConfig,
214 ) -> HttpServer {
215 let tls_enabled = tls.is_some();
216 let webhook_cache = WebhookAppenderCache::new();
217
218 let session_store = TowerSessionMemoryStore::default();
220 let session_layer = TowerSessionManagerLayer::new(session_store)
221 .with_secure(tls_enabled) .with_same_site(tower_sessions::cookie::SameSite::Strict) .with_http_only(true) .with_name("mz_session") .with_path("/"); let frontegg_middleware = frontegg.clone();
228 let oidc_middleware_rx = oidc_rx.clone();
229 let adapter_client_middleware_rx = adapter_client_rx.clone();
230 let auth_middleware = middleware::from_fn(move |req, next| {
231 let frontegg = frontegg_middleware.clone();
232 let oidc_rx = oidc_middleware_rx.clone();
233 let adapter_client_rx = adapter_client_middleware_rx.clone();
234 async move {
235 http_auth(
236 req,
237 next,
238 tls_enabled,
239 authenticator_kind,
240 frontegg,
241 oidc_rx,
242 adapter_client_rx,
243 allowed_roles,
244 )
245 .await
246 }
247 });
248
249 let mut router = Router::new();
250 let mut base_router = Router::new();
251 if routes_enabled.base {
252 base_router = base_router
253 .route(
254 "/",
255 routing::get(move || async move { root::handle_home(routes_enabled).await }),
256 )
257 .route("/api/sql", routing::post(sql::handle_sql))
258 .route("/memory", routing::get(memory::handle_memory))
259 .route(
260 "/hierarchical-memory",
261 routing::get(memory::handle_hierarchical_memory),
262 )
263 .route(
264 "/metrics-viz",
265 routing::get(metrics_viz::handle_metrics_viz),
266 )
267 .route("/static/{*path}", routing::get(root::handle_static));
268
269 let mut ws_router = Router::new()
270 .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
271 .with_state(WsState {
272 frontegg,
273 oidc_rx: oidc_rx.clone(),
274 authenticator_kind,
275 adapter_client_rx: adapter_client_rx.clone(),
276 active_connection_counter: active_connection_counter.clone(),
277 helm_chart_version,
278 allowed_roles,
279 });
280 if let listeners::AuthenticatorKind::None = authenticator_kind {
281 ws_router = ws_router.layer(middleware::from_fn(x_materialize_user_header_auth));
282 }
283 router = router.merge(ws_router);
284 }
285 if routes_enabled.profiling {
286 base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
287 }
288
289 if routes_enabled.webhook {
290 let webhook_router = Router::new()
291 .route(
292 "/api/webhook/{:database}/{:schema}/{:id}",
293 routing::post(webhook::handle_webhook),
294 )
295 .with_state(WebhookState {
296 adapter_client_rx: adapter_client_rx.clone(),
297 webhook_cache,
298 })
299 .layer(
300 tower_http::decompression::RequestDecompressionLayer::new()
301 .gzip(true)
302 .deflate(true)
303 .br(true)
304 .zstd(true),
305 )
306 .layer(
307 CorsLayer::new()
308 .allow_methods(Method::POST)
309 .allow_origin(AllowOrigin::mirror_request())
310 .allow_headers(Any),
311 )
312 .layer(
313 ServiceBuilder::new()
314 .layer(HandleErrorLayer::new(handle_load_error))
315 .load_shed()
316 .layer(GlobalConcurrencyLimitLayer::with_semaphore(
317 concurrent_webhook_req,
318 )),
319 );
320 router = router.merge(webhook_router);
321 }
322
323 if routes_enabled.internal {
324 let console_config = Arc::new(console::ConsoleProxyConfig::new(
325 internal_route_config.internal_console_redirect_url.clone(),
326 "/internal-console".to_string(),
327 ));
328 base_router = base_router
329 .route(
330 "/api/opentelemetry/config",
331 routing::put({
332 move |_: axum::Json<DynamicFilterTarget>| async {
333 (
334 StatusCode::BAD_REQUEST,
335 "This endpoint has been replaced. \
336 Use the `opentelemetry_filter` system variable."
337 .to_string(),
338 )
339 }
340 }),
341 )
342 .route(
343 "/api/stderr/config",
344 routing::put({
345 move |_: axum::Json<DynamicFilterTarget>| async {
346 (
347 StatusCode::BAD_REQUEST,
348 "This endpoint has been replaced. \
349 Use the `log_filter` system variable."
350 .to_string(),
351 )
352 }
353 }),
354 )
355 .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
356 .route(
357 "/api/catalog/dump",
358 routing::get(catalog::handle_catalog_dump),
359 )
360 .route(
361 "/api/catalog/check",
362 routing::get(catalog::handle_catalog_check),
363 )
364 .route(
365 "/api/catalog/inject-audit-events",
366 routing::post(catalog::handle_inject_audit_events),
367 )
368 .route(
369 "/api/coordinator/check",
370 routing::get(catalog::handle_coordinator_check),
371 )
372 .route(
373 "/api/coordinator/dump",
374 routing::get(catalog::handle_coordinator_dump),
375 )
376 .route(
377 "/internal-console",
378 routing::get(|| async { Redirect::temporary("/internal-console/") }),
379 )
380 .route(
381 "/internal-console/{*path}",
382 routing::get(console::handle_internal_console),
383 )
384 .route(
385 "/internal-console/",
386 routing::get(console::handle_internal_console),
387 )
388 .layer(Extension(console_config));
389
390 let cluster_proxy_config = Arc::new(cluster::ClusterProxyConfig::new(Arc::clone(
392 &replica_http_locator,
393 )));
394 base_router = base_router
395 .route("/clusters", routing::get(cluster::handle_clusters))
396 .route(
397 "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/",
398 routing::any(cluster::handle_cluster_proxy_root),
399 )
400 .route(
401 "/api/cluster/{:cluster_id}/replica/{:replica_id}/process/{:process}/{*path}",
402 routing::any(cluster::handle_cluster_proxy),
403 )
404 .layer(Extension(cluster_proxy_config));
405
406 let leader_router = Router::new()
407 .route("/api/leader/status", routing::get(handle_leader_status))
408 .route("/api/leader/promote", routing::post(handle_leader_promote))
409 .route(
410 "/api/leader/skip-catchup",
411 routing::post(handle_leader_skip_catchup),
412 )
413 .layer(auth_middleware.clone())
414 .with_state(internal_route_config.deployment_state_handle.clone());
415 router = router.merge(leader_router);
416 }
417
418 if routes_enabled.metrics {
419 let metrics_router = Router::new()
420 .route(
421 "/metrics",
422 routing::get(move || async move {
423 mz_http_util::handle_prometheus(&metrics_registry).await
424 }),
425 )
426 .route(
427 "/metrics/mz_usage",
428 routing::get(|client: AuthedClient| async move {
429 let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
430 mz_http_util::handle_prometheus(®istry).await
431 }),
432 )
433 .route(
434 "/metrics/mz_frontier",
435 routing::get(|client: AuthedClient| async move {
436 let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
437 mz_http_util::handle_prometheus(®istry).await
438 }),
439 )
440 .route(
441 "/metrics/mz_compute",
442 routing::get(|client: AuthedClient| async move {
443 let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
444 mz_http_util::handle_prometheus(®istry).await
445 }),
446 )
447 .route(
448 "/metrics/mz_storage",
449 routing::get(|client: AuthedClient| async move {
450 let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
451 mz_http_util::handle_prometheus(®istry).await
452 }),
453 )
454 .route(
455 "/api/livez",
456 routing::get(mz_http_util::handle_liveness_check),
457 )
458 .route("/api/readyz", routing::get(probe::handle_ready))
459 .layer(auth_middleware.clone())
460 .layer(Extension(adapter_client_rx.clone()))
461 .layer(Extension(active_connection_counter.clone()));
462 router = router.merge(metrics_router);
463 }
464
465 if routes_enabled.console_config {
466 let console_config_router = Router::new()
467 .route(
468 "/api/console/config",
469 routing::get(console::handle_console_config),
470 )
471 .layer(Extension(adapter_client_rx.clone()))
472 .layer(Extension(active_connection_counter.clone()));
473 router = router.merge(console_config_router);
474 }
475
476 if routes_enabled.mcp_agents || routes_enabled.mcp_observatory {
479 use tracing::info;
480
481 let mut mcp_router = Router::new();
482
483 if routes_enabled.mcp_agents {
484 info!("Enabling MCP agents endpoint: /api/mcp/agents");
485 mcp_router = mcp_router.route(
486 "/api/mcp/agents",
487 routing::post(mcp::handle_mcp_agents).get(mcp::handle_mcp_method_not_allowed),
488 );
489 }
490
491 if routes_enabled.mcp_observatory {
492 info!("Enabling MCP observatory endpoint: /api/mcp/observatory");
493 mcp_router = mcp_router.route(
494 "/api/mcp/observatory",
495 routing::post(mcp::handle_mcp_observatory)
496 .get(mcp::handle_mcp_method_not_allowed),
497 );
498 }
499
500 mcp_router = mcp_router
501 .layer(auth_middleware.clone())
502 .layer(Extension(adapter_client_rx.clone()))
503 .layer(Extension(active_connection_counter.clone()))
504 .layer(
505 CorsLayer::new()
506 .allow_methods(Method::POST)
507 .allow_origin(AllowOrigin::mirror_request())
508 .allow_headers(Any),
509 );
510 router = router.merge(mcp_router);
511 }
512
513 base_router = base_router
514 .layer(auth_middleware.clone())
515 .layer(Extension(adapter_client_rx.clone()))
516 .layer(Extension(active_connection_counter.clone()))
517 .layer(
518 CorsLayer::new()
519 .allow_credentials(false)
520 .allow_headers([
521 AUTHORIZATION,
522 CONTENT_TYPE,
523 HeaderName::from_static("x-materialize-version"),
524 ])
525 .allow_methods(Any)
526 .allow_origin(allowed_origin)
527 .expose_headers(Any)
528 .max_age(Duration::from_secs(60) * 60),
529 );
530
531 match authenticator_kind {
532 listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Oidc => {
533 base_router = base_router.layer(session_layer.clone());
534
535 let login_router = Router::new()
536 .route("/api/login", routing::post(handle_login))
537 .route("/api/logout", routing::post(handle_logout))
538 .layer(Extension(adapter_client_rx));
539 router = router.merge(login_router).layer(session_layer);
540 }
541 listeners::AuthenticatorKind::None => {
542 base_router =
543 base_router.layer(middleware::from_fn(x_materialize_user_header_auth));
544 }
545 _ => {}
546 }
547
548 router = router
549 .merge(base_router)
550 .apply_default_layers(source, metrics);
551
552 HttpServer { tls, router }
553 }
554}
555
556impl Server for HttpServer {
557 const NAME: &'static str = "http";
558
559 fn handle_connection(
560 &self,
561 conn: Connection,
562 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
563 ) -> ConnectionHandler {
564 let router = self.router.clone();
565 let tls_context = self.tls.clone();
566 let mut conn = TokioIo::new(conn);
567
568 Box::pin(async {
569 let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
570 let peer_addr = conn
571 .inner_mut()
572 .take_proxy_header_address()
573 .await
574 .map(|a| a.source)
575 .unwrap_or(direct_peer_addr);
576
577 let (conn, conn_protocol) = match tls_context {
578 Some(tls_context) => {
579 let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
580 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
581 let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
582 return Err(e.into());
583 }
584 (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
585 }
586 _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
587 };
588 let mut make_tower_svc = router
589 .layer(Extension(conn_protocol))
590 .into_make_service_with_connect_info::<SocketAddr>();
591 let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
592 let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
593 let http = hyper::server::conn::http1::Builder::new();
594 http.serve_connection(conn, hyper_svc)
595 .with_upgrades()
596 .err_into()
597 .await
598 })
599 }
600}
601
602pub async fn handle_leader_status(
603 State(deployment_state_handle): State<DeploymentStateHandle>,
604) -> impl IntoResponse {
605 let status = deployment_state_handle.status();
606 (StatusCode::OK, Json(json!({ "status": status })))
607}
608
609pub async fn handle_leader_promote(
610 State(deployment_state_handle): State<DeploymentStateHandle>,
611) -> impl IntoResponse {
612 match deployment_state_handle.try_promote() {
613 Ok(()) => {
614 let status = StatusCode::OK;
617 let body = Json(json!({
618 "result": "Success",
619 }));
620 (status, body)
621 }
622 Err(()) => {
623 let status = StatusCode::BAD_REQUEST;
626 let body = Json(json!({
627 "result": {"Failure": {"message": "cannot promote leader while initializing"}},
628 }));
629 (status, body)
630 }
631 }
632}
633
634pub async fn handle_leader_skip_catchup(
635 State(deployment_state_handle): State<DeploymentStateHandle>,
636) -> impl IntoResponse {
637 match deployment_state_handle.try_skip_catchup() {
638 Ok(()) => StatusCode::NO_CONTENT.into_response(),
639 Err(()) => {
640 let status = StatusCode::BAD_REQUEST;
641 let body = Json(json!({
642 "message": "cannot skip catchup in this phase of initialization; try again later",
643 }));
644 (status, body).into_response()
645 }
646 }
647}
648
649async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl IntoResponse {
650 if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
652 let username = match username {
653 Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
654 _ => {
655 return Err(AuthError::MismatchedUser(format!(
656 "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
657 )));
658 }
659 };
660 req.extensions_mut().insert(AuthedUser {
661 name: username,
662 external_metadata_rx: None,
663 authenticated: Authenticated,
664 authenticator_kind: mz_auth::AuthenticatorKind::None,
665 });
666 }
667 Ok(next.run(req).await)
668}
669
670type Delayed<T> = Shared<oneshot::Receiver<T>>;
671
672#[derive(Clone)]
673enum ConnProtocol {
674 Http,
675 Https,
676}
677
678#[derive(Clone, Debug)]
679pub struct AuthedUser {
680 name: String,
681 external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
682 authenticated: Authenticated,
683 authenticator_kind: mz_auth::AuthenticatorKind,
684}
685
686pub struct AuthedClient {
687 pub client: SessionClient,
688 pub connection_guard: Option<ConnectionHandle>,
689}
690
691impl AuthedClient {
692 async fn new<F>(
693 adapter_client: &Client,
694 user: AuthedUser,
695 peer_addr: IpAddr,
696 active_connection_counter: ConnectionCounter,
697 helm_chart_version: Option<String>,
698 session_config: F,
699 options: BTreeMap<String, String>,
700 now: NowFn,
701 ) -> Result<Self, AdapterError>
702 where
703 F: FnOnce(&mut AdapterSession),
704 {
705 let conn_id = adapter_client.new_conn_id()?;
706 let mut session = adapter_client.new_session(
707 AdapterSessionConfig {
708 conn_id,
709 uuid: epoch_to_uuid_v7(&(now)()),
710 user: user.name,
711 client_ip: Some(peer_addr),
712 external_metadata_rx: user.external_metadata_rx,
713 helm_chart_version,
714 authenticator_kind: user.authenticator_kind,
715 },
716 user.authenticated,
717 );
718 let connection_guard = active_connection_counter.allocate_connection(session.user())?;
719
720 session_config(&mut session);
721 let system_vars = adapter_client.get_system_vars().await;
722 for (key, val) in options {
723 const LOCAL: bool = false;
724 if let Err(err) =
725 session
726 .vars_mut()
727 .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
728 {
729 session.add_notice(AdapterNotice::BadStartupSetting {
730 name: key.to_string(),
731 reason: err.to_string(),
732 })
733 }
734 }
735 let adapter_client = adapter_client.startup(session).await?;
736 Ok(AuthedClient {
737 client: adapter_client,
738 connection_guard,
739 })
740 }
741}
742
743impl<S> FromRequestParts<S> for AuthedClient
744where
745 S: Send + Sync,
746{
747 type Rejection = Response;
748
749 async fn from_request_parts(
750 req: &mut http::request::Parts,
751 state: &S,
752 ) -> Result<Self, Self::Rejection> {
753 #[derive(Debug, Default, Deserialize)]
754 struct Params {
755 #[serde(default)]
756 options: String,
757 }
758 let params: Query<Params> = Query::from_request_parts(req, state)
759 .await
760 .unwrap_or_default();
761
762 let peer_addr = req
763 .extensions
764 .get::<ConnectInfo<SocketAddr>>()
765 .expect("ConnectInfo extension guaranteed to exist")
766 .0
767 .ip();
768
769 let user = req.extensions.get::<AuthedUser>().unwrap();
770 let adapter_client = req
771 .extensions
772 .get::<Delayed<mz_adapter::Client>>()
773 .unwrap()
774 .clone();
775 let adapter_client = adapter_client.await.map_err(|_| {
776 (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
777 })?;
778 let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
779 let helm_chart_version = None;
780
781 let options = if params.options.is_empty() {
782 BTreeMap::<String, String>::default()
785 } else {
786 match serde_json::from_str(¶ms.options) {
787 Ok(options) => options,
788 Err(_e) => {
789 let code = StatusCode::BAD_REQUEST;
791 let msg = format!("Failed to deserialize {} map", "options".quoted());
792 return Err((code, msg).into_response());
793 }
794 }
795 };
796
797 let client = AuthedClient::new(
798 &adapter_client,
799 user.clone(),
800 peer_addr,
801 active_connection_counter.clone(),
802 helm_chart_version,
803 |session| {
804 session
805 .vars_mut()
806 .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
807 .expect("known to exist")
808 },
809 options,
810 SYSTEM_TIME.clone(),
811 )
812 .await
813 .map_err(|e| {
814 let status = match e {
815 AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
816 StatusCode::FORBIDDEN
817 }
818 _ => StatusCode::INTERNAL_SERVER_ERROR,
819 };
820 (status, Json(SqlError::from(e))).into_response()
821 })?;
822
823 Ok(client)
824 }
825}
826
827#[derive(Debug, Error)]
828pub(crate) enum AuthError {
829 #[error("role dissallowed")]
830 RoleDisallowed(String),
831 #[error("{0}")]
832 Frontegg(#[from] FronteggError),
833 #[error("missing authorization header")]
834 MissingHttpAuthentication {
835 include_www_authenticate_header: bool,
836 },
837 #[error("{0}")]
838 MismatchedUser(String),
839 #[error("session expired")]
840 SessionExpired,
841 #[error("failed to update session")]
842 FailedToUpdateSession,
843 #[error("invalid credentials")]
844 InvalidCredentials,
845}
846
847impl IntoResponse for AuthError {
848 fn into_response(self) -> Response {
849 warn!("HTTP request failed authentication: {}", self);
850 let mut headers = HeaderMap::new();
851 match self {
852 AuthError::MissingHttpAuthentication {
853 include_www_authenticate_header,
854 } if include_www_authenticate_header => {
855 headers.insert(
856 http::header::WWW_AUTHENTICATE,
857 HeaderValue::from_static("Basic realm=Materialize"),
858 );
859 }
860 _ => {}
861 };
862 (StatusCode::UNAUTHORIZED, headers, "unauthorized").into_response()
865 }
866}
867
868pub async fn handle_login(
870 session: Option<Extension<TowerSession>>,
871 Extension(adapter_client_rx): Extension<Delayed<Client>>,
872 Json(LoginCredentials { username, password }): Json<LoginCredentials>,
873) -> impl IntoResponse {
874 let Ok(adapter_client) = adapter_client_rx.clone().await else {
875 return StatusCode::INTERNAL_SERVER_ERROR;
876 };
877 let authenticated = match adapter_client.authenticate(&username, &password).await {
878 Ok(authenticated) => authenticated,
879 Err(err) => {
880 warn!(?err, "HTTP login failed authentication");
881 return StatusCode::UNAUTHORIZED;
882 }
883 };
884 let session_data = TowerSessionData {
886 username,
887 created_at: SystemTime::now(),
888 last_activity: SystemTime::now(),
889 authenticated,
890 authenticator_kind: mz_auth::AuthenticatorKind::Password,
891 };
892 let session = session.and_then(|Extension(session)| Some(session));
894 let Some(session) = session else {
895 return StatusCode::INTERNAL_SERVER_ERROR;
896 };
897 match session.insert("data", &session_data).await {
898 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
899 Ok(_) => StatusCode::OK,
900 }
901}
902
903pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
905 let session = session.and_then(|Extension(session)| Some(session));
906 let Some(session) = session else {
907 return StatusCode::INTERNAL_SERVER_ERROR;
908 };
909 match session.delete().await {
911 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
912 Ok(_) => StatusCode::OK,
913 }
914}
915
916async fn http_auth(
917 mut req: Request,
918 next: Next,
919 tls_enabled: bool,
920 authenticator_kind: listeners::AuthenticatorKind,
921 frontegg: Option<mz_frontegg_auth::Authenticator>,
922 oidc_rx: Delayed<mz_authenticator::GenericOidcAuthenticator>,
923 adapter_client_rx: Delayed<Client>,
924 allowed_roles: AllowedRoles,
925) -> Result<impl IntoResponse, AuthError> {
926 let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
927 Some(Credentials::Password {
928 username: basic.username().to_owned(),
929 password: Password(basic.password().to_owned()),
930 })
931 } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
932 Some(Credentials::Token {
933 token: bearer.token().to_owned(),
934 })
935 } else {
936 None
937 };
938
939 if creds.is_none()
943 && let Some((session, session_data)) =
944 maybe_get_authenticated_session(req.extensions().get::<TowerSession>()).await
945 {
946 let user = ensure_session_unexpired(session, session_data).await?;
947 req.extensions_mut().insert(user);
949 return Ok(next.run(req).await);
950 }
951
952 let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
956 match (tls_enabled, &conn_protocol) {
957 (false, ConnProtocol::Http) => {}
958 (false, ConnProtocol::Https { .. }) => unreachable!(),
959 (true, ConnProtocol::Http) => {
960 let mut parts = req.uri().clone().into_parts();
961 parts.scheme = Some(Scheme::HTTPS);
962 return Ok(Redirect::permanent(
963 &Uri::from_parts(parts)
964 .expect("it was already a URI, just changed the scheme")
965 .to_string(),
966 )
967 .into_response());
968 }
969 (true, ConnProtocol::Https { .. }) => {}
970 }
971 if req.extensions().get::<AuthedUser>().is_some() {
973 return Ok(next.run(req).await);
974 }
975
976 let path = req.uri().path();
977 let include_www_authenticate_header = path == "/"
978 || PROFILING_API_ENDPOINTS
979 .iter()
980 .any(|prefix| path.starts_with(prefix));
981 let authenticator = get_authenticator(
982 authenticator_kind,
983 creds.as_ref(),
984 frontegg,
985 &oidc_rx,
986 &adapter_client_rx,
987 )
988 .await;
989
990 let user = auth(
991 &authenticator,
992 creds,
993 allowed_roles,
994 include_www_authenticate_header,
995 )
996 .await?;
997
998 req.extensions_mut().insert(user);
1001
1002 Ok(next.run(req).await)
1004}
1005
1006async fn init_ws(
1007 WsState {
1008 frontegg,
1009 oidc_rx,
1010 authenticator_kind,
1011 adapter_client_rx,
1012 active_connection_counter,
1013 helm_chart_version,
1014 allowed_roles,
1015 }: WsState,
1016 existing_user: Option<ExistingUser>,
1017 peer_addr: IpAddr,
1018 ws: &mut WebSocket,
1019) -> Result<AuthedClient, anyhow::Error> {
1020 let ws_auth: WebSocketAuth = loop {
1023 let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
1024 match init_msg {
1025 Message::Text(data) => break serde_json::from_str(&data)?,
1026 Message::Binary(data) => break serde_json::from_slice(&data)?,
1027 Message::Ping(_) => {
1029 continue;
1030 }
1031 Message::Pong(_) => {
1032 continue;
1033 }
1034 Message::Close(_) => {
1035 anyhow::bail!("closed");
1036 }
1037 }
1038 };
1039
1040 let (creds, options) = match ws_auth {
1043 WebSocketAuth::Basic {
1044 user,
1045 password,
1046 options,
1047 } => {
1048 let creds = Credentials::Password {
1049 username: user,
1050 password,
1051 };
1052 (Some(creds), options)
1053 }
1054 WebSocketAuth::Bearer { token, options } => {
1055 let creds = Credentials::Token { token };
1056 (Some(creds), options)
1057 }
1058 WebSocketAuth::OptionsOnly { options } => (None, options),
1059 };
1060
1061 let user = match (existing_user, creds) {
1062 (Some(ExistingUser::XMaterializeUserHeader(_)), Some(_creds)) => {
1063 warn!("Unexpected bearer or basic auth provided when using user header");
1064 anyhow::bail!("unexpected")
1065 }
1066 (Some(ExistingUser::Session(user)), None)
1067 | (Some(ExistingUser::XMaterializeUserHeader(user)), None) => user,
1068 (_, Some(creds)) => {
1069 let authenticator = get_authenticator(
1070 authenticator_kind,
1071 Some(&creds),
1072 frontegg,
1073 &oidc_rx,
1074 &adapter_client_rx,
1075 )
1076 .await;
1077 let user = auth(&authenticator, Some(creds), allowed_roles, false).await?;
1078 user
1079 }
1080 (None, None) => anyhow::bail!("expected auth information"),
1081 };
1082
1083 let client = AuthedClient::new(
1084 &adapter_client_rx.clone().await?,
1085 user,
1086 peer_addr,
1087 active_connection_counter.clone(),
1088 helm_chart_version.clone(),
1089 |_session| (),
1090 options,
1091 SYSTEM_TIME.clone(),
1092 )
1093 .await?;
1094
1095 Ok(client)
1096}
1097
1098enum Credentials {
1099 Password {
1100 username: String,
1101 password: Password,
1102 },
1103 Token {
1104 token: String,
1105 },
1106}
1107
1108async fn get_authenticator(
1109 kind: listeners::AuthenticatorKind,
1110 creds: Option<&Credentials>,
1111 frontegg: Option<mz_frontegg_auth::Authenticator>,
1112 oidc_rx: &Delayed<mz_authenticator::GenericOidcAuthenticator>,
1113 adapter_client_rx: &Delayed<Client>,
1114) -> Authenticator {
1115 match kind {
1116 listeners::AuthenticatorKind::Frontegg => Authenticator::Frontegg(frontegg.expect(
1117 "Frontegg authenticator should exist with listeners::AuthenticatorKind::Frontegg",
1118 )),
1119 listeners::AuthenticatorKind::Password | listeners::AuthenticatorKind::Sasl => {
1120 let client = adapter_client_rx.clone().await.expect("sender not dropped");
1121 Authenticator::Password(client)
1122 }
1123 listeners::AuthenticatorKind::Oidc => match creds {
1124 Some(Credentials::Password { .. }) => {
1126 let client = adapter_client_rx.clone().await.expect("sender not dropped");
1127 Authenticator::Password(client)
1128 }
1129 _ => Authenticator::Oidc(oidc_rx.clone().await.expect("sender not dropped")),
1130 },
1131 listeners::AuthenticatorKind::None => Authenticator::None,
1132 }
1133}
1134
1135pub(crate) async fn maybe_get_authenticated_session(
1139 session: Option<&TowerSession>,
1140) -> Option<(&TowerSession, TowerSessionData)> {
1141 if let Some(session) = session {
1142 if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
1143 return Some((session, session_data));
1144 }
1145 }
1146 None
1147}
1148
1149pub(crate) async fn ensure_session_unexpired(
1152 session: &TowerSession,
1153 session_data: TowerSessionData,
1154) -> Result<AuthedUser, AuthError> {
1155 if session_data
1156 .last_activity
1157 .elapsed()
1158 .unwrap_or(Duration::MAX)
1159 > SESSION_DURATION
1160 {
1161 let _ = session.delete().await;
1162 return Err(AuthError::SessionExpired);
1163 }
1164 let mut updated_data = session_data.clone();
1165 updated_data.last_activity = SystemTime::now();
1166 session
1167 .insert("data", &updated_data)
1168 .await
1169 .map_err(|_| AuthError::FailedToUpdateSession)?;
1170
1171 Ok(AuthedUser {
1172 name: session_data.username,
1173 external_metadata_rx: None,
1174 authenticated: session_data.authenticated,
1175 authenticator_kind: session_data.authenticator_kind,
1176 })
1177}
1178
1179async fn auth(
1180 authenticator: &Authenticator,
1181 creds: Option<Credentials>,
1182 allowed_roles: AllowedRoles,
1183 include_www_authenticate_header: bool,
1184) -> Result<AuthedUser, AuthError> {
1185 let (name, external_metadata_rx, authenticated) = match authenticator {
1186 Authenticator::Frontegg(frontegg) => match creds {
1187 Some(Credentials::Password { username, password }) => {
1188 let (auth_session, authenticated) =
1189 frontegg.authenticate(&username, password.as_str()).await?;
1190 let name = auth_session.user().into();
1191 let external_metadata_rx = Some(auth_session.external_metadata_rx());
1192 (name, external_metadata_rx, authenticated)
1193 }
1194 Some(Credentials::Token { token }) => {
1195 let (claims, authenticated) = frontegg.validate_access_token(&token, None)?;
1196 let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
1197 user_id: claims.user_id,
1198 admin: claims.is_admin,
1199 });
1200 (claims.user, Some(external_metadata_rx), authenticated)
1201 }
1202 None => {
1203 return Err(AuthError::MissingHttpAuthentication {
1204 include_www_authenticate_header,
1205 });
1206 }
1207 },
1208 Authenticator::Password(adapter_client) => match creds {
1209 Some(Credentials::Password { username, password }) => {
1210 let authenticated = adapter_client
1211 .authenticate(&username, &password)
1212 .await
1213 .map_err(|_| AuthError::InvalidCredentials)?;
1214 (username, None, authenticated)
1215 }
1216 _ => {
1217 return Err(AuthError::MissingHttpAuthentication {
1218 include_www_authenticate_header,
1219 });
1220 }
1221 },
1222 Authenticator::Sasl(_) => {
1223 return Err(AuthError::MissingHttpAuthentication {
1227 include_www_authenticate_header,
1228 });
1229 }
1230 Authenticator::Oidc(oidc) => match creds {
1231 Some(Credentials::Token { token }) => {
1232 let (claims, authenticated) = oidc
1234 .authenticate(&token, None)
1235 .await
1236 .map_err(|_| AuthError::InvalidCredentials)?;
1237 let name = claims.user;
1238 (name, None, authenticated)
1239 }
1240 _ => {
1241 return Err(AuthError::MissingHttpAuthentication {
1242 include_www_authenticate_header,
1243 });
1244 }
1245 },
1246 Authenticator::None => {
1247 let name = match creds {
1251 Some(Credentials::Password { username, .. }) => username,
1252 _ => HTTP_DEFAULT_USER.name.to_owned(),
1253 };
1254 (name, None, Authenticated)
1255 }
1256 };
1257
1258 check_role_allowed(&name, allowed_roles)?;
1259
1260 Ok(AuthedUser {
1261 name,
1262 external_metadata_rx,
1263 authenticated,
1264 authenticator_kind: authenticator.kind(),
1265 })
1266}
1267
1268fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1270 let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1271 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1273 let role_allowed = match allowed_roles {
1274 AllowedRoles::Normal => !is_reserved_user,
1275 AllowedRoles::Internal => is_internal_user,
1276 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1277 };
1278 if role_allowed {
1279 Ok(())
1280 } else {
1281 Err(AuthError::RoleDisallowed(name.to_owned()))
1282 }
1283}
1284
1285trait DefaultLayers {
1288 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1289}
1290
1291impl DefaultLayers for Router {
1292 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1293 self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1294 .layer(metrics::PrometheusLayer::new(source, metrics))
1295 }
1296}
1297
1298async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1305 if error.is::<tower::load_shed::error::Overloaded>() {
1306 return (
1307 StatusCode::TOO_MANY_REQUESTS,
1308 Cow::from("too many requests, try again later"),
1309 );
1310 }
1311
1312 (
1315 StatusCode::INTERNAL_SERVER_ERROR,
1316 Cow::from(format!("Unhandled internal error: {}", error)),
1317 )
1318}
1319
1320#[derive(Debug, Deserialize, Serialize, PartialEq)]
1321pub struct LoginCredentials {
1322 username: String,
1323 password: Password,
1324}
1325
1326#[derive(Debug, Clone, Serialize, Deserialize)]
1327pub struct TowerSessionData {
1328 username: String,
1329 created_at: SystemTime,
1330 last_activity: SystemTime,
1331 authenticated: Authenticated,
1332 authenticator_kind: mz_auth::AuthenticatorKind,
1333}
1334
1335#[cfg(test)]
1336mod tests {
1337 use super::{AllowedRoles, check_role_allowed};
1338
1339 #[mz_ore::test]
1340 fn test_check_role_allowed() {
1341 assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1343 assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1344 assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1345
1346 assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1348 assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1349 assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1350
1351 assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1353 assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1354 assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1355
1356 assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1358 assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1359 assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1360
1361 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1363 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1364 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1365
1366 assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1368 assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1369 assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1370
1371 assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1373 assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1374 assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1375
1376 assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1378 assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1379 assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1380
1381 assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1383 assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1384 assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1385
1386 assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1388 assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1389 assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1390 }
1391}