mz_ore/
task.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License in the LICENSE file at the
6// root of this repository, or online at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Tokio task utilities.
17//!
18//! ## Named task spawning
19//!
20//! The [`spawn`] and [`spawn_blocking`] methods are wrappers around
21//! [`tokio::task::spawn`] and [`tokio::task::spawn_blocking`] that attach a
22//! name the spawned task.
23//!
24//! If Clippy sent you here, replace:
25//!
26//! ```ignore
27//! tokio::task::spawn(my_future)
28//! tokio::task::spawn_blocking(my_blocking_closure)
29//! ```
30//!
31//! with:
32//!
33//! ```ignore
34//! mz_ore::task::spawn(|| format!("taskname:{}", info), my_future)
35//! mz_ore::task::spawn_blocking(|| format!("name:{}", info), my_blocking_closure)
36//! ```
37//!
38//! If you are using methods of the same names on a [`Runtime`] or [`Handle`],
39//! import [`RuntimeExt`] and replace `spawn` with [`RuntimeExt::spawn_named`]
40//! and `spawn_blocking` with [`RuntimeExt::spawn_blocking_named`], adding
41//! naming closures like above.
42
43use std::future::Future;
44use std::pin::Pin;
45use std::sync::Arc;
46use std::task::{Context, Poll};
47
48use futures::FutureExt;
49use tokio::runtime::{Handle, Runtime};
50use tokio::task::{self, JoinHandle as TokioJoinHandle};
51
52/// Wraps a [`JoinHandle`] to abort the underlying task when dropped.
53#[derive(Debug)]
54pub struct AbortOnDropHandle<T>(JoinHandle<T>);
55
56impl<T> AbortOnDropHandle<T> {
57    /// Checks if the task associated with this [`AbortOnDropHandle`] has finished.a
58    pub fn is_finished(&self) -> bool {
59        self.0.inner.is_finished()
60    }
61
62    // Note: adding an `abort(&self)` method here is incorrect; see the comment in JoinHandle::poll.
63}
64
65impl<T> Drop for AbortOnDropHandle<T> {
66    fn drop(&mut self) {
67        self.0.inner.abort();
68    }
69}
70
71impl<T> Future for AbortOnDropHandle<T> {
72    type Output = T;
73
74    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75        self.0.poll_unpin(cx)
76    }
77}
78
79/// Wraps a tokio `JoinHandle` that has never been cancelled.
80/// This allows it to have an infallible implementation of [Future],
81/// and provides some exclusive (i.e. they take `self` ownership)
82/// operations:
83///
84/// - `abort_on_drop`: create an `AbortOnDropHandle` that will automatically abort the task
85/// when the handle is dropped.
86/// - `JoinHandleExt::abort_and_wait`: abort the task and wait for it to be finished.
87/// - `into_tokio_handle`: turn it into an ordinary tokio `JoinHandle`.
88#[derive(Debug)]
89pub struct JoinHandle<T> {
90    inner: TokioJoinHandle<T>,
91    runtime_shutting_down: bool,
92}
93
94impl<T> JoinHandle<T> {
95    /// Wrap a tokio join handle. This is intentionally private, so we can statically guarantee
96    /// that the inner join handle has not been aborted.
97    fn new(handle: TokioJoinHandle<T>) -> Self {
98        Self {
99            inner: handle,
100            runtime_shutting_down: false,
101        }
102    }
103}
104
105impl<T> Future for JoinHandle<T> {
106    type Output = T;
107
108    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
109        if self.runtime_shutting_down {
110            return Poll::Pending;
111        }
112        match self.inner.poll_unpin(cx) {
113            Poll::Ready(Ok(res)) => Poll::Ready(res),
114            Poll::Ready(Err(err)) => {
115                match err.try_into_panic() {
116                    Ok(panic) => std::panic::resume_unwind(panic),
117                    Err(err) => {
118                        assert!(
119                            err.is_cancelled(),
120                            "join errors are either cancellations or panics"
121                        );
122                        // Because `JoinHandle` and `AbortOnDropHandle` don't
123                        // offer an `abort` method, this can only happen if the runtime is
124                        // shutting down, which means this `pending` won't cause a deadlock
125                        // because Tokio drops all outstanding futures on shutdown.
126                        // (In multi-threaded runtimes, not all threads drop futures simultaneously,
127                        // so it is possible for a future on one thread to observe the drop of a future
128                        // on another thread, before it itself is dropped.)
129                        self.runtime_shutting_down = true;
130                        Poll::Pending
131                    }
132                }
133            }
134            Poll::Pending => Poll::Pending,
135        }
136    }
137}
138
139impl<T> JoinHandle<T> {
140    /// Create an [`AbortOnDropHandle`] from this [`JoinHandle`].
141    pub fn abort_on_drop(self) -> AbortOnDropHandle<T> {
142        AbortOnDropHandle(self)
143    }
144
145    /// Checks if the task associated with this [`JoinHandle`] has finished.
146    pub fn is_finished(&self) -> bool {
147        self.inner.is_finished()
148    }
149
150    /// Aborts the task, then waits for it to complete.
151    pub async fn abort_and_wait(self) {
152        self.inner.abort();
153        let _ = self.inner.await;
154    }
155
156    /// Unwrap this handle into a standard [tokio::task::JoinHandle].
157    pub fn into_tokio_handle(self) -> TokioJoinHandle<T> {
158        self.inner
159    }
160
161    // Note: adding an `abort(&self)` method here is incorrect; see the comment in JoinHandle::poll.
162}
163
164/// Spawns a new asynchronous task with a name.
165///
166/// See [`tokio::task::spawn`] and the [module][`self`] docs for more
167/// information.
168#[cfg(not(tokio_unstable))]
169#[track_caller]
170pub fn spawn<Fut, Name, NameClosure>(_nc: NameClosure, future: Fut) -> JoinHandle<Fut::Output>
171where
172    Name: AsRef<str>,
173    NameClosure: FnOnce() -> Name,
174    Fut: Future + Send + 'static,
175    Fut::Output: Send + 'static,
176{
177    #[allow(clippy::disallowed_methods)]
178    JoinHandle::new(tokio::spawn(future))
179}
180
181/// Spawns a new asynchronous task with a name.
182///
183/// See [`tokio::task::spawn`] and the [module][`self`] docs for more
184/// information.
185#[cfg(tokio_unstable)]
186#[track_caller]
187pub fn spawn<Fut, Name, NameClosure>(nc: NameClosure, future: Fut) -> JoinHandle<Fut::Output>
188where
189    Name: AsRef<str>,
190    NameClosure: FnOnce() -> Name,
191    Fut: Future + Send + 'static,
192    Fut::Output: Send + 'static,
193{
194    #[allow(clippy::disallowed_methods)]
195    JoinHandle::new(
196        task::Builder::new()
197            .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
198            .spawn(future)
199            .expect("task spawning cannot fail"),
200    )
201}
202
203/// Runs the provided closure with a name on a thread where blocking is
204/// acceptable.
205///
206/// See [`tokio::task::spawn_blocking`] and the [module][`self`] docs for more
207/// information.
208#[cfg(not(tokio_unstable))]
209#[track_caller]
210#[allow(clippy::disallowed_methods)]
211pub fn spawn_blocking<Function, Output, Name, NameClosure>(
212    _nc: NameClosure,
213    function: Function,
214) -> JoinHandle<Output>
215where
216    Name: AsRef<str>,
217    NameClosure: FnOnce() -> Name,
218    Function: FnOnce() -> Output + Send + 'static,
219    Output: Send + 'static,
220{
221    JoinHandle::new(task::spawn_blocking(function))
222}
223
224/// Runs the provided closure with a name on a thread where blocking is
225/// acceptable.
226///
227/// See [`tokio::task::spawn_blocking`] and the [module][`self`] docs for more
228/// information.
229#[cfg(tokio_unstable)]
230#[track_caller]
231#[allow(clippy::disallowed_methods)]
232pub fn spawn_blocking<Function, Output, Name, NameClosure>(
233    nc: NameClosure,
234    function: Function,
235) -> JoinHandle<Output>
236where
237    Name: AsRef<str>,
238    NameClosure: FnOnce() -> Name,
239    Function: FnOnce() -> Output + Send + 'static,
240    Output: Send + 'static,
241{
242    JoinHandle::new(
243        task::Builder::new()
244            .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
245            .spawn_blocking(function)
246            .expect("task spawning cannot fail"),
247    )
248}
249
250/// Extension methods for [`Runtime`] and [`Handle`].
251///
252/// See the [module][`self`] docs for more information.
253pub trait RuntimeExt {
254    /// Runs the provided closure with a name on a thread where blocking is
255    /// acceptable.
256    ///
257    /// See [`tokio::task::spawn_blocking`] and the [module][`self`] docs for more
258    /// information.
259    #[track_caller]
260    fn spawn_blocking_named<Function, Output, Name, NameClosure>(
261        &self,
262        nc: NameClosure,
263        function: Function,
264    ) -> JoinHandle<Output>
265    where
266        Name: AsRef<str>,
267        NameClosure: FnOnce() -> Name,
268        Function: FnOnce() -> Output + Send + 'static,
269        Output: Send + 'static;
270
271    /// Spawns a new asynchronous task with a name.
272    ///
273    /// See [`tokio::task::spawn`] and the [module][`self`] docs for more
274    /// information.
275    #[track_caller]
276    fn spawn_named<Fut, Name, NameClosure>(
277        &self,
278        _nc: NameClosure,
279        future: Fut,
280    ) -> JoinHandle<Fut::Output>
281    where
282        Name: AsRef<str>,
283        NameClosure: FnOnce() -> Name,
284        Fut: Future + Send + 'static,
285        Fut::Output: Send + 'static;
286}
287
288impl RuntimeExt for &Runtime {
289    fn spawn_blocking_named<Function, Output, Name, NameClosure>(
290        &self,
291        nc: NameClosure,
292        function: Function,
293    ) -> JoinHandle<Output>
294    where
295        Name: AsRef<str>,
296        NameClosure: FnOnce() -> Name,
297        Function: FnOnce() -> Output + Send + 'static,
298        Output: Send + 'static,
299    {
300        let _g = self.enter();
301        spawn_blocking(nc, function)
302    }
303
304    fn spawn_named<Fut, Name, NameClosure>(
305        &self,
306        nc: NameClosure,
307        future: Fut,
308    ) -> JoinHandle<Fut::Output>
309    where
310        Name: AsRef<str>,
311        NameClosure: FnOnce() -> Name,
312        Fut: Future + Send + 'static,
313        Fut::Output: Send + 'static,
314    {
315        let _g = self.enter();
316        spawn(nc, future)
317    }
318}
319
320impl RuntimeExt for Arc<Runtime> {
321    fn spawn_blocking_named<Function, Output, Name, NameClosure>(
322        &self,
323        nc: NameClosure,
324        function: Function,
325    ) -> JoinHandle<Output>
326    where
327        Name: AsRef<str>,
328        NameClosure: FnOnce() -> Name,
329        Function: FnOnce() -> Output + Send + 'static,
330        Output: Send + 'static,
331    {
332        (&**self).spawn_blocking_named(nc, function)
333    }
334
335    fn spawn_named<Fut, Name, NameClosure>(
336        &self,
337        nc: NameClosure,
338        future: Fut,
339    ) -> JoinHandle<Fut::Output>
340    where
341        Name: AsRef<str>,
342        NameClosure: FnOnce() -> Name,
343        Fut: Future + Send + 'static,
344        Fut::Output: Send + 'static,
345    {
346        (&**self).spawn_named(nc, future)
347    }
348}
349
350impl RuntimeExt for Handle {
351    fn spawn_blocking_named<Function, Output, Name, NameClosure>(
352        &self,
353        nc: NameClosure,
354        function: Function,
355    ) -> JoinHandle<Output>
356    where
357        Name: AsRef<str>,
358        NameClosure: FnOnce() -> Name,
359        Function: FnOnce() -> Output + Send + 'static,
360        Output: Send + 'static,
361    {
362        let _g = self.enter();
363        spawn_blocking(nc, function)
364    }
365
366    fn spawn_named<Fut, Name, NameClosure>(
367        &self,
368        nc: NameClosure,
369        future: Fut,
370    ) -> JoinHandle<Fut::Output>
371    where
372        Name: AsRef<str>,
373        NameClosure: FnOnce() -> Name,
374        Fut: Future + Send + 'static,
375        Fut::Output: Send + 'static,
376    {
377        let _g = self.enter();
378        spawn(nc, future)
379    }
380}
381
382/// Extension methods for [`tokio::task::JoinSet`].
383///
384/// See the [module][`self`] docs for more information.
385pub trait JoinSetExt<T> {
386    /// Spawns a new asynchronous task with a name.
387    ///
388    /// See [`tokio::task::spawn`] and the [module][`self`] docs for more
389    /// information.
390    #[track_caller]
391    fn spawn_named<Fut, Name, NameClosure>(
392        &mut self,
393        nc: NameClosure,
394        future: Fut,
395    ) -> tokio::task::AbortHandle
396    where
397        Name: AsRef<str>,
398        NameClosure: FnOnce() -> Name,
399        Fut: Future<Output = T> + Send + 'static,
400        T: Send + 'static;
401}
402
403impl<T> JoinSetExt<T> for tokio::task::JoinSet<T> {
404    // Allow unused variables until everything in ci uses `tokio_unstable`.
405    #[allow(unused_variables)]
406    fn spawn_named<Fut, Name, NameClosure>(
407        &mut self,
408        nc: NameClosure,
409        future: Fut,
410    ) -> tokio::task::AbortHandle
411    where
412        Name: AsRef<str>,
413        NameClosure: FnOnce() -> Name,
414        Fut: Future<Output = T> + Send + 'static,
415        T: Send + 'static,
416    {
417        #[cfg(tokio_unstable)]
418        #[allow(clippy::disallowed_methods)]
419        {
420            self.build_task()
421                .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
422                .spawn(future)
423                .expect("task spawning cannot fail")
424        }
425        #[cfg(not(tokio_unstable))]
426        #[allow(clippy::disallowed_methods)]
427        {
428            self.spawn(future)
429        }
430    }
431}