1use askama::Template;
13use axum::Json;
14use axum::http::HeaderValue;
15use axum::http::Uri;
16use axum::http::status::StatusCode;
17use axum::response::{Html, IntoResponse};
18use axum_extra::TypedHeader;
19use headers::ContentType;
20use mz_ore::metrics::MetricsRegistry;
21use mz_ore::tracing::TracingHandle;
22use prometheus::Encoder;
23use serde::{Deserialize, Serialize};
24use serde_json::json;
25use tower_http::cors::AllowOrigin;
26use tracing_subscriber::EnvFilter;
27
28pub fn template_response<T>(template: T) -> Html<String>
30where
31 T: Template,
32{
33 Html(template.render().expect("template rendering cannot fail"))
34}
35
36#[macro_export]
37macro_rules! make_handle_static {
42 (
43 dir_1: $dir_1:expr,
44 $(dir_2: $dir_2:expr,)?
45 prod_base_path: $prod_base_path:expr,
46 dev_base_path: $dev_base_path:expr$(,)?
47 ) => {
48 #[allow(clippy::unused_async)]
49 pub async fn handle_static(
50 path: ::axum::extract::Path<String>,
51 ) -> impl ::axum::response::IntoResponse {
52 #[cfg(not(feature = "dev-web"))]
53 const DIR_1: ::include_dir::Dir = $dir_1;
54 $(
55 #[cfg(not(feature = "dev-web"))]
56 const DIR_2: ::include_dir::Dir = $dir_2;
57 )?
58
59
60 #[cfg(not(feature = "dev-web"))]
61 fn get_static_file(path: &str) -> Option<&'static [u8]> {
62 DIR_1.get_file(path).or_else(|| DIR_2.get_file(path)).map(|f| f.contents())
63 }
64
65 #[cfg(feature = "dev-web")]
66 fn get_static_file(path: &str) -> Option<Vec<u8>> {
67 use ::std::fs;
68
69 #[cfg(not(debug_assertions))]
70 compile_error!("cannot enable insecure `dev-web` feature in release mode");
71
72 let dev_path =
74 format!("{}/{}/{}", env!("CARGO_MANIFEST_DIR"), $dev_base_path, path);
75 let prod_path = format!(
76 "{}/{}/{}",
77 env!("CARGO_MANIFEST_DIR"),
78 $prod_base_path,
79 path
80 );
81 match fs::read(dev_path).or_else(|_| fs::read(prod_path)) {
82 Ok(contents) => Some(contents),
83 Err(e) => {
84 ::tracing::debug!("dev-web failed to load static file: {}: {}", path, e);
85 None
86 }
87 }
88 }
89 let path = path.strip_prefix('/').unwrap_or(&path);
90 let content_type = match ::std::path::Path::new(path)
91 .extension()
92 .and_then(|e| e.to_str())
93 {
94 Some("js") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
95 ::mime::TEXT_JAVASCRIPT,
96 ))),
97 Some("css") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
98 ::mime::TEXT_CSS,
99 ))),
100 None | Some(_) => None,
101 };
102 match get_static_file(path) {
103 Some(body) => Ok((content_type, body)),
104 None => Err((::http::StatusCode::NOT_FOUND, "not found")),
105 }
106 }
107 };
108}
109
110#[allow(clippy::unused_async)]
112pub async fn handle_liveness_check() -> impl IntoResponse {
113 (StatusCode::OK, "Liveness check successful!")
114}
115
116#[allow(clippy::unused_async)]
118pub async fn handle_prometheus(registry: &MetricsRegistry) -> impl IntoResponse + use<> {
119 let mut buffer = Vec::new();
120 let encoder = prometheus::TextEncoder::new();
121 encoder
122 .encode(®istry.gather(), &mut buffer)
123 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
124 Ok::<_, (StatusCode, String)>((TypedHeader(ContentType::text()), buffer))
125}
126
127#[derive(Serialize, Deserialize)]
128pub struct DynamicFilterTarget {
129 targets: String,
130}
131
132#[allow(clippy::unused_async)]
134pub async fn handle_reload_tracing_filter(
135 handle: &TracingHandle,
136 reload: fn(&TracingHandle, EnvFilter) -> Result<(), anyhow::Error>,
137 Json(cfg): Json<DynamicFilterTarget>,
138) -> impl IntoResponse {
139 match cfg.targets.parse::<EnvFilter>() {
140 Ok(targets) => match reload(handle, targets) {
141 Ok(()) => (StatusCode::OK, cfg.targets.to_string()),
142 Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
143 },
144 Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
145 }
146}
147
148#[allow(clippy::unused_async)]
150pub async fn handle_tracing() -> impl IntoResponse {
151 (
152 StatusCode::OK,
153 Json(json!({
154 "current_level_filter": tracing::level_filters::LevelFilter::current().to_string()
155 })),
156 )
157}
158
159pub fn origin_is_allowed(origin: &HeaderValue, allowed: &[HeaderValue]) -> bool {
162 fn wildcard_origin_match(origin: &HeaderValue, wildcard: &[u8]) -> bool {
163 let Some(origin) = origin.to_str().ok() else {
164 return false;
165 };
166 let Ok(origin) = origin.parse::<Uri>() else {
167 return false;
168 };
169 let Some(host) = origin.host() else {
170 return false;
171 };
172
173 host.as_bytes().ends_with(wildcard)
174 }
175
176 if allowed.iter().any(|o| o.as_bytes() == b"*") {
177 return true;
178 }
179 for val in allowed {
180 if (val.as_bytes().starts_with(b"*.")
181 && wildcard_origin_match(origin, &val.as_bytes()[1..]))
182 || origin == val
183 {
184 return true;
185 }
186 }
187 false
188}
189
190pub fn build_cors_allowed_origin<'a, I>(allowed: I) -> AllowOrigin
195where
196 I: IntoIterator<Item = &'a HeaderValue>,
197{
198 fn wildcard_origin_match(origin: &HeaderValue, wildcard: &[u8]) -> bool {
199 let Some(origin) = origin.to_str().ok() else {
200 return false;
201 };
202 let Ok(origin) = origin.parse::<Uri>() else {
203 return false;
204 };
205 let Some(host) = origin.host() else {
206 return false;
207 };
208
209 host.as_bytes().ends_with(wildcard)
210 }
211
212 let allowed = allowed.into_iter().cloned().collect::<Vec<HeaderValue>>();
213 if allowed.iter().any(|o| o.as_bytes() == b"*") {
214 AllowOrigin::any()
215 } else {
216 AllowOrigin::predicate(move |origin: &HeaderValue, _request_parts: _| {
217 for val in &allowed {
218 if (val.as_bytes().starts_with(b"*.")
219 && wildcard_origin_match(origin, &val.as_bytes()[1..]))
220 || origin == val
221 {
222 return true;
223 }
224 }
225 false
226 })
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use http::header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN};
233 use http::{HeaderValue, Method, Request, Response};
234 use tower::{Service, ServiceBuilder, ServiceExt};
235 use tower_http::cors::CorsLayer;
236
237 #[mz_ore::test(tokio::test)]
238 async fn test_cors() {
239 async fn test_request(cors: &CorsLayer, origin: &HeaderValue) -> Option<HeaderValue> {
240 let mut service = ServiceBuilder::new()
241 .layer(cors)
242 .service_fn(|_| async { Ok::<_, anyhow::Error>(Response::new("")) });
243 let request = Request::builder().header(ORIGIN, origin).body("").unwrap();
244 let response = service.ready().await.unwrap().call(request).await.unwrap();
245 response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).cloned()
246 }
247
248 #[derive(Default)]
249 struct TestCase {
250 allowed_origins: Vec<HeaderValue>,
252 mirrored_origins: Vec<HeaderValue>,
255 wildcard_origins: Vec<HeaderValue>,
258 invalid_origins: Vec<HeaderValue>,
260 }
261
262 let test_cases = [
263 TestCase {
264 allowed_origins: vec![HeaderValue::from_static("https://example.org")],
265 mirrored_origins: vec![HeaderValue::from_static("https://example.org")],
266 invalid_origins: vec![HeaderValue::from_static("https://wrong.com")],
267 wildcard_origins: vec![],
268 },
269 TestCase {
270 allowed_origins: vec![HeaderValue::from_static("*.example.org")],
271 mirrored_origins: vec![
272 HeaderValue::from_static("https://foo.example.org"),
273 HeaderValue::from_static("https://foo.example.org:8443"),
274 HeaderValue::from_static("https://bar.example.org"),
275 HeaderValue::from_static("https://baz.bar.foo.example.org"),
276 ],
277 wildcard_origins: vec![],
278 invalid_origins: vec![
279 HeaderValue::from_static("https://example.org"),
280 HeaderValue::from_static("https://wrong.com"),
281 HeaderValue::from_static("https://wrong.com:8443"),
282 ],
283 },
284 TestCase {
285 allowed_origins: vec![
286 HeaderValue::from_static("*.example.org"),
287 HeaderValue::from_static("https://other.com"),
288 ],
289 mirrored_origins: vec![
290 HeaderValue::from_static("https://foo.example.org"),
291 HeaderValue::from_static("https://foo.example.org:8443"),
292 HeaderValue::from_static("https://bar.example.org"),
293 HeaderValue::from_static("https://baz.bar.foo.example.org"),
294 HeaderValue::from_static("https://other.com"),
295 ],
296 wildcard_origins: vec![],
297 invalid_origins: vec![HeaderValue::from_static("https://example.org")],
298 },
299 TestCase {
300 allowed_origins: vec![HeaderValue::from_static("*")],
301 mirrored_origins: vec![],
302 wildcard_origins: vec![
303 HeaderValue::from_static("literally"),
304 HeaderValue::from_static("https://anything.com"),
305 ],
306 invalid_origins: vec![],
307 },
308 TestCase {
309 allowed_origins: vec![
310 HeaderValue::from_static("*"),
311 HeaderValue::from_static("https://iwillbeignored.com"),
312 ],
313 mirrored_origins: vec![],
314 wildcard_origins: vec![
315 HeaderValue::from_static("literally"),
316 HeaderValue::from_static("https://anything.com"),
317 ],
318 invalid_origins: vec![],
319 },
320 ];
321
322 for test_case in test_cases {
323 let allowed_origins = &test_case.allowed_origins;
324 let cors = CorsLayer::new()
325 .allow_methods([Method::GET])
326 .allow_origin(super::build_cors_allowed_origin(allowed_origins));
327 for valid in &test_case.mirrored_origins {
328 let header = test_request(&cors, valid).await;
329 assert_eq!(
330 header.as_ref(),
331 Some(valid),
332 "origin {valid:?} unexpectedly not mirrored\n\
333 allowed_origins={allowed_origins:?}",
334 );
335 }
336 for valid in &test_case.wildcard_origins {
337 let header = test_request(&cors, valid).await;
338 assert_eq!(
339 header.as_ref(),
340 Some(&HeaderValue::from_static("*")),
341 "origin {valid:?} unexpectedly not allowed\n\
342 allowed_origins={allowed_origins:?}",
343 );
344 }
345 for invalid in &test_case.invalid_origins {
346 let header = test_request(&cors, invalid).await;
347 assert_eq!(
348 header, None,
349 "origin {invalid:?} unexpectedly not allowed\n\
350 allowed_origins={allowed_origins:?}",
351 );
352 }
353 }
354 }
355}