Skip to main content

mz_transform/
canonicalization.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transformations that bring relation expressions to their canonical form.
11//!
12//! This is achieved  by:
13//! 1. Bringing enclosed scalar expressions to a canonical form,
14//! 2. Converting / peeling off part of the enclosing relation expression into
15//!    another relation expression that can represent the same concept.
16
17mod flat_map_elimination;
18mod projection_extraction;
19mod topk_elision;
20
21pub use flat_map_elimination::FlatMapElimination;
22use itertools::Itertools;
23pub use projection_extraction::ProjectionExtraction;
24pub use topk_elision::TopKElision;
25
26use crate::TransformCtx;
27use crate::analysis::{DerivedBuilder, ReprRelationType};
28use mz_expr::MirRelationExpr;
29use mz_repr::ReprColumnType;
30
31/// A transform that visits each AST node and reduces scalar expressions.
32#[derive(Debug)]
33pub struct ReduceScalars;
34
35impl crate::Transform for ReduceScalars {
36    fn name(&self) -> &'static str {
37        "ReduceScalars"
38    }
39
40    #[mz_ore::instrument(
41        target = "optimizer",
42        level = "debug",
43        fields(path.segment = "reduce_scalars")
44    )]
45    fn actually_perform_transform(
46        &self,
47        relation: &mut MirRelationExpr,
48        ctx: &mut TransformCtx,
49    ) -> Result<(), crate::TransformError> {
50        let mut builder = DerivedBuilder::new(ctx.features);
51        builder.require(ReprRelationType);
52        let derived = builder.visit(&*relation);
53
54        // Descend the AST, reducing scalar expressions.
55        let mut todo = vec![(&mut *relation, derived.as_view())];
56        while let Some((expr, view)) = todo.pop() {
57            match expr {
58                MirRelationExpr::Constant { .. }
59                | MirRelationExpr::Get { .. }
60                | MirRelationExpr::Let { .. }
61                | MirRelationExpr::LetRec { .. }
62                | MirRelationExpr::Project { .. }
63                | MirRelationExpr::Union { .. }
64                | MirRelationExpr::Threshold { .. }
65                | MirRelationExpr::Negate { .. } => {
66                    // No expressions to reduce
67                }
68                MirRelationExpr::ArrangeBy { .. } => {
69                    // Has expressions, but we aren't brave enough to reduce these yet.
70                }
71                MirRelationExpr::Filter { predicates, .. } => {
72                    let input_type: &Vec<ReprColumnType> = view
73                        .last_child()
74                        .value::<ReprRelationType>()
75                        .expect("ReprRelationType required")
76                        .as_ref()
77                        .unwrap();
78                    for predicate in predicates.iter_mut() {
79                        predicate.reduce(input_type);
80                    }
81                    predicates.retain(|p| !p.is_literal_true());
82                }
83                MirRelationExpr::FlatMap { exprs, .. } => {
84                    let input_type: &Vec<ReprColumnType> = view
85                        .last_child()
86                        .value::<ReprRelationType>()
87                        .expect("ReprRelationType required")
88                        .as_ref()
89                        .unwrap();
90                    for expr in exprs.iter_mut() {
91                        expr.reduce(input_type);
92                    }
93                }
94                MirRelationExpr::Map { scalars, .. } => {
95                    // Use the output type, to incorporate the types of `scalars` as they land.
96                    let output_type: &Vec<ReprColumnType> = view
97                        .value::<ReprRelationType>()
98                        .expect("ReprRelationType required")
99                        .as_ref()
100                        .unwrap();
101                    let input_arity = output_type.len() - scalars.len();
102                    for (index, scalar) in scalars.iter_mut().enumerate() {
103                        scalar.reduce(&output_type[..input_arity + index]);
104                    }
105                }
106                MirRelationExpr::Join { equivalences, .. } => {
107                    let mut children: Vec<_> = view.children_rev().collect::<Vec<_>>();
108                    children.reverse();
109                    let input_types: Vec<ReprColumnType> = children
110                        .iter()
111                        .flat_map(|c| {
112                            c.value::<ReprRelationType>()
113                                .expect("ReprRelationType required")
114                                .as_ref()
115                                .unwrap()
116                                .iter()
117                                .cloned()
118                        })
119                        .collect();
120                    for class in equivalences.iter_mut() {
121                        for expr in class.iter_mut() {
122                            expr.reduce(&input_types[..]);
123                        }
124                        class.sort();
125                        class.dedup();
126                    }
127                    equivalences.retain(|e| e.len() > 1);
128                    equivalences.sort();
129                    equivalences.dedup();
130                }
131                MirRelationExpr::Reduce {
132                    group_key,
133                    aggregates,
134                    ..
135                } => {
136                    let input_type: &Vec<ReprColumnType> = view
137                        .last_child()
138                        .value::<ReprRelationType>()
139                        .expect("ReprRelationType required")
140                        .as_ref()
141                        .unwrap();
142                    for key in group_key.iter_mut() {
143                        key.reduce(input_type);
144                    }
145                    for aggregate in aggregates.iter_mut() {
146                        aggregate.expr.reduce(input_type);
147                    }
148                }
149                MirRelationExpr::TopK { limit, .. } => {
150                    let input_type: &Vec<ReprColumnType> = view
151                        .last_child()
152                        .value::<ReprRelationType>()
153                        .expect("ReprRelationType required")
154                        .as_ref()
155                        .unwrap();
156                    if let Some(limit) = limit {
157                        limit.reduce(input_type);
158                    }
159                }
160            }
161            todo.extend(expr.children_mut().rev().zip_eq(view.children_rev()))
162        }
163
164        mz_repr::explain::trace_plan(&*relation);
165        Ok(())
166    }
167}