mz_transform/
coalesce_case.rs1use mz_expr::func::variadic::Coalesce;
13use mz_expr::visit::Visit;
14use mz_expr::{AggregateExpr, MirRelationExpr, MirScalarExpr, VariadicFunc};
15use mz_ore::soft_panic_or_log;
16
17use crate::{TransformCtx, TransformError};
18
19#[derive(Debug)]
21pub struct CoalesceCase;
22
23impl Default for CoalesceCase {
24 fn default() -> Self {
25 Self {}
26 }
27}
28
29impl crate::Transform for CoalesceCase {
30 fn name(&self) -> &'static str {
31 "CoalesceCase"
32 }
33
34 #[mz_ore::instrument(
35 target = "optimizer",
36 level = "debug",
37 fields(path.segment = "coalesce_case")
38 )]
39 fn actually_perform_transform(
40 &self,
41 relation: &mut MirRelationExpr,
42 _: &mut TransformCtx,
43 ) -> Result<(), TransformError> {
44 relation.visit_mut_post(&mut |e| self.action(e));
45 mz_repr::explain::trace_plan(&*relation);
46 Ok(())
47 }
48}
49
50impl CoalesceCase {
51 fn action(&self, relation: &mut MirRelationExpr) {
52 match relation {
53 MirRelationExpr::Constant { .. } |
54 MirRelationExpr::Get { .. } |
55 MirRelationExpr::Let { .. } |
56 MirRelationExpr::LetRec { .. } |
57 MirRelationExpr::Project { .. } |
58 MirRelationExpr::Negate { .. } |
59 MirRelationExpr::Threshold { .. } |
60 MirRelationExpr::Union { .. } |
61 MirRelationExpr::ArrangeBy { .. } => (),
64 MirRelationExpr::Map { scalars: exprs, .. } |
65 MirRelationExpr::Filter { predicates: exprs, .. } |
66 MirRelationExpr::FlatMap { exprs, .. } => {
67 for expr in exprs.iter_mut() {
69 self.rewrite_scalar(expr);
70 }
71 }
72 MirRelationExpr::Join { equivalences, implementation, .. } => {
73 if implementation.is_implemented() {
74 soft_panic_or_log!("unexpected implemented Join when optimizing coalesce/case, skipping: {implementation:?}");
75 return;
76 }
77
78 for equivalence in equivalences.iter_mut() {
79 for expr in equivalence {
80 self.rewrite_scalar(expr);
81 }
82 }
83 }
84 MirRelationExpr::Reduce { group_key, aggregates, .. } => {
85 for expr in group_key.iter_mut() {
86 self.rewrite_scalar(expr);
87 }
88
89 for agg in aggregates.iter_mut() {
90 self.rewrite_aggreagte(agg);
91 }
92 }
93 MirRelationExpr::TopK { limit, .. } => {
94 if let Some(expr) = limit {
95 self.rewrite_scalar(expr);
96 }
97 }
98 }
99 }
100
101 fn rewrite_scalar(&self, expr: &mut MirScalarExpr) {
102 expr.visit_mut_pre(&mut |e| self.try_combine_coalesce_case(e))
104 }
105
106 fn rewrite_aggreagte(&self, agg: &mut AggregateExpr) {
107 self.rewrite_scalar(&mut agg.expr)
109 }
110
111 fn try_combine_coalesce_case(&self, expr: &mut MirScalarExpr) {
112 expr.flatten_associative();
116
117 if let MirScalarExpr::CallVariadic { func, exprs } = expr
118 && *func == VariadicFunc::Coalesce(Coalesce)
119 {
120 if let MirScalarExpr::If { .. } = &exprs[0] {
121 let mut exprs = std::mem::take(exprs);
122 if let MirScalarExpr::If {
123 mut cond,
124 mut then,
125 mut els,
126 } = exprs.remove(0)
127 {
128 let cond = cond.take();
129
130 let mut then_exprs = Vec::with_capacity(exprs.len() + 1);
131 then_exprs.push(then.take());
132 then_exprs.extend(exprs.iter().cloned());
133
134 let mut else_exprs = Vec::with_capacity(exprs.len() + 1);
135 else_exprs.push(els.take());
136 else_exprs.append(&mut exprs);
137
138 let t =
139 MirScalarExpr::call_variadic(VariadicFunc::Coalesce(Coalesce), then_exprs);
140 let f =
141 MirScalarExpr::call_variadic(VariadicFunc::Coalesce(Coalesce), else_exprs);
142 *expr = cond.if_then_else(t, f);
143 } else {
144 unreachable!();
145 };
146 }
147 }
148 }
149}