Skip to main content

mz_ore/
iter.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License in the LICENSE file at the
6// root of this repository, or online at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Iterator utilities.
17
18use std::cmp::{Ordering, Reverse};
19use std::collections::BinaryHeap;
20use std::collections::binary_heap::PeekMut;
21use std::fmt::Debug;
22use std::iter::{self, Chain, Once, Peekable};
23use std::rc::Rc;
24
25/// Extension methods for iterators.
26pub trait IteratorExt
27where
28    Self: Iterator + Sized,
29{
30    /// Chains a single `item` onto the end of this iterator.
31    ///
32    /// Equivalent to `self.chain(iter::once(item))`.
33    fn chain_one(self, item: Self::Item) -> Chain<Self, Once<Self::Item>> {
34        self.chain(iter::once(item))
35    }
36
37    /// Reports whether all the elements of the iterator are the same.
38    ///
39    /// This condition is trivially true for iterators with zero or one elements.
40    fn all_equal(mut self) -> bool
41    where
42        Self::Item: PartialEq,
43    {
44        match self.next() {
45            None => true,
46            Some(v1) => self.all(|v2| v1 == v2),
47        }
48    }
49
50    /// Converts the the iterator into an `ExactSizeIterator` reporting the given size.
51    ///
52    /// The caller is responsible for providing the correct size of the iterator! Providing an
53    /// incorrect size value will lead to panics and/or incorrect responses to size queries.
54    ///
55    /// # Panics
56    ///
57    /// Panics if the given length is not consistent with this iterator's `size_hint`.
58    fn exact_size(self, len: usize) -> ExactSize<Self> {
59        let (lower, upper) = self.size_hint();
60        assert!(
61            lower <= len && upper.map_or(true, |upper| upper >= len),
62            "provided length {len} inconsistent with `size_hint`: {:?}",
63            (lower, upper)
64        );
65
66        ExactSize { inner: self, len }
67    }
68
69    /// Wrap this iterator with one that yields a tuple of the iterator element and the extra
70    /// value on each iteration. The extra value is cloned for each but the last `Some` element
71    /// returned.
72    ///
73    /// This is useful to provide an owned extra value to each iteration, but only clone it
74    /// when necessary.
75    ///
76    /// NOTE: Once the iterator starts producing `None` values, the extra value will be consumed
77    /// and no longer be available. This should not be used for iterators that may produce
78    /// `Some` values after producing `None`.
79    fn repeat_clone<A: Clone>(self, extra_val: A) -> RepeatClone<Self, A> {
80        RepeatClone {
81            iter: self.peekable(),
82            extra_val: Some(extra_val),
83        }
84    }
85}
86
87impl<I> IteratorExt for I where I: Iterator {}
88
89/// Proactively consolidate adjacent elements in an iterator.
90///
91/// If the input iterator is sorted by `D`, the outputs will be both sorted and fully consolidated.
92/// If the input is not sorted, the output will probably not be either... but still equivalent from
93/// a differential perspective and slightly smaller.
94#[cfg(feature = "differential-dataflow")]
95pub fn consolidate_iter<D: PartialEq, R: differential_dataflow::difference::Semigroup>(
96    iter: impl Iterator<Item = (D, R)>,
97) -> impl Iterator<Item = (D, R)> {
98    let mut peekable = iter.peekable();
99    iter::from_fn(move || {
100        loop {
101            let (t, mut d) = peekable.next()?;
102            while let Some((t_next, d_next)) = peekable.peek()
103                && t == *t_next
104            {
105                d.plus_equals(d_next);
106                let _ = peekable.next();
107            }
108            if d.is_zero() {
109                continue;
110            }
111            return Some((t, d));
112        }
113    })
114}
115
116/// Proactively consolidate adjacent elements in an iterator. (Triple / update edition.)
117///
118/// If the input iterator is sorted by `D` and `T`, the outputs will be both sorted and fully consolidated.
119/// If the input is not sorted, the output will probably not be either... but still equivalent from
120/// a differential perspective and slightly smaller.
121#[cfg(feature = "differential-dataflow")]
122pub fn consolidate_update_iter<
123    D: PartialEq,
124    T: PartialEq,
125    R: differential_dataflow::difference::Semigroup,
126>(
127    iter: impl Iterator<Item = (D, T, R)>,
128) -> impl Iterator<Item = (D, T, R)> {
129    consolidate_iter(iter.map(|(d, t, r)| ((d, t), r))).map(|((d, t), r)| (d, t, r))
130}
131
132/// Combine a stream of iterators into a new iterator, according to the provided merge function.
133///
134/// If the input iterators are sorted by the provided function, the resulting iterators will be
135/// sorted by that function also.
136pub fn merge_iters_by<I: Iterator, F: Fn(&I::Item, &I::Item) -> Ordering>(
137    iters: impl IntoIterator<Item = I>,
138    merge_by: F,
139) -> impl Iterator<Item = I::Item> {
140    /// Iterator-like struct to help with extracting rows in sorted order from `RowCollection`.
141    struct RunIter<I: Iterator, F> {
142        iter: I,
143        peek: <I as Iterator>::Item,
144        merge_by: Rc<F>,
145    }
146
147    impl<I: Iterator, F: Fn(&I::Item, &I::Item) -> Ordering> Ord for RunIter<I, F> {
148        fn cmp(&self, other: &Self) -> Ordering {
149            (self.merge_by)(&self.peek, &other.peek)
150        }
151    }
152
153    impl<I: Iterator, F: Fn(&I::Item, &I::Item) -> Ordering> PartialOrd for RunIter<I, F> {
154        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
155            Some(self.cmp(other))
156        }
157    }
158    impl<I: Iterator, F: Fn(&I::Item, &I::Item) -> Ordering> PartialEq for RunIter<I, F> {
159        fn eq(&self, other: &Self) -> bool {
160            self.cmp(other).is_eq()
161        }
162    }
163
164    impl<I: Iterator, F: Fn(&I::Item, &I::Item) -> Ordering> Eq for RunIter<I, F> {}
165
166    let iters = iters.into_iter();
167    let mut heap = BinaryHeap::with_capacity(iters.size_hint().0);
168    let merge_by = Rc::new(merge_by);
169    for mut i in iters {
170        let Some(peek) = i.next() else {
171            continue;
172        };
173        heap.push(Reverse(RunIter {
174            peek,
175            iter: i,
176            merge_by: Rc::clone(&merge_by),
177        }));
178    }
179
180    iter::from_fn(move || {
181        let mut peek_mut = heap.peek_mut()?;
182        let next = match peek_mut.0.iter.next() {
183            None => PeekMut::pop(peek_mut).0.peek,
184            Some(next) => std::mem::replace(&mut peek_mut.0.peek, next),
185        };
186        Some(next)
187    })
188}
189
190/// Iterator type returned by [`IteratorExt::exact_size`].
191#[derive(Debug)]
192pub struct ExactSize<I> {
193    inner: I,
194    len: usize,
195}
196
197impl<I: Iterator> Iterator for ExactSize<I> {
198    type Item = I::Item;
199
200    fn next(&mut self) -> Option<Self::Item> {
201        self.len = self.len.saturating_sub(1);
202        self.inner.next()
203    }
204
205    fn size_hint(&self) -> (usize, Option<usize>) {
206        (self.len, Some(self.len))
207    }
208}
209
210impl<I: Iterator> ExactSizeIterator for ExactSize<I> {}
211
212/// Iterator type returned by [`IteratorExt::repeat_clone`].
213pub struct RepeatClone<I: Iterator, A> {
214    iter: Peekable<I>,
215    extra_val: Option<A>,
216}
217
218impl<I: Iterator, A: Clone> Iterator for RepeatClone<I, A> {
219    type Item = (I::Item, A);
220
221    fn next(&mut self) -> Option<Self::Item> {
222        let next = self.iter.next()?;
223
224        // Clone the extra_val only if there is an item to return on the next call to `next`.
225        let val = match self.iter.peek() {
226            Some(_) => self.extra_val.clone(),
227            None => self.extra_val.take(),
228        };
229
230        // We should always return a value if there is a current element.
231        Some((next, val.expect("RepeatClone invariant violated")))
232    }
233}
234
235impl<I: Iterator<Item: Debug> + Debug, A: Debug> Debug for RepeatClone<I, A> {
236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        f.debug_struct("RepeatClone")
238            .field("iter", &self.iter)
239            .field("extra_val", &self.extra_val)
240            .finish()
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use proptest::collection::vec;
247    use proptest::prelude::*;
248
249    use super::*;
250
251    #[crate::test]
252    fn test_all_equal() {
253        let empty: [i64; 0] = [];
254        assert!(empty.iter().all_equal());
255        assert!([1].iter().all_equal());
256        assert!([1, 1].iter().all_equal());
257        assert!(![1, 2].iter().all_equal());
258    }
259
260    #[crate::test]
261    #[cfg(feature = "differential-dataflow")]
262    fn test_consolidate_sorted() {
263        proptest!(|(mut data in any::<Vec<(u64, i64)>>())| {
264             data.sort();
265             let streamed: Vec<_> = consolidate_iter(data.iter().copied()).collect();
266             differential_dataflow::consolidation::consolidate(&mut data);
267             assert_eq!(data, streamed);
268        });
269    }
270
271    #[crate::test]
272    #[cfg(feature = "differential-dataflow")]
273    fn test_consolidate() {
274        proptest!(|(mut data in any::<Vec<(u64, i64)>>())| {
275             let streamed: Vec<_> = consolidate_iter(data.iter().copied()).collect();
276             data.dedup_by_key(|t| t.0);
277             assert_eq!(data.len(), streamed.len());
278        });
279    }
280
281    #[crate::test]
282    fn test_merge() {
283        proptest!(|(mut data in vec(vec(0usize..100usize, 0..10), 0..10))| {
284            let mut expected: Vec<_> = data.iter().flatten().copied().collect();
285            expected.sort();
286
287            for series in &mut data {
288                series.sort()
289            }
290            let merged: Vec<_> = merge_iters_by(
291                data.into_iter().map(|i| i.into_iter()),
292                |a, b| a.cmp(b)
293            ).collect();
294             assert_eq!(expected, merged);
295        });
296    }
297}