1use std::future::Future;
44use std::pin::Pin;
45use std::sync::Arc;
46use std::task::{Context, Poll};
47
48use futures::FutureExt;
49use tokio::runtime::{Handle, Runtime};
50use tokio::task::{self, JoinHandle as TokioJoinHandle};
51
52#[derive(Debug)]
54pub struct AbortOnDropHandle<T>(JoinHandle<T>);
55
56impl<T> AbortOnDropHandle<T> {
57 pub fn is_finished(&self) -> bool {
59 self.0.inner.is_finished()
60 }
61
62 }
64
65impl<T> Drop for AbortOnDropHandle<T> {
66 fn drop(&mut self) {
67 self.0.inner.abort();
68 }
69}
70
71impl<T> Future for AbortOnDropHandle<T> {
72 type Output = T;
73
74 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75 self.0.poll_unpin(cx)
76 }
77}
78
79#[derive(Debug)]
89pub struct JoinHandle<T> {
90 inner: TokioJoinHandle<T>,
91 runtime_shutting_down: bool,
92}
93
94impl<T> JoinHandle<T> {
95 fn new(handle: TokioJoinHandle<T>) -> Self {
98 Self {
99 inner: handle,
100 runtime_shutting_down: false,
101 }
102 }
103}
104
105impl<T> Future for JoinHandle<T> {
106 type Output = T;
107
108 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
109 if self.runtime_shutting_down {
110 return Poll::Pending;
111 }
112 match self.inner.poll_unpin(cx) {
113 Poll::Ready(Ok(res)) => Poll::Ready(res),
114 Poll::Ready(Err(err)) => {
115 match err.try_into_panic() {
116 Ok(panic) => std::panic::resume_unwind(panic),
117 Err(err) => {
118 assert!(
119 err.is_cancelled(),
120 "join errors are either cancellations or panics"
121 );
122 self.runtime_shutting_down = true;
130 Poll::Pending
131 }
132 }
133 }
134 Poll::Pending => Poll::Pending,
135 }
136 }
137}
138
139impl<T> JoinHandle<T> {
140 pub fn abort_on_drop(self) -> AbortOnDropHandle<T> {
142 AbortOnDropHandle(self)
143 }
144
145 pub fn is_finished(&self) -> bool {
147 self.inner.is_finished()
148 }
149
150 pub async fn abort_and_wait(self) {
152 self.inner.abort();
153 let _ = self.inner.await;
154 }
155
156 pub fn into_tokio_handle(self) -> TokioJoinHandle<T> {
158 self.inner
159 }
160
161 }
163
164#[cfg(not(tokio_unstable))]
169#[track_caller]
170pub fn spawn<Fut, Name, NameClosure>(_nc: NameClosure, future: Fut) -> JoinHandle<Fut::Output>
171where
172 Name: AsRef<str>,
173 NameClosure: FnOnce() -> Name,
174 Fut: Future + Send + 'static,
175 Fut::Output: Send + 'static,
176{
177 #[allow(clippy::disallowed_methods)]
178 JoinHandle::new(tokio::spawn(future))
179}
180
181#[cfg(tokio_unstable)]
186#[track_caller]
187pub fn spawn<Fut, Name, NameClosure>(nc: NameClosure, future: Fut) -> JoinHandle<Fut::Output>
188where
189 Name: AsRef<str>,
190 NameClosure: FnOnce() -> Name,
191 Fut: Future + Send + 'static,
192 Fut::Output: Send + 'static,
193{
194 #[allow(clippy::disallowed_methods)]
195 JoinHandle::new(
196 task::Builder::new()
197 .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
198 .spawn(future)
199 .expect("task spawning cannot fail"),
200 )
201}
202
203#[cfg(not(tokio_unstable))]
209#[track_caller]
210#[allow(clippy::disallowed_methods)]
211pub fn spawn_blocking<Function, Output, Name, NameClosure>(
212 _nc: NameClosure,
213 function: Function,
214) -> JoinHandle<Output>
215where
216 Name: AsRef<str>,
217 NameClosure: FnOnce() -> Name,
218 Function: FnOnce() -> Output + Send + 'static,
219 Output: Send + 'static,
220{
221 JoinHandle::new(task::spawn_blocking(function))
222}
223
224#[cfg(tokio_unstable)]
230#[track_caller]
231#[allow(clippy::disallowed_methods)]
232pub fn spawn_blocking<Function, Output, Name, NameClosure>(
233 nc: NameClosure,
234 function: Function,
235) -> JoinHandle<Output>
236where
237 Name: AsRef<str>,
238 NameClosure: FnOnce() -> Name,
239 Function: FnOnce() -> Output + Send + 'static,
240 Output: Send + 'static,
241{
242 JoinHandle::new(
243 task::Builder::new()
244 .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
245 .spawn_blocking(function)
246 .expect("task spawning cannot fail"),
247 )
248}
249
250pub trait RuntimeExt {
254 #[track_caller]
260 fn spawn_blocking_named<Function, Output, Name, NameClosure>(
261 &self,
262 nc: NameClosure,
263 function: Function,
264 ) -> JoinHandle<Output>
265 where
266 Name: AsRef<str>,
267 NameClosure: FnOnce() -> Name,
268 Function: FnOnce() -> Output + Send + 'static,
269 Output: Send + 'static;
270
271 #[track_caller]
276 fn spawn_named<Fut, Name, NameClosure>(
277 &self,
278 _nc: NameClosure,
279 future: Fut,
280 ) -> JoinHandle<Fut::Output>
281 where
282 Name: AsRef<str>,
283 NameClosure: FnOnce() -> Name,
284 Fut: Future + Send + 'static,
285 Fut::Output: Send + 'static;
286}
287
288impl RuntimeExt for &Runtime {
289 fn spawn_blocking_named<Function, Output, Name, NameClosure>(
290 &self,
291 nc: NameClosure,
292 function: Function,
293 ) -> JoinHandle<Output>
294 where
295 Name: AsRef<str>,
296 NameClosure: FnOnce() -> Name,
297 Function: FnOnce() -> Output + Send + 'static,
298 Output: Send + 'static,
299 {
300 let _g = self.enter();
301 spawn_blocking(nc, function)
302 }
303
304 fn spawn_named<Fut, Name, NameClosure>(
305 &self,
306 nc: NameClosure,
307 future: Fut,
308 ) -> JoinHandle<Fut::Output>
309 where
310 Name: AsRef<str>,
311 NameClosure: FnOnce() -> Name,
312 Fut: Future + Send + 'static,
313 Fut::Output: Send + 'static,
314 {
315 let _g = self.enter();
316 spawn(nc, future)
317 }
318}
319
320impl RuntimeExt for Arc<Runtime> {
321 fn spawn_blocking_named<Function, Output, Name, NameClosure>(
322 &self,
323 nc: NameClosure,
324 function: Function,
325 ) -> JoinHandle<Output>
326 where
327 Name: AsRef<str>,
328 NameClosure: FnOnce() -> Name,
329 Function: FnOnce() -> Output + Send + 'static,
330 Output: Send + 'static,
331 {
332 (&**self).spawn_blocking_named(nc, function)
333 }
334
335 fn spawn_named<Fut, Name, NameClosure>(
336 &self,
337 nc: NameClosure,
338 future: Fut,
339 ) -> JoinHandle<Fut::Output>
340 where
341 Name: AsRef<str>,
342 NameClosure: FnOnce() -> Name,
343 Fut: Future + Send + 'static,
344 Fut::Output: Send + 'static,
345 {
346 (&**self).spawn_named(nc, future)
347 }
348}
349
350impl RuntimeExt for Handle {
351 fn spawn_blocking_named<Function, Output, Name, NameClosure>(
352 &self,
353 nc: NameClosure,
354 function: Function,
355 ) -> JoinHandle<Output>
356 where
357 Name: AsRef<str>,
358 NameClosure: FnOnce() -> Name,
359 Function: FnOnce() -> Output + Send + 'static,
360 Output: Send + 'static,
361 {
362 let _g = self.enter();
363 spawn_blocking(nc, function)
364 }
365
366 fn spawn_named<Fut, Name, NameClosure>(
367 &self,
368 nc: NameClosure,
369 future: Fut,
370 ) -> JoinHandle<Fut::Output>
371 where
372 Name: AsRef<str>,
373 NameClosure: FnOnce() -> Name,
374 Fut: Future + Send + 'static,
375 Fut::Output: Send + 'static,
376 {
377 let _g = self.enter();
378 spawn(nc, future)
379 }
380}
381
382pub trait JoinSetExt<T> {
386 #[track_caller]
391 fn spawn_named<Fut, Name, NameClosure>(
392 &mut self,
393 nc: NameClosure,
394 future: Fut,
395 ) -> tokio::task::AbortHandle
396 where
397 Name: AsRef<str>,
398 NameClosure: FnOnce() -> Name,
399 Fut: Future<Output = T> + Send + 'static,
400 T: Send + 'static;
401}
402
403impl<T> JoinSetExt<T> for tokio::task::JoinSet<T> {
404 #[allow(unused_variables)]
406 fn spawn_named<Fut, Name, NameClosure>(
407 &mut self,
408 nc: NameClosure,
409 future: Fut,
410 ) -> tokio::task::AbortHandle
411 where
412 Name: AsRef<str>,
413 NameClosure: FnOnce() -> Name,
414 Fut: Future<Output = T> + Send + 'static,
415 T: Send + 'static,
416 {
417 #[cfg(tokio_unstable)]
418 #[allow(clippy::disallowed_methods)]
419 {
420 self.build_task()
421 .name(&format!("{}:{}", Handle::current().id(), nc().as_ref()))
422 .spawn(future)
423 .expect("task spawning cannot fail")
424 }
425 #[cfg(not(tokio_unstable))]
426 #[allow(clippy::disallowed_methods)]
427 {
428 self.spawn(future)
429 }
430 }
431}