mz_transform/
reduce_reduction.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Breaks complex `Reduce` variants into a join of simpler variants.
11//!
12//! Specifically, any `Reduce` that contains two different "types" of aggregation,
13//! in the sense of `ReductionType`, will be broken in to one `Reduce` for each
14//! type of aggregation, each containing the aggregations of that type,
15//! and the results are then joined back together.
16
17use crate::TransformCtx;
18use mz_compute_types::plan::reduce::reduction_type;
19use mz_expr::MirRelationExpr;
20
21/// Breaks complex `Reduce` variants into a join of simpler variants.
22#[derive(Debug)]
23pub struct ReduceReduction;
24
25impl crate::Transform for ReduceReduction {
26    fn name(&self) -> &'static str {
27        "ReduceReduction"
28    }
29
30    /// Transforms an expression through accumulated knowledge.
31    #[mz_ore::instrument(
32        target = "optimizer",
33        level = "debug",
34        fields(path.segment = "reduce_reduction")
35    )]
36    fn actually_perform_transform(
37        &self,
38        relation: &mut MirRelationExpr,
39        ctx: &mut TransformCtx,
40    ) -> Result<(), crate::TransformError> {
41        if ctx.features.enable_reduce_reduction {
42            relation.visit_pre_mut(&mut Self::action);
43            mz_repr::explain::trace_plan(&*relation);
44        }
45        Ok(())
46    }
47}
48
49impl ReduceReduction {
50    /// Breaks complex `Reduce` variants into a join of simpler variants.
51    pub fn action(relation: &mut MirRelationExpr) {
52        if let MirRelationExpr::Reduce {
53            input,
54            group_key,
55            aggregates,
56            monotonic,
57            expected_group_size,
58        } = relation
59        {
60            // We start by segmenting the aggregates into those that should be rendered independently.
61            // Each element of this list is a pair of lists describing a bundle of aggregations that
62            // should be applied independently. Each pair of lists correspond to the aggregaties and
63            // the column positions in which they should appear in the output.
64            // Perhaps these should be lists of pairs, to ensure they align, but their subsequent use
65            // is as the shredded lists.
66            let mut segmented_aggregates: Vec<(Vec<mz_expr::AggregateExpr>, Vec<usize>)> =
67                Vec::new();
68
69            // Our rendering currently produces independent dataflow paths for 1. all accumulable aggregations,
70            // 2. all hierarchical aggregations, and 3. *each* basic aggregation.
71            // We'll form groups for accumulable, hierarchical, and a list of basic aggregates.
72            let mut accumulable = (Vec::new(), Vec::new());
73            let mut hierarchical = (Vec::new(), Vec::new());
74
75            use mz_compute_types::plan::reduce::ReductionType;
76            for (index, aggr) in aggregates.iter().enumerate() {
77                match reduction_type(&aggr.func) {
78                    ReductionType::Accumulable => {
79                        accumulable.0.push(aggr.clone());
80                        accumulable.1.push(group_key.len() + index);
81                    }
82                    ReductionType::Hierarchical => {
83                        hierarchical.0.push(aggr.clone());
84                        hierarchical.1.push(group_key.len() + index);
85                    }
86                    ReductionType::Basic => segmented_aggregates
87                        .push((vec![aggr.clone()], vec![group_key.len() + index])),
88                }
89            }
90
91            // Fold in hierarchical and accumulable aggregates.
92            if !hierarchical.0.is_empty() {
93                segmented_aggregates.push(hierarchical);
94            }
95            if !accumulable.0.is_empty() {
96                segmented_aggregates.push(accumulable);
97            }
98            segmented_aggregates.sort();
99
100            // Do nothing unless there are at least two distinct types of aggregations.
101            if segmented_aggregates.len() < 2 {
102                return;
103            }
104
105            // For each type of aggregation we'll plan the corresponding `Reduce`,
106            // and then join the at-least-two `Reduce` stages together.
107            // TODO: Perhaps we should introduce a `Let` stage rather than clone the input?
108            let mut reduces = Vec::with_capacity(segmented_aggregates.len());
109            // Track the current and intended locations of each output column.
110            let mut columns = Vec::new();
111
112            for (aggrs, indexes) in segmented_aggregates {
113                columns.extend(0..group_key.len());
114                columns.extend(indexes);
115
116                reduces.push(MirRelationExpr::Reduce {
117                    input: input.clone(),
118                    group_key: group_key.clone(),
119                    aggregates: aggrs,
120                    monotonic: *monotonic,
121                    expected_group_size: *expected_group_size,
122                });
123            }
124
125            // Now build a `Join` of the reduces, on their keys, followed by a permutation of their aggregates.
126            // Equate all `group_key` columns in all inputs.
127            let mut equivalences = vec![Vec::with_capacity(reduces.len()); group_key.len()];
128            for column in 0..group_key.len() {
129                for input in 0..reduces.len() {
130                    equivalences[column].push((input, column));
131                }
132            }
133
134            // Determine projection that puts aggregate columns in their intended locations,
135            // and projects away repeated key columns.
136            let max_column = columns.iter().max().expect("Non-empty aggregates expected");
137            let mut projection = Vec::with_capacity(max_column + 1);
138            for column in 0..max_column + 1 {
139                projection.push(columns.iter().position(|c| *c == column).unwrap())
140            }
141
142            // Now make the join.
143            *relation = MirRelationExpr::join(reduces, equivalences).project(projection);
144        }
145    }
146}