1use std::future::Future;
44use std::pin::Pin;
45use std::sync::Arc;
46use std::task::{Context, Poll};
47
48use futures::FutureExt;
49use tokio::runtime::{Handle, Runtime};
50use tokio::task::{self, JoinError, JoinHandle as TokioJoinHandle};
51
52#[derive(Debug)]
54pub struct AbortOnDropHandle<T>(TokioJoinHandle<T>);
55
56impl<T> AbortOnDropHandle<T> {
57 pub fn is_finished(&self) -> bool {
59 self.0.is_finished()
60 }
61
62 }
64
65impl<T> Drop for AbortOnDropHandle<T> {
66 fn drop(&mut self) {
67 self.0.abort();
68 }
69}
70
71impl<T> Future for AbortOnDropHandle<T> {
72 type Output = Result<T, JoinError>;
73
74 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75 self.0.poll_unpin(cx)
76 }
77}
78
79#[derive(Debug)]
88pub struct JoinHandle<T>(TokioJoinHandle<T>);
89
90impl<T> Future for JoinHandle<T> {
91 type Output = Result<T, JoinError>;
92
93 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
94 self.0.poll_unpin(cx)
95 }
96}
97
98impl<T> JoinHandle<T> {
99 pub fn abort_on_drop(self) -> AbortOnDropHandle<T> {
101 AbortOnDropHandle(self.0)
102 }
103
104 pub fn is_finished(&self) -> bool {
106 self.0.is_finished()
107 }
108
109 pub fn into_tokio_handle(self) -> TokioJoinHandle<T> {
111 self.0
112 }
113
114 }
116
117#[async_trait::async_trait]
119pub trait JoinHandleExt<T>: Future<Output = Result<T, JoinError>> {
120 async fn wait_and_assert_finished(self) -> T;
126
127 async fn abort_and_wait(self);
129}
130
131async fn unpack_join_result<T>(res: Result<T, JoinError>) -> T {
132 match res {
133 Ok(val) => val,
134 Err(err) => match err.try_into_panic() {
135 Ok(panic) => std::panic::resume_unwind(panic),
136 Err(_) => {
137 std::future::pending().await
149 }
150 },
151 }
152}
153
154#[async_trait::async_trait]
155impl<T: Send> JoinHandleExt<T> for JoinHandle<T> {
156 async fn wait_and_assert_finished(self) -> T {
157 unpack_join_result(self.await).await
158 }
159
160 async fn abort_and_wait(self) {
161 self.0.abort();
162 let _ = self.await;
163 }
164}
165
166#[async_trait::async_trait]
167impl<T: Send, J: JoinHandleExt<T> + Send> JoinHandleExt<T>
168 for tracing::instrument::Instrumented<J>
169{
170 async fn wait_and_assert_finished(self) -> T {
171 unpack_join_result(self.await).await
172 }
173
174 async fn abort_and_wait(self) {
175 self.abort_and_wait().await
176 }
177}
178
179#[async_trait::async_trait]
180impl<T: Send> JoinHandleExt<T> for AbortOnDropHandle<T> {
181 async fn wait_and_assert_finished(self) -> T {
185 unpack_join_result(self.await).await
186 }
187
188 async fn abort_and_wait(self) {
189 self.0.abort();
190 let _ = self.await;
191 }
192}
193
194#[cfg(not(tokio_unstable))]
199#[track_caller]
200pub fn spawn<Fut, Name, NameClosure>(_nc: NameClosure, future: Fut) -> JoinHandle<Fut::Output>
201where
202 Name: AsRef<str>,
203 NameClosure: FnOnce() -> Name,
204 Fut: Future + Send + 'static,
205 Fut::Output: Send + 'static,
206{
207 #[allow(clippy::disallowed_methods)]
208 JoinHandle(tokio::spawn(future))
209}
210
211#[cfg(tokio_unstable)]
216#[track_caller]
217pub fn spawn<Fut, Name, NameClosure>(nc: NameClosure, future: Fut) -> JoinHandle<Fut::Output>
218where
219 Name: AsRef<str>,
220 NameClosure: FnOnce() -> Name,
221 Fut: Future + Send + 'static,
222 Fut::Output: Send + 'static,
223{
224 #[allow(clippy::disallowed_methods)]
225 JoinHandle(
226 task::Builder::new()
227 .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
228 .spawn(future)
229 .expect("task spawning cannot fail"),
230 )
231}
232
233#[cfg(not(tokio_unstable))]
239#[track_caller]
240#[allow(clippy::disallowed_methods)]
241pub fn spawn_blocking<Function, Output, Name, NameClosure>(
242 _nc: NameClosure,
243 function: Function,
244) -> JoinHandle<Output>
245where
246 Name: AsRef<str>,
247 NameClosure: FnOnce() -> Name,
248 Function: FnOnce() -> Output + Send + 'static,
249 Output: Send + 'static,
250{
251 JoinHandle(task::spawn_blocking(function))
252}
253
254#[cfg(tokio_unstable)]
260#[track_caller]
261#[allow(clippy::disallowed_methods)]
262pub fn spawn_blocking<Function, Output, Name, NameClosure>(
263 nc: NameClosure,
264 function: Function,
265) -> JoinHandle<Output>
266where
267 Name: AsRef<str>,
268 NameClosure: FnOnce() -> Name,
269 Function: FnOnce() -> Output + Send + 'static,
270 Output: Send + 'static,
271{
272 JoinHandle(
273 task::Builder::new()
274 .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
275 .spawn_blocking(function)
276 .expect("task spawning cannot fail"),
277 )
278}
279
280pub trait RuntimeExt {
284 #[track_caller]
290 fn spawn_blocking_named<Function, Output, Name, NameClosure>(
291 &self,
292 nc: NameClosure,
293 function: Function,
294 ) -> JoinHandle<Output>
295 where
296 Name: AsRef<str>,
297 NameClosure: FnOnce() -> Name,
298 Function: FnOnce() -> Output + Send + 'static,
299 Output: Send + 'static;
300
301 #[track_caller]
306 fn spawn_named<Fut, Name, NameClosure>(
307 &self,
308 _nc: NameClosure,
309 future: Fut,
310 ) -> JoinHandle<Fut::Output>
311 where
312 Name: AsRef<str>,
313 NameClosure: FnOnce() -> Name,
314 Fut: Future + Send + 'static,
315 Fut::Output: Send + 'static;
316}
317
318impl RuntimeExt for &Runtime {
319 fn spawn_blocking_named<Function, Output, Name, NameClosure>(
320 &self,
321 nc: NameClosure,
322 function: Function,
323 ) -> JoinHandle<Output>
324 where
325 Name: AsRef<str>,
326 NameClosure: FnOnce() -> Name,
327 Function: FnOnce() -> Output + Send + 'static,
328 Output: Send + 'static,
329 {
330 let _g = self.enter();
331 spawn_blocking(nc, function)
332 }
333
334 fn spawn_named<Fut, Name, NameClosure>(
335 &self,
336 nc: NameClosure,
337 future: Fut,
338 ) -> JoinHandle<Fut::Output>
339 where
340 Name: AsRef<str>,
341 NameClosure: FnOnce() -> Name,
342 Fut: Future + Send + 'static,
343 Fut::Output: Send + 'static,
344 {
345 let _g = self.enter();
346 spawn(nc, future)
347 }
348}
349
350impl RuntimeExt for Arc<Runtime> {
351 fn spawn_blocking_named<Function, Output, Name, NameClosure>(
352 &self,
353 nc: NameClosure,
354 function: Function,
355 ) -> JoinHandle<Output>
356 where
357 Name: AsRef<str>,
358 NameClosure: FnOnce() -> Name,
359 Function: FnOnce() -> Output + Send + 'static,
360 Output: Send + 'static,
361 {
362 (&**self).spawn_blocking_named(nc, function)
363 }
364
365 fn spawn_named<Fut, Name, NameClosure>(
366 &self,
367 nc: NameClosure,
368 future: Fut,
369 ) -> JoinHandle<Fut::Output>
370 where
371 Name: AsRef<str>,
372 NameClosure: FnOnce() -> Name,
373 Fut: Future + Send + 'static,
374 Fut::Output: Send + 'static,
375 {
376 (&**self).spawn_named(nc, future)
377 }
378}
379
380impl RuntimeExt for Handle {
381 fn spawn_blocking_named<Function, Output, Name, NameClosure>(
382 &self,
383 nc: NameClosure,
384 function: Function,
385 ) -> JoinHandle<Output>
386 where
387 Name: AsRef<str>,
388 NameClosure: FnOnce() -> Name,
389 Function: FnOnce() -> Output + Send + 'static,
390 Output: Send + 'static,
391 {
392 let _g = self.enter();
393 spawn_blocking(nc, function)
394 }
395
396 fn spawn_named<Fut, Name, NameClosure>(
397 &self,
398 nc: NameClosure,
399 future: Fut,
400 ) -> JoinHandle<Fut::Output>
401 where
402 Name: AsRef<str>,
403 NameClosure: FnOnce() -> Name,
404 Fut: Future + Send + 'static,
405 Fut::Output: Send + 'static,
406 {
407 let _g = self.enter();
408 spawn(nc, future)
409 }
410}
411
412pub trait JoinSetExt<T> {
416 #[track_caller]
421 fn spawn_named<Fut, Name, NameClosure>(
422 &mut self,
423 nc: NameClosure,
424 future: Fut,
425 ) -> tokio::task::AbortHandle
426 where
427 Name: AsRef<str>,
428 NameClosure: FnOnce() -> Name,
429 Fut: Future<Output = T> + Send + 'static,
430 T: Send + 'static;
431}
432
433impl<T> JoinSetExt<T> for tokio::task::JoinSet<T> {
434 #[allow(unused_variables)]
436 fn spawn_named<Fut, Name, NameClosure>(
437 &mut self,
438 nc: NameClosure,
439 future: Fut,
440 ) -> tokio::task::AbortHandle
441 where
442 Name: AsRef<str>,
443 NameClosure: FnOnce() -> Name,
444 Fut: Future<Output = T> + Send + 'static,
445 T: Send + 'static,
446 {
447 #[cfg(tokio_unstable)]
448 #[allow(clippy::disallowed_methods)]
449 {
450 self.build_task()
451 .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
452 .spawn(future)
453 .expect("task spawning cannot fail")
454 }
455 #[cfg(not(tokio_unstable))]
456 #[allow(clippy::disallowed_methods)]
457 {
458 self.spawn(future)
459 }
460 }
461}