Skip to main content

mz_http_util/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! HTTP utilities.
11
12use askama::Template;
13use axum::Json;
14use axum::http::HeaderMap;
15use axum::http::HeaderValue;
16use axum::http::Uri;
17use axum::http::status::StatusCode;
18use axum::response::{Html, IntoResponse, Response};
19use axum_extra::TypedHeader;
20use headers::ContentType;
21use mz_ore::metrics::MetricsRegistry;
22use mz_ore::tracing::TracingHandle;
23use prometheus::Encoder;
24use serde::{Deserialize, Serialize};
25use serde_json::json;
26use tower_http::cors::AllowOrigin;
27use tracing_subscriber::EnvFilter;
28
29/// MIME type used for the Prometheus protobuf scrape format.
30/// <https://prometheus.io/docs/instrumenting/content_negotiation/#protocol-headers>
31pub const PROMETHEUS_PROTOBUF_CONTENT_TYPE: &str = "application/vnd.google.protobuf; \
32     proto=io.prometheus.client.MetricFamily; \
33     encoding=delimited";
34
35fn wants_prometheus_protobuf(headers: &HeaderMap) -> bool {
36    headers
37        .get_all(axum::http::header::ACCEPT)
38        .iter()
39        .filter_map(|v| v.to_str().ok())
40        .any(|v| v.contains(PROMETHEUS_PROTOBUF_CONTENT_TYPE))
41}
42
43/// Renders a template into an HTTP response.
44pub fn template_response<T>(template: T) -> Html<String>
45where
46    T: Template,
47{
48    Html(template.render().expect("template rendering cannot fail"))
49}
50
51#[macro_export]
52/// Generates a `handle_static` function that serves static content for HTTP servers.
53/// Expects three arguments: an `include_dir::Dir` object where the static content is served,
54/// and two strings representing the (crate-local) paths to the production and development
55/// static files.
56macro_rules! make_handle_static {
57    (
58        dir_1: $dir_1:expr,
59        $(dir_2: $dir_2:expr,)?
60        prod_base_path: $prod_base_path:expr,
61        dev_base_path: $dev_base_path:expr$(,)?
62    ) => {
63        #[allow(clippy::unused_async)]
64        pub async fn handle_static(
65            path: ::axum::extract::Path<String>,
66        ) -> impl ::axum::response::IntoResponse {
67            #[cfg(not(feature = "dev-web"))]
68            const DIR_1: ::include_dir::Dir = $dir_1;
69            $(
70                #[cfg(not(feature = "dev-web"))]
71                const DIR_2: ::include_dir::Dir = $dir_2;
72            )?
73
74
75            #[cfg(not(feature = "dev-web"))]
76            fn get_static_file(path: &str) -> Option<&'static [u8]> {
77                DIR_1.get_file(path).or_else(|| DIR_2.get_file(path)).map(|f| f.contents())
78            }
79
80            #[cfg(feature = "dev-web")]
81            fn get_static_file(path: &str) -> Option<Vec<u8>> {
82                use ::std::fs;
83
84                #[cfg(not(debug_assertions))]
85                compile_error!("cannot enable insecure `dev-web` feature in release mode");
86
87                // Prefer the unminified files in static-dev, if they exist.
88                let dev_path =
89                    format!("{}/{}/{}", env!("CARGO_MANIFEST_DIR"), $dev_base_path, path);
90                let prod_path = format!(
91                    "{}/{}/{}",
92                    env!("CARGO_MANIFEST_DIR"),
93                    $prod_base_path,
94                    path
95                );
96                match fs::read(dev_path).or_else(|_| fs::read(prod_path)) {
97                    Ok(contents) => Some(contents),
98                    Err(e) => {
99                        ::tracing::debug!("dev-web failed to load static file: {}: {}", path, e);
100                        None
101                    }
102                }
103            }
104            let path = path.strip_prefix('/').unwrap_or(&path);
105            let content_type = match ::std::path::Path::new(path)
106                .extension()
107                .and_then(|e| e.to_str())
108            {
109                Some("js") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
110                    ::mime::TEXT_JAVASCRIPT,
111                ))),
112                Some("css") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
113                    ::mime::TEXT_CSS,
114                ))),
115                None | Some(_) => None,
116            };
117            match get_static_file(path) {
118                Some(body) => Ok((content_type, body)),
119                None => Err((::http::StatusCode::NOT_FOUND, "not found")),
120            }
121        }
122    };
123}
124
125/// Serves a basic liveness check response
126#[allow(clippy::unused_async)]
127pub async fn handle_liveness_check() -> impl IntoResponse {
128    (StatusCode::OK, "Liveness check successful!")
129}
130
131/// Serves metrics from the selected metrics registry variant.
132///
133/// If the caller's `Accept` header advertises support for the Prometheus
134/// protobuf format (`application/vnd.google.protobuf`), the response is a
135/// length-delimited stream of `io.prometheus.client.MetricFamily` messages.
136/// Otherwise the standard Prometheus text format is returned.
137#[allow(clippy::unused_async)]
138pub async fn handle_prometheus(
139    registry: &MetricsRegistry,
140    headers: HeaderMap,
141) -> Result<Response, (StatusCode, String)> {
142    let families = registry.gather();
143    let mut buf = Vec::new();
144    let content_type = if wants_prometheus_protobuf(&headers) {
145        prometheus::ProtobufEncoder::new()
146            .encode(&families, &mut buf)
147            .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
148        ContentType::from(
149            PROMETHEUS_PROTOBUF_CONTENT_TYPE
150                .parse::<mime::Mime>()
151                .map_err(|e: mime::FromStrError| {
152                    (StatusCode::INTERNAL_SERVER_ERROR, e.to_string())
153                })?,
154        )
155    } else {
156        prometheus::TextEncoder::new()
157            .encode(&families, &mut buf)
158            .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
159        ContentType::text()
160    };
161
162    Ok((TypedHeader(content_type), buf).into_response())
163}
164
165#[derive(Serialize, Deserialize)]
166pub struct DynamicFilterTarget {
167    targets: String,
168}
169
170/// Dynamically reloads a filter for a tracing layer.
171#[allow(clippy::unused_async)]
172pub async fn handle_reload_tracing_filter(
173    handle: &TracingHandle,
174    reload: fn(&TracingHandle, EnvFilter) -> Result<(), anyhow::Error>,
175    Json(cfg): Json<DynamicFilterTarget>,
176) -> impl IntoResponse {
177    match cfg.targets.parse::<EnvFilter>() {
178        Ok(targets) => match reload(handle, targets) {
179            Ok(()) => (StatusCode::OK, cfg.targets.to_string()),
180            Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
181        },
182        Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
183    }
184}
185
186/// Returns information about the current status of tracing.
187#[allow(clippy::unused_async)]
188pub async fn handle_tracing() -> impl IntoResponse {
189    (
190        StatusCode::OK,
191        Json(json!({
192            "current_level_filter": tracing::level_filters::LevelFilter::current().to_string()
193        })),
194    )
195}
196
197/// Returns true if `origin` matches any entry in `allowed`. Supports bare `*`
198/// (any origin), exact match, and wildcard subdomains (`*.example.com`).
199pub fn origin_is_allowed(origin: &HeaderValue, allowed: &[HeaderValue]) -> bool {
200    fn wildcard_origin_match(origin: &HeaderValue, wildcard: &[u8]) -> bool {
201        let Some(origin) = origin.to_str().ok() else {
202            return false;
203        };
204        let Ok(origin) = origin.parse::<Uri>() else {
205            return false;
206        };
207        let Some(host) = origin.host() else {
208            return false;
209        };
210
211        host.as_bytes().ends_with(wildcard)
212    }
213
214    if allowed.iter().any(|o| o.as_bytes() == b"*") {
215        return true;
216    }
217    for val in allowed {
218        if (val.as_bytes().starts_with(b"*.")
219            && wildcard_origin_match(origin, &val.as_bytes()[1..]))
220            || origin == val
221        {
222            return true;
223        }
224    }
225    false
226}
227
228/// Construct a CORS policy to allow origins to query us via HTTP. If any bare
229/// '*' is passed, this allows any origin; otherwise, allows a list of origins,
230/// which can include wildcard subdomains. If the allowed origin starts with a
231/// '*', allow anything from that glob. Otherwise check for an exact match.
232pub fn build_cors_allowed_origin<'a, I>(allowed: I) -> AllowOrigin
233where
234    I: IntoIterator<Item = &'a HeaderValue>,
235{
236    fn wildcard_origin_match(origin: &HeaderValue, wildcard: &[u8]) -> bool {
237        let Some(origin) = origin.to_str().ok() else {
238            return false;
239        };
240        let Ok(origin) = origin.parse::<Uri>() else {
241            return false;
242        };
243        let Some(host) = origin.host() else {
244            return false;
245        };
246
247        host.as_bytes().ends_with(wildcard)
248    }
249
250    let allowed = allowed.into_iter().cloned().collect::<Vec<HeaderValue>>();
251    if allowed.iter().any(|o| o.as_bytes() == b"*") {
252        AllowOrigin::any()
253    } else {
254        AllowOrigin::predicate(move |origin: &HeaderValue, _request_parts: _| {
255            for val in &allowed {
256                if (val.as_bytes().starts_with(b"*.")
257                    && wildcard_origin_match(origin, &val.as_bytes()[1..]))
258                    || origin == val
259                {
260                    return true;
261                }
262            }
263            false
264        })
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use http::header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN};
271    use http::{HeaderValue, Method, Request, Response};
272    use tower::{Service, ServiceBuilder, ServiceExt};
273    use tower_http::cors::CorsLayer;
274
275    #[mz_ore::test(tokio::test)]
276    async fn test_cors() {
277        async fn test_request(cors: &CorsLayer, origin: &HeaderValue) -> Option<HeaderValue> {
278            let mut service = ServiceBuilder::new()
279                .layer(cors)
280                .service_fn(|_| async { Ok::<_, anyhow::Error>(Response::new("")) });
281            let request = Request::builder().header(ORIGIN, origin).body("").unwrap();
282            let response = service.ready().await.unwrap().call(request).await.unwrap();
283            response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).cloned()
284        }
285
286        #[derive(Default)]
287        struct TestCase {
288            /// The allowed origins to provide as input.
289            allowed_origins: Vec<HeaderValue>,
290            /// Request origins that are expected to be mirrored back in the
291            /// response.
292            mirrored_origins: Vec<HeaderValue>,
293            /// Request origins that are expected to be allowed via a `*`
294            /// response.
295            wildcard_origins: Vec<HeaderValue>,
296            /// Request origins that are expected to be rejected.
297            invalid_origins: Vec<HeaderValue>,
298        }
299
300        let test_cases = [
301            TestCase {
302                allowed_origins: vec![HeaderValue::from_static("https://example.org")],
303                mirrored_origins: vec![HeaderValue::from_static("https://example.org")],
304                invalid_origins: vec![HeaderValue::from_static("https://wrong.com")],
305                wildcard_origins: vec![],
306            },
307            TestCase {
308                allowed_origins: vec![HeaderValue::from_static("*.example.org")],
309                mirrored_origins: vec![
310                    HeaderValue::from_static("https://foo.example.org"),
311                    HeaderValue::from_static("https://foo.example.org:8443"),
312                    HeaderValue::from_static("https://bar.example.org"),
313                    HeaderValue::from_static("https://baz.bar.foo.example.org"),
314                ],
315                wildcard_origins: vec![],
316                invalid_origins: vec![
317                    HeaderValue::from_static("https://example.org"),
318                    HeaderValue::from_static("https://wrong.com"),
319                    HeaderValue::from_static("https://wrong.com:8443"),
320                ],
321            },
322            TestCase {
323                allowed_origins: vec![
324                    HeaderValue::from_static("*.example.org"),
325                    HeaderValue::from_static("https://other.com"),
326                ],
327                mirrored_origins: vec![
328                    HeaderValue::from_static("https://foo.example.org"),
329                    HeaderValue::from_static("https://foo.example.org:8443"),
330                    HeaderValue::from_static("https://bar.example.org"),
331                    HeaderValue::from_static("https://baz.bar.foo.example.org"),
332                    HeaderValue::from_static("https://other.com"),
333                ],
334                wildcard_origins: vec![],
335                invalid_origins: vec![HeaderValue::from_static("https://example.org")],
336            },
337            TestCase {
338                allowed_origins: vec![HeaderValue::from_static("*")],
339                mirrored_origins: vec![],
340                wildcard_origins: vec![
341                    HeaderValue::from_static("literally"),
342                    HeaderValue::from_static("https://anything.com"),
343                ],
344                invalid_origins: vec![],
345            },
346            TestCase {
347                allowed_origins: vec![
348                    HeaderValue::from_static("*"),
349                    HeaderValue::from_static("https://iwillbeignored.com"),
350                ],
351                mirrored_origins: vec![],
352                wildcard_origins: vec![
353                    HeaderValue::from_static("literally"),
354                    HeaderValue::from_static("https://anything.com"),
355                ],
356                invalid_origins: vec![],
357            },
358        ];
359
360        for test_case in test_cases {
361            let allowed_origins = &test_case.allowed_origins;
362            let cors = CorsLayer::new()
363                .allow_methods([Method::GET])
364                .allow_origin(super::build_cors_allowed_origin(allowed_origins));
365            for valid in &test_case.mirrored_origins {
366                let header = test_request(&cors, valid).await;
367                assert_eq!(
368                    header.as_ref(),
369                    Some(valid),
370                    "origin {valid:?} unexpectedly not mirrored\n\
371                     allowed_origins={allowed_origins:?}",
372                );
373            }
374            for valid in &test_case.wildcard_origins {
375                let header = test_request(&cors, valid).await;
376                assert_eq!(
377                    header.as_ref(),
378                    Some(&HeaderValue::from_static("*")),
379                    "origin {valid:?} unexpectedly not allowed\n\
380                     allowed_origins={allowed_origins:?}",
381                );
382            }
383            for invalid in &test_case.invalid_origins {
384                let header = test_request(&cors, invalid).await;
385                assert_eq!(
386                    header, None,
387                    "origin {invalid:?} unexpectedly not allowed\n\
388                     allowed_origins={allowed_origins:?}",
389                );
390            }
391        }
392    }
393}