Skip to main content

mz_http_util/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! HTTP utilities.
11
12use askama::Template;
13use axum::Json;
14use axum::http::HeaderValue;
15use axum::http::Uri;
16use axum::http::status::StatusCode;
17use axum::response::{Html, IntoResponse};
18use axum_extra::TypedHeader;
19use headers::ContentType;
20use mz_ore::metrics::MetricsRegistry;
21use mz_ore::tracing::TracingHandle;
22use prometheus::Encoder;
23use serde::{Deserialize, Serialize};
24use serde_json::json;
25use tower_http::cors::AllowOrigin;
26use tracing_subscriber::EnvFilter;
27
28/// Renders a template into an HTTP response.
29pub fn template_response<T>(template: T) -> Html<String>
30where
31    T: Template,
32{
33    Html(template.render().expect("template rendering cannot fail"))
34}
35
36#[macro_export]
37/// Generates a `handle_static` function that serves static content for HTTP servers.
38/// Expects three arguments: an `include_dir::Dir` object where the static content is served,
39/// and two strings representing the (crate-local) paths to the production and development
40/// static files.
41macro_rules! make_handle_static {
42    (
43        dir_1: $dir_1:expr,
44        $(dir_2: $dir_2:expr,)?
45        prod_base_path: $prod_base_path:expr,
46        dev_base_path: $dev_base_path:expr$(,)?
47    ) => {
48        #[allow(clippy::unused_async)]
49        pub async fn handle_static(
50            path: ::axum::extract::Path<String>,
51        ) -> impl ::axum::response::IntoResponse {
52            #[cfg(not(feature = "dev-web"))]
53            const DIR_1: ::include_dir::Dir = $dir_1;
54            $(
55                #[cfg(not(feature = "dev-web"))]
56                const DIR_2: ::include_dir::Dir = $dir_2;
57            )?
58
59
60            #[cfg(not(feature = "dev-web"))]
61            fn get_static_file(path: &str) -> Option<&'static [u8]> {
62                DIR_1.get_file(path).or_else(|| DIR_2.get_file(path)).map(|f| f.contents())
63            }
64
65            #[cfg(feature = "dev-web")]
66            fn get_static_file(path: &str) -> Option<Vec<u8>> {
67                use ::std::fs;
68
69                #[cfg(not(debug_assertions))]
70                compile_error!("cannot enable insecure `dev-web` feature in release mode");
71
72                // Prefer the unminified files in static-dev, if they exist.
73                let dev_path =
74                    format!("{}/{}/{}", env!("CARGO_MANIFEST_DIR"), $dev_base_path, path);
75                let prod_path = format!(
76                    "{}/{}/{}",
77                    env!("CARGO_MANIFEST_DIR"),
78                    $prod_base_path,
79                    path
80                );
81                match fs::read(dev_path).or_else(|_| fs::read(prod_path)) {
82                    Ok(contents) => Some(contents),
83                    Err(e) => {
84                        ::tracing::debug!("dev-web failed to load static file: {}: {}", path, e);
85                        None
86                    }
87                }
88            }
89            let path = path.strip_prefix('/').unwrap_or(&path);
90            let content_type = match ::std::path::Path::new(path)
91                .extension()
92                .and_then(|e| e.to_str())
93            {
94                Some("js") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
95                    ::mime::TEXT_JAVASCRIPT,
96                ))),
97                Some("css") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
98                    ::mime::TEXT_CSS,
99                ))),
100                None | Some(_) => None,
101            };
102            match get_static_file(path) {
103                Some(body) => Ok((content_type, body)),
104                None => Err((::http::StatusCode::NOT_FOUND, "not found")),
105            }
106        }
107    };
108}
109
110/// Serves a basic liveness check response
111#[allow(clippy::unused_async)]
112pub async fn handle_liveness_check() -> impl IntoResponse {
113    (StatusCode::OK, "Liveness check successful!")
114}
115
116/// Serves metrics from the selected metrics registry variant.
117#[allow(clippy::unused_async)]
118pub async fn handle_prometheus(registry: &MetricsRegistry) -> impl IntoResponse + use<> {
119    let mut buffer = Vec::new();
120    let encoder = prometheus::TextEncoder::new();
121    encoder
122        .encode(&registry.gather(), &mut buffer)
123        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
124    Ok::<_, (StatusCode, String)>((TypedHeader(ContentType::text()), buffer))
125}
126
127#[derive(Serialize, Deserialize)]
128pub struct DynamicFilterTarget {
129    targets: String,
130}
131
132/// Dynamically reloads a filter for a tracing layer.
133#[allow(clippy::unused_async)]
134pub async fn handle_reload_tracing_filter(
135    handle: &TracingHandle,
136    reload: fn(&TracingHandle, EnvFilter) -> Result<(), anyhow::Error>,
137    Json(cfg): Json<DynamicFilterTarget>,
138) -> impl IntoResponse {
139    match cfg.targets.parse::<EnvFilter>() {
140        Ok(targets) => match reload(handle, targets) {
141            Ok(()) => (StatusCode::OK, cfg.targets.to_string()),
142            Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
143        },
144        Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
145    }
146}
147
148/// Returns information about the current status of tracing.
149#[allow(clippy::unused_async)]
150pub async fn handle_tracing() -> impl IntoResponse {
151    (
152        StatusCode::OK,
153        Json(json!({
154            "current_level_filter": tracing::level_filters::LevelFilter::current().to_string()
155        })),
156    )
157}
158
159/// Construct a CORS policy to allow origins to query us via HTTP. If any bare
160/// '*' is passed, this allows any origin; otherwise, allows a list of origins,
161/// which can include wildcard subdomains. If the allowed origin starts with a
162/// '*', allow anything from that glob. Otherwise check for an exact match.
163pub fn build_cors_allowed_origin<'a, I>(allowed: I) -> AllowOrigin
164where
165    I: IntoIterator<Item = &'a HeaderValue>,
166{
167    fn wildcard_origin_match(origin: &HeaderValue, wildcard: &[u8]) -> bool {
168        let Some(origin) = origin.to_str().ok() else {
169            return false;
170        };
171        let Ok(origin) = origin.parse::<Uri>() else {
172            return false;
173        };
174        let Some(host) = origin.host() else {
175            return false;
176        };
177
178        host.as_bytes().ends_with(wildcard)
179    }
180
181    let allowed = allowed.into_iter().cloned().collect::<Vec<HeaderValue>>();
182    if allowed.iter().any(|o| o.as_bytes() == b"*") {
183        AllowOrigin::any()
184    } else {
185        AllowOrigin::predicate(move |origin: &HeaderValue, _request_parts: _| {
186            for val in &allowed {
187                if (val.as_bytes().starts_with(b"*.")
188                    && wildcard_origin_match(origin, &val.as_bytes()[1..]))
189                    || origin == val
190                {
191                    return true;
192                }
193            }
194            false
195        })
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use http::header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN};
202    use http::{HeaderValue, Method, Request, Response};
203    use tower::{Service, ServiceBuilder, ServiceExt};
204    use tower_http::cors::CorsLayer;
205
206    #[mz_ore::test(tokio::test)]
207    async fn test_cors() {
208        async fn test_request(cors: &CorsLayer, origin: &HeaderValue) -> Option<HeaderValue> {
209            let mut service = ServiceBuilder::new()
210                .layer(cors)
211                .service_fn(|_| async { Ok::<_, anyhow::Error>(Response::new("")) });
212            let request = Request::builder().header(ORIGIN, origin).body("").unwrap();
213            let response = service.ready().await.unwrap().call(request).await.unwrap();
214            response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).cloned()
215        }
216
217        #[derive(Default)]
218        struct TestCase {
219            /// The allowed origins to provide as input.
220            allowed_origins: Vec<HeaderValue>,
221            /// Request origins that are expected to be mirrored back in the
222            /// response.
223            mirrored_origins: Vec<HeaderValue>,
224            /// Request origins that are expected to be allowed via a `*`
225            /// response.
226            wildcard_origins: Vec<HeaderValue>,
227            /// Request origins that are expected to be rejected.
228            invalid_origins: Vec<HeaderValue>,
229        }
230
231        let test_cases = [
232            TestCase {
233                allowed_origins: vec![HeaderValue::from_static("https://example.org")],
234                mirrored_origins: vec![HeaderValue::from_static("https://example.org")],
235                invalid_origins: vec![HeaderValue::from_static("https://wrong.com")],
236                wildcard_origins: vec![],
237            },
238            TestCase {
239                allowed_origins: vec![HeaderValue::from_static("*.example.org")],
240                mirrored_origins: vec![
241                    HeaderValue::from_static("https://foo.example.org"),
242                    HeaderValue::from_static("https://foo.example.org:8443"),
243                    HeaderValue::from_static("https://bar.example.org"),
244                    HeaderValue::from_static("https://baz.bar.foo.example.org"),
245                ],
246                wildcard_origins: vec![],
247                invalid_origins: vec![
248                    HeaderValue::from_static("https://example.org"),
249                    HeaderValue::from_static("https://wrong.com"),
250                    HeaderValue::from_static("https://wrong.com:8443"),
251                ],
252            },
253            TestCase {
254                allowed_origins: vec![
255                    HeaderValue::from_static("*.example.org"),
256                    HeaderValue::from_static("https://other.com"),
257                ],
258                mirrored_origins: vec![
259                    HeaderValue::from_static("https://foo.example.org"),
260                    HeaderValue::from_static("https://foo.example.org:8443"),
261                    HeaderValue::from_static("https://bar.example.org"),
262                    HeaderValue::from_static("https://baz.bar.foo.example.org"),
263                    HeaderValue::from_static("https://other.com"),
264                ],
265                wildcard_origins: vec![],
266                invalid_origins: vec![HeaderValue::from_static("https://example.org")],
267            },
268            TestCase {
269                allowed_origins: vec![HeaderValue::from_static("*")],
270                mirrored_origins: vec![],
271                wildcard_origins: vec![
272                    HeaderValue::from_static("literally"),
273                    HeaderValue::from_static("https://anything.com"),
274                ],
275                invalid_origins: vec![],
276            },
277            TestCase {
278                allowed_origins: vec![
279                    HeaderValue::from_static("*"),
280                    HeaderValue::from_static("https://iwillbeignored.com"),
281                ],
282                mirrored_origins: vec![],
283                wildcard_origins: vec![
284                    HeaderValue::from_static("literally"),
285                    HeaderValue::from_static("https://anything.com"),
286                ],
287                invalid_origins: vec![],
288            },
289        ];
290
291        for test_case in test_cases {
292            let allowed_origins = &test_case.allowed_origins;
293            let cors = CorsLayer::new()
294                .allow_methods([Method::GET])
295                .allow_origin(super::build_cors_allowed_origin(allowed_origins));
296            for valid in &test_case.mirrored_origins {
297                let header = test_request(&cors, valid).await;
298                assert_eq!(
299                    header.as_ref(),
300                    Some(valid),
301                    "origin {valid:?} unexpectedly not mirrored\n\
302                     allowed_origins={allowed_origins:?}",
303                );
304            }
305            for valid in &test_case.wildcard_origins {
306                let header = test_request(&cors, valid).await;
307                assert_eq!(
308                    header.as_ref(),
309                    Some(&HeaderValue::from_static("*")),
310                    "origin {valid:?} unexpectedly not allowed\n\
311                     allowed_origins={allowed_origins:?}",
312                );
313            }
314            for invalid in &test_case.invalid_origins {
315                let header = test_request(&cors, invalid).await;
316                assert_eq!(
317                    header, None,
318                    "origin {invalid:?} unexpectedly not allowed\n\
319                     allowed_origins={allowed_origins:?}",
320                );
321            }
322        }
323    }
324}