mz_transform/
canonicalization.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transformations that bring relation expressions to their canonical form.
11//!
12//! This is achieved  by:
13//! 1. Bringing enclosed scalar expressions to a canonical form,
14//! 2. Converting / peeling off part of the enclosing relation expression into
15//!    another relation expression that can represent the same concept.
16
17mod flat_map_elimination;
18mod projection_extraction;
19mod topk_elision;
20
21pub use flat_map_elimination::FlatMapElimination;
22use itertools::Itertools;
23pub use projection_extraction::ProjectionExtraction;
24pub use topk_elision::TopKElision;
25
26use mz_expr::MirRelationExpr;
27
28use crate::TransformCtx;
29use crate::analysis::{DerivedBuilder, SqlRelationType};
30
31/// A transform that visits each AST node and reduces scalar expressions.
32#[derive(Debug)]
33pub struct ReduceScalars;
34
35impl crate::Transform for ReduceScalars {
36    fn name(&self) -> &'static str {
37        "ReduceScalars"
38    }
39
40    #[mz_ore::instrument(
41        target = "optimizer",
42        level = "debug",
43        fields(path.segment = "reduce_scalars")
44    )]
45    fn actually_perform_transform(
46        &self,
47        relation: &mut MirRelationExpr,
48        ctx: &mut TransformCtx,
49    ) -> Result<(), crate::TransformError> {
50        let mut builder = DerivedBuilder::new(ctx.features);
51        builder.require(SqlRelationType);
52        let derived = builder.visit(&*relation);
53
54        // Descend the AST, reducing scalar expressions.
55        let mut todo = vec![(&mut *relation, derived.as_view())];
56        while let Some((expr, view)) = todo.pop() {
57            match expr {
58                MirRelationExpr::Constant { .. }
59                | MirRelationExpr::Get { .. }
60                | MirRelationExpr::Let { .. }
61                | MirRelationExpr::LetRec { .. }
62                | MirRelationExpr::Project { .. }
63                | MirRelationExpr::Union { .. }
64                | MirRelationExpr::Threshold { .. }
65                | MirRelationExpr::Negate { .. } => {
66                    // No expressions to reduce
67                }
68                MirRelationExpr::ArrangeBy { .. } => {
69                    // Has expressions, but we aren't brave enough to reduce these yet.
70                }
71                MirRelationExpr::Filter { predicates, .. } => {
72                    let input_type = view
73                        .last_child()
74                        .value::<SqlRelationType>()
75                        .expect("SqlRelationType required")
76                        .as_ref()
77                        .unwrap();
78                    for predicate in predicates.iter_mut() {
79                        predicate.reduce(input_type);
80                    }
81                    predicates.retain(|p| !p.is_literal_true());
82                }
83                MirRelationExpr::FlatMap { exprs, .. } => {
84                    let input_type = view
85                        .last_child()
86                        .value::<SqlRelationType>()
87                        .expect("SqlRelationType required")
88                        .as_ref()
89                        .unwrap();
90                    for expr in exprs.iter_mut() {
91                        expr.reduce(input_type);
92                    }
93                }
94                MirRelationExpr::Map { scalars, .. } => {
95                    // Use the output type, to incorporate the types of `scalars` as they land.
96                    let output_type = view
97                        .value::<SqlRelationType>()
98                        .expect("SqlRelationType required")
99                        .as_ref()
100                        .unwrap();
101                    let input_arity = output_type.len() - scalars.len();
102                    for (index, scalar) in scalars.iter_mut().enumerate() {
103                        scalar.reduce(&output_type[..input_arity + index]);
104                    }
105                }
106                MirRelationExpr::Join { equivalences, .. } => {
107                    let mut children: Vec<_> = view.children_rev().collect::<Vec<_>>();
108                    children.reverse();
109                    let input_types = children
110                        .iter()
111                        .flat_map(|c| {
112                            c.value::<SqlRelationType>()
113                                .expect("SqlRelationType required")
114                                .as_ref()
115                                .unwrap()
116                                .iter()
117                                .cloned()
118                        })
119                        .collect::<Vec<_>>();
120
121                    for class in equivalences.iter_mut() {
122                        for expr in class.iter_mut() {
123                            expr.reduce(&input_types[..]);
124                        }
125                        class.sort();
126                        class.dedup();
127                    }
128                    equivalences.retain(|e| e.len() > 1);
129                    equivalences.sort();
130                    equivalences.dedup();
131                }
132                MirRelationExpr::Reduce {
133                    group_key,
134                    aggregates,
135                    ..
136                } => {
137                    let input_type = view
138                        .last_child()
139                        .value::<SqlRelationType>()
140                        .expect("SqlRelationType required")
141                        .as_ref()
142                        .unwrap();
143                    for key in group_key.iter_mut() {
144                        key.reduce(input_type);
145                    }
146                    for aggregate in aggregates.iter_mut() {
147                        aggregate.expr.reduce(input_type);
148                    }
149                }
150                MirRelationExpr::TopK { limit, .. } => {
151                    let input_type = view
152                        .last_child()
153                        .value::<SqlRelationType>()
154                        .expect("SqlRelationType required")
155                        .as_ref()
156                        .unwrap();
157                    if let Some(limit) = limit {
158                        limit.reduce(input_type);
159                    }
160                }
161            }
162            todo.extend(expr.children_mut().rev().zip_eq(view.children_rev()))
163        }
164
165        mz_repr::explain::trace_plan(&*relation);
166        Ok(())
167    }
168}