mz_transform/fusion/reduce.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.
//! Fuses reduce operators with parent operators if possible.
use mz_expr::{MirRelationExpr, MirScalarExpr};
use crate::{TransformCtx, TransformError};
/// Fuses reduce operators with parent operators if possible.
#[derive(Debug)]
pub struct Reduce;
impl crate::Transform for Reduce {
fn name(&self) -> &'static str {
"ReduceFusion"
}
#[mz_ore::instrument(
target = "optimizer",
level = "debug",
fields(path.segment = "reduce_fusion")
)]
fn actually_perform_transform(
&self,
relation: &mut MirRelationExpr,
_: &mut TransformCtx,
) -> Result<(), TransformError> {
let result = relation.visit_pre_mut(|e| self.action(e));
mz_repr::explain::trace_plan(&*relation);
Ok(result)
}
}
impl Reduce {
/// Fuses reduce operators with parent operators if possible.
pub fn action(&self, relation: &mut MirRelationExpr) {
if let MirRelationExpr::Reduce {
input,
group_key,
aggregates,
monotonic: _,
expected_group_size: _,
} = relation
{
if let MirRelationExpr::Reduce {
input: inner_input,
group_key: inner_group_key,
aggregates: inner_aggregates,
monotonic: _,
expected_group_size: _,
} = &mut **input
{
// Collect all columns referenced by outer
let mut outer_cols = vec![];
for expr in group_key.iter() {
expr.visit_pre(|e| {
if let MirScalarExpr::Column(i) = e {
outer_cols.push(*i);
}
});
}
// We can fuse reduce operators as long as the outer one doesn't
// group by an aggregation performed by the inner one.
if outer_cols.iter().any(|c| *c >= inner_group_key.len()) {
return;
}
if aggregates.is_empty() && inner_aggregates.is_empty() {
// Replace inner reduce with map + project (no grouping)
let mut outputs = vec![];
let mut scalars = vec![];
let arity = inner_input.arity();
for e in inner_group_key {
if let MirScalarExpr::Column(i) = e {
outputs.push(*i);
} else {
outputs.push(arity + scalars.len());
scalars.push(e.clone());
}
}
**input = inner_input.take_dangerous().map(scalars).project(outputs);
}
}
}
}
}