1use askama::Template;
13use axum::Json;
14use axum::http::HeaderValue;
15use axum::http::status::StatusCode;
16use axum::response::{Html, IntoResponse};
17use axum_extra::TypedHeader;
18use headers::ContentType;
19use mz_ore::metrics::MetricsRegistry;
20use mz_ore::tracing::TracingHandle;
21use prometheus::Encoder;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24use tower_http::cors::AllowOrigin;
25use tracing_subscriber::EnvFilter;
26
27pub fn template_response<T>(template: T) -> Html<String>
29where
30 T: Template,
31{
32 Html(template.render().expect("template rendering cannot fail"))
33}
34
35#[macro_export]
36macro_rules! make_handle_static {
41 (
42 dir_1: $dir_1:expr,
43 $(dir_2: $dir_2:expr,)?
44 prod_base_path: $prod_base_path:expr,
45 dev_base_path: $dev_base_path:expr$(,)?
46 ) => {
47 #[allow(clippy::unused_async)]
48 pub async fn handle_static(
49 path: ::axum::extract::Path<String>,
50 ) -> impl ::axum::response::IntoResponse {
51 #[cfg(not(feature = "dev-web"))]
52 const DIR_1: ::include_dir::Dir = $dir_1;
53 $(
54 #[cfg(not(feature = "dev-web"))]
55 const DIR_2: ::include_dir::Dir = $dir_2;
56 )?
57
58
59 #[cfg(not(feature = "dev-web"))]
60 fn get_static_file(path: &str) -> Option<&'static [u8]> {
61 DIR_1.get_file(path).or_else(|| DIR_2.get_file(path)).map(|f| f.contents())
62 }
63
64 #[cfg(feature = "dev-web")]
65 fn get_static_file(path: &str) -> Option<Vec<u8>> {
66 use ::std::fs;
67
68 #[cfg(not(debug_assertions))]
69 compile_error!("cannot enable insecure `dev-web` feature in release mode");
70
71 let dev_path =
73 format!("{}/{}/{}", env!("CARGO_MANIFEST_DIR"), $dev_base_path, path);
74 let prod_path = format!(
75 "{}/{}/{}",
76 env!("CARGO_MANIFEST_DIR"),
77 $prod_base_path,
78 path
79 );
80 match fs::read(dev_path).or_else(|_| fs::read(prod_path)) {
81 Ok(contents) => Some(contents),
82 Err(e) => {
83 ::tracing::debug!("dev-web failed to load static file: {}: {}", path, e);
84 None
85 }
86 }
87 }
88 let path = path.strip_prefix('/').unwrap_or(&path);
89 let content_type = match ::std::path::Path::new(path)
90 .extension()
91 .and_then(|e| e.to_str())
92 {
93 Some("js") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
94 ::mime::TEXT_JAVASCRIPT,
95 ))),
96 Some("css") => Some(::axum_extra::TypedHeader(::headers::ContentType::from(
97 ::mime::TEXT_CSS,
98 ))),
99 None | Some(_) => None,
100 };
101 match get_static_file(path) {
102 Some(body) => Ok((content_type, body)),
103 None => Err((::http::StatusCode::NOT_FOUND, "not found")),
104 }
105 }
106 };
107}
108
109#[allow(clippy::unused_async)]
111pub async fn handle_liveness_check() -> impl IntoResponse {
112 (StatusCode::OK, "Liveness check successful!")
113}
114
115#[allow(clippy::unused_async)]
117pub async fn handle_prometheus(registry: &MetricsRegistry) -> impl IntoResponse + use<> {
118 let mut buffer = Vec::new();
119 let encoder = prometheus::TextEncoder::new();
120 encoder
121 .encode(®istry.gather(), &mut buffer)
122 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
123 Ok::<_, (StatusCode, String)>((TypedHeader(ContentType::text()), buffer))
124}
125
126#[derive(Serialize, Deserialize)]
127pub struct DynamicFilterTarget {
128 targets: String,
129}
130
131#[allow(clippy::unused_async)]
133pub async fn handle_reload_tracing_filter(
134 handle: &TracingHandle,
135 reload: fn(&TracingHandle, EnvFilter) -> Result<(), anyhow::Error>,
136 Json(cfg): Json<DynamicFilterTarget>,
137) -> impl IntoResponse {
138 match cfg.targets.parse::<EnvFilter>() {
139 Ok(targets) => match reload(handle, targets) {
140 Ok(()) => (StatusCode::OK, cfg.targets.to_string()),
141 Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
142 },
143 Err(e) => (StatusCode::BAD_REQUEST, e.to_string()),
144 }
145}
146
147#[allow(clippy::unused_async)]
149pub async fn handle_tracing() -> impl IntoResponse {
150 (
151 StatusCode::OK,
152 Json(json!({
153 "current_level_filter": tracing::level_filters::LevelFilter::current().to_string()
154 })),
155 )
156}
157
158pub fn build_cors_allowed_origin<'a, I>(allowed: I) -> AllowOrigin
163where
164 I: IntoIterator<Item = &'a HeaderValue>,
165{
166 let allowed = allowed.into_iter().cloned().collect::<Vec<HeaderValue>>();
167 if allowed.iter().any(|o| o.as_bytes() == b"*") {
168 AllowOrigin::any()
169 } else {
170 AllowOrigin::predicate(move |origin: &HeaderValue, _request_parts: _| {
171 for val in &allowed {
172 if (val.as_bytes().starts_with(b"*.")
173 && origin.as_bytes().ends_with(&val.as_bytes()[1..]))
174 || origin == val
175 {
176 return true;
177 }
178 }
179 false
180 })
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use http::header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN};
187 use http::{HeaderValue, Method, Request, Response};
188 use tower::{Service, ServiceBuilder, ServiceExt};
189 use tower_http::cors::CorsLayer;
190
191 #[mz_ore::test(tokio::test)]
192 async fn test_cors() {
193 async fn test_request(cors: &CorsLayer, origin: &HeaderValue) -> Option<HeaderValue> {
194 let mut service = ServiceBuilder::new()
195 .layer(cors)
196 .service_fn(|_| async { Ok::<_, anyhow::Error>(Response::new("")) });
197 let request = Request::builder().header(ORIGIN, origin).body("").unwrap();
198 let response = service.ready().await.unwrap().call(request).await.unwrap();
199 response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).cloned()
200 }
201
202 #[derive(Default)]
203 struct TestCase {
204 allowed_origins: Vec<HeaderValue>,
206 mirrored_origins: Vec<HeaderValue>,
209 wildcard_origins: Vec<HeaderValue>,
212 invalid_origins: Vec<HeaderValue>,
214 }
215
216 let test_cases = [
217 TestCase {
218 allowed_origins: vec![HeaderValue::from_static("https://example.org")],
219 mirrored_origins: vec![HeaderValue::from_static("https://example.org")],
220 invalid_origins: vec![HeaderValue::from_static("https://wrong.com")],
221 wildcard_origins: vec![],
222 },
223 TestCase {
224 allowed_origins: vec![HeaderValue::from_static("*.example.org")],
225 mirrored_origins: vec![
226 HeaderValue::from_static("https://foo.example.org"),
227 HeaderValue::from_static("https://bar.example.org"),
228 HeaderValue::from_static("https://baz.bar.foo.example.org"),
229 ],
230 wildcard_origins: vec![],
231 invalid_origins: vec![
232 HeaderValue::from_static("https://example.org"),
233 HeaderValue::from_static("https://wrong.com"),
234 ],
235 },
236 TestCase {
237 allowed_origins: vec![
238 HeaderValue::from_static("*.example.org"),
239 HeaderValue::from_static("https://other.com"),
240 ],
241 mirrored_origins: vec![
242 HeaderValue::from_static("https://foo.example.org"),
243 HeaderValue::from_static("https://bar.example.org"),
244 HeaderValue::from_static("https://baz.bar.foo.example.org"),
245 HeaderValue::from_static("https://other.com"),
246 ],
247 wildcard_origins: vec![],
248 invalid_origins: vec![HeaderValue::from_static("https://example.org")],
249 },
250 TestCase {
251 allowed_origins: vec![HeaderValue::from_static("*")],
252 mirrored_origins: vec![],
253 wildcard_origins: vec![
254 HeaderValue::from_static("literally"),
255 HeaderValue::from_static("https://anything.com"),
256 ],
257 invalid_origins: vec![],
258 },
259 TestCase {
260 allowed_origins: vec![
261 HeaderValue::from_static("*"),
262 HeaderValue::from_static("https://iwillbeignored.com"),
263 ],
264 mirrored_origins: vec![],
265 wildcard_origins: vec![
266 HeaderValue::from_static("literally"),
267 HeaderValue::from_static("https://anything.com"),
268 ],
269 invalid_origins: vec![],
270 },
271 ];
272
273 for test_case in test_cases {
274 let allowed_origins = &test_case.allowed_origins;
275 let cors = CorsLayer::new()
276 .allow_methods([Method::GET])
277 .allow_origin(super::build_cors_allowed_origin(allowed_origins));
278 for valid in &test_case.mirrored_origins {
279 let header = test_request(&cors, valid).await;
280 assert_eq!(
281 header.as_ref(),
282 Some(valid),
283 "origin {valid:?} unexpectedly not mirrored\n\
284 allowed_origins={allowed_origins:?}",
285 );
286 }
287 for valid in &test_case.wildcard_origins {
288 let header = test_request(&cors, valid).await;
289 assert_eq!(
290 header.as_ref(),
291 Some(&HeaderValue::from_static("*")),
292 "origin {valid:?} unexpectedly not allowed\n\
293 allowed_origins={allowed_origins:?}",
294 );
295 }
296 for invalid in &test_case.invalid_origins {
297 let header = test_request(&cors, invalid).await;
298 assert_eq!(
299 header, None,
300 "origin {invalid:?} unexpectedly not allowed\n\
301 allowed_origins={allowed_origins:?}",
302 );
303 }
304 }
305 }
306}