1#![allow(clippy::unused_async)]
18
19use std::borrow::Cow;
20use std::collections::BTreeMap;
21use std::fmt::Debug;
22use std::net::{IpAddr, SocketAddr};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::time::{Duration, SystemTime};
26
27use anyhow::Context;
28use async_trait::async_trait;
29use axum::error_handling::HandleErrorLayer;
30use axum::extract::ws::{Message, WebSocket};
31use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
32use axum::middleware::{self, Next};
33use axum::response::{IntoResponse, Redirect, Response};
34use axum::{Extension, Json, Router, routing};
35use futures::future::{Shared, TryFutureExt};
36use headers::authorization::{Authorization, Basic, Bearer};
37use headers::{HeaderMapExt, HeaderName};
38use http::header::{AUTHORIZATION, CONTENT_TYPE};
39use http::uri::Scheme;
40use http::{HeaderMap, HeaderValue, Method, StatusCode, Uri};
41use hyper_openssl::SslStream;
42use hyper_openssl::client::legacy::MaybeHttpsStream;
43use hyper_util::rt::TokioIo;
44use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig};
45use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
46use mz_auth::password::Password;
47use mz_authenticator::Authenticator;
48use mz_frontegg_auth::Error as FronteggError;
49use mz_http_util::DynamicFilterTarget;
50use mz_ore::cast::u64_to_usize;
51use mz_ore::metrics::MetricsRegistry;
52use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
53use mz_ore::str::StrExt;
54use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
55use mz_repr::user::{ExternalUserMetadata, InternalUserMetadata};
56use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind, HttpRoutesEnabled};
57use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
58use mz_sql::session::metadata::SessionMetadata;
59use mz_sql::session::user::{
60 HTTP_DEFAULT_USER, INTERNAL_USER_NAMES, SUPPORT_USER_NAME, SYSTEM_USER_NAME,
61};
62use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
63use openssl::ssl::Ssl;
64use prometheus::{
65 COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
66};
67use serde::{Deserialize, Serialize};
68use serde_json::json;
69use thiserror::Error;
70use tokio::io::AsyncWriteExt;
71use tokio::sync::oneshot::Receiver;
72use tokio::sync::{oneshot, watch};
73use tokio_metrics::TaskMetrics;
74use tower::limit::GlobalConcurrencyLimitLayer;
75use tower::{Service, ServiceBuilder};
76use tower_http::cors::{AllowOrigin, Any, CorsLayer};
77use tower_sessions::{
78 MemoryStore as TowerSessionMemoryStore, Session as TowerSession,
79 SessionManagerLayer as TowerSessionManagerLayer,
80};
81use tracing::{error, warn};
82
83use crate::BUILD_INFO;
84use crate::deployment::state::DeploymentStateHandle;
85use crate::http::sql::SqlError;
86
87mod catalog;
88mod console;
89mod memory;
90mod metrics;
91mod probe;
92mod prometheus;
93mod root;
94mod sql;
95mod webhook;
96
97pub use metrics::Metrics;
98pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
99
100pub const MAX_REQUEST_SIZE: usize = u64_to_usize(5 * bytesize::MIB);
102
103const SESSION_DURATION: Duration = Duration::from_secs(8 * 3600); const PROFILING_API_ENDPOINTS: &[&str] = &["/memory", "/hierarchical-memory", "/prof/"];
106
107#[derive(Debug)]
108pub struct HttpConfig {
109 pub source: &'static str,
110 pub tls: Option<ReloadingSslContext>,
111 pub authenticator_kind: AuthenticatorKind,
112 pub authenticator_rx: Shared<Receiver<Arc<Authenticator>>>,
113 pub adapter_client_rx: Shared<Receiver<Client>>,
114 pub allowed_origin: AllowOrigin,
115 pub active_connection_counter: ConnectionCounter,
116 pub helm_chart_version: Option<String>,
117 pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
118 pub metrics: Metrics,
119 pub metrics_registry: MetricsRegistry,
120 pub allowed_roles: AllowedRoles,
121 pub internal_route_config: Arc<InternalRouteConfig>,
122 pub routes_enabled: HttpRoutesEnabled,
123}
124
125#[derive(Debug, Clone)]
126pub struct InternalRouteConfig {
127 pub deployment_state_handle: DeploymentStateHandle,
128 pub internal_console_redirect_url: Option<String>,
129}
130
131#[derive(Clone)]
132pub struct WsState {
133 authenticator_rx: Delayed<Arc<Authenticator>>,
134 adapter_client_rx: Delayed<mz_adapter::Client>,
135 active_connection_counter: ConnectionCounter,
136 helm_chart_version: Option<String>,
137 allowed_roles: AllowedRoles,
138}
139
140#[derive(Clone)]
141pub struct WebhookState {
142 adapter_client_rx: Delayed<mz_adapter::Client>,
143 webhook_cache: WebhookAppenderCache,
144}
145
146#[derive(Debug)]
147pub struct HttpServer {
148 tls: Option<ReloadingSslContext>,
149 router: Router,
150}
151
152impl HttpServer {
153 pub fn new(
154 HttpConfig {
155 source,
156 tls,
157 authenticator_kind,
158 authenticator_rx,
159 adapter_client_rx,
160 allowed_origin,
161 active_connection_counter,
162 helm_chart_version,
163 concurrent_webhook_req,
164 metrics,
165 metrics_registry,
166 allowed_roles,
167 internal_route_config,
168 routes_enabled,
169 }: HttpConfig,
170 ) -> HttpServer {
171 let tls_enabled = tls.is_some();
172 let webhook_cache = WebhookAppenderCache::new();
173
174 let session_store = TowerSessionMemoryStore::default();
176 let session_layer = TowerSessionManagerLayer::new(session_store)
177 .with_secure(tls_enabled) .with_same_site(tower_sessions::cookie::SameSite::Strict) .with_http_only(true) .with_name("mz_session") .with_path("/"); let auth_middleware_authenticator_rx = authenticator_rx.clone();
184 let auth_middleware = middleware::from_fn(move |req, next| {
185 let authenticator_rx = auth_middleware_authenticator_rx.clone();
186 async move {
187 let authenticator = authenticator_rx
188 .await
189 .expect("sender not dropped before sending once");
190 http_auth(req, next, tls_enabled, authenticator, allowed_roles).await
191 }
192 });
193
194 let mut router = Router::new();
195 let mut base_router = Router::new();
196 if routes_enabled.base {
197 base_router = base_router
198 .route(
199 "/",
200 routing::get(move || async move {
201 root::handle_home(routes_enabled.profiling).await
202 }),
203 )
204 .route("/api/sql", routing::post(sql::handle_sql))
205 .route("/memory", routing::get(memory::handle_memory))
206 .route(
207 "/hierarchical-memory",
208 routing::get(memory::handle_hierarchical_memory),
209 )
210 .route("/static/*path", routing::get(root::handle_static));
211
212 let mut ws_router = Router::new()
213 .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
214 .with_state(WsState {
215 authenticator_rx: authenticator_rx.clone(),
216 adapter_client_rx: adapter_client_rx.clone(),
217 active_connection_counter: active_connection_counter.clone(),
218 helm_chart_version,
219 allowed_roles,
220 });
221 if let AuthenticatorKind::None = authenticator_kind {
222 ws_router = ws_router.layer(middleware::from_fn(x_materialize_user_header_auth));
223 }
224 router = router.merge(ws_router);
225 }
226 if routes_enabled.profiling {
227 base_router = base_router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
228 }
229
230 if routes_enabled.webhook {
231 let webhook_router = Router::new()
232 .route(
233 "/api/webhook/:database/:schema/:id",
234 routing::post(webhook::handle_webhook),
235 )
236 .with_state(WebhookState {
237 adapter_client_rx: adapter_client_rx.clone(),
238 webhook_cache,
239 })
240 .layer(
241 tower_http::decompression::RequestDecompressionLayer::new()
242 .gzip(true)
243 .deflate(true)
244 .br(true)
245 .zstd(true),
246 )
247 .layer(
248 CorsLayer::new()
249 .allow_methods(Method::POST)
250 .allow_origin(AllowOrigin::mirror_request())
251 .allow_headers(Any),
252 )
253 .layer(
254 ServiceBuilder::new()
255 .layer(HandleErrorLayer::new(handle_load_error))
256 .load_shed()
257 .layer(GlobalConcurrencyLimitLayer::with_semaphore(
258 concurrent_webhook_req,
259 )),
260 );
261 router = router.merge(webhook_router);
262 }
263
264 if routes_enabled.internal {
265 let console_config = Arc::new(console::ConsoleProxyConfig::new(
266 internal_route_config.internal_console_redirect_url.clone(),
267 "/internal-console".to_string(),
268 ));
269 base_router = base_router
270 .route(
271 "/api/opentelemetry/config",
272 routing::put({
273 move |_: axum::Json<DynamicFilterTarget>| async {
274 (
275 StatusCode::BAD_REQUEST,
276 "This endpoint has been replaced. \
277 Use the `opentelemetry_filter` system variable."
278 .to_string(),
279 )
280 }
281 }),
282 )
283 .route(
284 "/api/stderr/config",
285 routing::put({
286 move |_: axum::Json<DynamicFilterTarget>| async {
287 (
288 StatusCode::BAD_REQUEST,
289 "This endpoint has been replaced. \
290 Use the `log_filter` system variable."
291 .to_string(),
292 )
293 }
294 }),
295 )
296 .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
297 .route(
298 "/api/catalog/dump",
299 routing::get(catalog::handle_catalog_dump),
300 )
301 .route(
302 "/api/catalog/check",
303 routing::get(catalog::handle_catalog_check),
304 )
305 .route(
306 "/api/coordinator/check",
307 routing::get(catalog::handle_coordinator_check),
308 )
309 .route(
310 "/api/coordinator/dump",
311 routing::get(catalog::handle_coordinator_dump),
312 )
313 .route(
314 "/internal-console",
315 routing::get(|| async { Redirect::temporary("/internal-console/") }),
316 )
317 .route(
318 "/internal-console/*path",
319 routing::get(console::handle_internal_console),
320 )
321 .route(
322 "/internal-console/",
323 routing::get(console::handle_internal_console),
324 )
325 .layer(Extension(console_config));
326 let leader_router = Router::new()
327 .route("/api/leader/status", routing::get(handle_leader_status))
328 .route("/api/leader/promote", routing::post(handle_leader_promote))
329 .route(
330 "/api/leader/skip-catchup",
331 routing::post(handle_leader_skip_catchup),
332 )
333 .layer(auth_middleware.clone())
334 .with_state(internal_route_config.deployment_state_handle.clone());
335 router = router.merge(leader_router);
336 }
337
338 if routes_enabled.metrics {
339 let metrics_router = Router::new()
340 .route(
341 "/metrics",
342 routing::get(move || async move {
343 mz_http_util::handle_prometheus(&metrics_registry).await
344 }),
345 )
346 .route(
347 "/metrics/mz_usage",
348 routing::get(|client: AuthedClient| async move {
349 let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
350 mz_http_util::handle_prometheus(®istry).await
351 }),
352 )
353 .route(
354 "/metrics/mz_frontier",
355 routing::get(|client: AuthedClient| async move {
356 let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
357 mz_http_util::handle_prometheus(®istry).await
358 }),
359 )
360 .route(
361 "/metrics/mz_compute",
362 routing::get(|client: AuthedClient| async move {
363 let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
364 mz_http_util::handle_prometheus(®istry).await
365 }),
366 )
367 .route(
368 "/metrics/mz_storage",
369 routing::get(|client: AuthedClient| async move {
370 let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
371 mz_http_util::handle_prometheus(®istry).await
372 }),
373 )
374 .route(
375 "/api/livez",
376 routing::get(mz_http_util::handle_liveness_check),
377 )
378 .route("/api/readyz", routing::get(probe::handle_ready))
379 .layer(auth_middleware.clone())
380 .layer(Extension(adapter_client_rx.clone()))
381 .layer(Extension(active_connection_counter.clone()));
382 router = router.merge(metrics_router);
383 }
384
385 base_router = base_router
386 .layer(auth_middleware.clone())
387 .layer(Extension(adapter_client_rx.clone()))
388 .layer(Extension(active_connection_counter.clone()))
389 .layer(
390 CorsLayer::new()
391 .allow_credentials(false)
392 .allow_headers([
393 AUTHORIZATION,
394 CONTENT_TYPE,
395 HeaderName::from_static("x-materialize-version"),
396 ])
397 .allow_methods(Any)
398 .allow_origin(allowed_origin)
399 .expose_headers(Any)
400 .max_age(Duration::from_secs(60) * 60),
401 );
402
403 match authenticator_kind {
404 AuthenticatorKind::Password => {
405 base_router = base_router.layer(session_layer.clone());
406
407 let login_router = Router::new()
408 .route("/api/login", routing::post(handle_login))
409 .route("/api/logout", routing::post(handle_logout))
410 .layer(Extension(adapter_client_rx));
411 router = router.merge(login_router).layer(session_layer);
412 }
413 AuthenticatorKind::None => {
414 base_router =
415 base_router.layer(middleware::from_fn(x_materialize_user_header_auth));
416 }
417 _ => {}
418 }
419
420 router = router
421 .merge(base_router)
422 .apply_default_layers(source, metrics);
423
424 HttpServer { tls, router }
425 }
426}
427
428impl Server for HttpServer {
429 const NAME: &'static str = "http";
430
431 fn handle_connection(
432 &self,
433 conn: Connection,
434 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
435 ) -> ConnectionHandler {
436 let router = self.router.clone();
437 let tls_context = self.tls.clone();
438 let mut conn = TokioIo::new(conn);
439
440 Box::pin(async {
441 let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
442 let peer_addr = conn
443 .inner_mut()
444 .take_proxy_header_address()
445 .await
446 .map(|a| a.source)
447 .unwrap_or(direct_peer_addr);
448
449 let (conn, conn_protocol) = match tls_context {
450 Some(tls_context) => {
451 let mut ssl_stream = SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
452 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
453 let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
454 return Err(e.into());
455 }
456 (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
457 }
458 _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
459 };
460 let mut make_tower_svc = router
461 .layer(Extension(conn_protocol))
462 .into_make_service_with_connect_info::<SocketAddr>();
463 let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
464 let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
465 let http = hyper::server::conn::http1::Builder::new();
466 http.serve_connection(conn, hyper_svc)
467 .with_upgrades()
468 .err_into()
469 .await
470 })
471 }
472}
473
474pub async fn handle_leader_status(
475 State(deployment_state_handle): State<DeploymentStateHandle>,
476) -> impl IntoResponse {
477 let status = deployment_state_handle.status();
478 (StatusCode::OK, Json(json!({ "status": status })))
479}
480
481pub async fn handle_leader_promote(
482 State(deployment_state_handle): State<DeploymentStateHandle>,
483) -> impl IntoResponse {
484 match deployment_state_handle.try_promote() {
485 Ok(()) => {
486 let status = StatusCode::OK;
489 let body = Json(json!({
490 "result": "Success",
491 }));
492 (status, body)
493 }
494 Err(()) => {
495 let status = StatusCode::BAD_REQUEST;
498 let body = Json(json!({
499 "result": {"Failure": {"message": "cannot promote leader while initializing"}},
500 }));
501 (status, body)
502 }
503 }
504}
505
506pub async fn handle_leader_skip_catchup(
507 State(deployment_state_handle): State<DeploymentStateHandle>,
508) -> impl IntoResponse {
509 match deployment_state_handle.try_skip_catchup() {
510 Ok(()) => StatusCode::NO_CONTENT.into_response(),
511 Err(()) => {
512 let status = StatusCode::BAD_REQUEST;
513 let body = Json(json!({
514 "message": "cannot skip catchup in this phase of initialization; try again later",
515 }));
516 (status, body).into_response()
517 }
518 }
519}
520
521async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl IntoResponse {
522 if let Some(username) = req.headers().get("x-materialize-user").map(|h| h.to_str()) {
524 let username = match username {
525 Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
526 _ => {
527 return Err(AuthError::MismatchedUser(format!(
528 "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
529 )));
530 }
531 };
532 let superuser = matches!(username.as_str(), SYSTEM_USER_NAME);
533 req.extensions_mut().insert(AuthedUser {
534 name: username,
535 external_metadata_rx: None,
536 internal_metadata: Some(InternalUserMetadata { superuser }),
537 });
538 }
539 Ok(next.run(req).await)
540}
541
542type Delayed<T> = Shared<oneshot::Receiver<T>>;
543
544#[derive(Clone)]
545enum ConnProtocol {
546 Http,
547 Https,
548}
549
550#[derive(Clone, Debug)]
551pub struct AuthedUser {
552 name: String,
553 external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
554 internal_metadata: Option<InternalUserMetadata>,
555}
556
557pub struct AuthedClient {
558 pub client: SessionClient,
559 pub connection_guard: Option<ConnectionHandle>,
560}
561
562impl AuthedClient {
563 async fn new<F>(
564 adapter_client: &Client,
565 user: AuthedUser,
566 peer_addr: IpAddr,
567 active_connection_counter: ConnectionCounter,
568 helm_chart_version: Option<String>,
569 session_config: F,
570 options: BTreeMap<String, String>,
571 now: NowFn,
572 ) -> Result<Self, AdapterError>
573 where
574 F: FnOnce(&mut AdapterSession),
575 {
576 let conn_id = adapter_client.new_conn_id()?;
577 let mut session = adapter_client.new_session(AdapterSessionConfig {
578 conn_id,
579 uuid: epoch_to_uuid_v7(&(now)()),
580 user: user.name,
581 client_ip: Some(peer_addr),
582 external_metadata_rx: user.external_metadata_rx,
583 internal_user_metadata: user.internal_metadata,
584 helm_chart_version,
585 });
586 let connection_guard = active_connection_counter.allocate_connection(session.user())?;
587
588 session_config(&mut session);
589 let system_vars = adapter_client.get_system_vars().await;
590 for (key, val) in options {
591 const LOCAL: bool = false;
592 if let Err(err) =
593 session
594 .vars_mut()
595 .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
596 {
597 session.add_notice(AdapterNotice::BadStartupSetting {
598 name: key.to_string(),
599 reason: err.to_string(),
600 })
601 }
602 }
603 let adapter_client = adapter_client.startup(session).await?;
604 Ok(AuthedClient {
605 client: adapter_client,
606 connection_guard,
607 })
608 }
609}
610
611#[async_trait]
612impl<S> FromRequestParts<S> for AuthedClient
613where
614 S: Send + Sync,
615{
616 type Rejection = Response;
617
618 async fn from_request_parts(
619 req: &mut http::request::Parts,
620 state: &S,
621 ) -> Result<Self, Self::Rejection> {
622 #[derive(Debug, Default, Deserialize)]
623 struct Params {
624 #[serde(default)]
625 options: String,
626 }
627 let params: Query<Params> = Query::from_request_parts(req, state)
628 .await
629 .unwrap_or_default();
630
631 let peer_addr = req
632 .extensions
633 .get::<ConnectInfo<SocketAddr>>()
634 .expect("ConnectInfo extension guaranteed to exist")
635 .0
636 .ip();
637
638 let user = req.extensions.get::<AuthedUser>().unwrap();
639 let adapter_client = req
640 .extensions
641 .get::<Delayed<mz_adapter::Client>>()
642 .unwrap()
643 .clone();
644 let adapter_client = adapter_client.await.map_err(|_| {
645 (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
646 })?;
647 let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
648 let helm_chart_version = None;
649
650 let options = if params.options.is_empty() {
651 BTreeMap::<String, String>::default()
654 } else {
655 match serde_json::from_str(¶ms.options) {
656 Ok(options) => options,
657 Err(_e) => {
658 let code = StatusCode::BAD_REQUEST;
660 let msg = format!("Failed to deserialize {} map", "options".quoted());
661 return Err((code, msg).into_response());
662 }
663 }
664 };
665
666 let client = AuthedClient::new(
667 &adapter_client,
668 user.clone(),
669 peer_addr,
670 active_connection_counter.clone(),
671 helm_chart_version,
672 |session| {
673 session
674 .vars_mut()
675 .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
676 .expect("known to exist")
677 },
678 options,
679 SYSTEM_TIME.clone(),
680 )
681 .await
682 .map_err(|e| {
683 let status = match e {
684 AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
685 StatusCode::FORBIDDEN
686 }
687 _ => StatusCode::INTERNAL_SERVER_ERROR,
688 };
689 (status, Json(SqlError::from(e))).into_response()
690 })?;
691
692 Ok(client)
693 }
694}
695
696#[derive(Debug, Error)]
697enum AuthError {
698 #[error("role dissallowed")]
699 RoleDisallowed(String),
700 #[error("{0}")]
701 Frontegg(#[from] FronteggError),
702 #[error("missing authorization header")]
703 MissingHttpAuthentication {
704 include_www_authenticate_header: bool,
705 },
706 #[error("{0}")]
707 MismatchedUser(String),
708 #[error("session expired")]
709 SessionExpired,
710 #[error("failed to update session")]
711 FailedToUpdateSession,
712 #[error("invalid credentials")]
713 InvalidCredentials,
714}
715
716impl IntoResponse for AuthError {
717 fn into_response(self) -> Response {
718 warn!("HTTP request failed authentication: {}", self);
719 let mut headers = HeaderMap::new();
720 match self {
721 AuthError::MissingHttpAuthentication {
722 include_www_authenticate_header,
723 } if include_www_authenticate_header => {
724 headers.insert(
725 http::header::WWW_AUTHENTICATE,
726 HeaderValue::from_static("Basic realm=Materialize"),
727 );
728 }
729 _ => {}
730 };
731 (StatusCode::UNAUTHORIZED, headers, "unauthorized").into_response()
734 }
735}
736
737pub async fn handle_login(
739 session: Option<Extension<TowerSession>>,
740 Extension(adapter_client_rx): Extension<Delayed<Client>>,
741 Json(LoginCredentials { username, password }): Json<LoginCredentials>,
742) -> impl IntoResponse {
743 let Ok(adapter_client) = adapter_client_rx.clone().await else {
744 return StatusCode::INTERNAL_SERVER_ERROR;
745 };
746 let auth_response = match adapter_client.authenticate(&username, &password).await {
747 Ok(auth_response) => auth_response,
748 Err(err) => {
749 warn!(?err, "HTTP login failed authentication");
750 return StatusCode::UNAUTHORIZED;
751 }
752 };
753
754 let session_data = TowerSessionData {
756 username,
757 created_at: SystemTime::now(),
758 last_activity: SystemTime::now(),
759 internal_metadata: InternalUserMetadata {
760 superuser: auth_response.superuser,
761 },
762 };
763 let session = session.and_then(|Extension(session)| Some(session));
765 let Some(session) = session else {
766 return StatusCode::INTERNAL_SERVER_ERROR;
767 };
768 match session.insert("data", &session_data).await {
769 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
770 Ok(_) => StatusCode::OK,
771 }
772}
773
774pub async fn handle_logout(session: Option<Extension<TowerSession>>) -> impl IntoResponse {
776 let session = session.and_then(|Extension(session)| Some(session));
777 let Some(session) = session else {
778 return StatusCode::INTERNAL_SERVER_ERROR;
779 };
780 match session.delete().await {
782 Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
783 Ok(_) => StatusCode::OK,
784 }
785}
786
787async fn http_auth(
788 mut req: Request,
789 next: Next,
790 tls_enabled: bool,
791 authenticator: Arc<Authenticator>,
792 allowed_roles: AllowedRoles,
793) -> impl IntoResponse + use<> {
794 if let Some(session) = req.extensions().get::<TowerSession>() {
796 if let Ok(Some(session_data)) = session.get::<TowerSessionData>("data").await {
797 if session_data
799 .last_activity
800 .elapsed()
801 .unwrap_or(Duration::MAX)
802 > SESSION_DURATION
803 {
804 let _ = session.delete().await;
805 return Err(AuthError::SessionExpired);
806 }
807 let mut updated_data = session_data.clone();
809 updated_data.last_activity = SystemTime::now();
810 session
811 .insert("data", &updated_data)
812 .await
813 .map_err(|_| AuthError::FailedToUpdateSession)?;
814 req.extensions_mut().insert(AuthedUser {
816 name: session_data.username,
817 external_metadata_rx: None,
818 internal_metadata: Some(session_data.internal_metadata),
819 });
820 return Ok(next.run(req).await);
821 }
822 }
823
824 let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
828 match (tls_enabled, &conn_protocol) {
829 (false, ConnProtocol::Http) => {}
830 (false, ConnProtocol::Https { .. }) => unreachable!(),
831 (true, ConnProtocol::Http) => {
832 let mut parts = req.uri().clone().into_parts();
833 parts.scheme = Some(Scheme::HTTPS);
834 return Ok(Redirect::permanent(
835 &Uri::from_parts(parts)
836 .expect("it was already a URI, just changed the scheme")
837 .to_string(),
838 )
839 .into_response());
840 }
841 (true, ConnProtocol::Https { .. }) => {}
842 }
843 if req.extensions().get::<AuthedUser>().is_some() {
845 return Ok(next.run(req).await);
846 }
847 let creds = if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
848 Some(Credentials::Password {
849 username: basic.username().to_owned(),
850 password: Password(basic.password().to_owned()),
851 })
852 } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
853 Some(Credentials::Token {
854 token: bearer.token().to_owned(),
855 })
856 } else {
857 None
858 };
859
860 let path = req.uri().path();
861 let include_www_authenticate_header = path == "/"
862 || PROFILING_API_ENDPOINTS
863 .iter()
864 .any(|prefix| path.starts_with(prefix));
865 let user = auth(
866 &authenticator,
867 creds,
868 allowed_roles,
869 include_www_authenticate_header,
870 )
871 .await?;
872
873 req.extensions_mut().insert(user);
876
877 Ok(next.run(req).await)
879}
880
881async fn init_ws(
882 WsState {
883 authenticator_rx,
884 adapter_client_rx,
885 active_connection_counter,
886 helm_chart_version,
887 allowed_roles,
888 }: &WsState,
889 existing_user: Option<AuthedUser>,
890 peer_addr: IpAddr,
891 ws: &mut WebSocket,
892) -> Result<AuthedClient, anyhow::Error> {
893 let authenticator = authenticator_rx.clone().await.expect("sender not dropped");
894 let ws_auth: WebSocketAuth = loop {
897 let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
898 match init_msg {
899 Message::Text(data) => break serde_json::from_str(&data)?,
900 Message::Binary(data) => break serde_json::from_slice(&data)?,
901 Message::Ping(_) => {
903 continue;
904 }
905 Message::Pong(_) => {
906 continue;
907 }
908 Message::Close(_) => {
909 anyhow::bail!("closed");
910 }
911 }
912 };
913
914 let (user, options) = if let Some(existing_user) = existing_user {
915 match ws_auth {
916 WebSocketAuth::OptionsOnly { options } => (existing_user, options),
917 _ => {
918 warn!("Unexpected bearer or basic auth provided when using user header");
919 anyhow::bail!("unexpected")
920 }
921 }
922 } else {
923 let (creds, options) = match ws_auth {
924 WebSocketAuth::Basic {
925 user,
926 password,
927 options,
928 } => {
929 let creds = Credentials::Password {
930 username: user,
931 password,
932 };
933 (creds, options)
934 }
935 WebSocketAuth::Bearer { token, options } => {
936 let creds = Credentials::Token { token };
937 (creds, options)
938 }
939 WebSocketAuth::OptionsOnly { .. } => {
940 anyhow::bail!("expected auth information");
941 }
942 };
943 let user = auth(&authenticator, Some(creds), *allowed_roles, false).await?;
944 (user, options)
945 };
946
947 let client = AuthedClient::new(
948 &adapter_client_rx.clone().await?,
949 user,
950 peer_addr,
951 active_connection_counter.clone(),
952 helm_chart_version.clone(),
953 |_session| (),
954 options,
955 SYSTEM_TIME.clone(),
956 )
957 .await?;
958
959 Ok(client)
960}
961
962enum Credentials {
963 Password {
964 username: String,
965 password: Password,
966 },
967 Token {
968 token: String,
969 },
970}
971
972async fn auth(
973 authenticator: &Authenticator,
974 creds: Option<Credentials>,
975 allowed_roles: AllowedRoles,
976 include_www_authenticate_header: bool,
977) -> Result<AuthedUser, AuthError> {
978 let (name, external_metadata_rx, internal_metadata) = match authenticator {
979 Authenticator::Frontegg(frontegg) => match creds {
980 Some(Credentials::Password { username, password }) => {
981 let auth_session = frontegg.authenticate(&username, &password.0).await?;
982 let name = auth_session.user().into();
983 let external_metadata_rx = Some(auth_session.external_metadata_rx());
984 (name, external_metadata_rx, None)
985 }
986 Some(Credentials::Token { token }) => {
987 let claims = frontegg.validate_access_token(&token, None)?;
988 let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
989 user_id: claims.user_id,
990 admin: claims.is_admin,
991 });
992 (claims.user, Some(external_metadata_rx), None)
993 }
994 None => {
995 return Err(AuthError::MissingHttpAuthentication {
996 include_www_authenticate_header,
997 });
998 }
999 },
1000 Authenticator::Password(adapter_client) => match creds {
1001 Some(Credentials::Password { username, password }) => {
1002 let auth_response = adapter_client
1003 .authenticate(&username, &password)
1004 .await
1005 .map_err(|_| AuthError::InvalidCredentials)?;
1006 let internal_metadata = InternalUserMetadata {
1007 superuser: auth_response.superuser,
1008 };
1009 (username, None, Some(internal_metadata))
1010 }
1011 _ => {
1012 return Err(AuthError::MissingHttpAuthentication {
1013 include_www_authenticate_header,
1014 });
1015 }
1016 },
1017 Authenticator::Sasl(_) => {
1018 return Err(AuthError::MissingHttpAuthentication {
1022 include_www_authenticate_header,
1023 });
1024 }
1025 Authenticator::None => {
1026 let name = match creds {
1030 Some(Credentials::Password { username, .. }) => username,
1031 _ => HTTP_DEFAULT_USER.name.to_owned(),
1032 };
1033 (name, None, None)
1034 }
1035 };
1036
1037 check_role_allowed(&name, allowed_roles)?;
1038
1039 Ok(AuthedUser {
1040 name,
1041 external_metadata_rx,
1042 internal_metadata,
1043 })
1044}
1045
1046fn check_role_allowed(name: &str, allowed_roles: AllowedRoles) -> Result<(), AuthError> {
1048 let is_internal_user = INTERNAL_USER_NAMES.contains(name);
1049 let is_reserved_user = mz_adapter::catalog::is_reserved_role_name(name);
1051 let role_allowed = match allowed_roles {
1052 AllowedRoles::Normal => !is_reserved_user,
1053 AllowedRoles::Internal => is_internal_user,
1054 AllowedRoles::NormalAndInternal => !is_reserved_user || is_internal_user,
1055 };
1056 if role_allowed {
1057 Ok(())
1058 } else {
1059 Err(AuthError::RoleDisallowed(name.to_owned()))
1060 }
1061}
1062
1063trait DefaultLayers {
1066 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
1067}
1068
1069impl DefaultLayers for Router {
1070 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
1071 self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
1072 .layer(metrics::PrometheusLayer::new(source, metrics))
1073 }
1074}
1075
1076async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
1083 if error.is::<tower::load_shed::error::Overloaded>() {
1084 return (
1085 StatusCode::TOO_MANY_REQUESTS,
1086 Cow::from("too many requests, try again later"),
1087 );
1088 }
1089
1090 (
1093 StatusCode::INTERNAL_SERVER_ERROR,
1094 Cow::from(format!("Unhandled internal error: {}", error)),
1095 )
1096}
1097
1098#[derive(Debug, Deserialize, Serialize, PartialEq)]
1099pub struct LoginCredentials {
1100 username: String,
1101 password: Password,
1102}
1103
1104#[derive(Debug, Clone, Serialize, Deserialize)]
1105pub struct TowerSessionData {
1106 username: String,
1107 created_at: SystemTime,
1108 last_activity: SystemTime,
1109 internal_metadata: InternalUserMetadata,
1110}
1111
1112#[cfg(test)]
1113mod tests {
1114 use super::{AllowedRoles, check_role_allowed};
1115
1116 #[mz_ore::test]
1117 fn test_check_role_allowed() {
1118 assert!(check_role_allowed("mz_system", AllowedRoles::Internal).is_ok());
1120 assert!(check_role_allowed("mz_system", AllowedRoles::NormalAndInternal).is_ok());
1121 assert!(check_role_allowed("mz_system", AllowedRoles::Normal).is_err());
1122
1123 assert!(check_role_allowed("mz_support", AllowedRoles::Internal).is_ok());
1125 assert!(check_role_allowed("mz_support", AllowedRoles::NormalAndInternal).is_ok());
1126 assert!(check_role_allowed("mz_support", AllowedRoles::Normal).is_err());
1127
1128 assert!(check_role_allowed("mz_analytics", AllowedRoles::Internal).is_ok());
1130 assert!(check_role_allowed("mz_analytics", AllowedRoles::NormalAndInternal).is_ok());
1131 assert!(check_role_allowed("mz_analytics", AllowedRoles::Normal).is_err());
1132
1133 assert!(check_role_allowed("materialize", AllowedRoles::Internal).is_err());
1135 assert!(check_role_allowed("materialize", AllowedRoles::NormalAndInternal).is_ok());
1136 assert!(check_role_allowed("materialize", AllowedRoles::Normal).is_ok());
1137
1138 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Internal).is_err());
1140 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::NormalAndInternal).is_ok());
1141 assert!(check_role_allowed("anonymous_http_user", AllowedRoles::Normal).is_ok());
1142
1143 assert!(check_role_allowed("alex", AllowedRoles::Internal).is_err());
1145 assert!(check_role_allowed("alex", AllowedRoles::NormalAndInternal).is_ok());
1146 assert!(check_role_allowed("alex", AllowedRoles::Normal).is_ok());
1147
1148 assert!(check_role_allowed("external_asdf", AllowedRoles::Internal).is_err());
1150 assert!(check_role_allowed("external_asdf", AllowedRoles::NormalAndInternal).is_err());
1151 assert!(check_role_allowed("external_asdf", AllowedRoles::Normal).is_err());
1152
1153 assert!(check_role_allowed("pg_somebody", AllowedRoles::Internal).is_err());
1155 assert!(check_role_allowed("pg_somebody", AllowedRoles::NormalAndInternal).is_err());
1156 assert!(check_role_allowed("pg_somebody", AllowedRoles::Normal).is_err());
1157
1158 assert!(check_role_allowed("mz_unknown", AllowedRoles::Internal).is_err());
1160 assert!(check_role_allowed("mz_unknown", AllowedRoles::NormalAndInternal).is_err());
1161 assert!(check_role_allowed("mz_unknown", AllowedRoles::Normal).is_err());
1162
1163 assert!(check_role_allowed("PUBLIC", AllowedRoles::Internal).is_err());
1165 assert!(check_role_allowed("PUBLIC", AllowedRoles::NormalAndInternal).is_err());
1166 assert!(check_role_allowed("PUBLIC", AllowedRoles::Normal).is_err());
1167 }
1168}