mz_transform/reduction_pushdown.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Tries to convert a reduce around a join to a join of reduces.
11//! Also absorbs Map operators into Reduce operators.
12//!
13//! In a traditional DB, this transformation has a potential benefit of reducing
14//! the size of the join. In our streaming system built on top of Timely
15//! Dataflow and Differential Dataflow, there are two other potential benefits:
16//! 1) Reducing data skew in the arrangements constructed for a join.
17//! 2) The join can potentially reuse the final arrangement constructed for the
18//! reduce and not have to construct its own arrangement.
19//! 3) Reducing the frequency with which we have to recalculate the result of a join.
20//!
21//! Suppose there are two inputs R and S being joined. According to
22//! [Galindo-Legaria and Joshi (2001)](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.563.8492&rep=rep1&type=pdf),
23//! a full reduction pushdown to R can be done if and only if:
24//! 1) Columns from R involved in join constraints are a subset of the group by keys.
25//! 2) The key of S is a subset of the group by keys.
26//! 3) The columns involved in the aggregation all belong to R.
27//!
28//! In our current implementation:
29//! * We abide by condition 1 to the letter.
30//! * We work around condition 2 by rewriting the reduce around a join of R to
31//! S with an equivalent relational expression involving a join of R to
32//! ```ignore
33//! select <columns involved in join constraints>, count(true)
34//! from S
35//! group by <columns involved in join constraints>
36//! ```
37//! * TODO: We work around condition 3 in some cases by noting that `sum(R.a * S.a)`
38//! is equivalent to `sum(R.a) * sum(S.a)`.
39//!
40//! Full documentation with examples can be found
41//! [here](https://docs.google.com/document/d/1xrBJGGDkkiGBKRSNYR2W-nKba96ZOdC2mVbLqMLjJY0/edit)
42//!
43//! The current implementation is chosen so that reduction pushdown kicks in
44//! only in the subset of cases mostly likely to help users. In the future, we
45//! may allow the user to toggle the aggressiveness of reduction pushdown. A
46//! more aggressive reduction pushdown implementation may, for example, try to
47//! work around condition 1 by pushing down an inner reduce through the join
48//! while retaining the original outer reduce.
49
50use std::collections::{BTreeMap, BTreeSet};
51use std::iter::FromIterator;
52
53use mz_expr::visit::Visit;
54use mz_expr::{AggregateExpr, Columns, JoinInputMapper, MirRelationExpr, MirScalarExpr};
55
56use crate::TransformCtx;
57use crate::analysis::equivalences::EquivalenceClasses;
58
59/// Pushes Reduce operators toward sources.
60#[derive(Debug)]
61pub struct ReductionPushdown;
62
63impl crate::Transform for ReductionPushdown {
64 fn name(&self) -> &'static str {
65 "ReductionPushdown"
66 }
67
68 #[mz_ore::instrument(
69 target = "optimizer",
70 level = "debug",
71 fields(path.segment = "reduction_pushdown")
72 )]
73 fn actually_perform_transform(
74 &self,
75 relation: &mut MirRelationExpr,
76 _: &mut TransformCtx,
77 ) -> Result<(), crate::TransformError> {
78 // `try_visit_mut_pre` is used here because after pushing down a reduction,
79 // we want to see if we can push the same reduction further down.
80 relation.visit_mut_pre(&mut |e| self.action(e));
81 mz_repr::explain::trace_plan(&*relation);
82 Ok(())
83 }
84}
85
86impl ReductionPushdown {
87 /// Pushes Reduce operators toward sources.
88 ///
89 /// A join can be thought of as a multigraph where vertices are inputs and
90 /// edges are join constraints. After removing constraints containing a
91 /// GroupBy, the reduce will be pushed down to all connected components. If
92 /// there is only one connected component, this method is a no-op.
93 pub fn action(&self, relation: &mut MirRelationExpr) {
94 if let MirRelationExpr::Reduce {
95 input,
96 group_key,
97 aggregates,
98 monotonic,
99 expected_group_size,
100 } = relation
101 {
102 // Map expressions can be absorbed into the Reduce at no cost.
103 if let MirRelationExpr::Map {
104 input: inner,
105 scalars,
106 } = &mut **input
107 {
108 let arity = inner.arity();
109
110 // Normalize the scalars to not be self-referential.
111 let mut scalars = scalars.clone();
112 for index in 0..scalars.len() {
113 let (lower, upper) = scalars.split_at_mut(index);
114 upper[0].visit_mut_post(&mut |e| {
115 if let mz_expr::MirScalarExpr::Column(c, _) = e {
116 if *c >= arity {
117 *e = lower[*c - arity].clone();
118 }
119 }
120 });
121 }
122 for key in group_key.iter_mut() {
123 key.visit_mut_post(&mut |e| {
124 if let mz_expr::MirScalarExpr::Column(c, _) = e {
125 if *c >= arity {
126 *e = scalars[*c - arity].clone();
127 }
128 }
129 });
130 }
131 for agg in aggregates.iter_mut() {
132 agg.expr.visit_mut_post(&mut |e| {
133 if let mz_expr::MirScalarExpr::Column(c, _) = e {
134 if *c >= arity {
135 *e = scalars[*c - arity].clone();
136 }
137 }
138 });
139 }
140
141 **input = inner.take_dangerous()
142 }
143 if let MirRelationExpr::Join {
144 inputs,
145 equivalences,
146 implementation: _,
147 } = &mut **input
148 {
149 if let Some(new_relation_expr) = try_push_reduce_through_join(
150 inputs,
151 equivalences,
152 group_key,
153 aggregates,
154 *monotonic,
155 *expected_group_size,
156 ) {
157 *relation = new_relation_expr;
158 }
159 }
160 }
161 }
162}
163
164fn try_push_reduce_through_join(
165 inputs: &Vec<MirRelationExpr>,
166 equivalences: &Vec<Vec<MirScalarExpr>>,
167 group_key: &Vec<MirScalarExpr>,
168 aggregates: &Vec<AggregateExpr>,
169 monotonic: bool,
170 expected_group_size: Option<u64>,
171) -> Option<MirRelationExpr> {
172 // Variable name details:
173 // The goal is to turn `old` (`Reduce { Join { <inputs> }}`) into
174 // `new`, which looks like:
175 // ```
176 // Project {
177 // Join {
178 // Reduce { <component> }, ... , Reduce { <component> }
179 // }
180 // }
181 // ```
182 //
183 // `<component>` is either `Join {<subset of inputs>}` or
184 // `<element of inputs>`.
185
186 // 0) Make sure that `equivalences` is a proper equivalence relation. Later, in 3a)/i), we'll
187 // rely on expressions appearing in at most one equivalence class.
188 let mut eq_classes = EquivalenceClasses::default();
189 eq_classes.classes = equivalences.clone();
190 eq_classes.minimize(None);
191 let equivalences = eq_classes.classes;
192
193 let old_join_mapper = JoinInputMapper::new(inputs.as_slice());
194 // 1) Partition the join constraints into constraints containing a group
195 // key and constraints that don't.
196 let (new_join_equivalences, component_equivalences): (Vec<_>, Vec<_>) = equivalences
197 .iter()
198 .cloned()
199 .partition(|cls| cls.iter().any(|expr| group_key.contains(expr)));
200
201 // 2) Find the connected components that remain after removing constraints
202 // containing the group_key. Also, track the set of constraints that
203 // connect the inputs in each component.
204 let mut components = (0..inputs.len()).map(Component::new).collect::<Vec<_>>();
205 for equivalence in component_equivalences {
206 // a) Find the inputs referenced by the constraint.
207 let inputs_to_connect = BTreeSet::<usize>::from_iter(
208 equivalence
209 .iter()
210 .flat_map(|expr| old_join_mapper.lookup_inputs(expr)),
211 );
212 // b) Extract the set of components that covers the inputs.
213 let (mut components_to_connect, other): (Vec<_>, Vec<_>) = components
214 .into_iter()
215 .partition(|c| c.inputs.iter().any(|i| inputs_to_connect.contains(i)));
216 components = other;
217 // c) Connect the components and push the result back into the list of
218 // components.
219 if let Some(mut connected_component) = components_to_connect.pop() {
220 connected_component.connect(components_to_connect, equivalence);
221 components.push(connected_component);
222 }
223 // d) Abort reduction pushdown if there are less than two connected components.
224 if components.len() < 2 {
225 return None;
226 }
227 }
228 components.sort();
229 // TODO: Connect components referenced by the same multi-input expression
230 // contained in a constraint containing a GroupBy key.
231 // For the example query below, there should be two components `{foo, bar}`
232 // and `baz`.
233 // ```
234 // select sum(foo.b) from foo, bar, baz
235 // where foo.a * bar.a = 24 group by foo.a * bar.a
236 // ```
237
238 // Maps (input idxs from old join) -> (idx of component it belongs to)
239 let input_component_map = BTreeMap::from_iter(
240 components
241 .iter()
242 .enumerate()
243 .flat_map(|(c_idx, c)| c.inputs.iter().map(move |i| (*i, c_idx))),
244 );
245
246 // 3) Construct a reduce to push to each input
247 let mut new_reduces = components
248 .into_iter()
249 .map(|component| ReduceBuilder::new(component, inputs, &old_join_mapper))
250 .collect::<Vec<_>>();
251
252 // The new projection and new join equivalences will reference columns
253 // produced by the new reduces, but we don't know the arities of the new
254 // reduces yet. Thus, they are temporarily stored as
255 // `(component_idx, column_idx_relative_to_new_reduce)`.
256 let mut new_projection = Vec::with_capacity(group_key.len());
257 let mut new_join_equivalences_by_component = Vec::new();
258
259 // 3a) Calculate the group key for each new reduce. We must make sure that
260 // the union of group keys across the new reduces can produce:
261 // (1) the group keys of the old reduce.
262 // (2) every expression in the equivalences of the new join.
263 for key in group_key {
264 // i) Find the equivalence class that the key is in.
265 // This relies on the expression appearing in at most one equivalence class. This
266 // invariant is ensured in step 0).
267 if let Some(cls) = new_join_equivalences
268 .iter()
269 .find(|cls| cls.iter().any(|expr| expr == key))
270 {
271 // ii) Rewrite the join equivalence in terms of columns produced by
272 // the pushed down reduction.
273 let mut new_join_cls = Vec::new();
274 for expr in cls {
275 if let Some(component) =
276 lookup_corresponding_component(expr, &old_join_mapper, &input_component_map)
277 {
278 if key == expr {
279 new_projection.push((component, new_reduces[component].arity()));
280 }
281 new_join_cls.push((component, new_reduces[component].arity()));
282 new_reduces[component].add_group_key(expr.clone());
283 } else {
284 // Abort reduction pushdown if the expression does not
285 // refer to exactly one component.
286 return None;
287 }
288 }
289 new_join_equivalences_by_component.push(new_join_cls);
290 } else {
291 // If GroupBy key does not belong in an equivalence class,
292 // add the key to new projection + add it as a GroupBy key to
293 // the new component
294 if let Some(component) =
295 lookup_corresponding_component(key, &old_join_mapper, &input_component_map)
296 {
297 new_projection.push((component, new_reduces[component].arity()));
298 new_reduces[component].add_group_key(key.clone())
299 } else {
300 // Abort reduction pushdown if the expression does not
301 // refer to exactly one component.
302 return None;
303 }
304 }
305 }
306
307 // 3b) Deduce the aggregates that each reduce needs to calculate in order to
308 // reconstruct each aggregate in the old reduce.
309 for agg in aggregates {
310 if let Some(component) =
311 lookup_corresponding_component(&agg.expr, &old_join_mapper, &input_component_map)
312 {
313 if !agg.distinct {
314 // TODO: support non-distinct aggs.
315 // For more details, see https://github.com/MaterializeInc/database-issues/issues/2924
316 return None;
317 }
318 new_projection.push((component, new_reduces[component].arity()));
319 new_reduces[component].add_aggregate(agg.clone());
320 } else {
321 // TODO: support multi- and zero- component aggs
322 // For more details, see https://github.com/MaterializeInc/database-issues/issues/2924
323 return None;
324 }
325 }
326
327 // 4) Construct the new `MirRelationExpr`.
328 let new_join_mapper =
329 JoinInputMapper::new_from_input_arities(new_reduces.iter().map(|builder| builder.arity()));
330
331 let new_inputs = new_reduces
332 .into_iter()
333 .map(|builder| builder.construct_reduce(monotonic, expected_group_size))
334 .collect::<Vec<_>>();
335
336 let new_equivalences = new_join_equivalences_by_component
337 .into_iter()
338 .map(|cls| {
339 cls.into_iter()
340 .map(|(idx, col)| {
341 MirScalarExpr::column(new_join_mapper.map_column_to_global(col, idx))
342 })
343 .collect::<Vec<_>>()
344 })
345 .collect::<Vec<_>>();
346
347 let new_projection = new_projection
348 .into_iter()
349 .map(|(idx, col)| new_join_mapper.map_column_to_global(col, idx))
350 .collect::<Vec<_>>();
351
352 Some(MirRelationExpr::join_scalars(new_inputs, new_equivalences).project(new_projection))
353}
354
355/// Returns None if `expr` does not belong to exactly one component.
356fn lookup_corresponding_component(
357 expr: &MirScalarExpr,
358 old_join_mapper: &JoinInputMapper,
359 input_component_map: &BTreeMap<usize, usize>,
360) -> Option<usize> {
361 let mut dedupped = old_join_mapper
362 .lookup_inputs(expr)
363 .map(|i| input_component_map[&i])
364 .collect::<BTreeSet<_>>();
365 if dedupped.len() == 1 {
366 dedupped.pop_first()
367 } else {
368 None
369 }
370}
371
372/// A subjoin represented as a multigraph.
373#[derive(Eq, Ord, PartialEq, PartialOrd)]
374struct Component {
375 /// Index numbers of the inputs in the subjoin.
376 /// Are the vertices in the multigraph.
377 inputs: Vec<usize>,
378 /// The edges in the multigraph.
379 constraints: Vec<Vec<MirScalarExpr>>,
380}
381
382impl Component {
383 /// Create a new component that contains only one input.
384 fn new(i: usize) -> Self {
385 Component {
386 inputs: vec![i],
387 constraints: Vec::new(),
388 }
389 }
390
391 /// Connect `self` with `others` using the edge `connecting_constraint`.
392 fn connect(&mut self, others: Vec<Component>, connecting_constraint: Vec<MirScalarExpr>) {
393 self.constraints.push(connecting_constraint);
394 for mut other in others {
395 self.inputs.append(&mut other.inputs);
396 self.constraints.append(&mut other.constraints);
397 }
398 self.inputs.sort();
399 self.inputs.dedup();
400 }
401}
402
403/// Constructs a Reduce around a component, localizing column references.
404struct ReduceBuilder {
405 input: MirRelationExpr,
406 group_key: Vec<MirScalarExpr>,
407 aggregates: Vec<AggregateExpr>,
408 /// Maps (global column relative to old join) -> (local column relative to `input`)
409 localize_map: BTreeMap<usize, usize>,
410}
411
412impl ReduceBuilder {
413 fn new(
414 mut component: Component,
415 inputs: &Vec<MirRelationExpr>,
416 old_join_mapper: &JoinInputMapper,
417 ) -> Self {
418 let localize_map = component
419 .inputs
420 .iter()
421 .flat_map(|i| old_join_mapper.global_columns(*i))
422 .enumerate()
423 .map(|(local, global)| (global, local))
424 .collect::<BTreeMap<_, _>>();
425 // Convert the subjoin from the `Component` representation to a
426 // `MirRelationExpr` representation.
427 let mut inputs = component
428 .inputs
429 .iter()
430 .map(|i| inputs[*i].clone())
431 .collect::<Vec<_>>();
432 // Constraints need to be localized to the subjoin.
433 for constraint in component.constraints.iter_mut() {
434 for expr in constraint.iter_mut() {
435 expr.permute_map(&localize_map)
436 }
437 }
438 let input = if inputs.len() == 1 {
439 let mut predicates = Vec::new();
440 for class in component.constraints {
441 for expr in class[1..].iter() {
442 predicates.push(
443 class[0]
444 .clone()
445 .call_binary(expr.clone(), mz_expr::func::Eq)
446 .or(class[0]
447 .clone()
448 .call_is_null()
449 .and(expr.clone().call_is_null())),
450 );
451 }
452 }
453 inputs.pop().unwrap().filter(predicates)
454 } else {
455 MirRelationExpr::join_scalars(inputs, component.constraints)
456 };
457 Self {
458 input,
459 group_key: Vec::new(),
460 aggregates: Vec::new(),
461 localize_map,
462 }
463 }
464
465 fn add_group_key(&mut self, mut key: MirScalarExpr) {
466 key.permute_map(&self.localize_map);
467 self.group_key.push(key);
468 }
469
470 fn add_aggregate(&mut self, mut agg: AggregateExpr) {
471 agg.expr.permute_map(&self.localize_map);
472 self.aggregates.push(agg);
473 }
474
475 fn arity(&self) -> usize {
476 self.group_key.len() + self.aggregates.len()
477 }
478
479 fn construct_reduce(
480 self,
481 monotonic: bool,
482 expected_group_size: Option<u64>,
483 ) -> MirRelationExpr {
484 MirRelationExpr::Reduce {
485 input: Box::new(self.input),
486 group_key: self.group_key,
487 aggregates: self.aggregates,
488 monotonic,
489 expected_group_size,
490 }
491 }
492}