1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
910use mz_ore::task::{JoinHandle, JoinHandleExt};
11use std::fmt::{Debug, Formatter};
12use std::mem;
13use std::ops::{Deref, DerefMut};
1415/// A merge tree.
16///
17/// Invariants and guarantees:
18/// - This structure preserves the order in which elements are `push`ed.
19/// - Merging also preserves order: only adjacent elements will be merged together,
20/// and the result will have the same place in the ordering as the input did.
21/// - The tree will store at most `O(K log N)` elements at once, where `K` is the provided max len
22/// and `N` is the number of elements pushed.
23/// - `finish` will return at most `K` elements.
24/// - The "depth" of the merge tree - the number of merges any particular element may undergo -
25/// is `O(log N)`.
26pub struct MergeTree<T> {
27/// Configuration: the largest any level in the tree is allowed to grow.
28max_level_len: usize,
29/// The length of each level in the tree, stored in order from shallowest to deepest.
30level_lens: Vec<usize>,
31/// A flattened representation of the contents of the tree, stored in order from earliest /
32 /// deepest to newest / shallowest.
33data: Vec<T>,
34 merge_fn: Box<dyn Fn(Vec<T>) -> T + Sync + Send>,
35}
3637impl<T: Debug> Debug for MergeTree<T> {
38fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39let Self {
40 max_level_len,
41 level_lens,
42 data,
43 merge_fn: _,
44 } = self;
45 f.debug_struct("MergeTree")
46 .field("max_level_len", max_level_len)
47 .field("level_lens", level_lens)
48 .field("data", data)
49 .finish_non_exhaustive()
50 }
51}
5253impl<T> MergeTree<T> {
54/// Create a new merge tree. `max_len` limits both the number of parts to keep at each level of
55 /// the tree, and the number of parts that `Self::finish` will return... and if we exceed that
56 /// limit, the provided `merge_fn` is used to combine adjacent elements together.
57pub fn new(max_len: usize, merge_fn: impl Fn(Vec<T>) -> T + Send + Sync + 'static) -> Self {
58let new = Self {
59 max_level_len: max_len,
60 level_lens: vec![0],
61 data: vec![],
62 merge_fn: Box::new(merge_fn),
63 };
64 new.assert_invariants();
65 new
66 }
6768fn merge_last(&mut self, level_len: usize) {
69let offset = self.data.len() - level_len;
70let split = self.data.split_off(offset);
71let merged = (self.merge_fn)(split);
72self.data.push(merged);
73 }
7475/// Push a new part onto the end of this tree, possibly triggering a merge.
76pub fn push(&mut self, part: T) {
77// Normally, all levels have strictly less than max_len elements.
78 // However, the _deepest_ level is allowed to have exactly max_len elements,
79 // since that can save us an unnecessary merge in some cases.
80 // (For example, when precisely max_len elements are added.)
81if let Some(last_len) = self.level_lens.last_mut() {
82if *last_len == self.max_level_len {
83let len = mem::take(last_len);
84self.merge_last(len);
85self.level_lens.push(1);
86 }
87 }
8889// At this point, all levels have room. Add our new part, then continue
90 // merging up the tree until either there's still room in the current level
91 // or we've reached the top.
92self.data.push(part);
9394let max_level = self.level_lens.len() - 1;
95for depth in 0..=max_level {
96let level_len = &mut self.level_lens[depth];
97*level_len += 1;
9899if *level_len < self.max_level_len || depth == max_level {
100break;
101 }
102103let len = mem::take(level_len);
104self.merge_last(len);
105 }
106 }
107108/// Return the contents of this merge tree, flattened into at most `max_len` parts.
109pub fn finish(mut self) -> Vec<T> {
110let mut tail_len = 0;
111for level_len in mem::take(&mut self.level_lens) {
112if tail_len + level_len <= self.max_level_len {
113// Optimization: we can combine the current level with the last level without
114 // going over our limit.
115tail_len += level_len;
116 } else {
117// Otherwise, perform the merge and start a new tail.
118self.merge_last(tail_len);
119 tail_len = level_len + 1
120}
121 }
122assert!(self.data.len() <= self.max_level_len);
123self.data
124 }
125126pub(crate) fn assert_invariants(&self) {
127assert!(self.max_level_len >= 2, "max_len must be at least 2");
128129assert_eq!(
130self.data.len(),
131self.level_lens.iter().copied().sum::<usize>(),
132"level sizes should sum to overall len"
133);
134let (deepest_len, shallow) = self.level_lens.split_last().expect("non-empty level array");
135for (depth, level_len) in shallow.iter().enumerate() {
136assert!(
137*level_len < self.max_level_len,
138"strictly less than max elements at level {depth}"
139);
140 }
141assert!(
142*deepest_len <= self.max_level_len,
143"at most max elements at deepest level"
144);
145 }
146}
147148impl<T> Deref for MergeTree<T> {
149type Target = [T];
150151fn deref(&self) -> &Self::Target {
152&*self.data
153 }
154}
155156impl<T> DerefMut for MergeTree<T> {
157fn deref_mut(&mut self) -> &mut Self::Target {
158&mut *self.data
159 }
160}
161162/// Either a handle to a task that returns a value or the value itself.
163#[derive(Debug)]
164pub enum Pending<T> {
165 Writing(JoinHandle<T>),
166 Blocking,
167 Finished(T),
168}
169170impl<T: Send + 'static> Pending<T> {
171pub fn new(handle: JoinHandle<T>) -> Self {
172Self::Writing(handle)
173 }
174175pub fn is_finished(&self) -> bool {
176matches!(self, Self::Finished(_))
177 }
178179pub async fn into_result(self) -> T {
180match self {
181 Pending::Writing(h) => h.wait_and_assert_finished().await,
182 Pending::Blocking => panic!("block_until_ready cancelled?"),
183 Pending::Finished(t) => t,
184 }
185 }
186187pub async fn block_until_ready(&mut self) {
188let pending = mem::replace(self, Self::Blocking);
189let value = pending.into_result().await;
190*self = Pending::Finished(value);
191 }
192}
193194#[cfg(test)]
195mod tests {
196use super::*;
197use mz_ore::cast::CastLossy;
198199#[mz_ore::test]
200 #[cfg_attr(miri, ignore)] // too slow
201fn test_merge_tree() {
202// Exhaustively test the merge tree for small sizes.
203struct Value {
204 merge_depth: usize,
205 elements: Vec<i64>,
206 }
207208for max_len in 2..8 {
209for items in 0..100 {
210let mut merge_tree = MergeTree::new(max_len, |vals: Vec<Value>| {
211// Merge sequences by concatenation.
212Value {
213 merge_depth: vals.iter().map(|v| v.merge_depth).max().unwrap_or(0) + 1,
214 elements: vals.into_iter().flat_map(|e| e.elements).collect(),
215 }
216 });
217for i in 0..items {
218 merge_tree.push(Value {
219 merge_depth: 0,
220 elements: vec![i],
221 });
222assert!(
223 merge_tree
224 .iter()
225 .flat_map(|v| v.elements.iter())
226 .copied()
227 .eq(0..=i),
228"no parts should be lost"
229);
230 merge_tree.assert_invariants();
231 }
232let parts = merge_tree.finish();
233assert!(
234 parts.len() <= max_len,
235"no more than {max_len} finished parts"
236);
237238// We want our merged tree to be "balanced".
239 // If we have 2^N elements in a binary tree, we want the depth to be N;
240 // and more generally, we want a depth of N for a K-ary tree with K^N elements...
241 // which is to say, a depth of log_K N for a tree with N elements.
242let expected_merge_depth =
243 usize::cast_lossy(f64::cast_lossy(items).log(f64::cast_lossy(max_len)).floor());
244for part in &parts {
245assert!(
246 part.merge_depth <= expected_merge_depth,
247"expected at most {expected_merge_depth} merges for a tree \
248 with max len {max_len} and {items} elements, but got {}",
249 part.merge_depth
250 );
251 }
252assert!(
253 parts
254 .iter()
255 .flat_map(|v| v.elements.iter())
256 .copied()
257 .eq(0..items),
258"no parts lost"
259);
260 }
261 }
262 }
263}