mz_expr/scalar/reduce/if_then.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Post-order rewrites for `If` nodes.
11
12use mz_repr::{Datum, ReprColumnType, ReprScalarType};
13
14use crate::MirScalarExpr;
15
16pub(super) fn reduce_if(e: &mut MirScalarExpr, column_types: &[ReprColumnType]) {
17 let MirScalarExpr::If { cond, then, els } = e else {
18 unreachable!()
19 };
20
21 if let Some(literal) = cond.as_literal() {
22 match literal {
23 Ok(Datum::True) => *e = then.take(),
24 Ok(Datum::False) | Ok(Datum::Null) => *e = els.take(),
25 Err(err) => {
26 *e = MirScalarExpr::Literal(
27 Err(err.clone()),
28 then.typ(column_types)
29 .union(&els.typ(column_types))
30 .unwrap(),
31 );
32 }
33 _ => unreachable!(),
34 }
35 return;
36 }
37 if then == els {
38 *e = then.take();
39 return;
40 }
41 if then.is_literal_ok()
42 && els.is_literal_ok()
43 && then.typ(column_types).scalar_type == ReprScalarType::Bool
44 && els.typ(column_types).scalar_type == ReprScalarType::Bool
45 {
46 match (then.as_literal(), els.as_literal()) {
47 // Note: NULLs from the condition should not be propagated to the result
48 // of the expression.
49 (Some(Ok(Datum::True)), _) => {
50 // Rewritten as ((<cond> IS NOT NULL) AND (<cond>)) OR (<els>)
51 // NULL <cond> results in: (FALSE AND NULL) OR (<els>) => (<els>)
52 *e = cond
53 .clone()
54 .call_is_null()
55 .not()
56 .and(cond.take())
57 .or(els.take());
58 }
59 (Some(Ok(Datum::False)), _) => {
60 // Rewritten as ((NOT <cond>) OR (<cond> IS NULL)) AND (<els>)
61 // NULL <cond> results in: (NULL OR TRUE) AND (<els>) => TRUE AND (<els>) => (<els>)
62 *e = cond
63 .clone()
64 .not()
65 .or(cond.take().call_is_null())
66 .and(els.take());
67 }
68 (_, Some(Ok(Datum::True))) => {
69 // Rewritten as (NOT <cond>) OR (<cond> IS NULL) OR (<then>)
70 // NULL <cond> results in: NULL OR TRUE OR (<then>) => TRUE
71 *e = cond
72 .clone()
73 .not()
74 .or(cond.take().call_is_null())
75 .or(then.take());
76 }
77 (_, Some(Ok(Datum::False))) => {
78 // Rewritten as (<cond> IS NOT NULL) AND (<cond>) AND (<then>)
79 // NULL <cond> results in: FALSE AND NULL AND (<then>) => FALSE
80 *e = cond
81 .clone()
82 .call_is_null()
83 .not()
84 .and(cond.take())
85 .and(then.take());
86 }
87 _ => {}
88 }
89 return;
90 }
91
92 // Equivalent expression structure would allow us to push the `If` into the expression.
93 // For example, `IF <cond> THEN x = y ELSE x = z` becomes `x = IF <cond> THEN y ELSE z`.
94 //
95 // We have to also make sure that the expressions that will end up in
96 // the two `If` branches have unionable types. Otherwise, the `If` could
97 // not be typed by `typ`. An example where this could cause an issue is
98 // when pulling out `cast_jsonbable_to_jsonb`, which accepts a wide
99 // range of input types. (In theory, we could still do the optimization
100 // in this case by inserting appropriate casts, but this corner case is
101 // not worth the complication for now.)
102 // See https://github.com/MaterializeInc/database-issues/issues/9182
103 match (&mut **then, &mut **els) {
104 (
105 MirScalarExpr::CallUnary { func: f1, expr: e1 },
106 MirScalarExpr::CallUnary { func: f2, expr: e2 },
107 ) if f1 == f2 && e1.typ(column_types) == e2.typ(column_types) => {
108 *e = cond
109 .take()
110 .if_then_else(e1.take(), e2.take())
111 .call_unary(f1.clone());
112 }
113 (
114 MirScalarExpr::CallBinary {
115 func: f1,
116 expr1: e1a,
117 expr2: e2a,
118 },
119 MirScalarExpr::CallBinary {
120 func: f2,
121 expr1: e1b,
122 expr2: e2b,
123 },
124 ) if f1 == f2 && e1a == e1b && e2a.typ(column_types) == e2b.typ(column_types) => {
125 *e = e1a
126 .take()
127 .call_binary(cond.take().if_then_else(e2a.take(), e2b.take()), f1.clone());
128 }
129 (
130 MirScalarExpr::CallBinary {
131 func: f1,
132 expr1: e1a,
133 expr2: e2a,
134 },
135 MirScalarExpr::CallBinary {
136 func: f2,
137 expr1: e1b,
138 expr2: e2b,
139 },
140 ) if f1 == f2 && e2a == e2b && e1a.typ(column_types) == e1b.typ(column_types) => {
141 *e = cond
142 .take()
143 .if_then_else(e1a.take(), e1b.take())
144 .call_binary(e2a.take(), f1.clone());
145 }
146 _ => {}
147 }
148}