1#![allow(clippy::unused_async)]
18
19use std::borrow::Cow;
20use std::collections::BTreeMap;
21use std::fmt::Debug;
22use std::net::{IpAddr, SocketAddr};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::time::{Duration, SystemTime};
26
27use anyhow::Context;
28use async_trait::async_trait;
29use axum::error_handling::HandleErrorLayer;
30use axum::extract::ws::{Message, WebSocket};
31use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
32use axum::middleware::{self, Next};
33use axum::response::{IntoResponse, Redirect, Response};
34use axum::{Extension, Json, Router, routing};
35use futures::future::{Shared, TryFutureExt};
36use headers::authorization::{Authorization, Basic, Bearer};
37use headers::{HeaderMapExt, HeaderName};
38use http::header::{AUTHORIZATION, CONTENT_TYPE};
39use http::uri::Scheme;
40use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
41use hyper_openssl::SslStream;
42use hyper_openssl::client::legacy::MaybeHttpsStream;
43use hyper_util::rt::TokioIo;
44use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
45use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
46use mz_auth::password::Password;
47use mz_authenticator::Authenticator;
48use mz_frontegg_auth::Error as FronteggError;
49use mz_http_util::DynamicFilterTarget;
50use mz_ore::cast::u64_to_usize;
51use mz_ore::metrics::MetricsRegistry;
52use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
53use mz_ore::str::StrExt;
54use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
55use mz_repr::user::ExternalUserMetadata;
56use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind, HttpRoutesEnabled};
57use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
58use mz_sql::session::metadata::SessionMetadata;
59use mz_sql::session::user::{
60 HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
61};
62use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
63use openssl::ssl::Ssl;
64use prometheus::{
65 COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
66};
67use serde::{Deserialize, Serialize};
68use serde_json::json;
69use thiserror::Error;
70use tokio::io::AsyncWriteExt;
71use tokio::sync::oneshot::Receiver;
72use tokio::sync::{oneshot, watch};
73use tokio_metrics::TaskMetrics;
74use tower::limit::GlobalConcurrencyLimitLayer;
75use tower::{Service, ServiceBuilder};
76use tower_http::cors::{AllowOrigin, Any, CorsLayer};
77use tower_sessions::{
78 MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
79 SessionManagerLayer as TowerSessionManagerLayer,
80};
81use tracing::{error, warn};
82
83use crate::BUILD_INFO;
84use crate::deployment::state::DeploymentStateHandle;
85use crate::http::sql::SqlError;
86
87mod catalog;
88mod console;
89mod memory;
90mod metrics;
91mod probe;
92mod prometheus;
93mod root;
94mod sql;
95mod webhook;
96
97pub use metrics::Metrics;
98pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
99
100pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
102
103const SESSION_DURATION: Duration = Duration::from_secs(3600); const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
106
107#[derive(Debug)]
108pub struct HttpConfig {
109 pub source: &'static str,
110 pub tls: Option<ReloadingSslContext>,
111 pub authenticator_kind: AuthenticatorKind,
112 pub authenticator_rx: Shared<Receiver<Arc<Authenticator>>>,
113 pub adapter_client_rx: Shared<Receiver<Client>>,
114 pub allowed_origin: AllowOrigin,
115 pub active_connection_counter: ConnectionCounter,
116 pub helm_chart_version: Option<String>,
117 pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
118 pub metrics: Metrics,
119 pub metrics_registry: MetricsRegistry,
120 pub allowed_roles: AllowedRoles,
121 pub internal_route_config: Arc<InternalRouteConfig>,
122 pub routes_enabled: HttpRoutesEnabled,
123}
124
125#[derive(Debug, Clone)]
126pub struct InternalRouteConfig {
127 pub deployment_state_handle: DeploymentStateHandle,
128 pub internal_console_redirect_url: Option<String>,
129}
130
131#[derive(Clone)]
132pub struct WsState {
133 authenticator_rx: Delayed<Arc<Authenticator>>,
134 adapter_client_rx: Delayed<mz_adapter::Client>,
135 active_connection_counter: ConnectionCounter,
136 helm_chart_version: Option<String>,
137 allowed_roles: AllowedRoles,
138}
139
140#[derive(Clone)]
141pub struct WebhookState {
142 adapter_client_rx: Delayed<mz_adapter::Client>,
143 webhook_cache: WebhookAppenderCache,
144}
145
146#[derive(Debug)]
147pub struct HttpServer {
148 tls: Option<ReloadingSslContext>,
149 router: Router,
150}
151
152impl HttpServer {
153 pub fn new(
154 HttpConfig {
155 source,
156 tls,
157 authenticator_kind,
158 authenticator_rx,
159 adapter_client_rx,
160 allowed_origin,
161 active_connection_counter,
162 helm_chart_version,
163 concurrent_webhook_req,
164 metrics,
165 metrics_registry,
166 allowed_roles,
167 internal_route_config,
168 routes_enabled,
169 }: HttpConfig,
170 ) -> HttpServer {
171 let tls_enabled = tls.is_some();
172 let webhook_cache = WebhookAppenderCache::new();
173
174 let session_store = TowerSessionMemoryStore::default();
176 let session_layer = TowerSessionManagerLayer::new(session_store)
177 .with_secure(tls_enabled) .with_same_site(tower_sessions::cookie::SameSite::Strict) .with_http_only(true) .with_name("mz_session") .with_path("/"); let auth_middleware_authenticator_rx = authenticator_rx.clone();
184 let auth_middleware = middleware::from_fn(move |req, next| {
185 let authenticator_rx = auth_middleware_authenticator_rx.clone();
186 async move {
187 let authenticator = authenticator_rx
188 .await
189 .expect("sender not dropped before sending once");
190 http_auth(req, next, tls_enabled, authenticator, allowed_roles).await
191 }
192 });
193
194 let mut router = Router::new();
195 let mut base_router = Router::new();
196 if routes_enabled.base {
197 base_router = base_router
198 .route(
199 "/",
200 routing::get(move || async move {
201 root::handle_home(routes_enabled.profiling).await
202 }),
203 )
204 .route("/api/sql", routing::post(sql::handle_sql))
205 .route("/memory", routing::get(memory::handle_memory))
206 .route(
207 "/hierarchical-memory",
208 routing::get(memory::handle_hierarchical_memory),
209 )
210 .route("/static/*path", routing::get(root::handle_static));
211
212 let mut ws_router = Router::new()
213 .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
214 .with_state(WsState {
215 authenticator_rx: authenticator_rx.clone(),
216 adapter_client_rx: adapter_client_rx.clone(),
217 active_connection_counter: active_connection_counter.clone(),
218 helm_chart_version,
219 allowed_roles,
220 });
221 if let AuthenticatorKind::None = authenticator_kind {
222 ws_router = ws_router.layer(middleware::from_fn(x_materialize_user_header_auth));
223 }
224 router = router.merge(ws_router);
225 }
226 if routes_enabled.profiling {
227 base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
228 }
229
230 if routes_enabled.webhook {
231 let webhook_router = Router::new()
232 .route(
233 "/api/webhook/:database/:schema/:id",
234 routing::post(webhook::handle_webhook),
235 )
236 .with_state(WebhookState {
237 adapter_client_rx: adapter_client_rx.clone(),
238 webhook_cache,
239 })
240 .layer(
241 tower_http::decompression::RequestDecompressionLayer::new()
242 .gzip(true)
243 .deflate(true)
244 .br(true)
245 .zstd(true),
246 )
247 .layer(
248 CorsLayer::new()
249 .allow_methods(Method::POST)
250 .allow_origin(AllowOrigin::mirror_request())
251 .allow_headers(Any),
252 )
253 .layer(
254 ServiceBuilder::new()
255 .layer(HandleErrorLayer::new(handle_load_error))
256 .load_shed()
257 .layer(GlobalConcurrencyLimitLayer::with_semaphore(
258 concurrent_webhook_req,
259 )),
260 );
261 router = router.merge(webhook_router);
262 }
263
264 if routes_enabled.internal {
265 let console_config = Arc::new(console::ConsoleProxyConfig::new(
266 internal_route_config.internal_console_redirect_url.clone(),
267 "/internal-console".to_string(),
268 ));
269 base_router = base_router
270 .route(
271 "/api/opentelemetry/config",
272 routing::put({
273 move |_: axum::Json<DynamicFilterTarget>| async {
274 (
275 StatusCode::BAD_REQUEST,
276 "This endpoint has been replaced. \
277 Use the `opentelemetry_filter` system variable."
278 .to_string(),
279 )
280 }
281 }),
282 )
283 .route(
284 "/api/stderr/config",
285 routing::put({
286 move |_: axum::Json<DynamicFilterTarget>| async {
287 (
288 StatusCode::BAD_REQUEST,
289 "This endpoint has been replaced. \
290 Use the `log_filter` system variable."
291 .to_string(),
292 )
293 }
294 }),
295 )
296 .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
297 .route(
298 "/api/catalog/dump",
299 routing::get(catalog::handle_catalog_dump),
300 )
301 .route(
302 "/api/catalog/check",
303 routing::get(catalog::handle_catalog_check),
304 )
305 .route(
306 "/api/coordinator/check",
307 routing::get(catalog::handle_coordinator_check),
308 )
309 .route(
310 "/api/coordinator/dump",
311 routing::get(catalog::handle_coordinator_dump),
312 )
313 .route(
314 "/internal-console",
315 routing::get(|| async { Redirect::temporary("/internal-console/") }),
316 )
317 .route(
318 "/internal-console/*path",
319 routing::get(console::handle_internal_console),
320 )
321 .route(
322 "/internal-console/",
323 routing::get(console::handle_internal_console),
324 )
325 .layer(Extension(console_config));
326 let leader_router = Router::new()
327 .route("/api/leader/status", routing::get(handle_leader_status))
328 .route("/api/leader/promote", routing::post(handle_leader_promote))
329 .route(
330 "/api/leader/skip-catchup",
331 routing::post(handle_leader_skip_catchup),
332 )
333 .layer(auth_middleware.clone())
334 .with_state(internal_route_config.deployment_state_handle.clone());
335 router = router.merge(leader_router);
336 }
337
338 if routes_enabled.metrics {
339 let metrics_router = Router::new()
340 .route(
341 "/metrics",
342 routing::get(move || async move {
343 mz_http_util::handle_prometheus(&metrics_registry).await
344 }),
345 )
346 .route(
347 "/metrics/mz_usage",
348 routing::get(|client: AuthedClient| async move {
349 let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
350 mz_http_util::handle_prometheus(®istry).await
351 }),
352 )
353 .route(
354 "/metrics/mz_frontier",
355 routing::get(|client: AuthedClient| async move {
356 let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
357 mz_http_util::handle_prometheus(®istry).await
358 }),
359 )
360 .route(
361 "/metrics/mz_compute",
362 routing::get(|client: AuthedClient| async move {
363 let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
364 mz_http_util::handle_prometheus(®istry).await
365 }),
366 )
367 .route(
368 "/metrics/mz_storage",
369 routing::get(|client: AuthedClient| async move {
370 let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
371 mz_http_util::handle_prometheus(®istry).await
372 }),
373 )
374 .route(
375 "/api/livez",
376 routing::get(mz_http_util::handle_liveness_check),
377 )
378 .route("/api/readyz", routing::get(probe::handle_ready))
379 .layer(auth_middleware.clone())
380 .layer(Extension(adapter_client_rx.clone()))
381 .layer(Extension(active_connection_counter.clone()));
382 router = router.merge(metrics_router);
383 }
384
385 base_router = base_router
386 .layer(auth_middleware.clone())
387 .layer(Extension(adapter_client_rx.clone()))
388 .layer(Extension(active_connection_counter.clone()))
389 .layer(
390 CorsLayer::new()
391 .allow_credentials(false)
392 .allow_headers([
393 AUTHORIZATION,
394 CONTENT_TYPE,
395 HeaderName::from_static("x-materialize-version"),
396 ])
397 .allow_methods(Any)
398 .allow_origin(allowed_origin)
399 .expose_headers(Any)
400 .max_age(Duration::from_secs(60) * 60),
401 );
402
403 match authenticator_kind {
404 AuthenticatorKind::Password => {
405 base_router = base_router.layer(session_layer.clone());
406
407 let login_router = Router::new()
408 .route("/api/login", routing::post(handle_login))
409 .route("/api/logout", routing::post(handle_logout))
410 .layer(Extension(adapter_client_rx));
411 router = router.merge(login_router).layer(session_layer);
412 }
413 AuthenticatorKind::None => {
414 base_router =
415 base_router.layer(middleware::from_fn(x_materialize_user_header_auth));
416 }
417 _ => {}
418 }
419
420 router = router
421 .merge(base_router)
422 .apply_default_layers(source, metrics);
423
424 HttpServer { tls, router }
425 }
426}
427
428impl Server for HttpServer {
429 const NAME: &'static str = "http";
430
431 fn handle_connection(
432 &self,
433 conn: Connection,
434 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
435 ) -> ConnectionHandler {
436 let router = self.router.clone();
437 let tls_context = self.tls.clone();
438 let mut conn = TokioIo::new(conn);
439
440 Box::pin(async {
441 let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
442 let peer_addr = conn
443 .inner_mut()
444 .take_proxy_header_address()
445 .await
446 .map(|a| a.source)
447 .unwrap_or(direct_peer_addr);
448
449 let (conn, conn_protocol) = match tls_context {
450 Some(tls_context) => {
451 let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
452 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
453 let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
454 return Err(e.into());
455 }
456 (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
457 }
458 _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
459 };
460 let mut make_tower_svc = router
461 .layer(Extension(conn_protocol))
462 .into_make_service_with_connect_info::<SocketAddr>();
463 let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
464 let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
465 let http = hyper::server::conn::http1::Builder::new();
466 http.serve_connection(conn, hyper_svc)
467 .with_upgrades()
468 .err_into()
469 .await
470 })
471 }
472}
473
474pub async fn handle_leader_status(
475 State(deployment_state_handle): State<DeploymentStateHandle>,
476) -> impl IntoResponse {
477 let status = deployment_state_handle.status();
478 (StatusCode::OK, Json(json!({ "status": status })))
479}
480
481pub async fn handle_leader_promote(
482 State(deployment_state_handle): State<DeploymentStateHandle>,
483) -> impl IntoResponse {
484 match deployment_state_handle.try_promote() {
485 Ok(()) => {
486 let status = StatusCode::OK;
489 let body = Json(json!({
490 "result": "Success",
491 }));
492 (status, body)
493 }
494 Err(()) => {
495 let status = StatusCode::BAD_REQUEST;
498 let body = Json(json!({
499 "result": {"Failure": {"message": "cannot promote leader while initializing"}},
500 }));
501 (status, body)
502 }
503 }
504}
505
506pub async fn handle_leader_skip_catchup(
507 State(deployment_state_handle): State<DeploymentStateHandle>,
508) -> impl IntoResponse {
509 match deployment_state_handle.try_skip_catchup() {
510 Ok(()) => StatusCode::NO_CONTENT.into_response(),
511 Err(()) => {
512 let status = StatusCode::BAD_REQUEST;
513 let body = Json(json!({
514 "message": "cannot skip catchup in this phase of initialization; try again later",
515 }));
516 (status, body).into_response()
517 }
518 }
519}
520
521async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl IntoResponse {
522 if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
524 let username = match username {
525 Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
526 _ => {
527 return Err(AuthError::MismatchedUser(format!(
528 "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
529 )));
530 }
531 };
532 req.extensions_mut().insert(AuthedUser {
533 name: username,
534 external_metadata_rx: None,
535 });
536 }
537 Ok(next.run(req).await)
538}
539
540type Delayed<T> = Shared<oneshot::Receiver<T>>;
541
542#[derive(Clone)]
543enum ConnProtocol {
544 Http,
545 Https,
546}
547
548#[derive(Clone, Debug)]
549pub struct AuthedUser {
550 name: String,
551 external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
552}
553
554pub struct AuthedClient {
555 pub client: SessionClient,
556 pub connection_guard: Option<ConnectionHandle>,
557}
558
559impl AuthedClient {
560 async fn new<F>(
561 adapter_client: &Client,
562 user: AuthedUser,
563 peer_addr: IpAddr,
564 active_connection_counter: ConnectionCounter,
565 helm_chart_version: Option<String>,
566 session_config: F,
567 options: BTreeMap<String, String>,
568 now: NowFn,
569 ) -> Result<Self, AdapterError>
570 where
571 F: FnOnce(&mut AdapterSession),
572 {
573 let conn_id = adapter_client.new_conn_id()?;
574 let mut session = adapter_client.new_session(AdapterSessionConfig {
575 conn_id,
576 uuid: epoch_to_uuid_v7(&(now)()),
577 user: user.name,
578 client_ip: Some(peer_addr),
579 external_metadata_rx: user.external_metadata_rx,
580 internal_user_metadata: None,
582 helm_chart_version,
583 });
584 let connection_guard = active_connection_counter.allocate_connection(session.user())?;
585
586 session_config(&mut session);
587 let system_vars = adapter_client.get_system_vars().await;
588 for (key, val) in options {
589 const LOCAL: bool = false;
590 if let Err(err) =
591 session
592 .vars_mut()
593 .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
594 {
595 session.add_notice(AdapterNotice::BadStartupSetting {
596 name: key.to_string(),
597 reason: err.to_string(),
598 })
599 }
600 }
601 let adapter_client = adapter_client.startup(session).await?;
602 Ok(AuthedClient {
603 client: adapter_client,
604 connection_guard,
605 })
606 }
607}
608
609#[async_trait]
610impl<S> FromRequestParts<S> for AuthedClient
611where
612 S: Send + Sync,
613{
614 type Rejection = Response;
615
616 async fn from_request_parts(
617 req: &mut http::request::Parts,
618 state: &S,
619 ) -> Result<Self, Self::Rejection> {
620 #[derive(Debug, Default, Deserialize)]
621 struct Params {
622 #[serde(default)]
623 options: String,
624 }
625 let params: Query<Params> = Query::from_request_parts(req, state)
626 .await
627 .unwrap_or_default();
628
629 let peer_addr = req
630 .extensions
631 .get::<ConnectInfo<SocketAddr>>()
632 .expect("ConnectInfo extension guaranteed to exist")
633 .0
634 .ip();
635
636 let user = req.extensions.get::<AuthedUser>().unwrap();
637 let adapter_client = req
638 .extensions
639 .get::<Delayed<mz_adapter::Client>>()
640 .unwrap()
641 .clone();
642 let adapter_client = adapter_client.await.map_err(|_| {
643 (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
644 })?;
645 let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
646 let helm_chart_version = None;
647
648 let options = if params.options.is_empty() {
649 BTreeMap::<String, String>::default()
652 } else {
653 match serde_json::from_str(¶ms.options) {
654 Ok(options) => options,
655 Err(_e) => {
656 let code = StatusCode::BAD_REQUEST;
658 let msg = format!("Failed to deserialize {} map", "options".quoted());
659 return Err((code, msg).into_response());
660 }
661 }
662 };
663
664 let client = AuthedClient::new(
665 &adapter_client,
666 user.clone(),
667 peer_addr,
668 active_connection_counter.clone(),
669 helm_chart_version,
670 |session| {
671 session
672 .vars_mut()
673 .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
674 .expect("known to exist")
675 },
676 options,
677 SYSTEM_TIME.clone(),
678 )
679 .await
680 .map_err(|e| {
681 let status = match e {
682 AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
683 StatusCode::FORBIDDEN
684 }
685 _ => StatusCode::INTERNAL_SERVER_ERROR,
686 };
687 (status, Json(SqlError::from(e))).into_response()
688 })?;
689
690 Ok(client)
691 }
692}
693
694#[derive(Debug, Error)]
695enum AuthError {
696 #[error("role dissallowed")]
697 RoleDisallowed(String),
698 #[error("{0}")]
699 Frontegg(#[from] FronteggError),
700 #[error("missing authorization header")]
701 MissingHttpAuthentication {
702 include_www_authenticate_header: bool,
703 },
704 #[error("{0}")]
705 MismatchedUser(String),
706 #[error("session expired")]
707 SessionExpired,
708 #[error("failed to update session")]
709 FailedToUpdateSession,
710 #[error("invalid credentials")]
711 InvalidCredentials,
712}
713
714impl IntoResponse for AuthError {
715 fn into_response(self) -> Response {
716 warn!("HTTP request failed authentication: {}", self);
717 let mut headers = HeaderMap::new();
718 match self {
719 AuthError::MissingHttpAuthentication {
720 include_www_authenticate_header,
721 } if include_www_authenticate_header => {
722 headers.insert(
723 http::header::WWW_AUTHENTICATE,
724 HeaderValue::from_static("Basic realm=Materialize"),
725 );
726 }
727 _ => {}
728 };
729 (StatusCode::UNAUTHORIZED, headers, "unauthorized").into_response()
732 }
733}
734
735pub async fn handle_login(
737 session: Option<Extension<TowerSession>>,
738 Extension(adapter_client_rx): Extension<Delayed<Client>>,
739 Json(LoginCredentials { username, password }): Json<LoginCredentials>,
740) -> impl IntoResponse {
741 let Ok(adapter_client) = adapter_client_rx.clone().await else {
742 return StatusCode::INTERNAL_SERVER_ERROR;
743 };
744 if let Err(err) = adapter_client.authenticate(&username, &password).await {
745 warn!(?err, "HTTP login failed authentication");
746 return StatusCode::UNAUTHORIZED;
747 };
748
749 let session_data = TowerSessionData {
751 username,
752 created_at: SystemTime::now(),
753 last_activity: SystemTime::now(),
754 };
755 let session = session.and_then(|Extension(session)| Some(session));
757 let Some(session) = session else {
758 return StatusCode::INTERNAL_SERVER_ERROR;
759 };
760 match session.insert("data", &session_data).await {
761 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
762 Ok(_) => StatusCode::OK,
763 }
764}
765
766pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
768 let session = session.and_then(|Extension(session)| Some(session));
769 let Some(session) = session else {
770 return StatusCode::INTERNAL_SERVER_ERROR;
771 };
772 match session.delete().await {
774 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
775 Ok(_) => StatusCode::OK,
776 }
777}
778
779async fn http_auth(
780 mut req: Request,
781 next: Next,
782 tls_enabled: bool,
783 authenticator: Arc<Authenticator>,
784 allowed_roles: AllowedRoles,
785) -> impl IntoResponse + use<> {
786 if let Some(session) = req.extensions().get::<TowerSession>() {
788 if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
789 if session_data
791 .last_activity
792 .elapsed()
793 .unwrap_or(Duration::MAX)
794 > SESSION_DURATION
795 {
796 let _ = session.delete().await;
797 return Err(AuthError::SessionExpired);
798 }
799 let mut updated_data = session_data.clone();
801 updated_data.last_activity = SystemTime::now();
802 session
803 .insert("data", &updated_data)
804 .await
805 .map_err(|_| AuthError::FailedToUpdateSession)?;
806 req.extensions_mut().insert(AuthedUser {
808 name: session_data.username,
809 external_metadata_rx: None,
810 });
811 return Ok(next.run(req).await);
812 }
813 }
814
815 let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
819 match (tls_enabled, &conn_protocol) {
820 (false, ConnProtocol::Http) => {}
821 (false, ConnProtocol::Https { .. }) => unreachable!(),
822 (true, ConnProtocol::Http) => {
823 let mut parts = req.uri().clone().into_parts();
824 parts.scheme = Some(Scheme::HTTPS);
825 return Ok(Redirect::permanent(
826 &Uri::from_parts(parts)
827 .expect("it was already a URI, just changed the scheme")
828 .to_string(),
829 )
830 .into_response());
831 }
832 (true, ConnProtocol::Https { .. }) => {}
833 }
834 if req.extensions().get::<AuthedUser>().is_some() {
836 return Ok(next.run(req).await);
837 }
838 let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
839 Some(Credentials::Password {
840 username: basic.username().to_owned(),
841 password: Password(basic.password().to_owned()),
842 })
843 } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
844 Some(Credentials::Token {
845 token: bearer.token().to_owned(),
846 })
847 } else {
848 None
849 };
850
851 let path = req.uri().path();
852 let include_www_authenticate_header = path == "/"
853 || PROFILING_API_ENDPOINTS
854 .iter()
855 .any(|prefix| path.starts_with(prefix));
856 let user = auth(
857 &authenticator,
858 creds,
859 allowed_roles,
860 include_www_authenticate_header,
861 )
862 .await?;
863
864 req.extensions_mut().insert(user);
867
868 Ok(next.run(req).await)
870}
871
872async fn init_ws(
873 WsState {
874 authenticator_rx,
875 adapter_client_rx,
876 active_connection_counter,
877 helm_chart_version,
878 allowed_roles,
879 }: &WsState,
880 existing_user: Option<AuthedUser>,
881 peer_addr: IpAddr,
882 ws: &mut WebSocket,
883) -> Result<AuthedClient, anyhow::Error> {
884 let authenticator = authenticator_rx.clone().await.expect("sender not dropped");
885 let ws_auth: WebSocketAuth = loop {
888 let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
889 match init_msg {
890 Message::Text(data) => break serde_json::from_str(&data)?,
891 Message::Binary(data) => break serde_json::from_slice(&data)?,
892 Message::Ping(_) => {
894 continue;
895 }
896 Message::Pong(_) => {
897 continue;
898 }
899 Message::Close(_) => {
900 anyhow::bail!("closed");
901 }
902 }
903 };
904
905 let (user, options) = if let Some(existing_user) = existing_user {
906 match ws_auth {
907 WebSocketAuth::OptionsOnly { options } => (existing_user, options),
908 _ => {
909 warn!("Unexpected bearer or basic auth provided when using user header");
910 anyhow::bail!("unexpected")
911 }
912 }
913 } else {
914 let (creds, options) = match ws_auth {
915 WebSocketAuth::Basic {
916 user,
917 password,
918 options,
919 } => {
920 let creds = Credentials::Password {
921 username: user,
922 password,
923 };
924 (creds, options)
925 }
926 WebSocketAuth::Bearer { token, options } => {
927 let creds = Credentials::Token { token };
928 (creds, options)
929 }
930 WebSocketAuth::OptionsOnly { .. } => {
931 anyhow::bail!("expected auth information");
932 }
933 };
934 let user = auth(&authenticator, Some(creds), *allowed_roles, false).await?;
935 (user, options)
936 };
937
938 let client = AuthedClient::new(
939 &adapter_client_rx.clone().await?,
940 user,
941 peer_addr,
942 active_connection_counter.clone(),
943 helm_chart_version.clone(),
944 |_session| (),
945 options,
946 SYSTEM_TIME.clone(),
947 )
948 .await?;
949
950 Ok(client)
951}
952
953enum Credentials {
954 Password {
955 username: String,
956 password: Password,
957 },
958 Token {
959 token: String,
960 },
961}
962
963async fn auth(
964 authenticator: &Authenticator,
965 creds: Option<Credentials>,
966 allowed_roles: AllowedRoles,
967 include_www_authenticate_header: bool,
968) -> Result<AuthedUser, AuthError> {
969 let (name, external_metadata_rx) = match authenticator {
971 Authenticator::Frontegg(frontegg) => match creds {
972 Some(Credentials::Password { username, password }) => {
973 let auth_session = frontegg.authenticate(&username, &password.0).await?;
974 let name = auth_session.user().into();
975 let external_metadata_rx = Some(auth_session.external_metadata_rx());
976 (name, external_metadata_rx)
977 }
978 Some(Credentials::Token { token }) => {
979 let claims = frontegg.validate_access_token(&token, None)?;
980 let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
981 user_id: claims.user_id,
982 admin: claims.is_admin,
983 });
984 (claims.user, Some(external_metadata_rx))
985 }
986 None => {
987 return Err(AuthError::MissingHttpAuthentication {
988 include_www_authenticate_header,
989 });
990 }
991 },
992 Authenticator::Password(adapter_client) => match creds {
993 Some(Credentials::Password { username, password }) => {
994 if let Err(_) = adapter_client.authenticate(&username, &password).await {
995 return Err(AuthError::InvalidCredentials);
996 }
997 (username, None)
998 }
999 _ => {
1000 return Err(AuthError::MissingHttpAuthentication {
1001 include_www_authenticate_header,
1002 });
1003 }
1004 },
1005 Authenticator::None => {
1006 let name = match creds {
1010 Some(Credentials::Password { username, .. }) => username,
1011 _ => HTTP_DEFAULT_USER.name.to_owned(),
1012 };
1013 (name, None)
1014 }
1015 };
1016
1017 check_role_allowed(&name, allowed_roles)?;
1018
1019 Ok(AuthedUser {
1020 name,
1021 external_metadata_rx,
1022 })
1023}
1024
1025fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1027 let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1028 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1030 let role_allowed = match allowed_roles {
1031 AllowedRoles::Normal => !is_reserved_user,
1032 AllowedRoles::Internal => is_internal_user,
1033 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1034 };
1035 if role_allowed {
1036 Ok(())
1037 } else {
1038 Err(AuthError::RoleDisallowed(name.to_owned()))
1039 }
1040}
1041
1042trait DefaultLayers {
1045 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1046}
1047
1048impl DefaultLayers for Router {
1049 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1050 self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1051 .layer(metrics::PrometheusLayer::new(source, metrics))
1052 }
1053}
1054
1055async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1062 if error.is::<tower::load_shed::error::Overloaded>() {
1063 return (
1064 StatusCode::TOO_MANY_REQUESTS,
1065 Cow::from("too many requests, try again later"),
1066 );
1067 }
1068
1069 (
1072 StatusCode::INTERNAL_SERVER_ERROR,
1073 Cow::from(format!("Unhandled internal error: {}", error)),
1074 )
1075}
1076
1077#[derive(Debug, Deserialize, Serialize, PartialEq)]
1078pub struct LoginCredentials {
1079 username: String,
1080 password: Password,
1081}
1082
1083#[derive(Debug, Clone, Serialize, Deserialize)]
1084pub struct TowerSessionData {
1085 username: String,
1086 created_at: SystemTime,
1087 last_activity: SystemTime,
1088}
1089
1090#[cfg(test)]
1091mod tests {
1092 use super::{AllowedRoles, check_role_allowed};
1093
1094 #[mz_ore::test]
1095 fn test_check_role_allowed() {
1096 assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1098 assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1099 assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1100
1101 assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1103 assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1104 assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1105
1106 assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1108 assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1109 assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1110
1111 assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1113 assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1114 assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1115
1116 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1118 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1119 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1120
1121 assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1123 assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1124 assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1125
1126 assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1128 assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1129 assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1130
1131 assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1133 assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1134 assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1135
1136 assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1138 assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1139 assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1140
1141 assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1143 assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1144 assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1145 }
1146}