1use askama::Template;
13use axum::Json;
14use axum::http::HeaderValue;
15use axum::http::Uri;
16use axum::http::status::StatusCode;
17use axum::response::{Html, IntoResponse};
18use axum_extra::TypedHeader;
19use headers::ContentType;
20use mz_ore::metrics::MetricsRegistry;
21use mz_ore::tracing::TracingHandle;
22use prometheus::Encoder;
23use serde::{Deserialize, Serialize};
24use serde_json::json;
25use tower_http::cors::AllowOrigin;
26use tracing_subscriber::EnvFilter;
27
28pub fn template_response<T>(template: T) -> Html<String>
30where
31 T: Template,
32{
33 Html(template.render().expect("template rendering cannot fail"))
34}
35
36#[macro_export]
37macro_rules! make_handle_static {
42 (
43 dir_1: $dir_1:expr,
44 $(dir_2: $dir_2:expr,)?
45 prod_base_path: $prod_base_path:expr,
46 dev_base_path: $dev_base_path:expr$(,)?
47 ) => {
48 #[allow(clippy::unused_async)]
49 pub async fn handle_static(
50 path: ::axum::extract::Path<String>,
51 ) -> impl ::axum::response::IntoResponse {
52 #[cfg(not(feature = "dev-web"))]
53 const DIR_1: ::include_dir::Dir = $dir_1;
54 $(
55 #[cfg(not(feature = "dev-web"))]
56 const DIR_2: ::include_dir::Dir = $dir_2;
57 )?
58
59
60 #[cfg(not(feature = "dev-web"))]
61 fn get_static_file(path: &str) -> Option<&'static [u8]> {
62 DIR_1.get_file(path).or_else(|| DIR_2.get_file(path)).map(|f| f.contents())
63 }
64
65 #[cfg(feature = "dev-web")]
66 fn get_static_file(path: &str) -> Option<Vec<u8>> {
67 use ::std::fs;
68
69 #[cfg(not(debug_assertions))]
70 compile_error!("cannot enable insecure `dev-web` feature in release mode");
71
72 let dev_path =
74 format!("{}/{}/{}", env!("CARGO_MANIFEST_DIR"), $dev_base_path, path);
75 let prod_path = format!(
76 "{}/{}/{}",
77 env!("CARGO_MANIFEST_DIR"),
78 $prod_base_path,
79 path
80 );
81 match fs::read(dev_path).or_else(|_| fs::read(prod_path)) {
82 Ok(contents) => Some(contents),
83 Err(e) => {
84 ::tracing::debug!("dev-web failed to load static file: {}: {}", path, e);
85 None
86 }
87 }
88 }
89 let path = path.strip_prefix('/').unwrap_or(&path);
90 let content_type = match ::std::path::Path::new(path)
91 .extension()
92 .and_then(|e| e.to_str())
93 {
94 Some("js") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
95 ::mime::TEXT_JAVASCRIPT,
96 ))),
97 Some("css") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
98 ::mime::TEXT_CSS,
99 ))),
100 None | Some(_) => None,
101 };
102 match get_static_file(path) {
103 Some(body) => Ok((content_type, body)),
104 None => Err((::http::StatusCode::NOT_FOUND, "not found")),
105 }
106 }
107 };
108}
109
110#[allow(clippy::unused_async)]
112pub async fn handle_liveness_check() -> impl IntoResponse {
113 (StatusCode::OK, "Liveness check successful!")
114}
115
116#[allow(clippy::unused_async)]
118pub async fn handle_prometheus(registry: &MetricsRegistry) -> impl IntoResponse + use<> {
119 let mut buffer = Vec::new();
120 let encoder = prometheus::TextEncoder::new();
121 encoder
122 .encode(®istry.gather(), &mut buffer)
123 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
124 Ok::<_, (StatusCode, String)>((TypedHeader(ContentType::text()), buffer))
125}
126
127#[derive(Serialize, Deserialize)]
128pub struct DynamicFilterTarget {
129 targets: String,
130}
131
132#[allow(clippy::unused_async)]
134pub async fn handle_reload_tracing_filter(
135 handle: &TracingHandle,
136 reload: fn(&TracingHandle, EnvFilter) -> Result<(), anyhow::Error>,
137 Json(cfg): Json<DynamicFilterTarget>,
138) -> impl IntoResponse {
139 match cfg.targets.parse::<EnvFilter>() {
140 Ok(targets) => match reload(handle, targets) {
141 Ok(()) => (StatusCode::OK, cfg.targets.to_string()),
142 Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
143 },
144 Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
145 }
146}
147
148#[allow(clippy::unused_async)]
150pub async fn handle_tracing() -> impl IntoResponse {
151 (
152 StatusCode::OK,
153 Json(json!({
154 "current_level_filter": tracing::level_filters::LevelFilter::current().to_string()
155 })),
156 )
157}
158
159pub fn build_cors_allowed_origin<'a, I>(allowed: I) -> AllowOrigin
164where
165 I: IntoIterator<Item = &'a HeaderValue>,
166{
167 fn wildcard_origin_match(origin: &HeaderValue, wildcard: &[u8]) -> bool {
168 let Some(origin) = origin.to_str().ok() else {
169 return false;
170 };
171 let Ok(origin) = origin.parse::<Uri>() else {
172 return false;
173 };
174 let Some(host) = origin.host() else {
175 return false;
176 };
177
178 host.as_bytes().ends_with(wildcard)
179 }
180
181 let allowed = allowed.into_iter().cloned().collect::<Vec<HeaderValue>>();
182 if allowed.iter().any(|o| o.as_bytes() == b"*") {
183 AllowOrigin::any()
184 } else {
185 AllowOrigin::predicate(move |origin: &HeaderValue, _request_parts: _| {
186 for val in &allowed {
187 if (val.as_bytes().starts_with(b"*.")
188 && wildcard_origin_match(origin, &val.as_bytes()[1..]))
189 || origin == val
190 {
191 return true;
192 }
193 }
194 false
195 })
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use http::header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN};
202 use http::{HeaderValue, Method, Request, Response};
203 use tower::{Service, ServiceBuilder, ServiceExt};
204 use tower_http::cors::CorsLayer;
205
206 #[mz_ore::test(tokio::test)]
207 async fn test_cors() {
208 async fn test_request(cors: &CorsLayer, origin: &HeaderValue) -> Option<HeaderValue> {
209 let mut service = ServiceBuilder::new()
210 .layer(cors)
211 .service_fn(|_| async { Ok::<_, anyhow::Error>(Response::new("")) });
212 let request = Request::builder().header(ORIGIN, origin).body("").unwrap();
213 let response = service.ready().await.unwrap().call(request).await.unwrap();
214 response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).cloned()
215 }
216
217 #[derive(Default)]
218 struct TestCase {
219 allowed_origins: Vec<HeaderValue>,
221 mirrored_origins: Vec<HeaderValue>,
224 wildcard_origins: Vec<HeaderValue>,
227 invalid_origins: Vec<HeaderValue>,
229 }
230
231 let test_cases = [
232 TestCase {
233 allowed_origins: vec![HeaderValue::from_static("https://example.org")],
234 mirrored_origins: vec![HeaderValue::from_static("https://example.org")],
235 invalid_origins: vec![HeaderValue::from_static("https://wrong.com")],
236 wildcard_origins: vec![],
237 },
238 TestCase {
239 allowed_origins: vec![HeaderValue::from_static("*.example.org")],
240 mirrored_origins: vec![
241 HeaderValue::from_static("https://foo.example.org"),
242 HeaderValue::from_static("https://foo.example.org:8443"),
243 HeaderValue::from_static("https://bar.example.org"),
244 HeaderValue::from_static("https://baz.bar.foo.example.org"),
245 ],
246 wildcard_origins: vec![],
247 invalid_origins: vec![
248 HeaderValue::from_static("https://example.org"),
249 HeaderValue::from_static("https://wrong.com"),
250 HeaderValue::from_static("https://wrong.com:8443"),
251 ],
252 },
253 TestCase {
254 allowed_origins: vec![
255 HeaderValue::from_static("*.example.org"),
256 HeaderValue::from_static("https://other.com"),
257 ],
258 mirrored_origins: vec![
259 HeaderValue::from_static("https://foo.example.org"),
260 HeaderValue::from_static("https://foo.example.org:8443"),
261 HeaderValue::from_static("https://bar.example.org"),
262 HeaderValue::from_static("https://baz.bar.foo.example.org"),
263 HeaderValue::from_static("https://other.com"),
264 ],
265 wildcard_origins: vec![],
266 invalid_origins: vec![HeaderValue::from_static("https://example.org")],
267 },
268 TestCase {
269 allowed_origins: vec![HeaderValue::from_static("*")],
270 mirrored_origins: vec![],
271 wildcard_origins: vec![
272 HeaderValue::from_static("literally"),
273 HeaderValue::from_static("https://anything.com"),
274 ],
275 invalid_origins: vec![],
276 },
277 TestCase {
278 allowed_origins: vec![
279 HeaderValue::from_static("*"),
280 HeaderValue::from_static("https://iwillbeignored.com"),
281 ],
282 mirrored_origins: vec![],
283 wildcard_origins: vec![
284 HeaderValue::from_static("literally"),
285 HeaderValue::from_static("https://anything.com"),
286 ],
287 invalid_origins: vec![],
288 },
289 ];
290
291 for test_case in test_cases {
292 let allowed_origins = &test_case.allowed_origins;
293 let cors = CorsLayer::new()
294 .allow_methods([Method::GET])
295 .allow_origin(super::build_cors_allowed_origin(allowed_origins));
296 for valid in &test_case.mirrored_origins {
297 let header = test_request(&cors, valid).await;
298 assert_eq!(
299 header.as_ref(),
300 Some(valid),
301 "origin {valid:?} unexpectedly not mirrored\n\
302 allowed_origins={allowed_origins:?}",
303 );
304 }
305 for valid in &test_case.wildcard_origins {
306 let header = test_request(&cors, valid).await;
307 assert_eq!(
308 header.as_ref(),
309 Some(&HeaderValue::from_static("*")),
310 "origin {valid:?} unexpectedly not allowed\n\
311 allowed_origins={allowed_origins:?}",
312 );
313 }
314 for invalid in &test_case.invalid_origins {
315 let header = test_request(&cors, invalid).await;
316 assert_eq!(
317 header, None,
318 "origin {invalid:?} unexpectedly not allowed\n\
319 allowed_origins={allowed_origins:?}",
320 );
321 }
322 }
323 }
324}