Skip to main content

mz_transform/
reduce_reduction.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Breaks complex `Reduce` variants into a join of simpler variants.
11//!
12//! Specifically, any `Reduce` that contains two different "types" of aggregation,
13//! in the sense of `ReductionType`, will be broken in to one `Reduce` for each
14//! type of aggregation, each containing the aggregations of that type,
15//! and the results are then joined back together.
16
17use crate::TransformCtx;
18use mz_compute_types::plan::reduce::reduction_type;
19use mz_expr::MirRelationExpr;
20
21/// Breaks complex `Reduce` variants into a join of simpler variants.
22#[derive(Debug)]
23pub struct ReduceReduction;
24
25impl crate::Transform for ReduceReduction {
26    fn name(&self) -> &'static str {
27        "ReduceReduction"
28    }
29
30    /// Transforms an expression through accumulated knowledge.
31    #[mz_ore::instrument(
32        target = "optimizer",
33        level = "debug",
34        fields(path.segment = "reduce_reduction")
35    )]
36    fn actually_perform_transform(
37        &self,
38        relation: &mut MirRelationExpr,
39        _ctx: &mut TransformCtx,
40    ) -> Result<(), crate::TransformError> {
41        relation.visit_pre_mut(&mut Self::action);
42        mz_repr::explain::trace_plan(&*relation);
43        Ok(())
44    }
45}
46
47impl ReduceReduction {
48    /// Breaks complex `Reduce` variants into a join of simpler variants.
49    pub fn action(relation: &mut MirRelationExpr) {
50        if let MirRelationExpr::Reduce {
51            input,
52            group_key,
53            aggregates,
54            monotonic,
55            expected_group_size,
56        } = relation
57        {
58            // We start by segmenting the aggregates into those that should be rendered independently.
59            // Each element of this list is a pair of lists describing a bundle of aggregations that
60            // should be applied independently. Each pair of lists correspond to the aggregates and
61            // the column positions in which they should appear in the output.
62            // Perhaps these should be lists of pairs, to ensure they align, but their subsequent use
63            // is as the shredded lists.
64            let mut segmented_aggregates: Vec<(Vec<mz_expr::AggregateExpr>, Vec<usize>)> =
65                Vec::new();
66
67            // Our rendering currently produces independent dataflow paths for 1. all accumulable aggregations,
68            // 2. all hierarchical aggregations, and 3. *each* basic aggregation.
69            // We'll form groups for accumulable, hierarchical, and a list of basic aggregates.
70            let mut accumulable = (Vec::new(), Vec::new());
71            let mut hierarchical = (Vec::new(), Vec::new());
72
73            use mz_compute_types::plan::reduce::ReductionType;
74            for (index, aggr) in aggregates.iter().enumerate() {
75                match reduction_type(&aggr.func) {
76                    ReductionType::Accumulable => {
77                        accumulable.0.push(aggr.clone());
78                        accumulable.1.push(group_key.len() + index);
79                    }
80                    ReductionType::Hierarchical => {
81                        hierarchical.0.push(aggr.clone());
82                        hierarchical.1.push(group_key.len() + index);
83                    }
84                    ReductionType::Basic => segmented_aggregates
85                        .push((vec![aggr.clone()], vec![group_key.len() + index])),
86                }
87            }
88
89            // Fold in hierarchical and accumulable aggregates.
90            if !hierarchical.0.is_empty() {
91                segmented_aggregates.push(hierarchical);
92            }
93            if !accumulable.0.is_empty() {
94                segmented_aggregates.push(accumulable);
95            }
96            segmented_aggregates.sort();
97
98            // Do nothing unless there are at least two distinct types of aggregations.
99            if segmented_aggregates.len() < 2 {
100                return;
101            }
102
103            // For each type of aggregation we'll plan the corresponding `Reduce`,
104            // and then join the at-least-two `Reduce` stages together.
105            // TODO: Perhaps we should introduce a `Let` stage rather than clone the input?
106            let mut reduces = Vec::with_capacity(segmented_aggregates.len());
107            // Track the current and intended locations of each output column.
108            let mut columns = Vec::new();
109
110            for (aggrs, indexes) in segmented_aggregates {
111                columns.extend(0..group_key.len());
112                columns.extend(indexes);
113
114                reduces.push(MirRelationExpr::Reduce {
115                    input: input.clone(),
116                    group_key: group_key.clone(),
117                    aggregates: aggrs,
118                    monotonic: *monotonic,
119                    expected_group_size: *expected_group_size,
120                });
121            }
122
123            // Now build a `Join` of the reduces, on their keys, followed by a permutation of their aggregates.
124            // Equate all `group_key` columns in all inputs.
125            let mut equivalences = vec![Vec::with_capacity(reduces.len()); group_key.len()];
126            for column in 0..group_key.len() {
127                for input in 0..reduces.len() {
128                    equivalences[column].push((input, column));
129                }
130            }
131
132            // Determine projection that puts aggregate columns in their intended locations,
133            // and projects away repeated key columns.
134            let max_column = columns.iter().max().expect("Non-empty aggregates expected");
135            let mut projection = Vec::with_capacity(max_column + 1);
136            for column in 0..max_column + 1 {
137                projection.push(columns.iter().position(|c| *c == column).unwrap())
138            }
139
140            // Now make the join.
141            *relation = MirRelationExpr::join(reduces, equivalences).project(projection);
142        }
143    }
144}