Skip to main content

mz_expr/scalar/func/
macros.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10// Convenience macro for generating `inverse` values.
11macro_rules! to_unary {
12    ($f:expr) => {
13        Some(crate::UnaryFunc::from($f))
14    };
15}
16
17#[cfg(test)]
18mod test {
19    use mz_expr_derive::sqlfunc;
20    use mz_repr::SqlScalarType;
21
22    use crate::EvalError;
23    use crate::scalar::func::LazyUnaryFunc;
24
25    #[sqlfunc(sqlname = "INFALLIBLE")]
26    fn infallible1(a: f32) -> f32 {
27        a
28    }
29
30    #[sqlfunc]
31    fn infallible2(a: Option<f32>) -> f32 {
32        a.unwrap_or_default()
33    }
34
35    #[sqlfunc]
36    fn infallible3(a: f32) -> Option<f32> {
37        Some(a)
38    }
39
40    #[mz_ore::test]
41    fn elision_rules_infallible() {
42        assert_eq!(format!("{}", Infallible1), "INFALLIBLE");
43        assert!(Infallible1.propagates_nulls());
44        assert!(!Infallible1.introduces_nulls());
45
46        assert!(!Infallible2.propagates_nulls());
47        assert!(!Infallible2.introduces_nulls());
48
49        assert!(Infallible3.propagates_nulls());
50        assert!(Infallible3.introduces_nulls());
51    }
52
53    #[mz_ore::test]
54    fn output_types_infallible() {
55        assert_eq!(
56            Infallible1.output_type(SqlScalarType::Float32.nullable(true)),
57            SqlScalarType::Float32.nullable(true)
58        );
59        assert_eq!(
60            Infallible1.output_type(SqlScalarType::Float32.nullable(false)),
61            SqlScalarType::Float32.nullable(false)
62        );
63
64        assert_eq!(
65            Infallible2.output_type(SqlScalarType::Float32.nullable(true)),
66            SqlScalarType::Float32.nullable(false)
67        );
68        assert_eq!(
69            Infallible2.output_type(SqlScalarType::Float32.nullable(false)),
70            SqlScalarType::Float32.nullable(false)
71        );
72
73        assert_eq!(
74            Infallible3.output_type(SqlScalarType::Float32.nullable(true)),
75            SqlScalarType::Float32.nullable(true)
76        );
77        assert_eq!(
78            Infallible3.output_type(SqlScalarType::Float32.nullable(false)),
79            SqlScalarType::Float32.nullable(true)
80        );
81    }
82
83    #[sqlfunc]
84    fn fallible1(a: f32) -> Result<f32, EvalError> {
85        Ok(a)
86    }
87
88    #[sqlfunc]
89    fn fallible2(a: Option<f32>) -> Result<f32, EvalError> {
90        Ok(a.unwrap_or_default())
91    }
92
93    #[sqlfunc]
94    fn fallible3(a: f32) -> Result<Option<f32>, EvalError> {
95        Ok(Some(a))
96    }
97
98    #[mz_ore::test]
99    fn elision_rules_fallible() {
100        assert!(Fallible1.propagates_nulls());
101        assert!(!Fallible1.introduces_nulls());
102
103        assert!(!Fallible2.propagates_nulls());
104        assert!(!Fallible2.introduces_nulls());
105
106        assert!(Fallible3.propagates_nulls());
107        assert!(Fallible3.introduces_nulls());
108    }
109
110    #[mz_ore::test]
111    fn output_types_fallible() {
112        assert_eq!(
113            Fallible1.output_type(SqlScalarType::Float32.nullable(true)),
114            SqlScalarType::Float32.nullable(true)
115        );
116        assert_eq!(
117            Fallible1.output_type(SqlScalarType::Float32.nullable(false)),
118            SqlScalarType::Float32.nullable(false)
119        );
120
121        assert_eq!(
122            Fallible2.output_type(SqlScalarType::Float32.nullable(true)),
123            SqlScalarType::Float32.nullable(false)
124        );
125        assert_eq!(
126            Fallible2.output_type(SqlScalarType::Float32.nullable(false)),
127            SqlScalarType::Float32.nullable(false)
128        );
129
130        assert_eq!(
131            Fallible3.output_type(SqlScalarType::Float32.nullable(true)),
132            SqlScalarType::Float32.nullable(true)
133        );
134        assert_eq!(
135            Fallible3.output_type(SqlScalarType::Float32.nullable(false)),
136            SqlScalarType::Float32.nullable(true)
137        );
138    }
139}
140
141/// Temporary macro that generates the equivalent of what enum_dispatch will do in the future. We
142/// need this manual macro implementation to delegate to the previous manual implementation for
143/// variants that use the old definitions.
144///
145/// Once everything is handled by this macro we can remove it and replace it with `enum_dispatch`
146macro_rules! derive_unary {
147    ($($name:ident),*) => {
148        #[derive(
149            Ord, PartialOrd, Clone, Debug, Eq, PartialEq,
150            serde::Serialize, serde::Deserialize, Hash,
151            mz_lowertest::MzReflect,
152        )]
153        pub enum UnaryFunc {
154            $($name($name),)*
155        }
156
157        impl UnaryFunc {
158            pub fn eval<'a>(
159                &'a self,
160                datums: &[Datum<'a>],
161                temp_storage: &'a RowArena,
162                a: &'a MirScalarExpr,
163            ) -> Result<Datum<'a>, EvalError> {
164                match self {
165                    $(Self::$name(f) => f.eval(datums, temp_storage, a),)*
166                }
167            }
168
169            pub fn output_type(&self, input_type: SqlColumnType) -> SqlColumnType {
170                match self {
171                    $(Self::$name(f) => LazyUnaryFunc::output_type(f, input_type),)*
172                }
173            }
174            pub fn propagates_nulls(&self) -> bool {
175                match self {
176                    $(Self::$name(f) => LazyUnaryFunc::propagates_nulls(f),)*
177                }
178            }
179            pub fn introduces_nulls(&self) -> bool {
180                match self {
181                    $(Self::$name(f) => LazyUnaryFunc::introduces_nulls(f),)*
182                }
183            }
184            pub fn preserves_uniqueness(&self) -> bool {
185                match self {
186                    $(Self::$name(f) => LazyUnaryFunc::preserves_uniqueness(f),)*
187                }
188            }
189            pub fn inverse(&self) -> Option<UnaryFunc> {
190                match self {
191                    $(Self::$name(f) => LazyUnaryFunc::inverse(f),)*
192                }
193            }
194            pub fn is_monotone(&self) -> bool {
195                match self {
196                    $(Self::$name(f) => LazyUnaryFunc::is_monotone(f),)*
197                }
198            }
199            pub fn could_error(&self) -> bool {
200                match self {
201                    $(Self::$name(f) => LazyUnaryFunc::could_error(f),)*
202                }
203            }
204        }
205
206        impl fmt::Display for UnaryFunc {
207            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
208                match self {
209                    $(Self::$name(func) => func.fmt(f),)*
210                }
211            }
212        }
213
214        $(
215            impl From<$name> for crate::UnaryFunc {
216                fn from(variant: $name) -> Self {
217                    Self::$name(variant)
218                }
219            }
220        )*
221    }
222}
223
224/// Generates the `BinaryFunc` enum, its `impl` block (with 8 delegating methods),
225/// `Display` impl, and `From<InnerType>` impls for each variant.
226///
227/// All variants must use explicit `Name(Type)` syntax. When the variant name equals
228/// the inner type name, write e.g. `AddInt16(AddInt16)`.
229macro_rules! derive_binary {
230    ($($name:ident ( $variant:ident )),* $(,)?) => {
231        #[derive(
232            Ord, PartialOrd, Clone, Debug, Eq, PartialEq,
233            serde::Serialize, serde::Deserialize, Hash,
234            mz_lowertest::MzReflect,
235        )]
236        pub enum BinaryFunc {
237            $($name($variant),)*
238        }
239
240        impl BinaryFunc {
241            pub fn eval<'a>(
242                &'a self,
243                datums: &[Datum<'a>],
244                temp_storage: &'a RowArena,
245                exprs: &[&'a MirScalarExpr],
246            ) -> Result<Datum<'a>, EvalError> {
247                match self {
248                    $(Self::$name(f) => f.eval(datums, temp_storage, exprs),)*
249                }
250            }
251
252            pub fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
253                match self {
254                    $(Self::$name(f) => {
255                        LazyBinaryFunc::output_type(f, input_types)
256                    },)*
257                }
258            }
259
260            pub fn propagates_nulls(&self) -> bool {
261                match self {
262                    $(Self::$name(f) => LazyBinaryFunc::propagates_nulls(f),)*
263                }
264            }
265
266            pub fn introduces_nulls(&self) -> bool {
267                match self {
268                    $(Self::$name(f) => LazyBinaryFunc::introduces_nulls(f),)*
269                }
270            }
271
272            pub fn is_infix_op(&self) -> bool {
273                match self {
274                    $(Self::$name(f) => LazyBinaryFunc::is_infix_op(f),)*
275                }
276            }
277
278            pub fn negate(&self) -> Option<BinaryFunc> {
279                match self {
280                    $(Self::$name(f) => LazyBinaryFunc::negate(f),)*
281                }
282            }
283
284            pub fn could_error(&self) -> bool {
285                match self {
286                    $(Self::$name(f) => LazyBinaryFunc::could_error(f),)*
287                }
288            }
289
290            pub fn is_monotone(&self) -> (bool, bool) {
291                match self {
292                    $(Self::$name(f) => LazyBinaryFunc::is_monotone(f),)*
293                }
294            }
295        }
296
297        impl fmt::Display for BinaryFunc {
298            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
299                match self {
300                    $(Self::$name(func) => func.fmt(f),)*
301                }
302            }
303        }
304
305        $(
306            impl From<$variant> for crate::BinaryFunc {
307                fn from(variant: $variant) -> Self {
308                    Self::$name(variant)
309                }
310            }
311        )*
312    }
313}