1use std::convert::Infallible;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15
16use axum::extract::{MatchedPath, Request};
17use axum::response::IntoResponse;
18use futures::Future;
19use mz_ore::metric;
20use mz_ore::metrics::MetricsRegistry;
21use mz_ore::result::ResultExt;
22use mz_ore::stats::histogram_seconds_buckets;
23use pin_project::{pin_project, pinned_drop};
24use prometheus::{HistogramTimer, HistogramVec, IntCounterVec, IntGaugeVec};
25use tower::Layer;
26use tower::Service;
27
28#[derive(Debug, Clone)]
29pub struct Metrics {
30    pub requests: IntCounterVec,
32    pub requests_active: IntGaugeVec,
34    pub request_duration: HistogramVec,
36}
37
38impl Metrics {
39    pub(crate) fn register_into(registry: &MetricsRegistry, component: &'static str) -> Self {
40        Self {
41            requests: registry.register(metric!(
42                name: "requests_total",
43                help: "Total number of http requests received since process start.",
44                subsystem: component,
45                var_labels: ["source", "path", "status"],
46            )),
47            requests_active: registry.register(metric!(
48                name: "requests_active",
49                help: "Number of currently active/open http requests.",
50                subsystem: component,
51                var_labels: ["source", "path"],
52            )),
53            request_duration: registry.register(metric!(
54                name: "request_duration_seconds",
55                help: "How long it takes for a request to complete in seconds.",
56                subsystem: component,
57                var_labels: ["source", "path"],
58                buckets: histogram_seconds_buckets(0.000_128, 8.0)
59            )),
60        }
61    }
62}
63
64#[derive(Clone)]
65pub struct PrometheusLayer {
66    metrics: Metrics,
67    source: &'static str,
68}
69
70impl PrometheusLayer {
71    pub fn new(source: &'static str, metrics: Metrics) -> Self {
72        PrometheusLayer { source, metrics }
73    }
74}
75
76impl<S> Layer<S> for PrometheusLayer {
77    type Service = PrometheusService<S>;
78
79    fn layer(&self, service: S) -> Self::Service {
80        PrometheusService {
81            source: self.source,
82            metrics: self.metrics.clone(),
83            service,
84        }
85    }
86}
87
88#[derive(Clone)]
89pub struct PrometheusService<S> {
90    source: &'static str,
91    metrics: Metrics,
92    service: S,
93}
94
95impl<S> Service<Request> for PrometheusService<S>
96where
97    S: Service<Request>,
98    S::Response: IntoResponse,
99    S::Error: Into<Infallible>,
100    S::Future: Send,
101{
102    type Error = S::Error;
103    type Response = axum::response::Response;
104    type Future = PrometheusFuture<S::Future>;
105
106    fn poll_ready(
107        &mut self,
108        cx: &mut std::task::Context<'_>,
109    ) -> std::task::Poll<Result<(), Self::Error>> {
110        self.service.poll_ready(cx)
111    }
112
113    fn call(&mut self, req: Request) -> Self::Future {
114        let path = req
115            .extensions()
116            .get::<MatchedPath>()
117            .map(|path| path.as_str().to_string())
118            .unwrap_or_else(|| "unknown".to_string());
119        let fut = self.service.call(req);
120        PrometheusFuture::new(self.source, fut, path, self.metrics.clone())
121    }
122}
123
124#[pin_project(PinnedDrop)]
125pub struct PrometheusFuture<F> {
126    source: &'static str,
128    path: String,
130    timer: Option<HistogramTimer>,
132    metrics: Metrics,
134    #[pin]
136    fut: F,
137}
138
139impl<F> PrometheusFuture<F> {
140    pub fn new(source: &'static str, fut: F, path: String, metrics: Metrics) -> Self {
141        PrometheusFuture {
142            source,
143            path,
144            timer: None,
145            metrics,
146            fut,
147        }
148    }
149}
150
151impl<F, R, E> Future for PrometheusFuture<F>
152where
153    R: IntoResponse,
154    E: Into<Infallible>,
155    F: Future<Output = Result<R, E>>,
156{
157    type Output = Result<axum::response::Response, E>;
158
159    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
160        let this = self.project();
161
162        if this.timer.is_none() {
163            let duration_metric = this
165                .metrics
166                .request_duration
167                .with_label_values::<&str>(&[this.source, this.path]);
168            *this.timer = Some(duration_metric.start_timer());
169
170            this.metrics
172                .requests_active
173                .with_label_values::<&str>(&[this.source, this.path])
174                .inc();
175        }
176
177        match this.fut.poll(cx) {
179            Poll::Ready(resp) => {
180                let ok = resp.infallible_unwrap();
181                let resp = ok.into_response();
182                let status = resp.status();
183
184                this.metrics
186                    .requests
187                    .with_label_values::<&str>(&[this.source, this.path, status.as_str()])
188                    .inc();
189
190                if let Some(timer) = this.timer.take() {
192                    timer.observe_duration();
193                }
194
195                this.metrics
197                    .requests_active
198                    .with_label_values::<&str>(&[this.source, this.path])
199                    .dec();
200
201                Poll::Ready(Ok(resp))
202            }
203            Poll::Pending => Poll::Pending,
204        }
205    }
206}
207
208#[pinned_drop]
209impl<F> PinnedDrop for PrometheusFuture<F> {
210    fn drop(self: Pin<&mut Self>) {
211        let this = self.project();
212
213        if let Some(timer) = this.timer.take() {
214            this.metrics
216                .requests_active
217                .with_label_values::<&str>(&[this.source, this.path])
218                .dec();
219
220            timer.stop_and_discard();
222        }
223    }
224}
225
226#[cfg(test)]
227mod test {
228    use futures::Future;
229    use http::StatusCode;
230    use mz_ore::metrics::MetricsRegistry;
231    use std::convert::Infallible;
232    use std::pin::Pin;
233
234    use super::{Metrics, PrometheusFuture};
235
236    #[mz_ore::test]
237    fn test_metrics_future_on_drop() {
238        let registry = MetricsRegistry::new();
239        let metrics = Metrics::register_into(®istry, "test");
240        let waker = futures::task::noop_waker_ref();
241        let mut cx = std::task::Context::from_waker(waker);
242
243        let request_future = futures::future::pending::<Result<(StatusCode, String), Infallible>>();
244        let mut future =
245            PrometheusFuture::new("test", request_future, "/future/test".to_string(), metrics);
246
247        assert!(Pin::new(&mut future).poll(&mut cx).is_pending());
249
250        let metrics = registry.gather();
251
252        let total_requests_exists = metrics
254            .iter()
255            .find(|metric| metric.name().contains("requests_total"))
256            .is_some();
257        assert!(!total_requests_exists);
258
259        let active_requests = metrics
261            .iter()
262            .find(|metric| metric.name().contains("requests_active"))
263            .unwrap();
264        assert_eq!(active_requests.get_metric()[0].get_gauge().get_value(), 1.0);
265
266        drop(future);
268
269        let metrics = registry.gather();
270
271        let active_requests = metrics
273            .iter()
274            .find(|metric| metric.name().contains("requests_active"))
275            .unwrap();
276        assert_eq!(active_requests.get_metric()[0].get_gauge().get_value(), 0.0);
277
278        let active_requests = metrics
280            .iter()
281            .find(|metric| metric.name().contains("request_duration_seconds"))
282            .unwrap();
283        assert_eq!(
284            active_requests.get_metric()[0]
285                .get_histogram()
286                .get_sample_count(),
287            0
288        );
289    }
290}