1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.
use mz_ore::task::{JoinHandle, JoinHandleExt};
use std::fmt::{Debug, Formatter};
use std::mem;
/// A merge tree.
///
/// Invariants and guarantees:
/// - This structure preserves the order in which elements are `push`ed.
/// - Merging also preserves order: only adjacent elements will be merged together,
/// and the result will have the same place in the ordering as the input did.
/// - The tree will store at most `O(K log N)` elements at once, where `K` is the provided max len
/// and `N` is the number of elements pushed.
/// - `finish` will return at most `K` elements.
/// - The "depth" of the merge tree - the number of merges any particular element may undergo -
/// is `O(log N)`.
pub struct MergeTree<T> {
pub(crate) max_len: usize,
pub(crate) levels: Vec<Vec<T>>,
merge_fn: Box<dyn Fn(Vec<T>) -> T + Sync + Send>,
}
impl<T: Debug> Debug for MergeTree<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let Self {
max_len,
levels,
merge_fn: _,
} = self;
f.debug_struct("MergeTree")
.field("max_len", max_len)
.field("levels", levels)
.finish_non_exhaustive()
}
}
impl<T> MergeTree<T> {
/// Create a new merge tree. `max_len` limits both the number of parts to keep at each level of
/// the tree, and the number of parts that `Self::finish` will return... and if we exceed that
/// limit, the provided `merge_fn` is used to combine adjacent elements together.
pub fn new(max_len: usize, merge_fn: impl Fn(Vec<T>) -> T + Send + Sync + 'static) -> Self {
let new = Self {
max_len,
levels: vec![vec![]],
merge_fn: Box::new(merge_fn),
};
new.assert_invariants();
new
}
/// Iterate over (references to) the parts in this tree in first-to-latest order.
#[allow(unused)]
pub fn iter(&self) -> impl Iterator<Item = &T> + DoubleEndedIterator {
self.levels.iter().rev().flat_map(|l| l.iter())
}
/// Iterate over (mutable references to) the parts in this tree in first-to-latest order.
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut T> + DoubleEndedIterator {
self.levels.iter_mut().rev().flat_map(|l| l.iter_mut())
}
/// Push a new part onto the end of this tree, possibly triggering a merge.
pub fn push(&mut self, mut part: T) {
// Normally, all levels have strictly less than max_len elements.
// However, the _deepest_ level is allowed to have exactly max_len elements,
// since that can save us an unnecessary merge in some cases.
// (For example, when precisely max_len elements are added.)
if let Some(last) = self.levels.last_mut() {
if last.len() == self.max_len {
let merged = (self.merge_fn)(mem::take(last));
self.levels.push(vec![merged]);
}
}
// At this point, all levels have room. Add our new part, then continue
// merging up the tree until either there's still room in the current level
// or we've reached the top.
let max_level = self.levels.len() - 1;
for depth in 0..=max_level {
let level = &mut self.levels[depth];
level.push(part);
if level.len() < self.max_len || depth == max_level {
break;
}
part = (self.merge_fn)(mem::take(level));
}
}
/// Return the contents of this merge tree, flattened into at most `max_len` parts.
pub fn finish(self) -> Vec<T> {
self.levels
.into_iter()
.reduce(|mut shallower, mut deeper| {
if shallower.len() + deeper.len() <= self.max_len {
// Optimization: if there's enough room in the next level for everything at the
// current level, add it directly.
deeper.append(&mut shallower);
} else {
// Otherwise, merge this up as if it were a full level.
let merged = (self.merge_fn)(shallower);
deeper.push(merged);
}
deeper
})
.expect("non-empty level array")
}
pub(crate) fn assert_invariants(&self) {
assert!(self.max_len >= 2, "max_len must be at least 2");
let (deepest, shallow) = self.levels.split_last().expect("non-empty level array");
for (depth, level) in shallow.iter().enumerate() {
assert!(
level.len() < self.max_len,
"strictly less than max elements at level {depth}"
);
}
assert!(
deepest.len() <= self.max_len,
"at most max elements at deepest level"
);
}
}
/// Either a handle to a task that returns a value or the value itself.
#[derive(Debug)]
pub enum Pending<T> {
Writing(JoinHandle<T>),
Blocking,
Finished(T),
}
impl<T: Send + 'static> Pending<T> {
pub fn new(handle: JoinHandle<T>) -> Self {
Self::Writing(handle)
}
pub fn is_finished(&self) -> bool {
matches!(self, Self::Finished(_))
}
pub async fn into_result(self) -> T {
match self {
Pending::Writing(h) => h.wait_and_assert_finished().await,
Pending::Blocking => panic!("block_until_ready cancelled?"),
Pending::Finished(t) => t,
}
}
pub async fn block_until_ready(&mut self) {
let pending = mem::replace(self, Self::Blocking);
let value = pending.into_result().await;
*self = Pending::Finished(value);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[mz_ore::test]
fn test_merge_tree() {
// Exhaustively test the merge tree for small sizes.
for max_len in 2..8 {
for items in 0..100 {
let mut merge_tree = MergeTree::new(max_len, |vals: Vec<Vec<usize>>| {
// Merge sequences by concatenation.
vals.into_iter().flatten().collect()
});
for i in 0..items {
merge_tree.push(vec![i]);
assert!(
merge_tree.iter().flatten().copied().eq(0..=i),
"no parts should be lost"
);
merge_tree.assert_invariants();
}
let parts = merge_tree.finish();
assert!(
parts.len() <= max_len,
"no more than {max_len} finished parts"
);
assert!(parts.into_iter().flatten().eq(0..items), "no parts lost");
}
}
}
}