axum/
util.rs

1use pin_project_lite::pin_project;
2use std::{ops::Deref, sync::Arc};
3
4pub(crate) use self::mutex::*;
5
6#[derive(Clone, Debug, PartialEq, Eq, Hash)]
7pub(crate) struct PercentDecodedStr(Arc<str>);
8
9impl PercentDecodedStr {
10    pub(crate) fn new<S>(s: S) -> Option<Self>
11    where
12        S: AsRef<str>,
13    {
14        percent_encoding::percent_decode(s.as_ref().as_bytes())
15            .decode_utf8()
16            .ok()
17            .map(|decoded| Self(decoded.as_ref().into()))
18    }
19
20    pub(crate) fn as_str(&self) -> &str {
21        &self.0
22    }
23
24    pub(crate) fn into_inner(self) -> Arc<str> {
25        self.0
26    }
27}
28
29impl Deref for PercentDecodedStr {
30    type Target = str;
31
32    #[inline]
33    fn deref(&self) -> &Self::Target {
34        self.as_str()
35    }
36}
37
38pin_project! {
39    #[project = EitherProj]
40    pub(crate) enum Either<A, B> {
41        A { #[pin] inner: A },
42        B { #[pin] inner: B },
43    }
44}
45
46pub(crate) fn try_downcast<T, K>(k: K) -> Result<T, K>
47where
48    T: 'static,
49    K: Send + 'static,
50{
51    let mut k = Some(k);
52    if let Some(k) = <dyn std::any::Any>::downcast_mut::<Option<T>>(&mut k) {
53        Ok(k.take().unwrap())
54    } else {
55        Err(k.unwrap())
56    }
57}
58
59#[test]
60fn test_try_downcast() {
61    assert_eq!(try_downcast::<i32, _>(5_u32), Err(5_u32));
62    assert_eq!(try_downcast::<i32, _>(5_i32), Ok(5_i32));
63}
64
65// `AxumMutex` is a wrapper around `std::sync::Mutex` which, in test mode, tracks the number of
66// times it's been locked on the current task. That way we can write a test to ensure we don't
67// accidentally introduce more locking.
68//
69// When not in test mode, it is just a type alias for `std::sync::Mutex`.
70#[cfg(not(test))]
71mod mutex {
72    #[allow(clippy::disallowed_types)]
73    pub(crate) type AxumMutex<T> = std::sync::Mutex<T>;
74}
75
76#[cfg(test)]
77#[allow(clippy::disallowed_types)]
78mod mutex {
79    use std::sync::{
80        atomic::{AtomicUsize, Ordering},
81        LockResult, Mutex, MutexGuard,
82    };
83
84    tokio::task_local! {
85        pub(crate) static NUM_LOCKED: AtomicUsize;
86    }
87
88    pub(crate) async fn mutex_num_locked<F, Fut>(f: F) -> (usize, Fut::Output)
89    where
90        F: FnOnce() -> Fut,
91        Fut: std::future::IntoFuture,
92    {
93        NUM_LOCKED
94            .scope(AtomicUsize::new(0), async move {
95                let output = f().await;
96                let num = NUM_LOCKED.with(|num| num.load(Ordering::SeqCst));
97                (num, output)
98            })
99            .await
100    }
101
102    pub(crate) struct AxumMutex<T>(Mutex<T>);
103
104    impl<T> AxumMutex<T> {
105        pub(crate) fn new(value: T) -> Self {
106            Self(Mutex::new(value))
107        }
108
109        pub(crate) fn get_mut(&mut self) -> LockResult<&mut T> {
110            self.0.get_mut()
111        }
112
113        pub(crate) fn into_inner(self) -> LockResult<T> {
114            self.0.into_inner()
115        }
116
117        pub(crate) fn lock(&self) -> LockResult<MutexGuard<'_, T>> {
118            _ = NUM_LOCKED.try_with(|num| {
119                num.fetch_add(1, Ordering::SeqCst);
120            });
121            self.0.lock()
122        }
123    }
124}