indicatif/
iter.rs

1use std::borrow::Cow;
2use std::io::{self, IoSliceMut};
3use std::iter::FusedIterator;
4#[cfg(feature = "tokio")]
5use std::pin::Pin;
6#[cfg(feature = "tokio")]
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10#[cfg(feature = "tokio")]
11use tokio::io::{ReadBuf, SeekFrom};
12
13use crate::progress_bar::ProgressBar;
14use crate::state::ProgressFinish;
15use crate::style::ProgressStyle;
16
17/// Wraps an iterator to display its progress.
18pub trait ProgressIterator
19where
20    Self: Sized + Iterator,
21{
22    /// Wrap an iterator with default styling. Uses [`Iterator::size_hint()`] to get length.
23    /// Returns `Some(..)` only if `size_hint.1` is [`Some`]. If you want to create a progress bar
24    /// even if `size_hint.1` returns [`None`] use [`progress_count()`](ProgressIterator::progress_count)
25    /// or [`progress_with()`](ProgressIterator::progress_with) instead.
26    fn try_progress(self) -> Option<ProgressBarIter<Self>> {
27        self.size_hint()
28            .1
29            .map(|len| self.progress_count(u64::try_from(len).unwrap()))
30    }
31
32    /// Wrap an iterator with default styling.
33    fn progress(self) -> ProgressBarIter<Self>
34    where
35        Self: ExactSizeIterator,
36    {
37        let len = u64::try_from(self.len()).unwrap();
38        self.progress_count(len)
39    }
40
41    /// Wrap an iterator with an explicit element count.
42    fn progress_count(self, len: u64) -> ProgressBarIter<Self> {
43        self.progress_with(ProgressBar::new(len))
44    }
45
46    /// Wrap an iterator with a custom progress bar.
47    fn progress_with(self, progress: ProgressBar) -> ProgressBarIter<Self>;
48
49    /// Wrap an iterator with a progress bar and style it.
50    fn progress_with_style(self, style: crate::ProgressStyle) -> ProgressBarIter<Self>
51    where
52        Self: ExactSizeIterator,
53    {
54        let len = u64::try_from(self.len()).unwrap();
55        let bar = ProgressBar::new(len).with_style(style);
56        self.progress_with(bar)
57    }
58}
59
60/// Wraps an iterator to display its progress.
61#[derive(Debug)]
62pub struct ProgressBarIter<T> {
63    pub(crate) it: T,
64    pub progress: ProgressBar,
65}
66
67impl<T> ProgressBarIter<T> {
68    /// Builder-like function for setting underlying progress bar's style.
69    ///
70    /// See [`ProgressBar::with_style()`].
71    pub fn with_style(mut self, style: ProgressStyle) -> Self {
72        self.progress = self.progress.with_style(style);
73        self
74    }
75
76    /// Builder-like function for setting underlying progress bar's prefix.
77    ///
78    /// See [`ProgressBar::with_prefix()`].
79    pub fn with_prefix(mut self, prefix: impl Into<Cow<'static, str>>) -> Self {
80        self.progress = self.progress.with_prefix(prefix);
81        self
82    }
83
84    /// Builder-like function for setting underlying progress bar's message.
85    ///
86    /// See [`ProgressBar::with_message()`].
87    pub fn with_message(mut self, message: impl Into<Cow<'static, str>>) -> Self {
88        self.progress = self.progress.with_message(message);
89        self
90    }
91
92    /// Builder-like function for setting underlying progress bar's position.
93    ///
94    /// See [`ProgressBar::with_position()`].
95    pub fn with_position(mut self, position: u64) -> Self {
96        self.progress = self.progress.with_position(position);
97        self
98    }
99
100    /// Builder-like function for setting underlying progress bar's elapsed time.
101    ///
102    /// See [`ProgressBar::with_elapsed()`].
103    pub fn with_elapsed(mut self, elapsed: Duration) -> Self {
104        self.progress = self.progress.with_elapsed(elapsed);
105        self
106    }
107
108    /// Builder-like function for setting underlying progress bar's finish behavior.
109    ///
110    /// See [`ProgressBar::with_finish()`].
111    pub fn with_finish(mut self, finish: ProgressFinish) -> Self {
112        self.progress = self.progress.with_finish(finish);
113        self
114    }
115}
116
117impl<S, T: Iterator<Item = S>> Iterator for ProgressBarIter<T> {
118    type Item = S;
119
120    fn next(&mut self) -> Option<Self::Item> {
121        let item = self.it.next();
122
123        if item.is_some() {
124            self.progress.inc(1);
125        } else if !self.progress.is_finished() {
126            self.progress.finish_using_style();
127        }
128
129        item
130    }
131}
132
133impl<T: ExactSizeIterator> ExactSizeIterator for ProgressBarIter<T> {
134    fn len(&self) -> usize {
135        self.it.len()
136    }
137}
138
139impl<T: DoubleEndedIterator> DoubleEndedIterator for ProgressBarIter<T> {
140    fn next_back(&mut self) -> Option<Self::Item> {
141        let item = self.it.next_back();
142
143        if item.is_some() {
144            self.progress.inc(1);
145        } else if !self.progress.is_finished() {
146            self.progress.finish_using_style();
147        }
148
149        item
150    }
151}
152
153impl<T: FusedIterator> FusedIterator for ProgressBarIter<T> {}
154
155impl<R: io::Read> io::Read for ProgressBarIter<R> {
156    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
157        let inc = self.it.read(buf)?;
158        self.progress.inc(inc as u64);
159        Ok(inc)
160    }
161
162    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
163        let inc = self.it.read_vectored(bufs)?;
164        self.progress.inc(inc as u64);
165        Ok(inc)
166    }
167
168    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
169        let inc = self.it.read_to_string(buf)?;
170        self.progress.inc(inc as u64);
171        Ok(inc)
172    }
173
174    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
175        self.it.read_exact(buf)?;
176        self.progress.inc(buf.len() as u64);
177        Ok(())
178    }
179}
180
181impl<R: io::BufRead> io::BufRead for ProgressBarIter<R> {
182    fn fill_buf(&mut self) -> io::Result<&[u8]> {
183        self.it.fill_buf()
184    }
185
186    fn consume(&mut self, amt: usize) {
187        self.it.consume(amt);
188        self.progress.inc(amt as u64);
189    }
190}
191
192impl<S: io::Seek> io::Seek for ProgressBarIter<S> {
193    fn seek(&mut self, f: io::SeekFrom) -> io::Result<u64> {
194        self.it.seek(f).map(|pos| {
195            self.progress.set_position(pos);
196            pos
197        })
198    }
199    // Pass this through to preserve optimizations that the inner I/O object may use here
200    // Also avoid sending a set_position update when the position hasn't changed
201    fn stream_position(&mut self) -> io::Result<u64> {
202        self.it.stream_position()
203    }
204}
205
206#[cfg(feature = "tokio")]
207#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
208impl<W: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for ProgressBarIter<W> {
209    fn poll_write(
210        mut self: Pin<&mut Self>,
211        cx: &mut Context<'_>,
212        buf: &[u8],
213    ) -> Poll<io::Result<usize>> {
214        Pin::new(&mut self.it).poll_write(cx, buf).map(|poll| {
215            poll.map(|inc| {
216                self.progress.inc(inc as u64);
217                inc
218            })
219        })
220    }
221
222    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
223        Pin::new(&mut self.it).poll_flush(cx)
224    }
225
226    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
227        Pin::new(&mut self.it).poll_shutdown(cx)
228    }
229}
230
231#[cfg(feature = "tokio")]
232#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
233impl<W: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for ProgressBarIter<W> {
234    fn poll_read(
235        mut self: Pin<&mut Self>,
236        cx: &mut Context<'_>,
237        buf: &mut ReadBuf<'_>,
238    ) -> Poll<io::Result<()>> {
239        let prev_len = buf.filled().len() as u64;
240        if let Poll::Ready(e) = Pin::new(&mut self.it).poll_read(cx, buf) {
241            self.progress.inc(buf.filled().len() as u64 - prev_len);
242            Poll::Ready(e)
243        } else {
244            Poll::Pending
245        }
246    }
247}
248
249#[cfg(feature = "tokio")]
250#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
251impl<W: tokio::io::AsyncSeek + Unpin> tokio::io::AsyncSeek for ProgressBarIter<W> {
252    fn start_seek(mut self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
253        Pin::new(&mut self.it).start_seek(position)
254    }
255
256    fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
257        Pin::new(&mut self.it).poll_complete(cx)
258    }
259}
260
261#[cfg(feature = "tokio")]
262#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
263impl<W: tokio::io::AsyncBufRead + Unpin + tokio::io::AsyncRead> tokio::io::AsyncBufRead
264    for ProgressBarIter<W>
265{
266    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
267        let this = self.get_mut();
268        let result = Pin::new(&mut this.it).poll_fill_buf(cx);
269        if let Poll::Ready(Ok(buf)) = &result {
270            this.progress.inc(buf.len() as u64);
271        }
272        result
273    }
274
275    fn consume(mut self: Pin<&mut Self>, amt: usize) {
276        Pin::new(&mut self.it).consume(amt);
277    }
278}
279
280#[cfg(feature = "futures")]
281#[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
282impl<S: futures_core::Stream + Unpin> futures_core::Stream for ProgressBarIter<S> {
283    type Item = S::Item;
284
285    fn poll_next(
286        self: std::pin::Pin<&mut Self>,
287        cx: &mut std::task::Context<'_>,
288    ) -> std::task::Poll<Option<Self::Item>> {
289        let this = self.get_mut();
290        let item = std::pin::Pin::new(&mut this.it).poll_next(cx);
291        match &item {
292            std::task::Poll::Ready(Some(_)) => this.progress.inc(1),
293            std::task::Poll::Ready(None) => this.progress.finish_using_style(),
294            std::task::Poll::Pending => {}
295        }
296        item
297    }
298}
299
300impl<W: io::Write> io::Write for ProgressBarIter<W> {
301    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
302        self.it.write(buf).map(|inc| {
303            self.progress.inc(inc as u64);
304            inc
305        })
306    }
307
308    fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> io::Result<usize> {
309        self.it.write_vectored(bufs).map(|inc| {
310            self.progress.inc(inc as u64);
311            inc
312        })
313    }
314
315    fn flush(&mut self) -> io::Result<()> {
316        self.it.flush()
317    }
318
319    // write_fmt can not be captured with reasonable effort.
320    // as it uses write_all internally by default that should not be a problem.
321    // fn write_fmt(&mut self, fmt: fmt::Arguments) -> io::Result<()>;
322}
323
324impl<S, T: Iterator<Item = S>> ProgressIterator for T {
325    fn progress_with(self, progress: ProgressBar) -> ProgressBarIter<Self> {
326        ProgressBarIter { it: self, progress }
327    }
328}
329
330#[cfg(test)]
331mod test {
332    use crate::iter::{ProgressBarIter, ProgressIterator};
333    use crate::progress_bar::ProgressBar;
334    use crate::ProgressStyle;
335
336    #[test]
337    fn it_can_wrap_an_iterator() {
338        let v = [1, 2, 3];
339        let wrap = |it: ProgressBarIter<_>| {
340            assert_eq!(it.map(|x| x * 2).collect::<Vec<_>>(), vec![2, 4, 6]);
341        };
342
343        wrap(v.iter().progress());
344        wrap(v.iter().progress_count(3));
345        wrap({
346            let pb = ProgressBar::new(v.len() as u64);
347            v.iter().progress_with(pb)
348        });
349        wrap({
350            let style = ProgressStyle::default_bar()
351                .template("{wide_bar:.red} {percent}/100%")
352                .unwrap();
353            v.iter().progress_with_style(style)
354        });
355    }
356}