Skip to main content

mz_http_util/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! HTTP utilities.
11
12use askama::Template;
13use axum::Json;
14use axum::http::HeaderValue;
15use axum::http::Uri;
16use axum::http::status::StatusCode;
17use axum::response::{Html, IntoResponse};
18use axum_extra::TypedHeader;
19use headers::ContentType;
20use mz_ore::metrics::MetricsRegistry;
21use mz_ore::tracing::TracingHandle;
22use prometheus::Encoder;
23use serde::{Deserialize, Serialize};
24use serde_json::json;
25use tower_http::cors::AllowOrigin;
26use tracing_subscriber::EnvFilter;
27
28/// Renders a template into an HTTP response.
29pub fn template_response<T>(template: T) -> Html<String>
30where
31    T: Template,
32{
33    Html(template.render().expect("template rendering cannot fail"))
34}
35
36#[macro_export]
37/// Generates a `handle_static` function that serves static content for HTTP servers.
38/// Expects three arguments: an `include_dir::Dir` object where the static content is served,
39/// and two strings representing the (crate-local) paths to the production and development
40/// static files.
41macro_rules! make_handle_static {
42    (
43        dir_1: $dir_1:expr,
44        $(dir_2: $dir_2:expr,)?
45        prod_base_path: $prod_base_path:expr,
46        dev_base_path: $dev_base_path:expr$(,)?
47    ) => {
48        #[allow(clippy::unused_async)]
49        pub async fn handle_static(
50            path: ::axum::extract::Path<String>,
51        ) -> impl ::axum::response::IntoResponse {
52            #[cfg(not(feature = "dev-web"))]
53            const DIR_1: ::include_dir::Dir = $dir_1;
54            $(
55                #[cfg(not(feature = "dev-web"))]
56                const DIR_2: ::include_dir::Dir = $dir_2;
57            )?
58
59
60            #[cfg(not(feature = "dev-web"))]
61            fn get_static_file(path: &str) -> Option<&'static [u8]> {
62                DIR_1.get_file(path).or_else(|| DIR_2.get_file(path)).map(|f| f.contents())
63            }
64
65            #[cfg(feature = "dev-web")]
66            fn get_static_file(path: &str) -> Option<Vec<u8>> {
67                use ::std::fs;
68
69                #[cfg(not(debug_assertions))]
70                compile_error!("cannot enable insecure `dev-web` feature in release mode");
71
72                // Prefer the unminified files in static-dev, if they exist.
73                let dev_path =
74                    format!("{}/{}/{}", env!("CARGO_MANIFEST_DIR"), $dev_base_path, path);
75                let prod_path = format!(
76                    "{}/{}/{}",
77                    env!("CARGO_MANIFEST_DIR"),
78                    $prod_base_path,
79                    path
80                );
81                match fs::read(dev_path).or_else(|_| fs::read(prod_path)) {
82                    Ok(contents) => Some(contents),
83                    Err(e) => {
84                        ::tracing::debug!("dev-web failed to load static file: {}: {}", path, e);
85                        None
86                    }
87                }
88            }
89            let path = path.strip_prefix('/').unwrap_or(&path);
90            let content_type = match ::std::path::Path::new(path)
91                .extension()
92                .and_then(|e| e.to_str())
93            {
94                Some("js") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
95                    ::mime::TEXT_JAVASCRIPT,
96                ))),
97                Some("css") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
98                    ::mime::TEXT_CSS,
99                ))),
100                None | Some(_) => None,
101            };
102            match get_static_file(path) {
103                Some(body) => Ok((content_type, body)),
104                None => Err((::http::StatusCode::NOT_FOUND, "not found")),
105            }
106        }
107    };
108}
109
110/// Serves a basic liveness check response
111#[allow(clippy::unused_async)]
112pub async fn handle_liveness_check() -> impl IntoResponse {
113    (StatusCode::OK, "Liveness check successful!")
114}
115
116/// Serves metrics from the selected metrics registry variant.
117#[allow(clippy::unused_async)]
118pub async fn handle_prometheus(registry: &MetricsRegistry) -> impl IntoResponse + use<> {
119    let mut buffer = Vec::new();
120    let encoder = prometheus::TextEncoder::new();
121    encoder
122        .encode(&registry.gather(), &mut buffer)
123        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
124    Ok::<_, (StatusCode, String)>((TypedHeader(ContentType::text()), buffer))
125}
126
127#[derive(Serialize, Deserialize)]
128pub struct DynamicFilterTarget {
129    targets: String,
130}
131
132/// Dynamically reloads a filter for a tracing layer.
133#[allow(clippy::unused_async)]
134pub async fn handle_reload_tracing_filter(
135    handle: &TracingHandle,
136    reload: fn(&TracingHandle, EnvFilter) -> Result<(), anyhow::Error>,
137    Json(cfg): Json<DynamicFilterTarget>,
138) -> impl IntoResponse {
139    match cfg.targets.parse::<EnvFilter>() {
140        Ok(targets) => match reload(handle, targets) {
141            Ok(()) => (StatusCode::OK, cfg.targets.to_string()),
142            Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
143        },
144        Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
145    }
146}
147
148/// Returns information about the current status of tracing.
149#[allow(clippy::unused_async)]
150pub async fn handle_tracing() -> impl IntoResponse {
151    (
152        StatusCode::OK,
153        Json(json!({
154            "current_level_filter": tracing::level_filters::LevelFilter::current().to_string()
155        })),
156    )
157}
158
159/// Returns true if `origin` matches any entry in `allowed`. Supports bare `*`
160/// (any origin), exact match, and wildcard subdomains (`*.example.com`).
161pub fn origin_is_allowed(origin: &HeaderValue, allowed: &[HeaderValue]) -> bool {
162    fn wildcard_origin_match(origin: &HeaderValue, wildcard: &[u8]) -> bool {
163        let Some(origin) = origin.to_str().ok() else {
164            return false;
165        };
166        let Ok(origin) = origin.parse::<Uri>() else {
167            return false;
168        };
169        let Some(host) = origin.host() else {
170            return false;
171        };
172
173        host.as_bytes().ends_with(wildcard)
174    }
175
176    if allowed.iter().any(|o| o.as_bytes() == b"*") {
177        return true;
178    }
179    for val in allowed {
180        if (val.as_bytes().starts_with(b"*.")
181            && wildcard_origin_match(origin, &val.as_bytes()[1..]))
182            || origin == val
183        {
184            return true;
185        }
186    }
187    false
188}
189
190/// Construct a CORS policy to allow origins to query us via HTTP. If any bare
191/// '*' is passed, this allows any origin; otherwise, allows a list of origins,
192/// which can include wildcard subdomains. If the allowed origin starts with a
193/// '*', allow anything from that glob. Otherwise check for an exact match.
194pub fn build_cors_allowed_origin<'a, I>(allowed: I) -> AllowOrigin
195where
196    I: IntoIterator<Item = &'a HeaderValue>,
197{
198    fn wildcard_origin_match(origin: &HeaderValue, wildcard: &[u8]) -> bool {
199        let Some(origin) = origin.to_str().ok() else {
200            return false;
201        };
202        let Ok(origin) = origin.parse::<Uri>() else {
203            return false;
204        };
205        let Some(host) = origin.host() else {
206            return false;
207        };
208
209        host.as_bytes().ends_with(wildcard)
210    }
211
212    let allowed = allowed.into_iter().cloned().collect::<Vec<HeaderValue>>();
213    if allowed.iter().any(|o| o.as_bytes() == b"*") {
214        AllowOrigin::any()
215    } else {
216        AllowOrigin::predicate(move |origin: &HeaderValue, _request_parts: _| {
217            for val in &allowed {
218                if (val.as_bytes().starts_with(b"*.")
219                    && wildcard_origin_match(origin, &val.as_bytes()[1..]))
220                    || origin == val
221                {
222                    return true;
223                }
224            }
225            false
226        })
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use http::header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN};
233    use http::{HeaderValue, Method, Request, Response};
234    use tower::{Service, ServiceBuilder, ServiceExt};
235    use tower_http::cors::CorsLayer;
236
237    #[mz_ore::test(tokio::test)]
238    async fn test_cors() {
239        async fn test_request(cors: &CorsLayer, origin: &HeaderValue) -> Option<HeaderValue> {
240            let mut service = ServiceBuilder::new()
241                .layer(cors)
242                .service_fn(|_| async { Ok::<_, anyhow::Error>(Response::new("")) });
243            let request = Request::builder().header(ORIGIN, origin).body("").unwrap();
244            let response = service.ready().await.unwrap().call(request).await.unwrap();
245            response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).cloned()
246        }
247
248        #[derive(Default)]
249        struct TestCase {
250            /// The allowed origins to provide as input.
251            allowed_origins: Vec<HeaderValue>,
252            /// Request origins that are expected to be mirrored back in the
253            /// response.
254            mirrored_origins: Vec<HeaderValue>,
255            /// Request origins that are expected to be allowed via a `*`
256            /// response.
257            wildcard_origins: Vec<HeaderValue>,
258            /// Request origins that are expected to be rejected.
259            invalid_origins: Vec<HeaderValue>,
260        }
261
262        let test_cases = [
263            TestCase {
264                allowed_origins: vec![HeaderValue::from_static("https://example.org")],
265                mirrored_origins: vec![HeaderValue::from_static("https://example.org")],
266                invalid_origins: vec![HeaderValue::from_static("https://wrong.com")],
267                wildcard_origins: vec![],
268            },
269            TestCase {
270                allowed_origins: vec![HeaderValue::from_static("*.example.org")],
271                mirrored_origins: vec![
272                    HeaderValue::from_static("https://foo.example.org"),
273                    HeaderValue::from_static("https://foo.example.org:8443"),
274                    HeaderValue::from_static("https://bar.example.org"),
275                    HeaderValue::from_static("https://baz.bar.foo.example.org"),
276                ],
277                wildcard_origins: vec![],
278                invalid_origins: vec![
279                    HeaderValue::from_static("https://example.org"),
280                    HeaderValue::from_static("https://wrong.com"),
281                    HeaderValue::from_static("https://wrong.com:8443"),
282                ],
283            },
284            TestCase {
285                allowed_origins: vec![
286                    HeaderValue::from_static("*.example.org"),
287                    HeaderValue::from_static("https://other.com"),
288                ],
289                mirrored_origins: vec![
290                    HeaderValue::from_static("https://foo.example.org"),
291                    HeaderValue::from_static("https://foo.example.org:8443"),
292                    HeaderValue::from_static("https://bar.example.org"),
293                    HeaderValue::from_static("https://baz.bar.foo.example.org"),
294                    HeaderValue::from_static("https://other.com"),
295                ],
296                wildcard_origins: vec![],
297                invalid_origins: vec![HeaderValue::from_static("https://example.org")],
298            },
299            TestCase {
300                allowed_origins: vec![HeaderValue::from_static("*")],
301                mirrored_origins: vec![],
302                wildcard_origins: vec![
303                    HeaderValue::from_static("literally"),
304                    HeaderValue::from_static("https://anything.com"),
305                ],
306                invalid_origins: vec![],
307            },
308            TestCase {
309                allowed_origins: vec![
310                    HeaderValue::from_static("*"),
311                    HeaderValue::from_static("https://iwillbeignored.com"),
312                ],
313                mirrored_origins: vec![],
314                wildcard_origins: vec![
315                    HeaderValue::from_static("literally"),
316                    HeaderValue::from_static("https://anything.com"),
317                ],
318                invalid_origins: vec![],
319            },
320        ];
321
322        for test_case in test_cases {
323            let allowed_origins = &test_case.allowed_origins;
324            let cors = CorsLayer::new()
325                .allow_methods([Method::GET])
326                .allow_origin(super::build_cors_allowed_origin(allowed_origins));
327            for valid in &test_case.mirrored_origins {
328                let header = test_request(&cors, valid).await;
329                assert_eq!(
330                    header.as_ref(),
331                    Some(valid),
332                    "origin {valid:?} unexpectedly not mirrored\n\
333                     allowed_origins={allowed_origins:?}",
334                );
335            }
336            for valid in &test_case.wildcard_origins {
337                let header = test_request(&cors, valid).await;
338                assert_eq!(
339                    header.as_ref(),
340                    Some(&HeaderValue::from_static("*")),
341                    "origin {valid:?} unexpectedly not allowed\n\
342                     allowed_origins={allowed_origins:?}",
343                );
344            }
345            for invalid in &test_case.invalid_origins {
346                let header = test_request(&cors, invalid).await;
347                assert_eq!(
348                    header, None,
349                    "origin {invalid:?} unexpectedly not allowed\n\
350                     allowed_origins={allowed_origins:?}",
351                );
352            }
353        }
354    }
355}