Skip to main content

mz_transform/
reduction_pushdown.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Tries to convert a reduce around a join to a join of reduces.
11//! Also absorbs Map operators into Reduce operators.
12//!
13//! In a traditional DB, this transformation has a potential benefit of reducing
14//! the size of the join. In our streaming system built on top of Timely
15//! Dataflow and Differential Dataflow, there are two other potential benefits:
16//! 1) Reducing data skew in the arrangements constructed for a join.
17//! 2) The join can potentially reuse the final arrangement constructed for the
18//!    reduce and not have to construct its own arrangement.
19//! 3) Reducing the frequency with which we have to recalculate the result of a join.
20//!
21//! Suppose there are two inputs R and S being joined. According to
22//! [Galindo-Legaria and Joshi (2001)](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.563.8492&rep=rep1&type=pdf),
23//! a full reduction pushdown to R can be done if and only if:
24//! 1) Columns from R involved in join constraints are a subset of the group by keys.
25//! 2) The key of S is a subset of the group by keys.
26//! 3) The columns involved in the aggregation all belong to R.
27//!
28//! In our current implementation:
29//! * We abide by condition 1 to the letter.
30//! * We work around condition 2 by rewriting the reduce around a join of R to
31//!   S with an equivalent relational expression involving a join of R to
32//!   ```ignore
33//!   select <columns involved in join constraints>, count(true)
34//!   from S
35//!   group by <columns involved in join constraints>
36//!   ```
37//! * TODO: We work around condition 3 in some cases by noting that `sum(R.a * S.a)`
38//!   is equivalent to `sum(R.a) * sum(S.a)`.
39//!
40//! Full documentation with examples can be found
41//! [here](https://docs.google.com/document/d/1xrBJGGDkkiGBKRSNYR2W-nKba96ZOdC2mVbLqMLjJY0/edit)
42//!
43//! The current implementation is chosen so that reduction pushdown kicks in
44//! only in the subset of cases mostly likely to help users. In the future, we
45//! may allow the user to toggle the aggressiveness of reduction pushdown. A
46//! more aggressive reduction pushdown implementation may, for example, try to
47//! work around condition 1 by pushing down an inner reduce through the join
48//! while retaining the original outer reduce.
49
50use std::collections::{BTreeMap, BTreeSet};
51use std::iter::FromIterator;
52
53use mz_expr::visit::Visit;
54use mz_expr::{AggregateExpr, Columns, JoinInputMapper, MirRelationExpr, MirScalarExpr};
55
56use crate::TransformCtx;
57use crate::analysis::equivalences::EquivalenceClasses;
58
59/// Pushes Reduce operators toward sources.
60#[derive(Debug)]
61pub struct ReductionPushdown;
62
63impl crate::Transform for ReductionPushdown {
64    fn name(&self) -> &'static str {
65        "ReductionPushdown"
66    }
67
68    #[mz_ore::instrument(
69        target = "optimizer",
70        level = "debug",
71        fields(path.segment = "reduction_pushdown")
72    )]
73    fn actually_perform_transform(
74        &self,
75        relation: &mut MirRelationExpr,
76        _: &mut TransformCtx,
77    ) -> Result<(), crate::TransformError> {
78        // `try_visit_mut_pre` is used here because after pushing down a reduction,
79        // we want to see if we can push the same reduction further down.
80        relation.visit_mut_pre(&mut |e| self.action(e));
81        mz_repr::explain::trace_plan(&*relation);
82        Ok(())
83    }
84}
85
86impl ReductionPushdown {
87    /// Pushes Reduce operators toward sources.
88    ///
89    /// A join can be thought of as a multigraph where vertices are inputs and
90    /// edges are join constraints. After removing constraints containing a
91    /// GroupBy, the reduce will be pushed down to all connected components. If
92    /// there is only one connected component, this method is a no-op.
93    pub fn action(&self, relation: &mut MirRelationExpr) {
94        if let MirRelationExpr::Reduce {
95            input,
96            group_key,
97            aggregates,
98            monotonic,
99            expected_group_size,
100        } = relation
101        {
102            // Map expressions can be absorbed into the Reduce at no cost.
103            if let MirRelationExpr::Map {
104                input: inner,
105                scalars,
106            } = &mut **input
107            {
108                let arity = inner.arity();
109
110                // Normalize the scalars to not be self-referential.
111                let mut scalars = scalars.clone();
112                for index in 0..scalars.len() {
113                    let (lower, upper) = scalars.split_at_mut(index);
114                    upper[0].visit_mut_post(&mut |e| {
115                        if let mz_expr::MirScalarExpr::Column(c, _) = e {
116                            if *c >= arity {
117                                *e = lower[*c - arity].clone();
118                            }
119                        }
120                    });
121                }
122                for key in group_key.iter_mut() {
123                    key.visit_mut_post(&mut |e| {
124                        if let mz_expr::MirScalarExpr::Column(c, _) = e {
125                            if *c >= arity {
126                                *e = scalars[*c - arity].clone();
127                            }
128                        }
129                    });
130                }
131                for agg in aggregates.iter_mut() {
132                    agg.expr.visit_mut_post(&mut |e| {
133                        if let mz_expr::MirScalarExpr::Column(c, _) = e {
134                            if *c >= arity {
135                                *e = scalars[*c - arity].clone();
136                            }
137                        }
138                    });
139                }
140
141                **input = inner.take_dangerous()
142            }
143            if let MirRelationExpr::Join {
144                inputs,
145                equivalences,
146                implementation: _,
147            } = &mut **input
148            {
149                if let Some(new_relation_expr) = try_push_reduce_through_join(
150                    inputs,
151                    equivalences,
152                    group_key,
153                    aggregates,
154                    *monotonic,
155                    *expected_group_size,
156                ) {
157                    *relation = new_relation_expr;
158                }
159            }
160        }
161    }
162}
163
164fn try_push_reduce_through_join(
165    inputs: &Vec<MirRelationExpr>,
166    equivalences: &Vec<Vec<MirScalarExpr>>,
167    group_key: &Vec<MirScalarExpr>,
168    aggregates: &Vec<AggregateExpr>,
169    monotonic: bool,
170    expected_group_size: Option<u64>,
171) -> Option<MirRelationExpr> {
172    // Variable name details:
173    // The goal is to turn `old` (`Reduce { Join { <inputs> }}`) into
174    // `new`, which looks like:
175    // ```
176    // Project {
177    //    Join {
178    //      Reduce { <component> }, ... , Reduce { <component> }
179    //    }
180    // }
181    // ```
182    //
183    // `<component>` is either `Join {<subset of inputs>}` or
184    // `<element of inputs>`.
185
186    // 0) Make sure that `equivalences` is a proper equivalence relation. Later, in 3a)/i), we'll
187    //    rely on expressions appearing in at most one equivalence class.
188    let mut eq_classes = EquivalenceClasses::default();
189    eq_classes.classes = equivalences.clone();
190    eq_classes.minimize(None);
191    let equivalences = eq_classes.classes;
192
193    let old_join_mapper = JoinInputMapper::new(inputs.as_slice());
194    // 1) Partition the join constraints into constraints containing a group
195    //    key and constraints that don't.
196    let (new_join_equivalences, component_equivalences): (Vec<_>, Vec<_>) = equivalences
197        .iter()
198        .cloned()
199        .partition(|cls| cls.iter().any(|expr| group_key.contains(expr)));
200
201    // 2) Find the connected components that remain after removing constraints
202    //    containing the group_key. Also, track the set of constraints that
203    //    connect the inputs in each component.
204    let mut components = (0..inputs.len()).map(Component::new).collect::<Vec<_>>();
205    for equivalence in component_equivalences {
206        // a) Find the inputs referenced by the constraint.
207        let inputs_to_connect = BTreeSet::<usize>::from_iter(
208            equivalence
209                .iter()
210                .flat_map(|expr| old_join_mapper.lookup_inputs(expr)),
211        );
212        // b) Extract the set of components that covers the inputs.
213        let (mut components_to_connect, other): (Vec<_>, Vec<_>) = components
214            .into_iter()
215            .partition(|c| c.inputs.iter().any(|i| inputs_to_connect.contains(i)));
216        components = other;
217        // c) Connect the components and push the result back into the list of
218        // components.
219        if let Some(mut connected_component) = components_to_connect.pop() {
220            connected_component.connect(components_to_connect, equivalence);
221            components.push(connected_component);
222        }
223        // d) Abort reduction pushdown if there are less than two connected components.
224        if components.len() < 2 {
225            return None;
226        }
227    }
228    components.sort();
229    // TODO: Connect components referenced by the same multi-input expression
230    // contained in a constraint containing a GroupBy key.
231    // For the example query below, there should be two components `{foo, bar}`
232    // and `baz`.
233    // ```
234    // select sum(foo.b) from foo, bar, baz
235    // where foo.a * bar.a = 24 group by foo.a * bar.a
236    // ```
237
238    // Maps (input idxs from old join) -> (idx of component it belongs to)
239    let input_component_map = BTreeMap::from_iter(
240        components
241            .iter()
242            .enumerate()
243            .flat_map(|(c_idx, c)| c.inputs.iter().map(move |i| (*i, c_idx))),
244    );
245
246    // 3) Construct a reduce to push to each input
247    let mut new_reduces = components
248        .into_iter()
249        .map(|component| ReduceBuilder::new(component, inputs, &old_join_mapper))
250        .collect::<Vec<_>>();
251
252    // The new projection and new join equivalences will reference columns
253    // produced by the new reduces, but we don't know the arities of the new
254    // reduces yet. Thus, they are temporarily stored as
255    // `(component_idx, column_idx_relative_to_new_reduce)`.
256    let mut new_projection = Vec::with_capacity(group_key.len());
257    let mut new_join_equivalences_by_component = Vec::new();
258
259    // 3a) Calculate the group key for each new reduce. We must make sure that
260    // the union of group keys across the new reduces can produce:
261    // (1) the group keys of the old reduce.
262    // (2) every expression in the equivalences of the new join.
263    for key in group_key {
264        // i) Find the equivalence class that the key is in.
265        //    This relies on the expression appearing in at most one equivalence class. This
266        //    invariant is ensured in step 0).
267        if let Some(cls) = new_join_equivalences
268            .iter()
269            .find(|cls| cls.iter().any(|expr| expr == key))
270        {
271            // ii) Rewrite the join equivalence in terms of columns produced by
272            // the pushed down reduction.
273            let mut new_join_cls = Vec::new();
274            for expr in cls {
275                if let Some(component) =
276                    lookup_corresponding_component(expr, &old_join_mapper, &input_component_map)
277                {
278                    if key == expr {
279                        new_projection.push((component, new_reduces[component].arity()));
280                    }
281                    new_join_cls.push((component, new_reduces[component].arity()));
282                    new_reduces[component].add_group_key(expr.clone());
283                } else {
284                    // Abort reduction pushdown if the expression does not
285                    // refer to exactly one component.
286                    return None;
287                }
288            }
289            new_join_equivalences_by_component.push(new_join_cls);
290        } else {
291            // If GroupBy key does not belong in an equivalence class,
292            // add the key to new projection + add it as a GroupBy key to
293            // the new component
294            if let Some(component) =
295                lookup_corresponding_component(key, &old_join_mapper, &input_component_map)
296            {
297                new_projection.push((component, new_reduces[component].arity()));
298                new_reduces[component].add_group_key(key.clone())
299            } else {
300                // Abort reduction pushdown if the expression does not
301                // refer to exactly one component.
302                return None;
303            }
304        }
305    }
306
307    // 3b) Deduce the aggregates that each reduce needs to calculate in order to
308    // reconstruct each aggregate in the old reduce.
309    for agg in aggregates {
310        if let Some(component) =
311            lookup_corresponding_component(&agg.expr, &old_join_mapper, &input_component_map)
312        {
313            if !agg.distinct {
314                // TODO: support non-distinct aggs.
315                // For more details, see https://github.com/MaterializeInc/database-issues/issues/2924
316                return None;
317            }
318            new_projection.push((component, new_reduces[component].arity()));
319            new_reduces[component].add_aggregate(agg.clone());
320        } else {
321            // TODO: support multi- and zero- component aggs
322            // For more details, see https://github.com/MaterializeInc/database-issues/issues/2924
323            return None;
324        }
325    }
326
327    // 4) Construct the new `MirRelationExpr`.
328    let new_join_mapper =
329        JoinInputMapper::new_from_input_arities(new_reduces.iter().map(|builder| builder.arity()));
330
331    let new_inputs = new_reduces
332        .into_iter()
333        .map(|builder| builder.construct_reduce(monotonic, expected_group_size))
334        .collect::<Vec<_>>();
335
336    let new_equivalences = new_join_equivalences_by_component
337        .into_iter()
338        .map(|cls| {
339            cls.into_iter()
340                .map(|(idx, col)| {
341                    MirScalarExpr::column(new_join_mapper.map_column_to_global(col, idx))
342                })
343                .collect::<Vec<_>>()
344        })
345        .collect::<Vec<_>>();
346
347    let new_projection = new_projection
348        .into_iter()
349        .map(|(idx, col)| new_join_mapper.map_column_to_global(col, idx))
350        .collect::<Vec<_>>();
351
352    Some(MirRelationExpr::join_scalars(new_inputs, new_equivalences).project(new_projection))
353}
354
355/// Returns None if `expr` does not belong to exactly one component.
356fn lookup_corresponding_component(
357    expr: &MirScalarExpr,
358    old_join_mapper: &JoinInputMapper,
359    input_component_map: &BTreeMap<usize, usize>,
360) -> Option<usize> {
361    let mut dedupped = old_join_mapper
362        .lookup_inputs(expr)
363        .map(|i| input_component_map[&i])
364        .collect::<BTreeSet<_>>();
365    if dedupped.len() == 1 {
366        dedupped.pop_first()
367    } else {
368        None
369    }
370}
371
372/// A subjoin represented as a multigraph.
373#[derive(Eq, Ord, PartialEq, PartialOrd)]
374struct Component {
375    /// Index numbers of the inputs in the subjoin.
376    /// Are the vertices in the multigraph.
377    inputs: Vec<usize>,
378    /// The edges in the multigraph.
379    constraints: Vec<Vec<MirScalarExpr>>,
380}
381
382impl Component {
383    /// Create a new component that contains only one input.
384    fn new(i: usize) -> Self {
385        Component {
386            inputs: vec![i],
387            constraints: Vec::new(),
388        }
389    }
390
391    /// Connect `self` with `others` using the edge `connecting_constraint`.
392    fn connect(&mut self, others: Vec<Component>, connecting_constraint: Vec<MirScalarExpr>) {
393        self.constraints.push(connecting_constraint);
394        for mut other in others {
395            self.inputs.append(&mut other.inputs);
396            self.constraints.append(&mut other.constraints);
397        }
398        self.inputs.sort();
399        self.inputs.dedup();
400    }
401}
402
403/// Constructs a Reduce around a component, localizing column references.
404struct ReduceBuilder {
405    input: MirRelationExpr,
406    group_key: Vec<MirScalarExpr>,
407    aggregates: Vec<AggregateExpr>,
408    /// Maps (global column relative to old join) -> (local column relative to `input`)
409    localize_map: BTreeMap<usize, usize>,
410}
411
412impl ReduceBuilder {
413    fn new(
414        mut component: Component,
415        inputs: &Vec<MirRelationExpr>,
416        old_join_mapper: &JoinInputMapper,
417    ) -> Self {
418        let localize_map = component
419            .inputs
420            .iter()
421            .flat_map(|i| old_join_mapper.global_columns(*i))
422            .enumerate()
423            .map(|(local, global)| (global, local))
424            .collect::<BTreeMap<_, _>>();
425        // Convert the subjoin from the `Component` representation to a
426        // `MirRelationExpr` representation.
427        let mut inputs = component
428            .inputs
429            .iter()
430            .map(|i| inputs[*i].clone())
431            .collect::<Vec<_>>();
432        // Constraints need to be localized to the subjoin.
433        for constraint in component.constraints.iter_mut() {
434            for expr in constraint.iter_mut() {
435                expr.permute_map(&localize_map)
436            }
437        }
438        let input = if inputs.len() == 1 {
439            let mut predicates = Vec::new();
440            for class in component.constraints {
441                for expr in class[1..].iter() {
442                    predicates.push(
443                        class[0]
444                            .clone()
445                            .call_binary(expr.clone(), mz_expr::func::Eq)
446                            .or(class[0]
447                                .clone()
448                                .call_is_null()
449                                .and(expr.clone().call_is_null())),
450                    );
451                }
452            }
453            inputs.pop().unwrap().filter(predicates)
454        } else {
455            MirRelationExpr::join_scalars(inputs, component.constraints)
456        };
457        Self {
458            input,
459            group_key: Vec::new(),
460            aggregates: Vec::new(),
461            localize_map,
462        }
463    }
464
465    fn add_group_key(&mut self, mut key: MirScalarExpr) {
466        key.permute_map(&self.localize_map);
467        self.group_key.push(key);
468    }
469
470    fn add_aggregate(&mut self, mut agg: AggregateExpr) {
471        agg.expr.permute_map(&self.localize_map);
472        self.aggregates.push(agg);
473    }
474
475    fn arity(&self) -> usize {
476        self.group_key.len() + self.aggregates.len()
477    }
478
479    fn construct_reduce(
480        self,
481        monotonic: bool,
482        expected_group_size: Option<u64>,
483    ) -> MirRelationExpr {
484        MirRelationExpr::Reduce {
485            input: Box::new(self.input),
486            group_key: self.group_key,
487            aggregates: self.aggregates,
488            monotonic,
489            expected_group_size,
490        }
491    }
492}