Skip to main content

mz_expr/scalar/reduce/
if_then.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Post-order rewrites for `If` nodes.
11
12use mz_repr::{Datum, ReprColumnType, ReprScalarType};
13
14use crate::MirScalarExpr;
15
16pub(super) fn reduce_if(e: &mut MirScalarExpr, column_types: &[ReprColumnType]) {
17    let MirScalarExpr::If { cond, then, els } = e else {
18        unreachable!()
19    };
20
21    if let Some(literal) = cond.as_literal() {
22        match literal {
23            Ok(Datum::True) => *e = then.take(),
24            Ok(Datum::False) | Ok(Datum::Null) => *e = els.take(),
25            Err(err) => {
26                *e = MirScalarExpr::Literal(
27                    Err(err.clone()),
28                    then.typ(column_types)
29                        .union(&els.typ(column_types))
30                        .unwrap(),
31                );
32            }
33            _ => unreachable!(),
34        }
35        return;
36    }
37    if then == els {
38        *e = then.take();
39        return;
40    }
41    if then.is_literal_ok()
42        && els.is_literal_ok()
43        && then.typ(column_types).scalar_type == ReprScalarType::Bool
44        && els.typ(column_types).scalar_type == ReprScalarType::Bool
45    {
46        match (then.as_literal(), els.as_literal()) {
47            // Note: NULLs from the condition should not be propagated to the result
48            // of the expression.
49            (Some(Ok(Datum::True)), _) => {
50                // Rewritten as ((<cond> IS NOT NULL) AND (<cond>)) OR (<els>)
51                // NULL <cond> results in: (FALSE AND NULL) OR (<els>) => (<els>)
52                *e = cond
53                    .clone()
54                    .call_is_null()
55                    .not()
56                    .and(cond.take())
57                    .or(els.take());
58            }
59            (Some(Ok(Datum::False)), _) => {
60                // Rewritten as ((NOT <cond>) OR (<cond> IS NULL)) AND (<els>)
61                // NULL <cond> results in: (NULL OR TRUE) AND (<els>) => TRUE AND (<els>) => (<els>)
62                *e = cond
63                    .clone()
64                    .not()
65                    .or(cond.take().call_is_null())
66                    .and(els.take());
67            }
68            (_, Some(Ok(Datum::True))) => {
69                // Rewritten as (NOT <cond>) OR (<cond> IS NULL) OR (<then>)
70                // NULL <cond> results in: NULL OR TRUE OR (<then>) => TRUE
71                *e = cond
72                    .clone()
73                    .not()
74                    .or(cond.take().call_is_null())
75                    .or(then.take());
76            }
77            (_, Some(Ok(Datum::False))) => {
78                // Rewritten as (<cond> IS NOT NULL) AND (<cond>) AND (<then>)
79                // NULL <cond> results in: FALSE AND NULL AND (<then>) => FALSE
80                *e = cond
81                    .clone()
82                    .call_is_null()
83                    .not()
84                    .and(cond.take())
85                    .and(then.take());
86            }
87            _ => {}
88        }
89        return;
90    }
91
92    // Equivalent expression structure would allow us to push the `If` into the expression.
93    // For example, `IF <cond> THEN x = y ELSE x = z` becomes `x = IF <cond> THEN y ELSE z`.
94    //
95    // We have to also make sure that the expressions that will end up in
96    // the two `If` branches have unionable types. Otherwise, the `If` could
97    // not be typed by `typ`. An example where this could cause an issue is
98    // when pulling out `cast_jsonbable_to_jsonb`, which accepts a wide
99    // range of input types. (In theory, we could still do the optimization
100    // in this case by inserting appropriate casts, but this corner case is
101    // not worth the complication for now.)
102    // See https://github.com/MaterializeInc/database-issues/issues/9182
103    match (&mut **then, &mut **els) {
104        (
105            MirScalarExpr::CallUnary { func: f1, expr: e1 },
106            MirScalarExpr::CallUnary { func: f2, expr: e2 },
107        ) if f1 == f2 && e1.typ(column_types) == e2.typ(column_types) => {
108            *e = cond
109                .take()
110                .if_then_else(e1.take(), e2.take())
111                .call_unary(f1.clone());
112        }
113        (
114            MirScalarExpr::CallBinary {
115                func: f1,
116                expr1: e1a,
117                expr2: e2a,
118            },
119            MirScalarExpr::CallBinary {
120                func: f2,
121                expr1: e1b,
122                expr2: e2b,
123            },
124        ) if f1 == f2 && e1a == e1b && e2a.typ(column_types) == e2b.typ(column_types) => {
125            *e = e1a
126                .take()
127                .call_binary(cond.take().if_then_else(e2a.take(), e2b.take()), f1.clone());
128        }
129        (
130            MirScalarExpr::CallBinary {
131                func: f1,
132                expr1: e1a,
133                expr2: e2a,
134            },
135            MirScalarExpr::CallBinary {
136                func: f2,
137                expr1: e1b,
138                expr2: e2b,
139            },
140        ) if f1 == f2 && e2a == e2b && e1a.typ(column_types) == e1b.typ(column_types) => {
141            *e = cond
142                .take()
143                .if_then_else(e1a.take(), e1b.take())
144                .call_binary(e2a.take(), f1.clone());
145        }
146        _ => {}
147    }
148}