mz_ore/
task.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License in the LICENSE file at the
6// root of this repository, or online at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Tokio task utilities.
17//!
18//! ## Named task spawning
19//!
20//! The [`spawn`] and [`spawn_blocking`] methods are wrappers around
21//! [`tokio::task::spawn`] and [`tokio::task::spawn_blocking`] that attach a
22//! name the spawned task.
23//!
24//! If Clippy sent you here, replace:
25//!
26//! ```ignore
27//! tokio::task::spawn(my_future)
28//! tokio::task::spawn_blocking(my_blocking_closure)
29//! ```
30//!
31//! with:
32//!
33//! ```ignore
34//! mz_ore::task::spawn(|| format!("taskname:{}", info), my_future)
35//! mz_ore::task::spawn_blocking(|| format!("name:{}", info), my_blocking_closure)
36//! ```
37//!
38//! If you are using methods of the same names on a [`Runtime`] or [`Handle`],
39//! import [`RuntimeExt`] and replace `spawn` with [`RuntimeExt::spawn_named`]
40//! and `spawn_blocking` with [`RuntimeExt::spawn_blocking_named`], adding
41//! naming closures like above.
42
43use std::future::Future;
44use std::pin::Pin;
45use std::sync::Arc;
46use std::task::{Context, Poll};
47
48use futures::FutureExt;
49use tokio::runtime::{Handle, Runtime};
50use tokio::task::{self, JoinError, JoinHandle as TokioJoinHandle};
51
52/// Wraps a [`JoinHandle`] to abort the underlying task when dropped.
53#[derive(Debug)]
54pub struct AbortOnDropHandle<T>(TokioJoinHandle<T>);
55
56impl<T> AbortOnDropHandle<T> {
57    /// Checks if the task associated with this [`AbortOnDropHandle`] has finished.a
58    pub fn is_finished(&self) -> bool {
59        self.0.is_finished()
60    }
61
62    // Note: adding an `abort(&self)` method here is incorrect, please see `unpack_join_result`.
63}
64
65impl<T> Drop for AbortOnDropHandle<T> {
66    fn drop(&mut self) {
67        self.0.abort();
68    }
69}
70
71impl<T> Future for AbortOnDropHandle<T> {
72    type Output = Result<T, JoinError>;
73
74    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75        self.0.poll_unpin(cx)
76    }
77}
78
79/// Wraps a tokio `JoinHandle` and provides 4 exclusive (i.e. they take `self` ownership)
80/// operations:
81///
82/// - `abort_on_drop`: create an `AbortOnDropHandle` that will automatically abort the task
83/// when the handle is dropped.
84/// - `JoinHandleExt::wait_and_assert_finished`: wait for the task to finish and return its return value.
85/// - `JoinHandleExt::abort_and_wait`: abort the task and wait for it to be finished.
86/// - `into_tokio_handle`: turn it into an ordinary tokio `JoinHandle`.
87#[derive(Debug)]
88pub struct JoinHandle<T>(TokioJoinHandle<T>);
89
90impl<T> Future for JoinHandle<T> {
91    type Output = Result<T, JoinError>;
92
93    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
94        self.0.poll_unpin(cx)
95    }
96}
97
98impl<T> JoinHandle<T> {
99    /// Create an [`AbortOnDropHandle`] from this [`JoinHandle`].
100    pub fn abort_on_drop(self) -> AbortOnDropHandle<T> {
101        AbortOnDropHandle(self.0)
102    }
103
104    /// Checks if the task associated with this [`JoinHandle`] has finished.a
105    pub fn is_finished(&self) -> bool {
106        self.0.is_finished()
107    }
108
109    /// Checks if the task associated with this [`JoinHandle`] has finished.a
110    pub fn into_tokio_handle(self) -> TokioJoinHandle<T> {
111        self.0
112    }
113
114    // Note: adding an `abort(&self)` method here is incorrect, please see `unpack_join_result`.
115}
116
117/// Extension methods for [`JoinHandle`] and [`AbortOnDropHandle`].
118#[async_trait::async_trait]
119pub trait JoinHandleExt<T>: Future<Output = Result<T, JoinError>> {
120    /// Waits for the task to finish, resuming the unwind if the task panicked.
121    ///
122    /// Because this takes ownership of `self`, and [`JoinHandle`] and
123    /// [`AbortOnDropHandle`] don't offer `abort` methods, this can avoid
124    /// worrying about aborted tasks.
125    async fn wait_and_assert_finished(self) -> T;
126
127    /// Aborts the task, then waits for it to complete.
128    async fn abort_and_wait(self);
129}
130
131async fn unpack_join_result<T>(res: Result<T, JoinError>) -> T {
132    match res {
133        Ok(val) => val,
134        Err(err) => match err.try_into_panic() {
135            Ok(panic) => std::panic::resume_unwind(panic),
136            Err(_) => {
137                // Because `JoinHandle` and `AbortOnDropHandle` don't
138                // offer `abort` method, this can only happen if the runtime is
139                // shutting down, which means this `pending` won't cause a deadlock
140                // because Tokio drops all outstanding futures on shutdown.
141                // (In multi-threaded runtimes, not all threads drop futures simultaneously,
142                // so it is possible for a future on one thread to observe the drop of a future
143                // on another thread, before it itself is dropped.)
144                //
145                // Instead, we yield to tokio runtime. A single `yield_now` is not
146                // sufficient as a `select!` or `FuturesUnordered` may
147                // poll this multiple times during shutdown.
148                std::future::pending().await
149            }
150        },
151    }
152}
153
154#[async_trait::async_trait]
155impl<T: Send> JoinHandleExt<T> for JoinHandle<T> {
156    async fn wait_and_assert_finished(self) -> T {
157        unpack_join_result(self.await).await
158    }
159
160    async fn abort_and_wait(self) {
161        self.0.abort();
162        let _ = self.await;
163    }
164}
165
166#[async_trait::async_trait]
167impl<T: Send, J: JoinHandleExt<T> + Send> JoinHandleExt<T>
168    for tracing::instrument::Instrumented<J>
169{
170    async fn wait_and_assert_finished(self) -> T {
171        unpack_join_result(self.await).await
172    }
173
174    async fn abort_and_wait(self) {
175        self.abort_and_wait().await
176    }
177}
178
179#[async_trait::async_trait]
180impl<T: Send> JoinHandleExt<T> for AbortOnDropHandle<T> {
181    // Because we are sure the `AbortOnDropHandle` still exists when we call
182    // `unpack_join_result` is called, we know `abort` hasn't been called, so its
183    // safe to call.
184    async fn wait_and_assert_finished(self) -> T {
185        unpack_join_result(self.await).await
186    }
187
188    async fn abort_and_wait(self) {
189        self.0.abort();
190        let _ = self.await;
191    }
192}
193
194/// Spawns a new asynchronous task with a name.
195///
196/// See [`tokio::task::spawn`] and the [module][`self`] docs for more
197/// information.
198#[cfg(not(tokio_unstable))]
199#[track_caller]
200pub fn spawn<Fut, Name, NameClosure>(_nc: NameClosure, future: Fut) -> JoinHandle<Fut::Output>
201where
202    Name: AsRef<str>,
203    NameClosure: FnOnce() -> Name,
204    Fut: Future + Send + 'static,
205    Fut::Output: Send + 'static,
206{
207    #[allow(clippy::disallowed_methods)]
208    JoinHandle(tokio::spawn(future))
209}
210
211/// Spawns a new asynchronous task with a name.
212///
213/// See [`tokio::task::spawn`] and the [module][`self`] docs for more
214/// information.
215#[cfg(tokio_unstable)]
216#[track_caller]
217pub fn spawn<Fut, Name, NameClosure>(nc: NameClosure, future: Fut) -> JoinHandle<Fut::Output>
218where
219    Name: AsRef<str>,
220    NameClosure: FnOnce() -> Name,
221    Fut: Future + Send + 'static,
222    Fut::Output: Send + 'static,
223{
224    #[allow(clippy::disallowed_methods)]
225    JoinHandle(
226        task::Builder::new()
227            .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
228            .spawn(future)
229            .expect("task spawning cannot fail"),
230    )
231}
232
233/// Runs the provided closure with a name on a thread where blocking is
234/// acceptable.
235///
236/// See [`tokio::task::spawn_blocking`] and the [module][`self`] docs for more
237/// information.
238#[cfg(not(tokio_unstable))]
239#[track_caller]
240#[allow(clippy::disallowed_methods)]
241pub fn spawn_blocking<Function, Output, Name, NameClosure>(
242    _nc: NameClosure,
243    function: Function,
244) -> JoinHandle<Output>
245where
246    Name: AsRef<str>,
247    NameClosure: FnOnce() -> Name,
248    Function: FnOnce() -> Output + Send + 'static,
249    Output: Send + 'static,
250{
251    JoinHandle(task::spawn_blocking(function))
252}
253
254/// Runs the provided closure with a name on a thread where blocking is
255/// acceptable.
256///
257/// See [`tokio::task::spawn_blocking`] and the [module][`self`] docs for more
258/// information.
259#[cfg(tokio_unstable)]
260#[track_caller]
261#[allow(clippy::disallowed_methods)]
262pub fn spawn_blocking<Function, Output, Name, NameClosure>(
263    nc: NameClosure,
264    function: Function,
265) -> JoinHandle<Output>
266where
267    Name: AsRef<str>,
268    NameClosure: FnOnce() -> Name,
269    Function: FnOnce() -> Output + Send + 'static,
270    Output: Send + 'static,
271{
272    JoinHandle(
273        task::Builder::new()
274            .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
275            .spawn_blocking(function)
276            .expect("task spawning cannot fail"),
277    )
278}
279
280/// Extension methods for [`Runtime`] and [`Handle`].
281///
282/// See the [module][`self`] docs for more information.
283pub trait RuntimeExt {
284    /// Runs the provided closure with a name on a thread where blocking is
285    /// acceptable.
286    ///
287    /// See [`tokio::task::spawn_blocking`] and the [module][`self`] docs for more
288    /// information.
289    #[track_caller]
290    fn spawn_blocking_named<Function, Output, Name, NameClosure>(
291        &self,
292        nc: NameClosure,
293        function: Function,
294    ) -> JoinHandle<Output>
295    where
296        Name: AsRef<str>,
297        NameClosure: FnOnce() -> Name,
298        Function: FnOnce() -> Output + Send + 'static,
299        Output: Send + 'static;
300
301    /// Spawns a new asynchronous task with a name.
302    ///
303    /// See [`tokio::task::spawn`] and the [module][`self`] docs for more
304    /// information.
305    #[track_caller]
306    fn spawn_named<Fut, Name, NameClosure>(
307        &self,
308        _nc: NameClosure,
309        future: Fut,
310    ) -> JoinHandle<Fut::Output>
311    where
312        Name: AsRef<str>,
313        NameClosure: FnOnce() -> Name,
314        Fut: Future + Send + 'static,
315        Fut::Output: Send + 'static;
316}
317
318impl RuntimeExt for &Runtime {
319    fn spawn_blocking_named<Function, Output, Name, NameClosure>(
320        &self,
321        nc: NameClosure,
322        function: Function,
323    ) -> JoinHandle<Output>
324    where
325        Name: AsRef<str>,
326        NameClosure: FnOnce() -> Name,
327        Function: FnOnce() -> Output + Send + 'static,
328        Output: Send + 'static,
329    {
330        let _g = self.enter();
331        spawn_blocking(nc, function)
332    }
333
334    fn spawn_named<Fut, Name, NameClosure>(
335        &self,
336        nc: NameClosure,
337        future: Fut,
338    ) -> JoinHandle<Fut::Output>
339    where
340        Name: AsRef<str>,
341        NameClosure: FnOnce() -> Name,
342        Fut: Future + Send + 'static,
343        Fut::Output: Send + 'static,
344    {
345        let _g = self.enter();
346        spawn(nc, future)
347    }
348}
349
350impl RuntimeExt for Arc<Runtime> {
351    fn spawn_blocking_named<Function, Output, Name, NameClosure>(
352        &self,
353        nc: NameClosure,
354        function: Function,
355    ) -> JoinHandle<Output>
356    where
357        Name: AsRef<str>,
358        NameClosure: FnOnce() -> Name,
359        Function: FnOnce() -> Output + Send + 'static,
360        Output: Send + 'static,
361    {
362        (&**self).spawn_blocking_named(nc, function)
363    }
364
365    fn spawn_named<Fut, Name, NameClosure>(
366        &self,
367        nc: NameClosure,
368        future: Fut,
369    ) -> JoinHandle<Fut::Output>
370    where
371        Name: AsRef<str>,
372        NameClosure: FnOnce() -> Name,
373        Fut: Future + Send + 'static,
374        Fut::Output: Send + 'static,
375    {
376        (&**self).spawn_named(nc, future)
377    }
378}
379
380impl RuntimeExt for Handle {
381    fn spawn_blocking_named<Function, Output, Name, NameClosure>(
382        &self,
383        nc: NameClosure,
384        function: Function,
385    ) -> JoinHandle<Output>
386    where
387        Name: AsRef<str>,
388        NameClosure: FnOnce() -> Name,
389        Function: FnOnce() -> Output + Send + 'static,
390        Output: Send + 'static,
391    {
392        let _g = self.enter();
393        spawn_blocking(nc, function)
394    }
395
396    fn spawn_named<Fut, Name, NameClosure>(
397        &self,
398        nc: NameClosure,
399        future: Fut,
400    ) -> JoinHandle<Fut::Output>
401    where
402        Name: AsRef<str>,
403        NameClosure: FnOnce() -> Name,
404        Fut: Future + Send + 'static,
405        Fut::Output: Send + 'static,
406    {
407        let _g = self.enter();
408        spawn(nc, future)
409    }
410}
411
412/// Extension methods for [`tokio::task::JoinSet`].
413///
414/// See the [module][`self`] docs for more information.
415pub trait JoinSetExt<T> {
416    /// Spawns a new asynchronous task with a name.
417    ///
418    /// See [`tokio::task::spawn`] and the [module][`self`] docs for more
419    /// information.
420    #[track_caller]
421    fn spawn_named<Fut, Name, NameClosure>(
422        &mut self,
423        nc: NameClosure,
424        future: Fut,
425    ) -> tokio::task::AbortHandle
426    where
427        Name: AsRef<str>,
428        NameClosure: FnOnce() -> Name,
429        Fut: Future<Output = T> + Send + 'static,
430        T: Send + 'static;
431}
432
433impl<T> JoinSetExt<T> for tokio::task::JoinSet<T> {
434    // Allow unused variables until everything in ci uses `tokio_unstable`.
435    #[allow(unused_variables)]
436    fn spawn_named<Fut, Name, NameClosure>(
437        &mut self,
438        nc: NameClosure,
439        future: Fut,
440    ) -> tokio::task::AbortHandle
441    where
442        Name: AsRef<str>,
443        NameClosure: FnOnce() -> Name,
444        Fut: Future<Output = T> + Send + 'static,
445        T: Send + 'static,
446    {
447        #[cfg(tokio_unstable)]
448        #[allow(clippy::disallowed_methods)]
449        {
450            self.build_task()
451                .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
452                .spawn(future)
453                .expect("task spawning cannot fail")
454        }
455        #[cfg(not(tokio_unstable))]
456        #[allow(clippy::disallowed_methods)]
457        {
458            self.spawn(future)
459        }
460    }
461}