mz_transform/
canonicalization.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Transformations that bring relation expressions to their canonical form.
11//!
12//! This is achieved  by:
13//! 1. Bringing enclosed scalar expressions to a canonical form,
14//! 2. Converting / peeling off part of the enclosing relation expression into
15//!    another relation expression that can represent the same concept.
16
17mod flatmap_to_map;
18mod projection_extraction;
19mod topk_elision;
20
21pub use flatmap_to_map::FlatMapToMap;
22pub use projection_extraction::ProjectionExtraction;
23pub use topk_elision::TopKElision;
24
25use mz_expr::MirRelationExpr;
26
27use crate::TransformCtx;
28use crate::analysis::{DerivedBuilder, RelationType};
29
30/// A transform that visits each AST node and reduces scalar expressions.
31#[derive(Debug)]
32pub struct ReduceScalars;
33
34impl crate::Transform for ReduceScalars {
35    fn name(&self) -> &'static str {
36        "ReduceScalars"
37    }
38
39    #[mz_ore::instrument(
40        target = "optimizer",
41        level = "debug",
42        fields(path.segment = "reduce_scalars")
43    )]
44    fn actually_perform_transform(
45        &self,
46        relation: &mut MirRelationExpr,
47        ctx: &mut TransformCtx,
48    ) -> Result<(), crate::TransformError> {
49        let mut builder = DerivedBuilder::new(ctx.features);
50        builder.require(RelationType);
51        let derived = builder.visit(&*relation);
52
53        // Descend the AST, reducing scalar expressions.
54        let mut todo = vec![(&mut *relation, derived.as_view())];
55        while let Some((expr, view)) = todo.pop() {
56            match expr {
57                MirRelationExpr::Constant { .. }
58                | MirRelationExpr::Get { .. }
59                | MirRelationExpr::Let { .. }
60                | MirRelationExpr::LetRec { .. }
61                | MirRelationExpr::Project { .. }
62                | MirRelationExpr::Union { .. }
63                | MirRelationExpr::Threshold { .. }
64                | MirRelationExpr::Negate { .. } => {
65                    // No expressions to reduce
66                }
67                MirRelationExpr::ArrangeBy { .. } => {
68                    // Has expressions, but we aren't brave enough to reduce these yet.
69                }
70                MirRelationExpr::Filter { predicates, .. } => {
71                    let input_type = view
72                        .last_child()
73                        .value::<RelationType>()
74                        .expect("RelationType required")
75                        .as_ref()
76                        .unwrap();
77                    for predicate in predicates.iter_mut() {
78                        predicate.reduce(input_type);
79                    }
80                    predicates.retain(|p| !p.is_literal_true());
81                }
82                MirRelationExpr::FlatMap { exprs, .. } => {
83                    let input_type = view
84                        .last_child()
85                        .value::<RelationType>()
86                        .expect("RelationType required")
87                        .as_ref()
88                        .unwrap();
89                    for expr in exprs.iter_mut() {
90                        expr.reduce(input_type);
91                    }
92                }
93                MirRelationExpr::Map { scalars, .. } => {
94                    // Use the output type, to incorporate the types of `scalars` as they land.
95                    let output_type = view
96                        .value::<RelationType>()
97                        .expect("RelationType required")
98                        .as_ref()
99                        .unwrap();
100                    let input_arity = output_type.len() - scalars.len();
101                    for (index, scalar) in scalars.iter_mut().enumerate() {
102                        scalar.reduce(&output_type[..input_arity + index]);
103                    }
104                }
105                MirRelationExpr::Join { equivalences, .. } => {
106                    let mut children: Vec<_> = view.children_rev().collect::<Vec<_>>();
107                    children.reverse();
108                    let input_types = children
109                        .iter()
110                        .flat_map(|c| {
111                            c.value::<RelationType>()
112                                .expect("RelationType required")
113                                .as_ref()
114                                .unwrap()
115                                .iter()
116                                .cloned()
117                        })
118                        .collect::<Vec<_>>();
119
120                    for class in equivalences.iter_mut() {
121                        for expr in class.iter_mut() {
122                            expr.reduce(&input_types[..]);
123                        }
124                        class.sort();
125                        class.dedup();
126                    }
127                    equivalences.retain(|e| e.len() > 1);
128                    equivalences.sort();
129                    equivalences.dedup();
130                }
131                MirRelationExpr::Reduce {
132                    group_key,
133                    aggregates,
134                    ..
135                } => {
136                    let input_type = view
137                        .last_child()
138                        .value::<RelationType>()
139                        .expect("RelationType required")
140                        .as_ref()
141                        .unwrap();
142                    for key in group_key.iter_mut() {
143                        key.reduce(input_type);
144                    }
145                    for aggregate in aggregates.iter_mut() {
146                        aggregate.expr.reduce(input_type);
147                    }
148                }
149                MirRelationExpr::TopK { limit, .. } => {
150                    let input_type = view
151                        .last_child()
152                        .value::<RelationType>()
153                        .expect("RelationType required")
154                        .as_ref()
155                        .unwrap();
156                    if let Some(limit) = limit {
157                        limit.reduce(input_type);
158                    }
159                }
160            }
161            todo.extend(expr.children_mut().rev().zip(view.children_rev()))
162        }
163
164        mz_repr::explain::trace_plan(&*relation);
165        Ok(())
166    }
167}