1#![allow(clippy::unused_async)]
18
19use std::borrow::Cow;
20use std::collections::BTreeMap;
21use std::fmt::Debug;
22use std::net::{IpAddr, SocketAddr};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::time::{Duration, SystemTime};
26
27use anyhow::Context;
28use async_trait::async_trait;
29use axum::error_handling::HandleErrorLayer;
30use axum::extract::ws::{Message, WebSocket};
31use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
32use axum::middleware::{self, Next};
33use axum::response::{IntoResponse, Redirect, Response};
34use axum::{Extension, Json, Router, routing};
35use futures::future::{Shared, TryFutureExt};
36use headers::authorization::{Authorization, Basic, Bearer};
37use headers::{HeaderMapExt, HeaderName};
38use http::header::{AUTHORIZATION, CONTENT_TYPE};
39use http::{Method, StatusCode};
40use hyper_openssl::SslStream;
41use hyper_openssl::client::legacy::MaybeHttpsStream;
42use hyper_util::rt::TokioIo;
43use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
44use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
45use mz_auth::password::Password;
46use mz_authenticator::Authenticator;
47use mz_frontegg_auth::Error as FronteggError;
48use mz_http_util::DynamicFilterTarget;
49use mz_ore::cast::u64_to_usize;
50use mz_ore::metrics::MetricsRegistry;
51use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
52use mz_ore::str::StrExt;
53use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
54use mz_repr::user::ExternalUserMetadata;
55use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind, HttpRoutesEnabled};
56use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
57use mz_sql::session::metadata::SessionMetadata;
58use mz_sql::session::user::{
59 HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
60};
61use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
62use openssl::ssl::Ssl;
63use prometheus::{
64 COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
65};
66use serde::{Deserialize, Serialize};
67use serde_json::json;
68use thiserror::Error;
69use tokio::io::AsyncWriteExt;
70use tokio::sync::oneshot::Receiver;
71use tokio::sync::{oneshot, watch};
72use tower::limit::GlobalConcurrencyLimitLayer;
73use tower::{Service, ServiceBuilder};
74use tower_http::cors::{AllowOrigin, Any, CorsLayer};
75use tower_sessions::{
76 MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
77 SessionManagerLayer as TowerSessionManagerLayer,
78};
79use tracing::{error, warn};
80
81use crate::BUILD_INFO;
82use crate::deployment::state::DeploymentStateHandle;
83use crate::http::sql::SqlError;
84
85mod catalog;
86mod console;
87mod memory;
88mod metrics;
89mod probe;
90mod prometheus;
91mod root;
92mod sql;
93mod webhook;
94
95pub use metrics::Metrics;
96pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
97
98pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
100
101const SESSION_DURATION: Duration = Duration::from_secs(3600); #[derive(Debug)]
104pub struct HttpConfig {
105 pub source: &'static str,
106 pub tls: Option<ReloadingSslContext>,
107 pub authenticator_kind: AuthenticatorKind,
108 pub authenticator_rx: Shared<Receiver<Arc<Authenticator>>>,
109 pub adapter_client_rx: Shared<Receiver<Client>>,
110 pub allowed_origin: AllowOrigin,
111 pub active_connection_counter: ConnectionCounter,
112 pub helm_chart_version: Option<String>,
113 pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
114 pub metrics: Metrics,
115 pub metrics_registry: MetricsRegistry,
116 pub allowed_roles: AllowedRoles,
117 pub internal_route_config: Arc<InternalRouteConfig>,
118 pub routes_enabled: HttpRoutesEnabled,
119}
120
121#[derive(Debug, Clone)]
122pub struct InternalRouteConfig {
123 pub deployment_state_handle: DeploymentStateHandle,
124 pub internal_console_redirect_url: Option<String>,
125}
126
127#[derive(Clone)]
128pub struct WsState {
129 authenticator_rx: Delayed<Arc<Authenticator>>,
130 adapter_client_rx: Delayed<mz_adapter::Client>,
131 active_connection_counter: ConnectionCounter,
132 helm_chart_version: Option<String>,
133 allowed_roles: AllowedRoles,
134}
135
136#[derive(Clone)]
137pub struct WebhookState {
138 adapter_client_rx: Delayed<mz_adapter::Client>,
139 webhook_cache: WebhookAppenderCache,
140}
141
142#[derive(Debug)]
143pub struct HttpServer {
144 tls: Option<ReloadingSslContext>,
145 router: Router,
146}
147
148impl HttpServer {
149 pub fn new(
150 HttpConfig {
151 source,
152 tls,
153 authenticator_kind,
154 authenticator_rx,
155 adapter_client_rx,
156 allowed_origin,
157 active_connection_counter,
158 helm_chart_version,
159 concurrent_webhook_req,
160 metrics,
161 metrics_registry,
162 allowed_roles,
163 internal_route_config,
164 routes_enabled,
165 }: HttpConfig,
166 ) -> HttpServer {
167 let tls_enabled = tls.is_some();
168 let webhook_cache = WebhookAppenderCache::new();
169
170 let session_store = TowerSessionMemoryStore::default();
172 let session_layer = TowerSessionManagerLayer::new(session_store)
173 .with_secure(tls_enabled) .with_same_site(tower_sessions::cookie::SameSite::Strict) .with_http_only(true) .with_name("mz_session") .with_path("/"); let auth_middleware_authenticator_rx = authenticator_rx.clone();
180 let auth_middleware = middleware::from_fn(move |req, next| {
181 let authenticator_rx = auth_middleware_authenticator_rx.clone();
182 async move {
183 let authenticator = authenticator_rx
184 .await
185 .expect("sender not dropped before sending once");
186 http_auth(req, next, tls_enabled, authenticator, allowed_roles).await
187 }
188 });
189
190 let mut router = Router::new();
191 let mut base_router = Router::new();
192 if routes_enabled.base {
193 base_router = base_router
194 .route(
195 "/",
196 routing::get(move || async move {
197 root::handle_home(routes_enabled.profiling).await
198 }),
199 )
200 .route("/api/sql", routing::post(sql::handle_sql))
201 .route("/memory", routing::get(memory::handle_memory))
202 .route(
203 "/hierarchical-memory",
204 routing::get(memory::handle_hierarchical_memory),
205 )
206 .route("/static/*path", routing::get(root::handle_static));
207
208 let mut ws_router = Router::new()
209 .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
210 .with_state(WsState {
211 authenticator_rx: authenticator_rx.clone(),
212 adapter_client_rx: adapter_client_rx.clone(),
213 active_connection_counter: active_connection_counter.clone(),
214 helm_chart_version,
215 allowed_roles,
216 });
217 if let AuthenticatorKind::None = authenticator_kind {
218 ws_router = ws_router.layer(middleware::from_fn(x_materialize_user_header_auth));
219 }
220 router = router.merge(ws_router);
221 }
222 if routes_enabled.profiling {
223 base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
224 }
225
226 if routes_enabled.webhook {
227 let webhook_router = Router::new()
228 .route(
229 "/api/webhook/:database/:schema/:id",
230 routing::post(webhook::handle_webhook),
231 )
232 .with_state(WebhookState {
233 adapter_client_rx: adapter_client_rx.clone(),
234 webhook_cache,
235 })
236 .layer(
237 tower_http::decompression::RequestDecompressionLayer::new()
238 .gzip(true)
239 .deflate(true)
240 .br(true)
241 .zstd(true),
242 )
243 .layer(
244 CorsLayer::new()
245 .allow_methods(Method::POST)
246 .allow_origin(AllowOrigin::mirror_request())
247 .allow_headers(Any),
248 )
249 .layer(
250 ServiceBuilder::new()
251 .layer(HandleErrorLayer::new(handle_load_error))
252 .load_shed()
253 .layer(GlobalConcurrencyLimitLayer::with_semaphore(
254 concurrent_webhook_req,
255 )),
256 );
257 router = router.merge(webhook_router);
258 }
259
260 if routes_enabled.internal {
261 let console_config = Arc::new(console::ConsoleProxyConfig::new(
262 internal_route_config.internal_console_redirect_url.clone(),
263 "/internal-console".to_string(),
264 ));
265 base_router = base_router
266 .route(
267 "/api/opentelemetry/config",
268 routing::put({
269 move |_: axum::Json<DynamicFilterTarget>| async {
270 (
271 StatusCode::BAD_REQUEST,
272 "This endpoint has been replaced. \
273 Use the `opentelemetry_filter` system variable."
274 .to_string(),
275 )
276 }
277 }),
278 )
279 .route(
280 "/api/stderr/config",
281 routing::put({
282 move |_: axum::Json<DynamicFilterTarget>| async {
283 (
284 StatusCode::BAD_REQUEST,
285 "This endpoint has been replaced. \
286 Use the `log_filter` system variable."
287 .to_string(),
288 )
289 }
290 }),
291 )
292 .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
293 .route(
294 "/api/catalog/dump",
295 routing::get(catalog::handle_catalog_dump),
296 )
297 .route(
298 "/api/catalog/check",
299 routing::get(catalog::handle_catalog_check),
300 )
301 .route(
302 "/api/coordinator/check",
303 routing::get(catalog::handle_coordinator_check),
304 )
305 .route(
306 "/api/coordinator/dump",
307 routing::get(catalog::handle_coordinator_dump),
308 )
309 .route(
310 "/internal-console",
311 routing::get(|| async { Redirect::temporary("/internal-console/") }),
312 )
313 .route(
314 "/internal-console/*path",
315 routing::get(console::handle_internal_console),
316 )
317 .route(
318 "/internal-console/",
319 routing::get(console::handle_internal_console),
320 )
321 .layer(Extension(console_config));
322 let leader_router = Router::new()
323 .route("/api/leader/status", routing::get(handle_leader_status))
324 .route("/api/leader/promote", routing::post(handle_leader_promote))
325 .route(
326 "/api/leader/skip-catchup",
327 routing::post(handle_leader_skip_catchup),
328 )
329 .layer(auth_middleware.clone())
330 .with_state(internal_route_config.deployment_state_handle.clone());
331 router = router.merge(leader_router);
332 }
333
334 if routes_enabled.metrics {
335 let metrics_router = Router::new()
336 .route(
337 "/metrics",
338 routing::get(move || async move {
339 mz_http_util::handle_prometheus(&metrics_registry).await
340 }),
341 )
342 .route(
343 "/metrics/mz_usage",
344 routing::get(|client: AuthedClient| async move {
345 let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
346 mz_http_util::handle_prometheus(®istry).await
347 }),
348 )
349 .route(
350 "/metrics/mz_frontier",
351 routing::get(|client: AuthedClient| async move {
352 let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
353 mz_http_util::handle_prometheus(®istry).await
354 }),
355 )
356 .route(
357 "/metrics/mz_compute",
358 routing::get(|client: AuthedClient| async move {
359 let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
360 mz_http_util::handle_prometheus(®istry).await
361 }),
362 )
363 .route(
364 "/metrics/mz_storage",
365 routing::get(|client: AuthedClient| async move {
366 let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
367 mz_http_util::handle_prometheus(®istry).await
368 }),
369 )
370 .route(
371 "/api/livez",
372 routing::get(mz_http_util::handle_liveness_check),
373 )
374 .route("/api/readyz", routing::get(probe::handle_ready))
375 .layer(auth_middleware.clone())
376 .layer(Extension(adapter_client_rx.clone()))
377 .layer(Extension(active_connection_counter.clone()));
378 router = router.merge(metrics_router);
379 }
380
381 base_router = base_router
382 .layer(auth_middleware.clone())
383 .layer(Extension(adapter_client_rx.clone()))
384 .layer(Extension(active_connection_counter.clone()))
385 .layer(
386 CorsLayer::new()
387 .allow_credentials(false)
388 .allow_headers([
389 AUTHORIZATION,
390 CONTENT_TYPE,
391 HeaderName::from_static("x-materialize-version"),
392 ])
393 .allow_methods(Any)
394 .allow_origin(allowed_origin)
395 .expose_headers(Any)
396 .max_age(Duration::from_secs(60) * 60),
397 );
398
399 match authenticator_kind {
400 AuthenticatorKind::Password => {
401 base_router = base_router.layer(session_layer.clone());
402
403 let login_router = Router::new()
404 .route("/api/login", routing::post(handle_login))
405 .route("/api/logout", routing::post(handle_logout))
406 .layer(Extension(adapter_client_rx));
407 router = router.merge(login_router).layer(session_layer);
408 }
409 AuthenticatorKind::None => {
410 base_router =
411 base_router.layer(middleware::from_fn(x_materialize_user_header_auth));
412 }
413 _ => {}
414 }
415
416 router = router
417 .merge(base_router)
418 .apply_default_layers(source, metrics);
419
420 HttpServer { tls, router }
421 }
422}
423
424impl Server for HttpServer {
425 const NAME: &'static str = "http";
426
427 fn handle_connection(&self, conn: Connection) -> ConnectionHandler {
428 let router = self.router.clone();
429 let tls_context = self.tls.clone();
430 let mut conn = TokioIo::new(conn);
431
432 Box::pin(async {
433 let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
434 let peer_addr = conn
435 .inner_mut()
436 .take_proxy_header_address()
437 .await
438 .map(|a| a.source)
439 .unwrap_or(direct_peer_addr);
440
441 let (conn, conn_protocol) = match tls_context {
442 Some(tls_context) => {
443 let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
444 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
445 let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
446 return Err(e.into());
447 }
448 (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
449 }
450 _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
451 };
452 let mut make_tower_svc = router
453 .layer(Extension(conn_protocol))
454 .into_make_service_with_connect_info::<SocketAddr>();
455 let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
456 let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
457 let http = hyper::server::conn::http1::Builder::new();
458 http.serve_connection(conn, hyper_svc)
459 .with_upgrades()
460 .err_into()
461 .await
462 })
463 }
464}
465
466pub async fn handle_leader_status(
467 State(deployment_state_handle): State<DeploymentStateHandle>,
468) -> impl IntoResponse {
469 let status = deployment_state_handle.status();
470 (StatusCode::OK, Json(json!({ "status": status })))
471}
472
473pub async fn handle_leader_promote(
474 State(deployment_state_handle): State<DeploymentStateHandle>,
475) -> impl IntoResponse {
476 match deployment_state_handle.try_promote() {
477 Ok(()) => {
478 let status = StatusCode::OK;
481 let body = Json(json!({
482 "result": "Success",
483 }));
484 (status, body)
485 }
486 Err(()) => {
487 let status = StatusCode::BAD_REQUEST;
490 let body = Json(json!({
491 "result": {"Failure": {"message": "cannot promote leader while initializing"}},
492 }));
493 (status, body)
494 }
495 }
496}
497
498pub async fn handle_leader_skip_catchup(
499 State(deployment_state_handle): State<DeploymentStateHandle>,
500) -> impl IntoResponse {
501 match deployment_state_handle.try_skip_catchup() {
502 Ok(()) => StatusCode::NO_CONTENT.into_response(),
503 Err(()) => {
504 let status = StatusCode::BAD_REQUEST;
505 let body = Json(json!({
506 "message": "cannot skip catchup in this phase of initialization; try again later",
507 }));
508 (status, body).into_response()
509 }
510 }
511}
512
513async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl IntoResponse {
514 if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
516 let username = match username {
517 Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
518 _ => {
519 return Err(AuthError::MismatchedUser(format!(
520 "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
521 )));
522 }
523 };
524 req.extensions_mut().insert(AuthedUser {
525 name: username,
526 external_metadata_rx: None,
527 });
528 }
529 Ok(next.run(req).await)
530}
531
532type Delayed<T> = Shared<oneshot::Receiver<T>>;
533
534#[derive(Clone)]
535enum ConnProtocol {
536 Http,
537 Https,
538}
539
540#[derive(Clone, Debug)]
541pub struct AuthedUser {
542 name: String,
543 external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
544}
545
546pub struct AuthedClient {
547 pub client: SessionClient,
548 pub connection_guard: Option<ConnectionHandle>,
549}
550
551impl AuthedClient {
552 async fn new<F>(
553 adapter_client: &Client,
554 user: AuthedUser,
555 peer_addr: IpAddr,
556 active_connection_counter: ConnectionCounter,
557 helm_chart_version: Option<String>,
558 session_config: F,
559 options: BTreeMap<String, String>,
560 now: NowFn,
561 ) -> Result<Self, AdapterError>
562 where
563 F: FnOnce(&mut AdapterSession),
564 {
565 let conn_id = adapter_client.new_conn_id()?;
566 let mut session = adapter_client.new_session(AdapterSessionConfig {
567 conn_id,
568 uuid: epoch_to_uuid_v7(&(now)()),
569 user: user.name,
570 client_ip: Some(peer_addr),
571 external_metadata_rx: user.external_metadata_rx,
572 internal_user_metadata: None,
574 helm_chart_version,
575 });
576 let connection_guard = active_connection_counter.allocate_connection(session.user())?;
577
578 session_config(&mut session);
579 let system_vars = adapter_client.get_system_vars().await;
580 for (key, val) in options {
581 const LOCAL: bool = false;
582 if let Err(err) =
583 session
584 .vars_mut()
585 .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
586 {
587 session.add_notice(AdapterNotice::BadStartupSetting {
588 name: key.to_string(),
589 reason: err.to_string(),
590 })
591 }
592 }
593 let adapter_client = adapter_client.startup(session).await?;
594 Ok(AuthedClient {
595 client: adapter_client,
596 connection_guard,
597 })
598 }
599}
600
601#[async_trait]
602impl<S> FromRequestParts<S> for AuthedClient
603where
604 S: Send + Sync,
605{
606 type Rejection = Response;
607
608 async fn from_request_parts(
609 req: &mut http::request::Parts,
610 state: &S,
611 ) -> Result<Self, Self::Rejection> {
612 #[derive(Debug, Default, Deserialize)]
613 struct Params {
614 #[serde(default)]
615 options: String,
616 }
617 let params: Query<Params> = Query::from_request_parts(req, state)
618 .await
619 .unwrap_or_default();
620
621 let peer_addr = req
622 .extensions
623 .get::<ConnectInfo<SocketAddr>>()
624 .expect("ConnectInfo extension guaranteed to exist")
625 .0
626 .ip();
627
628 let user = req.extensions.get::<AuthedUser>().unwrap();
629 let adapter_client = req
630 .extensions
631 .get::<Delayed<mz_adapter::Client>>()
632 .unwrap()
633 .clone();
634 let adapter_client = adapter_client.await.map_err(|_| {
635 (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
636 })?;
637 let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
638 let helm_chart_version = None;
639
640 let options = if params.options.is_empty() {
641 BTreeMap::<String, String>::default()
644 } else {
645 match serde_json::from_str(¶ms.options) {
646 Ok(options) => options,
647 Err(_e) => {
648 let code = StatusCode::BAD_REQUEST;
650 let msg = format!("Failed to deserialize {} map", "options".quoted());
651 return Err((code, msg).into_response());
652 }
653 }
654 };
655
656 let client = AuthedClient::new(
657 &adapter_client,
658 user.clone(),
659 peer_addr,
660 active_connection_counter.clone(),
661 helm_chart_version,
662 |session| {
663 session
664 .vars_mut()
665 .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
666 .expect("known to exist")
667 },
668 options,
669 SYSTEM_TIME.clone(),
670 )
671 .await
672 .map_err(|e| {
673 let status = match e {
674 AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
675 StatusCode::FORBIDDEN
676 }
677 _ => StatusCode::INTERNAL_SERVER_ERROR,
678 };
679 (status, Json(SqlError::from(e))).into_response()
680 })?;
681
682 Ok(client)
683 }
684}
685
686#[derive(Debug, Error)]
687enum AuthError {
688 #[error("HTTPS is required")]
689 HttpsRequired,
690 #[error("invalid username in client certificate")]
691 InvalidLogin(String),
692 #[error("{0}")]
693 Frontegg(#[from] FronteggError),
694 #[error("missing authorization header")]
695 MissingHttpAuthentication,
696 #[error("{0}")]
697 MismatchedUser(String),
698 #[error("session expired")]
699 SessionExpired,
700 #[error("failed to update session")]
701 FailedToUpdateSession,
702}
703
704impl IntoResponse for AuthError {
705 fn into_response(self) -> Response {
706 warn!("HTTP request failed authentication: {}", self);
707 let message = match self {
710 AuthError::HttpsRequired => self.to_string(),
711 _ => "unauthorized".into(),
712 };
713 (
714 StatusCode::UNAUTHORIZED,
715 [(http::header::WWW_AUTHENTICATE, "Basic realm=Materialize")],
716 message,
717 )
718 .into_response()
719 }
720}
721
722pub async fn handle_login(
724 session: Option<Extension<TowerSession>>,
725 Extension(adapter_client_rx): Extension<Delayed<Client>>,
726 Json(LoginCredentials { username, password }): Json<LoginCredentials>,
727) -> impl IntoResponse {
728 let Ok(adapter_client) = adapter_client_rx.clone().await else {
729 return StatusCode::INTERNAL_SERVER_ERROR;
730 };
731 if let Err(err) = adapter_client.authenticate(&username, &password).await {
732 warn!(?err, "HTTP login failed authentication");
733 return StatusCode::UNAUTHORIZED;
734 };
735
736 let session_data = TowerSessionData {
738 username,
739 created_at: SystemTime::now(),
740 last_activity: SystemTime::now(),
741 };
742 let session = session.and_then(|Extension(session)| Some(session));
744 let Some(session) = session else {
745 return StatusCode::INTERNAL_SERVER_ERROR;
746 };
747 match session.insert("data", &session_data).await {
748 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
749 Ok(_) => StatusCode::OK,
750 }
751}
752
753pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
755 let session = session.and_then(|Extension(session)| Some(session));
756 let Some(session) = session else {
757 return StatusCode::INTERNAL_SERVER_ERROR;
758 };
759 match session.delete().await {
761 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
762 Ok(_) => StatusCode::OK,
763 }
764}
765
766async fn http_auth(
767 mut req: Request,
768 next: Next,
769 tls_enabled: bool,
770 authenticator: Arc<Authenticator>,
771 allowed_roles: AllowedRoles,
772) -> impl IntoResponse + use<> {
773 if let Some(session) = req.extensions().get::<TowerSession>() {
775 if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
776 if session_data
778 .last_activity
779 .elapsed()
780 .unwrap_or(Duration::MAX)
781 > SESSION_DURATION
782 {
783 let _ = session.delete().await;
784 return Err(AuthError::SessionExpired);
785 }
786 let mut updated_data = session_data.clone();
788 updated_data.last_activity = SystemTime::now();
789 session
790 .insert("data", &updated_data)
791 .await
792 .map_err(|_| AuthError::FailedToUpdateSession)?;
793 req.extensions_mut().insert(AuthedUser {
795 name: session_data.username,
796 external_metadata_rx: None,
797 });
798 return Ok(next.run(req).await);
799 }
800 }
801
802 let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
806 match (tls_enabled, &conn_protocol) {
807 (false, ConnProtocol::Http) => {}
808 (false, ConnProtocol::Https { .. }) => unreachable!(),
809 (true, ConnProtocol::Http) => return Err(AuthError::HttpsRequired),
810 (true, ConnProtocol::Https { .. }) => {}
811 }
812 if req.extensions().get::<AuthedUser>().is_some() {
814 return Ok(next.run(req).await);
815 }
816 let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
817 Some(Credentials::Password {
818 username: basic.username().to_owned(),
819 password: Password(basic.password().to_owned()),
820 })
821 } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
822 Some(Credentials::Token {
823 token: bearer.token().to_owned(),
824 })
825 } else {
826 None
827 };
828
829 let user = auth(&authenticator, creds, allowed_roles).await?;
830
831 req.extensions_mut().insert(user);
834
835 Ok(next.run(req).await)
837}
838
839async fn init_ws(
840 WsState {
841 authenticator_rx,
842 adapter_client_rx,
843 active_connection_counter,
844 helm_chart_version,
845 allowed_roles,
846 }: &WsState,
847 existing_user: Option<AuthedUser>,
848 peer_addr: IpAddr,
849 ws: &mut WebSocket,
850) -> Result<AuthedClient, anyhow::Error> {
851 let authenticator = authenticator_rx.clone().await.expect("sender not dropped");
852 let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
855 let ws_auth: WebSocketAuth = loop {
856 match init_msg {
857 Message::Text(data) => break serde_json::from_str(&data)?,
858 Message::Binary(data) => break serde_json::from_slice(&data)?,
859 Message::Ping(_) => {
861 continue;
862 }
863 Message::Pong(_) => {
864 continue;
865 }
866 Message::Close(_) => {
867 anyhow::bail!("closed");
868 }
869 }
870 };
871
872 let (user, options) = if let Some(existing_user) = existing_user {
873 match ws_auth {
874 WebSocketAuth::OptionsOnly { options } => (existing_user, options),
875 _ => {
876 warn!("Unexpected bearer or basic auth provided when using user header");
877 anyhow::bail!("unexpected")
878 }
879 }
880 } else {
881 let (creds, options) = match ws_auth {
882 WebSocketAuth::Basic {
883 user,
884 password,
885 options,
886 } => {
887 let creds = Credentials::Password {
888 username: user,
889 password,
890 };
891 (creds, options)
892 }
893 WebSocketAuth::Bearer { token, options } => {
894 let creds = Credentials::Token { token };
895 (creds, options)
896 }
897 WebSocketAuth::OptionsOnly { .. } => {
898 anyhow::bail!("expected auth information");
899 }
900 };
901 let user = auth(&authenticator, Some(creds), *allowed_roles).await?;
902 (user, options)
903 };
904
905 let client = AuthedClient::new(
906 &adapter_client_rx.clone().await?,
907 user,
908 peer_addr,
909 active_connection_counter.clone(),
910 helm_chart_version.clone(),
911 |_session| (),
912 options,
913 SYSTEM_TIME.clone(),
914 )
915 .await?;
916
917 Ok(client)
918}
919
920enum Credentials {
921 Password {
922 username: String,
923 password: Password,
924 },
925 Token {
926 token: String,
927 },
928}
929
930async fn auth(
931 authenticator: &Authenticator,
932 creds: Option<Credentials>,
933 allowed_roles: AllowedRoles,
934) -> Result<AuthedUser, AuthError> {
935 let (name, external_metadata_rx) = match authenticator {
937 Authenticator::Frontegg(frontegg) => match creds {
938 Some(Credentials::Password { username, password }) => {
939 let auth_session = frontegg.authenticate(&username, &password.0).await?;
940 let name = auth_session.user().into();
941 let external_metadata_rx = Some(auth_session.external_metadata_rx());
942 (name, external_metadata_rx)
943 }
944 Some(Credentials::Token { token }) => {
945 let claims = frontegg.validate_access_token(&token, None)?;
946 let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
947 user_id: claims.user_id,
948 admin: claims.is_admin,
949 });
950 (claims.user, Some(external_metadata_rx))
951 }
952 None => return Err(AuthError::MissingHttpAuthentication),
953 },
954 Authenticator::Password(_) => {
955 warn!("self hosted auth only, but somehow missing session data");
956 return Err(AuthError::MissingHttpAuthentication);
958 }
959 Authenticator::None => {
960 let name = match creds {
964 Some(Credentials::Password { username, .. }) => username,
965 _ => HTTP_DEFAULT_USER.name.to_owned(),
966 };
967 (name, None)
968 }
969 };
970
971 check_role_allowed(&name, allowed_roles)?;
972
973 Ok(AuthedUser {
974 name,
975 external_metadata_rx,
976 })
977}
978
979fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
981 let is_internal_user = INTERNAL_USER_NAMES.contains(name);
982 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
984 let role_allowed = match allowed_roles {
985 AllowedRoles::Normal => !is_reserved_user,
986 AllowedRoles::Internal => is_internal_user,
987 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
988 };
989 if role_allowed {
990 Ok(())
991 } else {
992 Err(AuthError::InvalidLogin(name.to_owned()))
993 }
994}
995
996trait DefaultLayers {
999 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1000}
1001
1002impl DefaultLayers for Router {
1003 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1004 self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1005 .layer(metrics::PrometheusLayer::new(source, metrics))
1006 }
1007}
1008
1009async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1016 if error.is::<tower::load_shed::error::Overloaded>() {
1017 return (
1018 StatusCode::TOO_MANY_REQUESTS,
1019 Cow::from("too many requests, try again later"),
1020 );
1021 }
1022
1023 (
1026 StatusCode::INTERNAL_SERVER_ERROR,
1027 Cow::from(format!("Unhandled internal error: {}", error)),
1028 )
1029}
1030
1031#[derive(Debug, Deserialize, Serialize, PartialEq)]
1032pub struct LoginCredentials {
1033 username: String,
1034 password: Password,
1035}
1036
1037#[derive(Debug, Clone, Serialize, Deserialize)]
1038pub struct TowerSessionData {
1039 username: String,
1040 created_at: SystemTime,
1041 last_activity: SystemTime,
1042}
1043
1044#[cfg(test)]
1045mod tests {
1046 use super::{AllowedRoles, check_role_allowed};
1047
1048 #[mz_ore::test]
1049 fn test_check_role_allowed() {
1050 assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1052 assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1053 assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1054
1055 assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1057 assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1058 assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1059
1060 assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1062 assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1063 assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1064
1065 assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1067 assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1068 assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1069
1070 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1072 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1073 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1074
1075 assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1077 assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1078 assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1079
1080 assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1082 assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1083 assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1084
1085 assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1087 assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1088 assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1089
1090 assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1092 assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1093 assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1094
1095 assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1097 assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1098 assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1099 }
1100}