1use std::future::Future;
44use std::pin::Pin;
45use std::sync::Arc;
46use std::task::{Context, Poll};
47
48use futures::FutureExt;
49use tokio::runtime::{Handle, Runtime};
50use tokio::task::{self, JoinHandle as TokioJoinHandle};
51
52#[derive(Debug)]
54pub struct AbortOnDropHandle<T>(JoinHandle<T>);
55
56impl<T> AbortOnDropHandle<T> {
57 pub fn is_finished(&self) -> bool {
59 self.0.inner.is_finished()
60 }
61
62 }
64
65impl<T> Drop for AbortOnDropHandle<T> {
66 fn drop(&mut self) {
67 self.0.inner.abort();
68 }
69}
70
71impl<T> Future for AbortOnDropHandle<T> {
72 type Output = T;
73
74 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75 self.0.poll_unpin(cx)
76 }
77}
78
79#[derive(Debug)]
89pub struct JoinHandle<T> {
90 inner: TokioJoinHandle<T>,
91 runtime_shutting_down: bool,
92}
93
94impl<T> JoinHandle<T> {
95 fn new(handle: TokioJoinHandle<T>) -> Self {
98 Self {
99 inner: handle,
100 runtime_shutting_down: false,
101 }
102 }
103}
104
105impl<T> Future for JoinHandle<T> {
106 type Output = T;
107
108 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
109 if self.runtime_shutting_down {
110 return Poll::Pending;
111 }
112 match self.inner.poll_unpin(cx) {
113 Poll::Ready(Ok(res)) => Poll::Ready(res),
114 Poll::Ready(Err(err)) => {
115 match err.try_into_panic() {
116 Ok(panic) => std::panic::resume_unwind(panic),
117 Err(err) => {
118 assert!(
119 err.is_cancelled(),
120 "join errors are either cancellations or panics"
121 );
122 self.runtime_shutting_down = true;
130 Poll::Pending
131 }
132 }
133 }
134 Poll::Pending => Poll::Pending,
135 }
136 }
137}
138
139impl<T> JoinHandle<T> {
140 pub fn abort_on_drop(self) -> AbortOnDropHandle<T> {
142 AbortOnDropHandle(self)
143 }
144
145 pub fn is_finished(&self) -> bool {
147 self.inner.is_finished()
148 }
149
150 pub async fn abort_and_wait(self) {
152 self.inner.abort();
153 let _ = self.inner.await;
154 }
155
156 pub fn into_tokio_handle(self) -> TokioJoinHandle<T> {
158 self.inner
159 }
160
161 }
163
164#[cfg(not(tokio_unstable))]
169#[track_caller]
170pub fn spawn<Fut, Name, NameClosure>(_nc: NameClosure, future: Fut) -> JoinHandle<Fut::Output>
171where
172 Name: AsRef<str>,
173 NameClosure: FnOnce() -> Name,
174 Fut: Future + Send + 'static,
175 Fut::Output: Send + 'static,
176{
177 let future: Pin<Box<dyn Future<Output = Fut::Output> + Send>> = Box::pin(future);
181 #[allow(clippy::disallowed_methods)]
182 JoinHandle::new(tokio::spawn(future))
183}
184
185#[cfg(tokio_unstable)]
190#[track_caller]
191pub fn spawn<Fut, Name, NameClosure>(nc: NameClosure, future: Fut) -> JoinHandle<Fut::Output>
192where
193 Name: AsRef<str>,
194 NameClosure: FnOnce() -> Name,
195 Fut: Future + Send + 'static,
196 Fut::Output: Send + 'static,
197{
198 let future: Pin<Box<dyn Future<Output = Fut::Output> + Send>> = Box::pin(future);
202 #[allow(clippy::disallowed_methods)]
203 JoinHandle::new(
204 task::Builder::new()
205 .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
206 .spawn(future)
207 .expect("task spawning cannot fail"),
208 )
209}
210
211#[cfg(not(tokio_unstable))]
217#[track_caller]
218#[allow(clippy::disallowed_methods)]
219pub fn spawn_blocking<Function, Output, Name, NameClosure>(
220 _nc: NameClosure,
221 function: Function,
222) -> JoinHandle<Output>
223where
224 Name: AsRef<str>,
225 NameClosure: FnOnce() -> Name,
226 Function: FnOnce() -> Output + Send + 'static,
227 Output: Send + 'static,
228{
229 JoinHandle::new(task::spawn_blocking(function))
230}
231
232#[cfg(tokio_unstable)]
238#[track_caller]
239#[allow(clippy::disallowed_methods)]
240pub fn spawn_blocking<Function, Output, Name, NameClosure>(
241 nc: NameClosure,
242 function: Function,
243) -> JoinHandle<Output>
244where
245 Name: AsRef<str>,
246 NameClosure: FnOnce() -> Name,
247 Function: FnOnce() -> Output + Send + 'static,
248 Output: Send + 'static,
249{
250 JoinHandle::new(
251 task::Builder::new()
252 .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
253 .spawn_blocking(function)
254 .expect("task spawning cannot fail"),
255 )
256}
257
258pub trait RuntimeExt {
262 #[track_caller]
268 fn spawn_blocking_named<Function, Output, Name, NameClosure>(
269 &self,
270 nc: NameClosure,
271 function: Function,
272 ) -> JoinHandle<Output>
273 where
274 Name: AsRef<str>,
275 NameClosure: FnOnce() -> Name,
276 Function: FnOnce() -> Output + Send + 'static,
277 Output: Send + 'static;
278
279 #[track_caller]
284 fn spawn_named<Fut, Name, NameClosure>(
285 &self,
286 _nc: NameClosure,
287 future: Fut,
288 ) -> JoinHandle<Fut::Output>
289 where
290 Name: AsRef<str>,
291 NameClosure: FnOnce() -> Name,
292 Fut: Future + Send + 'static,
293 Fut::Output: Send + 'static;
294}
295
296impl RuntimeExt for &Runtime {
297 fn spawn_blocking_named<Function, Output, Name, NameClosure>(
298 &self,
299 nc: NameClosure,
300 function: Function,
301 ) -> JoinHandle<Output>
302 where
303 Name: AsRef<str>,
304 NameClosure: FnOnce() -> Name,
305 Function: FnOnce() -> Output + Send + 'static,
306 Output: Send + 'static,
307 {
308 let _g = self.enter();
309 spawn_blocking(nc, function)
310 }
311
312 fn spawn_named<Fut, Name, NameClosure>(
313 &self,
314 nc: NameClosure,
315 future: Fut,
316 ) -> JoinHandle<Fut::Output>
317 where
318 Name: AsRef<str>,
319 NameClosure: FnOnce() -> Name,
320 Fut: Future + Send + 'static,
321 Fut::Output: Send + 'static,
322 {
323 let _g = self.enter();
324 spawn(nc, future)
325 }
326}
327
328impl RuntimeExt for Arc<Runtime> {
329 fn spawn_blocking_named<Function, Output, Name, NameClosure>(
330 &self,
331 nc: NameClosure,
332 function: Function,
333 ) -> JoinHandle<Output>
334 where
335 Name: AsRef<str>,
336 NameClosure: FnOnce() -> Name,
337 Function: FnOnce() -> Output + Send + 'static,
338 Output: Send + 'static,
339 {
340 (&**self).spawn_blocking_named(nc, function)
341 }
342
343 fn spawn_named<Fut, Name, NameClosure>(
344 &self,
345 nc: NameClosure,
346 future: Fut,
347 ) -> JoinHandle<Fut::Output>
348 where
349 Name: AsRef<str>,
350 NameClosure: FnOnce() -> Name,
351 Fut: Future + Send + 'static,
352 Fut::Output: Send + 'static,
353 {
354 (&**self).spawn_named(nc, future)
355 }
356}
357
358impl RuntimeExt for Handle {
359 fn spawn_blocking_named<Function, Output, Name, NameClosure>(
360 &self,
361 nc: NameClosure,
362 function: Function,
363 ) -> JoinHandle<Output>
364 where
365 Name: AsRef<str>,
366 NameClosure: FnOnce() -> Name,
367 Function: FnOnce() -> Output + Send + 'static,
368 Output: Send + 'static,
369 {
370 let _g = self.enter();
371 spawn_blocking(nc, function)
372 }
373
374 fn spawn_named<Fut, Name, NameClosure>(
375 &self,
376 nc: NameClosure,
377 future: Fut,
378 ) -> JoinHandle<Fut::Output>
379 where
380 Name: AsRef<str>,
381 NameClosure: FnOnce() -> Name,
382 Fut: Future + Send + 'static,
383 Fut::Output: Send + 'static,
384 {
385 let _g = self.enter();
386 spawn(nc, future)
387 }
388}
389
390pub trait JoinSetExt<T> {
394 #[track_caller]
399 fn spawn_named<Fut, Name, NameClosure>(
400 &mut self,
401 nc: NameClosure,
402 future: Fut,
403 ) -> tokio::task::AbortHandle
404 where
405 Name: AsRef<str>,
406 NameClosure: FnOnce() -> Name,
407 Fut: Future<Output = T> + Send + 'static,
408 T: Send + 'static;
409}
410
411impl<T> JoinSetExt<T> for tokio::task::JoinSet<T> {
412 #[allow(unused_variables)]
414 fn spawn_named<Fut, Name, NameClosure>(
415 &mut self,
416 nc: NameClosure,
417 future: Fut,
418 ) -> tokio::task::AbortHandle
419 where
420 Name: AsRef<str>,
421 NameClosure: FnOnce() -> Name,
422 Fut: Future<Output = T> + Send + 'static,
423 T: Send + 'static,
424 {
425 let future: Pin<Box<dyn Future<Output = T> + Send>> = Box::pin(future);
430 #[cfg(tokio_unstable)]
431 #[allow(clippy::disallowed_methods)]
432 {
433 self.build_task()
434 .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
435 .spawn(future)
436 .expect("task spawning cannot fail")
437 }
438 #[cfg(not(tokio_unstable))]
439 #[allow(clippy::disallowed_methods)]
440 {
441 self.spawn(future)
442 }
443 }
444}