Skip to main content

mz_expr/scalar/reduce/
unary.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Post-order rewrites for `CallUnary` nodes.
11
12use mz_repr::{ReprColumnType, RowArena};
13
14use crate::scalar::func::{self, UnaryFunc, VariadicFunc};
15use crate::{Eval, MirScalarExpr};
16
17pub(super) fn reduce_call_unary(
18    e: &mut MirScalarExpr,
19    column_types: &[ReprColumnType],
20    temp_storage: &RowArena,
21) {
22    let MirScalarExpr::CallUnary { func, expr } = e else {
23        unreachable!()
24    };
25
26    if expr.is_literal() && *func != UnaryFunc::Panic(func::Panic) {
27        *e = MirScalarExpr::literal(e.eval(&[], temp_storage), e.typ(column_types).scalar_type);
28        return;
29    }
30
31    // `f(g(x))` → `x` when `f` is a left inverse of `g`. Mind the contract:
32    // `g.inverse()` does not always return a mathematical inverse. Per its
33    // documentation it returns a left inverse — what this elimination needs —
34    // exactly when *`g` itself* preserves uniqueness; when only the *returned
35    // function* preserves uniqueness it is merely a right inverse (`g(f(y)) =
36    // y`, useful for moving casts across equalities, useless here). Hence the
37    // `preserves_uniqueness` check below is on `g`, and is load-bearing.
38    //
39    // Eliding the calls must also not change error or null behavior, so both
40    // functions must be infallible and propagate nulls. (This is what
41    // excludes e.g. `-(-x)` on integers, whose inner negation errors on the
42    // minimum value, and `int8(int4(x))`-style widenings, whose outer
43    // narrowing can error — while admitting `NOT(NOT(x))`, `~(~x)`,
44    // `reverse(reverse(x))`, numeric `-(-x)`, and infallible bijective cast
45    // roundtrips such as `bool::int4::bool`.)
46    if let MirScalarExpr::CallUnary {
47        func: inner_func,
48        expr: inner_expr,
49    } = &mut **expr
50    {
51        if inner_func.inverse().as_ref() == Some(func)
52            && inner_func.preserves_uniqueness()
53            && !inner_func.could_error()
54            && !func.could_error()
55            && inner_func.propagates_nulls()
56            && func.propagates_nulls()
57        {
58            let inner = inner_expr.take();
59            *e = inner;
60            return;
61        }
62    }
63
64    // `RecordGet(i)(RecordCreate(args))` → `args[i]`.
65    if let UnaryFunc::RecordGet(func::RecordGet(i)) = *func {
66        if let MirScalarExpr::CallVariadic {
67            func: VariadicFunc::RecordCreate(..),
68            exprs,
69        } = &mut **expr
70        {
71            *e = exprs.swap_remove(i);
72        }
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use mz_repr::ReprScalarType;
79
80    use crate::MirScalarExpr;
81    use crate::scalar::func;
82
83    #[mz_ore::test]
84    fn involution_folds() {
85        // `f(g(x))` folds to `x` when `g` names `f` as its inverse and both
86        // calls are infallible.
87        let col = || MirScalarExpr::column(0);
88
89        let bool_col = [ReprScalarType::Bool.nullable(true)];
90        let mut e = col().call_unary(func::Not).call_unary(func::Not);
91        e.reduce(&bool_col);
92        assert_eq!(e, col());
93
94        let int32 = [ReprScalarType::Int32.nullable(true)];
95        let mut e = col()
96            .call_unary(func::BitNotInt32)
97            .call_unary(func::BitNotInt32);
98        e.reduce(&int32);
99        assert_eq!(e, col());
100
101        let string = [ReprScalarType::String.nullable(true)];
102        let mut e = col().call_unary(func::Reverse).call_unary(func::Reverse);
103        e.reduce(&string);
104        assert_eq!(e, col());
105
106        // Numeric negation is infallible, so its double application folds.
107        let numeric = [ReprScalarType::Numeric {}.nullable(true)];
108        let mut e = col()
109            .call_unary(func::NegNumeric)
110            .call_unary(func::NegNumeric);
111        e.reduce(&numeric);
112        assert_eq!(e, col());
113
114        // An infallible bijective cast roundtrip folds: `int4(bool)` declares
115        // `bool(int4)` as its (left) inverse, and both directions are total.
116        let bool_col = [ReprScalarType::Bool.nullable(true)];
117        let mut e = col()
118            .call_unary(func::CastBoolToInt32)
119            .call_unary(func::CastInt32ToBool);
120        e.reduce(&bool_col);
121        assert_eq!(e, col());
122
123        // Integer negation errors on the minimum value, so eliding the double
124        // negation would suppress that error; it must not fold.
125        let int64 = [ReprScalarType::Int64.nullable(true)];
126        let mut e = col().call_unary(func::NegInt64).call_unary(func::NegInt64);
127        e.reduce(&int64);
128        assert_ne!(e, col());
129
130        // A widening cast roundtrip stays: the outer narrowing can error in
131        // general, and the gate does not reason about images.
132        let int32 = [ReprScalarType::Int32.nullable(true)];
133        let mut e = col()
134            .call_unary(func::CastInt32ToInt64)
135            .call_unary(func::CastInt64ToInt32);
136        e.reduce(&int32);
137        assert_ne!(e, col());
138    }
139}