mz_transform/
coalesce_case.rs1use mz_expr::func::variadic::Coalesce;
13use mz_expr::visit::Visit;
14use mz_expr::{AggregateExpr, MirRelationExpr, MirScalarExpr, VariadicFunc};
15use mz_ore::soft_panic_or_log;
16
17use crate::{TransformCtx, TransformError};
18
19#[derive(Debug)]
21pub struct CoalesceCase;
22
23impl Default for CoalesceCase {
24 fn default() -> Self {
25 Self {}
26 }
27}
28
29impl crate::Transform for CoalesceCase {
30 fn name(&self) -> &'static str {
31 "CoalesceCase"
32 }
33
34 #[mz_ore::instrument(
35 target = "optimizer",
36 level = "debug",
37 fields(path.segment = "coalesce_case")
38 )]
39 fn actually_perform_transform(
40 &self,
41 relation: &mut MirRelationExpr,
42 _: &mut TransformCtx,
43 ) -> Result<(), TransformError> {
44 relation.try_visit_mut_post(&mut |e| self.action(e))?;
45 mz_repr::explain::trace_plan(&*relation);
46 Ok(())
47 }
48}
49
50impl CoalesceCase {
51 fn action(&self, relation: &mut MirRelationExpr) -> Result<(), TransformError> {
52 match relation {
53 MirRelationExpr::Constant { .. } |
54 MirRelationExpr::Get { .. } |
55 MirRelationExpr::Let { .. } |
56 MirRelationExpr::LetRec { .. } |
57 MirRelationExpr::Project { .. } |
58 MirRelationExpr::Negate { .. } |
59 MirRelationExpr::Threshold { .. } |
60 MirRelationExpr::Union { .. } |
61 MirRelationExpr::ArrangeBy { .. } => (),
64 MirRelationExpr::Map { scalars: exprs, .. } |
65 MirRelationExpr::Filter { predicates: exprs, .. } |
66 MirRelationExpr::FlatMap { exprs, .. } => {
67 for expr in exprs.iter_mut() {
69 self.rewrite_scalar(expr)?;
70 }
71 }
72 MirRelationExpr::Join { equivalences, implementation, .. } => {
73 if implementation.is_implemented() {
74 soft_panic_or_log!("unexpected implemented Join when optimizing coalesce/case, skipping: {implementation:?}");
75 return Ok(());
76 }
77
78 for equivalence in equivalences.iter_mut() {
79 for expr in equivalence {
80 self.rewrite_scalar(expr)?;
81 }
82 }
83 }
84 MirRelationExpr::Reduce { group_key, aggregates, .. } => {
85 for expr in group_key.iter_mut() {
86 self.rewrite_scalar(expr)?;
87 }
88
89 for agg in aggregates.iter_mut() {
90 self.rewrite_aggreagte(agg)?;
91 }
92 }
93 MirRelationExpr::TopK { limit, .. } => {
94 if let Some(expr) = limit {
95 self.rewrite_scalar(expr)?;
96 }
97 }
98 }
99
100 Ok(())
101 }
102
103 fn rewrite_scalar(&self, expr: &mut MirScalarExpr) -> Result<(), TransformError> {
104 expr.visit_mut_pre(&mut |e| self.try_combine_coalesce_case(e))
106 .map_err(TransformError::from)
107 }
108
109 fn rewrite_aggreagte(&self, agg: &mut AggregateExpr) -> Result<(), TransformError> {
110 self.rewrite_scalar(&mut agg.expr)
112 }
113
114 fn try_combine_coalesce_case(&self, expr: &mut MirScalarExpr) {
115 expr.flatten_associative();
119
120 if let MirScalarExpr::CallVariadic { func, exprs } = expr
121 && *func == VariadicFunc::Coalesce(Coalesce)
122 {
123 if let MirScalarExpr::If { .. } = &exprs[0] {
124 let mut exprs = std::mem::take(exprs);
125 if let MirScalarExpr::If {
126 mut cond,
127 mut then,
128 mut els,
129 } = exprs.remove(0)
130 {
131 let cond = cond.take();
132
133 let mut then_exprs = Vec::with_capacity(exprs.len() + 1);
134 then_exprs.push(then.take());
135 then_exprs.extend(exprs.iter().cloned());
136
137 let mut else_exprs = Vec::with_capacity(exprs.len() + 1);
138 else_exprs.push(els.take());
139 else_exprs.append(&mut exprs);
140
141 let t =
142 MirScalarExpr::call_variadic(VariadicFunc::Coalesce(Coalesce), then_exprs);
143 let f =
144 MirScalarExpr::call_variadic(VariadicFunc::Coalesce(Coalesce), else_exprs);
145 *expr = cond.if_then_else(t, f);
146 } else {
147 unreachable!();
148 };
149 }
150 }
151 }
152}