Skip to main content

mz_http_util/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! HTTP utilities.
11
12use std::io::{Read, Write};
13
14use askama::Template;
15use axum::Json;
16use axum::http::HeaderMap;
17use axum::http::HeaderValue;
18use axum::http::Uri;
19use axum::http::status::StatusCode;
20use axum::response::{Html, IntoResponse, Response};
21use axum_extra::TypedHeader;
22use base64::prelude::*;
23use flate2::Compression;
24use flate2::read::GzDecoder;
25use flate2::write::GzEncoder;
26use headers::ContentType;
27use mz_ore::metrics::MetricsRegistry;
28use mz_ore::tracing::TracingHandle;
29use prometheus::Encoder;
30use serde::de::DeserializeOwned;
31use serde::{Deserialize, Serialize};
32use serde_json::json;
33use tower_http::cors::AllowOrigin;
34use tracing_subscriber::EnvFilter;
35
36/// MIME type used for the Prometheus protobuf scrape format.
37/// <https://prometheus.io/docs/instrumenting/content_negotiation/#protocol-headers>
38pub const PROMETHEUS_PROTOBUF_CONTENT_TYPE: &str = "application/vnd.google.protobuf; \
39     proto=io.prometheus.client.MetricFamily; \
40     encoding=delimited";
41
42/// Request header sent by callers that understand and want
43/// [`MATERIALIZE_ENRICH_RULES_HEADER`] in the response.
44pub const MATERIALIZE_ACCEPT_ENRICH_RULES_HEADER: &str = "x-materialize-accept-enrich-rules";
45
46/// Response header listing the [`mz_ore::metrics::Rule`]s registered on the
47/// metrics registry, as gzipped-then-base64-encoded JSON (see
48/// [`encode_enrich_rules`]). Emitted by [`handle_prometheus`] only when the
49/// caller opts in via [`MATERIALIZE_ACCEPT_ENRICH_RULES_HEADER`].
50pub const MATERIALIZE_ENRICH_RULES_HEADER: &str = "x-materialize-enrich-rules";
51
52fn wants_prometheus_protobuf(headers: &HeaderMap) -> bool {
53    headers
54        .get_all(axum::http::header::ACCEPT)
55        .iter()
56        .filter_map(|v| v.to_str().ok())
57        .any(|v| v.contains(PROMETHEUS_PROTOBUF_CONTENT_TYPE))
58}
59
60fn wants_enrich_rules(headers: &HeaderMap) -> bool {
61    headers.contains_key(MATERIALIZE_ACCEPT_ENRICH_RULES_HEADER)
62}
63
64/// Serializes `value` as JSON, gzips it, and base64-encodes the result for
65/// transport in [`MATERIALIZE_ENRICH_RULES_HEADER`].
66///
67/// The same handful of enrichment rules repeat across nearly every metric, so
68/// the JSON is highly compressible. Gzipping keeps the header well under the
69/// typical 8-16KB header-size limit even with hundreds of metrics; we then
70/// base64-encode because HTTP header values must be printable ASCII.
71pub fn encode_enrich_rules<T: Serialize>(value: &T) -> anyhow::Result<String> {
72    let json = serde_json::to_vec(value)?;
73    let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
74    encoder.write_all(&json)?;
75    Ok(BASE64_STANDARD.encode(encoder.finish()?))
76}
77
78pub fn decode_enrich_rules<T: DeserializeOwned>(value: &str) -> anyhow::Result<T> {
79    let compressed = BASE64_STANDARD.decode(value)?;
80    let mut json = Vec::new();
81    GzDecoder::new(&compressed[..]).read_to_end(&mut json)?;
82    Ok(serde_json::from_slice(&json)?)
83}
84
85/// Renders a template into an HTTP response.
86pub fn template_response<T>(template: T) -> Html<String>
87where
88    T: Template,
89{
90    Html(template.render().expect("template rendering cannot fail"))
91}
92
93#[macro_export]
94/// Generates a `handle_static` function that serves static content for HTTP servers.
95/// Expects three arguments: an `include_dir::Dir` object where the static content is served,
96/// and two strings representing the (crate-local) paths to the production and development
97/// static files.
98macro_rules! make_handle_static {
99    (
100        dir_1: $dir_1:expr,
101        $(dir_2: $dir_2:expr,)?
102        prod_base_path: $prod_base_path:expr,
103        dev_base_path: $dev_base_path:expr$(,)?
104    ) => {
105        #[allow(clippy::unused_async)]
106        pub async fn handle_static(
107            path: ::axum::extract::Path<String>,
108        ) -> impl ::axum::response::IntoResponse {
109            #[cfg(not(feature = "dev-web"))]
110            const DIR_1: ::include_dir::Dir = $dir_1;
111            $(
112                #[cfg(not(feature = "dev-web"))]
113                const DIR_2: ::include_dir::Dir = $dir_2;
114            )?
115
116
117            #[cfg(not(feature = "dev-web"))]
118            fn get_static_file(path: &str) -> Option<&'static [u8]> {
119                DIR_1.get_file(path).or_else(|| DIR_2.get_file(path)).map(|f| f.contents())
120            }
121
122            #[cfg(feature = "dev-web")]
123            fn get_static_file(path: &str) -> Option<Vec<u8>> {
124                use ::std::fs;
125
126                #[cfg(not(debug_assertions))]
127                compile_error!("cannot enable insecure `dev-web` feature in release mode");
128
129                // Prefer the unminified files in static-dev, if they exist.
130                let dev_path =
131                    format!("{}/{}/{}", env!("CARGO_MANIFEST_DIR"), $dev_base_path, path);
132                let prod_path = format!(
133                    "{}/{}/{}",
134                    env!("CARGO_MANIFEST_DIR"),
135                    $prod_base_path,
136                    path
137                );
138                match fs::read(dev_path).or_else(|_| fs::read(prod_path)) {
139                    Ok(contents) => Some(contents),
140                    Err(e) => {
141                        ::tracing::debug!("dev-web failed to load static file: {}: {}", path, e);
142                        None
143                    }
144                }
145            }
146            let path = path.strip_prefix('/').unwrap_or(&path);
147            let content_type = match ::std::path::Path::new(path)
148                .extension()
149                .and_then(|e| e.to_str())
150            {
151                Some("js") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
152                    ::mime::TEXT_JAVASCRIPT,
153                ))),
154                Some("css") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
155                    ::mime::TEXT_CSS,
156                ))),
157                None | Some(_) => None,
158            };
159            match get_static_file(path) {
160                Some(body) => Ok((content_type, body)),
161                None => Err((::http::StatusCode::NOT_FOUND, "not found")),
162            }
163        }
164    };
165}
166
167/// Serves a basic liveness check response
168#[allow(clippy::unused_async)]
169pub async fn handle_liveness_check() -> impl IntoResponse {
170    (StatusCode::OK, "Liveness check successful!")
171}
172
173/// Serves metrics from the selected metrics registry variant.
174///
175/// If the caller's `Accept` header advertises support for the Prometheus
176/// protobuf format (`application/vnd.google.protobuf`), the response is a
177/// length-delimited stream of `io.prometheus.client.MetricFamily` messages.
178/// Otherwise the standard Prometheus text format is returned.
179#[allow(clippy::unused_async)]
180pub async fn handle_prometheus(
181    registry: &MetricsRegistry,
182    headers: HeaderMap,
183) -> Result<Response, (StatusCode, String)> {
184    let families = registry.gather();
185    let mut buf = Vec::new();
186    let content_type = if wants_prometheus_protobuf(&headers) {
187        prometheus::ProtobufEncoder::new()
188            .encode(&families, &mut buf)
189            .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
190        ContentType::from(
191            PROMETHEUS_PROTOBUF_CONTENT_TYPE
192                .parse::<mime::Mime>()
193                .map_err(|e: mime::FromStrError| {
194                    (StatusCode::INTERNAL_SERVER_ERROR, e.to_string())
195                })?,
196        )
197    } else {
198        prometheus::TextEncoder::new()
199            .encode(&families, &mut buf)
200            .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
201        ContentType::text()
202    };
203
204    let mut resp = (TypedHeader(content_type), buf).into_response();
205    if wants_enrich_rules(&headers) {
206        let rules_by_metric = registry.rules_by_metric();
207        if !rules_by_metric.is_empty() {
208            let encoded = encode_enrich_rules(&rules_by_metric)
209                .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
210            resp.headers_mut().insert(
211                MATERIALIZE_ENRICH_RULES_HEADER,
212                HeaderValue::from_str(&encoded)
213                    .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?,
214            );
215        }
216    }
217    Ok(resp)
218}
219
220#[derive(Serialize, Deserialize)]
221pub struct DynamicFilterTarget {
222    targets: String,
223}
224
225/// Dynamically reloads a filter for a tracing layer.
226#[allow(clippy::unused_async)]
227pub async fn handle_reload_tracing_filter(
228    handle: &TracingHandle,
229    reload: fn(&TracingHandle, EnvFilter) -> Result<(), anyhow::Error>,
230    Json(cfg): Json<DynamicFilterTarget>,
231) -> impl IntoResponse {
232    match cfg.targets.parse::<EnvFilter>() {
233        Ok(targets) => match reload(handle, targets) {
234            Ok(()) => (StatusCode::OK, cfg.targets.to_string()),
235            Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
236        },
237        Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
238    }
239}
240
241/// Returns information about the current status of tracing.
242#[allow(clippy::unused_async)]
243pub async fn handle_tracing() -> impl IntoResponse {
244    (
245        StatusCode::OK,
246        Json(json!({
247            "current_level_filter": tracing::level_filters::LevelFilter::current().to_string()
248        })),
249    )
250}
251
252/// Returns true if `origin` matches any entry in `allowed`. Supports bare `*`
253/// (any origin), exact match, and wildcard subdomains (`*.example.com`).
254pub fn origin_is_allowed(origin: &HeaderValue, allowed: &[HeaderValue]) -> bool {
255    fn wildcard_origin_match(origin: &HeaderValue, wildcard: &[u8]) -> bool {
256        let Some(origin) = origin.to_str().ok() else {
257            return false;
258        };
259        let Ok(origin) = origin.parse::<Uri>() else {
260            return false;
261        };
262        let Some(host) = origin.host() else {
263            return false;
264        };
265
266        host.as_bytes().ends_with(wildcard)
267    }
268
269    if allowed.iter().any(|o| o.as_bytes() == b"*") {
270        return true;
271    }
272    for val in allowed {
273        if (val.as_bytes().starts_with(b"*.")
274            && wildcard_origin_match(origin, &val.as_bytes()[1..]))
275            || origin == val
276        {
277            return true;
278        }
279    }
280    false
281}
282
283/// Construct a CORS policy to allow origins to query us via HTTP. If any bare
284/// '*' is passed, this allows any origin; otherwise, allows a list of origins,
285/// which can include wildcard subdomains. If the allowed origin starts with a
286/// '*', allow anything from that glob. Otherwise check for an exact match.
287pub fn build_cors_allowed_origin<'a, I>(allowed: I) -> AllowOrigin
288where
289    I: IntoIterator<Item = &'a HeaderValue>,
290{
291    fn wildcard_origin_match(origin: &HeaderValue, wildcard: &[u8]) -> bool {
292        let Some(origin) = origin.to_str().ok() else {
293            return false;
294        };
295        let Ok(origin) = origin.parse::<Uri>() else {
296            return false;
297        };
298        let Some(host) = origin.host() else {
299            return false;
300        };
301
302        host.as_bytes().ends_with(wildcard)
303    }
304
305    let allowed = allowed.into_iter().cloned().collect::<Vec<HeaderValue>>();
306    if allowed.iter().any(|o| o.as_bytes() == b"*") {
307        AllowOrigin::any()
308    } else {
309        AllowOrigin::predicate(move |origin: &HeaderValue, _request_parts: _| {
310            for val in &allowed {
311                if (val.as_bytes().starts_with(b"*.")
312                    && wildcard_origin_match(origin, &val.as_bytes()[1..]))
313                    || origin == val
314                {
315                    return true;
316                }
317            }
318            false
319        })
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use std::collections::BTreeMap;
326
327    use axum::http::HeaderMap;
328    use http::header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN};
329    use http::{HeaderValue, Method, Request, Response};
330    use mz_ore::metric;
331    use mz_ore::metrics::{MetricsRegistry, Rule};
332    use tower::{Service, ServiceBuilder, ServiceExt};
333    use tower_http::cors::CorsLayer;
334
335    use super::{
336        MATERIALIZE_ACCEPT_ENRICH_RULES_HEADER, MATERIALIZE_ENRICH_RULES_HEADER, handle_prometheus,
337    };
338
339    fn registry_with_rules() -> MetricsRegistry {
340        let registry = MetricsRegistry::new();
341        let _: prometheus::IntCounter = registry.register(metric!(
342            name: "mz_test_handle_prometheus_metric",
343            help: "test metric carrying a per-metric enrichment rule",
344            rules: [
345                Rule::ClusterNameLookup {
346                    cluster_id_label: "cluster_id".into(),
347                    output_label: "cluster_name".into(),
348                },
349            ],
350        ));
351        registry
352    }
353
354    #[mz_ore::test(tokio::test)]
355    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function on OS `linux`
356    async fn handle_prometheus_emits_rules_header_when_opted_in() {
357        let registry = registry_with_rules();
358        let mut headers = HeaderMap::new();
359        headers.insert(
360            MATERIALIZE_ACCEPT_ENRICH_RULES_HEADER,
361            HeaderValue::from_static("1"),
362        );
363        let resp = handle_prometheus(&registry, headers).await.unwrap();
364        let value = resp
365            .headers()
366            .get(MATERIALIZE_ENRICH_RULES_HEADER)
367            .expect("rules header present");
368        let parsed: BTreeMap<String, Vec<Rule>> =
369            super::decode_enrich_rules(value.to_str().unwrap()).unwrap();
370        assert_eq!(parsed, registry.rules_by_metric());
371    }
372
373    #[mz_ore::test(tokio::test)]
374    async fn handle_prometheus_omits_header_without_opt_in() {
375        let registry = registry_with_rules();
376        let resp = handle_prometheus(&registry, HeaderMap::new())
377            .await
378            .unwrap();
379        assert!(
380            resp.headers()
381                .get(MATERIALIZE_ENRICH_RULES_HEADER)
382                .is_none()
383        );
384    }
385
386    #[mz_ore::test(tokio::test)]
387    async fn handle_prometheus_omits_header_when_no_rules() {
388        let registry = MetricsRegistry::new();
389        let mut headers = HeaderMap::new();
390        headers.insert(
391            MATERIALIZE_ACCEPT_ENRICH_RULES_HEADER,
392            HeaderValue::from_static("1"),
393        );
394        let resp = handle_prometheus(&registry, headers).await.unwrap();
395        assert!(
396            resp.headers()
397                .get(MATERIALIZE_ENRICH_RULES_HEADER)
398                .is_none()
399        );
400    }
401
402    #[mz_ore::test(tokio::test)]
403    async fn test_cors() {
404        async fn test_request(cors: &CorsLayer, origin: &HeaderValue) -> Option<HeaderValue> {
405            let mut service = ServiceBuilder::new()
406                .layer(cors)
407                .service_fn(|_| async { Ok::<_, anyhow::Error>(Response::new("")) });
408            let request = Request::builder().header(ORIGIN, origin).body("").unwrap();
409            let response = service.ready().await.unwrap().call(request).await.unwrap();
410            response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).cloned()
411        }
412
413        #[derive(Default)]
414        struct TestCase {
415            /// The allowed origins to provide as input.
416            allowed_origins: Vec<HeaderValue>,
417            /// Request origins that are expected to be mirrored back in the
418            /// response.
419            mirrored_origins: Vec<HeaderValue>,
420            /// Request origins that are expected to be allowed via a `*`
421            /// response.
422            wildcard_origins: Vec<HeaderValue>,
423            /// Request origins that are expected to be rejected.
424            invalid_origins: Vec<HeaderValue>,
425        }
426
427        let test_cases = [
428            TestCase {
429                allowed_origins: vec![HeaderValue::from_static("https://example.org")],
430                mirrored_origins: vec![HeaderValue::from_static("https://example.org")],
431                invalid_origins: vec![HeaderValue::from_static("https://wrong.com")],
432                wildcard_origins: vec![],
433            },
434            TestCase {
435                allowed_origins: vec![HeaderValue::from_static("*.example.org")],
436                mirrored_origins: vec![
437                    HeaderValue::from_static("https://foo.example.org"),
438                    HeaderValue::from_static("https://foo.example.org:8443"),
439                    HeaderValue::from_static("https://bar.example.org"),
440                    HeaderValue::from_static("https://baz.bar.foo.example.org"),
441                ],
442                wildcard_origins: vec![],
443                invalid_origins: vec![
444                    HeaderValue::from_static("https://example.org"),
445                    HeaderValue::from_static("https://wrong.com"),
446                    HeaderValue::from_static("https://wrong.com:8443"),
447                ],
448            },
449            TestCase {
450                allowed_origins: vec![
451                    HeaderValue::from_static("*.example.org"),
452                    HeaderValue::from_static("https://other.com"),
453                ],
454                mirrored_origins: vec![
455                    HeaderValue::from_static("https://foo.example.org"),
456                    HeaderValue::from_static("https://foo.example.org:8443"),
457                    HeaderValue::from_static("https://bar.example.org"),
458                    HeaderValue::from_static("https://baz.bar.foo.example.org"),
459                    HeaderValue::from_static("https://other.com"),
460                ],
461                wildcard_origins: vec![],
462                invalid_origins: vec![HeaderValue::from_static("https://example.org")],
463            },
464            TestCase {
465                allowed_origins: vec![HeaderValue::from_static("*")],
466                mirrored_origins: vec![],
467                wildcard_origins: vec![
468                    HeaderValue::from_static("literally"),
469                    HeaderValue::from_static("https://anything.com"),
470                ],
471                invalid_origins: vec![],
472            },
473            TestCase {
474                allowed_origins: vec![
475                    HeaderValue::from_static("*"),
476                    HeaderValue::from_static("https://iwillbeignored.com"),
477                ],
478                mirrored_origins: vec![],
479                wildcard_origins: vec![
480                    HeaderValue::from_static("literally"),
481                    HeaderValue::from_static("https://anything.com"),
482                ],
483                invalid_origins: vec![],
484            },
485        ];
486
487        for test_case in test_cases {
488            let allowed_origins = &test_case.allowed_origins;
489            let cors = CorsLayer::new()
490                .allow_methods([Method::GET])
491                .allow_origin(super::build_cors_allowed_origin(allowed_origins));
492            for valid in &test_case.mirrored_origins {
493                let header = test_request(&cors, valid).await;
494                assert_eq!(
495                    header.as_ref(),
496                    Some(valid),
497                    "origin {valid:?} unexpectedly not mirrored\n\
498                     allowed_origins={allowed_origins:?}",
499                );
500            }
501            for valid in &test_case.wildcard_origins {
502                let header = test_request(&cors, valid).await;
503                assert_eq!(
504                    header.as_ref(),
505                    Some(&HeaderValue::from_static("*")),
506                    "origin {valid:?} unexpectedly not allowed\n\
507                     allowed_origins={allowed_origins:?}",
508                );
509            }
510            for invalid in &test_case.invalid_origins {
511                let header = test_request(&cors, invalid).await;
512                assert_eq!(
513                    header, None,
514                    "origin {invalid:?} unexpectedly not allowed\n\
515                     allowed_origins={allowed_origins:?}",
516                );
517            }
518        }
519    }
520}