Skip to main content

mz_transform/
coalesce_case.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Pushes `COALESCE` into `CASE WHEN`
11
12use mz_expr::func::variadic::Coalesce;
13use mz_expr::visit::Visit;
14use mz_expr::{AggregateExpr, MirRelationExpr, MirScalarExpr, VariadicFunc};
15use mz_ore::soft_panic_or_log;
16
17use crate::{TransformCtx, TransformError};
18
19/// Push `COALESCE` into `CASE WHEN` as possible.
20#[derive(Debug)]
21pub struct CoalesceCase;
22
23impl Default for CoalesceCase {
24    fn default() -> Self {
25        Self {}
26    }
27}
28
29impl crate::Transform for CoalesceCase {
30    fn name(&self) -> &'static str {
31        "CoalesceCase"
32    }
33
34    #[mz_ore::instrument(
35        target = "optimizer",
36        level = "debug",
37        fields(path.segment = "coalesce_case")
38    )]
39    fn actually_perform_transform(
40        &self,
41        relation: &mut MirRelationExpr,
42        _: &mut TransformCtx,
43    ) -> Result<(), TransformError> {
44        relation.try_visit_mut_post(&mut |e| self.action(e))?;
45        mz_repr::explain::trace_plan(&*relation);
46        Ok(())
47    }
48}
49
50impl CoalesceCase {
51    fn action(&self, relation: &mut MirRelationExpr) -> Result<(), TransformError> {
52        match relation {
53            MirRelationExpr::Constant { .. } |
54            MirRelationExpr::Get { .. } |
55            MirRelationExpr::Let { .. } |
56            MirRelationExpr::LetRec { .. } |
57            MirRelationExpr::Project { .. } |
58            MirRelationExpr::Negate { .. } |
59            MirRelationExpr::Threshold { .. } |
60            MirRelationExpr::Union { .. } |
61            // don't mess with arrangements, even though their keys have MSEs in them
62            // these _mostly_ shouldn't occur, since we run before join implementation, but some may be inserted earlier for us
63            MirRelationExpr::ArrangeBy { .. } => (),
64            MirRelationExpr::Map { scalars: exprs, .. } |
65            MirRelationExpr::Filter { predicates: exprs, .. } |
66            MirRelationExpr::FlatMap { exprs, .. } => {
67                // NB TableFunc doesn't ever hold an MSE
68                for expr in exprs.iter_mut() {
69                    self.rewrite_scalar(expr)?;
70                }
71            }
72            MirRelationExpr::Join { equivalences, implementation, .. } => {
73                if implementation.is_implemented() {
74                    soft_panic_or_log!("unexpected implemented Join when optimizing coalesce/case, skipping: {implementation:?}");
75                    return Ok(());
76                }
77
78                for equivalence in equivalences.iter_mut() {
79                    for expr in equivalence {
80                        self.rewrite_scalar(expr)?;
81                    }
82                }
83            }
84            MirRelationExpr::Reduce { group_key, aggregates, .. } => {
85                for expr in group_key.iter_mut() {
86                    self.rewrite_scalar(expr)?;
87                }
88
89                for agg in aggregates.iter_mut() {
90                    self.rewrite_aggreagte(agg)?;
91                }
92            }
93            MirRelationExpr::TopK { limit, .. } => {
94                if let Some(expr) = limit {
95                    self.rewrite_scalar(expr)?;
96                }
97            }
98        }
99
100        Ok(())
101    }
102
103    fn rewrite_scalar(&self, expr: &mut MirScalarExpr) -> Result<(), TransformError> {
104        // Visiting in pre-order means that when we push a `COALESCE` down, we'll keep pushing if the `CASE` chain continues.
105        expr.visit_mut_pre(&mut |e| self.try_combine_coalesce_case(e))
106            .map_err(TransformError::from)
107    }
108
109    fn rewrite_aggreagte(&self, agg: &mut AggregateExpr) -> Result<(), TransformError> {
110        // NB AggregateFunc doesn't contain any MSEs
111        self.rewrite_scalar(&mut agg.expr)
112    }
113
114    fn try_combine_coalesce_case(&self, expr: &mut MirScalarExpr) {
115        // COALESCE(CASE WHEN e_cond THEN e_then ELSE e_else END, ...)
116        // ->
117        // CASE WHEN e_cond THEN COALESCE(e_then, ...) ELSE COALESCE(e_else, ...) END
118        expr.flatten_associative();
119
120        if let MirScalarExpr::CallVariadic { func, exprs } = expr
121            && *func == VariadicFunc::Coalesce(Coalesce)
122        {
123            if let MirScalarExpr::If { .. } = &exprs[0] {
124                let mut exprs = std::mem::take(exprs);
125                if let MirScalarExpr::If {
126                    mut cond,
127                    mut then,
128                    mut els,
129                } = exprs.remove(0)
130                {
131                    let cond = cond.take();
132
133                    let mut then_exprs = Vec::with_capacity(exprs.len() + 1);
134                    then_exprs.push(then.take());
135                    then_exprs.extend(exprs.iter().cloned());
136
137                    let mut else_exprs = Vec::with_capacity(exprs.len() + 1);
138                    else_exprs.push(els.take());
139                    else_exprs.append(&mut exprs);
140
141                    let t =
142                        MirScalarExpr::call_variadic(VariadicFunc::Coalesce(Coalesce), then_exprs);
143                    let f =
144                        MirScalarExpr::call_variadic(VariadicFunc::Coalesce(Coalesce), else_exprs);
145                    *expr = cond.if_then_else(t, f);
146                } else {
147                    unreachable!();
148                };
149            }
150        }
151    }
152}