1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
use super::{ClassifiedResponse, ClassifyResponse, NeverClassifyEos, SharedClassifier};
use http::StatusCode;
use std::{fmt, ops::RangeInclusive};

/// Response classifier that considers responses with a status code within some range to be
/// failures.
///
/// # Example
///
/// A client with tracing where server errors _and_ client errors are considered failures.
///
/// ```no_run
/// use tower_http::{trace::TraceLayer, classify::StatusInRangeAsFailures};
/// use tower::{ServiceBuilder, Service, ServiceExt};
/// use http::{Request, Method};
/// use http_body_util::Full;
/// use bytes::Bytes;
/// use hyper_util::{rt::TokioExecutor, client::legacy::Client};
///
/// # async fn foo() -> Result<(), tower::BoxError> {
/// let classifier = StatusInRangeAsFailures::new(400..=599);
///
/// let client = Client::builder(TokioExecutor::new()).build_http();
/// let mut client = ServiceBuilder::new()
///     .layer(TraceLayer::new(classifier.into_make_classifier()))
///     .service(client);
///
/// let request = Request::builder()
///     .method(Method::GET)
///     .uri("https://example.com")
///     .body(Full::<Bytes>::default())
///     .unwrap();
///
/// let response = client.ready().await?.call(request).await?;
/// # Ok(())
/// # }
/// ```
#[derive(Debug, Clone)]
pub struct StatusInRangeAsFailures {
    range: RangeInclusive<u16>,
}

impl StatusInRangeAsFailures {
    /// Creates a new `StatusInRangeAsFailures`.
    ///
    /// # Panics
    ///
    /// Panics if the start or end of `range` aren't valid status codes as determined by
    /// [`StatusCode::from_u16`].
    ///
    /// [`StatusCode::from_u16`]: https://docs.rs/http/latest/http/status/struct.StatusCode.html#method.from_u16
    pub fn new(range: RangeInclusive<u16>) -> Self {
        assert!(
            StatusCode::from_u16(*range.start()).is_ok(),
            "range start isn't a valid status code"
        );
        assert!(
            StatusCode::from_u16(*range.end()).is_ok(),
            "range end isn't a valid status code"
        );

        Self { range }
    }

    /// Creates a new `StatusInRangeAsFailures` that classifies client and server responses as
    /// failures.
    ///
    /// This is a convenience for `StatusInRangeAsFailures::new(400..=599)`.
    pub fn new_for_client_and_server_errors() -> Self {
        Self::new(400..=599)
    }

    /// Convert this `StatusInRangeAsFailures` into a [`MakeClassifier`].
    ///
    /// [`MakeClassifier`]: super::MakeClassifier
    pub fn into_make_classifier(self) -> SharedClassifier<Self> {
        SharedClassifier::new(self)
    }
}

impl ClassifyResponse for StatusInRangeAsFailures {
    type FailureClass = StatusInRangeFailureClass;
    type ClassifyEos = NeverClassifyEos<Self::FailureClass>;

    fn classify_response<B>(
        self,
        res: &http::Response<B>,
    ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> {
        if self.range.contains(&res.status().as_u16()) {
            let class = StatusInRangeFailureClass::StatusCode(res.status());
            ClassifiedResponse::Ready(Err(class))
        } else {
            ClassifiedResponse::Ready(Ok(()))
        }
    }

    fn classify_error<E>(self, error: &E) -> Self::FailureClass
    where
        E: std::fmt::Display + 'static,
    {
        StatusInRangeFailureClass::Error(error.to_string())
    }
}

/// The failure class for [`StatusInRangeAsFailures`].
#[derive(Debug)]
pub enum StatusInRangeFailureClass {
    /// A response was classified as a failure with the corresponding status.
    StatusCode(StatusCode),
    /// A response was classified as an error with the corresponding error description.
    Error(String),
}

impl fmt::Display for StatusInRangeFailureClass {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            Self::StatusCode(code) => write!(f, "Status code: {}", code),
            Self::Error(error) => write!(f, "Error: {}", error),
        }
    }
}

#[cfg(test)]
mod tests {
    #[allow(unused_imports)]
    use super::*;
    use http::Response;

    #[test]
    fn basic() {
        let classifier = StatusInRangeAsFailures::new(400..=599);

        assert!(matches!(
            dbg!(classifier
                .clone()
                .classify_response(&response_with_status(200))),
            ClassifiedResponse::Ready(Ok(())),
        ));

        assert!(matches!(
            dbg!(classifier
                .clone()
                .classify_response(&response_with_status(400))),
            ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode(
                StatusCode::BAD_REQUEST
            ))),
        ));

        assert!(matches!(
            dbg!(classifier.classify_response(&response_with_status(500))),
            ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode(
                StatusCode::INTERNAL_SERVER_ERROR
            ))),
        ));
    }

    fn response_with_status(status: u16) -> Response<()> {
        Response::builder().status(status).body(()).unwrap()
    }
}