1use std::any::Any;
22use std::error::Error;
23use std::fmt::{self, Debug};
24use std::future::Future;
25use std::marker::PhantomData;
26use std::panic::UnwindSafe;
27use std::pin::Pin;
28use std::task::{Context, Poll};
29
30use async_trait::async_trait;
31use futures::Stream;
32use futures::future::{CatchUnwind, FutureExt};
33use futures::sink::Sink;
34use pin_project::pin_project;
35use tokio::task::futures::TaskLocalFuture;
36use tokio::time::{self, Duration, Instant};
37
38use crate::task::{self, JoinHandleExt};
39
40#[derive(Clone, Copy, Debug)]
42pub enum InTask {
43 Yes,
45 No,
47}
48
49#[async_trait::async_trait]
51pub trait OreFutureExt {
52 fn spawn_if_canceled<Name, NameClosure>(
55 self,
56 nc: NameClosure,
57 ) -> SpawnIfCanceled<Self::Output, Name, NameClosure>
58 where
59 Name: AsRef<str>,
60 NameClosure: FnOnce() -> Name + Unpin,
61 Self: Future + Send + 'static,
62 Self::Output: Send + 'static;
63
64 async fn run_in_task<Name, NameClosure>(self, nc: NameClosure) -> Self::Output
71 where
72 Name: AsRef<str>,
73 NameClosure: FnOnce() -> Name + Unpin + Send,
74 Self: Future + Send + 'static,
75 Self::Output: Send + 'static;
76
77 async fn run_in_task_if<Name, NameClosure>(
81 self,
82 in_task: InTask,
83 nc: NameClosure,
84 ) -> Self::Output
85 where
86 Name: AsRef<str>,
87 NameClosure: FnOnce() -> Name + Unpin + Send,
88 Self: Future + Send + 'static,
89 Self::Output: Send + 'static;
90
91 #[cfg(feature = "panic")]
96 fn ore_catch_unwind(self) -> OreCatchUnwind<Self>
97 where
98 Self: Sized + UnwindSafe;
99}
100
101#[async_trait::async_trait]
102impl<T> OreFutureExt for T
103where
104 T: Future,
105{
106 fn spawn_if_canceled<Name, NameClosure>(
107 self,
108 nc: NameClosure,
109 ) -> SpawnIfCanceled<T::Output, Name, NameClosure>
110 where
111 Name: AsRef<str>,
112 NameClosure: FnOnce() -> Name + Unpin,
113 T: Send + 'static,
114 T::Output: Send + 'static,
115 {
116 SpawnIfCanceled {
117 inner: Some(Box::pin(self)),
118 nc: Some(nc),
119 }
120 }
121
122 async fn run_in_task<Name, NameClosure>(self, nc: NameClosure) -> T::Output
123 where
124 Name: AsRef<str>,
125 NameClosure: FnOnce() -> Name + Unpin + Send,
126 T: Send + 'static,
127 T::Output: Send + 'static,
128 {
129 task::spawn(nc, self).wait_and_assert_finished().await
130 }
131
132 async fn run_in_task_if<Name, NameClosure>(self, in_task: InTask, nc: NameClosure) -> T::Output
133 where
134 Name: AsRef<str>,
135 NameClosure: FnOnce() -> Name + Unpin + Send,
136 Self: Future + Send + 'static,
137 T::Output: Send + 'static,
138 {
139 if let InTask::Yes = in_task {
140 self.run_in_task(nc).await
141 } else {
142 self.await
143 }
144 }
145
146 #[cfg(feature = "panic")]
147 fn ore_catch_unwind(self) -> OreCatchUnwind<Self>
148 where
149 Self: UnwindSafe,
150 {
151 use crate::panic::CATCHING_UNWIND_ASYNC;
152
153 OreCatchUnwind {
154 #[allow(clippy::disallowed_methods)]
155 inner: CATCHING_UNWIND_ASYNC.scope(true, FutureExt::catch_unwind(self)),
156 }
157 }
158}
159
160pub struct SpawnIfCanceled<T, Name, NameClosure>
162where
163 Name: AsRef<str>,
164 NameClosure: FnOnce() -> Name + Unpin,
165 T: Send + 'static,
166{
167 inner: Option<Pin<Box<dyn Future<Output = T> + Send>>>,
168 nc: Option<NameClosure>,
169}
170
171impl<T, Name, NameClosure> Future for SpawnIfCanceled<T, Name, NameClosure>
172where
173 Name: AsRef<str>,
174 NameClosure: FnOnce() -> Name + Unpin,
175 T: Send + 'static,
176{
177 type Output = T;
178
179 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<T> {
180 match &mut self.inner {
181 None => panic!("SpawnIfCanceled polled after completion"),
182 Some(f) => match f.as_mut().poll(cx) {
183 Poll::Pending => Poll::Pending,
184 Poll::Ready(res) => {
185 self.inner = None;
186 Poll::Ready(res)
187 }
188 },
189 }
190 }
191}
192
193impl<T, Name, NameClosure> Drop for SpawnIfCanceled<T, Name, NameClosure>
194where
195 Name: AsRef<str>,
196 NameClosure: FnOnce() -> Name + Unpin,
197 T: Send + 'static,
198{
199 fn drop(&mut self) {
200 if let Some(f) = self.inner.take() {
201 task::spawn(
202 || format!("spawn_if_canceled:{}", (self.nc).take().unwrap()().as_ref()),
203 f,
204 );
205 }
206 }
207}
208
209impl<T, Name, NameClosure> fmt::Debug for SpawnIfCanceled<T, Name, NameClosure>
210where
211 Name: AsRef<str>,
212 NameClosure: FnOnce() -> Name + Unpin,
213 T: Send + 'static,
214{
215 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
216 f.debug_struct("SpawnIfCanceled")
217 .field(
218 "inner",
219 match &self.inner {
220 None => &"None",
221 Some(_) => &"Some(<future>)",
222 },
223 )
224 .finish()
225 }
226}
227
228#[derive(Debug)]
230#[pin_project]
231pub struct OreCatchUnwind<Fut> {
232 #[pin]
233 inner: TaskLocalFuture<bool, CatchUnwind<Fut>>,
234}
235
236impl<Fut> Future for OreCatchUnwind<Fut>
237where
238 Fut: Future + UnwindSafe,
239{
240 type Output = Result<Fut::Output, Box<dyn Any + Send>>;
241
242 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
243 self.project().inner.poll(cx)
244 }
245}
246
247#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
249pub enum TimeoutError<E> {
250 DeadlineElapsed,
252 Inner(E),
254}
255
256impl<E> fmt::Display for TimeoutError<E>
257where
258 E: fmt::Display,
259{
260 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 match self {
262 TimeoutError::DeadlineElapsed => f.write_str("deadline has elapsed"),
263 e => e.fmt(f),
264 }
265 }
266}
267
268impl<E> Error for TimeoutError<E>
269where
270 E: Error + 'static,
271{
272 fn source(&self) -> Option<&(dyn Error + 'static)> {
273 match self {
274 TimeoutError::DeadlineElapsed => None,
275 TimeoutError::Inner(e) => Some(e),
276 }
277 }
278}
279
280pub async fn timeout<F, T, E>(duration: Duration, future: F) -> Result<T, TimeoutError<E>>
309where
310 F: Future<Output = Result<T, E>>,
311{
312 match time::timeout(duration, future).await {
313 Ok(Ok(t)) => Ok(t),
314 Ok(Err(e)) => Err(TimeoutError::Inner(e)),
315 Err(_) => Err(TimeoutError::DeadlineElapsed),
316 }
317}
318
319pub async fn timeout_at<F, T, E>(deadline: Instant, future: F) -> Result<T, TimeoutError<E>>
348where
349 F: Future<Output = Result<T, E>>,
350{
351 match time::timeout_at(deadline, future).await {
352 Ok(Ok(t)) => Ok(t),
353 Ok(Err(e)) => Err(TimeoutError::Inner(e)),
354 Err(_) => Err(TimeoutError::DeadlineElapsed),
355 }
356}
357
358pub trait OreSinkExt<T>: Sink<T> {
360 fn boxed(self) -> Box<dyn Sink<T, Error = Self::Error> + Send>
362 where
363 Self: Sized + Send + 'static,
364 {
365 Box::new(self)
366 }
367
368 fn enqueue(&mut self, item: T) -> Enqueue<Self, T> {
371 Enqueue {
372 sink: self,
373 item: Some(item),
374 }
375 }
376}
377
378impl<S, T> OreSinkExt<T> for S where S: Sink<T> {}
379
380#[derive(Debug)]
382#[must_use = "futures do nothing unless you `.await` or poll them"]
383pub struct Enqueue<'a, Si, Item>
384where
385 Si: ?Sized,
386{
387 sink: &'a mut Si,
388 item: Option<Item>,
389}
390
391impl<Si, Item> Future for Enqueue<'_, Si, Item>
392where
393 Si: Sink<Item> + Unpin + ?Sized,
394 Item: Unpin,
395{
396 type Output = Result<(), Si::Error>;
397
398 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
399 let this = &mut *self;
400 if let Some(item) = this.item.take() {
401 let mut sink = Pin::new(&mut this.sink);
402 match sink.as_mut().poll_ready(cx)? {
403 Poll::Ready(()) => sink.as_mut().start_send(item)?,
404 Poll::Pending => {
405 this.item = Some(item);
406 return Poll::Pending;
407 }
408 }
409 }
410 Poll::Ready(Ok(()))
411 }
412}
413
414pub fn dev_null<T, E>() -> DevNull<T, E> {
416 DevNull(PhantomData, PhantomData)
417}
418
419#[derive(Debug)]
424pub struct DevNull<T, E>(PhantomData<T>, PhantomData<E>);
425
426impl<T, E> Sink<T> for DevNull<T, E> {
427 type Error = E;
428
429 fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
430 Poll::Ready(Ok(()))
431 }
432
433 fn start_send(self: Pin<&mut Self>, _: T) -> Result<(), Self::Error> {
434 Ok(())
435 }
436
437 fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
438 Poll::Ready(Ok(()))
439 }
440
441 fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
442 Poll::Ready(Ok(()))
443 }
444}
445
446#[async_trait]
448pub trait OreStreamExt: Stream {
449 async fn recv_many(&mut self, max: usize) -> Option<Vec<Self::Item>>;
464}
465
466#[async_trait]
467impl<T> OreStreamExt for T
468where
469 T: futures::stream::Stream + futures::StreamExt + Send + Unpin,
470{
471 async fn recv_many(&mut self, max: usize) -> Option<Vec<Self::Item>> {
472 let first = self.next().await?;
474 let mut buffer = Vec::from([first]);
475
476 while let Some(v) = self.next().now_or_never().and_then(|e| e) {
481 buffer.push(v);
482
483 if buffer.len() >= max {
485 break;
486 }
487 }
488
489 Some(buffer)
490 }
491}