1use askama::Template;
13use axum::Json;
14use axum::http::HeaderMap;
15use axum::http::HeaderValue;
16use axum::http::Uri;
17use axum::http::status::StatusCode;
18use axum::response::{Html, IntoResponse, Response};
19use axum_extra::TypedHeader;
20use headers::ContentType;
21use mz_ore::metrics::MetricsRegistry;
22use mz_ore::tracing::TracingHandle;
23use prometheus::Encoder;
24use serde::{Deserialize, Serialize};
25use serde_json::json;
26use tower_http::cors::AllowOrigin;
27use tracing_subscriber::EnvFilter;
28
29pub const PROMETHEUS_PROTOBUF_CONTENT_TYPE: &str = "application/vnd.google.protobuf; \
32 proto=io.prometheus.client.MetricFamily; \
33 encoding=delimited";
34
35fn wants_prometheus_protobuf(headers: &HeaderMap) -> bool {
36 headers
37 .get_all(axum::http::header::ACCEPT)
38 .iter()
39 .filter_map(|v| v.to_str().ok())
40 .any(|v| v.contains(PROMETHEUS_PROTOBUF_CONTENT_TYPE))
41}
42
43pub fn template_response<T>(template: T) -> Html<String>
45where
46 T: Template,
47{
48 Html(template.render().expect("template rendering cannot fail"))
49}
50
51#[macro_export]
52macro_rules! make_handle_static {
57 (
58 dir_1: $dir_1:expr,
59 $(dir_2: $dir_2:expr,)?
60 prod_base_path: $prod_base_path:expr,
61 dev_base_path: $dev_base_path:expr$(,)?
62 ) => {
63 #[allow(clippy::unused_async)]
64 pub async fn handle_static(
65 path: ::axum::extract::Path<String>,
66 ) -> impl ::axum::response::IntoResponse {
67 #[cfg(not(feature = "dev-web"))]
68 const DIR_1: ::include_dir::Dir = $dir_1;
69 $(
70 #[cfg(not(feature = "dev-web"))]
71 const DIR_2: ::include_dir::Dir = $dir_2;
72 )?
73
74
75 #[cfg(not(feature = "dev-web"))]
76 fn get_static_file(path: &str) -> Option<&'static [u8]> {
77 DIR_1.get_file(path).or_else(|| DIR_2.get_file(path)).map(|f| f.contents())
78 }
79
80 #[cfg(feature = "dev-web")]
81 fn get_static_file(path: &str) -> Option<Vec<u8>> {
82 use ::std::fs;
83
84 #[cfg(not(debug_assertions))]
85 compile_error!("cannot enable insecure `dev-web` feature in release mode");
86
87 let dev_path =
89 format!("{}/{}/{}", env!("CARGO_MANIFEST_DIR"), $dev_base_path, path);
90 let prod_path = format!(
91 "{}/{}/{}",
92 env!("CARGO_MANIFEST_DIR"),
93 $prod_base_path,
94 path
95 );
96 match fs::read(dev_path).or_else(|_| fs::read(prod_path)) {
97 Ok(contents) => Some(contents),
98 Err(e) => {
99 ::tracing::debug!("dev-web failed to load static file: {}: {}", path, e);
100 None
101 }
102 }
103 }
104 let path = path.strip_prefix('/').unwrap_or(&path);
105 let content_type = match ::std::path::Path::new(path)
106 .extension()
107 .and_then(|e| e.to_str())
108 {
109 Some("js") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
110 ::mime::TEXT_JAVASCRIPT,
111 ))),
112 Some("css") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
113 ::mime::TEXT_CSS,
114 ))),
115 None | Some(_) => None,
116 };
117 match get_static_file(path) {
118 Some(body) => Ok((content_type, body)),
119 None => Err((::http::StatusCode::NOT_FOUND, "not found")),
120 }
121 }
122 };
123}
124
125#[allow(clippy::unused_async)]
127pub async fn handle_liveness_check() -> impl IntoResponse {
128 (StatusCode::OK, "Liveness check successful!")
129}
130
131#[allow(clippy::unused_async)]
138pub async fn handle_prometheus(
139 registry: &MetricsRegistry,
140 headers: HeaderMap,
141) -> Result<Response, (StatusCode, String)> {
142 let families = registry.gather();
143 let mut buf = Vec::new();
144 let content_type = if wants_prometheus_protobuf(&headers) {
145 prometheus::ProtobufEncoder::new()
146 .encode(&families, &mut buf)
147 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
148 ContentType::from(
149 PROMETHEUS_PROTOBUF_CONTENT_TYPE
150 .parse::<mime::Mime>()
151 .map_err(|e: mime::FromStrError| {
152 (StatusCode::INTERNAL_SERVER_ERROR, e.to_string())
153 })?,
154 )
155 } else {
156 prometheus::TextEncoder::new()
157 .encode(&families, &mut buf)
158 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
159 ContentType::text()
160 };
161
162 Ok((TypedHeader(content_type), buf).into_response())
163}
164
165#[derive(Serialize, Deserialize)]
166pub struct DynamicFilterTarget {
167 targets: String,
168}
169
170#[allow(clippy::unused_async)]
172pub async fn handle_reload_tracing_filter(
173 handle: &TracingHandle,
174 reload: fn(&TracingHandle, EnvFilter) -> Result<(), anyhow::Error>,
175 Json(cfg): Json<DynamicFilterTarget>,
176) -> impl IntoResponse {
177 match cfg.targets.parse::<EnvFilter>() {
178 Ok(targets) => match reload(handle, targets) {
179 Ok(()) => (StatusCode::OK, cfg.targets.to_string()),
180 Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
181 },
182 Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
183 }
184}
185
186#[allow(clippy::unused_async)]
188pub async fn handle_tracing() -> impl IntoResponse {
189 (
190 StatusCode::OK,
191 Json(json!({
192 "current_level_filter": tracing::level_filters::LevelFilter::current().to_string()
193 })),
194 )
195}
196
197pub fn origin_is_allowed(origin: &HeaderValue, allowed: &[HeaderValue]) -> bool {
200 fn wildcard_origin_match(origin: &HeaderValue, wildcard: &[u8]) -> bool {
201 let Some(origin) = origin.to_str().ok() else {
202 return false;
203 };
204 let Ok(origin) = origin.parse::<Uri>() else {
205 return false;
206 };
207 let Some(host) = origin.host() else {
208 return false;
209 };
210
211 host.as_bytes().ends_with(wildcard)
212 }
213
214 if allowed.iter().any(|o| o.as_bytes() == b"*") {
215 return true;
216 }
217 for val in allowed {
218 if (val.as_bytes().starts_with(b"*.")
219 && wildcard_origin_match(origin, &val.as_bytes()[1..]))
220 || origin == val
221 {
222 return true;
223 }
224 }
225 false
226}
227
228pub fn build_cors_allowed_origin<'a, I>(allowed: I) -> AllowOrigin
233where
234 I: IntoIterator<Item = &'a HeaderValue>,
235{
236 fn wildcard_origin_match(origin: &HeaderValue, wildcard: &[u8]) -> bool {
237 let Some(origin) = origin.to_str().ok() else {
238 return false;
239 };
240 let Ok(origin) = origin.parse::<Uri>() else {
241 return false;
242 };
243 let Some(host) = origin.host() else {
244 return false;
245 };
246
247 host.as_bytes().ends_with(wildcard)
248 }
249
250 let allowed = allowed.into_iter().cloned().collect::<Vec<HeaderValue>>();
251 if allowed.iter().any(|o| o.as_bytes() == b"*") {
252 AllowOrigin::any()
253 } else {
254 AllowOrigin::predicate(move |origin: &HeaderValue, _request_parts: _| {
255 for val in &allowed {
256 if (val.as_bytes().starts_with(b"*.")
257 && wildcard_origin_match(origin, &val.as_bytes()[1..]))
258 || origin == val
259 {
260 return true;
261 }
262 }
263 false
264 })
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use http::header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN};
271 use http::{HeaderValue, Method, Request, Response};
272 use tower::{Service, ServiceBuilder, ServiceExt};
273 use tower_http::cors::CorsLayer;
274
275 #[mz_ore::test(tokio::test)]
276 async fn test_cors() {
277 async fn test_request(cors: &CorsLayer, origin: &HeaderValue) -> Option<HeaderValue> {
278 let mut service = ServiceBuilder::new()
279 .layer(cors)
280 .service_fn(|_| async { Ok::<_, anyhow::Error>(Response::new("")) });
281 let request = Request::builder().header(ORIGIN, origin).body("").unwrap();
282 let response = service.ready().await.unwrap().call(request).await.unwrap();
283 response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).cloned()
284 }
285
286 #[derive(Default)]
287 struct TestCase {
288 allowed_origins: Vec<HeaderValue>,
290 mirrored_origins: Vec<HeaderValue>,
293 wildcard_origins: Vec<HeaderValue>,
296 invalid_origins: Vec<HeaderValue>,
298 }
299
300 let test_cases = [
301 TestCase {
302 allowed_origins: vec![HeaderValue::from_static("https://example.org")],
303 mirrored_origins: vec![HeaderValue::from_static("https://example.org")],
304 invalid_origins: vec![HeaderValue::from_static("https://wrong.com")],
305 wildcard_origins: vec![],
306 },
307 TestCase {
308 allowed_origins: vec![HeaderValue::from_static("*.example.org")],
309 mirrored_origins: vec![
310 HeaderValue::from_static("https://foo.example.org"),
311 HeaderValue::from_static("https://foo.example.org:8443"),
312 HeaderValue::from_static("https://bar.example.org"),
313 HeaderValue::from_static("https://baz.bar.foo.example.org"),
314 ],
315 wildcard_origins: vec![],
316 invalid_origins: vec![
317 HeaderValue::from_static("https://example.org"),
318 HeaderValue::from_static("https://wrong.com"),
319 HeaderValue::from_static("https://wrong.com:8443"),
320 ],
321 },
322 TestCase {
323 allowed_origins: vec![
324 HeaderValue::from_static("*.example.org"),
325 HeaderValue::from_static("https://other.com"),
326 ],
327 mirrored_origins: vec![
328 HeaderValue::from_static("https://foo.example.org"),
329 HeaderValue::from_static("https://foo.example.org:8443"),
330 HeaderValue::from_static("https://bar.example.org"),
331 HeaderValue::from_static("https://baz.bar.foo.example.org"),
332 HeaderValue::from_static("https://other.com"),
333 ],
334 wildcard_origins: vec![],
335 invalid_origins: vec![HeaderValue::from_static("https://example.org")],
336 },
337 TestCase {
338 allowed_origins: vec![HeaderValue::from_static("*")],
339 mirrored_origins: vec![],
340 wildcard_origins: vec![
341 HeaderValue::from_static("literally"),
342 HeaderValue::from_static("https://anything.com"),
343 ],
344 invalid_origins: vec![],
345 },
346 TestCase {
347 allowed_origins: vec![
348 HeaderValue::from_static("*"),
349 HeaderValue::from_static("https://iwillbeignored.com"),
350 ],
351 mirrored_origins: vec![],
352 wildcard_origins: vec![
353 HeaderValue::from_static("literally"),
354 HeaderValue::from_static("https://anything.com"),
355 ],
356 invalid_origins: vec![],
357 },
358 ];
359
360 for test_case in test_cases {
361 let allowed_origins = &test_case.allowed_origins;
362 let cors = CorsLayer::new()
363 .allow_methods([Method::GET])
364 .allow_origin(super::build_cors_allowed_origin(allowed_origins));
365 for valid in &test_case.mirrored_origins {
366 let header = test_request(&cors, valid).await;
367 assert_eq!(
368 header.as_ref(),
369 Some(valid),
370 "origin {valid:?} unexpectedly not mirrored\n\
371 allowed_origins={allowed_origins:?}",
372 );
373 }
374 for valid in &test_case.wildcard_origins {
375 let header = test_request(&cors, valid).await;
376 assert_eq!(
377 header.as_ref(),
378 Some(&HeaderValue::from_static("*")),
379 "origin {valid:?} unexpectedly not allowed\n\
380 allowed_origins={allowed_origins:?}",
381 );
382 }
383 for invalid in &test_case.invalid_origins {
384 let header = test_request(&cors, invalid).await;
385 assert_eq!(
386 header, None,
387 "origin {invalid:?} unexpectedly not allowed\n\
388 allowed_origins={allowed_origins:?}",
389 );
390 }
391 }
392 }
393}