task_local_extensions/
task_local.rs

1// clippy bug wrongly flags the task_local macro as being bad.
2// a fix is already merged but hasn't made it upstream yet
3#![allow(clippy::declare_interior_mutable_const)]
4
5use crate::Extensions;
6use std::cell::RefCell;
7use std::future::Future;
8
9thread_local! {
10    static EXTENSIONS: RefCell<Extensions> = RefCell::new(Extensions::new());
11}
12
13/// Sets a task local to `Extensions` before `fut` is run,
14/// and fetches the contents of the task local Extensions after completion
15/// and returns it.
16pub async fn with_extensions<T>(
17    mut extensions: Extensions,
18    fut: impl Future<Output = T>,
19) -> (Extensions, T) {
20    pin_utils::pin_mut!(fut);
21    let res = std::future::poll_fn(|cx| {
22        EXTENSIONS.with(|ext| {
23            // swap in the extensions
24            std::mem::swap(&mut extensions, &mut *ext.borrow_mut());
25
26            let res = fut.as_mut().poll(cx);
27
28            // swap back
29            std::mem::swap(&mut extensions, &mut *ext.borrow_mut());
30
31            res
32        })
33    })
34    .await;
35
36    (extensions, res)
37}
38
39/// Retrieve any item from task-local storage.
40// TODO: doesn't need to be async?
41pub async fn get_local_item<T: Send + Sync + Clone + 'static>() -> Option<T> {
42    EXTENSIONS
43        .try_with(|e| e.borrow().get::<T>().cloned())
44        .ok()
45        .flatten()
46}
47
48/// Set an item in task-local storage.
49// TODO: doesn't need to be async?
50pub async fn set_local_item<T: Send + Sync + 'static>(item: T) {
51    EXTENSIONS
52        .try_with(|e| e.borrow_mut().insert(item))
53        .expect("Failed to set local item.");
54}
55
56#[cfg(test)]
57mod tests {
58    use crate::{get_local_item, set_local_item, with_extensions, Extensions};
59
60    #[derive(Clone)]
61    struct A;
62    #[derive(Clone)]
63    struct B;
64
65    #[derive(Clone)]
66    struct C;
67
68    #[tokio::test]
69    async fn works() {
70        let mut a = Extensions::new();
71        a.insert(A);
72
73        let (a, _) = with_extensions(a, async {
74            let mut b = Extensions::new();
75            b.insert(B);
76
77            let (b, _) = with_extensions(b, async {
78                assert!(get_local_item::<A>().await.is_none());
79                assert!(get_local_item::<B>().await.is_some());
80                set_local_item(C).await;
81            })
82            .await;
83
84            // returned extension is correct
85            assert!(b.get::<B>().is_some());
86            assert!(b.get::<C>().is_some());
87
88            assert!(get_local_item::<A>().await.is_some());
89            assert!(get_local_item::<B>().await.is_none());
90            assert!(get_local_item::<C>().await.is_none());
91        })
92        .await;
93
94        // returned extension is correct
95        assert!(a.get::<A>().is_some());
96    }
97}