Skip to main content

mz_ore/
task.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License in the LICENSE file at the
6// root of this repository, or online at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Tokio task utilities.
17//!
18//! ## Named task spawning
19//!
20//! The [`spawn`] and [`spawn_blocking`] methods are wrappers around
21//! [`tokio::task::spawn`] and [`tokio::task::spawn_blocking`] that attach a
22//! name the spawned task.
23//!
24//! If Clippy sent you here, replace:
25//!
26//! ```ignore
27//! tokio::task::spawn(my_future)
28//! tokio::task::spawn_blocking(my_blocking_closure)
29//! ```
30//!
31//! with:
32//!
33//! ```ignore
34//! mz_ore::task::spawn(|| format!("taskname:{}", info), my_future)
35//! mz_ore::task::spawn_blocking(|| format!("name:{}", info), my_blocking_closure)
36//! ```
37//!
38//! If you are using methods of the same names on a [`Runtime`] or [`Handle`],
39//! import [`RuntimeExt`] and replace `spawn` with [`RuntimeExt::spawn_named`]
40//! and `spawn_blocking` with [`RuntimeExt::spawn_blocking_named`], adding
41//! naming closures like above.
42
43use std::future::Future;
44use std::pin::Pin;
45use std::sync::Arc;
46use std::task::{Context, Poll};
47
48use futures::FutureExt;
49use tokio::runtime::{Handle, Runtime};
50use tokio::task::{self, JoinHandle as TokioJoinHandle};
51
52/// Wraps a [`JoinHandle`] to abort the underlying task when dropped.
53#[derive(Debug)]
54pub struct AbortOnDropHandle<T>(JoinHandle<T>);
55
56impl<T> AbortOnDropHandle<T> {
57    /// Checks if the task associated with this [`AbortOnDropHandle`] has finished.a
58    pub fn is_finished(&self) -> bool {
59        self.0.inner.is_finished()
60    }
61
62    // Note: adding an `abort(&self)` method here is incorrect; see the comment in JoinHandle::poll.
63}
64
65impl<T> Drop for AbortOnDropHandle<T> {
66    fn drop(&mut self) {
67        self.0.inner.abort();
68    }
69}
70
71impl<T> Future for AbortOnDropHandle<T> {
72    type Output = T;
73
74    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75        self.0.poll_unpin(cx)
76    }
77}
78
79/// Wraps a tokio `JoinHandle` that has never been cancelled.
80/// This allows it to have an infallible implementation of [Future],
81/// and provides some exclusive (i.e. they take `self` ownership)
82/// operations:
83///
84/// - `abort_on_drop`: create an `AbortOnDropHandle` that will automatically abort the task
85/// when the handle is dropped.
86/// - `JoinHandleExt::abort_and_wait`: abort the task and wait for it to be finished.
87/// - `into_tokio_handle`: turn it into an ordinary tokio `JoinHandle`.
88#[derive(Debug)]
89pub struct JoinHandle<T> {
90    inner: TokioJoinHandle<T>,
91    runtime_shutting_down: bool,
92}
93
94impl<T> JoinHandle<T> {
95    /// Wrap a tokio join handle. This is intentionally private, so we can statically guarantee
96    /// that the inner join handle has not been aborted.
97    fn new(handle: TokioJoinHandle<T>) -> Self {
98        Self {
99            inner: handle,
100            runtime_shutting_down: false,
101        }
102    }
103}
104
105impl<T> Future for JoinHandle<T> {
106    type Output = T;
107
108    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
109        if self.runtime_shutting_down {
110            return Poll::Pending;
111        }
112        match self.inner.poll_unpin(cx) {
113            Poll::Ready(Ok(res)) => Poll::Ready(res),
114            Poll::Ready(Err(err)) => {
115                match err.try_into_panic() {
116                    Ok(panic) => std::panic::resume_unwind(panic),
117                    Err(err) => {
118                        assert!(
119                            err.is_cancelled(),
120                            "join errors are either cancellations or panics"
121                        );
122                        // Because `JoinHandle` and `AbortOnDropHandle` don't
123                        // offer an `abort` method, this can only happen if the runtime is
124                        // shutting down, which means this `pending` won't cause a deadlock
125                        // because Tokio drops all outstanding futures on shutdown.
126                        // (In multi-threaded runtimes, not all threads drop futures simultaneously,
127                        // so it is possible for a future on one thread to observe the drop of a future
128                        // on another thread, before it itself is dropped.)
129                        self.runtime_shutting_down = true;
130                        Poll::Pending
131                    }
132                }
133            }
134            Poll::Pending => Poll::Pending,
135        }
136    }
137}
138
139impl<T> JoinHandle<T> {
140    /// Create an [`AbortOnDropHandle`] from this [`JoinHandle`].
141    pub fn abort_on_drop(self) -> AbortOnDropHandle<T> {
142        AbortOnDropHandle(self)
143    }
144
145    /// Checks if the task associated with this [`JoinHandle`] has finished.
146    pub fn is_finished(&self) -> bool {
147        self.inner.is_finished()
148    }
149
150    /// Aborts the task, then waits for it to complete.
151    pub async fn abort_and_wait(self) {
152        self.inner.abort();
153        let _ = self.inner.await;
154    }
155
156    /// Unwrap this handle into a standard [tokio::task::JoinHandle].
157    pub fn into_tokio_handle(self) -> TokioJoinHandle<T> {
158        self.inner
159    }
160
161    // Note: adding an `abort(&self)` method here is incorrect; see the comment in JoinHandle::poll.
162}
163
164/// Spawns a new asynchronous task with a name.
165///
166/// See [`tokio::task::spawn`] and the [module][`self`] docs for more
167/// information.
168#[cfg(not(tokio_unstable))]
169#[track_caller]
170pub fn spawn<Fut, Name, NameClosure>(_nc: NameClosure, future: Fut) -> JoinHandle<Fut::Output>
171where
172    Name: AsRef<str>,
173    NameClosure: FnOnce() -> Name,
174    Fut: Future + Send + 'static,
175    Fut::Output: Send + 'static,
176{
177    // Box the future so tokio's task machinery is monomorphized over
178    // `Pin<Box<dyn Future<Output = Fut::Output> + Send>>` per output type,
179    // rather than over every distinct future type at every call site.
180    let future: Pin<Box<dyn Future<Output = Fut::Output> + Send>> = Box::pin(future);
181    #[allow(clippy::disallowed_methods)]
182    JoinHandle::new(tokio::spawn(future))
183}
184
185/// Spawns a new asynchronous task with a name.
186///
187/// See [`tokio::task::spawn`] and the [module][`self`] docs for more
188/// information.
189#[cfg(tokio_unstable)]
190#[track_caller]
191pub fn spawn<Fut, Name, NameClosure>(nc: NameClosure, future: Fut) -> JoinHandle<Fut::Output>
192where
193    Name: AsRef<str>,
194    NameClosure: FnOnce() -> Name,
195    Fut: Future + Send + 'static,
196    Fut::Output: Send + 'static,
197{
198    // Box the future so tokio's task machinery is monomorphized over
199    // `Pin<Box<dyn Future<Output = Fut::Output> + Send>>` per output type,
200    // rather than over every distinct future type at every call site.
201    let future: Pin<Box<dyn Future<Output = Fut::Output> + Send>> = Box::pin(future);
202    #[allow(clippy::disallowed_methods)]
203    JoinHandle::new(
204        task::Builder::new()
205            .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
206            .spawn(future)
207            .expect("task spawning cannot fail"),
208    )
209}
210
211/// Runs the provided closure with a name on a thread where blocking is
212/// acceptable.
213///
214/// See [`tokio::task::spawn_blocking`] and the [module][`self`] docs for more
215/// information.
216#[cfg(not(tokio_unstable))]
217#[track_caller]
218#[allow(clippy::disallowed_methods)]
219pub fn spawn_blocking<Function, Output, Name, NameClosure>(
220    _nc: NameClosure,
221    function: Function,
222) -> JoinHandle<Output>
223where
224    Name: AsRef<str>,
225    NameClosure: FnOnce() -> Name,
226    Function: FnOnce() -> Output + Send + 'static,
227    Output: Send + 'static,
228{
229    JoinHandle::new(task::spawn_blocking(function))
230}
231
232/// Runs the provided closure with a name on a thread where blocking is
233/// acceptable.
234///
235/// See [`tokio::task::spawn_blocking`] and the [module][`self`] docs for more
236/// information.
237#[cfg(tokio_unstable)]
238#[track_caller]
239#[allow(clippy::disallowed_methods)]
240pub fn spawn_blocking<Function, Output, Name, NameClosure>(
241    nc: NameClosure,
242    function: Function,
243) -> JoinHandle<Output>
244where
245    Name: AsRef<str>,
246    NameClosure: FnOnce() -> Name,
247    Function: FnOnce() -> Output + Send + 'static,
248    Output: Send + 'static,
249{
250    JoinHandle::new(
251        task::Builder::new()
252            .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
253            .spawn_blocking(function)
254            .expect("task spawning cannot fail"),
255    )
256}
257
258/// Extension methods for [`Runtime`] and [`Handle`].
259///
260/// See the [module][`self`] docs for more information.
261pub trait RuntimeExt {
262    /// Runs the provided closure with a name on a thread where blocking is
263    /// acceptable.
264    ///
265    /// See [`tokio::task::spawn_blocking`] and the [module][`self`] docs for more
266    /// information.
267    #[track_caller]
268    fn spawn_blocking_named<Function, Output, Name, NameClosure>(
269        &self,
270        nc: NameClosure,
271        function: Function,
272    ) -> JoinHandle<Output>
273    where
274        Name: AsRef<str>,
275        NameClosure: FnOnce() -> Name,
276        Function: FnOnce() -> Output + Send + 'static,
277        Output: Send + 'static;
278
279    /// Spawns a new asynchronous task with a name.
280    ///
281    /// See [`tokio::task::spawn`] and the [module][`self`] docs for more
282    /// information.
283    #[track_caller]
284    fn spawn_named<Fut, Name, NameClosure>(
285        &self,
286        _nc: NameClosure,
287        future: Fut,
288    ) -> JoinHandle<Fut::Output>
289    where
290        Name: AsRef<str>,
291        NameClosure: FnOnce() -> Name,
292        Fut: Future + Send + 'static,
293        Fut::Output: Send + 'static;
294}
295
296impl RuntimeExt for &Runtime {
297    fn spawn_blocking_named<Function, Output, Name, NameClosure>(
298        &self,
299        nc: NameClosure,
300        function: Function,
301    ) -> JoinHandle<Output>
302    where
303        Name: AsRef<str>,
304        NameClosure: FnOnce() -> Name,
305        Function: FnOnce() -> Output + Send + 'static,
306        Output: Send + 'static,
307    {
308        let _g = self.enter();
309        spawn_blocking(nc, function)
310    }
311
312    fn spawn_named<Fut, Name, NameClosure>(
313        &self,
314        nc: NameClosure,
315        future: Fut,
316    ) -> JoinHandle<Fut::Output>
317    where
318        Name: AsRef<str>,
319        NameClosure: FnOnce() -> Name,
320        Fut: Future + Send + 'static,
321        Fut::Output: Send + 'static,
322    {
323        let _g = self.enter();
324        spawn(nc, future)
325    }
326}
327
328impl RuntimeExt for Arc<Runtime> {
329    fn spawn_blocking_named<Function, Output, Name, NameClosure>(
330        &self,
331        nc: NameClosure,
332        function: Function,
333    ) -> JoinHandle<Output>
334    where
335        Name: AsRef<str>,
336        NameClosure: FnOnce() -> Name,
337        Function: FnOnce() -> Output + Send + 'static,
338        Output: Send + 'static,
339    {
340        (&**self).spawn_blocking_named(nc, function)
341    }
342
343    fn spawn_named<Fut, Name, NameClosure>(
344        &self,
345        nc: NameClosure,
346        future: Fut,
347    ) -> JoinHandle<Fut::Output>
348    where
349        Name: AsRef<str>,
350        NameClosure: FnOnce() -> Name,
351        Fut: Future + Send + 'static,
352        Fut::Output: Send + 'static,
353    {
354        (&**self).spawn_named(nc, future)
355    }
356}
357
358impl RuntimeExt for Handle {
359    fn spawn_blocking_named<Function, Output, Name, NameClosure>(
360        &self,
361        nc: NameClosure,
362        function: Function,
363    ) -> JoinHandle<Output>
364    where
365        Name: AsRef<str>,
366        NameClosure: FnOnce() -> Name,
367        Function: FnOnce() -> Output + Send + 'static,
368        Output: Send + 'static,
369    {
370        let _g = self.enter();
371        spawn_blocking(nc, function)
372    }
373
374    fn spawn_named<Fut, Name, NameClosure>(
375        &self,
376        nc: NameClosure,
377        future: Fut,
378    ) -> JoinHandle<Fut::Output>
379    where
380        Name: AsRef<str>,
381        NameClosure: FnOnce() -> Name,
382        Fut: Future + Send + 'static,
383        Fut::Output: Send + 'static,
384    {
385        let _g = self.enter();
386        spawn(nc, future)
387    }
388}
389
390/// Extension methods for [`tokio::task::JoinSet`].
391///
392/// See the [module][`self`] docs for more information.
393pub trait JoinSetExt<T> {
394    /// Spawns a new asynchronous task with a name.
395    ///
396    /// See [`tokio::task::spawn`] and the [module][`self`] docs for more
397    /// information.
398    #[track_caller]
399    fn spawn_named<Fut, Name, NameClosure>(
400        &mut self,
401        nc: NameClosure,
402        future: Fut,
403    ) -> tokio::task::AbortHandle
404    where
405        Name: AsRef<str>,
406        NameClosure: FnOnce() -> Name,
407        Fut: Future<Output = T> + Send + 'static,
408        T: Send + 'static;
409}
410
411impl<T> JoinSetExt<T> for tokio::task::JoinSet<T> {
412    // Allow unused variables until everything in ci uses `tokio_unstable`.
413    #[allow(unused_variables)]
414    fn spawn_named<Fut, Name, NameClosure>(
415        &mut self,
416        nc: NameClosure,
417        future: Fut,
418    ) -> tokio::task::AbortHandle
419    where
420        Name: AsRef<str>,
421        NameClosure: FnOnce() -> Name,
422        Fut: Future<Output = T> + Send + 'static,
423        T: Send + 'static,
424    {
425        // Box the future so tokio's task machinery is monomorphized over
426        // `Pin<Box<dyn Future<Output = T> + Send>>` per output type, rather
427        // than over every distinct future type at every call site. See the
428        // analogous comment on the top-level `spawn` for context.
429        let future: Pin<Box<dyn Future<Output = T> + Send>> = Box::pin(future);
430        #[cfg(tokio_unstable)]
431        #[allow(clippy::disallowed_methods)]
432        {
433            self.build_task()
434                .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
435                .spawn(future)
436                .expect("task spawning cannot fail")
437        }
438        #[cfg(not(tokio_unstable))]
439        #[allow(clippy::disallowed_methods)]
440        {
441            self.spawn(future)
442        }
443    }
444}