mz_ore/
future.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License in the LICENSE file at the
6// root of this repository, or online at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Future and stream utilities.
17//!
18//! This module provides future and stream combinators that are missing from
19//! the [`futures`] crate.
20
21use std::any::Any;
22use std::error::Error;
23use std::fmt::{self, Debug};
24use std::future::Future;
25use std::marker::PhantomData;
26use std::panic::UnwindSafe;
27use std::pin::Pin;
28use std::task::{Context, Poll};
29
30use async_trait::async_trait;
31use futures::Stream;
32use futures::future::{CatchUnwind, FutureExt};
33use futures::sink::Sink;
34use pin_project::pin_project;
35use tokio::task::futures::TaskLocalFuture;
36use tokio::time::{self, Duration, Instant};
37
38use crate::task::{self, JoinHandleExt};
39
40/// Whether or not to run the future in `run_in_task_if` in a task.
41#[derive(Clone, Copy, Debug)]
42pub enum InTask {
43    /// Run it in a task.
44    Yes,
45    /// Poll it normally.
46    No,
47}
48
49/// Extension methods for futures.
50#[async_trait::async_trait]
51pub trait OreFutureExt {
52    /// Wraps a future in a [`SpawnIfCanceled`] future, which will spawn a
53    /// task to poll the inner future to completion if it is dropped.
54    fn spawn_if_canceled<Name, NameClosure>(
55        self,
56        nc: NameClosure,
57    ) -> SpawnIfCanceled<Self::Output, Name, NameClosure>
58    where
59        Name: AsRef<str>,
60        NameClosure: FnOnce() -> Name + Unpin,
61        Self: Future + Send + 'static,
62        Self::Output: Send + 'static;
63
64    /// Run a `'static` future in a Tokio task, naming that task, using a convenient
65    /// postfix call notation.
66    ///
67    /// Useful in contexts where futures may be starved and cause inadvertent
68    /// failures in I/O-sensitive operations, such as when called within timely
69    /// operators.
70    async fn run_in_task<Name, NameClosure>(self, nc: NameClosure) -> Self::Output
71    where
72        Name: AsRef<str>,
73        NameClosure: FnOnce() -> Name + Unpin + Send,
74        Self: Future + Send + 'static,
75        Self::Output: Send + 'static;
76
77    /// The same as `run_in_task`, but allows the callee to dynamically choose whether or
78    /// not the future is polled into a Tokio task.
79    // This is not currently a provided method because rust-analyzer fails inference if it is :(.
80    async fn run_in_task_if<Name, NameClosure>(
81        self,
82        in_task: InTask,
83        nc: NameClosure,
84    ) -> Self::Output
85    where
86        Name: AsRef<str>,
87        NameClosure: FnOnce() -> Name + Unpin + Send,
88        Self: Future + Send + 'static,
89        Self::Output: Send + 'static;
90
91    /// Like [`FutureExt::catch_unwind`], but can unwind panics even if
92    /// [`panic::install_enhanced_handler`] has been called.
93    ///
94    /// [`panic::install_enhanced_handler`]: crate::panic::install_enhanced_handler
95    #[cfg(feature = "panic")]
96    fn ore_catch_unwind(self) -> OreCatchUnwind<Self>
97    where
98        Self: Sized + UnwindSafe;
99}
100
101#[async_trait::async_trait]
102impl<T> OreFutureExt for T
103where
104    T: Future,
105{
106    fn spawn_if_canceled<Name, NameClosure>(
107        self,
108        nc: NameClosure,
109    ) -> SpawnIfCanceled<T::Output, Name, NameClosure>
110    where
111        Name: AsRef<str>,
112        NameClosure: FnOnce() -> Name + Unpin,
113        T: Send + 'static,
114        T::Output: Send + 'static,
115    {
116        SpawnIfCanceled {
117            inner: Some(Box::pin(self)),
118            nc: Some(nc),
119        }
120    }
121
122    async fn run_in_task<Name, NameClosure>(self, nc: NameClosure) -> T::Output
123    where
124        Name: AsRef<str>,
125        NameClosure: FnOnce() -> Name + Unpin + Send,
126        T: Send + 'static,
127        T::Output: Send + 'static,
128    {
129        task::spawn(nc, self).wait_and_assert_finished().await
130    }
131
132    async fn run_in_task_if<Name, NameClosure>(self, in_task: InTask, nc: NameClosure) -> T::Output
133    where
134        Name: AsRef<str>,
135        NameClosure: FnOnce() -> Name + Unpin + Send,
136        Self: Future + Send + 'static,
137        T::Output: Send + 'static,
138    {
139        if let InTask::Yes = in_task {
140            self.run_in_task(nc).await
141        } else {
142            self.await
143        }
144    }
145
146    #[cfg(feature = "panic")]
147    fn ore_catch_unwind(self) -> OreCatchUnwind<Self>
148    where
149        Self: UnwindSafe,
150    {
151        use crate::panic::CATCHING_UNWIND_ASYNC;
152
153        OreCatchUnwind {
154            #[allow(clippy::disallowed_methods)]
155            inner: CATCHING_UNWIND_ASYNC.scope(true, FutureExt::catch_unwind(self)),
156        }
157    }
158}
159
160/// The future returned by [`OreFutureExt::spawn_if_canceled`].
161pub struct SpawnIfCanceled<T, Name, NameClosure>
162where
163    Name: AsRef<str>,
164    NameClosure: FnOnce() -> Name + Unpin,
165    T: Send + 'static,
166{
167    inner: Option<Pin<Box<dyn Future<Output = T> + Send>>>,
168    nc: Option<NameClosure>,
169}
170
171impl<T, Name, NameClosure> Future for SpawnIfCanceled<T, Name, NameClosure>
172where
173    Name: AsRef<str>,
174    NameClosure: FnOnce() -> Name + Unpin,
175    T: Send + 'static,
176{
177    type Output = T;
178
179    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<T> {
180        match &mut self.inner {
181            None => panic!("SpawnIfCanceled polled after completion"),
182            Some(f) => match f.as_mut().poll(cx) {
183                Poll::Pending => Poll::Pending,
184                Poll::Ready(res) => {
185                    self.inner = None;
186                    Poll::Ready(res)
187                }
188            },
189        }
190    }
191}
192
193impl<T, Name, NameClosure> Drop for SpawnIfCanceled<T, Name, NameClosure>
194where
195    Name: AsRef<str>,
196    NameClosure: FnOnce() -> Name + Unpin,
197    T: Send + 'static,
198{
199    fn drop(&mut self) {
200        if let Some(f) = self.inner.take() {
201            task::spawn(
202                || format!("spawn_if_canceled:{}", (self.nc).take().unwrap()().as_ref()),
203                f,
204            );
205        }
206    }
207}
208
209impl<T, Name, NameClosure> fmt::Debug for SpawnIfCanceled<T, Name, NameClosure>
210where
211    Name: AsRef<str>,
212    NameClosure: FnOnce() -> Name + Unpin,
213    T: Send + 'static,
214{
215    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
216        f.debug_struct("SpawnIfCanceled")
217            .field(
218                "inner",
219                match &self.inner {
220                    None => &"None",
221                    Some(_) => &"Some(<future>)",
222                },
223            )
224            .finish()
225    }
226}
227
228/// The future returned by [`OreFutureExt::ore_catch_unwind`].
229#[derive(Debug)]
230#[pin_project]
231pub struct OreCatchUnwind<Fut> {
232    #[pin]
233    inner: TaskLocalFuture<bool, CatchUnwind<Fut>>,
234}
235
236impl<Fut> Future for OreCatchUnwind<Fut>
237where
238    Fut: Future + UnwindSafe,
239{
240    type Output = Result<Fut::Output, Box<dyn Any + Send>>;
241
242    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
243        self.project().inner.poll(cx)
244    }
245}
246
247/// The error returned by [`timeout`] and [`timeout_at`].
248#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
249pub enum TimeoutError<E> {
250    /// The timeout deadline has elapsed.
251    DeadlineElapsed,
252    /// The underlying operation failed.
253    Inner(E),
254}
255
256impl<E> fmt::Display for TimeoutError<E>
257where
258    E: fmt::Display,
259{
260    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261        match self {
262            TimeoutError::DeadlineElapsed => f.write_str("deadline has elapsed"),
263            e => e.fmt(f),
264        }
265    }
266}
267
268impl<E> Error for TimeoutError<E>
269where
270    E: Error + 'static,
271{
272    fn source(&self) -> Option<&(dyn Error + 'static)> {
273        match self {
274            TimeoutError::DeadlineElapsed => None,
275            TimeoutError::Inner(e) => Some(e),
276        }
277    }
278}
279
280/// Applies a maximum duration to a [`Result`]-returning future.
281///
282/// Whether the maximum duration was reached is indicated via the error type.
283/// Specifically:
284///
285///   * If `future` does not complete within `duration`, returns
286///     [`TimeoutError::DeadlineElapsed`].
287///   * If `future` completes with `Ok(t)`, returns `Ok(t)`.
288///   * If `future` completes with `Err(e)`, returns
289///     [`TimeoutError::Inner(e)`](TimeoutError::Inner).
290///
291/// Using this function can be considerably more readable than
292/// [`tokio::time::timeout`] when the inner future returns a `Result`.
293///
294/// # Examples
295///
296/// ```
297/// # use tokio::time::Duration;
298/// use mz_ore::future::TimeoutError;
299/// # tokio_test::block_on(async {
300/// let slow_op = async {
301///     tokio::time::sleep(Duration::from_secs(1)).await;
302///     Ok::<_, String>(())
303/// };
304/// let res = mz_ore::future::timeout(Duration::from_millis(1), slow_op).await;
305/// assert_eq!(res, Err(TimeoutError::DeadlineElapsed));
306/// # });
307/// ```
308pub async fn timeout<F, T, E>(duration: Duration, future: F) -> Result<T, TimeoutError<E>>
309where
310    F: Future<Output = Result<T, E>>,
311{
312    match time::timeout(duration, future).await {
313        Ok(Ok(t)) => Ok(t),
314        Ok(Err(e)) => Err(TimeoutError::Inner(e)),
315        Err(_) => Err(TimeoutError::DeadlineElapsed),
316    }
317}
318
319/// Applies a deadline to a [`Result`]-returning future.
320///
321/// Whether the deadline elapsed is indicated via the error type. Specifically:
322///
323///   * If `future` does not complete by `deadline`, returns
324///     [`TimeoutError::DeadlineElapsed`].
325///   * If `future` completes with `Ok(t)`, returns `Ok(t)`.
326///   * If `future` completes with `Err(e)`, returns
327///     [`TimeoutError::Inner(e)`](TimeoutError::Inner).
328///
329/// Using this function can be considerably more readable than
330/// [`tokio::time::timeout_at`] when the inner future returns a `Result`.
331///
332/// # Examples
333///
334/// ```
335/// # use tokio::time::{Duration, Instant};
336/// use mz_ore::future::TimeoutError;
337/// # tokio_test::block_on(async {
338/// let slow_op = async {
339///     tokio::time::sleep(Duration::from_secs(1)).await;
340///     Ok::<_, String>(())
341/// };
342/// let deadline = Instant::now() + Duration::from_millis(1);
343/// let res = mz_ore::future::timeout_at(deadline, slow_op).await;
344/// assert_eq!(res, Err(TimeoutError::DeadlineElapsed));
345/// # });
346/// ```
347pub async fn timeout_at<F, T, E>(deadline: Instant, future: F) -> Result<T, TimeoutError<E>>
348where
349    F: Future<Output = Result<T, E>>,
350{
351    match time::timeout_at(deadline, future).await {
352        Ok(Ok(t)) => Ok(t),
353        Ok(Err(e)) => Err(TimeoutError::Inner(e)),
354        Err(_) => Err(TimeoutError::DeadlineElapsed),
355    }
356}
357
358/// Extension methods for sinks.
359pub trait OreSinkExt<T>: Sink<T> {
360    /// Boxes this sink.
361    fn boxed(self) -> Box<dyn Sink<T, Error = Self::Error> + Send>
362    where
363        Self: Sized + Send + 'static,
364    {
365        Box::new(self)
366    }
367
368    /// Like [`futures::sink::SinkExt::send`], but does not flush the sink after enqueuing
369    /// `item`.
370    fn enqueue(&mut self, item: T) -> Enqueue<Self, T> {
371        Enqueue {
372            sink: self,
373            item: Some(item),
374        }
375    }
376}
377
378impl<S, T> OreSinkExt<T> for S where S: Sink<T> {}
379
380/// Future for the [`enqueue`](OreSinkExt::enqueue) method.
381#[derive(Debug)]
382#[must_use = "futures do nothing unless you `.await` or poll them"]
383pub struct Enqueue<'a, Si, Item>
384where
385    Si: ?Sized,
386{
387    sink: &'a mut Si,
388    item: Option<Item>,
389}
390
391impl<Si, Item> Future for Enqueue<'_, Si, Item>
392where
393    Si: Sink<Item> + Unpin + ?Sized,
394    Item: Unpin,
395{
396    type Output = Result<(), Si::Error>;
397
398    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
399        let this = &mut *self;
400        if let Some(item) = this.item.take() {
401            let mut sink = Pin::new(&mut this.sink);
402            match sink.as_mut().poll_ready(cx)? {
403                Poll::Ready(()) => sink.as_mut().start_send(item)?,
404                Poll::Pending => {
405                    this.item = Some(item);
406                    return Poll::Pending;
407                }
408            }
409        }
410        Poll::Ready(Ok(()))
411    }
412}
413
414/// Constructs a sink that consumes its input and sends it nowhere.
415pub fn dev_null<T, E>() -> DevNull<T, E> {
416    DevNull(PhantomData, PhantomData)
417}
418
419/// A sink that consumes its input and sends it nowhere.
420///
421/// Primarily useful as a base sink when folding multiple sinks into one using
422/// [`futures::sink::SinkExt::fanout`].
423#[derive(Debug)]
424pub struct DevNull<T, E>(PhantomData<T>, PhantomData<E>);
425
426impl<T, E> Sink<T> for DevNull<T, E> {
427    type Error = E;
428
429    fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
430        Poll::Ready(Ok(()))
431    }
432
433    fn start_send(self: Pin<&mut Self>, _: T) -> Result<(), Self::Error> {
434        Ok(())
435    }
436
437    fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
438        Poll::Ready(Ok(()))
439    }
440
441    fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
442        Poll::Ready(Ok(()))
443    }
444}
445
446/// Extension methods for streams.
447#[async_trait]
448pub trait OreStreamExt: Stream {
449    /// Awaits the stream for an event to be available and returns all currently buffered
450    /// events on the stream up to some `max`.
451    ///
452    /// This method returns `None` if the stream has ended.
453    ///
454    /// If there are no events ready on the stream this method will sleep until an event is
455    /// sent or the stream is closed. When woken it will return up to `max` currently buffered
456    /// events.
457    ///
458    /// # Cancel safety
459    ///
460    /// This method is cancel safe. If `recv_many` is used as the event in a `select!` statement
461    /// and some other branch completes first, it is guaranteed that no messages were received on
462    /// this channel.
463    async fn recv_many(&mut self, max: usize) -> Option<Vec<Self::Item>>;
464}
465
466#[async_trait]
467impl<T> OreStreamExt for T
468where
469    T: futures::stream::Stream + futures::StreamExt + Send + Unpin,
470{
471    async fn recv_many(&mut self, max: usize) -> Option<Vec<Self::Item>> {
472        // Wait for an event to be ready on the stream
473        let first = self.next().await?;
474        let mut buffer = Vec::from([first]);
475
476        // Note(parkmycar): It's very important for cancelation safety that we don't add any more
477        // .await points other than the initial one.
478
479        // Pull all other ready events off the stream, up to the max
480        while let Some(v) = self.next().now_or_never().and_then(|e| e) {
481            buffer.push(v);
482
483            // Break so we don't loop here continuously.
484            if buffer.len() >= max {
485                break;
486            }
487        }
488
489        Some(buffer)
490    }
491}