1#![allow(clippy::unused_async)]
18
19use std::borrow::Cow;
20use std::collections::BTreeMap;
21use std::fmt::Debug;
22use std::net::{IpAddr, SocketAddr};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::time::Duration;
26
27use anyhow::Context;
28use async_trait::async_trait;
29use axum::error_handling::HandleErrorLayer;
30use axum::extract::ws::{Message, WebSocket};
31use axum::extract::{ConnectInfo, DefaultBodyLimit, FromRequestParts, Query, Request, State};
32use axum::middleware::{self, Next};
33use axum::response::{IntoResponse, Redirect, Response};
34use axum::{Extension, Json, Router, routing};
35use futures::future::{FutureExt, Shared, TryFutureExt};
36use headers::authorization::{Authorization, Basic, Bearer};
37use headers::{HeaderMapExt, HeaderName};
38use http::header::{AUTHORIZATION, CONTENT_TYPE};
39use http::{Method, StatusCode};
40use hyper_openssl::SslStream;
41use hyper_openssl::client::legacy::MaybeHttpsStream;
42use hyper_util::rt::TokioIo;
43use mz_adapter::session::{Session, SessionConfig};
44use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache};
45use mz_frontegg_auth::{Authenticator as FronteggAuthentication, Error as FronteggError};
46use mz_http_util::DynamicFilterTarget;
47use mz_ore::cast::u64_to_usize;
48use mz_ore::metrics::MetricsRegistry;
49use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
50use mz_ore::str::StrExt;
51use mz_pgwire_common::{ConnectionCounter, ConnectionHandle};
52use mz_repr::user::ExternalUserMetadata;
53use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server};
54use mz_sql::session::metadata::SessionMetadata;
55use mz_sql::session::user::{HTTP_DEFAULT_USER, SUPPORT_USER_NAME, SYSTEM_USER_NAME};
56use mz_sql::session::vars::{Value, Var, VarInput, WELCOME_MESSAGE};
57use openssl::ssl::Ssl;
58use prometheus::{
59 COMPUTE_METRIC_QUERIES, FRONTIER_METRIC_QUERIES, STORAGE_METRIC_QUERIES, USAGE_METRIC_QUERIES,
60};
61use serde::Deserialize;
62use serde_json::json;
63use thiserror::Error;
64use tokio::io::AsyncWriteExt;
65use tokio::sync::{oneshot, watch};
66use tower::limit::GlobalConcurrencyLimitLayer;
67use tower::{Service, ServiceBuilder};
68use tower_http::cors::{AllowOrigin, Any, CorsLayer};
69use tracing::{error, warn};
70
71use crate::BUILD_INFO;
72use crate::deployment::state::DeploymentStateHandle;
73use crate::http::sql::SqlError;
74
75mod catalog;
76mod console;
77mod memory;
78mod metrics;
79mod probe;
80mod prometheus;
81mod root;
82mod sql;
83mod webhook;
84
85pub use metrics::Metrics;
86pub use sql::{SqlResponse, WebSocketAuth, WebSocketResponse};
87
88pub const MAX_REQUEST_SIZE: usize = u64_to_usize(2 * bytesize::MB);
90
91#[derive(Debug, Clone)]
92pub struct HttpConfig {
93 pub source: &'static str,
94 pub tls: Option<ReloadingTlsConfig>,
95 pub frontegg: Option<FronteggAuthentication>,
96 pub adapter_client: mz_adapter::Client,
97 pub allowed_origin: AllowOrigin,
98 pub active_connection_counter: ConnectionCounter,
99 pub helm_chart_version: Option<String>,
100 pub concurrent_webhook_req: Arc<tokio::sync::Semaphore>,
101 pub metrics: Metrics,
102}
103
104#[derive(Debug, Clone)]
105pub struct ReloadingTlsConfig {
106 pub context: ReloadingSslContext,
107 pub mode: TlsMode,
108}
109
110#[derive(Debug, Clone, Copy)]
111pub enum TlsMode {
112 Disable,
113 Require,
114}
115
116#[derive(Clone)]
117pub struct WsState {
118 frontegg: Arc<Option<FronteggAuthentication>>,
119 adapter_client_rx: Delayed<mz_adapter::Client>,
120 active_connection_counter: ConnectionCounter,
121 helm_chart_version: Option<String>,
122}
123
124#[derive(Clone)]
125pub struct WebhookState {
126 adapter_client: mz_adapter::Client,
127 webhook_cache: WebhookAppenderCache,
128}
129
130#[derive(Debug)]
131pub struct HttpServer {
132 tls: Option<ReloadingTlsConfig>,
133 router: Router,
134}
135
136impl HttpServer {
137 pub fn new(
138 HttpConfig {
139 source,
140 tls,
141 frontegg,
142 adapter_client,
143 allowed_origin,
144 active_connection_counter,
145 helm_chart_version,
146 concurrent_webhook_req,
147 metrics,
148 }: HttpConfig,
149 ) -> HttpServer {
150 let tls_mode = tls.as_ref().map(|tls| tls.mode).unwrap_or(TlsMode::Disable);
151 let frontegg = Arc::new(frontegg);
152 let base_frontegg = Arc::clone(&frontegg);
153 let (adapter_client_tx, adapter_client_rx) = oneshot::channel();
154 adapter_client_tx
155 .send(adapter_client.clone())
156 .expect("rx known to be live");
157 let adapter_client_rx = adapter_client_rx.shared();
158 let webhook_cache = WebhookAppenderCache::new();
159
160 let base_router = base_router(BaseRouterConfig { profiling: false })
161 .layer(middleware::from_fn(move |req, next| {
162 let base_frontegg = Arc::clone(&base_frontegg);
163 async move { http_auth(req, next, tls_mode, base_frontegg.as_ref().as_ref()).await }
164 }))
165 .layer(Extension(adapter_client_rx.clone()))
166 .layer(Extension(active_connection_counter.clone()))
167 .layer(
168 CorsLayer::new()
169 .allow_credentials(false)
170 .allow_headers([
171 AUTHORIZATION,
172 CONTENT_TYPE,
173 HeaderName::from_static("x-materialize-version"),
174 ])
175 .allow_methods(Any)
176 .allow_origin(allowed_origin)
177 .expose_headers(Any)
178 .max_age(Duration::from_secs(60) * 60),
179 );
180 let ws_router = Router::new()
181 .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
182 .with_state(WsState {
183 frontegg,
184 adapter_client_rx,
185 active_connection_counter,
186 helm_chart_version,
187 });
188
189 let webhook_router = Router::new()
190 .route(
191 "/api/webhook/{:database}/{:schema}/{:id}",
192 routing::post(webhook::handle_webhook),
193 )
194 .with_state(WebhookState {
195 adapter_client,
196 webhook_cache,
197 })
198 .layer(
199 CorsLayer::new()
200 .allow_methods(Method::POST)
201 .allow_origin(AllowOrigin::mirror_request())
202 .allow_headers(Any),
203 )
204 .layer(
205 ServiceBuilder::new()
206 .layer(HandleErrorLayer::new(handle_load_error))
207 .load_shed()
208 .layer(GlobalConcurrencyLimitLayer::with_semaphore(
209 concurrent_webhook_req,
210 )),
211 );
212
213 let router = Router::new()
214 .merge(base_router)
215 .merge(ws_router)
216 .merge(webhook_router)
217 .apply_default_layers(source, metrics);
218
219 HttpServer { tls, router }
220 }
221}
222
223impl Server for HttpServer {
224 const NAME: &'static str = "http";
225
226 fn handle_connection(&self, conn: Connection) -> ConnectionHandler {
227 let router = self.router.clone();
228 let tls_config = self.tls.clone();
229 let mut conn = TokioIo::new(conn);
230 Box::pin(async {
231 let direct_peer_addr = conn.inner().peer_addr().context("fetching peer addr")?;
232 let peer_addr = conn
233 .inner_mut()
234 .take_proxy_header_address()
235 .await
236 .map(|a| a.source)
237 .unwrap_or(direct_peer_addr);
238
239 let (conn, conn_protocol) = match tls_config {
240 Some(tls_config) => {
241 let mut ssl_stream =
242 SslStream::new(Ssl::new(&tls_config.context.get())?, conn)?;
243 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
244 let _ = ssl_stream.get_mut().inner_mut().shutdown().await;
245 return Err(e.into());
246 }
247 (MaybeHttpsStream::Https(ssl_stream), ConnProtocol::Https)
248 }
249 _ => (MaybeHttpsStream::Http(conn), ConnProtocol::Http),
250 };
251 let mut make_tower_svc = router
252 .layer(Extension(conn_protocol))
253 .into_make_service_with_connect_info::<SocketAddr>();
254 let tower_svc = make_tower_svc.call(peer_addr).await.unwrap();
255 let hyper_svc = hyper::service::service_fn(|req| tower_svc.clone().call(req));
256 let http = hyper::server::conn::http1::Builder::new();
257 http.serve_connection(conn, hyper_svc)
258 .with_upgrades()
259 .err_into()
260 .await
261 })
262 }
263}
264
265pub struct InternalHttpConfig {
266 pub metrics_registry: MetricsRegistry,
267 pub adapter_client_rx: oneshot::Receiver<mz_adapter::Client>,
268 pub active_connection_counter: ConnectionCounter,
269 pub helm_chart_version: Option<String>,
270 pub deployment_state_handle: DeploymentStateHandle,
271 pub internal_console_redirect_url: Option<String>,
272}
273
274pub struct InternalHttpServer {
275 router: Router,
276}
277
278pub async fn handle_leader_status(
279 State(deployment_state_handle): State<DeploymentStateHandle>,
280) -> impl IntoResponse {
281 let status = deployment_state_handle.status();
282 (StatusCode::OK, Json(json!({ "status": status })))
283}
284
285pub async fn handle_leader_promote(
286 State(deployment_state_handle): State<DeploymentStateHandle>,
287) -> impl IntoResponse {
288 match deployment_state_handle.try_promote() {
289 Ok(()) => {
290 let status = StatusCode::OK;
293 let body = Json(json!({
294 "result": "Success",
295 }));
296 (status, body)
297 }
298 Err(()) => {
299 let status = StatusCode::BAD_REQUEST;
302 let body = Json(json!({
303 "result": {"Failure": {"message": "cannot promote leader while initializing"}},
304 }));
305 (status, body)
306 }
307 }
308}
309
310pub async fn handle_leader_skip_catchup(
311 State(deployment_state_handle): State<DeploymentStateHandle>,
312) -> impl IntoResponse {
313 match deployment_state_handle.try_skip_catchup() {
314 Ok(()) => StatusCode::NO_CONTENT.into_response(),
315 Err(()) => {
316 let status = StatusCode::BAD_REQUEST;
317 let body = Json(json!({
318 "message": "cannot skip catchup in this phase of initialization; try again later",
319 }));
320 (status, body).into_response()
321 }
322 }
323}
324
325impl InternalHttpServer {
326 pub fn new(
327 InternalHttpConfig {
328 metrics_registry,
329 adapter_client_rx,
330 active_connection_counter,
331 helm_chart_version,
332 deployment_state_handle,
333 internal_console_redirect_url,
334 }: InternalHttpConfig,
335 ) -> InternalHttpServer {
336 let metrics = Metrics::register_into(&metrics_registry, "mz_internal_http");
337 let console_config = Arc::new(console::ConsoleProxyConfig::new(
338 internal_console_redirect_url,
339 "/internal-console".to_string(),
340 ));
341
342 let adapter_client_rx = adapter_client_rx.shared();
343 let router = base_router(BaseRouterConfig { profiling: true })
344 .route(
345 "/metrics",
346 routing::get(move || async move {
347 mz_http_util::handle_prometheus(&metrics_registry).await
348 }),
349 )
350 .route(
351 "/metrics/mz_usage",
352 routing::get(|client: AuthedClient| async move {
353 let registry = sql::handle_promsql(client, USAGE_METRIC_QUERIES).await;
354 mz_http_util::handle_prometheus(®istry).await
355 }),
356 )
357 .route(
358 "/metrics/mz_frontier",
359 routing::get(|client: AuthedClient| async move {
360 let registry = sql::handle_promsql(client, FRONTIER_METRIC_QUERIES).await;
361 mz_http_util::handle_prometheus(®istry).await
362 }),
363 )
364 .route(
365 "/metrics/mz_compute",
366 routing::get(|client: AuthedClient| async move {
367 let registry = sql::handle_promsql(client, COMPUTE_METRIC_QUERIES).await;
368 mz_http_util::handle_prometheus(®istry).await
369 }),
370 )
371 .route(
372 "/metrics/mz_storage",
373 routing::get(|client: AuthedClient| async move {
374 let registry = sql::handle_promsql(client, STORAGE_METRIC_QUERIES).await;
375 mz_http_util::handle_prometheus(®istry).await
376 }),
377 )
378 .route(
379 "/api/livez",
380 routing::get(mz_http_util::handle_liveness_check),
381 )
382 .route("/api/readyz", routing::get(probe::handle_ready))
383 .route(
384 "/api/opentelemetry/config",
385 routing::put({
386 move |_: axum::Json<DynamicFilterTarget>| async {
387 (
388 StatusCode::BAD_REQUEST,
389 "This endpoint has been replaced. \
390 Use the `opentelemetry_filter` system variable."
391 .to_string(),
392 )
393 }
394 }),
395 )
396 .route(
397 "/api/stderr/config",
398 routing::put({
399 move |_: axum::Json<DynamicFilterTarget>| async {
400 (
401 StatusCode::BAD_REQUEST,
402 "This endpoint has been replaced. \
403 Use the `log_filter` system variable."
404 .to_string(),
405 )
406 }
407 }),
408 )
409 .route("/api/tracing", routing::get(mz_http_util::handle_tracing))
410 .route(
411 "/api/catalog/dump",
412 routing::get(catalog::handle_catalog_dump),
413 )
414 .route(
415 "/api/catalog/check",
416 routing::get(catalog::handle_catalog_check),
417 )
418 .route(
419 "/api/coordinator/check",
420 routing::get(catalog::handle_coordinator_check),
421 )
422 .route(
423 "/api/coordinator/dump",
424 routing::get(catalog::handle_coordinator_dump),
425 )
426 .route(
427 "/internal-console",
428 routing::get(|| async { Redirect::temporary("/internal-console/") }),
429 )
430 .route(
431 "/internal-console/{*path}",
432 routing::get(console::handle_internal_console),
433 )
434 .route(
435 "/internal-console/",
436 routing::get(console::handle_internal_console),
437 )
438 .layer(middleware::from_fn(internal_http_auth))
439 .layer(Extension(adapter_client_rx.clone()))
440 .layer(Extension(console_config))
441 .layer(Extension(active_connection_counter.clone()));
442
443 let ws_router = Router::new()
444 .route("/api/experimental/sql", routing::get(sql::handle_sql_ws))
445 .layer(middleware::from_fn(internal_http_auth))
450 .with_state(WsState {
451 frontegg: Arc::new(None),
452 adapter_client_rx,
453 active_connection_counter,
454 helm_chart_version: helm_chart_version.clone(),
455 });
456
457 let leader_router = Router::new()
458 .route("/api/leader/status", routing::get(handle_leader_status))
459 .route("/api/leader/promote", routing::post(handle_leader_promote))
460 .route(
461 "/api/leader/skip-catchup",
462 routing::post(handle_leader_skip_catchup),
463 )
464 .with_state(deployment_state_handle);
465
466 let router = router
467 .merge(ws_router)
468 .merge(leader_router)
469 .apply_default_layers("internal", metrics);
470
471 InternalHttpServer { router }
472 }
473}
474
475async fn internal_http_auth(mut req: Request, next: Next) -> impl IntoResponse {
476 let user_name = req
477 .headers()
478 .get("x-materialize-user")
479 .map(|h| h.to_str())
480 .unwrap_or_else(|| Ok(SYSTEM_USER_NAME));
481 let user_name = match user_name {
482 Ok(name @ (SUPPORT_USER_NAME | SYSTEM_USER_NAME)) => name.to_string(),
483 _ => {
484 return Err(AuthError::MismatchedUser(format!(
485 "user specified in x-materialize-user must be {SUPPORT_USER_NAME} or {SYSTEM_USER_NAME}"
486 )));
487 }
488 };
489 req.extensions_mut().insert(AuthedUser {
490 name: user_name,
491 external_metadata_rx: None,
492 });
493 Ok(next.run(req).await)
494}
495
496#[async_trait]
497impl Server for InternalHttpServer {
498 const NAME: &'static str = "internal_http";
499
500 fn handle_connection(&self, mut conn: Connection) -> ConnectionHandler {
501 let router = self.router.clone();
502
503 Box::pin(async {
504 let mut make_tower_svc = router.into_make_service_with_connect_info::<SocketAddr>();
505 let direct_peer_addr = conn.peer_addr().context("fetching peer addr")?;
506 let peer_addr = conn
507 .take_proxy_header_address()
508 .await
509 .map(|a| a.source)
510 .unwrap_or(direct_peer_addr);
511 let tower_svc = make_tower_svc.call(peer_addr).await?;
512 let service = hyper::service::service_fn(|req| tower_svc.clone().call(req));
513 let http = hyper::server::conn::http1::Builder::new();
514 http.serve_connection(TokioIo::new(conn), service)
515 .with_upgrades()
516 .err_into()
517 .await
518 })
519 }
520}
521
522type Delayed<T> = Shared<oneshot::Receiver<T>>;
523
524#[derive(Clone)]
525enum ConnProtocol {
526 Http,
527 Https,
528}
529
530#[derive(Clone, Debug)]
531pub struct AuthedUser {
532 name: String,
533 external_metadata_rx: Option<watch::Receiver<ExternalUserMetadata>>,
534}
535
536pub struct AuthedClient {
537 pub client: SessionClient,
538 pub connection_guard: Option<ConnectionHandle>,
539}
540
541impl AuthedClient {
542 async fn new<F>(
543 adapter_client: &Client,
544 user: AuthedUser,
545 peer_addr: IpAddr,
546 active_connection_counter: ConnectionCounter,
547 helm_chart_version: Option<String>,
548 session_config: F,
549 options: BTreeMap<String, String>,
550 now: NowFn,
551 ) -> Result<Self, AdapterError>
552 where
553 F: FnOnce(&mut Session),
554 {
555 let conn_id = adapter_client.new_conn_id()?;
556 let mut session = adapter_client.new_session(SessionConfig {
557 conn_id,
558 uuid: epoch_to_uuid_v7(&(now)()),
559 user: user.name,
560 client_ip: Some(peer_addr),
561 external_metadata_rx: user.external_metadata_rx,
562 internal_user_metadata: None,
564 helm_chart_version,
565 });
566 let connection_guard = active_connection_counter.allocate_connection(session.user())?;
567
568 session_config(&mut session);
569 let system_vars = adapter_client.get_system_vars().await;
570 for (key, val) in options {
571 const LOCAL: bool = false;
572 if let Err(err) =
573 session
574 .vars_mut()
575 .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
576 {
577 session.add_notice(AdapterNotice::BadStartupSetting {
578 name: key.to_string(),
579 reason: err.to_string(),
580 })
581 }
582 }
583 let adapter_client = adapter_client.startup(session).await?;
584 Ok(AuthedClient {
585 client: adapter_client,
586 connection_guard,
587 })
588 }
589}
590
591impl<S> FromRequestParts<S> for AuthedClient
592where
593 S: Send + Sync,
594{
595 type Rejection = Response;
596
597 fn from_request_parts(
598 req: &mut http::request::Parts,
599 state: &S,
600 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
601 Box::pin(async move {
602 #[derive(Debug, Default, Deserialize)]
603 struct Params {
604 #[serde(default)]
605 options: String,
606 }
607 let params: Query<Params> = Query::from_request_parts(req, state)
608 .await
609 .unwrap_or_default();
610
611 let peer_addr = req
612 .extensions
613 .get::<ConnectInfo<SocketAddr>>()
614 .expect("ConnectInfo extension guaranteed to exist")
615 .0
616 .ip();
617
618 let user = req.extensions.get::<AuthedUser>().unwrap();
619 let adapter_client = req
620 .extensions
621 .get::<Delayed<mz_adapter::Client>>()
622 .unwrap()
623 .clone();
624 let adapter_client = adapter_client.await.map_err(|_| {
625 (StatusCode::INTERNAL_SERVER_ERROR, "adapter client missing").into_response()
626 })?;
627 let active_connection_counter = req.extensions.get::<ConnectionCounter>().unwrap();
628 let helm_chart_version = None;
629
630 let options = if params.options.is_empty() {
631 BTreeMap::<String, String>::default()
634 } else {
635 match serde_json::from_str(¶ms.options) {
636 Ok(options) => options,
637 Err(_e) => {
638 let code = StatusCode::BAD_REQUEST;
640 let msg = format!("Failed to deserialize {} map", "options".quoted());
641 return Err((code, msg).into_response());
642 }
643 }
644 };
645
646 let client = AuthedClient::new(
647 &adapter_client,
648 user.clone(),
649 peer_addr,
650 active_connection_counter.clone(),
651 helm_chart_version,
652 |session| {
653 session
654 .vars_mut()
655 .set_default(WELCOME_MESSAGE.name(), VarInput::Flat(&false.format()))
656 .expect("known to exist")
657 },
658 options,
659 SYSTEM_TIME.clone(),
660 )
661 .await
662 .map_err(|e| {
663 let status = match e {
664 AdapterError::UserSessionsDisallowed | AdapterError::NetworkPolicyDenied(_) => {
665 StatusCode::FORBIDDEN
666 }
667 _ => StatusCode::INTERNAL_SERVER_ERROR,
668 };
669 (status, Json(SqlError::from(e))).into_response()
670 })?;
671
672 Ok(client)
673 })
674 }
675}
676
677#[derive(Debug, Error)]
678enum AuthError {
679 #[error("HTTPS is required")]
680 HttpsRequired,
681 #[error("invalid username in client certificate")]
682 InvalidLogin(String),
683 #[error("{0}")]
684 Frontegg(#[from] FronteggError),
685 #[error("missing authorization header")]
686 MissingHttpAuthentication,
687 #[error("{0}")]
688 MismatchedUser(String),
689 #[error("unexpected credentials")]
690 UnexpectedCredentials,
691}
692
693impl IntoResponse for AuthError {
694 fn into_response(self) -> Response {
695 warn!("HTTP request failed authentication: {}", self);
696 let message = match self {
699 AuthError::HttpsRequired => self.to_string(),
700 _ => "unauthorized".into(),
701 };
702 (
703 StatusCode::UNAUTHORIZED,
704 [(http::header::WWW_AUTHENTICATE, "Basic realm=Materialize")],
705 message,
706 )
707 .into_response()
708 }
709}
710
711async fn http_auth(
712 mut req: Request,
713 next: Next,
714 tls_mode: TlsMode,
715 frontegg: Option<&FronteggAuthentication>,
716) -> impl IntoResponse + use<> {
717 let conn_protocol = req.extensions().get::<ConnProtocol>().unwrap();
720 match (tls_mode, &conn_protocol) {
721 (TlsMode::Disable, ConnProtocol::Http) => {}
722 (TlsMode::Disable, ConnProtocol::Https { .. }) => unreachable!(),
723 (TlsMode::Require, ConnProtocol::Http) => return Err(AuthError::HttpsRequired),
724 (TlsMode::Require, ConnProtocol::Https { .. }) => {}
725 }
726 let creds = match frontegg {
727 None => {
728 if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
732 Credentials::User(basic.username().to_string())
733 } else {
734 Credentials::DefaultUser
735 }
736 }
737 Some(_) => {
738 if let Some(basic) = req.headers().typed_get::<Authorization<Basic>>() {
739 Credentials::Password {
740 username: basic.username().to_string(),
741 password: basic.password().to_string(),
742 }
743 } else if let Some(bearer) = req.headers().typed_get::<Authorization<Bearer>>() {
744 Credentials::Token {
745 token: bearer.token().to_string(),
746 }
747 } else {
748 return Err(AuthError::MissingHttpAuthentication);
749 }
750 }
751 };
752
753 let user = auth(frontegg, creds).await?;
754
755 req.extensions_mut().insert(user);
758
759 Ok(next.run(req).await)
761}
762
763async fn init_ws(
764 WsState {
765 frontegg,
766 adapter_client_rx,
767 active_connection_counter,
768 helm_chart_version,
769 }: &WsState,
770 existing_user: Option<AuthedUser>,
771 peer_addr: IpAddr,
772 ws: &mut WebSocket,
773) -> Result<AuthedClient, anyhow::Error> {
774 let init_msg = ws.recv().await.ok_or_else(|| anyhow::anyhow!("closed"))??;
777 let ws_auth: WebSocketAuth = loop {
778 match init_msg {
779 Message::Text(data) => break serde_json::from_str(&data)?,
780 Message::Binary(data) => break serde_json::from_slice(&data)?,
781 Message::Ping(_) => {
783 continue;
784 }
785 Message::Pong(_) => {
786 continue;
787 }
788 Message::Close(_) => {
789 anyhow::bail!("closed");
790 }
791 }
792 };
793 let (user, options) = match (frontegg.as_ref(), existing_user, ws_auth) {
794 (Some(frontegg), None, ws_auth) => {
795 let (creds, options) = match ws_auth {
796 WebSocketAuth::Basic {
797 user,
798 password,
799 options,
800 } => {
801 let creds = Credentials::Password {
802 username: user,
803 password,
804 };
805 (creds, options)
806 }
807 WebSocketAuth::Bearer { token, options } => {
808 let creds = Credentials::Token { token };
809 (creds, options)
810 }
811 WebSocketAuth::OptionsOnly { options: _ } => {
812 anyhow::bail!("expected auth information");
813 }
814 };
815 (auth(Some(frontegg), creds).await?, options)
816 }
817 (
818 None,
819 None,
820 WebSocketAuth::Basic {
821 user,
822 password: _,
823 options,
824 },
825 ) => (auth(None, Credentials::User(user)).await?, options),
826 (None, Some(existing_user), WebSocketAuth::OptionsOnly { options }) => {
828 (existing_user, options)
829 }
830 (None, Some(_), WebSocketAuth::Basic { .. } | WebSocketAuth::Bearer { .. }) => {
832 warn!("Unexpected bearer or basic auth provided when using user header");
833 anyhow::bail!("unexpected")
834 }
835 (Some(_), Some(_), _) => anyhow::bail!("unexpected"),
837 (None, None, WebSocketAuth::Bearer { .. } | WebSocketAuth::OptionsOnly { .. }) => {
839 warn!("Unexpected auth type when not using frontegg or user header");
840 anyhow::bail!("unexpected")
841 }
842 };
843
844 let client = AuthedClient::new(
845 &adapter_client_rx.clone().await?,
846 user,
847 peer_addr,
848 active_connection_counter.clone(),
849 helm_chart_version.clone(),
850 |_session| (),
851 options,
852 SYSTEM_TIME.clone(),
853 )
854 .await?;
855
856 Ok(client)
857}
858
859enum Credentials {
860 User(String),
861 DefaultUser,
862 Password { username: String, password: String },
863 Token { token: String },
864}
865
866async fn auth(
867 frontegg: Option<&FronteggAuthentication>,
868 creds: Credentials,
869) -> Result<AuthedUser, AuthError> {
870 let (name, external_metadata_rx) = match (frontegg, creds) {
881 (None, Credentials::DefaultUser) => (HTTP_DEFAULT_USER.name.to_string(), None),
883 (None, Credentials::User(name)) => (name, None),
885 (None, _) => return Err(AuthError::UnexpectedCredentials),
887 (Some(frontegg), creds) => match creds {
893 Credentials::Password { username, password } => {
894 let auth_session = frontegg.authenticate(&username, &password).await?;
895 let user = auth_session.user().into();
896 let external_metadata_rx = Some(auth_session.external_metadata_rx());
897 (user, external_metadata_rx)
898 }
899 Credentials::Token { token } => {
900 let claims = frontegg.validate_access_token(&token, None)?;
901 let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata {
902 user_id: claims.user_id,
903 admin: claims.is_admin,
904 });
905 (claims.user, Some(external_metadata_rx))
906 }
907 Credentials::DefaultUser | Credentials::User(_) => {
908 return Err(AuthError::MissingHttpAuthentication);
909 }
910 },
911 };
912
913 if mz_adapter::catalog::is_reserved_role_name(name.as_str()) {
914 return Err(AuthError::InvalidLogin(name));
915 }
916 Ok(AuthedUser {
917 name,
918 external_metadata_rx,
919 })
920}
921
922struct BaseRouterConfig {
924 profiling: bool,
926}
927
928fn base_router(BaseRouterConfig { profiling }: BaseRouterConfig) -> Router {
931 let mut router = Router::new()
934 .route(
935 "/",
936 routing::get(move || async move { root::handle_home(profiling).await }),
937 )
938 .route("/api/sql", routing::post(sql::handle_sql))
939 .route("/memory", routing::get(memory::handle_memory))
940 .route(
941 "/hierarchical-memory",
942 routing::get(memory::handle_hierarchical_memory),
943 )
944 .route("/static/{*path}", routing::get(root::handle_static));
945
946 if profiling {
947 router = router.nest("/prof/", mz_prof_http::router(&BUILD_INFO));
948 }
949
950 router
951}
952
953trait DefaultLayers {
956 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self;
957}
958
959impl DefaultLayers for Router {
960 fn apply_default_layers(self, source: &'static str, metrics: Metrics) -> Self {
961 self.layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
962 .layer(metrics::PrometheusLayer::new(source, metrics))
963 }
964}
965
966async fn handle_load_error(error: tower::BoxError) -> impl IntoResponse {
973 if error.is::<tower::load_shed::error::Overloaded>() {
974 return (
975 StatusCode::TOO_MANY_REQUESTS,
976 Cow::from("too many requests, try again later"),
977 );
978 }
979
980 (
983 StatusCode::INTERNAL_SERVER_ERROR,
984 Cow::from(format!("Unhandled internal error: {}", error)),
985 )
986}