Skip to main content

mz_expr/scalar/func/
binary.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Utilities for binary functions.
11
12use mz_ore::assert_none;
13use mz_repr::{Datum, InputDatumType, OutputDatumType, ReprColumnType, RowArena, SqlColumnType};
14
15use crate::{EvalError, MirScalarExpr};
16
17/// A description of an SQL binary function that has the ability to lazy evaluate its arguments
18// This trait will eventually be annotated with #[enum_dispatch] to autogenerate the UnaryFunc enum
19pub(crate) trait LazyBinaryFunc {
20    fn eval<'a>(
21        &'a self,
22        datums: &[Datum<'a>],
23        temp_storage: &'a RowArena,
24        exprs: &[&'a MirScalarExpr],
25    ) -> Result<Datum<'a>, EvalError>;
26
27    /// The output SqlColumnType of this function.
28    fn output_sql_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType;
29
30    /// A wrapper around [`Self::output_sql_type`] that works with representation types.
31    fn output_type(&self, input_types: &[ReprColumnType]) -> ReprColumnType {
32        ReprColumnType::from(
33            &self.output_sql_type(
34                &input_types
35                    .iter()
36                    .map(SqlColumnType::from_repr)
37                    .collect::<Vec<_>>(),
38            ),
39        )
40    }
41
42    /// Whether this function will produce NULL on NULL input.
43    fn propagates_nulls(&self) -> bool;
44
45    /// Whether this function will produce NULL on non-NULL input.
46    fn introduces_nulls(&self) -> bool;
47
48    /// Whether this function might error on non-error input.
49    fn could_error(&self) -> bool {
50        // NB: override this for functions that never error.
51        true
52    }
53
54    /// Returns the negation of the function, if one exists.
55    fn negate(&self) -> Option<crate::BinaryFunc>;
56
57    /// Returns true if the function is monotone. (Non-strict; either increasing or decreasing.)
58    /// Monotone functions map ranges to ranges: ie. given a range of possible inputs, we can
59    /// determine the range of possible outputs just by mapping the endpoints.
60    ///
61    /// This describes the *pointwise* behaviour of the function:
62    /// ie. the behaviour of any specific argument as the others are held constant. (For example, `a - b` is
63    /// monotone in the first argument because for any particular value of `b`, increasing `a` will
64    /// always cause the result to increase... and in the second argument because for any specific `a`,
65    /// increasing `b` will always cause the result to _decrease_.)
66    ///
67    /// This property describes the behaviour of the function over ranges where the function is defined:
68    /// ie. the arguments and the result are non-error datums.
69    fn is_monotone(&self) -> (bool, bool);
70
71    /// Yep, I guess this returns true for infix operators.
72    fn is_infix_op(&self) -> bool;
73}
74
75pub(crate) trait EagerBinaryFunc {
76    type Input<'a>: InputDatumType<'a, EvalError>;
77    type Output<'a>: OutputDatumType<'a, EvalError>;
78
79    fn call<'a>(&self, input: Self::Input<'a>, temp_storage: &'a RowArena) -> Self::Output<'a>;
80
81    /// The output SqlColumnType of this function
82    fn output_sql_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType;
83
84    /// The output of this function as a representation type.
85    #[allow(dead_code)]
86    fn output_type(&self, input_types: &[ReprColumnType]) -> ReprColumnType {
87        ReprColumnType::from(
88            &self.output_sql_type(
89                &input_types
90                    .iter()
91                    .map(SqlColumnType::from_repr)
92                    .collect::<Vec<_>>(),
93            ),
94        )
95    }
96
97    /// Whether this function will produce NULL on NULL input
98    fn propagates_nulls(&self) -> bool {
99        // If the inputs are not nullable then nulls are propagated
100        !Self::Input::nullable()
101    }
102
103    /// Whether this function will produce NULL on non-NULL input
104    fn introduces_nulls(&self) -> bool {
105        // If the output is nullable then nulls can be introduced
106        Self::Output::nullable()
107    }
108
109    /// Whether this function could produce an error
110    fn could_error(&self) -> bool {
111        Self::Output::fallible()
112    }
113
114    /// Returns the negation of the given binary function, if it exists.
115    fn negate(&self) -> Option<crate::BinaryFunc> {
116        None
117    }
118
119    fn is_monotone(&self) -> (bool, bool) {
120        (false, false)
121    }
122
123    fn is_infix_op(&self) -> bool {
124        false
125    }
126}
127
128impl<T: EagerBinaryFunc> LazyBinaryFunc for T {
129    fn eval<'a>(
130        &'a self,
131        datums: &[Datum<'a>],
132        temp_storage: &'a RowArena,
133        exprs: &[&'a MirScalarExpr],
134    ) -> Result<Datum<'a>, EvalError> {
135        let mut datums = exprs
136            .into_iter()
137            .map(|expr| expr.eval(datums, temp_storage));
138        let input = match T::Input::try_from_iter(&mut datums) {
139            // If we can convert to the input type then we call the function
140            Ok(input) => input,
141            // If we can't and we got a non-null datum something went wrong in the planner
142            Err(Ok(Some(datum))) if !datum.is_null() => {
143                return Err(EvalError::Internal("invalid input type".into()));
144            }
145            Err(Ok(None)) => {
146                return Err(EvalError::Internal("unexpectedly missing parameter".into()));
147            }
148            // Otherwise we just propagate NULLs and errors
149            Err(Ok(Some(datum))) => return Ok(datum),
150            Err(Err(res)) => return Err(res),
151        };
152        assert_none!(datums.next(), "No leftover input arguments");
153        self.call(input, temp_storage).into_result(temp_storage)
154    }
155
156    fn output_sql_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
157        self.output_sql_type(input_types)
158    }
159
160    fn propagates_nulls(&self) -> bool {
161        self.propagates_nulls()
162    }
163
164    fn introduces_nulls(&self) -> bool {
165        self.introduces_nulls()
166    }
167
168    fn could_error(&self) -> bool {
169        self.could_error()
170    }
171
172    fn negate(&self) -> Option<crate::BinaryFunc> {
173        self.negate()
174    }
175
176    fn is_monotone(&self) -> (bool, bool) {
177        self.is_monotone()
178    }
179
180    fn is_infix_op(&self) -> bool {
181        self.is_infix_op()
182    }
183}
184
185pub use derive::BinaryFunc;
186
187mod derive {
188    use std::fmt;
189
190    use mz_repr::{Datum, ReprColumnType, RowArena, SqlColumnType};
191
192    use crate::scalar::func::binary::LazyBinaryFunc;
193    use crate::scalar::func::*;
194    use crate::{EvalError, MirScalarExpr};
195
196    derive_binary! {
197        AddInt16(AddInt16),
198        AddInt32(AddInt32),
199        AddInt64(AddInt64),
200        AddUint16(AddUint16),
201        AddUint32(AddUint32),
202        AddUint64(AddUint64),
203        AddFloat32(AddFloat32),
204        AddFloat64(AddFloat64),
205        AddInterval(AddInterval),
206        AddTimestampInterval(AddTimestampInterval),
207        AddTimestampTzInterval(AddTimestampTzInterval),
208        AddDateInterval(AddDateInterval),
209        AddDateTime(AddDateTime),
210        AddTimeInterval(AddTimeInterval),
211        AddNumeric(AddNumeric),
212        AgeTimestamp(AgeTimestamp),
213        AgeTimestampTz(AgeTimestampTz),
214        BitAndInt16(BitAndInt16),
215        BitAndInt32(BitAndInt32),
216        BitAndInt64(BitAndInt64),
217        BitAndUint16(BitAndUint16),
218        BitAndUint32(BitAndUint32),
219        BitAndUint64(BitAndUint64),
220        BitOrInt16(BitOrInt16),
221        BitOrInt32(BitOrInt32),
222        BitOrInt64(BitOrInt64),
223        BitOrUint16(BitOrUint16),
224        BitOrUint32(BitOrUint32),
225        BitOrUint64(BitOrUint64),
226        BitXorInt16(BitXorInt16),
227        BitXorInt32(BitXorInt32),
228        BitXorInt64(BitXorInt64),
229        BitXorUint16(BitXorUint16),
230        BitXorUint32(BitXorUint32),
231        BitXorUint64(BitXorUint64),
232        BitShiftLeftInt16(BitShiftLeftInt16),
233        BitShiftLeftInt32(BitShiftLeftInt32),
234        BitShiftLeftInt64(BitShiftLeftInt64),
235        BitShiftLeftUint16(BitShiftLeftUint16),
236        BitShiftLeftUint32(BitShiftLeftUint32),
237        BitShiftLeftUint64(BitShiftLeftUint64),
238        BitShiftRightInt16(BitShiftRightInt16),
239        BitShiftRightInt32(BitShiftRightInt32),
240        BitShiftRightInt64(BitShiftRightInt64),
241        BitShiftRightUint16(BitShiftRightUint16),
242        BitShiftRightUint32(BitShiftRightUint32),
243        BitShiftRightUint64(BitShiftRightUint64),
244        SubInt16(SubInt16),
245        SubInt32(SubInt32),
246        SubInt64(SubInt64),
247        SubUint16(SubUint16),
248        SubUint32(SubUint32),
249        SubUint64(SubUint64),
250        SubFloat32(SubFloat32),
251        SubFloat64(SubFloat64),
252        SubInterval(SubInterval),
253        SubTimestamp(SubTimestamp),
254        SubTimestampTz(SubTimestampTz),
255        SubTimestampInterval(SubTimestampInterval),
256        SubTimestampTzInterval(SubTimestampTzInterval),
257        SubDate(SubDate),
258        SubDateInterval(SubDateInterval),
259        SubTime(SubTime),
260        SubTimeInterval(SubTimeInterval),
261        SubNumeric(SubNumeric),
262        MulInt16(MulInt16),
263        MulInt32(MulInt32),
264        MulInt64(MulInt64),
265        MulUint16(MulUint16),
266        MulUint32(MulUint32),
267        MulUint64(MulUint64),
268        MulFloat32(MulFloat32),
269        MulFloat64(MulFloat64),
270        MulNumeric(MulNumeric),
271        MulInterval(MulInterval),
272        DivInt16(DivInt16),
273        DivInt32(DivInt32),
274        DivInt64(DivInt64),
275        DivUint16(DivUint16),
276        DivUint32(DivUint32),
277        DivUint64(DivUint64),
278        DivFloat32(DivFloat32),
279        DivFloat64(DivFloat64),
280        DivNumeric(DivNumeric),
281        DivInterval(DivInterval),
282        ModInt16(ModInt16),
283        ModInt32(ModInt32),
284        ModInt64(ModInt64),
285        ModUint16(ModUint16),
286        ModUint32(ModUint32),
287        ModUint64(ModUint64),
288        ModFloat32(ModFloat32),
289        ModFloat64(ModFloat64),
290        ModNumeric(ModNumeric),
291        RoundNumeric(RoundNumericBinary),
292        Eq(Eq),
293        NotEq(NotEq),
294        Lt(Lt),
295        Lte(Lte),
296        Gt(Gt),
297        Gte(Gte),
298        LikeEscape(LikeEscape),
299        IsLikeMatchCaseInsensitive(IsLikeMatchCaseInsensitive),
300        IsLikeMatchCaseSensitive(IsLikeMatchCaseSensitive),
301        IsRegexpMatchCaseSensitive(IsRegexpMatchCaseSensitive),
302        IsRegexpMatchCaseInsensitive(IsRegexpMatchCaseInsensitive),
303        ToCharTimestamp(ToCharTimestampFormat),
304        ToCharTimestampTz(ToCharTimestampTzFormat),
305        DateBinTimestamp(DateBinTimestamp),
306        DateBinTimestampTz(DateBinTimestampTz),
307        ExtractInterval(DatePartIntervalNumeric),
308        ExtractTime(DatePartTimeNumeric),
309        ExtractTimestamp(DatePartTimestampTimestampNumeric),
310        ExtractTimestampTz(DatePartTimestampTimestampTzNumeric),
311        ExtractDate(ExtractDateUnits),
312        DatePartInterval(DatePartIntervalF64),
313        DatePartTime(DatePartTimeF64),
314        DatePartTimestamp(DatePartTimestampTimestampF64),
315        DatePartTimestampTz(DatePartTimestampTimestampTzF64),
316        DateTruncTimestamp(DateTruncUnitsTimestamp),
317        DateTruncTimestampTz(DateTruncUnitsTimestampTz),
318        DateTruncInterval(DateTruncInterval),
319        TimezoneTimestampBinary(TimezoneTimestampBinary),
320        TimezoneTimestampTzBinary(TimezoneTimestampTzBinary),
321        TimezoneIntervalTimestampBinary(TimezoneIntervalTimestampBinary),
322        TimezoneIntervalTimestampTzBinary(TimezoneIntervalTimestampTzBinary),
323        TimezoneIntervalTimeBinary(TimezoneIntervalTimeBinary),
324        TimezoneOffset(TimezoneOffset),
325        TextConcat(TextConcatBinary),
326        JsonbGetInt64(JsonbGetInt64),
327        JsonbGetInt64Stringify(JsonbGetInt64Stringify),
328        JsonbGetString(JsonbGetString),
329        JsonbGetStringStringify(JsonbGetStringStringify),
330        JsonbGetPath(JsonbGetPath),
331        JsonbGetPathStringify(JsonbGetPathStringify),
332        JsonbContainsString(JsonbContainsString),
333        JsonbConcat(JsonbConcat),
334        JsonbContainsJsonb(JsonbContainsJsonb),
335        JsonbDeleteInt64(JsonbDeleteInt64),
336        JsonbDeleteString(JsonbDeleteString),
337        MapContainsKey(MapContainsKey),
338        MapGetValue(MapGetValue),
339        MapContainsAllKeys(MapContainsAllKeys),
340        MapContainsAnyKeys(MapContainsAnyKeys),
341        MapContainsMap(MapContainsMap),
342        ConvertFrom(ConvertFrom),
343        Left(Left),
344        Position(Position),
345        Strpos(Strpos),
346        Right(Right),
347        RepeatString(RepeatString),
348        Normalize(Normalize),
349        Trim(Trim),
350        TrimLeading(TrimLeading),
351        TrimTrailing(TrimTrailing),
352        EncodedBytesCharLength(EncodedBytesCharLength),
353        ListLengthMax(ListLengthMax),
354        ArrayContains(ArrayContains),
355        ArrayContainsArray(ArrayContainsArray),
356        ArrayContainsArrayRev(ArrayContainsArrayRev),
357        ArrayLength(ArrayLength),
358        ArrayLower(ArrayLower),
359        ArrayRemove(ArrayRemove),
360        ArrayUpper(ArrayUpper),
361        ArrayArrayConcat(ArrayArrayConcat),
362        ListListConcat(ListListConcat),
363        ListElementConcat(ListElementConcat),
364        ElementListConcat(ElementListConcat),
365        ListRemove(ListRemove),
366        ListContainsList(ListContainsList),
367        ListContainsListRev(ListContainsListRev),
368        DigestString(DigestString),
369        DigestBytes(DigestBytes),
370        MzRenderTypmod(MzRenderTypmod),
371        Encode(Encode),
372        Decode(Decode),
373        LogNumeric(LogBaseNumeric),
374        Power(Power),
375        PowerNumeric(PowerNumeric),
376        GetBit(GetBit),
377        GetByte(GetByte),
378        ConstantTimeEqBytes(ConstantTimeEqBytes),
379        ConstantTimeEqString(ConstantTimeEqString),
380        RangeContainsDate(RangeContainsDate),
381        RangeContainsDateRev(RangeContainsDateRev),
382        RangeContainsI32(RangeContainsI32),
383        RangeContainsI32Rev(RangeContainsI32Rev),
384        RangeContainsI64(RangeContainsI64),
385        RangeContainsI64Rev(RangeContainsI64Rev),
386        RangeContainsNumeric(RangeContainsNumeric),
387        RangeContainsNumericRev(RangeContainsNumericRev),
388        RangeContainsRange(RangeContainsRange),
389        RangeContainsRangeRev(RangeContainsRangeRev),
390        RangeContainsTimestamp(RangeContainsTimestamp),
391        RangeContainsTimestampRev(RangeContainsTimestampRev),
392        RangeContainsTimestampTz(RangeContainsTimestampTz),
393        RangeContainsTimestampTzRev(RangeContainsTimestampTzRev),
394        RangeOverlaps(RangeOverlaps),
395        RangeAfter(RangeAfter),
396        RangeBefore(RangeBefore),
397        RangeOverleft(RangeOverleft),
398        RangeOverright(RangeOverright),
399        RangeAdjacent(RangeAdjacent),
400        RangeUnion(RangeUnion),
401        RangeIntersection(RangeIntersection),
402        RangeDifference(RangeDifference),
403        UuidGenerateV5(UuidGenerateV5),
404        MzAclItemContainsPrivilege(MzAclItemContainsPrivilege),
405        ParseIdent(ParseIdent),
406        PrettySql(PrettySql),
407        RegexpReplace(RegexpReplace),
408        StartsWith(StartsWith),
409    }
410}
411
412#[cfg(test)]
413mod test {
414    use mz_expr_derive::sqlfunc;
415    use mz_repr::SqlScalarType;
416
417    use crate::EvalError;
418    use crate::scalar::func::binary::LazyBinaryFunc;
419
420    #[sqlfunc(sqlname = "INFALLIBLE", is_infix_op = true, test = true)]
421    #[allow(dead_code)]
422    fn infallible1(a: f32, b: f32) -> f32 {
423        a + b
424    }
425
426    #[sqlfunc(test = true)]
427    #[allow(dead_code)]
428    fn infallible2(a: Option<f32>, b: Option<f32>) -> f32 {
429        a.unwrap_or_default() + b.unwrap_or_default()
430    }
431
432    #[sqlfunc(test = true)]
433    #[allow(dead_code)]
434    fn infallible3(a: f32, b: f32) -> Option<f32> {
435        Some(a + b)
436    }
437
438    #[mz_ore::test]
439    fn elision_rules_infallible() {
440        assert_eq!(format!("{}", Infallible1), "INFALLIBLE");
441        assert!(Infallible1.propagates_nulls());
442        assert!(!Infallible1.introduces_nulls());
443
444        assert!(!Infallible2.propagates_nulls());
445        assert!(!Infallible2.introduces_nulls());
446
447        assert!(Infallible3.propagates_nulls());
448        assert!(Infallible3.introduces_nulls());
449    }
450
451    #[mz_ore::test]
452    fn output_types_infallible() {
453        assert_eq!(
454            Infallible1.output_sql_type(&[
455                SqlScalarType::Float32.nullable(true),
456                SqlScalarType::Float32.nullable(true)
457            ]),
458            SqlScalarType::Float32.nullable(true)
459        );
460        assert_eq!(
461            Infallible1.output_sql_type(&[
462                SqlScalarType::Float32.nullable(true),
463                SqlScalarType::Float32.nullable(false)
464            ]),
465            SqlScalarType::Float32.nullable(true)
466        );
467        assert_eq!(
468            Infallible1.output_sql_type(&[
469                SqlScalarType::Float32.nullable(false),
470                SqlScalarType::Float32.nullable(true)
471            ]),
472            SqlScalarType::Float32.nullable(true)
473        );
474        assert_eq!(
475            Infallible1.output_sql_type(&[
476                SqlScalarType::Float32.nullable(false),
477                SqlScalarType::Float32.nullable(false)
478            ]),
479            SqlScalarType::Float32.nullable(false)
480        );
481
482        assert_eq!(
483            Infallible2.output_sql_type(&[
484                SqlScalarType::Float32.nullable(true),
485                SqlScalarType::Float32.nullable(true)
486            ]),
487            SqlScalarType::Float32.nullable(false)
488        );
489        assert_eq!(
490            Infallible2.output_sql_type(&[
491                SqlScalarType::Float32.nullable(true),
492                SqlScalarType::Float32.nullable(false)
493            ]),
494            SqlScalarType::Float32.nullable(false)
495        );
496        assert_eq!(
497            Infallible2.output_sql_type(&[
498                SqlScalarType::Float32.nullable(false),
499                SqlScalarType::Float32.nullable(true)
500            ]),
501            SqlScalarType::Float32.nullable(false)
502        );
503        assert_eq!(
504            Infallible2.output_sql_type(&[
505                SqlScalarType::Float32.nullable(false),
506                SqlScalarType::Float32.nullable(false)
507            ]),
508            SqlScalarType::Float32.nullable(false)
509        );
510
511        assert_eq!(
512            Infallible3.output_sql_type(&[
513                SqlScalarType::Float32.nullable(true),
514                SqlScalarType::Float32.nullable(true)
515            ]),
516            SqlScalarType::Float32.nullable(true)
517        );
518        assert_eq!(
519            Infallible3.output_sql_type(&[
520                SqlScalarType::Float32.nullable(true),
521                SqlScalarType::Float32.nullable(false)
522            ]),
523            SqlScalarType::Float32.nullable(true)
524        );
525        assert_eq!(
526            Infallible3.output_sql_type(&[
527                SqlScalarType::Float32.nullable(false),
528                SqlScalarType::Float32.nullable(true)
529            ]),
530            SqlScalarType::Float32.nullable(true)
531        );
532        assert_eq!(
533            Infallible3.output_sql_type(&[
534                SqlScalarType::Float32.nullable(false),
535                SqlScalarType::Float32.nullable(false)
536            ]),
537            SqlScalarType::Float32.nullable(true)
538        );
539    }
540
541    #[sqlfunc(test = true)]
542    #[allow(dead_code)]
543    fn fallible1(a: f32, b: f32) -> Result<f32, EvalError> {
544        Ok(a + b)
545    }
546
547    #[sqlfunc(test = true)]
548    #[allow(dead_code)]
549    fn fallible2(a: Option<f32>, b: Option<f32>) -> Result<f32, EvalError> {
550        Ok(a.unwrap_or_default() + b.unwrap_or_default())
551    }
552
553    #[sqlfunc(test = true)]
554    #[allow(dead_code)]
555    fn fallible3(a: f32, b: f32) -> Result<Option<f32>, EvalError> {
556        Ok(Some(a + b))
557    }
558
559    #[mz_ore::test]
560    fn elision_rules_fallible() {
561        assert!(Fallible1.propagates_nulls());
562        assert!(!Fallible1.introduces_nulls());
563
564        assert!(!Fallible2.propagates_nulls());
565        assert!(!Fallible2.introduces_nulls());
566
567        assert!(Fallible3.propagates_nulls());
568        assert!(Fallible3.introduces_nulls());
569    }
570
571    #[mz_ore::test]
572    fn output_types_fallible() {
573        assert_eq!(
574            Fallible1.output_sql_type(&[
575                SqlScalarType::Float32.nullable(true),
576                SqlScalarType::Float32.nullable(true)
577            ]),
578            SqlScalarType::Float32.nullable(true)
579        );
580        assert_eq!(
581            Fallible1.output_sql_type(&[
582                SqlScalarType::Float32.nullable(true),
583                SqlScalarType::Float32.nullable(false)
584            ]),
585            SqlScalarType::Float32.nullable(true)
586        );
587        assert_eq!(
588            Fallible1.output_sql_type(&[
589                SqlScalarType::Float32.nullable(false),
590                SqlScalarType::Float32.nullable(true)
591            ]),
592            SqlScalarType::Float32.nullable(true)
593        );
594        assert_eq!(
595            Fallible1.output_sql_type(&[
596                SqlScalarType::Float32.nullable(false),
597                SqlScalarType::Float32.nullable(false)
598            ]),
599            SqlScalarType::Float32.nullable(false)
600        );
601
602        assert_eq!(
603            Fallible2.output_sql_type(&[
604                SqlScalarType::Float32.nullable(true),
605                SqlScalarType::Float32.nullable(true)
606            ]),
607            SqlScalarType::Float32.nullable(false)
608        );
609        assert_eq!(
610            Fallible2.output_sql_type(&[
611                SqlScalarType::Float32.nullable(true),
612                SqlScalarType::Float32.nullable(false)
613            ]),
614            SqlScalarType::Float32.nullable(false)
615        );
616        assert_eq!(
617            Fallible2.output_sql_type(&[
618                SqlScalarType::Float32.nullable(false),
619                SqlScalarType::Float32.nullable(true)
620            ]),
621            SqlScalarType::Float32.nullable(false)
622        );
623        assert_eq!(
624            Fallible2.output_sql_type(&[
625                SqlScalarType::Float32.nullable(false),
626                SqlScalarType::Float32.nullable(false)
627            ]),
628            SqlScalarType::Float32.nullable(false)
629        );
630
631        assert_eq!(
632            Fallible3.output_sql_type(&[
633                SqlScalarType::Float32.nullable(true),
634                SqlScalarType::Float32.nullable(true)
635            ]),
636            SqlScalarType::Float32.nullable(true)
637        );
638        assert_eq!(
639            Fallible3.output_sql_type(&[
640                SqlScalarType::Float32.nullable(true),
641                SqlScalarType::Float32.nullable(false)
642            ]),
643            SqlScalarType::Float32.nullable(true)
644        );
645        assert_eq!(
646            Fallible3.output_sql_type(&[
647                SqlScalarType::Float32.nullable(false),
648                SqlScalarType::Float32.nullable(true)
649            ]),
650            SqlScalarType::Float32.nullable(true)
651        );
652        assert_eq!(
653            Fallible3.output_sql_type(&[
654                SqlScalarType::Float32.nullable(false),
655                SqlScalarType::Float32.nullable(false)
656            ]),
657            SqlScalarType::Float32.nullable(true)
658        );
659    }
660
661    #[mz_ore::test]
662    fn mz_reflect_binary_func() {
663        use crate::BinaryFunc;
664        use mz_lowertest::{MzReflect, ReflectedTypeInfo};
665
666        let mut rti = ReflectedTypeInfo::default();
667        BinaryFunc::add_to_reflected_type_info(&mut rti);
668
669        // Check that the enum is registered
670        let variants = rti
671            .enum_dict
672            .get("BinaryFunc")
673            .expect("BinaryFunc should be in enum_dict");
674        assert!(
675            variants.contains_key("AddInt64"),
676            "AddInt64 variant should exist"
677        );
678        assert!(variants.contains_key("Gte"), "Gte variant should exist");
679
680        // Check that inner types are registered in struct_dict
681        assert!(
682            rti.struct_dict.contains_key("AddInt64"),
683            "AddInt64 should be in struct_dict"
684        );
685        assert!(
686            rti.struct_dict.contains_key("Gte"),
687            "Gte should be in struct_dict"
688        );
689
690        // Verify zero-field unit structs
691        let (names, types) = rti.struct_dict.get("AddInt64").unwrap();
692        assert!(names.is_empty(), "AddInt64 should have no field names");
693        assert!(types.is_empty(), "AddInt64 should have no field types");
694    }
695}