1use std::io::{Read, Write};
13
14use askama::Template;
15use axum::Json;
16use axum::http::HeaderMap;
17use axum::http::HeaderValue;
18use axum::http::Uri;
19use axum::http::status::StatusCode;
20use axum::response::{Html, IntoResponse, Response};
21use axum_extra::TypedHeader;
22use base64::prelude::*;
23use flate2::Compression;
24use flate2::read::GzDecoder;
25use flate2::write::GzEncoder;
26use headers::ContentType;
27use mz_ore::metrics::MetricsRegistry;
28use mz_ore::tracing::TracingHandle;
29use prometheus::Encoder;
30use serde::de::DeserializeOwned;
31use serde::{Deserialize, Serialize};
32use serde_json::json;
33use tower_http::cors::AllowOrigin;
34use tracing_subscriber::EnvFilter;
35
36pub const PROMETHEUS_PROTOBUF_CONTENT_TYPE: &str = "application/vnd.google.protobuf; \
39 proto=io.prometheus.client.MetricFamily; \
40 encoding=delimited";
41
42pub const MATERIALIZE_ACCEPT_ENRICH_RULES_HEADER: &str = "x-materialize-accept-enrich-rules";
45
46pub const MATERIALIZE_ENRICH_RULES_HEADER: &str = "x-materialize-enrich-rules";
51
52fn wants_prometheus_protobuf(headers: &HeaderMap) -> bool {
53 headers
54 .get_all(axum::http::header::ACCEPT)
55 .iter()
56 .filter_map(|v| v.to_str().ok())
57 .any(|v| v.contains(PROMETHEUS_PROTOBUF_CONTENT_TYPE))
58}
59
60fn wants_enrich_rules(headers: &HeaderMap) -> bool {
61 headers.contains_key(MATERIALIZE_ACCEPT_ENRICH_RULES_HEADER)
62}
63
64pub fn encode_enrich_rules<T: Serialize>(value: &T) -> anyhow::Result<String> {
72 let json = serde_json::to_vec(value)?;
73 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
74 encoder.write_all(&json)?;
75 Ok(BASE64_STANDARD.encode(encoder.finish()?))
76}
77
78pub fn decode_enrich_rules<T: DeserializeOwned>(value: &str) -> anyhow::Result<T> {
79 let compressed = BASE64_STANDARD.decode(value)?;
80 let mut json = Vec::new();
81 GzDecoder::new(&compressed[..]).read_to_end(&mut json)?;
82 Ok(serde_json::from_slice(&json)?)
83}
84
85pub fn template_response<T>(template: T) -> Html<String>
87where
88 T: Template,
89{
90 Html(template.render().expect("template rendering cannot fail"))
91}
92
93#[macro_export]
94macro_rules! make_handle_static {
99 (
100 dir_1: $dir_1:expr,
101 $(dir_2: $dir_2:expr,)?
102 prod_base_path: $prod_base_path:expr,
103 dev_base_path: $dev_base_path:expr$(,)?
104 ) => {
105 #[allow(clippy::unused_async)]
106 pub async fn handle_static(
107 path: ::axum::extract::Path<String>,
108 ) -> impl ::axum::response::IntoResponse {
109 #[cfg(not(feature = "dev-web"))]
110 const DIR_1: ::include_dir::Dir = $dir_1;
111 $(
112 #[cfg(not(feature = "dev-web"))]
113 const DIR_2: ::include_dir::Dir = $dir_2;
114 )?
115
116
117 #[cfg(not(feature = "dev-web"))]
118 fn get_static_file(path: &str) -> Option<&'static [u8]> {
119 DIR_1.get_file(path).or_else(|| DIR_2.get_file(path)).map(|f| f.contents())
120 }
121
122 #[cfg(feature = "dev-web")]
123 fn get_static_file(path: &str) -> Option<Vec<u8>> {
124 use ::std::fs;
125
126 #[cfg(not(debug_assertions))]
127 compile_error!("cannot enable insecure `dev-web` feature in release mode");
128
129 let dev_path =
131 format!("{}/{}/{}", env!("CARGO_MANIFEST_DIR"), $dev_base_path, path);
132 let prod_path = format!(
133 "{}/{}/{}",
134 env!("CARGO_MANIFEST_DIR"),
135 $prod_base_path,
136 path
137 );
138 match fs::read(dev_path).or_else(|_| fs::read(prod_path)) {
139 Ok(contents) => Some(contents),
140 Err(e) => {
141 ::tracing::debug!("dev-web failed to load static file: {}: {}", path, e);
142 None
143 }
144 }
145 }
146 let path = path.strip_prefix('/').unwrap_or(&path);
147 let content_type = match ::std::path::Path::new(path)
148 .extension()
149 .and_then(|e| e.to_str())
150 {
151 Some("js") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
152 ::mime::TEXT_JAVASCRIPT,
153 ))),
154 Some("css") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
155 ::mime::TEXT_CSS,
156 ))),
157 None | Some(_) => None,
158 };
159 match get_static_file(path) {
160 Some(body) => Ok((content_type, body)),
161 None => Err((::http::StatusCode::NOT_FOUND, "not found")),
162 }
163 }
164 };
165}
166
167#[allow(clippy::unused_async)]
169pub async fn handle_liveness_check() -> impl IntoResponse {
170 (StatusCode::OK, "Liveness check successful!")
171}
172
173#[allow(clippy::unused_async)]
180pub async fn handle_prometheus(
181 registry: &MetricsRegistry,
182 headers: HeaderMap,
183) -> Result<Response, (StatusCode, String)> {
184 let families = registry.gather();
185 let mut buf = Vec::new();
186 let content_type = if wants_prometheus_protobuf(&headers) {
187 prometheus::ProtobufEncoder::new()
188 .encode(&families, &mut buf)
189 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
190 ContentType::from(
191 PROMETHEUS_PROTOBUF_CONTENT_TYPE
192 .parse::<mime::Mime>()
193 .map_err(|e: mime::FromStrError| {
194 (StatusCode::INTERNAL_SERVER_ERROR, e.to_string())
195 })?,
196 )
197 } else {
198 prometheus::TextEncoder::new()
199 .encode(&families, &mut buf)
200 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
201 ContentType::text()
202 };
203
204 let mut resp = (TypedHeader(content_type), buf).into_response();
205 if wants_enrich_rules(&headers) {
206 let rules_by_metric = registry.rules_by_metric();
207 if !rules_by_metric.is_empty() {
208 let encoded = encode_enrich_rules(&rules_by_metric)
209 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
210 resp.headers_mut().insert(
211 MATERIALIZE_ENRICH_RULES_HEADER,
212 HeaderValue::from_str(&encoded)
213 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?,
214 );
215 }
216 }
217 Ok(resp)
218}
219
220#[derive(Serialize, Deserialize)]
221pub struct DynamicFilterTarget {
222 targets: String,
223}
224
225#[allow(clippy::unused_async)]
227pub async fn handle_reload_tracing_filter(
228 handle: &TracingHandle,
229 reload: fn(&TracingHandle, EnvFilter) -> Result<(), anyhow::Error>,
230 Json(cfg): Json<DynamicFilterTarget>,
231) -> impl IntoResponse {
232 match cfg.targets.parse::<EnvFilter>() {
233 Ok(targets) => match reload(handle, targets) {
234 Ok(()) => (StatusCode::OK, cfg.targets.to_string()),
235 Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
236 },
237 Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
238 }
239}
240
241#[allow(clippy::unused_async)]
243pub async fn handle_tracing() -> impl IntoResponse {
244 (
245 StatusCode::OK,
246 Json(json!({
247 "current_level_filter": tracing::level_filters::LevelFilter::current().to_string()
248 })),
249 )
250}
251
252pub fn origin_is_allowed(origin: &HeaderValue, allowed: &[HeaderValue]) -> bool {
255 fn wildcard_origin_match(origin: &HeaderValue, wildcard: &[u8]) -> bool {
256 let Some(origin) = origin.to_str().ok() else {
257 return false;
258 };
259 let Ok(origin) = origin.parse::<Uri>() else {
260 return false;
261 };
262 let Some(host) = origin.host() else {
263 return false;
264 };
265
266 host.as_bytes().ends_with(wildcard)
267 }
268
269 if allowed.iter().any(|o| o.as_bytes() == b"*") {
270 return true;
271 }
272 for val in allowed {
273 if (val.as_bytes().starts_with(b"*.")
274 && wildcard_origin_match(origin, &val.as_bytes()[1..]))
275 || origin == val
276 {
277 return true;
278 }
279 }
280 false
281}
282
283pub fn build_cors_allowed_origin<'a, I>(allowed: I) -> AllowOrigin
288where
289 I: IntoIterator<Item = &'a HeaderValue>,
290{
291 fn wildcard_origin_match(origin: &HeaderValue, wildcard: &[u8]) -> bool {
292 let Some(origin) = origin.to_str().ok() else {
293 return false;
294 };
295 let Ok(origin) = origin.parse::<Uri>() else {
296 return false;
297 };
298 let Some(host) = origin.host() else {
299 return false;
300 };
301
302 host.as_bytes().ends_with(wildcard)
303 }
304
305 let allowed = allowed.into_iter().cloned().collect::<Vec<HeaderValue>>();
306 if allowed.iter().any(|o| o.as_bytes() == b"*") {
307 AllowOrigin::any()
308 } else {
309 AllowOrigin::predicate(move |origin: &HeaderValue, _request_parts: _| {
310 for val in &allowed {
311 if (val.as_bytes().starts_with(b"*.")
312 && wildcard_origin_match(origin, &val.as_bytes()[1..]))
313 || origin == val
314 {
315 return true;
316 }
317 }
318 false
319 })
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use std::collections::BTreeMap;
326
327 use axum::http::HeaderMap;
328 use http::header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN};
329 use http::{HeaderValue, Method, Request, Response};
330 use mz_ore::metric;
331 use mz_ore::metrics::{MetricsRegistry, Rule};
332 use tower::{Service, ServiceBuilder, ServiceExt};
333 use tower_http::cors::CorsLayer;
334
335 use super::{
336 MATERIALIZE_ACCEPT_ENRICH_RULES_HEADER, MATERIALIZE_ENRICH_RULES_HEADER, handle_prometheus,
337 };
338
339 fn registry_with_rules() -> MetricsRegistry {
340 let registry = MetricsRegistry::new();
341 let _: prometheus::IntCounter = registry.register(metric!(
342 name: "mz_test_handle_prometheus_metric",
343 help: "test metric carrying a per-metric enrichment rule",
344 rules: [
345 Rule::ClusterNameLookup {
346 cluster_id_label: "cluster_id".into(),
347 output_label: "cluster_name".into(),
348 },
349 ],
350 ));
351 registry
352 }
353
354 #[mz_ore::test(tokio::test)]
355 #[cfg_attr(miri, ignore)] async fn handle_prometheus_emits_rules_header_when_opted_in() {
357 let registry = registry_with_rules();
358 let mut headers = HeaderMap::new();
359 headers.insert(
360 MATERIALIZE_ACCEPT_ENRICH_RULES_HEADER,
361 HeaderValue::from_static("1"),
362 );
363 let resp = handle_prometheus(®istry, headers).await.unwrap();
364 let value = resp
365 .headers()
366 .get(MATERIALIZE_ENRICH_RULES_HEADER)
367 .expect("rules header present");
368 let parsed: BTreeMap<String, Vec<Rule>> =
369 super::decode_enrich_rules(value.to_str().unwrap()).unwrap();
370 assert_eq!(parsed, registry.rules_by_metric());
371 }
372
373 #[mz_ore::test(tokio::test)]
374 async fn handle_prometheus_omits_header_without_opt_in() {
375 let registry = registry_with_rules();
376 let resp = handle_prometheus(®istry, HeaderMap::new())
377 .await
378 .unwrap();
379 assert!(
380 resp.headers()
381 .get(MATERIALIZE_ENRICH_RULES_HEADER)
382 .is_none()
383 );
384 }
385
386 #[mz_ore::test(tokio::test)]
387 async fn handle_prometheus_omits_header_when_no_rules() {
388 let registry = MetricsRegistry::new();
389 let mut headers = HeaderMap::new();
390 headers.insert(
391 MATERIALIZE_ACCEPT_ENRICH_RULES_HEADER,
392 HeaderValue::from_static("1"),
393 );
394 let resp = handle_prometheus(®istry, headers).await.unwrap();
395 assert!(
396 resp.headers()
397 .get(MATERIALIZE_ENRICH_RULES_HEADER)
398 .is_none()
399 );
400 }
401
402 #[mz_ore::test(tokio::test)]
403 async fn test_cors() {
404 async fn test_request(cors: &CorsLayer, origin: &HeaderValue) -> Option<HeaderValue> {
405 let mut service = ServiceBuilder::new()
406 .layer(cors)
407 .service_fn(|_| async { Ok::<_, anyhow::Error>(Response::new("")) });
408 let request = Request::builder().header(ORIGIN, origin).body("").unwrap();
409 let response = service.ready().await.unwrap().call(request).await.unwrap();
410 response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).cloned()
411 }
412
413 #[derive(Default)]
414 struct TestCase {
415 allowed_origins: Vec<HeaderValue>,
417 mirrored_origins: Vec<HeaderValue>,
420 wildcard_origins: Vec<HeaderValue>,
423 invalid_origins: Vec<HeaderValue>,
425 }
426
427 let test_cases = [
428 TestCase {
429 allowed_origins: vec![HeaderValue::from_static("https://example.org")],
430 mirrored_origins: vec![HeaderValue::from_static("https://example.org")],
431 invalid_origins: vec![HeaderValue::from_static("https://wrong.com")],
432 wildcard_origins: vec![],
433 },
434 TestCase {
435 allowed_origins: vec![HeaderValue::from_static("*.example.org")],
436 mirrored_origins: vec![
437 HeaderValue::from_static("https://foo.example.org"),
438 HeaderValue::from_static("https://foo.example.org:8443"),
439 HeaderValue::from_static("https://bar.example.org"),
440 HeaderValue::from_static("https://baz.bar.foo.example.org"),
441 ],
442 wildcard_origins: vec![],
443 invalid_origins: vec![
444 HeaderValue::from_static("https://example.org"),
445 HeaderValue::from_static("https://wrong.com"),
446 HeaderValue::from_static("https://wrong.com:8443"),
447 ],
448 },
449 TestCase {
450 allowed_origins: vec![
451 HeaderValue::from_static("*.example.org"),
452 HeaderValue::from_static("https://other.com"),
453 ],
454 mirrored_origins: vec![
455 HeaderValue::from_static("https://foo.example.org"),
456 HeaderValue::from_static("https://foo.example.org:8443"),
457 HeaderValue::from_static("https://bar.example.org"),
458 HeaderValue::from_static("https://baz.bar.foo.example.org"),
459 HeaderValue::from_static("https://other.com"),
460 ],
461 wildcard_origins: vec![],
462 invalid_origins: vec![HeaderValue::from_static("https://example.org")],
463 },
464 TestCase {
465 allowed_origins: vec![HeaderValue::from_static("*")],
466 mirrored_origins: vec![],
467 wildcard_origins: vec![
468 HeaderValue::from_static("literally"),
469 HeaderValue::from_static("https://anything.com"),
470 ],
471 invalid_origins: vec![],
472 },
473 TestCase {
474 allowed_origins: vec![
475 HeaderValue::from_static("*"),
476 HeaderValue::from_static("https://iwillbeignored.com"),
477 ],
478 mirrored_origins: vec![],
479 wildcard_origins: vec![
480 HeaderValue::from_static("literally"),
481 HeaderValue::from_static("https://anything.com"),
482 ],
483 invalid_origins: vec![],
484 },
485 ];
486
487 for test_case in test_cases {
488 let allowed_origins = &test_case.allowed_origins;
489 let cors = CorsLayer::new()
490 .allow_methods([Method::GET])
491 .allow_origin(super::build_cors_allowed_origin(allowed_origins));
492 for valid in &test_case.mirrored_origins {
493 let header = test_request(&cors, valid).await;
494 assert_eq!(
495 header.as_ref(),
496 Some(valid),
497 "origin {valid:?} unexpectedly not mirrored\n\
498 allowed_origins={allowed_origins:?}",
499 );
500 }
501 for valid in &test_case.wildcard_origins {
502 let header = test_request(&cors, valid).await;
503 assert_eq!(
504 header.as_ref(),
505 Some(&HeaderValue::from_static("*")),
506 "origin {valid:?} unexpectedly not allowed\n\
507 allowed_origins={allowed_origins:?}",
508 );
509 }
510 for invalid in &test_case.invalid_origins {
511 let header = test_request(&cors, invalid).await;
512 assert_eq!(
513 header, None,
514 "origin {invalid:?} unexpectedly not allowed\n\
515 allowed_origins={allowed_origins:?}",
516 );
517 }
518 }
519 }
520}