axum/routing/
method_filter.rs

1use http::Method;
2use std::{
3    fmt,
4    fmt::{Debug, Formatter},
5};
6
7/// A filter that matches one or more HTTP methods.
8#[derive(Debug, Copy, Clone, PartialEq)]
9pub struct MethodFilter(u16);
10
11impl MethodFilter {
12    /// Match `DELETE` requests.
13    pub const DELETE: Self = Self::from_bits(0b0_0000_0010);
14    /// Match `GET` requests.
15    pub const GET: Self = Self::from_bits(0b0_0000_0100);
16    /// Match `HEAD` requests.
17    pub const HEAD: Self = Self::from_bits(0b0_0000_1000);
18    /// Match `OPTIONS` requests.
19    pub const OPTIONS: Self = Self::from_bits(0b0_0001_0000);
20    /// Match `PATCH` requests.
21    pub const PATCH: Self = Self::from_bits(0b0_0010_0000);
22    /// Match `POST` requests.
23    pub const POST: Self = Self::from_bits(0b0_0100_0000);
24    /// Match `PUT` requests.
25    pub const PUT: Self = Self::from_bits(0b0_1000_0000);
26    /// Match `TRACE` requests.
27    pub const TRACE: Self = Self::from_bits(0b1_0000_0000);
28
29    const fn bits(&self) -> u16 {
30        let bits = self;
31        bits.0
32    }
33
34    const fn from_bits(bits: u16) -> Self {
35        Self(bits)
36    }
37
38    pub(crate) const fn contains(&self, other: Self) -> bool {
39        self.bits() & other.bits() == other.bits()
40    }
41
42    /// Performs the OR operation between the [`MethodFilter`] in `self` with `other`.
43    pub const fn or(self, other: Self) -> Self {
44        Self(self.0 | other.0)
45    }
46}
47
48/// Error type used when converting a [`Method`] to a [`MethodFilter`] fails.
49#[derive(Debug)]
50pub struct NoMatchingMethodFilter {
51    method: Method,
52}
53
54impl NoMatchingMethodFilter {
55    /// Get the [`Method`] that couldn't be converted to a [`MethodFilter`].
56    pub fn method(&self) -> &Method {
57        &self.method
58    }
59}
60
61impl fmt::Display for NoMatchingMethodFilter {
62    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
63        write!(f, "no `MethodFilter` for `{}`", self.method.as_str())
64    }
65}
66
67impl std::error::Error for NoMatchingMethodFilter {}
68
69impl TryFrom<Method> for MethodFilter {
70    type Error = NoMatchingMethodFilter;
71
72    fn try_from(m: Method) -> Result<Self, NoMatchingMethodFilter> {
73        match m {
74            Method::DELETE => Ok(MethodFilter::DELETE),
75            Method::GET => Ok(MethodFilter::GET),
76            Method::HEAD => Ok(MethodFilter::HEAD),
77            Method::OPTIONS => Ok(MethodFilter::OPTIONS),
78            Method::PATCH => Ok(MethodFilter::PATCH),
79            Method::POST => Ok(MethodFilter::POST),
80            Method::PUT => Ok(MethodFilter::PUT),
81            Method::TRACE => Ok(MethodFilter::TRACE),
82            other => Err(NoMatchingMethodFilter { method: other }),
83        }
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn from_http_method() {
93        assert_eq!(
94            MethodFilter::try_from(Method::DELETE).unwrap(),
95            MethodFilter::DELETE
96        );
97
98        assert_eq!(
99            MethodFilter::try_from(Method::GET).unwrap(),
100            MethodFilter::GET
101        );
102
103        assert_eq!(
104            MethodFilter::try_from(Method::HEAD).unwrap(),
105            MethodFilter::HEAD
106        );
107
108        assert_eq!(
109            MethodFilter::try_from(Method::OPTIONS).unwrap(),
110            MethodFilter::OPTIONS
111        );
112
113        assert_eq!(
114            MethodFilter::try_from(Method::PATCH).unwrap(),
115            MethodFilter::PATCH
116        );
117
118        assert_eq!(
119            MethodFilter::try_from(Method::POST).unwrap(),
120            MethodFilter::POST
121        );
122
123        assert_eq!(
124            MethodFilter::try_from(Method::PUT).unwrap(),
125            MethodFilter::PUT
126        );
127
128        assert_eq!(
129            MethodFilter::try_from(Method::TRACE).unwrap(),
130            MethodFilter::TRACE
131        );
132
133        assert!(MethodFilter::try_from(http::Method::CONNECT)
134            .unwrap_err()
135            .to_string()
136            .contains("CONNECT"));
137    }
138}