async_task/
header.rs

1use core::cell::UnsafeCell;
2use core::fmt;
3use core::task::Waker;
4
5#[cfg(not(feature = "portable-atomic"))]
6use core::sync::atomic::AtomicUsize;
7use core::sync::atomic::Ordering;
8#[cfg(feature = "portable-atomic")]
9use portable_atomic::AtomicUsize;
10
11use crate::raw::TaskVTable;
12use crate::state::*;
13use crate::utils::abort_on_panic;
14
15/// The header of a task.
16///
17/// This header is stored in memory at the beginning of the heap-allocated task.
18pub(crate) struct Header<M> {
19    /// Current state of the task.
20    ///
21    /// Contains flags representing the current state and the reference count.
22    pub(crate) state: AtomicUsize,
23
24    /// The task that is blocked on the `Task` handle.
25    ///
26    /// This waker needs to be woken up once the task completes or is closed.
27    pub(crate) awaiter: UnsafeCell<Option<Waker>>,
28
29    /// The virtual table.
30    ///
31    /// In addition to the actual waker virtual table, it also contains pointers to several other
32    /// methods necessary for bookkeeping the heap-allocated task.
33    pub(crate) vtable: &'static TaskVTable,
34
35    /// Metadata associated with the task.
36    ///
37    /// This metadata may be provided to the user.
38    pub(crate) metadata: M,
39
40    /// Whether or not a panic that occurs in the task should be propagated.
41    #[cfg(feature = "std")]
42    pub(crate) propagate_panic: bool,
43}
44
45impl<M> Header<M> {
46    /// Notifies the awaiter blocked on this task.
47    ///
48    /// If the awaiter is the same as the current waker, it will not be notified.
49    #[inline]
50    pub(crate) fn notify(&self, current: Option<&Waker>) {
51        if let Some(w) = self.take(current) {
52            abort_on_panic(|| w.wake());
53        }
54    }
55
56    /// Takes the awaiter blocked on this task.
57    ///
58    /// If there is no awaiter or if it is the same as the current waker, returns `None`.
59    #[inline]
60    pub(crate) fn take(&self, current: Option<&Waker>) -> Option<Waker> {
61        // Set the bit indicating that the task is notifying its awaiter.
62        let state = self.state.fetch_or(NOTIFYING, Ordering::AcqRel);
63
64        // If the task was not notifying or registering an awaiter...
65        if state & (NOTIFYING | REGISTERING) == 0 {
66            // Take the waker out.
67            let waker = unsafe { (*self.awaiter.get()).take() };
68
69            // Unset the bit indicating that the task is notifying its awaiter.
70            self.state
71                .fetch_and(!NOTIFYING & !AWAITER, Ordering::Release);
72
73            // Finally, notify the waker if it's different from the current waker.
74            if let Some(w) = waker {
75                match current {
76                    None => return Some(w),
77                    Some(c) if !w.will_wake(c) => return Some(w),
78                    Some(_) => abort_on_panic(|| drop(w)),
79                }
80            }
81        }
82
83        None
84    }
85
86    /// Registers a new awaiter blocked on this task.
87    ///
88    /// This method is called when `Task` is polled and it has not yet completed.
89    #[inline]
90    pub(crate) fn register(&self, waker: &Waker) {
91        // Load the state and synchronize with it.
92        let mut state = self.state.fetch_or(0, Ordering::Acquire);
93
94        loop {
95            // There can't be two concurrent registrations because `Task` can only be polled
96            // by a unique pinned reference.
97            debug_assert!(state & REGISTERING == 0);
98
99            // If we're in the notifying state at this moment, just wake and return without
100            // registering.
101            if state & NOTIFYING != 0 {
102                abort_on_panic(|| waker.wake_by_ref());
103                return;
104            }
105
106            // Mark the state to let other threads know we're registering a new awaiter.
107            match self.state.compare_exchange_weak(
108                state,
109                state | REGISTERING,
110                Ordering::AcqRel,
111                Ordering::Acquire,
112            ) {
113                Ok(_) => {
114                    state |= REGISTERING;
115                    break;
116                }
117                Err(s) => state = s,
118            }
119        }
120
121        // Put the waker into the awaiter field.
122        unsafe {
123            abort_on_panic(|| (*self.awaiter.get()) = Some(waker.clone()));
124        }
125
126        // This variable will contain the newly registered waker if a notification comes in before
127        // we complete registration.
128        let mut waker = None;
129
130        loop {
131            // If there was a notification, take the waker out of the awaiter field.
132            if state & NOTIFYING != 0 {
133                if let Some(w) = unsafe { (*self.awaiter.get()).take() } {
134                    abort_on_panic(|| waker = Some(w));
135                }
136            }
137
138            // The new state is not being notified nor registered, but there might or might not be
139            // an awaiter depending on whether there was a concurrent notification.
140            let new = if waker.is_none() {
141                (state & !NOTIFYING & !REGISTERING) | AWAITER
142            } else {
143                state & !NOTIFYING & !REGISTERING & !AWAITER
144            };
145
146            match self
147                .state
148                .compare_exchange_weak(state, new, Ordering::AcqRel, Ordering::Acquire)
149            {
150                Ok(_) => break,
151                Err(s) => state = s,
152            }
153        }
154
155        // If there was a notification during registration, wake the awaiter now.
156        if let Some(w) = waker {
157            abort_on_panic(|| w.wake());
158        }
159    }
160}
161
162impl<M: fmt::Debug> fmt::Debug for Header<M> {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        let state = self.state.load(Ordering::SeqCst);
165
166        f.debug_struct("Header")
167            .field("scheduled", &(state & SCHEDULED != 0))
168            .field("running", &(state & RUNNING != 0))
169            .field("completed", &(state & COMPLETED != 0))
170            .field("closed", &(state & CLOSED != 0))
171            .field("awaiter", &(state & AWAITER != 0))
172            .field("task", &(state & TASK != 0))
173            .field("ref_count", &(state / REFERENCE))
174            .field("metadata", &self.metadata)
175            .finish()
176    }
177}