Skip to main content

mz_expr/scalar/func/
binary.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Utilities for binary functions.
11
12use mz_ore::assert_none;
13use mz_repr::{Datum, InputDatumType, OutputDatumType, RowArena, SqlColumnType};
14
15use crate::{EvalError, MirScalarExpr};
16
17/// A description of an SQL binary function that has the ability to lazy evaluate its arguments
18// This trait will eventually be annotated with #[enum_dispatch] to autogenerate the UnaryFunc enum
19pub(crate) trait LazyBinaryFunc {
20    fn eval<'a>(
21        &'a self,
22        datums: &[Datum<'a>],
23        temp_storage: &'a RowArena,
24        exprs: &[&'a MirScalarExpr],
25    ) -> Result<Datum<'a>, EvalError>;
26
27    /// The output SqlColumnType of this function.
28    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType;
29
30    /// Whether this function will produce NULL on NULL input.
31    fn propagates_nulls(&self) -> bool;
32
33    /// Whether this function will produce NULL on non-NULL input.
34    fn introduces_nulls(&self) -> bool;
35
36    /// Whether this function might error on non-error input.
37    fn could_error(&self) -> bool {
38        // NB: override this for functions that never error.
39        true
40    }
41
42    /// Returns the negation of the function, if one exists.
43    fn negate(&self) -> Option<crate::BinaryFunc>;
44
45    /// Returns true if the function is monotone. (Non-strict; either increasing or decreasing.)
46    /// Monotone functions map ranges to ranges: ie. given a range of possible inputs, we can
47    /// determine the range of possible outputs just by mapping the endpoints.
48    ///
49    /// This describes the *pointwise* behaviour of the function:
50    /// ie. the behaviour of any specific argument as the others are held constant. (For example, `a - b` is
51    /// monotone in the first argument because for any particular value of `b`, increasing `a` will
52    /// always cause the result to increase... and in the second argument because for any specific `a`,
53    /// increasing `b` will always cause the result to _decrease_.)
54    ///
55    /// This property describes the behaviour of the function over ranges where the function is defined:
56    /// ie. the arguments and the result are non-error datums.
57    fn is_monotone(&self) -> (bool, bool);
58
59    /// Yep, I guess this returns true for infix operators.
60    fn is_infix_op(&self) -> bool;
61}
62
63pub(crate) trait EagerBinaryFunc {
64    type Input<'a>: InputDatumType<'a, EvalError>;
65    type Output<'a>: OutputDatumType<'a, EvalError>;
66
67    fn call<'a>(&self, input: Self::Input<'a>, temp_storage: &'a RowArena) -> Self::Output<'a>;
68
69    /// The output SqlColumnType of this function
70    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType;
71
72    /// Whether this function will produce NULL on NULL input
73    fn propagates_nulls(&self) -> bool {
74        // If the inputs are not nullable then nulls are propagated
75        !Self::Input::nullable()
76    }
77
78    /// Whether this function will produce NULL on non-NULL input
79    fn introduces_nulls(&self) -> bool {
80        // If the output is nullable then nulls can be introduced
81        Self::Output::nullable()
82    }
83
84    /// Whether this function could produce an error
85    fn could_error(&self) -> bool {
86        Self::Output::fallible()
87    }
88
89    /// Returns the negation of the given binary function, if it exists.
90    fn negate(&self) -> Option<crate::BinaryFunc> {
91        None
92    }
93
94    fn is_monotone(&self) -> (bool, bool) {
95        (false, false)
96    }
97
98    fn is_infix_op(&self) -> bool {
99        false
100    }
101}
102
103impl<T: EagerBinaryFunc> LazyBinaryFunc for T {
104    fn eval<'a>(
105        &'a self,
106        datums: &[Datum<'a>],
107        temp_storage: &'a RowArena,
108        exprs: &[&'a MirScalarExpr],
109    ) -> Result<Datum<'a>, EvalError> {
110        let mut datums = exprs
111            .into_iter()
112            .map(|expr| expr.eval(datums, temp_storage));
113        let input = match T::Input::try_from_iter(&mut datums) {
114            // If we can convert to the input type then we call the function
115            Ok(input) => input,
116            // If we can't and we got a non-null datum something went wrong in the planner
117            Err(Ok(Some(datum))) if !datum.is_null() => {
118                return Err(EvalError::Internal("invalid input type".into()));
119            }
120            Err(Ok(None)) => {
121                return Err(EvalError::Internal("unexpectedly missing parameter".into()));
122            }
123            // Otherwise we just propagate NULLs and errors
124            Err(Ok(Some(datum))) => return Ok(datum),
125            Err(Err(res)) => return Err(res),
126        };
127        assert_none!(datums.next(), "No leftover input arguments");
128        self.call(input, temp_storage).into_result(temp_storage)
129    }
130
131    fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
132        self.output_type(input_types)
133    }
134
135    fn propagates_nulls(&self) -> bool {
136        self.propagates_nulls()
137    }
138
139    fn introduces_nulls(&self) -> bool {
140        self.introduces_nulls()
141    }
142
143    fn could_error(&self) -> bool {
144        self.could_error()
145    }
146
147    fn negate(&self) -> Option<crate::BinaryFunc> {
148        self.negate()
149    }
150
151    fn is_monotone(&self) -> (bool, bool) {
152        self.is_monotone()
153    }
154
155    fn is_infix_op(&self) -> bool {
156        self.is_infix_op()
157    }
158}
159
160pub use derive::BinaryFunc;
161
162mod derive {
163    use std::fmt;
164
165    use mz_repr::{Datum, RowArena, SqlColumnType};
166
167    use crate::scalar::func::binary::LazyBinaryFunc;
168    use crate::scalar::func::*;
169    use crate::{EvalError, MirScalarExpr};
170
171    derive_binary! {
172        AddInt16(AddInt16),
173        AddInt32(AddInt32),
174        AddInt64(AddInt64),
175        AddUint16(AddUint16),
176        AddUint32(AddUint32),
177        AddUint64(AddUint64),
178        AddFloat32(AddFloat32),
179        AddFloat64(AddFloat64),
180        AddInterval(AddInterval),
181        AddTimestampInterval(AddTimestampInterval),
182        AddTimestampTzInterval(AddTimestampTzInterval),
183        AddDateInterval(AddDateInterval),
184        AddDateTime(AddDateTime),
185        AddTimeInterval(AddTimeInterval),
186        AddNumeric(AddNumeric),
187        AgeTimestamp(AgeTimestamp),
188        AgeTimestampTz(AgeTimestampTz),
189        BitAndInt16(BitAndInt16),
190        BitAndInt32(BitAndInt32),
191        BitAndInt64(BitAndInt64),
192        BitAndUint16(BitAndUint16),
193        BitAndUint32(BitAndUint32),
194        BitAndUint64(BitAndUint64),
195        BitOrInt16(BitOrInt16),
196        BitOrInt32(BitOrInt32),
197        BitOrInt64(BitOrInt64),
198        BitOrUint16(BitOrUint16),
199        BitOrUint32(BitOrUint32),
200        BitOrUint64(BitOrUint64),
201        BitXorInt16(BitXorInt16),
202        BitXorInt32(BitXorInt32),
203        BitXorInt64(BitXorInt64),
204        BitXorUint16(BitXorUint16),
205        BitXorUint32(BitXorUint32),
206        BitXorUint64(BitXorUint64),
207        BitShiftLeftInt16(BitShiftLeftInt16),
208        BitShiftLeftInt32(BitShiftLeftInt32),
209        BitShiftLeftInt64(BitShiftLeftInt64),
210        BitShiftLeftUint16(BitShiftLeftUint16),
211        BitShiftLeftUint32(BitShiftLeftUint32),
212        BitShiftLeftUint64(BitShiftLeftUint64),
213        BitShiftRightInt16(BitShiftRightInt16),
214        BitShiftRightInt32(BitShiftRightInt32),
215        BitShiftRightInt64(BitShiftRightInt64),
216        BitShiftRightUint16(BitShiftRightUint16),
217        BitShiftRightUint32(BitShiftRightUint32),
218        BitShiftRightUint64(BitShiftRightUint64),
219        SubInt16(SubInt16),
220        SubInt32(SubInt32),
221        SubInt64(SubInt64),
222        SubUint16(SubUint16),
223        SubUint32(SubUint32),
224        SubUint64(SubUint64),
225        SubFloat32(SubFloat32),
226        SubFloat64(SubFloat64),
227        SubInterval(SubInterval),
228        SubTimestamp(SubTimestamp),
229        SubTimestampTz(SubTimestampTz),
230        SubTimestampInterval(SubTimestampInterval),
231        SubTimestampTzInterval(SubTimestampTzInterval),
232        SubDate(SubDate),
233        SubDateInterval(SubDateInterval),
234        SubTime(SubTime),
235        SubTimeInterval(SubTimeInterval),
236        SubNumeric(SubNumeric),
237        MulInt16(MulInt16),
238        MulInt32(MulInt32),
239        MulInt64(MulInt64),
240        MulUint16(MulUint16),
241        MulUint32(MulUint32),
242        MulUint64(MulUint64),
243        MulFloat32(MulFloat32),
244        MulFloat64(MulFloat64),
245        MulNumeric(MulNumeric),
246        MulInterval(MulInterval),
247        DivInt16(DivInt16),
248        DivInt32(DivInt32),
249        DivInt64(DivInt64),
250        DivUint16(DivUint16),
251        DivUint32(DivUint32),
252        DivUint64(DivUint64),
253        DivFloat32(DivFloat32),
254        DivFloat64(DivFloat64),
255        DivNumeric(DivNumeric),
256        DivInterval(DivInterval),
257        ModInt16(ModInt16),
258        ModInt32(ModInt32),
259        ModInt64(ModInt64),
260        ModUint16(ModUint16),
261        ModUint32(ModUint32),
262        ModUint64(ModUint64),
263        ModFloat32(ModFloat32),
264        ModFloat64(ModFloat64),
265        ModNumeric(ModNumeric),
266        RoundNumeric(RoundNumericBinary),
267        Eq(Eq),
268        NotEq(NotEq),
269        Lt(Lt),
270        Lte(Lte),
271        Gt(Gt),
272        Gte(Gte),
273        LikeEscape(LikeEscape),
274        IsLikeMatchCaseInsensitive(IsLikeMatchCaseInsensitive),
275        IsLikeMatchCaseSensitive(IsLikeMatchCaseSensitive),
276        IsRegexpMatchCaseSensitive(IsRegexpMatchCaseSensitive),
277        IsRegexpMatchCaseInsensitive(IsRegexpMatchCaseInsensitive),
278        ToCharTimestamp(ToCharTimestampFormat),
279        ToCharTimestampTz(ToCharTimestampTzFormat),
280        DateBinTimestamp(DateBinTimestamp),
281        DateBinTimestampTz(DateBinTimestampTz),
282        ExtractInterval(DatePartIntervalNumeric),
283        ExtractTime(DatePartTimeNumeric),
284        ExtractTimestamp(DatePartTimestampTimestampNumeric),
285        ExtractTimestampTz(DatePartTimestampTimestampTzNumeric),
286        ExtractDate(ExtractDateUnits),
287        DatePartInterval(DatePartIntervalF64),
288        DatePartTime(DatePartTimeF64),
289        DatePartTimestamp(DatePartTimestampTimestampF64),
290        DatePartTimestampTz(DatePartTimestampTimestampTzF64),
291        DateTruncTimestamp(DateTruncUnitsTimestamp),
292        DateTruncTimestampTz(DateTruncUnitsTimestampTz),
293        DateTruncInterval(DateTruncInterval),
294        TimezoneTimestampBinary(TimezoneTimestampBinary),
295        TimezoneTimestampTzBinary(TimezoneTimestampTzBinary),
296        TimezoneIntervalTimestampBinary(TimezoneIntervalTimestampBinary),
297        TimezoneIntervalTimestampTzBinary(TimezoneIntervalTimestampTzBinary),
298        TimezoneIntervalTimeBinary(TimezoneIntervalTimeBinary),
299        TimezoneOffset(TimezoneOffset),
300        TextConcat(TextConcatBinary),
301        JsonbGetInt64(JsonbGetInt64),
302        JsonbGetInt64Stringify(JsonbGetInt64Stringify),
303        JsonbGetString(JsonbGetString),
304        JsonbGetStringStringify(JsonbGetStringStringify),
305        JsonbGetPath(JsonbGetPath),
306        JsonbGetPathStringify(JsonbGetPathStringify),
307        JsonbContainsString(JsonbContainsString),
308        JsonbConcat(JsonbConcat),
309        JsonbContainsJsonb(JsonbContainsJsonb),
310        JsonbDeleteInt64(JsonbDeleteInt64),
311        JsonbDeleteString(JsonbDeleteString),
312        MapContainsKey(MapContainsKey),
313        MapGetValue(MapGetValue),
314        MapContainsAllKeys(MapContainsAllKeys),
315        MapContainsAnyKeys(MapContainsAnyKeys),
316        MapContainsMap(MapContainsMap),
317        ConvertFrom(ConvertFrom),
318        Left(Left),
319        Position(Position),
320        Strpos(Strpos),
321        Right(Right),
322        RepeatString(RepeatString),
323        Normalize(Normalize),
324        Trim(Trim),
325        TrimLeading(TrimLeading),
326        TrimTrailing(TrimTrailing),
327        EncodedBytesCharLength(EncodedBytesCharLength),
328        ListLengthMax(ListLengthMax),
329        ArrayContains(ArrayContains),
330        ArrayContainsArray(ArrayContainsArray),
331        ArrayContainsArrayRev(ArrayContainsArrayRev),
332        ArrayLength(ArrayLength),
333        ArrayLower(ArrayLower),
334        ArrayRemove(ArrayRemove),
335        ArrayUpper(ArrayUpper),
336        ArrayArrayConcat(ArrayArrayConcat),
337        ListListConcat(ListListConcat),
338        ListElementConcat(ListElementConcat),
339        ElementListConcat(ElementListConcat),
340        ListRemove(ListRemove),
341        ListContainsList(ListContainsList),
342        ListContainsListRev(ListContainsListRev),
343        DigestString(DigestString),
344        DigestBytes(DigestBytes),
345        MzRenderTypmod(MzRenderTypmod),
346        Encode(Encode),
347        Decode(Decode),
348        LogNumeric(LogBaseNumeric),
349        Power(Power),
350        PowerNumeric(PowerNumeric),
351        GetBit(GetBit),
352        GetByte(GetByte),
353        ConstantTimeEqBytes(ConstantTimeEqBytes),
354        ConstantTimeEqString(ConstantTimeEqString),
355        RangeContainsDate(RangeContainsDate),
356        RangeContainsDateRev(RangeContainsDateRev),
357        RangeContainsI32(RangeContainsI32),
358        RangeContainsI32Rev(RangeContainsI32Rev),
359        RangeContainsI64(RangeContainsI64),
360        RangeContainsI64Rev(RangeContainsI64Rev),
361        RangeContainsNumeric(RangeContainsNumeric),
362        RangeContainsNumericRev(RangeContainsNumericRev),
363        RangeContainsRange(RangeContainsRange),
364        RangeContainsRangeRev(RangeContainsRangeRev),
365        RangeContainsTimestamp(RangeContainsTimestamp),
366        RangeContainsTimestampRev(RangeContainsTimestampRev),
367        RangeContainsTimestampTz(RangeContainsTimestampTz),
368        RangeContainsTimestampTzRev(RangeContainsTimestampTzRev),
369        RangeOverlaps(RangeOverlaps),
370        RangeAfter(RangeAfter),
371        RangeBefore(RangeBefore),
372        RangeOverleft(RangeOverleft),
373        RangeOverright(RangeOverright),
374        RangeAdjacent(RangeAdjacent),
375        RangeUnion(RangeUnion),
376        RangeIntersection(RangeIntersection),
377        RangeDifference(RangeDifference),
378        UuidGenerateV5(UuidGenerateV5),
379        MzAclItemContainsPrivilege(MzAclItemContainsPrivilege),
380        ParseIdent(ParseIdent),
381        PrettySql(PrettySql),
382        RegexpReplace(RegexpReplace),
383        StartsWith(StartsWith),
384    }
385}
386
387#[cfg(test)]
388mod test {
389    use mz_expr_derive::sqlfunc;
390    use mz_repr::SqlScalarType;
391
392    use crate::EvalError;
393    use crate::scalar::func::binary::LazyBinaryFunc;
394
395    #[sqlfunc(sqlname = "INFALLIBLE", is_infix_op = true, test = true)]
396    #[allow(dead_code)]
397    fn infallible1(a: f32, b: f32) -> f32 {
398        a + b
399    }
400
401    #[sqlfunc(test = true)]
402    #[allow(dead_code)]
403    fn infallible2(a: Option<f32>, b: Option<f32>) -> f32 {
404        a.unwrap_or_default() + b.unwrap_or_default()
405    }
406
407    #[sqlfunc(test = true)]
408    #[allow(dead_code)]
409    fn infallible3(a: f32, b: f32) -> Option<f32> {
410        Some(a + b)
411    }
412
413    #[mz_ore::test]
414    fn elision_rules_infallible() {
415        assert_eq!(format!("{}", Infallible1), "INFALLIBLE");
416        assert!(Infallible1.propagates_nulls());
417        assert!(!Infallible1.introduces_nulls());
418
419        assert!(!Infallible2.propagates_nulls());
420        assert!(!Infallible2.introduces_nulls());
421
422        assert!(Infallible3.propagates_nulls());
423        assert!(Infallible3.introduces_nulls());
424    }
425
426    #[mz_ore::test]
427    fn output_types_infallible() {
428        assert_eq!(
429            Infallible1.output_type(&[
430                SqlScalarType::Float32.nullable(true),
431                SqlScalarType::Float32.nullable(true)
432            ]),
433            SqlScalarType::Float32.nullable(true)
434        );
435        assert_eq!(
436            Infallible1.output_type(&[
437                SqlScalarType::Float32.nullable(true),
438                SqlScalarType::Float32.nullable(false)
439            ]),
440            SqlScalarType::Float32.nullable(true)
441        );
442        assert_eq!(
443            Infallible1.output_type(&[
444                SqlScalarType::Float32.nullable(false),
445                SqlScalarType::Float32.nullable(true)
446            ]),
447            SqlScalarType::Float32.nullable(true)
448        );
449        assert_eq!(
450            Infallible1.output_type(&[
451                SqlScalarType::Float32.nullable(false),
452                SqlScalarType::Float32.nullable(false)
453            ]),
454            SqlScalarType::Float32.nullable(false)
455        );
456
457        assert_eq!(
458            Infallible2.output_type(&[
459                SqlScalarType::Float32.nullable(true),
460                SqlScalarType::Float32.nullable(true)
461            ]),
462            SqlScalarType::Float32.nullable(false)
463        );
464        assert_eq!(
465            Infallible2.output_type(&[
466                SqlScalarType::Float32.nullable(true),
467                SqlScalarType::Float32.nullable(false)
468            ]),
469            SqlScalarType::Float32.nullable(false)
470        );
471        assert_eq!(
472            Infallible2.output_type(&[
473                SqlScalarType::Float32.nullable(false),
474                SqlScalarType::Float32.nullable(true)
475            ]),
476            SqlScalarType::Float32.nullable(false)
477        );
478        assert_eq!(
479            Infallible2.output_type(&[
480                SqlScalarType::Float32.nullable(false),
481                SqlScalarType::Float32.nullable(false)
482            ]),
483            SqlScalarType::Float32.nullable(false)
484        );
485
486        assert_eq!(
487            Infallible3.output_type(&[
488                SqlScalarType::Float32.nullable(true),
489                SqlScalarType::Float32.nullable(true)
490            ]),
491            SqlScalarType::Float32.nullable(true)
492        );
493        assert_eq!(
494            Infallible3.output_type(&[
495                SqlScalarType::Float32.nullable(true),
496                SqlScalarType::Float32.nullable(false)
497            ]),
498            SqlScalarType::Float32.nullable(true)
499        );
500        assert_eq!(
501            Infallible3.output_type(&[
502                SqlScalarType::Float32.nullable(false),
503                SqlScalarType::Float32.nullable(true)
504            ]),
505            SqlScalarType::Float32.nullable(true)
506        );
507        assert_eq!(
508            Infallible3.output_type(&[
509                SqlScalarType::Float32.nullable(false),
510                SqlScalarType::Float32.nullable(false)
511            ]),
512            SqlScalarType::Float32.nullable(true)
513        );
514    }
515
516    #[sqlfunc(test = true)]
517    #[allow(dead_code)]
518    fn fallible1(a: f32, b: f32) -> Result<f32, EvalError> {
519        Ok(a + b)
520    }
521
522    #[sqlfunc(test = true)]
523    #[allow(dead_code)]
524    fn fallible2(a: Option<f32>, b: Option<f32>) -> Result<f32, EvalError> {
525        Ok(a.unwrap_or_default() + b.unwrap_or_default())
526    }
527
528    #[sqlfunc(test = true)]
529    #[allow(dead_code)]
530    fn fallible3(a: f32, b: f32) -> Result<Option<f32>, EvalError> {
531        Ok(Some(a + b))
532    }
533
534    #[mz_ore::test]
535    fn elision_rules_fallible() {
536        assert!(Fallible1.propagates_nulls());
537        assert!(!Fallible1.introduces_nulls());
538
539        assert!(!Fallible2.propagates_nulls());
540        assert!(!Fallible2.introduces_nulls());
541
542        assert!(Fallible3.propagates_nulls());
543        assert!(Fallible3.introduces_nulls());
544    }
545
546    #[mz_ore::test]
547    fn output_types_fallible() {
548        assert_eq!(
549            Fallible1.output_type(&[
550                SqlScalarType::Float32.nullable(true),
551                SqlScalarType::Float32.nullable(true)
552            ]),
553            SqlScalarType::Float32.nullable(true)
554        );
555        assert_eq!(
556            Fallible1.output_type(&[
557                SqlScalarType::Float32.nullable(true),
558                SqlScalarType::Float32.nullable(false)
559            ]),
560            SqlScalarType::Float32.nullable(true)
561        );
562        assert_eq!(
563            Fallible1.output_type(&[
564                SqlScalarType::Float32.nullable(false),
565                SqlScalarType::Float32.nullable(true)
566            ]),
567            SqlScalarType::Float32.nullable(true)
568        );
569        assert_eq!(
570            Fallible1.output_type(&[
571                SqlScalarType::Float32.nullable(false),
572                SqlScalarType::Float32.nullable(false)
573            ]),
574            SqlScalarType::Float32.nullable(false)
575        );
576
577        assert_eq!(
578            Fallible2.output_type(&[
579                SqlScalarType::Float32.nullable(true),
580                SqlScalarType::Float32.nullable(true)
581            ]),
582            SqlScalarType::Float32.nullable(false)
583        );
584        assert_eq!(
585            Fallible2.output_type(&[
586                SqlScalarType::Float32.nullable(true),
587                SqlScalarType::Float32.nullable(false)
588            ]),
589            SqlScalarType::Float32.nullable(false)
590        );
591        assert_eq!(
592            Fallible2.output_type(&[
593                SqlScalarType::Float32.nullable(false),
594                SqlScalarType::Float32.nullable(true)
595            ]),
596            SqlScalarType::Float32.nullable(false)
597        );
598        assert_eq!(
599            Fallible2.output_type(&[
600                SqlScalarType::Float32.nullable(false),
601                SqlScalarType::Float32.nullable(false)
602            ]),
603            SqlScalarType::Float32.nullable(false)
604        );
605
606        assert_eq!(
607            Fallible3.output_type(&[
608                SqlScalarType::Float32.nullable(true),
609                SqlScalarType::Float32.nullable(true)
610            ]),
611            SqlScalarType::Float32.nullable(true)
612        );
613        assert_eq!(
614            Fallible3.output_type(&[
615                SqlScalarType::Float32.nullable(true),
616                SqlScalarType::Float32.nullable(false)
617            ]),
618            SqlScalarType::Float32.nullable(true)
619        );
620        assert_eq!(
621            Fallible3.output_type(&[
622                SqlScalarType::Float32.nullable(false),
623                SqlScalarType::Float32.nullable(true)
624            ]),
625            SqlScalarType::Float32.nullable(true)
626        );
627        assert_eq!(
628            Fallible3.output_type(&[
629                SqlScalarType::Float32.nullable(false),
630                SqlScalarType::Float32.nullable(false)
631            ]),
632            SqlScalarType::Float32.nullable(true)
633        );
634    }
635
636    #[mz_ore::test]
637    fn mz_reflect_binary_func() {
638        use crate::BinaryFunc;
639        use mz_lowertest::{MzReflect, ReflectedTypeInfo};
640
641        let mut rti = ReflectedTypeInfo::default();
642        BinaryFunc::add_to_reflected_type_info(&mut rti);
643
644        // Check that the enum is registered
645        let variants = rti
646            .enum_dict
647            .get("BinaryFunc")
648            .expect("BinaryFunc should be in enum_dict");
649        assert!(
650            variants.contains_key("AddInt64"),
651            "AddInt64 variant should exist"
652        );
653        assert!(variants.contains_key("Gte"), "Gte variant should exist");
654
655        // Check that inner types are registered in struct_dict
656        assert!(
657            rti.struct_dict.contains_key("AddInt64"),
658            "AddInt64 should be in struct_dict"
659        );
660        assert!(
661            rti.struct_dict.contains_key("Gte"),
662            "Gte should be in struct_dict"
663        );
664
665        // Verify zero-field unit structs
666        let (names, types) = rti.struct_dict.get("AddInt64").unwrap();
667        assert!(names.is_empty(), "AddInt64 should have no field names");
668        assert!(types.is_empty(), "AddInt64 should have no field types");
669    }
670}