mz_persist_client/internal/
merge.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use mz_ore::task::{JoinHandle, JoinHandleExt};
11use std::fmt::{Debug, Formatter};
12use std::mem;
13use std::ops::{Deref, DerefMut};
14
15/// A merge tree.
16///
17/// Invariants and guarantees:
18/// - This structure preserves the order in which elements are `push`ed.
19/// - Merging also preserves order: only adjacent elements will be merged together,
20///   and the result will have the same place in the ordering as the input did.
21/// - The tree will store at most `O(K log N)` elements at once, where `K` is the provided max len
22///   and `N` is the number of elements pushed.
23/// - `finish` will return at most `K` elements.
24/// - The "depth" of the merge tree - the number of merges any particular element may undergo -
25///   is `O(log N)`.
26pub struct MergeTree<T> {
27    /// Configuration: the largest any level in the tree is allowed to grow.
28    max_level_len: usize,
29    /// The length of each level in the tree, stored in order from shallowest to deepest.
30    level_lens: Vec<usize>,
31    /// A flattened representation of the contents of the tree, stored in order from earliest /
32    /// deepest to newest / shallowest.
33    data: Vec<T>,
34    merge_fn: Box<dyn Fn(Vec<T>) -> T + Sync + Send>,
35}
36
37impl<T: Debug> Debug for MergeTree<T> {
38    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39        let Self {
40            max_level_len,
41            level_lens,
42            data,
43            merge_fn: _,
44        } = self;
45        f.debug_struct("MergeTree")
46            .field("max_level_len", max_level_len)
47            .field("level_lens", level_lens)
48            .field("data", data)
49            .finish_non_exhaustive()
50    }
51}
52
53impl<T> MergeTree<T> {
54    /// Create a new merge tree. `max_len` limits both the number of parts to keep at each level of
55    /// the tree, and the number of parts that `Self::finish` will return... and if we exceed that
56    /// limit, the provided `merge_fn` is used to combine adjacent elements together.
57    pub fn new(max_len: usize, merge_fn: impl Fn(Vec<T>) -> T + Send + Sync + 'static) -> Self {
58        let new = Self {
59            max_level_len: max_len,
60            level_lens: vec![0],
61            data: vec![],
62            merge_fn: Box::new(merge_fn),
63        };
64        new.assert_invariants();
65        new
66    }
67
68    fn merge_last(&mut self, level_len: usize) {
69        let offset = self.data.len() - level_len;
70        let split = self.data.split_off(offset);
71        let merged = (self.merge_fn)(split);
72        self.data.push(merged);
73    }
74
75    /// Push a new part onto the end of this tree, possibly triggering a merge.
76    pub fn push(&mut self, part: T) {
77        // Normally, all levels have strictly less than max_len elements.
78        // However, the _deepest_ level is allowed to have exactly max_len elements,
79        // since that can save us an unnecessary merge in some cases.
80        // (For example, when precisely max_len elements are added.)
81        if let Some(last_len) = self.level_lens.last_mut() {
82            if *last_len == self.max_level_len {
83                let len = mem::take(last_len);
84                self.merge_last(len);
85                self.level_lens.push(1);
86            }
87        }
88
89        // At this point, all levels have room. Add our new part, then continue
90        // merging up the tree until either there's still room in the current level
91        // or we've reached the top.
92        self.data.push(part);
93
94        let max_level = self.level_lens.len() - 1;
95        for depth in 0..=max_level {
96            let level_len = &mut self.level_lens[depth];
97            *level_len += 1;
98
99            if *level_len < self.max_level_len || depth == max_level {
100                break;
101            }
102
103            let len = mem::take(level_len);
104            self.merge_last(len);
105        }
106    }
107
108    /// Return the contents of this merge tree, flattened into at most `max_len` parts.
109    pub fn finish(mut self) -> Vec<T> {
110        let mut tail_len = 0;
111        for level_len in mem::take(&mut self.level_lens) {
112            if tail_len + level_len <= self.max_level_len {
113                // Optimization: we can combine the current level with the last level without
114                // going over our limit.
115                tail_len += level_len;
116            } else {
117                // Otherwise, perform the merge and start a new tail.
118                self.merge_last(tail_len);
119                tail_len = level_len + 1
120            }
121        }
122        assert!(self.data.len() <= self.max_level_len);
123        self.data
124    }
125
126    pub(crate) fn assert_invariants(&self) {
127        assert!(self.max_level_len >= 2, "max_len must be at least 2");
128
129        assert_eq!(
130            self.data.len(),
131            self.level_lens.iter().copied().sum::<usize>(),
132            "level sizes should sum to overall len"
133        );
134        let (deepest_len, shallow) = self.level_lens.split_last().expect("non-empty level array");
135        for (depth, level_len) in shallow.iter().enumerate() {
136            assert!(
137                *level_len < self.max_level_len,
138                "strictly less than max elements at level {depth}"
139            );
140        }
141        assert!(
142            *deepest_len <= self.max_level_len,
143            "at most max elements at deepest level"
144        );
145    }
146}
147
148impl<T> Deref for MergeTree<T> {
149    type Target = [T];
150
151    fn deref(&self) -> &Self::Target {
152        &*self.data
153    }
154}
155
156impl<T> DerefMut for MergeTree<T> {
157    fn deref_mut(&mut self) -> &mut Self::Target {
158        &mut *self.data
159    }
160}
161
162/// Either a handle to a task that returns a value or the value itself.
163#[derive(Debug)]
164pub enum Pending<T> {
165    Writing(JoinHandle<T>),
166    Blocking,
167    Finished(T),
168}
169
170impl<T: Send + 'static> Pending<T> {
171    pub fn new(handle: JoinHandle<T>) -> Self {
172        Self::Writing(handle)
173    }
174
175    pub fn is_finished(&self) -> bool {
176        matches!(self, Self::Finished(_))
177    }
178
179    pub async fn into_result(self) -> T {
180        match self {
181            Pending::Writing(h) => h.wait_and_assert_finished().await,
182            Pending::Blocking => panic!("block_until_ready cancelled?"),
183            Pending::Finished(t) => t,
184        }
185    }
186
187    pub async fn block_until_ready(&mut self) {
188        let pending = mem::replace(self, Self::Blocking);
189        let value = pending.into_result().await;
190        *self = Pending::Finished(value);
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use mz_ore::cast::CastLossy;
198
199    #[mz_ore::test]
200    #[cfg_attr(miri, ignore)] // too slow
201    fn test_merge_tree() {
202        // Exhaustively test the merge tree for small sizes.
203        struct Value {
204            merge_depth: usize,
205            elements: Vec<i64>,
206        }
207
208        for max_len in 2..8 {
209            for items in 0..100 {
210                let mut merge_tree = MergeTree::new(max_len, |vals: Vec<Value>| {
211                    // Merge sequences by concatenation.
212                    Value {
213                        merge_depth: vals.iter().map(|v| v.merge_depth).max().unwrap_or(0) + 1,
214                        elements: vals.into_iter().flat_map(|e| e.elements).collect(),
215                    }
216                });
217                for i in 0..items {
218                    merge_tree.push(Value {
219                        merge_depth: 0,
220                        elements: vec![i],
221                    });
222                    assert!(
223                        merge_tree
224                            .iter()
225                            .flat_map(|v| v.elements.iter())
226                            .copied()
227                            .eq(0..=i),
228                        "no parts should be lost"
229                    );
230                    merge_tree.assert_invariants();
231                }
232                let parts = merge_tree.finish();
233                assert!(
234                    parts.len() <= max_len,
235                    "no more than {max_len} finished parts"
236                );
237
238                // We want our merged tree to be "balanced".
239                // If we have 2^N elements in a binary tree, we want the depth to be N;
240                // and more generally, we want a depth of N for a K-ary tree with K^N elements...
241                // which is to say, a depth of log_K N for a tree with N elements.
242                let expected_merge_depth =
243                    usize::cast_lossy(f64::cast_lossy(items).log(f64::cast_lossy(max_len)).floor());
244                for part in &parts {
245                    assert!(
246                        part.merge_depth <= expected_merge_depth,
247                        "expected at most {expected_merge_depth} merges for a tree \
248                        with max len {max_len} and {items} elements, but got {}",
249                        part.merge_depth
250                    );
251                }
252                assert!(
253                    parts
254                        .iter()
255                        .flat_map(|v| v.elements.iter())
256                        .copied()
257                        .eq(0..items),
258                    "no parts lost"
259                );
260            }
261        }
262    }
263}