use std::convert::Infallible;
use std::pin::Pin;
use std::task::{Context, Poll};
use axum::extract::{MatchedPath, Request};
use axum::response::IntoResponse;
use futures::Future;
use mz_ore::metric;
use mz_ore::metrics::MetricsRegistry;
use mz_ore::result::ResultExt;
use mz_ore::stats::histogram_seconds_buckets;
use pin_project::{pin_project, pinned_drop};
use prometheus::{HistogramTimer, HistogramVec, IntCounterVec, IntGaugeVec};
use tower::Layer;
use tower::Service;
#[derive(Debug, Clone)]
pub struct Metrics {
pub requests: IntCounterVec,
pub requests_active: IntGaugeVec,
pub request_duration: HistogramVec,
}
impl Metrics {
pub(crate) fn register_into(registry: &MetricsRegistry, component: &'static str) -> Self {
Self {
requests: registry.register(metric!(
name: "requests_total",
help: "Total number of http requests received since process start.",
subsystem: component,
var_labels: ["source", "path", "status"],
)),
requests_active: registry.register(metric!(
name: "requests_active",
help: "Number of currently active/open http requests.",
subsystem: component,
var_labels: ["source", "path"],
)),
request_duration: registry.register(metric!(
name: "request_duration_seconds",
help: "How long it takes for a request to complete in seconds.",
subsystem: component,
var_labels: ["source", "path"],
buckets: histogram_seconds_buckets(0.000_128, 8.0)
)),
}
}
}
#[derive(Clone)]
pub struct PrometheusLayer {
metrics: Metrics,
source: &'static str,
}
impl PrometheusLayer {
pub fn new(source: &'static str, metrics: Metrics) -> Self {
PrometheusLayer { source, metrics }
}
}
impl<S> Layer<S> for PrometheusLayer {
type Service = PrometheusService<S>;
fn layer(&self, service: S) -> Self::Service {
PrometheusService {
source: self.source,
metrics: self.metrics.clone(),
service,
}
}
}
#[derive(Clone)]
pub struct PrometheusService<S> {
source: &'static str,
metrics: Metrics,
service: S,
}
impl<S> Service<Request> for PrometheusService<S>
where
S: Service<Request>,
S::Response: IntoResponse,
S::Error: Into<Infallible>,
S::Future: Send,
{
type Error = S::Error;
type Response = axum::response::Response;
type Future = PrometheusFuture<S::Future>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
let path = req
.extensions()
.get::<MatchedPath>()
.map(|path| path.as_str().to_string())
.unwrap_or_else(|| "unknown".to_string());
let fut = self.service.call(req);
PrometheusFuture::new(self.source, fut, path, self.metrics.clone())
}
}
#[pin_project(PinnedDrop)]
pub struct PrometheusFuture<F> {
source: &'static str,
path: String,
timer: Option<HistogramTimer>,
metrics: Metrics,
#[pin]
fut: F,
}
impl<F> PrometheusFuture<F> {
pub fn new(source: &'static str, fut: F, path: String, metrics: Metrics) -> Self {
PrometheusFuture {
source,
path,
timer: None,
metrics,
fut,
}
}
}
impl<F, R, E> Future for PrometheusFuture<F>
where
R: IntoResponse,
E: Into<Infallible>,
F: Future<Output = Result<R, E>>,
{
type Output = Result<axum::response::Response, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if this.timer.is_none() {
let duration_metric = this
.metrics
.request_duration
.with_label_values(&[this.source, this.path]);
*this.timer = Some(duration_metric.start_timer());
this.metrics
.requests_active
.with_label_values(&[this.source, this.path])
.inc();
}
match this.fut.poll(cx) {
Poll::Ready(resp) => {
let ok = resp.infallible_unwrap();
let resp = ok.into_response();
let status = resp.status();
this.metrics
.requests
.with_label_values(&[this.source, this.path, status.as_str()])
.inc();
if let Some(timer) = this.timer.take() {
timer.observe_duration();
}
this.metrics
.requests_active
.with_label_values(&[this.source, this.path])
.dec();
Poll::Ready(Ok(resp))
}
Poll::Pending => Poll::Pending,
}
}
}
#[pinned_drop]
impl<F> PinnedDrop for PrometheusFuture<F> {
fn drop(self: Pin<&mut Self>) {
let this = self.project();
if let Some(timer) = this.timer.take() {
this.metrics
.requests_active
.with_label_values(&[this.source, this.path])
.dec();
timer.stop_and_discard();
}
}
}
#[cfg(test)]
mod test {
use futures::Future;
use http::StatusCode;
use mz_ore::metrics::MetricsRegistry;
use std::convert::Infallible;
use std::pin::Pin;
use super::{Metrics, PrometheusFuture};
#[mz_ore::test]
fn test_metrics_future_on_drop() {
let registry = MetricsRegistry::new();
let metrics = Metrics::register_into(®istry, "test");
let waker = futures::task::noop_waker_ref();
let mut cx = std::task::Context::from_waker(waker);
let request_future = futures::future::pending::<Result<(StatusCode, String), Infallible>>();
let mut future =
PrometheusFuture::new("test", request_future, "/future/test".to_string(), metrics);
assert!(Pin::new(&mut future).poll(&mut cx).is_pending());
let metrics = registry.gather();
let total_requests_exists = metrics
.iter()
.find(|metric| metric.get_name().contains("requests_total"))
.is_some();
assert!(!total_requests_exists);
let active_requests = metrics
.iter()
.find(|metric| metric.get_name().contains("requests_active"))
.unwrap();
assert_eq!(active_requests.get_metric()[0].get_gauge().get_value(), 1.0);
drop(future);
let metrics = registry.gather();
let active_requests = metrics
.iter()
.find(|metric| metric.get_name().contains("requests_active"))
.unwrap();
assert_eq!(active_requests.get_metric()[0].get_gauge().get_value(), 0.0);
let active_requests = metrics
.iter()
.find(|metric| metric.get_name().contains("request_duration_seconds"))
.unwrap();
assert_eq!(
active_requests.get_metric()[0]
.get_histogram()
.get_sample_count(),
0
);
}
}