1use std::cmp::{Ordering, Reverse};
19use std::collections::BinaryHeap;
20use std::collections::binary_heap::PeekMut;
21use std::fmt::Debug;
22use std::iter::{self, Chain, Once, Peekable};
23use std::rc::Rc;
24
25pub trait IteratorExt
27where
28 Self: Iterator + Sized,
29{
30 fn chain_one(self, item: Self::Item) -> Chain<Self, Once<Self::Item>> {
34 self.chain(iter::once(item))
35 }
36
37 fn all_equal(mut self) -> bool
41 where
42 Self::Item: PartialEq,
43 {
44 match self.next() {
45 None => true,
46 Some(v1) => self.all(|v2| v1 == v2),
47 }
48 }
49
50 fn exact_size(self, len: usize) -> ExactSize<Self> {
59 let (lower, upper) = self.size_hint();
60 assert!(
61 lower <= len && upper.map_or(true, |upper| upper >= len),
62 "provided length {len} inconsistent with `size_hint`: {:?}",
63 (lower, upper)
64 );
65
66 ExactSize { inner: self, len }
67 }
68
69 fn repeat_clone<A: Clone>(self, extra_val: A) -> RepeatClone<Self, A> {
80 RepeatClone {
81 iter: self.peekable(),
82 extra_val: Some(extra_val),
83 }
84 }
85}
86
87impl<I> IteratorExt for I where I: Iterator {}
88
89#[cfg(feature = "differential-dataflow")]
95pub fn consolidate_iter<D: PartialEq, R: differential_dataflow::difference::Semigroup>(
96 iter: impl Iterator<Item = (D, R)>,
97) -> impl Iterator<Item = (D, R)> {
98 let mut peekable = iter.peekable();
99 iter::from_fn(move || {
100 loop {
101 let (t, mut d) = peekable.next()?;
102 while let Some((t_next, d_next)) = peekable.peek()
103 && t == *t_next
104 {
105 d.plus_equals(d_next);
106 let _ = peekable.next();
107 }
108 if d.is_zero() {
109 continue;
110 }
111 return Some((t, d));
112 }
113 })
114}
115
116#[cfg(feature = "differential-dataflow")]
122pub fn consolidate_update_iter<
123 D: PartialEq,
124 T: PartialEq,
125 R: differential_dataflow::difference::Semigroup,
126>(
127 iter: impl Iterator<Item = (D, T, R)>,
128) -> impl Iterator<Item = (D, T, R)> {
129 consolidate_iter(iter.map(|(d, t, r)| ((d, t), r))).map(|((d, t), r)| (d, t, r))
130}
131
132pub fn merge_iters_by<I: Iterator, F: Fn(&I::Item, &I::Item) -> Ordering>(
137 iters: impl IntoIterator<Item = I>,
138 merge_by: F,
139) -> impl Iterator<Item = I::Item> {
140 struct RunIter<I: Iterator, F> {
142 iter: I,
143 peek: <I as Iterator>::Item,
144 merge_by: Rc<F>,
145 }
146
147 impl<I: Iterator, F: Fn(&I::Item, &I::Item) -> Ordering> Ord for RunIter<I, F> {
148 fn cmp(&self, other: &Self) -> Ordering {
149 (self.merge_by)(&self.peek, &other.peek)
150 }
151 }
152
153 impl<I: Iterator, F: Fn(&I::Item, &I::Item) -> Ordering> PartialOrd for RunIter<I, F> {
154 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
155 Some(self.cmp(other))
156 }
157 }
158 impl<I: Iterator, F: Fn(&I::Item, &I::Item) -> Ordering> PartialEq for RunIter<I, F> {
159 fn eq(&self, other: &Self) -> bool {
160 self.cmp(other).is_eq()
161 }
162 }
163
164 impl<I: Iterator, F: Fn(&I::Item, &I::Item) -> Ordering> Eq for RunIter<I, F> {}
165
166 let iters = iters.into_iter();
167 let mut heap = BinaryHeap::with_capacity(iters.size_hint().0);
168 let merge_by = Rc::new(merge_by);
169 for mut i in iters {
170 let Some(peek) = i.next() else {
171 continue;
172 };
173 heap.push(Reverse(RunIter {
174 peek,
175 iter: i,
176 merge_by: Rc::clone(&merge_by),
177 }));
178 }
179
180 iter::from_fn(move || {
181 let mut peek_mut = heap.peek_mut()?;
182 let next = match peek_mut.0.iter.next() {
183 None => PeekMut::pop(peek_mut).0.peek,
184 Some(next) => std::mem::replace(&mut peek_mut.0.peek, next),
185 };
186 Some(next)
187 })
188}
189
190#[derive(Debug)]
192pub struct ExactSize<I> {
193 inner: I,
194 len: usize,
195}
196
197impl<I: Iterator> Iterator for ExactSize<I> {
198 type Item = I::Item;
199
200 fn next(&mut self) -> Option<Self::Item> {
201 self.len = self.len.saturating_sub(1);
202 self.inner.next()
203 }
204
205 fn size_hint(&self) -> (usize, Option<usize>) {
206 (self.len, Some(self.len))
207 }
208}
209
210impl<I: Iterator> ExactSizeIterator for ExactSize<I> {}
211
212pub struct RepeatClone<I: Iterator, A> {
214 iter: Peekable<I>,
215 extra_val: Option<A>,
216}
217
218impl<I: Iterator, A: Clone> Iterator for RepeatClone<I, A> {
219 type Item = (I::Item, A);
220
221 fn next(&mut self) -> Option<Self::Item> {
222 let next = self.iter.next()?;
223
224 let val = match self.iter.peek() {
226 Some(_) => self.extra_val.clone(),
227 None => self.extra_val.take(),
228 };
229
230 Some((next, val.expect("RepeatClone invariant violated")))
232 }
233}
234
235impl<I: Iterator<Item: Debug> + Debug, A: Debug> Debug for RepeatClone<I, A> {
236 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237 f.debug_struct("RepeatClone")
238 .field("iter", &self.iter)
239 .field("extra_val", &self.extra_val)
240 .finish()
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use proptest::collection::vec;
247 use proptest::prelude::*;
248
249 use super::*;
250
251 #[crate::test]
252 fn test_all_equal() {
253 let empty: [i64; 0] = [];
254 assert!(empty.iter().all_equal());
255 assert!([1].iter().all_equal());
256 assert!([1, 1].iter().all_equal());
257 assert!(![1, 2].iter().all_equal());
258 }
259
260 #[crate::test]
261 #[cfg(feature = "differential-dataflow")]
262 fn test_consolidate_sorted() {
263 proptest!(|(mut data in any::<Vec<(u64, i64)>>())| {
264 data.sort();
265 let streamed: Vec<_> = consolidate_iter(data.iter().copied()).collect();
266 differential_dataflow::consolidation::consolidate(&mut data);
267 assert_eq!(data, streamed);
268 });
269 }
270
271 #[crate::test]
272 #[cfg(feature = "differential-dataflow")]
273 fn test_consolidate() {
274 proptest!(|(mut data in any::<Vec<(u64, i64)>>())| {
275 let streamed: Vec<_> = consolidate_iter(data.iter().copied()).collect();
276 data.dedup_by_key(|t| t.0);
277 assert_eq!(data.len(), streamed.len());
278 });
279 }
280
281 #[crate::test]
282 fn test_merge() {
283 proptest!(|(mut data in vec(vec(0usize..100usize, 0..10), 0..10))| {
284 let mut expected: Vec<_> = data.iter().flatten().copied().collect();
285 expected.sort();
286
287 for series in &mut data {
288 series.sort()
289 }
290 let merged: Vec<_> = merge_iters_by(
291 data.into_iter().map(|i| i.into_iter()),
292 |a, b| a.cmp(b)
293 ).collect();
294 assert_eq!(expected, merged);
295 });
296 }
297}