mz_transform/reduction_pushdown.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Tries to convert a reduce around a join to a join of reduces.
11//! Also absorbs Map operators into Reduce operators.
12//!
13//! In a traditional DB, this transformation has a potential benefit of reducing
14//! the size of the join. In our streaming system built on top of Timely
15//! Dataflow and Differential Dataflow, there are two other potential benefits:
16//! 1) Reducing data skew in the arrangements constructed for a join.
17//! 2) The join can potentially reuse the final arrangement constructed for the
18//!    reduce and not have to construct its own arrangement.
19//! 3) Reducing the frequency with which we have to recalculate the result of a join.
20//!
21//! Suppose there are two inputs R and S being joined. According to
22//! [Galindo-Legaria and Joshi (2001)](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.563.8492&rep=rep1&type=pdf),
23//! a full reduction pushdown to R can be done if and only if:
24//! 1) Columns from R involved in join constraints are a subset of the group by keys.
25//! 2) The key of S is a subset of the group by keys.
26//! 3) The columns involved in the aggregation all belong to R.
27//!
28//! In our current implementation:
29//! * We abide by condition 1 to the letter.
30//! * We work around condition 2 by rewriting the reduce around a join of R to
31//!   S with an equivalent relational expression involving a join of R to
32//!   ```ignore
33//!   select <columns involved in join constraints>, count(true)
34//!   from S
35//!   group by <columns involved in join constraints>
36//!   ```
37//! * TODO: We work around condition 3 in some cases by noting that `sum(R.a * S.a)`
38//!   is equivalent to `sum(R.a) * sum(S.a)`.
39//!
40//! Full documentation with examples can be found
41//! [here](https://docs.google.com/document/d/1xrBJGGDkkiGBKRSNYR2W-nKba96ZOdC2mVbLqMLjJY0/edit)
42//!
43//! The current implementation is chosen so that reduction pushdown kicks in
44//! only in the subset of cases mostly likely to help users. In the future, we
45//! may allow the user to toggle the aggressiveness of reduction pushdown. A
46//! more aggressive reduction pushdown implementation may, for example, try to
47//! work around condition 1 by pushing down an inner reduce through the join
48//! while retaining the original outer reduce.
49
50use std::collections::{BTreeMap, BTreeSet};
51use std::iter::FromIterator;
52
53use mz_expr::visit::Visit;
54use mz_expr::{AggregateExpr, JoinInputMapper, MirRelationExpr, MirScalarExpr};
55
56use crate::TransformCtx;
57use crate::analysis::equivalences::EquivalenceClasses;
58
59/// Pushes Reduce operators toward sources.
60#[derive(Debug)]
61pub struct ReductionPushdown;
62
63impl crate::Transform for ReductionPushdown {
64    fn name(&self) -> &'static str {
65        "ReductionPushdown"
66    }
67
68    #[mz_ore::instrument(
69        target = "optimizer",
70        level = "debug",
71        fields(path.segment = "reduction_pushdown")
72    )]
73    fn actually_perform_transform(
74        &self,
75        relation: &mut MirRelationExpr,
76        _: &mut TransformCtx,
77    ) -> Result<(), crate::TransformError> {
78        // `try_visit_mut_pre` is used here because after pushing down a reduction,
79        // we want to see if we can push the same reduction further down.
80        let result = relation.try_visit_mut_pre(&mut |e| self.action(e));
81        mz_repr::explain::trace_plan(&*relation);
82        result
83    }
84}
85
86impl ReductionPushdown {
87    /// Pushes Reduce operators toward sources.
88    ///
89    /// A join can be thought of as a multigraph where vertices are inputs and
90    /// edges are join constraints. After removing constraints containing a
91    /// GroupBy, the reduce will be pushed down to all connected components. If
92    /// there is only one connected component, this method is a no-op.
93    pub fn action(&self, relation: &mut MirRelationExpr) -> Result<(), crate::TransformError> {
94        if let MirRelationExpr::Reduce {
95            input,
96            group_key,
97            aggregates,
98            monotonic,
99            expected_group_size,
100        } = relation
101        {
102            // Map expressions can be absorbed into the Reduce at no cost.
103            if let MirRelationExpr::Map {
104                input: inner,
105                scalars,
106            } = &mut **input
107            {
108                let arity = inner.arity();
109
110                // Normalize the scalars to not be self-referential.
111                let mut scalars = scalars.clone();
112                for index in 0..scalars.len() {
113                    let (lower, upper) = scalars.split_at_mut(index);
114                    upper[0].visit_mut_post(&mut |e| {
115                        if let mz_expr::MirScalarExpr::Column(c, _) = e {
116                            if *c >= arity {
117                                *e = lower[*c - arity].clone();
118                            }
119                        }
120                    })?;
121                }
122                for key in group_key.iter_mut() {
123                    key.visit_mut_post(&mut |e| {
124                        if let mz_expr::MirScalarExpr::Column(c, _) = e {
125                            if *c >= arity {
126                                *e = scalars[*c - arity].clone();
127                            }
128                        }
129                    })?;
130                }
131                for agg in aggregates.iter_mut() {
132                    agg.expr.visit_mut_post(&mut |e| {
133                        if let mz_expr::MirScalarExpr::Column(c, _) = e {
134                            if *c >= arity {
135                                *e = scalars[*c - arity].clone();
136                            }
137                        }
138                    })?;
139                }
140
141                **input = inner.take_dangerous()
142            }
143            if let MirRelationExpr::Join {
144                inputs,
145                equivalences,
146                implementation: _,
147            } = &mut **input
148            {
149                if let Some(new_relation_expr) = try_push_reduce_through_join(
150                    inputs,
151                    equivalences,
152                    group_key,
153                    aggregates,
154                    *monotonic,
155                    *expected_group_size,
156                ) {
157                    *relation = new_relation_expr;
158                }
159            }
160        }
161        Ok(())
162    }
163}
164
165fn try_push_reduce_through_join(
166    inputs: &Vec<MirRelationExpr>,
167    equivalences: &Vec<Vec<MirScalarExpr>>,
168    group_key: &Vec<MirScalarExpr>,
169    aggregates: &Vec<AggregateExpr>,
170    monotonic: bool,
171    expected_group_size: Option<u64>,
172) -> Option<MirRelationExpr> {
173    // Variable name details:
174    // The goal is to turn `old` (`Reduce { Join { <inputs> }}`) into
175    // `new`, which looks like:
176    // ```
177    // Project {
178    //    Join {
179    //      Reduce { <component> }, ... , Reduce { <component> }
180    //    }
181    // }
182    // ```
183    //
184    // `<component>` is either `Join {<subset of inputs>}` or
185    // `<element of inputs>`.
186
187    // 0) Make sure that `equivalences` is a proper equivalence relation. Later, in 3a)/i), we'll
188    //    rely on expressions appearing in at most one equivalence class.
189    let mut eq_classes = EquivalenceClasses::default();
190    eq_classes.classes = equivalences.clone();
191    eq_classes.minimize(None);
192    let equivalences = eq_classes.classes;
193
194    let old_join_mapper = JoinInputMapper::new(inputs.as_slice());
195    // 1) Partition the join constraints into constraints containing a group
196    //    key and constraints that don't.
197    let (new_join_equivalences, component_equivalences): (Vec<_>, Vec<_>) = equivalences
198        .iter()
199        .cloned()
200        .partition(|cls| cls.iter().any(|expr| group_key.contains(expr)));
201
202    // 2) Find the connected components that remain after removing constraints
203    //    containing the group_key. Also, track the set of constraints that
204    //    connect the inputs in each component.
205    let mut components = (0..inputs.len()).map(Component::new).collect::<Vec<_>>();
206    for equivalence in component_equivalences {
207        // a) Find the inputs referenced by the constraint.
208        let inputs_to_connect = BTreeSet::<usize>::from_iter(
209            equivalence
210                .iter()
211                .flat_map(|expr| old_join_mapper.lookup_inputs(expr)),
212        );
213        // b) Extract the set of components that covers the inputs.
214        let (mut components_to_connect, other): (Vec<_>, Vec<_>) = components
215            .into_iter()
216            .partition(|c| c.inputs.iter().any(|i| inputs_to_connect.contains(i)));
217        components = other;
218        // c) Connect the components and push the result back into the list of
219        // components.
220        if let Some(mut connected_component) = components_to_connect.pop() {
221            connected_component.connect(components_to_connect, equivalence);
222            components.push(connected_component);
223        }
224        // d) Abort reduction pushdown if there are less than two connected components.
225        if components.len() < 2 {
226            return None;
227        }
228    }
229    components.sort();
230    // TODO: Connect components referenced by the same multi-input expression
231    // contained in a constraint containing a GroupBy key.
232    // For the example query below, there should be two components `{foo, bar}`
233    // and `baz`.
234    // ```
235    // select sum(foo.b) from foo, bar, baz
236    // where foo.a * bar.a = 24 group by foo.a * bar.a
237    // ```
238
239    // Maps (input idxs from old join) -> (idx of component it belongs to)
240    let input_component_map = BTreeMap::from_iter(
241        components
242            .iter()
243            .enumerate()
244            .flat_map(|(c_idx, c)| c.inputs.iter().map(move |i| (*i, c_idx))),
245    );
246
247    // 3) Construct a reduce to push to each input
248    let mut new_reduces = components
249        .into_iter()
250        .map(|component| ReduceBuilder::new(component, inputs, &old_join_mapper))
251        .collect::<Vec<_>>();
252
253    // The new projection and new join equivalences will reference columns
254    // produced by the new reduces, but we don't know the arities of the new
255    // reduces yet. Thus, they are temporarily stored as
256    // `(component_idx, column_idx_relative_to_new_reduce)`.
257    let mut new_projection = Vec::with_capacity(group_key.len());
258    let mut new_join_equivalences_by_component = Vec::new();
259
260    // 3a) Calculate the group key for each new reduce. We must make sure that
261    // the union of group keys across the new reduces can produce:
262    // (1) the group keys of the old reduce.
263    // (2) every expression in the equivalences of the new join.
264    for key in group_key {
265        // i) Find the equivalence class that the key is in.
266        //    This relies on the expression appearing in at most one equivalence class. This
267        //    invariant is ensured in step 0).
268        if let Some(cls) = new_join_equivalences
269            .iter()
270            .find(|cls| cls.iter().any(|expr| expr == key))
271        {
272            // ii) Rewrite the join equivalence in terms of columns produced by
273            // the pushed down reduction.
274            let mut new_join_cls = Vec::new();
275            for expr in cls {
276                if let Some(component) =
277                    lookup_corresponding_component(expr, &old_join_mapper, &input_component_map)
278                {
279                    if key == expr {
280                        new_projection.push((component, new_reduces[component].arity()));
281                    }
282                    new_join_cls.push((component, new_reduces[component].arity()));
283                    new_reduces[component].add_group_key(expr.clone());
284                } else {
285                    // Abort reduction pushdown if the expression does not
286                    // refer to exactly one component.
287                    return None;
288                }
289            }
290            new_join_equivalences_by_component.push(new_join_cls);
291        } else {
292            // If GroupBy key does not belong in an equivalence class,
293            // add the key to new projection + add it as a GroupBy key to
294            // the new component
295            if let Some(component) =
296                lookup_corresponding_component(key, &old_join_mapper, &input_component_map)
297            {
298                new_projection.push((component, new_reduces[component].arity()));
299                new_reduces[component].add_group_key(key.clone())
300            } else {
301                // Abort reduction pushdown if the expression does not
302                // refer to exactly one component.
303                return None;
304            }
305        }
306    }
307
308    // 3b) Deduce the aggregates that each reduce needs to calculate in order to
309    // reconstruct each aggregate in the old reduce.
310    for agg in aggregates {
311        if let Some(component) =
312            lookup_corresponding_component(&agg.expr, &old_join_mapper, &input_component_map)
313        {
314            if !agg.distinct {
315                // TODO: support non-distinct aggs.
316                // For more details, see https://github.com/MaterializeInc/database-issues/issues/2924
317                return None;
318            }
319            new_projection.push((component, new_reduces[component].arity()));
320            new_reduces[component].add_aggregate(agg.clone());
321        } else {
322            // TODO: support multi- and zero- component aggs
323            // For more details, see https://github.com/MaterializeInc/database-issues/issues/2924
324            return None;
325        }
326    }
327
328    // 4) Construct the new `MirRelationExpr`.
329    let new_join_mapper =
330        JoinInputMapper::new_from_input_arities(new_reduces.iter().map(|builder| builder.arity()));
331
332    let new_inputs = new_reduces
333        .into_iter()
334        .map(|builder| builder.construct_reduce(monotonic, expected_group_size))
335        .collect::<Vec<_>>();
336
337    let new_equivalences = new_join_equivalences_by_component
338        .into_iter()
339        .map(|cls| {
340            cls.into_iter()
341                .map(|(idx, col)| {
342                    MirScalarExpr::column(new_join_mapper.map_column_to_global(col, idx))
343                })
344                .collect::<Vec<_>>()
345        })
346        .collect::<Vec<_>>();
347
348    let new_projection = new_projection
349        .into_iter()
350        .map(|(idx, col)| new_join_mapper.map_column_to_global(col, idx))
351        .collect::<Vec<_>>();
352
353    Some(MirRelationExpr::join_scalars(new_inputs, new_equivalences).project(new_projection))
354}
355
356/// Returns None if `expr` does not belong to exactly one component.
357fn lookup_corresponding_component(
358    expr: &MirScalarExpr,
359    old_join_mapper: &JoinInputMapper,
360    input_component_map: &BTreeMap<usize, usize>,
361) -> Option<usize> {
362    let mut dedupped = old_join_mapper
363        .lookup_inputs(expr)
364        .map(|i| input_component_map[&i])
365        .collect::<BTreeSet<_>>();
366    if dedupped.len() == 1 {
367        dedupped.pop_first()
368    } else {
369        None
370    }
371}
372
373/// A subjoin represented as a multigraph.
374#[derive(Eq, Ord, PartialEq, PartialOrd)]
375struct Component {
376    /// Index numbers of the inputs in the subjoin.
377    /// Are the vertices in the multigraph.
378    inputs: Vec<usize>,
379    /// The edges in the multigraph.
380    constraints: Vec<Vec<MirScalarExpr>>,
381}
382
383impl Component {
384    /// Create a new component that contains only one input.
385    fn new(i: usize) -> Self {
386        Component {
387            inputs: vec![i],
388            constraints: Vec::new(),
389        }
390    }
391
392    /// Connect `self` with `others` using the edge `connecting_constraint`.
393    fn connect(&mut self, others: Vec<Component>, connecting_constraint: Vec<MirScalarExpr>) {
394        self.constraints.push(connecting_constraint);
395        for mut other in others {
396            self.inputs.append(&mut other.inputs);
397            self.constraints.append(&mut other.constraints);
398        }
399        self.inputs.sort();
400        self.inputs.dedup();
401    }
402}
403
404/// Constructs a Reduce around a component, localizing column references.
405struct ReduceBuilder {
406    input: MirRelationExpr,
407    group_key: Vec<MirScalarExpr>,
408    aggregates: Vec<AggregateExpr>,
409    /// Maps (global column relative to old join) -> (local column relative to `input`)
410    localize_map: BTreeMap<usize, usize>,
411}
412
413impl ReduceBuilder {
414    fn new(
415        mut component: Component,
416        inputs: &Vec<MirRelationExpr>,
417        old_join_mapper: &JoinInputMapper,
418    ) -> Self {
419        let localize_map = component
420            .inputs
421            .iter()
422            .flat_map(|i| old_join_mapper.global_columns(*i))
423            .enumerate()
424            .map(|(local, global)| (global, local))
425            .collect::<BTreeMap<_, _>>();
426        // Convert the subjoin from the `Component` representation to a
427        // `MirRelationExpr` representation.
428        let mut inputs = component
429            .inputs
430            .iter()
431            .map(|i| inputs[*i].clone())
432            .collect::<Vec<_>>();
433        // Constraints need to be localized to the subjoin.
434        for constraint in component.constraints.iter_mut() {
435            for expr in constraint.iter_mut() {
436                expr.permute_map(&localize_map)
437            }
438        }
439        let input = if inputs.len() == 1 {
440            let mut predicates = Vec::new();
441            for class in component.constraints {
442                for expr in class[1..].iter() {
443                    predicates.push(
444                        class[0]
445                            .clone()
446                            .call_binary(expr.clone(), mz_expr::func::Eq)
447                            .or(class[0]
448                                .clone()
449                                .call_is_null()
450                                .and(expr.clone().call_is_null())),
451                    );
452                }
453            }
454            inputs.pop().unwrap().filter(predicates)
455        } else {
456            MirRelationExpr::join_scalars(inputs, component.constraints)
457        };
458        Self {
459            input,
460            group_key: Vec::new(),
461            aggregates: Vec::new(),
462            localize_map,
463        }
464    }
465
466    fn add_group_key(&mut self, mut key: MirScalarExpr) {
467        key.permute_map(&self.localize_map);
468        self.group_key.push(key);
469    }
470
471    fn add_aggregate(&mut self, mut agg: AggregateExpr) {
472        agg.expr.permute_map(&self.localize_map);
473        self.aggregates.push(agg);
474    }
475
476    fn arity(&self) -> usize {
477        self.group_key.len() + self.aggregates.len()
478    }
479
480    fn construct_reduce(
481        self,
482        monotonic: bool,
483        expected_group_size: Option<u64>,
484    ) -> MirRelationExpr {
485        MirRelationExpr::Reduce {
486            input: Box::new(self.input),
487            group_key: self.group_key,
488            aggregates: self.aggregates,
489            monotonic,
490            expected_group_size,
491        }
492    }
493}