Skip to main content

mz_transform/
coalesce_case.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Pushes `COALESCE` into `CASE WHEN`
11
12use mz_expr::func::variadic::Coalesce;
13use mz_expr::visit::Visit;
14use mz_expr::{AggregateExpr, MirRelationExpr, MirScalarExpr, VariadicFunc};
15use mz_ore::soft_panic_or_log;
16
17use crate::{TransformCtx, TransformError};
18
19/// Push `COALESCE` into `CASE WHEN` as possible.
20#[derive(Debug)]
21pub struct CoalesceCase;
22
23impl Default for CoalesceCase {
24    fn default() -> Self {
25        Self {}
26    }
27}
28
29impl crate::Transform for CoalesceCase {
30    fn name(&self) -> &'static str {
31        "CoalesceCase"
32    }
33
34    #[mz_ore::instrument(
35        target = "optimizer",
36        level = "debug",
37        fields(path.segment = "coalesce_case")
38    )]
39    fn actually_perform_transform(
40        &self,
41        relation: &mut MirRelationExpr,
42        _: &mut TransformCtx,
43    ) -> Result<(), TransformError> {
44        relation.visit_mut_post(&mut |e| self.action(e));
45        mz_repr::explain::trace_plan(&*relation);
46        Ok(())
47    }
48}
49
50impl CoalesceCase {
51    fn action(&self, relation: &mut MirRelationExpr) {
52        match relation {
53            MirRelationExpr::Constant { .. } |
54            MirRelationExpr::Get { .. } |
55            MirRelationExpr::Let { .. } |
56            MirRelationExpr::LetRec { .. } |
57            MirRelationExpr::Project { .. } |
58            MirRelationExpr::Negate { .. } |
59            MirRelationExpr::Threshold { .. } |
60            MirRelationExpr::Union { .. } |
61            // don't mess with arrangements, even though their keys have MSEs in them
62            // these _mostly_ shouldn't occur, since we run before join implementation, but some may be inserted earlier for us
63            MirRelationExpr::ArrangeBy { .. } => (),
64            MirRelationExpr::Map { scalars: exprs, .. } |
65            MirRelationExpr::Filter { predicates: exprs, .. } |
66            MirRelationExpr::FlatMap { exprs, .. } => {
67                // NB TableFunc doesn't ever hold an MSE
68                for expr in exprs.iter_mut() {
69                    self.rewrite_scalar(expr);
70                }
71            }
72            MirRelationExpr::Join { equivalences, implementation, .. } => {
73                if implementation.is_implemented() {
74                    soft_panic_or_log!("unexpected implemented Join when optimizing coalesce/case, skipping: {implementation:?}");
75                    return;
76                }
77
78                for equivalence in equivalences.iter_mut() {
79                    for expr in equivalence {
80                        self.rewrite_scalar(expr);
81                    }
82                }
83            }
84            MirRelationExpr::Reduce { group_key, aggregates, .. } => {
85                for expr in group_key.iter_mut() {
86                    self.rewrite_scalar(expr);
87                }
88
89                for agg in aggregates.iter_mut() {
90                    self.rewrite_aggreagte(agg);
91                }
92            }
93            MirRelationExpr::TopK { limit, .. } => {
94                if let Some(expr) = limit {
95                    self.rewrite_scalar(expr);
96                }
97            }
98        }
99    }
100
101    fn rewrite_scalar(&self, expr: &mut MirScalarExpr) {
102        // Visiting in pre-order means that when we push a `COALESCE` down, we'll keep pushing if the `CASE` chain continues.
103        expr.visit_mut_pre(&mut |e| self.try_combine_coalesce_case(e))
104    }
105
106    fn rewrite_aggreagte(&self, agg: &mut AggregateExpr) {
107        // NB AggregateFunc doesn't contain any MSEs
108        self.rewrite_scalar(&mut agg.expr)
109    }
110
111    fn try_combine_coalesce_case(&self, expr: &mut MirScalarExpr) {
112        // COALESCE(CASE WHEN e_cond THEN e_then ELSE e_else END, ...)
113        // ->
114        // CASE WHEN e_cond THEN COALESCE(e_then, ...) ELSE COALESCE(e_else, ...) END
115        expr.flatten_associative();
116
117        if let MirScalarExpr::CallVariadic { func, exprs } = expr
118            && *func == VariadicFunc::Coalesce(Coalesce)
119        {
120            if let MirScalarExpr::If { .. } = &exprs[0] {
121                let mut exprs = std::mem::take(exprs);
122                if let MirScalarExpr::If {
123                    mut cond,
124                    mut then,
125                    mut els,
126                } = exprs.remove(0)
127                {
128                    let cond = cond.take();
129
130                    let mut then_exprs = Vec::with_capacity(exprs.len() + 1);
131                    then_exprs.push(then.take());
132                    then_exprs.extend(exprs.iter().cloned());
133
134                    let mut else_exprs = Vec::with_capacity(exprs.len() + 1);
135                    else_exprs.push(els.take());
136                    else_exprs.append(&mut exprs);
137
138                    let t =
139                        MirScalarExpr::call_variadic(VariadicFunc::Coalesce(Coalesce), then_exprs);
140                    let f =
141                        MirScalarExpr::call_variadic(VariadicFunc::Coalesce(Coalesce), else_exprs);
142                    *expr = cond.if_then_else(t, f);
143                } else {
144                    unreachable!();
145                };
146            }
147        }
148    }
149}