mz_http_util/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! HTTP utilities.
11
12use askama::Template;
13use axum::Json;
14use axum::http::HeaderValue;
15use axum::http::status::StatusCode;
16use axum::response::{Html, IntoResponse};
17use axum_extra::TypedHeader;
18use headers::ContentType;
19use mz_ore::metrics::MetricsRegistry;
20use mz_ore::tracing::TracingHandle;
21use prometheus::Encoder;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24use tower_http::cors::AllowOrigin;
25use tracing_subscriber::EnvFilter;
26
27/// Renders a template into an HTTP response.
28pub fn template_response<T>(template: T) -> Html<String>
29where
30    T: Template,
31{
32    Html(template.render().expect("template rendering cannot fail"))
33}
34
35#[macro_export]
36/// Generates a `handle_static` function that serves static content for HTTP servers.
37/// Expects three arguments: an `include_dir::Dir` object where the static content is served,
38/// and two strings representing the (crate-local) paths to the production and development
39/// static files.
40macro_rules! make_handle_static {
41    (
42        dir_1: $dir_1:expr,
43        $(dir_2: $dir_2:expr,)?
44        prod_base_path: $prod_base_path:expr,
45        dev_base_path: $dev_base_path:expr$(,)?
46    ) => {
47        #[allow(clippy::unused_async)]
48        pub async fn handle_static(
49            path: ::axum::extract::Path<String>,
50        ) -> impl ::axum::response::IntoResponse {
51            #[cfg(not(feature = "dev-web"))]
52            const DIR_1: ::include_dir::Dir = $dir_1;
53            $(
54                #[cfg(not(feature = "dev-web"))]
55                const DIR_2: ::include_dir::Dir = $dir_2;
56            )?
57
58
59            #[cfg(not(feature = "dev-web"))]
60            fn get_static_file(path: &str) -> Option<&'static [u8]> {
61                DIR_1.get_file(path).or_else(|| DIR_2.get_file(path)).map(|f| f.contents())
62            }
63
64            #[cfg(feature = "dev-web")]
65            fn get_static_file(path: &str) -> Option<Vec<u8>> {
66                use ::std::fs;
67
68                #[cfg(not(debug_assertions))]
69                compile_error!("cannot enable insecure `dev-web` feature in release mode");
70
71                // Prefer the unminified files in static-dev, if they exist.
72                let dev_path =
73                    format!("{}/{}/{}", env!("CARGO_MANIFEST_DIR"), $dev_base_path, path);
74                let prod_path = format!(
75                    "{}/{}/{}",
76                    env!("CARGO_MANIFEST_DIR"),
77                    $prod_base_path,
78                    path
79                );
80                match fs::read(dev_path).or_else(|_| fs::read(prod_path)) {
81                    Ok(contents) => Some(contents),
82                    Err(e) => {
83                        ::tracing::debug!("dev-web failed to load static file: {}: {}", path, e);
84                        None
85                    }
86                }
87            }
88            let path = path.strip_prefix('/').unwrap_or(&path);
89            let content_type = match ::std::path::Path::new(path)
90                .extension()
91                .and_then(|e| e.to_str())
92            {
93                Some("js") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
94                    ::mime::TEXT_JAVASCRIPT,
95                ))),
96                Some("css") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
97                    ::mime::TEXT_CSS,
98                ))),
99                None | Some(_) => None,
100            };
101            match get_static_file(path) {
102                Some(body) => Ok((content_type, body)),
103                None => Err((::http::StatusCode::NOT_FOUND, "not found")),
104            }
105        }
106    };
107}
108
109/// Serves a basic liveness check response
110#[allow(clippy::unused_async)]
111pub async fn handle_liveness_check() -> impl IntoResponse {
112    (StatusCode::OK, "Liveness check successful!")
113}
114
115/// Serves metrics from the selected metrics registry variant.
116#[allow(clippy::unused_async)]
117pub async fn handle_prometheus(registry: &MetricsRegistry) -> impl IntoResponse + use<> {
118    let mut buffer = Vec::new();
119    let encoder = prometheus::TextEncoder::new();
120    encoder
121        .encode(&registry.gather(), &mut buffer)
122        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
123    Ok::<_, (StatusCode, String)>((TypedHeader(ContentType::text()), buffer))
124}
125
126#[derive(Serialize, Deserialize)]
127pub struct DynamicFilterTarget {
128    targets: String,
129}
130
131/// Dynamically reloads a filter for a tracing layer.
132#[allow(clippy::unused_async)]
133pub async fn handle_reload_tracing_filter(
134    handle: &TracingHandle,
135    reload: fn(&TracingHandle, EnvFilter) -> Result<(), anyhow::Error>,
136    Json(cfg): Json<DynamicFilterTarget>,
137) -> impl IntoResponse {
138    match cfg.targets.parse::<EnvFilter>() {
139        Ok(targets) => match reload(handle, targets) {
140            Ok(()) => (StatusCode::OK, cfg.targets.to_string()),
141            Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
142        },
143        Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
144    }
145}
146
147/// Returns information about the current status of tracing.
148#[allow(clippy::unused_async)]
149pub async fn handle_tracing() -> impl IntoResponse {
150    (
151        StatusCode::OK,
152        Json(json!({
153            "current_level_filter": tracing::level_filters::LevelFilter::current().to_string()
154        })),
155    )
156}
157
158/// Construct a CORS policy to allow origins to query us via HTTP. If any bare
159/// '*' is passed, this allows any origin; otherwise, allows a list of origins,
160/// which can include wildcard subdomains. If the allowed origin starts with a
161/// '*', allow anything from that glob. Otherwise check for an exact match.
162pub fn build_cors_allowed_origin<'a, I>(allowed: I) -> AllowOrigin
163where
164    I: IntoIterator<Item = &'a HeaderValue>,
165{
166    let allowed = allowed.into_iter().cloned().collect::<Vec<HeaderValue>>();
167    if allowed.iter().any(|o| o.as_bytes() == b"*") {
168        AllowOrigin::any()
169    } else {
170        AllowOrigin::predicate(move |origin: &HeaderValue, _request_parts: _| {
171            for val in &allowed {
172                if (val.as_bytes().starts_with(b"*.")
173                    && origin.as_bytes().ends_with(&val.as_bytes()[1..]))
174                    || origin == val
175                {
176                    return true;
177                }
178            }
179            false
180        })
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use http::header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN};
187    use http::{HeaderValue, Method, Request, Response};
188    use tower::{Service, ServiceBuilder, ServiceExt};
189    use tower_http::cors::CorsLayer;
190
191    #[mz_ore::test(tokio::test)]
192    async fn test_cors() {
193        async fn test_request(cors: &CorsLayer, origin: &HeaderValue) -> Option<HeaderValue> {
194            let mut service = ServiceBuilder::new()
195                .layer(cors)
196                .service_fn(|_| async { Ok::<_, anyhow::Error>(Response::new("")) });
197            let request = Request::builder().header(ORIGIN, origin).body("").unwrap();
198            let response = service.ready().await.unwrap().call(request).await.unwrap();
199            response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).cloned()
200        }
201
202        #[derive(Default)]
203        struct TestCase {
204            /// The allowed origins to provide as input.
205            allowed_origins: Vec<HeaderValue>,
206            /// Request origins that are expected to be mirrored back in the
207            /// response.
208            mirrored_origins: Vec<HeaderValue>,
209            /// Request origins that are expected to be allowed via a `*`
210            /// response.
211            wildcard_origins: Vec<HeaderValue>,
212            /// Request origins that are expected to be rejected.
213            invalid_origins: Vec<HeaderValue>,
214        }
215
216        let test_cases = [
217            TestCase {
218                allowed_origins: vec![HeaderValue::from_static("https://example.org")],
219                mirrored_origins: vec![HeaderValue::from_static("https://example.org")],
220                invalid_origins: vec![HeaderValue::from_static("https://wrong.com")],
221                wildcard_origins: vec![],
222            },
223            TestCase {
224                allowed_origins: vec![HeaderValue::from_static("*.example.org")],
225                mirrored_origins: vec![
226                    HeaderValue::from_static("https://foo.example.org"),
227                    HeaderValue::from_static("https://bar.example.org"),
228                    HeaderValue::from_static("https://baz.bar.foo.example.org"),
229                ],
230                wildcard_origins: vec![],
231                invalid_origins: vec![
232                    HeaderValue::from_static("https://example.org"),
233                    HeaderValue::from_static("https://wrong.com"),
234                ],
235            },
236            TestCase {
237                allowed_origins: vec![
238                    HeaderValue::from_static("*.example.org"),
239                    HeaderValue::from_static("https://other.com"),
240                ],
241                mirrored_origins: vec![
242                    HeaderValue::from_static("https://foo.example.org"),
243                    HeaderValue::from_static("https://bar.example.org"),
244                    HeaderValue::from_static("https://baz.bar.foo.example.org"),
245                    HeaderValue::from_static("https://other.com"),
246                ],
247                wildcard_origins: vec![],
248                invalid_origins: vec![HeaderValue::from_static("https://example.org")],
249            },
250            TestCase {
251                allowed_origins: vec![HeaderValue::from_static("*")],
252                mirrored_origins: vec![],
253                wildcard_origins: vec![
254                    HeaderValue::from_static("literally"),
255                    HeaderValue::from_static("https://anything.com"),
256                ],
257                invalid_origins: vec![],
258            },
259            TestCase {
260                allowed_origins: vec![
261                    HeaderValue::from_static("*"),
262                    HeaderValue::from_static("https://iwillbeignored.com"),
263                ],
264                mirrored_origins: vec![],
265                wildcard_origins: vec![
266                    HeaderValue::from_static("literally"),
267                    HeaderValue::from_static("https://anything.com"),
268                ],
269                invalid_origins: vec![],
270            },
271        ];
272
273        for test_case in test_cases {
274            let allowed_origins = &test_case.allowed_origins;
275            let cors = CorsLayer::new()
276                .allow_methods([Method::GET])
277                .allow_origin(super::build_cors_allowed_origin(allowed_origins));
278            for valid in &test_case.mirrored_origins {
279                let header = test_request(&cors, valid).await;
280                assert_eq!(
281                    header.as_ref(),
282                    Some(valid),
283                    "origin {valid:?} unexpectedly not mirrored\n\
284                     allowed_origins={allowed_origins:?}",
285                );
286            }
287            for valid in &test_case.wildcard_origins {
288                let header = test_request(&cors, valid).await;
289                assert_eq!(
290                    header.as_ref(),
291                    Some(&HeaderValue::from_static("*")),
292                    "origin {valid:?} unexpectedly not allowed\n\
293                     allowed_origins={allowed_origins:?}",
294                );
295            }
296            for invalid in &test_case.invalid_origins {
297                let header = test_request(&cors, invalid).await;
298                assert_eq!(
299                    header, None,
300                    "origin {invalid:?} unexpectedly not allowed\n\
301                     allowed_origins={allowed_origins:?}",
302                );
303            }
304        }
305    }
306}