mz_persist_client/internal/
merge.rs1use mz_ore::task::{JoinHandle, JoinHandleExt};
11use std::fmt::{Debug, Formatter};
12use std::mem;
13use std::ops::{Deref, DerefMut};
14
15pub struct MergeTree<T> {
27 max_level_len: usize,
29 level_lens: Vec<usize>,
31 data: Vec<T>,
34 merge_fn: Box<dyn Fn(Vec<T>) -> T + Sync + Send>,
35}
36
37impl<T: Debug> Debug for MergeTree<T> {
38 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39 let Self {
40 max_level_len,
41 level_lens,
42 data,
43 merge_fn: _,
44 } = self;
45 f.debug_struct("MergeTree")
46 .field("max_level_len", max_level_len)
47 .field("level_lens", level_lens)
48 .field("data", data)
49 .finish_non_exhaustive()
50 }
51}
52
53impl<T> MergeTree<T> {
54 pub fn new(max_len: usize, merge_fn: impl Fn(Vec<T>) -> T + Send + Sync + 'static) -> Self {
58 let new = Self {
59 max_level_len: max_len,
60 level_lens: vec![0],
61 data: vec![],
62 merge_fn: Box::new(merge_fn),
63 };
64 new.assert_invariants();
65 new
66 }
67
68 fn merge_last(&mut self, level_len: usize) {
69 let offset = self.data.len() - level_len;
70 let split = self.data.split_off(offset);
71 let merged = (self.merge_fn)(split);
72 self.data.push(merged);
73 }
74
75 pub fn push(&mut self, part: T) {
77 if let Some(last_len) = self.level_lens.last_mut() {
82 if *last_len == self.max_level_len {
83 let len = mem::take(last_len);
84 self.merge_last(len);
85 self.level_lens.push(1);
86 }
87 }
88
89 self.data.push(part);
93
94 let max_level = self.level_lens.len() - 1;
95 for depth in 0..=max_level {
96 let level_len = &mut self.level_lens[depth];
97 *level_len += 1;
98
99 if *level_len < self.max_level_len || depth == max_level {
100 break;
101 }
102
103 let len = mem::take(level_len);
104 self.merge_last(len);
105 }
106 }
107
108 pub fn finish(mut self) -> Vec<T> {
110 let mut tail_len = 0;
111 for level_len in mem::take(&mut self.level_lens) {
112 if tail_len + level_len <= self.max_level_len {
113 tail_len += level_len;
116 } else {
117 self.merge_last(tail_len);
119 tail_len = level_len + 1
120 }
121 }
122 assert!(self.data.len() <= self.max_level_len);
123 self.data
124 }
125
126 pub(crate) fn assert_invariants(&self) {
127 assert!(self.max_level_len >= 2, "max_len must be at least 2");
128
129 assert_eq!(
130 self.data.len(),
131 self.level_lens.iter().copied().sum::<usize>(),
132 "level sizes should sum to overall len"
133 );
134 let (deepest_len, shallow) = self.level_lens.split_last().expect("non-empty level array");
135 for (depth, level_len) in shallow.iter().enumerate() {
136 assert!(
137 *level_len < self.max_level_len,
138 "strictly less than max elements at level {depth}"
139 );
140 }
141 assert!(
142 *deepest_len <= self.max_level_len,
143 "at most max elements at deepest level"
144 );
145 }
146}
147
148impl<T> Deref for MergeTree<T> {
149 type Target = [T];
150
151 fn deref(&self) -> &Self::Target {
152 &*self.data
153 }
154}
155
156impl<T> DerefMut for MergeTree<T> {
157 fn deref_mut(&mut self) -> &mut Self::Target {
158 &mut *self.data
159 }
160}
161
162#[derive(Debug)]
164pub enum Pending<T> {
165 Writing(JoinHandle<T>),
166 Blocking,
167 Finished(T),
168}
169
170impl<T: Send + 'static> Pending<T> {
171 pub fn new(handle: JoinHandle<T>) -> Self {
172 Self::Writing(handle)
173 }
174
175 pub fn is_finished(&self) -> bool {
176 matches!(self, Self::Finished(_))
177 }
178
179 pub async fn into_result(self) -> T {
180 match self {
181 Pending::Writing(h) => h.wait_and_assert_finished().await,
182 Pending::Blocking => panic!("block_until_ready cancelled?"),
183 Pending::Finished(t) => t,
184 }
185 }
186
187 pub async fn block_until_ready(&mut self) {
188 let pending = mem::replace(self, Self::Blocking);
189 let value = pending.into_result().await;
190 *self = Pending::Finished(value);
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use mz_ore::cast::CastLossy;
198
199 #[mz_ore::test]
200 #[cfg_attr(miri, ignore)] fn test_merge_tree() {
202 struct Value {
204 merge_depth: usize,
205 elements: Vec<i64>,
206 }
207
208 for max_len in 2..8 {
209 for items in 0..100 {
210 let mut merge_tree = MergeTree::new(max_len, |vals: Vec<Value>| {
211 Value {
213 merge_depth: vals.iter().map(|v| v.merge_depth).max().unwrap_or(0) + 1,
214 elements: vals.into_iter().flat_map(|e| e.elements).collect(),
215 }
216 });
217 for i in 0..items {
218 merge_tree.push(Value {
219 merge_depth: 0,
220 elements: vec![i],
221 });
222 assert!(
223 merge_tree
224 .iter()
225 .flat_map(|v| v.elements.iter())
226 .copied()
227 .eq(0..=i),
228 "no parts should be lost"
229 );
230 merge_tree.assert_invariants();
231 }
232 let parts = merge_tree.finish();
233 assert!(
234 parts.len() <= max_len,
235 "no more than {max_len} finished parts"
236 );
237
238 let expected_merge_depth =
243 usize::cast_lossy(f64::cast_lossy(items).log(f64::cast_lossy(max_len)).floor());
244 for part in &parts {
245 assert!(
246 part.merge_depth <= expected_merge_depth,
247 "expected at most {expected_merge_depth} merges for a tree \
248 with max len {max_len} and {items} elements, but got {}",
249 part.merge_depth
250 );
251 }
252 assert!(
253 parts
254 .iter()
255 .flat_map(|v| v.elements.iter())
256 .copied()
257 .eq(0..items),
258 "no parts lost"
259 );
260 }
261 }
262 }
263}