keyed_priority_queue/
editable_binary_heap.rs

1use std::cmp::{Ord, Ordering};
2use std::fmt::Debug;
3use std::hash::BuildHasher;
4use std::vec::Vec;
5
6use crate::mediator::MediatorIndex;
7
8/// Wrapper around usize that can be used only as index of `BinaryHeap`
9/// Mostly needed to statically check that
10/// Heap is not indexed by any other collection index
11#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug, Hash)]
12pub(crate) struct HeapIndex(usize);
13
14#[derive(Copy, Clone)]
15struct HeapEntry<TPriority> {
16    outer_pos: MediatorIndex,
17    priority: TPriority,
18}
19
20impl<TPriority> HeapEntry<TPriority> {
21    // For usings as HeapEntry::as_pair instead of closures in map
22
23    #[inline(always)]
24    fn conv_pair(self) -> (MediatorIndex, TPriority) {
25        (self.outer_pos, self.priority)
26    }
27
28    #[inline(always)]
29    fn to_pair_ref(&self) -> (MediatorIndex, &TPriority) {
30        (self.outer_pos, &self.priority)
31    }
32
33    #[inline(always)]
34    fn to_outer(&self) -> MediatorIndex {
35        self.outer_pos
36    }
37}
38
39#[derive(Clone)]
40pub(crate) struct BinaryHeap<TPriority>
41where
42    TPriority: Ord,
43{
44    data: Vec<HeapEntry<TPriority>>,
45}
46
47impl<TPriority: Ord> BinaryHeap<TPriority> {
48    #[inline]
49    pub(crate) fn with_capacity(capacity: usize) -> Self {
50        Self {
51            data: Vec::with_capacity(capacity),
52        }
53    }
54
55    #[inline]
56    pub fn reserve(&mut self, additional: usize) {
57        self.data.reserve(additional);
58    }
59
60    /// Puts outer index and priority in queue
61    /// outer_pos is assumed to be unique but not validated
62    /// because validation too expensive
63    /// Calls change_handler for every move of old values
64    #[inline]
65    pub(crate) fn push<TChangeHandler: std::ops::FnMut(MediatorIndex, HeapIndex)>(
66        &mut self,
67        outer_pos: MediatorIndex,
68        priority: TPriority,
69        mut change_handler: TChangeHandler,
70    ) {
71        self.data.push(HeapEntry {
72            outer_pos,
73            priority,
74        });
75        self.heapify_up(HeapIndex(self.data.len() - 1), &mut change_handler);
76    }
77
78    /// Removes item at position and returns it
79    /// Time complexity - O(log n) swaps and change_handler calls
80    pub(crate) fn remove<TChangeHandler: std::ops::FnMut(MediatorIndex, HeapIndex)>(
81        &mut self,
82        position: HeapIndex,
83        mut change_handler: TChangeHandler,
84    ) -> Option<(MediatorIndex, TPriority)> {
85        if position >= self.len() {
86            return None;
87        }
88        if position.0 + 1 == self.len().0 {
89            let result = self.data.pop().expect("At least 1 item");
90            return Some(result.conv_pair());
91        }
92
93        let result = self.data.swap_remove(position.0);
94        self.heapify_down(position, &mut change_handler);
95        if position.0 > 0 {
96            self.heapify_up(position, &mut change_handler);
97        }
98        Some(result.conv_pair())
99    }
100
101    #[inline]
102    pub(crate) fn look_into(&self, position: HeapIndex) -> Option<(MediatorIndex, &TPriority)> {
103        self.data.get(position.0).map(HeapEntry::to_pair_ref)
104    }
105
106    /// Changes priority of queue item
107    /// Returns old priority
108    pub(crate) fn change_priority<TChangeHandler: std::ops::FnMut(MediatorIndex, HeapIndex)>(
109        &mut self,
110        position: HeapIndex,
111        updated: TPriority,
112        mut change_handler: TChangeHandler,
113    ) -> TPriority {
114        debug_assert!(
115            position < self.len(),
116            "Out of index during changing priority"
117        );
118
119        let old = std::mem::replace(&mut self.data[position.0].priority, updated);
120        match old.cmp(&self.data[position.0].priority) {
121            Ordering::Less => {
122                self.heapify_up(position, &mut change_handler);
123            }
124            Ordering::Equal => {}
125            Ordering::Greater => {
126                self.heapify_down(position, &mut change_handler);
127            }
128        }
129        old
130    }
131
132    // Changes outer index for element and return old index
133    pub(crate) fn change_outer_pos(
134        &mut self,
135        outer_pos: MediatorIndex,
136        position: HeapIndex,
137    ) -> MediatorIndex {
138        debug_assert!(position < self.len(), "Out of index during changing key");
139
140        let old_pos = self.data[position.0].outer_pos;
141        self.data[position.0].outer_pos = outer_pos;
142        old_pos
143    }
144
145    #[inline]
146    pub(crate) fn most_prioritized_idx(&self) -> Option<(MediatorIndex, HeapIndex)> {
147        self.data.get(0).map(|x| (x.outer_pos, HeapIndex(0)))
148    }
149
150    #[inline]
151    pub(crate) fn len(&self) -> HeapIndex {
152        HeapIndex(self.data.len())
153    }
154
155    #[inline]
156    pub(crate) fn usize_len(&self) -> usize {
157        self.data.len()
158    }
159
160    #[inline]
161    pub(crate) fn is_empty(&self) -> bool {
162        self.data.is_empty()
163    }
164
165    #[inline]
166    pub(crate) fn clear(&mut self) {
167        self.data.clear()
168    }
169
170    #[inline]
171    pub(crate) fn iter(&self) -> BinaryHeapIterator<TPriority> {
172        BinaryHeapIterator {
173            inner: self.data.iter(),
174        }
175    }
176
177    pub(crate) fn produce_from_iter_hash<TKey, TIter, S>(
178        iter: TIter,
179    ) -> (Self, crate::mediator::Mediator<TKey, S>)
180    where
181        TKey: std::hash::Hash + Eq,
182        TIter: IntoIterator<Item = (TKey, TPriority)>,
183        S: BuildHasher + Default,
184    {
185        use crate::mediator::{Mediator, MediatorEntry};
186
187        let iter = iter.into_iter();
188        let (min_size, _) = iter.size_hint();
189
190        let mut heap_base: Vec<HeapEntry<TPriority>> = Vec::with_capacity(min_size);
191        let mut map: Mediator<TKey, S> = Mediator::with_capacity_and_hasher(min_size, S::default());
192
193        for (key, priority) in iter {
194            match map.entry(key) {
195                MediatorEntry::Vacant(entry) => {
196                    let outer_pos = entry.index();
197                    unsafe {
198                        // Safety: resulting reference never used
199                        entry.insert(HeapIndex(heap_base.len()));
200                    }
201                    heap_base.push(HeapEntry {
202                        outer_pos,
203                        priority,
204                    });
205                }
206                MediatorEntry::Occupied(entry) => {
207                    let HeapIndex(heap_pos) = entry.get_heap_idx();
208                    heap_base[heap_pos].priority = priority;
209                }
210            }
211        }
212
213        let heapify_start = std::cmp::min(heap_base.len() / 2 + 2, heap_base.len());
214        let mut heap = BinaryHeap { data: heap_base };
215        for pos in (0..heapify_start).rev().map(HeapIndex) {
216            heap.heapify_down(pos, &mut |_, _| {});
217        }
218
219        for (i, pos) in heap.data.iter().map(HeapEntry::to_outer).enumerate() {
220            let heap_idx = map.get_index_mut(pos);
221            *heap_idx = HeapIndex(i);
222        }
223
224        (heap, map)
225    }
226
227    fn heapify_up<TChangeHandler: std::ops::FnMut(MediatorIndex, HeapIndex)>(
228        &mut self,
229        position: HeapIndex,
230        change_handler: &mut TChangeHandler,
231    ) {
232        debug_assert!(position < self.len(), "Out of index in heapify_up");
233        let HeapIndex(mut position) = position;
234        while position > 0 {
235            let parent_pos = (position - 1) / 2;
236            if self.data[parent_pos].priority >= self.data[position].priority {
237                break;
238            }
239            self.data.swap(parent_pos, position);
240            change_handler(self.data[position].outer_pos, HeapIndex(position));
241            position = parent_pos;
242        }
243        change_handler(self.data[position].outer_pos, HeapIndex(position));
244    }
245
246    fn heapify_down<TChangeHandler: std::ops::FnMut(MediatorIndex, HeapIndex)>(
247        &mut self,
248        position: HeapIndex,
249        change_handler: &mut TChangeHandler,
250    ) {
251        debug_assert!(position < self.len(), "Out of index in heapify_down");
252        let HeapIndex(mut position) = position;
253        loop {
254            let max_child_idx = {
255                let child1 = position * 2 + 1;
256                let child2 = child1 + 1;
257                if child1 >= self.data.len() {
258                    break;
259                }
260                if child2 < self.data.len()
261                    && self.data[child1].priority <= self.data[child2].priority
262                {
263                    child2
264                } else {
265                    child1
266                }
267            };
268
269            if self.data[position].priority >= self.data[max_child_idx].priority {
270                break;
271            }
272            self.data.swap(position, max_child_idx);
273            change_handler(self.data[position].outer_pos, HeapIndex(position));
274            position = max_child_idx;
275        }
276        change_handler(self.data[position].outer_pos, HeapIndex(position));
277    }
278}
279
280/// Useful to create iterator for outer struct
281/// Does NOT guarantee any particular order
282pub(crate) struct BinaryHeapIterator<'a, TPriority> {
283    inner: std::slice::Iter<'a, HeapEntry<TPriority>>,
284}
285
286impl<'a, TPriority> Iterator for BinaryHeapIterator<'a, TPriority> {
287    type Item = (MediatorIndex, &'a TPriority);
288
289    #[inline]
290    fn next(&mut self) -> Option<Self::Item> {
291        self.inner
292            .next()
293            .map(|entry: &'a HeapEntry<TPriority>| (entry.outer_pos, &entry.priority))
294    }
295
296    #[inline]
297    fn size_hint(&self) -> (usize, Option<usize>) {
298        self.inner.size_hint()
299    }
300
301    #[inline]
302    fn count(self) -> usize
303    where
304        Self: Sized,
305    {
306        self.inner.count()
307    }
308}
309
310// Default implementations
311
312impl<TPriority: Debug> Debug for HeapEntry<TPriority> {
313    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
314        write!(
315            f,
316            "{{outer: {:?}, priority: {:?}}}",
317            &self.outer_pos, &self.priority
318        )
319    }
320}
321
322impl<TPriority: Debug + Ord> Debug for BinaryHeap<TPriority> {
323    #[inline]
324    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
325        self.data.fmt(f)
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use crate::mediator::Mediator;
332
333    use super::*;
334    use std::cmp::Reverse;
335    use std::collections::hash_map::RandomState;
336    use std::collections::{HashMap, HashSet};
337
338    fn is_valid_heap<TP: Ord>(heap: &BinaryHeap<TP>) -> bool {
339        for (i, current) in heap.data.iter().enumerate().skip(1) {
340            let parent = &heap.data[(i - 1) / 2];
341            if parent.priority < current.priority {
342                return false;
343            }
344        }
345        true
346    }
347
348    #[test]
349    fn test_heap_fill() {
350        let items = [
351            70, 50, 0, 1, 2, 4, 6, 7, 9, 72, 4, 4, 87, 78, 72, 6, 7, 9, 2, -50, -72, -50, -42, -1,
352            -3, -13,
353        ];
354        let mut maximum = std::i32::MIN;
355        let mut heap = BinaryHeap::<i32>::with_capacity(0);
356        assert!(heap.look_into(HeapIndex(0)).is_none());
357        assert!(is_valid_heap(&heap), "Heap state is invalid");
358        for (key, x) in items
359            .iter()
360            .enumerate()
361            .map(|(i, &x)| (MediatorIndex(i), x))
362        {
363            if x > maximum {
364                maximum = x;
365            }
366            heap.push(key, x, |_, _| {});
367            assert!(
368                is_valid_heap(&heap),
369                "Heap state is invalid after pushing {}",
370                x
371            );
372            assert!(heap.look_into(HeapIndex(0)).is_some());
373            let (_, &heap_max) = heap.look_into(HeapIndex(0)).unwrap();
374            assert_eq!(maximum, heap_max)
375        }
376    }
377
378    #[test]
379    fn test_change_logger() {
380        let items = [
381            2, 3, 21, 22, 25, 29, 36, 90, 89, 88, 87, 83, 48, 50, 52, 69, 65, 55, 73, 75, 76, -53,
382            78, 81, -45, -41, 91, -34, -33, -31, -27, -22, -19, -8, -5, -3,
383        ];
384        let mut last_positions = HashMap::<MediatorIndex, HeapIndex>::new();
385        let mut heap = BinaryHeap::<i32>::with_capacity(0);
386        let mut on_pos_change = |outer_pos: MediatorIndex, position: HeapIndex| {
387            last_positions.insert(outer_pos, position);
388        };
389        for (i, &x) in items.iter().enumerate() {
390            heap.push(MediatorIndex(i), x, &mut on_pos_change);
391        }
392        assert_eq!(heap.usize_len(), last_positions.len());
393        for i in 0..items.len() {
394            let rem_idx = MediatorIndex(i);
395            assert!(
396                last_positions.contains_key(&rem_idx),
397                "Not for all items change_handler called"
398            );
399            let position = last_positions[&rem_idx];
400            assert_eq!(
401                items[(heap.look_into(position).unwrap().0).0],
402                *heap.look_into(position).unwrap().1
403            );
404            assert_eq!(heap.look_into(position).unwrap().0, rem_idx);
405        }
406
407        let mut removed = HashSet::<MediatorIndex>::new();
408        loop {
409            let mut on_pos_change = |key: MediatorIndex, position: HeapIndex| {
410                last_positions.insert(key, position);
411            };
412            let popped = heap.remove(HeapIndex(0), &mut on_pos_change);
413            if popped.is_none() {
414                break;
415            }
416            let (key, _) = popped.unwrap();
417            last_positions.remove(&key);
418            removed.insert(key);
419            assert_eq!(heap.usize_len(), last_positions.len());
420            for i in (0..items.len())
421                .into_iter()
422                .filter(|i| !removed.contains(&MediatorIndex(*i)))
423            {
424                let rem_idx = MediatorIndex(i);
425                assert!(
426                    last_positions.contains_key(&rem_idx),
427                    "Not for all items change_handler called"
428                );
429                let position = last_positions[&rem_idx];
430                assert_eq!(
431                    items[(heap.look_into(position).unwrap().0).0],
432                    *heap.look_into(position).unwrap().1
433                );
434                assert_eq!(heap.look_into(position).unwrap().0, rem_idx);
435            }
436        }
437    }
438
439    #[test]
440    fn test_pop() {
441        let items = [
442            -16, 5, 11, -1, -34, -42, -5, -6, 25, -35, 11, 35, -2, 40, 42, 40, -45, -48, 48, -38,
443            -28, -33, -31, 34, -18, 25, 16, -33, -11, -6, -35, -38, 35, -41, -38, 31, -38, -23, 26,
444            44, 38, 11, -49, 30, 7, 13, 12, -4, -11, -24, -49, 26, 42, 46, -25, -22, -6, -42, 28,
445            45, -47, 8, 8, 21, 49, -12, -5, -33, -37, 24, -3, -26, 6, -13, 16, -40, -14, -39, -26,
446            12, -44, 47, 45, -41, -22, -11, 20, 43, -44, 24, 47, 40, 43, 9, 19, 12, -17, 30, -36,
447            -50, 24, -2, 1, 1, 5, -19, 21, -38, 47, 34, -14, 12, -30, 24, -2, -32, -10, 40, 34, 2,
448            -33, 9, -31, -3, -15, 28, 50, -37, 35, 19, 35, 13, -2, 46, 28, 35, -40, -19, -1, -33,
449            -42, -35, -12, 19, 29, 10, -31, -4, -9, 24, 15, -27, 13, 20, 15, 19, -40, -41, 40, -25,
450            45, -11, -7, -19, 11, -44, -37, 35, 2, -49, 11, -37, -14, 13, 41, 10, 3, 19, -32, -12,
451            -12, 33, -26, -49, -45, 24, 47, -29, -25, -45, -36, 40, 24, -29, 15, 36, 0, 47, 3, -45,
452        ];
453
454        let mut heap = BinaryHeap::<i32>::with_capacity(0);
455        for (i, &x) in items.iter().enumerate() {
456            heap.push(MediatorIndex(i), x, |_, _| {});
457        }
458        assert!(is_valid_heap(&heap), "Heap is invalid before pops");
459
460        let mut sorted_items = items;
461        sorted_items.sort_unstable_by_key(|&x| Reverse(x));
462        for &x in sorted_items.iter() {
463            let pop_res = heap.remove(HeapIndex(0), |_, _| {});
464            assert!(pop_res.is_some());
465            let (rem_idx, val) = pop_res.unwrap();
466            assert_eq!(val, x);
467            assert_eq!(items[rem_idx.0], val);
468            assert!(is_valid_heap(&heap), "Heap is invalid after {}", x);
469        }
470
471        assert_eq!(heap.remove(HeapIndex(0), |_, _| {}), None);
472    }
473
474    #[test]
475    fn test_remove() {
476        let mut heap = BinaryHeap::with_capacity(16);
477        for i in 0..16 {
478            heap.push(MediatorIndex(i), i, |_, _| {});
479        }
480        assert!(is_valid_heap(&heap));
481        for _ in 0..5 {
482            heap.remove(HeapIndex(5), |_, _| {});
483            assert!(is_valid_heap(&heap));
484        }
485    }
486
487    #[test]
488    fn test_change_priority() {
489        let pairs = [
490            (MediatorIndex(0), 0),
491            (MediatorIndex(1), 1),
492            (MediatorIndex(2), 2),
493            (MediatorIndex(3), 3),
494            (MediatorIndex(4), 4),
495        ];
496
497        let mut heap = BinaryHeap::with_capacity(0);
498        for (key, priority) in pairs.iter().cloned() {
499            heap.push(key, priority, |_, _| {});
500        }
501        assert!(is_valid_heap(&heap), "Invalid before change");
502        heap.change_priority(HeapIndex(3), 10, |_, _| {});
503        assert!(is_valid_heap(&heap), "Invalid after upping");
504        heap.change_priority(HeapIndex(2), -10, |_, _| {});
505        assert!(is_valid_heap(&heap), "Invalid after lowering");
506    }
507
508    #[test]
509    fn create_heap_hash_test() {
510        let priorities = [
511            16i32, 16, 5, 20, 10, 12, 10, 8, 12, 2, 20, -1, -18, 5, -16, 1, 7, 3, 17, -20, -4, 3,
512            -7, -5, -8, 19, -19, -16, 3, 4, 17, 13, 3, 11, -9, 0, -10, -2, 16, 19, -12, -4, 19, 7,
513            16, -19, -9, -17, 6, -16, -3, 11, -14, -15, -10, 13, 11, -14, 18, -8, -9, -4, 5, -4,
514            17, 6, -16, -5, 12, 12, -3, 8, 5, -4, 7, 10, 7, -11, 18, -16, 18, 4, -15, -4, -13, 7,
515            -14, -16, -18, -10, 13, -1, -9, 0, -18, -4, -13, 16, 10, -20, 19, 20, 0, -9, -7, 14,
516            19, -8, -18, -1, -17, -11, 13, 12, -15, 0, -18, 6, -13, -17, -3, 18, 2, 12, 12, 4, -14,
517            -11, -10, -9, 3, 14, 8, 7, 13, 13, -17, -9, -4, -19, -6, 1, 9, 5, 20, -9, -19, -20,
518            -18, -8, 7,
519        ];
520        let (heap, key_to_pos): (_, Mediator<_, RandomState>) =
521            BinaryHeap::produce_from_iter_hash(priorities.iter().cloned().map(|x| (x, x)));
522        assert!(is_valid_heap(&heap), "Must be valid heap");
523        for (map_idx, (key, heap_idx)) in key_to_pos.iter().enumerate() {
524            assert_eq!(
525                Some((MediatorIndex(map_idx), key)),
526                heap.look_into(heap_idx)
527            );
528        }
529    }
530
531    #[test]
532    fn test_clear() {
533        let mut heap = BinaryHeap::with_capacity(0);
534        for x in 0..5 {
535            heap.push(MediatorIndex(x), x, |_, _| {});
536        }
537        assert!(!heap.is_empty(), "Heap must be non empty");
538        heap.data.clear();
539        assert!(heap.is_empty(), "Heap must be empty");
540        assert_eq!(heap.remove(HeapIndex(0), |_, _| {}), None);
541    }
542
543    #[test]
544    fn test_change_change_outer_pos() {
545        let mut heap = BinaryHeap::with_capacity(0);
546        for x in 0..5 {
547            heap.push(MediatorIndex(x), x, |_, _| {});
548        }
549        assert_eq!(heap.look_into(HeapIndex(0)), Some((MediatorIndex(4), &4)));
550        assert_eq!(
551            heap.change_outer_pos(MediatorIndex(10), HeapIndex(0)),
552            MediatorIndex(4)
553        );
554        assert_eq!(heap.look_into(HeapIndex(0)), Some((MediatorIndex(10), &4)));
555    }
556}