mz_transform/
canonicalization.rs1mod flat_map_elimination;
18mod projection_extraction;
19mod topk_elision;
20
21pub use flat_map_elimination::FlatMapElimination;
22use itertools::Itertools;
23pub use projection_extraction::ProjectionExtraction;
24pub use topk_elision::TopKElision;
25
26use mz_expr::MirRelationExpr;
27
28use crate::TransformCtx;
29use crate::analysis::{DerivedBuilder, SqlRelationType};
30
31#[derive(Debug)]
33pub struct ReduceScalars;
34
35impl crate::Transform for ReduceScalars {
36 fn name(&self) -> &'static str {
37 "ReduceScalars"
38 }
39
40 #[mz_ore::instrument(
41 target = "optimizer",
42 level = "debug",
43 fields(path.segment = "reduce_scalars")
44 )]
45 fn actually_perform_transform(
46 &self,
47 relation: &mut MirRelationExpr,
48 ctx: &mut TransformCtx,
49 ) -> Result<(), crate::TransformError> {
50 let mut builder = DerivedBuilder::new(ctx.features);
51 builder.require(SqlRelationType);
52 let derived = builder.visit(&*relation);
53
54 let mut todo = vec![(&mut *relation, derived.as_view())];
56 while let Some((expr, view)) = todo.pop() {
57 match expr {
58 MirRelationExpr::Constant { .. }
59 | MirRelationExpr::Get { .. }
60 | MirRelationExpr::Let { .. }
61 | MirRelationExpr::LetRec { .. }
62 | MirRelationExpr::Project { .. }
63 | MirRelationExpr::Union { .. }
64 | MirRelationExpr::Threshold { .. }
65 | MirRelationExpr::Negate { .. } => {
66 }
68 MirRelationExpr::ArrangeBy { .. } => {
69 }
71 MirRelationExpr::Filter { predicates, .. } => {
72 let input_type = view
73 .last_child()
74 .value::<SqlRelationType>()
75 .expect("SqlRelationType required")
76 .as_ref()
77 .unwrap();
78 for predicate in predicates.iter_mut() {
79 predicate.reduce(input_type);
80 }
81 predicates.retain(|p| !p.is_literal_true());
82 }
83 MirRelationExpr::FlatMap { exprs, .. } => {
84 let input_type = view
85 .last_child()
86 .value::<SqlRelationType>()
87 .expect("SqlRelationType required")
88 .as_ref()
89 .unwrap();
90 for expr in exprs.iter_mut() {
91 expr.reduce(input_type);
92 }
93 }
94 MirRelationExpr::Map { scalars, .. } => {
95 let output_type = view
97 .value::<SqlRelationType>()
98 .expect("SqlRelationType required")
99 .as_ref()
100 .unwrap();
101 let input_arity = output_type.len() - scalars.len();
102 for (index, scalar) in scalars.iter_mut().enumerate() {
103 scalar.reduce(&output_type[..input_arity + index]);
104 }
105 }
106 MirRelationExpr::Join { equivalences, .. } => {
107 let mut children: Vec<_> = view.children_rev().collect::<Vec<_>>();
108 children.reverse();
109 let input_types = children
110 .iter()
111 .flat_map(|c| {
112 c.value::<SqlRelationType>()
113 .expect("SqlRelationType required")
114 .as_ref()
115 .unwrap()
116 .iter()
117 .cloned()
118 })
119 .collect::<Vec<_>>();
120
121 for class in equivalences.iter_mut() {
122 for expr in class.iter_mut() {
123 expr.reduce(&input_types[..]);
124 }
125 class.sort();
126 class.dedup();
127 }
128 equivalences.retain(|e| e.len() > 1);
129 equivalences.sort();
130 equivalences.dedup();
131 }
132 MirRelationExpr::Reduce {
133 group_key,
134 aggregates,
135 ..
136 } => {
137 let input_type = view
138 .last_child()
139 .value::<SqlRelationType>()
140 .expect("SqlRelationType required")
141 .as_ref()
142 .unwrap();
143 for key in group_key.iter_mut() {
144 key.reduce(input_type);
145 }
146 for aggregate in aggregates.iter_mut() {
147 aggregate.expr.reduce(input_type);
148 }
149 }
150 MirRelationExpr::TopK { limit, .. } => {
151 let input_type = view
152 .last_child()
153 .value::<SqlRelationType>()
154 .expect("SqlRelationType required")
155 .as_ref()
156 .unwrap();
157 if let Some(limit) = limit {
158 limit.reduce(input_type);
159 }
160 }
161 }
162 todo.extend(expr.children_mut().rev().zip_eq(view.children_rev()))
163 }
164
165 mz_repr::explain::trace_plan(&*relation);
166 Ok(())
167 }
168}