1use std::convert::Infallible;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15
16use axum::extract::{MatchedPath, Request};
17use axum::response::IntoResponse;
18use futures::Future;
19use mz_ore::metric;
20use mz_ore::metrics::MetricsRegistry;
21use mz_ore::result::ResultExt;
22use mz_ore::stats::histogram_seconds_buckets;
23use pin_project::{pin_project, pinned_drop};
24use prometheus::{HistogramTimer, HistogramVec, IntCounterVec, IntGaugeVec};
25use tower::Layer;
26use tower::Service;
27
28#[derive(Debug, Clone)]
29pub struct Metrics {
30 pub requests: IntCounterVec,
32 pub requests_active: IntGaugeVec,
34 pub request_duration: HistogramVec,
36}
37
38impl Metrics {
39 pub(crate) fn register_into(registry: &MetricsRegistry, component: &'static str) -> Self {
40 Self {
41 requests: registry.register(metric!(
42 name: "requests_total",
43 help: "Total number of http requests received since process start.",
44 subsystem: component,
45 var_labels: ["source", "path", "status"],
46 )),
47 requests_active: registry.register(metric!(
48 name: "requests_active",
49 help: "Number of currently active/open http requests.",
50 subsystem: component,
51 var_labels: ["source", "path"],
52 )),
53 request_duration: registry.register(metric!(
54 name: "request_duration_seconds",
55 help: "How long it takes for a request to complete in seconds.",
56 subsystem: component,
57 var_labels: ["source", "path"],
58 buckets: histogram_seconds_buckets(0.000_128, 8.0)
59 )),
60 }
61 }
62}
63
64#[derive(Clone)]
65pub struct PrometheusLayer {
66 metrics: Metrics,
67 source: &'static str,
68}
69
70impl PrometheusLayer {
71 pub fn new(source: &'static str, metrics: Metrics) -> Self {
72 PrometheusLayer { source, metrics }
73 }
74}
75
76impl<S> Layer<S> for PrometheusLayer {
77 type Service = PrometheusService<S>;
78
79 fn layer(&self, service: S) -> Self::Service {
80 PrometheusService {
81 source: self.source,
82 metrics: self.metrics.clone(),
83 service,
84 }
85 }
86}
87
88#[derive(Clone)]
89pub struct PrometheusService<S> {
90 source: &'static str,
91 metrics: Metrics,
92 service: S,
93}
94
95impl<S> Service<Request> for PrometheusService<S>
96where
97 S: Service<Request>,
98 S::Response: IntoResponse,
99 S::Error: Into<Infallible>,
100 S::Future: Send,
101{
102 type Error = S::Error;
103 type Response = axum::response::Response;
104 type Future = PrometheusFuture<S::Future>;
105
106 fn poll_ready(
107 &mut self,
108 cx: &mut std::task::Context<'_>,
109 ) -> std::task::Poll<Result<(), Self::Error>> {
110 self.service.poll_ready(cx)
111 }
112
113 fn call(&mut self, req: Request) -> Self::Future {
114 let path = req
115 .extensions()
116 .get::<MatchedPath>()
117 .map(|path| path.as_str().to_string())
118 .unwrap_or_else(|| "unknown".to_string());
119 let fut = self.service.call(req);
120 PrometheusFuture::new(self.source, fut, path, self.metrics.clone())
121 }
122}
123
124#[pin_project(PinnedDrop)]
125pub struct PrometheusFuture<F> {
126 source: &'static str,
128 path: String,
130 timer: Option<HistogramTimer>,
132 metrics: Metrics,
134 #[pin]
136 fut: F,
137}
138
139impl<F> PrometheusFuture<F> {
140 pub fn new(source: &'static str, fut: F, path: String, metrics: Metrics) -> Self {
141 PrometheusFuture {
142 source,
143 path,
144 timer: None,
145 metrics,
146 fut,
147 }
148 }
149}
150
151impl<F, R, E> Future for PrometheusFuture<F>
152where
153 R: IntoResponse,
154 E: Into<Infallible>,
155 F: Future<Output = Result<R, E>>,
156{
157 type Output = Result<axum::response::Response, E>;
158
159 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
160 let this = self.project();
161
162 if this.timer.is_none() {
163 let duration_metric = this
165 .metrics
166 .request_duration
167 .with_label_values(&[this.source, this.path]);
168 *this.timer = Some(duration_metric.start_timer());
169
170 this.metrics
172 .requests_active
173 .with_label_values(&[this.source, this.path])
174 .inc();
175 }
176
177 match this.fut.poll(cx) {
179 Poll::Ready(resp) => {
180 let ok = resp.infallible_unwrap();
181 let resp = ok.into_response();
182 let status = resp.status();
183
184 this.metrics
186 .requests
187 .with_label_values(&[this.source, this.path, status.as_str()])
188 .inc();
189
190 if let Some(timer) = this.timer.take() {
192 timer.observe_duration();
193 }
194
195 this.metrics
197 .requests_active
198 .with_label_values(&[this.source, this.path])
199 .dec();
200
201 Poll::Ready(Ok(resp))
202 }
203 Poll::Pending => Poll::Pending,
204 }
205 }
206}
207
208#[pinned_drop]
209impl<F> PinnedDrop for PrometheusFuture<F> {
210 fn drop(self: Pin<&mut Self>) {
211 let this = self.project();
212
213 if let Some(timer) = this.timer.take() {
214 this.metrics
216 .requests_active
217 .with_label_values(&[this.source, this.path])
218 .dec();
219
220 timer.stop_and_discard();
222 }
223 }
224}
225
226#[cfg(test)]
227mod test {
228 use futures::Future;
229 use http::StatusCode;
230 use mz_ore::metrics::MetricsRegistry;
231 use std::convert::Infallible;
232 use std::pin::Pin;
233
234 use super::{Metrics, PrometheusFuture};
235
236 #[mz_ore::test]
237 fn test_metrics_future_on_drop() {
238 let registry = MetricsRegistry::new();
239 let metrics = Metrics::register_into(®istry, "test");
240 let waker = futures::task::noop_waker_ref();
241 let mut cx = std::task::Context::from_waker(waker);
242
243 let request_future = futures::future::pending::<Result<(StatusCode, String), Infallible>>();
244 let mut future =
245 PrometheusFuture::new("test", request_future, "/future/test".to_string(), metrics);
246
247 assert!(Pin::new(&mut future).poll(&mut cx).is_pending());
249
250 let metrics = registry.gather();
251
252 let total_requests_exists = metrics
254 .iter()
255 .find(|metric| metric.get_name().contains("requests_total"))
256 .is_some();
257 assert!(!total_requests_exists);
258
259 let active_requests = metrics
261 .iter()
262 .find(|metric| metric.get_name().contains("requests_active"))
263 .unwrap();
264 assert_eq!(active_requests.get_metric()[0].get_gauge().get_value(), 1.0);
265
266 drop(future);
268
269 let metrics = registry.gather();
270
271 let active_requests = metrics
273 .iter()
274 .find(|metric| metric.get_name().contains("requests_active"))
275 .unwrap();
276 assert_eq!(active_requests.get_metric()[0].get_gauge().get_value(), 0.0);
277
278 let active_requests = metrics
280 .iter()
281 .find(|metric| metric.get_name().contains("request_duration_seconds"))
282 .unwrap();
283 assert_eq!(
284 active_requests.get_metric()[0]
285 .get_histogram()
286 .get_sample_count(),
287 0
288 );
289 }
290}