mz_expr/scalar/reduce/
unary.rs1use mz_repr::{ReprColumnType, RowArena};
13
14use crate::scalar::func::{self, UnaryFunc, VariadicFunc};
15use crate::{Eval, MirScalarExpr};
16
17pub(super) fn reduce_call_unary(
18 e: &mut MirScalarExpr,
19 column_types: &[ReprColumnType],
20 temp_storage: &RowArena,
21) {
22 let MirScalarExpr::CallUnary { func, expr } = e else {
23 unreachable!()
24 };
25
26 if expr.is_literal() && *func != UnaryFunc::Panic(func::Panic) {
27 *e = MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type);
28 return;
29 }
30
31 if let MirScalarExpr::CallUnary {
47 func: inner_func,
48 expr: inner_expr,
49 } = &mut **expr
50 {
51 if inner_func.inverse().as_ref() == Some(func)
52 && inner_func.preserves_uniqueness()
53 && !inner_func.could_error()
54 && !func.could_error()
55 && inner_func.propagates_nulls()
56 && func.propagates_nulls()
57 {
58 let inner = inner_expr.take();
59 *e = inner;
60 return;
61 }
62 }
63
64 if let UnaryFunc::RecordGet(func::RecordGet(i)) = *func {
66 if let MirScalarExpr::CallVariadic {
67 func: VariadicFunc::RecordCreate(..),
68 exprs,
69 } = &mut **expr
70 {
71 *e = exprs.swap_remove(i);
72 }
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use mz_repr::ReprScalarType;
79
80 use crate::MirScalarExpr;
81 use crate::scalar::func;
82
83 #[mz_ore::test]
84 fn involution_folds() {
85 let col = || MirScalarExpr::column(0);
88
89 let bool_col = [ReprScalarType::Bool.nullable(true)];
90 let mut e = col().call_unary(func::Not).call_unary(func::Not);
91 e.reduce(&bool_col);
92 assert_eq!(e, col());
93
94 let int32 = [ReprScalarType::Int32.nullable(true)];
95 let mut e = col()
96 .call_unary(func::BitNotInt32)
97 .call_unary(func::BitNotInt32);
98 e.reduce(&int32);
99 assert_eq!(e, col());
100
101 let string = [ReprScalarType::String.nullable(true)];
102 let mut e = col().call_unary(func::Reverse).call_unary(func::Reverse);
103 e.reduce(&string);
104 assert_eq!(e, col());
105
106 let numeric = [ReprScalarType::Numeric {}.nullable(true)];
108 let mut e = col()
109 .call_unary(func::NegNumeric)
110 .call_unary(func::NegNumeric);
111 e.reduce(&numeric);
112 assert_eq!(e, col());
113
114 let bool_col = [ReprScalarType::Bool.nullable(true)];
117 let mut e = col()
118 .call_unary(func::CastBoolToInt32)
119 .call_unary(func::CastInt32ToBool);
120 e.reduce(&bool_col);
121 assert_eq!(e, col());
122
123 let int64 = [ReprScalarType::Int64.nullable(true)];
126 let mut e = col().call_unary(func::NegInt64).call_unary(func::NegInt64);
127 e.reduce(&int64);
128 assert_ne!(e, col());
129
130 let int32 = [ReprScalarType::Int32.nullable(true)];
133 let mut e = col()
134 .call_unary(func::CastInt32ToInt64)
135 .call_unary(func::CastInt64ToInt32);
136 e.reduce(&int32);
137 assert_ne!(e, col());
138 }
139}