1use std::cmp::{Ord, Ordering};
2use std::fmt::Debug;
3use std::hash::BuildHasher;
4use std::vec::Vec;
5
6use crate::mediator::MediatorIndex;
7
8#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug, Hash)]
12pub(crate) struct HeapIndex(usize);
13
14#[derive(Copy, Clone)]
15struct HeapEntry<TPriority> {
16 outer_pos: MediatorIndex,
17 priority: TPriority,
18}
19
20impl<TPriority> HeapEntry<TPriority> {
21 #[inline(always)]
24 fn conv_pair(self) -> (MediatorIndex, TPriority) {
25 (self.outer_pos, self.priority)
26 }
27
28 #[inline(always)]
29 fn to_pair_ref(&self) -> (MediatorIndex, &TPriority) {
30 (self.outer_pos, &self.priority)
31 }
32
33 #[inline(always)]
34 fn to_outer(&self) -> MediatorIndex {
35 self.outer_pos
36 }
37}
38
39#[derive(Clone)]
40pub(crate) struct BinaryHeap<TPriority>
41where
42 TPriority: Ord,
43{
44 data: Vec<HeapEntry<TPriority>>,
45}
46
47impl<TPriority: Ord> BinaryHeap<TPriority> {
48 #[inline]
49 pub(crate) fn with_capacity(capacity: usize) -> Self {
50 Self {
51 data: Vec::with_capacity(capacity),
52 }
53 }
54
55 #[inline]
56 pub fn reserve(&mut self, additional: usize) {
57 self.data.reserve(additional);
58 }
59
60 #[inline]
65 pub(crate) fn push<TChangeHandler: std::ops::FnMut(MediatorIndex, HeapIndex)>(
66 &mut self,
67 outer_pos: MediatorIndex,
68 priority: TPriority,
69 mut change_handler: TChangeHandler,
70 ) {
71 self.data.push(HeapEntry {
72 outer_pos,
73 priority,
74 });
75 self.heapify_up(HeapIndex(self.data.len() - 1), &mut change_handler);
76 }
77
78 pub(crate) fn remove<TChangeHandler: std::ops::FnMut(MediatorIndex, HeapIndex)>(
81 &mut self,
82 position: HeapIndex,
83 mut change_handler: TChangeHandler,
84 ) -> Option<(MediatorIndex, TPriority)> {
85 if position >= self.len() {
86 return None;
87 }
88 if position.0 + 1 == self.len().0 {
89 let result = self.data.pop().expect("At least 1 item");
90 return Some(result.conv_pair());
91 }
92
93 let result = self.data.swap_remove(position.0);
94 self.heapify_down(position, &mut change_handler);
95 if position.0 > 0 {
96 self.heapify_up(position, &mut change_handler);
97 }
98 Some(result.conv_pair())
99 }
100
101 #[inline]
102 pub(crate) fn look_into(&self, position: HeapIndex) -> Option<(MediatorIndex, &TPriority)> {
103 self.data.get(position.0).map(HeapEntry::to_pair_ref)
104 }
105
106 pub(crate) fn change_priority<TChangeHandler: std::ops::FnMut(MediatorIndex, HeapIndex)>(
109 &mut self,
110 position: HeapIndex,
111 updated: TPriority,
112 mut change_handler: TChangeHandler,
113 ) -> TPriority {
114 debug_assert!(
115 position < self.len(),
116 "Out of index during changing priority"
117 );
118
119 let old = std::mem::replace(&mut self.data[position.0].priority, updated);
120 match old.cmp(&self.data[position.0].priority) {
121 Ordering::Less => {
122 self.heapify_up(position, &mut change_handler);
123 }
124 Ordering::Equal => {}
125 Ordering::Greater => {
126 self.heapify_down(position, &mut change_handler);
127 }
128 }
129 old
130 }
131
132 pub(crate) fn change_outer_pos(
134 &mut self,
135 outer_pos: MediatorIndex,
136 position: HeapIndex,
137 ) -> MediatorIndex {
138 debug_assert!(position < self.len(), "Out of index during changing key");
139
140 let old_pos = self.data[position.0].outer_pos;
141 self.data[position.0].outer_pos = outer_pos;
142 old_pos
143 }
144
145 #[inline]
146 pub(crate) fn most_prioritized_idx(&self) -> Option<(MediatorIndex, HeapIndex)> {
147 self.data.get(0).map(|x| (x.outer_pos, HeapIndex(0)))
148 }
149
150 #[inline]
151 pub(crate) fn len(&self) -> HeapIndex {
152 HeapIndex(self.data.len())
153 }
154
155 #[inline]
156 pub(crate) fn usize_len(&self) -> usize {
157 self.data.len()
158 }
159
160 #[inline]
161 pub(crate) fn is_empty(&self) -> bool {
162 self.data.is_empty()
163 }
164
165 #[inline]
166 pub(crate) fn clear(&mut self) {
167 self.data.clear()
168 }
169
170 #[inline]
171 pub(crate) fn iter(&self) -> BinaryHeapIterator<TPriority> {
172 BinaryHeapIterator {
173 inner: self.data.iter(),
174 }
175 }
176
177 pub(crate) fn produce_from_iter_hash<TKey, TIter, S>(
178 iter: TIter,
179 ) -> (Self, crate::mediator::Mediator<TKey, S>)
180 where
181 TKey: std::hash::Hash + Eq,
182 TIter: IntoIterator<Item = (TKey, TPriority)>,
183 S: BuildHasher + Default,
184 {
185 use crate::mediator::{Mediator, MediatorEntry};
186
187 let iter = iter.into_iter();
188 let (min_size, _) = iter.size_hint();
189
190 let mut heap_base: Vec<HeapEntry<TPriority>> = Vec::with_capacity(min_size);
191 let mut map: Mediator<TKey, S> = Mediator::with_capacity_and_hasher(min_size, S::default());
192
193 for (key, priority) in iter {
194 match map.entry(key) {
195 MediatorEntry::Vacant(entry) => {
196 let outer_pos = entry.index();
197 unsafe {
198 entry.insert(HeapIndex(heap_base.len()));
200 }
201 heap_base.push(HeapEntry {
202 outer_pos,
203 priority,
204 });
205 }
206 MediatorEntry::Occupied(entry) => {
207 let HeapIndex(heap_pos) = entry.get_heap_idx();
208 heap_base[heap_pos].priority = priority;
209 }
210 }
211 }
212
213 let heapify_start = std::cmp::min(heap_base.len() / 2 + 2, heap_base.len());
214 let mut heap = BinaryHeap { data: heap_base };
215 for pos in (0..heapify_start).rev().map(HeapIndex) {
216 heap.heapify_down(pos, &mut |_, _| {});
217 }
218
219 for (i, pos) in heap.data.iter().map(HeapEntry::to_outer).enumerate() {
220 let heap_idx = map.get_index_mut(pos);
221 *heap_idx = HeapIndex(i);
222 }
223
224 (heap, map)
225 }
226
227 fn heapify_up<TChangeHandler: std::ops::FnMut(MediatorIndex, HeapIndex)>(
228 &mut self,
229 position: HeapIndex,
230 change_handler: &mut TChangeHandler,
231 ) {
232 debug_assert!(position < self.len(), "Out of index in heapify_up");
233 let HeapIndex(mut position) = position;
234 while position > 0 {
235 let parent_pos = (position - 1) / 2;
236 if self.data[parent_pos].priority >= self.data[position].priority {
237 break;
238 }
239 self.data.swap(parent_pos, position);
240 change_handler(self.data[position].outer_pos, HeapIndex(position));
241 position = parent_pos;
242 }
243 change_handler(self.data[position].outer_pos, HeapIndex(position));
244 }
245
246 fn heapify_down<TChangeHandler: std::ops::FnMut(MediatorIndex, HeapIndex)>(
247 &mut self,
248 position: HeapIndex,
249 change_handler: &mut TChangeHandler,
250 ) {
251 debug_assert!(position < self.len(), "Out of index in heapify_down");
252 let HeapIndex(mut position) = position;
253 loop {
254 let max_child_idx = {
255 let child1 = position * 2 + 1;
256 let child2 = child1 + 1;
257 if child1 >= self.data.len() {
258 break;
259 }
260 if child2 < self.data.len()
261 && self.data[child1].priority <= self.data[child2].priority
262 {
263 child2
264 } else {
265 child1
266 }
267 };
268
269 if self.data[position].priority >= self.data[max_child_idx].priority {
270 break;
271 }
272 self.data.swap(position, max_child_idx);
273 change_handler(self.data[position].outer_pos, HeapIndex(position));
274 position = max_child_idx;
275 }
276 change_handler(self.data[position].outer_pos, HeapIndex(position));
277 }
278}
279
280pub(crate) struct BinaryHeapIterator<'a, TPriority> {
283 inner: std::slice::Iter<'a, HeapEntry<TPriority>>,
284}
285
286impl<'a, TPriority> Iterator for BinaryHeapIterator<'a, TPriority> {
287 type Item = (MediatorIndex, &'a TPriority);
288
289 #[inline]
290 fn next(&mut self) -> Option<Self::Item> {
291 self.inner
292 .next()
293 .map(|entry: &'a HeapEntry<TPriority>| (entry.outer_pos, &entry.priority))
294 }
295
296 #[inline]
297 fn size_hint(&self) -> (usize, Option<usize>) {
298 self.inner.size_hint()
299 }
300
301 #[inline]
302 fn count(self) -> usize
303 where
304 Self: Sized,
305 {
306 self.inner.count()
307 }
308}
309
310impl<TPriority: Debug> Debug for HeapEntry<TPriority> {
313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
314 write!(
315 f,
316 "{{outer: {:?}, priority: {:?}}}",
317 &self.outer_pos, &self.priority
318 )
319 }
320}
321
322impl<TPriority: Debug + Ord> Debug for BinaryHeap<TPriority> {
323 #[inline]
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
325 self.data.fmt(f)
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use crate::mediator::Mediator;
332
333 use super::*;
334 use std::cmp::Reverse;
335 use std::collections::hash_map::RandomState;
336 use std::collections::{HashMap, HashSet};
337
338 fn is_valid_heap<TP: Ord>(heap: &BinaryHeap<TP>) -> bool {
339 for (i, current) in heap.data.iter().enumerate().skip(1) {
340 let parent = &heap.data[(i - 1) / 2];
341 if parent.priority < current.priority {
342 return false;
343 }
344 }
345 true
346 }
347
348 #[test]
349 fn test_heap_fill() {
350 let items = [
351 70, 50, 0, 1, 2, 4, 6, 7, 9, 72, 4, 4, 87, 78, 72, 6, 7, 9, 2, -50, -72, -50, -42, -1,
352 -3, -13,
353 ];
354 let mut maximum = std::i32::MIN;
355 let mut heap = BinaryHeap::<i32>::with_capacity(0);
356 assert!(heap.look_into(HeapIndex(0)).is_none());
357 assert!(is_valid_heap(&heap), "Heap state is invalid");
358 for (key, x) in items
359 .iter()
360 .enumerate()
361 .map(|(i, &x)| (MediatorIndex(i), x))
362 {
363 if x > maximum {
364 maximum = x;
365 }
366 heap.push(key, x, |_, _| {});
367 assert!(
368 is_valid_heap(&heap),
369 "Heap state is invalid after pushing {}",
370 x
371 );
372 assert!(heap.look_into(HeapIndex(0)).is_some());
373 let (_, &heap_max) = heap.look_into(HeapIndex(0)).unwrap();
374 assert_eq!(maximum, heap_max)
375 }
376 }
377
378 #[test]
379 fn test_change_logger() {
380 let items = [
381 2, 3, 21, 22, 25, 29, 36, 90, 89, 88, 87, 83, 48, 50, 52, 69, 65, 55, 73, 75, 76, -53,
382 78, 81, -45, -41, 91, -34, -33, -31, -27, -22, -19, -8, -5, -3,
383 ];
384 let mut last_positions = HashMap::<MediatorIndex, HeapIndex>::new();
385 let mut heap = BinaryHeap::<i32>::with_capacity(0);
386 let mut on_pos_change = |outer_pos: MediatorIndex, position: HeapIndex| {
387 last_positions.insert(outer_pos, position);
388 };
389 for (i, &x) in items.iter().enumerate() {
390 heap.push(MediatorIndex(i), x, &mut on_pos_change);
391 }
392 assert_eq!(heap.usize_len(), last_positions.len());
393 for i in 0..items.len() {
394 let rem_idx = MediatorIndex(i);
395 assert!(
396 last_positions.contains_key(&rem_idx),
397 "Not for all items change_handler called"
398 );
399 let position = last_positions[&rem_idx];
400 assert_eq!(
401 items[(heap.look_into(position).unwrap().0).0],
402 *heap.look_into(position).unwrap().1
403 );
404 assert_eq!(heap.look_into(position).unwrap().0, rem_idx);
405 }
406
407 let mut removed = HashSet::<MediatorIndex>::new();
408 loop {
409 let mut on_pos_change = |key: MediatorIndex, position: HeapIndex| {
410 last_positions.insert(key, position);
411 };
412 let popped = heap.remove(HeapIndex(0), &mut on_pos_change);
413 if popped.is_none() {
414 break;
415 }
416 let (key, _) = popped.unwrap();
417 last_positions.remove(&key);
418 removed.insert(key);
419 assert_eq!(heap.usize_len(), last_positions.len());
420 for i in (0..items.len())
421 .into_iter()
422 .filter(|i| !removed.contains(&MediatorIndex(*i)))
423 {
424 let rem_idx = MediatorIndex(i);
425 assert!(
426 last_positions.contains_key(&rem_idx),
427 "Not for all items change_handler called"
428 );
429 let position = last_positions[&rem_idx];
430 assert_eq!(
431 items[(heap.look_into(position).unwrap().0).0],
432 *heap.look_into(position).unwrap().1
433 );
434 assert_eq!(heap.look_into(position).unwrap().0, rem_idx);
435 }
436 }
437 }
438
439 #[test]
440 fn test_pop() {
441 let items = [
442 -16, 5, 11, -1, -34, -42, -5, -6, 25, -35, 11, 35, -2, 40, 42, 40, -45, -48, 48, -38,
443 -28, -33, -31, 34, -18, 25, 16, -33, -11, -6, -35, -38, 35, -41, -38, 31, -38, -23, 26,
444 44, 38, 11, -49, 30, 7, 13, 12, -4, -11, -24, -49, 26, 42, 46, -25, -22, -6, -42, 28,
445 45, -47, 8, 8, 21, 49, -12, -5, -33, -37, 24, -3, -26, 6, -13, 16, -40, -14, -39, -26,
446 12, -44, 47, 45, -41, -22, -11, 20, 43, -44, 24, 47, 40, 43, 9, 19, 12, -17, 30, -36,
447 -50, 24, -2, 1, 1, 5, -19, 21, -38, 47, 34, -14, 12, -30, 24, -2, -32, -10, 40, 34, 2,
448 -33, 9, -31, -3, -15, 28, 50, -37, 35, 19, 35, 13, -2, 46, 28, 35, -40, -19, -1, -33,
449 -42, -35, -12, 19, 29, 10, -31, -4, -9, 24, 15, -27, 13, 20, 15, 19, -40, -41, 40, -25,
450 45, -11, -7, -19, 11, -44, -37, 35, 2, -49, 11, -37, -14, 13, 41, 10, 3, 19, -32, -12,
451 -12, 33, -26, -49, -45, 24, 47, -29, -25, -45, -36, 40, 24, -29, 15, 36, 0, 47, 3, -45,
452 ];
453
454 let mut heap = BinaryHeap::<i32>::with_capacity(0);
455 for (i, &x) in items.iter().enumerate() {
456 heap.push(MediatorIndex(i), x, |_, _| {});
457 }
458 assert!(is_valid_heap(&heap), "Heap is invalid before pops");
459
460 let mut sorted_items = items;
461 sorted_items.sort_unstable_by_key(|&x| Reverse(x));
462 for &x in sorted_items.iter() {
463 let pop_res = heap.remove(HeapIndex(0), |_, _| {});
464 assert!(pop_res.is_some());
465 let (rem_idx, val) = pop_res.unwrap();
466 assert_eq!(val, x);
467 assert_eq!(items[rem_idx.0], val);
468 assert!(is_valid_heap(&heap), "Heap is invalid after {}", x);
469 }
470
471 assert_eq!(heap.remove(HeapIndex(0), |_, _| {}), None);
472 }
473
474 #[test]
475 fn test_remove() {
476 let mut heap = BinaryHeap::with_capacity(16);
477 for i in 0..16 {
478 heap.push(MediatorIndex(i), i, |_, _| {});
479 }
480 assert!(is_valid_heap(&heap));
481 for _ in 0..5 {
482 heap.remove(HeapIndex(5), |_, _| {});
483 assert!(is_valid_heap(&heap));
484 }
485 }
486
487 #[test]
488 fn test_change_priority() {
489 let pairs = [
490 (MediatorIndex(0), 0),
491 (MediatorIndex(1), 1),
492 (MediatorIndex(2), 2),
493 (MediatorIndex(3), 3),
494 (MediatorIndex(4), 4),
495 ];
496
497 let mut heap = BinaryHeap::with_capacity(0);
498 for (key, priority) in pairs.iter().cloned() {
499 heap.push(key, priority, |_, _| {});
500 }
501 assert!(is_valid_heap(&heap), "Invalid before change");
502 heap.change_priority(HeapIndex(3), 10, |_, _| {});
503 assert!(is_valid_heap(&heap), "Invalid after upping");
504 heap.change_priority(HeapIndex(2), -10, |_, _| {});
505 assert!(is_valid_heap(&heap), "Invalid after lowering");
506 }
507
508 #[test]
509 fn create_heap_hash_test() {
510 let priorities = [
511 16i32, 16, 5, 20, 10, 12, 10, 8, 12, 2, 20, -1, -18, 5, -16, 1, 7, 3, 17, -20, -4, 3,
512 -7, -5, -8, 19, -19, -16, 3, 4, 17, 13, 3, 11, -9, 0, -10, -2, 16, 19, -12, -4, 19, 7,
513 16, -19, -9, -17, 6, -16, -3, 11, -14, -15, -10, 13, 11, -14, 18, -8, -9, -4, 5, -4,
514 17, 6, -16, -5, 12, 12, -3, 8, 5, -4, 7, 10, 7, -11, 18, -16, 18, 4, -15, -4, -13, 7,
515 -14, -16, -18, -10, 13, -1, -9, 0, -18, -4, -13, 16, 10, -20, 19, 20, 0, -9, -7, 14,
516 19, -8, -18, -1, -17, -11, 13, 12, -15, 0, -18, 6, -13, -17, -3, 18, 2, 12, 12, 4, -14,
517 -11, -10, -9, 3, 14, 8, 7, 13, 13, -17, -9, -4, -19, -6, 1, 9, 5, 20, -9, -19, -20,
518 -18, -8, 7,
519 ];
520 let (heap, key_to_pos): (_, Mediator<_, RandomState>) =
521 BinaryHeap::produce_from_iter_hash(priorities.iter().cloned().map(|x| (x, x)));
522 assert!(is_valid_heap(&heap), "Must be valid heap");
523 for (map_idx, (key, heap_idx)) in key_to_pos.iter().enumerate() {
524 assert_eq!(
525 Some((MediatorIndex(map_idx), key)),
526 heap.look_into(heap_idx)
527 );
528 }
529 }
530
531 #[test]
532 fn test_clear() {
533 let mut heap = BinaryHeap::with_capacity(0);
534 for x in 0..5 {
535 heap.push(MediatorIndex(x), x, |_, _| {});
536 }
537 assert!(!heap.is_empty(), "Heap must be non empty");
538 heap.data.clear();
539 assert!(heap.is_empty(), "Heap must be empty");
540 assert_eq!(heap.remove(HeapIndex(0), |_, _| {}), None);
541 }
542
543 #[test]
544 fn test_change_change_outer_pos() {
545 let mut heap = BinaryHeap::with_capacity(0);
546 for x in 0..5 {
547 heap.push(MediatorIndex(x), x, |_, _| {});
548 }
549 assert_eq!(heap.look_into(HeapIndex(0)), Some((MediatorIndex(4), &4)));
550 assert_eq!(
551 heap.change_outer_pos(MediatorIndex(10), HeapIndex(0)),
552 MediatorIndex(4)
553 );
554 assert_eq!(heap.look_into(HeapIndex(0)), Some((MediatorIndex(10), &4)));
555 }
556}