mz_transform/reduction_pushdown.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Tries to convert a reduce around a join to a join of reduces.
11//! Also absorbs Map operators into Reduce operators.
12//!
13//! In a traditional DB, this transformation has a potential benefit of reducing
14//! the size of the join. In our streaming system built on top of Timely
15//! Dataflow and Differential Dataflow, there are two other potential benefits:
16//! 1) Reducing data skew in the arrangements constructed for a join.
17//! 2) The join can potentially reuse the final arrangement constructed for the
18//! reduce and not have to construct its own arrangement.
19//! 3) Reducing the frequency with which we have to recalculate the result of a join.
20//!
21//! Suppose there are two inputs R and S being joined. According to
22//! [Galindo-Legaria and Joshi (2001)](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.563.8492&rep=rep1&type=pdf),
23//! a full reduction pushdown to R can be done if and only if:
24//! 1) Columns from R involved in join constraints are a subset of the group by keys.
25//! 2) The key of S is a subset of the group by keys.
26//! 3) The columns involved in the aggregation all belong to R.
27//!
28//! In our current implementation:
29//! * We abide by condition 1 to the letter.
30//! * We work around condition 2 by rewriting the reduce around a join of R to
31//! S with an equivalent relational expression involving a join of R to
32//! ```ignore
33//! select <columns involved in join constraints>, count(true)
34//! from S
35//! group by <columns involved in join constraints>
36//! ```
37//! * TODO: We work around condition 3 in some cases by noting that `sum(R.a * S.a)`
38//! is equivalent to `sum(R.a) * sum(S.a)`.
39//!
40//! Full documentation with examples can be found
41//! [here](https://docs.google.com/document/d/1xrBJGGDkkiGBKRSNYR2W-nKba96ZOdC2mVbLqMLjJY0/edit)
42//!
43//! The current implementation is chosen so that reduction pushdown kicks in
44//! only in the subset of cases mostly likely to help users. In the future, we
45//! may allow the user to toggle the aggressiveness of reduction pushdown. A
46//! more aggressive reduction pushdown implementation may, for example, try to
47//! work around condition 1 by pushing down an inner reduce through the join
48//! while retaining the original outer reduce.
49
50use std::collections::{BTreeMap, BTreeSet};
51use std::iter::FromIterator;
52
53use mz_expr::visit::Visit;
54use mz_expr::{AggregateExpr, JoinInputMapper, MirRelationExpr, MirScalarExpr};
55
56use crate::TransformCtx;
57use crate::analysis::equivalences::EquivalenceClasses;
58
59/// Pushes Reduce operators toward sources.
60#[derive(Debug)]
61pub struct ReductionPushdown;
62
63impl crate::Transform for ReductionPushdown {
64 fn name(&self) -> &'static str {
65 "ReductionPushdown"
66 }
67
68 #[mz_ore::instrument(
69 target = "optimizer",
70 level = "debug",
71 fields(path.segment = "reduction_pushdown")
72 )]
73 fn actually_perform_transform(
74 &self,
75 relation: &mut MirRelationExpr,
76 _: &mut TransformCtx,
77 ) -> Result<(), crate::TransformError> {
78 // `try_visit_mut_pre` is used here because after pushing down a reduction,
79 // we want to see if we can push the same reduction further down.
80 let result = relation.try_visit_mut_pre(&mut |e| self.action(e));
81 mz_repr::explain::trace_plan(&*relation);
82 result
83 }
84}
85
86impl ReductionPushdown {
87 /// Pushes Reduce operators toward sources.
88 ///
89 /// A join can be thought of as a multigraph where vertices are inputs and
90 /// edges are join constraints. After removing constraints containing a
91 /// GroupBy, the reduce will be pushed down to all connected components. If
92 /// there is only one connected component, this method is a no-op.
93 pub fn action(&self, relation: &mut MirRelationExpr) -> Result<(), crate::TransformError> {
94 if let MirRelationExpr::Reduce {
95 input,
96 group_key,
97 aggregates,
98 monotonic,
99 expected_group_size,
100 } = relation
101 {
102 // Map expressions can be absorbed into the Reduce at no cost.
103 if let MirRelationExpr::Map {
104 input: inner,
105 scalars,
106 } = &mut **input
107 {
108 let arity = inner.arity();
109
110 // Normalize the scalars to not be self-referential.
111 let mut scalars = scalars.clone();
112 for index in 0..scalars.len() {
113 let (lower, upper) = scalars.split_at_mut(index);
114 upper[0].visit_mut_post(&mut |e| {
115 if let mz_expr::MirScalarExpr::Column(c) = e {
116 if *c >= arity {
117 *e = lower[*c - arity].clone();
118 }
119 }
120 })?;
121 }
122 for key in group_key.iter_mut() {
123 key.visit_mut_post(&mut |e| {
124 if let mz_expr::MirScalarExpr::Column(c) = e {
125 if *c >= arity {
126 *e = scalars[*c - arity].clone();
127 }
128 }
129 })?;
130 }
131 for agg in aggregates.iter_mut() {
132 agg.expr.visit_mut_post(&mut |e| {
133 if let mz_expr::MirScalarExpr::Column(c) = e {
134 if *c >= arity {
135 *e = scalars[*c - arity].clone();
136 }
137 }
138 })?;
139 }
140
141 **input = inner.take_dangerous()
142 }
143 if let MirRelationExpr::Join {
144 inputs,
145 equivalences,
146 implementation: _,
147 } = &mut **input
148 {
149 if let Some(new_relation_expr) = try_push_reduce_through_join(
150 inputs,
151 equivalences,
152 group_key,
153 aggregates,
154 *monotonic,
155 *expected_group_size,
156 ) {
157 *relation = new_relation_expr;
158 }
159 }
160 }
161 Ok(())
162 }
163}
164
165fn try_push_reduce_through_join(
166 inputs: &Vec<MirRelationExpr>,
167 equivalences: &Vec<Vec<MirScalarExpr>>,
168 group_key: &Vec<MirScalarExpr>,
169 aggregates: &Vec<AggregateExpr>,
170 monotonic: bool,
171 expected_group_size: Option<u64>,
172) -> Option<MirRelationExpr> {
173 // Variable name details:
174 // The goal is to turn `old` (`Reduce { Join { <inputs> }}`) into
175 // `new`, which looks like:
176 // ```
177 // Project {
178 // Join {
179 // Reduce { <component> }, ... , Reduce { <component> }
180 // }
181 // }
182 // ```
183 //
184 // `<component>` is either `Join {<subset of inputs>}` or
185 // `<element of inputs>`.
186
187 // 0) Make sure that `equivalences` is a proper equivalence relation. Later, in 3a)/i), we'll
188 // rely on expressions appearing in at most one equivalence class.
189 let mut eq_classes = EquivalenceClasses::default();
190 eq_classes.classes = equivalences.clone();
191 eq_classes.minimize(None);
192 let equivalences = eq_classes.classes;
193
194 let old_join_mapper = JoinInputMapper::new(inputs.as_slice());
195 // 1) Partition the join constraints into constraints containing a group
196 // key and constraints that don't.
197 let (new_join_equivalences, component_equivalences): (Vec<_>, Vec<_>) = equivalences
198 .iter()
199 .cloned()
200 .partition(|cls| cls.iter().any(|expr| group_key.contains(expr)));
201
202 // 2) Find the connected components that remain after removing constraints
203 // containing the group_key. Also, track the set of constraints that
204 // connect the inputs in each component.
205 let mut components = (0..inputs.len()).map(Component::new).collect::<Vec<_>>();
206 for equivalence in component_equivalences {
207 // a) Find the inputs referenced by the constraint.
208 let inputs_to_connect = BTreeSet::<usize>::from_iter(
209 equivalence
210 .iter()
211 .flat_map(|expr| old_join_mapper.lookup_inputs(expr)),
212 );
213 // b) Extract the set of components that covers the inputs.
214 let (mut components_to_connect, other): (Vec<_>, Vec<_>) = components
215 .into_iter()
216 .partition(|c| c.inputs.iter().any(|i| inputs_to_connect.contains(i)));
217 components = other;
218 // c) Connect the components and push the result back into the list of
219 // components.
220 if let Some(mut connected_component) = components_to_connect.pop() {
221 connected_component.connect(components_to_connect, equivalence);
222 components.push(connected_component);
223 }
224 // d) Abort reduction pushdown if there are less than two connected components.
225 if components.len() < 2 {
226 return None;
227 }
228 }
229 components.sort();
230 // TODO: Connect components referenced by the same multi-input expression
231 // contained in a constraint containing a GroupBy key.
232 // For the example query below, there should be two components `{foo, bar}`
233 // and `baz`.
234 // ```
235 // select sum(foo.b) from foo, bar, baz
236 // where foo.a * bar.a = 24 group by foo.a * bar.a
237 // ```
238
239 // Maps (input idxs from old join) -> (idx of component it belongs to)
240 let input_component_map = BTreeMap::from_iter(
241 components
242 .iter()
243 .enumerate()
244 .flat_map(|(c_idx, c)| c.inputs.iter().map(move |i| (*i, c_idx))),
245 );
246
247 // 3) Construct a reduce to push to each input
248 let mut new_reduces = components
249 .into_iter()
250 .map(|component| ReduceBuilder::new(component, inputs, &old_join_mapper))
251 .collect::<Vec<_>>();
252
253 // The new projection and new join equivalences will reference columns
254 // produced by the new reduces, but we don't know the arities of the new
255 // reduces yet. Thus, they are temporarily stored as
256 // `(component_idx, column_idx_relative_to_new_reduce)`.
257 let mut new_projection = Vec::with_capacity(group_key.len());
258 let mut new_join_equivalences_by_component = Vec::new();
259
260 // 3a) Calculate the group key for each new reduce. We must make sure that
261 // the union of group keys across the new reduces can produce:
262 // (1) the group keys of the old reduce.
263 // (2) every expression in the equivalences of the new join.
264 for key in group_key {
265 // i) Find the equivalence class that the key is in.
266 // This relies on the expression appearing in at most one equivalence class. This
267 // invariant is ensured in step 0).
268 if let Some(cls) = new_join_equivalences
269 .iter()
270 .find(|cls| cls.iter().any(|expr| expr == key))
271 {
272 // ii) Rewrite the join equivalence in terms of columns produced by
273 // the pushed down reduction.
274 let mut new_join_cls = Vec::new();
275 for expr in cls {
276 if let Some(component) =
277 lookup_corresponding_component(expr, &old_join_mapper, &input_component_map)
278 {
279 if key == expr {
280 new_projection.push((component, new_reduces[component].arity()));
281 }
282 new_join_cls.push((component, new_reduces[component].arity()));
283 new_reduces[component].add_group_key(expr.clone());
284 } else {
285 // Abort reduction pushdown if the expression does not
286 // refer to exactly one component.
287 return None;
288 }
289 }
290 new_join_equivalences_by_component.push(new_join_cls);
291 } else {
292 // If GroupBy key does not belong in an equivalence class,
293 // add the key to new projection + add it as a GroupBy key to
294 // the new component
295 if let Some(component) =
296 lookup_corresponding_component(key, &old_join_mapper, &input_component_map)
297 {
298 new_projection.push((component, new_reduces[component].arity()));
299 new_reduces[component].add_group_key(key.clone())
300 } else {
301 // Abort reduction pushdown if the expression does not
302 // refer to exactly one component.
303 return None;
304 }
305 }
306 }
307
308 // 3b) Deduce the aggregates that each reduce needs to calculate in order to
309 // reconstruct each aggregate in the old reduce.
310 for agg in aggregates {
311 if let Some(component) =
312 lookup_corresponding_component(&agg.expr, &old_join_mapper, &input_component_map)
313 {
314 if !agg.distinct {
315 // TODO: support non-distinct aggs.
316 // For more details, see https://github.com/MaterializeInc/database-issues/issues/2924
317 return None;
318 }
319 new_projection.push((component, new_reduces[component].arity()));
320 new_reduces[component].add_aggregate(agg.clone());
321 } else {
322 // TODO: support multi- and zero- component aggs
323 // For more details, see https://github.com/MaterializeInc/database-issues/issues/2924
324 return None;
325 }
326 }
327
328 // 4) Construct the new `MirRelationExpr`.
329 let new_join_mapper =
330 JoinInputMapper::new_from_input_arities(new_reduces.iter().map(|builder| builder.arity()));
331
332 let new_inputs = new_reduces
333 .into_iter()
334 .map(|builder| builder.construct_reduce(monotonic, expected_group_size))
335 .collect::<Vec<_>>();
336
337 let new_equivalences = new_join_equivalences_by_component
338 .into_iter()
339 .map(|cls| {
340 cls.into_iter()
341 .map(|(idx, col)| {
342 MirScalarExpr::Column(new_join_mapper.map_column_to_global(col, idx))
343 })
344 .collect::<Vec<_>>()
345 })
346 .collect::<Vec<_>>();
347
348 let new_projection = new_projection
349 .into_iter()
350 .map(|(idx, col)| new_join_mapper.map_column_to_global(col, idx))
351 .collect::<Vec<_>>();
352
353 Some(MirRelationExpr::join_scalars(new_inputs, new_equivalences).project(new_projection))
354}
355
356/// Returns None if `expr` does not belong to exactly one component.
357fn lookup_corresponding_component(
358 expr: &MirScalarExpr,
359 old_join_mapper: &JoinInputMapper,
360 input_component_map: &BTreeMap<usize, usize>,
361) -> Option<usize> {
362 let mut dedupped = old_join_mapper
363 .lookup_inputs(expr)
364 .map(|i| input_component_map[&i])
365 .collect::<BTreeSet<_>>();
366 if dedupped.len() == 1 {
367 dedupped.pop_first()
368 } else {
369 None
370 }
371}
372
373/// A subjoin represented as a multigraph.
374#[derive(Eq, Ord, PartialEq, PartialOrd)]
375struct Component {
376 /// Index numbers of the inputs in the subjoin.
377 /// Are the vertices in the multigraph.
378 inputs: Vec<usize>,
379 /// The edges in the multigraph.
380 constraints: Vec<Vec<MirScalarExpr>>,
381}
382
383impl Component {
384 /// Create a new component that contains only one input.
385 fn new(i: usize) -> Self {
386 Component {
387 inputs: vec![i],
388 constraints: Vec::new(),
389 }
390 }
391
392 /// Connect `self` with `others` using the edge `connecting_constraint`.
393 fn connect(&mut self, others: Vec<Component>, connecting_constraint: Vec<MirScalarExpr>) {
394 self.constraints.push(connecting_constraint);
395 for mut other in others {
396 self.inputs.append(&mut other.inputs);
397 self.constraints.append(&mut other.constraints);
398 }
399 self.inputs.sort();
400 self.inputs.dedup();
401 }
402}
403
404/// Constructs a Reduce around a component, localizing column references.
405struct ReduceBuilder {
406 input: MirRelationExpr,
407 group_key: Vec<MirScalarExpr>,
408 aggregates: Vec<AggregateExpr>,
409 /// Maps (global column relative to old join) -> (local column relative to `input`)
410 localize_map: BTreeMap<usize, usize>,
411}
412
413impl ReduceBuilder {
414 fn new(
415 mut component: Component,
416 inputs: &Vec<MirRelationExpr>,
417 old_join_mapper: &JoinInputMapper,
418 ) -> Self {
419 let localize_map = component
420 .inputs
421 .iter()
422 .flat_map(|i| old_join_mapper.global_columns(*i))
423 .enumerate()
424 .map(|(local, global)| (global, local))
425 .collect::<BTreeMap<_, _>>();
426 // Convert the subjoin from the `Component` representation to a
427 // `MirRelationExpr` representation.
428 let mut inputs = component
429 .inputs
430 .iter()
431 .map(|i| inputs[*i].clone())
432 .collect::<Vec<_>>();
433 // Constraints need to be localized to the subjoin.
434 for constraint in component.constraints.iter_mut() {
435 for expr in constraint.iter_mut() {
436 expr.permute_map(&localize_map)
437 }
438 }
439 let input = if inputs.len() == 1 {
440 let mut predicates = Vec::new();
441 for class in component.constraints {
442 for expr in class[1..].iter() {
443 predicates.push(
444 class[0]
445 .clone()
446 .call_binary(expr.clone(), mz_expr::BinaryFunc::Eq)
447 .or(class[0]
448 .clone()
449 .call_is_null()
450 .and(expr.clone().call_is_null())),
451 );
452 }
453 }
454 inputs.pop().unwrap().filter(predicates)
455 } else {
456 MirRelationExpr::join_scalars(inputs, component.constraints)
457 };
458 Self {
459 input,
460 group_key: Vec::new(),
461 aggregates: Vec::new(),
462 localize_map,
463 }
464 }
465
466 fn add_group_key(&mut self, mut key: MirScalarExpr) {
467 key.permute_map(&self.localize_map);
468 self.group_key.push(key);
469 }
470
471 fn add_aggregate(&mut self, mut agg: AggregateExpr) {
472 agg.expr.permute_map(&self.localize_map);
473 self.aggregates.push(agg);
474 }
475
476 fn arity(&self) -> usize {
477 self.group_key.len() + self.aggregates.len()
478 }
479
480 fn construct_reduce(
481 self,
482 monotonic: bool,
483 expected_group_size: Option<u64>,
484 ) -> MirRelationExpr {
485 MirRelationExpr::Reduce {
486 input: Box::new(self.input),
487 group_key: self.group_key,
488 aggregates: self.aggregates,
489 monotonic,
490 expected_group_size,
491 }
492 }
493}