1use mz_repr::{Datum, DatumType, RowArena, SqlColumnType};
13
14use crate::{EvalError, MirScalarExpr};
15
16#[allow(unused)]
19pub(crate) trait LazyBinaryFunc {
20 fn eval<'a>(
21 &'a self,
22 datums: &[Datum<'a>],
23 temp_storage: &'a RowArena,
24 a: &'a MirScalarExpr,
25 b: &'a MirScalarExpr,
26 ) -> Result<Datum<'a>, EvalError>;
27
28 fn output_type(
30 &self,
31 input_type_a: SqlColumnType,
32 input_type_b: SqlColumnType,
33 ) -> SqlColumnType;
34
35 fn propagates_nulls(&self) -> bool;
37
38 fn introduces_nulls(&self) -> bool;
40
41 fn could_error(&self) -> bool {
43 true
45 }
46
47 fn negate(&self) -> Option<crate::BinaryFunc>;
49
50 fn is_monotone(&self) -> (bool, bool);
63
64 fn is_infix_op(&self) -> bool;
66}
67
68#[allow(unused)]
69pub(crate) trait EagerBinaryFunc<'a> {
70 type Input1: DatumType<'a, EvalError>;
71 type Input2: DatumType<'a, EvalError>;
72 type Output: DatumType<'a, EvalError>;
73
74 fn call(&self, a: Self::Input1, b: Self::Input2, temp_storage: &'a RowArena) -> Self::Output;
75
76 fn output_type(
78 &self,
79 input_type_a: SqlColumnType,
80 input_type_b: SqlColumnType,
81 ) -> SqlColumnType;
82
83 fn propagates_nulls(&self) -> bool {
85 !Self::Input1::nullable() && !Self::Input2::nullable()
87 }
88
89 fn introduces_nulls(&self) -> bool {
91 Self::Output::nullable()
93 }
94
95 fn could_error(&self) -> bool {
97 Self::Output::fallible()
98 }
99
100 fn negate(&self) -> Option<crate::BinaryFunc> {
102 None
103 }
104
105 fn is_monotone(&self) -> (bool, bool) {
106 (false, false)
107 }
108
109 fn is_infix_op(&self) -> bool {
110 false
111 }
112}
113
114impl<T: for<'a> EagerBinaryFunc<'a>> LazyBinaryFunc for T {
115 fn eval<'a>(
116 &'a self,
117 datums: &[Datum<'a>],
118 temp_storage: &'a RowArena,
119 a: &'a MirScalarExpr,
120 b: &'a MirScalarExpr,
121 ) -> Result<Datum<'a>, EvalError> {
122 let a = a.eval(datums, temp_storage)?;
123 let b = b.eval(datums, temp_storage)?;
124 let a = match T::Input1::try_from_result(Ok(a)) {
125 Ok(input) => input,
127 Err(Ok(datum)) if !datum.is_null() => {
129 return Err(EvalError::Internal("invalid input type".into()));
130 }
131 Err(res) => return res,
133 };
134 let b = match T::Input2::try_from_result(Ok(b)) {
135 Ok(input) => input,
137 Err(Ok(datum)) if !datum.is_null() => {
139 return Err(EvalError::Internal("invalid input type".into()));
140 }
141 Err(res) => return res,
143 };
144 self.call(a, b, temp_storage).into_result(temp_storage)
145 }
146
147 fn output_type(
148 &self,
149 input_type_a: SqlColumnType,
150 input_type_b: SqlColumnType,
151 ) -> SqlColumnType {
152 self.output_type(input_type_a, input_type_b)
153 }
154
155 fn propagates_nulls(&self) -> bool {
156 self.propagates_nulls()
157 }
158
159 fn introduces_nulls(&self) -> bool {
160 self.introduces_nulls()
161 }
162
163 fn could_error(&self) -> bool {
164 self.could_error()
165 }
166
167 fn negate(&self) -> Option<crate::BinaryFunc> {
168 self.negate()
169 }
170
171 fn is_monotone(&self) -> (bool, bool) {
172 self.is_monotone()
173 }
174
175 fn is_infix_op(&self) -> bool {
176 self.is_infix_op()
177 }
178}
179
180mod derive {
181 use crate::scalar::func::*;
182
183 derive_binary_from! {
184 AddDateInterval,
185 AddDateTime,
186 AddFloat32,
187 AddFloat64,
188 AddInt16,
189 AddInt32,
190 AddInt64,
191 AddInterval,
192 AddNumeric,
193 AddTimeInterval,
194 AddTimestampInterval,
195 AddTimestampTzInterval,
196 AddUint16,
197 AddUint32,
198 AddUint64,
199 AgeTimestamp,
200 AgeTimestampTz,
201 ArrayArrayConcat,
202 ArrayContains,
203 ArrayLength,
205 ArrayLower,
206 ArrayRemove,
207 ArrayUpper,
208 BitAndInt16,
209 BitAndInt32,
210 BitAndInt64,
211 BitAndUint16,
212 BitAndUint32,
213 BitAndUint64,
214 BitOrInt16,
215 BitOrInt32,
216 BitOrInt64,
217 BitOrUint16,
218 BitOrUint32,
219 BitOrUint64,
220 BitShiftLeftInt16,
221 BitShiftLeftInt32,
222 BitShiftLeftInt64,
223 BitShiftLeftUint16,
224 BitShiftLeftUint32,
225 BitShiftLeftUint64,
226 BitShiftRightInt16,
227 BitShiftRightInt32,
228 BitShiftRightInt64,
229 BitShiftRightUint16,
230 BitShiftRightUint32,
231 BitShiftRightUint64,
232 BitXorInt16,
233 BitXorInt32,
234 BitXorInt64,
235 BitXorUint16,
236 BitXorUint32,
237 BitXorUint64,
238 ConstantTimeEqBytes,
239 ConstantTimeEqString,
240 ConvertFrom,
241 DateBinTimestamp,
242 DateBinTimestampTz,
243 DatePartInterval(DatePartIntervalF64),
244 DatePartTime(DatePartTimeF64),
245 DatePartTimestamp(DatePartTimestampTimestampF64),
246 DatePartTimestampTz(DatePartTimestampTimestampTzF64),
247 DateTruncInterval,
248 DateTruncTimestamp(DateTruncUnitsTimestamp),
249 DateTruncTimestampTz(DateTruncUnitsTimestampTz),
250 Decode,
251 DigestBytes,
252 DigestString,
253 DivFloat32,
254 DivFloat64,
255 DivInt16,
256 DivInt32,
257 DivInt64,
258 DivInterval,
259 DivNumeric,
260 DivUint16,
261 DivUint32,
262 DivUint64,
263 ElementListConcat,
264 Encode,
265 EncodedBytesCharLength,
266 Eq,
267 ExtractDate(ExtractDateUnits),
268 ExtractInterval(DatePartIntervalNumeric),
269 ExtractTime(DatePartTimeNumeric),
270 ExtractTimestamp(DatePartTimestampTimestampNumeric),
271 ExtractTimestampTz(DatePartTimestampTimestampTzNumeric),
272 GetBit,
273 GetByte,
274 Gt,
275 Gte,
276 JsonbConcat,
279 JsonbContainsJsonb,
280 JsonbContainsString,
281 JsonbDeleteInt64,
282 JsonbDeleteString,
283 Left,
290 LikeEscape,
291 ListElementConcat,
293 ListListConcat,
295 ListRemove,
296 LogNumeric(LogBaseNumeric),
297 Lt,
298 Lte,
299 MapContainsAllKeys,
300 MapContainsAnyKeys,
301 MapContainsKey,
302 MapContainsMap,
303 MapGetValue,
304 ModFloat32,
305 ModFloat64,
306 ModInt16,
307 ModInt32,
308 ModInt64,
309 ModNumeric,
310 ModUint16,
311 ModUint32,
312 ModUint64,
313 MulFloat32,
314 MulFloat64,
315 MulInt16,
316 MulInt32,
317 MulInt64,
318 MulInterval,
319 MulNumeric,
320 MulUint16,
321 MulUint32,
322 MulUint64,
323 MzAclItemContainsPrivilege,
324 MzRenderTypmod,
325 NotEq,
327 ParseIdent,
328 Position,
329 Power,
330 PowerNumeric,
331 PrettySql,
332 RangeAdjacent,
333 RangeAfter,
334 RangeBefore,
335 RangeDifference,
338 RangeIntersection,
339 RangeOverlaps,
340 RangeOverleft,
341 RangeOverright,
342 RangeUnion,
343 Right,
346 RoundNumeric(RoundNumericBinary),
347 StartsWith,
348 SubDate,
349 SubDateInterval,
350 SubFloat32,
351 SubFloat64,
352 SubInt16,
353 SubInt32,
354 SubInt64,
355 SubInterval,
356 SubNumeric,
357 SubTime,
358 SubTimeInterval,
359 SubTimestamp,
360 SubTimestampInterval,
361 SubTimestampTz,
362 SubTimestampTzInterval,
363 SubUint16,
364 SubUint32,
365 SubUint64,
366 TextConcat(TextConcatBinary),
367 TimezoneOffset,
371 ToCharTimestamp(ToCharTimestampFormat),
374 ToCharTimestampTz(ToCharTimestampTzFormat),
375 Trim,
376 TrimLeading,
377 TrimTrailing,
378 UuidGenerateV5,
379 }
380}
381
382#[cfg(test)]
383mod test {
384 use mz_expr_derive::sqlfunc;
385 use mz_repr::SqlColumnType;
386 use mz_repr::SqlScalarType;
387
388 use crate::scalar::func::binary::LazyBinaryFunc;
389 use crate::{BinaryFunc, EvalError, func};
390
391 #[sqlfunc(sqlname = "INFALLIBLE", is_infix_op = true)]
392 #[allow(dead_code)]
393 fn infallible1(a: f32, b: f32) -> f32 {
394 a + b
395 }
396
397 #[sqlfunc]
398 #[allow(dead_code)]
399 fn infallible2(a: Option<f32>, b: Option<f32>) -> f32 {
400 a.unwrap_or_default() + b.unwrap_or_default()
401 }
402
403 #[sqlfunc]
404 #[allow(dead_code)]
405 fn infallible3(a: f32, b: f32) -> Option<f32> {
406 Some(a + b)
407 }
408
409 #[mz_ore::test]
410 fn elision_rules_infallible() {
411 assert_eq!(format!("{}", Infallible1), "INFALLIBLE");
412 assert!(Infallible1.propagates_nulls());
413 assert!(!Infallible1.introduces_nulls());
414
415 assert!(!Infallible2.propagates_nulls());
416 assert!(!Infallible2.introduces_nulls());
417
418 assert!(Infallible3.propagates_nulls());
419 assert!(Infallible3.introduces_nulls());
420 }
421
422 #[mz_ore::test]
423 fn output_types_infallible() {
424 assert_eq!(
425 Infallible1.output_type(
426 SqlScalarType::Float32.nullable(true),
427 SqlScalarType::Float32.nullable(true)
428 ),
429 SqlScalarType::Float32.nullable(true)
430 );
431 assert_eq!(
432 Infallible1.output_type(
433 SqlScalarType::Float32.nullable(true),
434 SqlScalarType::Float32.nullable(false)
435 ),
436 SqlScalarType::Float32.nullable(true)
437 );
438 assert_eq!(
439 Infallible1.output_type(
440 SqlScalarType::Float32.nullable(false),
441 SqlScalarType::Float32.nullable(true)
442 ),
443 SqlScalarType::Float32.nullable(true)
444 );
445 assert_eq!(
446 Infallible1.output_type(
447 SqlScalarType::Float32.nullable(false),
448 SqlScalarType::Float32.nullable(false)
449 ),
450 SqlScalarType::Float32.nullable(false)
451 );
452
453 assert_eq!(
454 Infallible2.output_type(
455 SqlScalarType::Float32.nullable(true),
456 SqlScalarType::Float32.nullable(true)
457 ),
458 SqlScalarType::Float32.nullable(false)
459 );
460 assert_eq!(
461 Infallible2.output_type(
462 SqlScalarType::Float32.nullable(true),
463 SqlScalarType::Float32.nullable(false)
464 ),
465 SqlScalarType::Float32.nullable(false)
466 );
467 assert_eq!(
468 Infallible2.output_type(
469 SqlScalarType::Float32.nullable(false),
470 SqlScalarType::Float32.nullable(true)
471 ),
472 SqlScalarType::Float32.nullable(false)
473 );
474 assert_eq!(
475 Infallible2.output_type(
476 SqlScalarType::Float32.nullable(false),
477 SqlScalarType::Float32.nullable(false)
478 ),
479 SqlScalarType::Float32.nullable(false)
480 );
481
482 assert_eq!(
483 Infallible3.output_type(
484 SqlScalarType::Float32.nullable(true),
485 SqlScalarType::Float32.nullable(true)
486 ),
487 SqlScalarType::Float32.nullable(true)
488 );
489 assert_eq!(
490 Infallible3.output_type(
491 SqlScalarType::Float32.nullable(true),
492 SqlScalarType::Float32.nullable(false)
493 ),
494 SqlScalarType::Float32.nullable(true)
495 );
496 assert_eq!(
497 Infallible3.output_type(
498 SqlScalarType::Float32.nullable(false),
499 SqlScalarType::Float32.nullable(true)
500 ),
501 SqlScalarType::Float32.nullable(true)
502 );
503 assert_eq!(
504 Infallible3.output_type(
505 SqlScalarType::Float32.nullable(false),
506 SqlScalarType::Float32.nullable(false)
507 ),
508 SqlScalarType::Float32.nullable(true)
509 );
510 }
511
512 #[sqlfunc]
513 #[allow(dead_code)]
514 fn fallible1(a: f32, b: f32) -> Result<f32, EvalError> {
515 Ok(a + b)
516 }
517
518 #[sqlfunc]
519 #[allow(dead_code)]
520 fn fallible2(a: Option<f32>, b: Option<f32>) -> Result<f32, EvalError> {
521 Ok(a.unwrap_or_default() + b.unwrap_or_default())
522 }
523
524 #[sqlfunc]
525 #[allow(dead_code)]
526 fn fallible3(a: f32, b: f32) -> Result<Option<f32>, EvalError> {
527 Ok(Some(a + b))
528 }
529
530 #[mz_ore::test]
531 fn elision_rules_fallible() {
532 assert!(Fallible1.propagates_nulls());
533 assert!(!Fallible1.introduces_nulls());
534
535 assert!(!Fallible2.propagates_nulls());
536 assert!(!Fallible2.introduces_nulls());
537
538 assert!(Fallible3.propagates_nulls());
539 assert!(Fallible3.introduces_nulls());
540 }
541
542 #[mz_ore::test]
543 fn output_types_fallible() {
544 assert_eq!(
545 Fallible1.output_type(
546 SqlScalarType::Float32.nullable(true),
547 SqlScalarType::Float32.nullable(true)
548 ),
549 SqlScalarType::Float32.nullable(true)
550 );
551 assert_eq!(
552 Fallible1.output_type(
553 SqlScalarType::Float32.nullable(true),
554 SqlScalarType::Float32.nullable(false)
555 ),
556 SqlScalarType::Float32.nullable(true)
557 );
558 assert_eq!(
559 Fallible1.output_type(
560 SqlScalarType::Float32.nullable(false),
561 SqlScalarType::Float32.nullable(true)
562 ),
563 SqlScalarType::Float32.nullable(true)
564 );
565 assert_eq!(
566 Fallible1.output_type(
567 SqlScalarType::Float32.nullable(false),
568 SqlScalarType::Float32.nullable(false)
569 ),
570 SqlScalarType::Float32.nullable(false)
571 );
572
573 assert_eq!(
574 Fallible2.output_type(
575 SqlScalarType::Float32.nullable(true),
576 SqlScalarType::Float32.nullable(true)
577 ),
578 SqlScalarType::Float32.nullable(false)
579 );
580 assert_eq!(
581 Fallible2.output_type(
582 SqlScalarType::Float32.nullable(true),
583 SqlScalarType::Float32.nullable(false)
584 ),
585 SqlScalarType::Float32.nullable(false)
586 );
587 assert_eq!(
588 Fallible2.output_type(
589 SqlScalarType::Float32.nullable(false),
590 SqlScalarType::Float32.nullable(true)
591 ),
592 SqlScalarType::Float32.nullable(false)
593 );
594 assert_eq!(
595 Fallible2.output_type(
596 SqlScalarType::Float32.nullable(false),
597 SqlScalarType::Float32.nullable(false)
598 ),
599 SqlScalarType::Float32.nullable(false)
600 );
601
602 assert_eq!(
603 Fallible3.output_type(
604 SqlScalarType::Float32.nullable(true),
605 SqlScalarType::Float32.nullable(true)
606 ),
607 SqlScalarType::Float32.nullable(true)
608 );
609 assert_eq!(
610 Fallible3.output_type(
611 SqlScalarType::Float32.nullable(true),
612 SqlScalarType::Float32.nullable(false)
613 ),
614 SqlScalarType::Float32.nullable(true)
615 );
616 assert_eq!(
617 Fallible3.output_type(
618 SqlScalarType::Float32.nullable(false),
619 SqlScalarType::Float32.nullable(true)
620 ),
621 SqlScalarType::Float32.nullable(true)
622 );
623 assert_eq!(
624 Fallible3.output_type(
625 SqlScalarType::Float32.nullable(false),
626 SqlScalarType::Float32.nullable(false)
627 ),
628 SqlScalarType::Float32.nullable(true)
629 );
630 }
631
632 #[mz_ore::test]
633 fn test_equivalence_nullable() {
634 test_equivalence_inner(true);
635 }
636
637 #[mz_ore::test]
638 fn test_equivalence_non_nullable() {
639 test_equivalence_inner(false);
640 }
641
642 fn test_equivalence_inner(input_nullable: bool) {
646 #[track_caller]
647 fn check<T: LazyBinaryFunc + std::fmt::Display + std::fmt::Debug>(
648 new: T,
649 old: BinaryFunc,
650 column_a_ty: &SqlColumnType,
651 column_b_ty: &SqlColumnType,
652 ) {
653 assert_eq!(
654 new.propagates_nulls(),
655 old.propagates_nulls(),
656 "{new:?} propagates_nulls mismatch"
657 );
658 assert_eq!(
659 new.introduces_nulls(),
660 old.introduces_nulls(),
661 "{new:?} introduces_nulls mismatch"
662 );
663 assert_eq!(
664 new.could_error(),
665 old.could_error(),
666 "{new:?} could_error mismatch"
667 );
668 assert_eq!(
669 new.is_monotone(),
670 old.is_monotone(),
671 "{new:?} is_monotone mismatch"
672 );
673 assert_eq!(
674 new.is_infix_op(),
675 old.is_infix_op(),
676 "{new:?} is_infix_op mismatch"
677 );
678 assert_eq!(
679 new.output_type(column_a_ty.clone(), column_b_ty.clone()),
680 old.output_type(column_a_ty.clone(), column_b_ty.clone()),
681 "{new:?} output_type mismatch"
682 );
683 assert_eq!(new.negate(), old.negate(), "{new:?} negate mismatch");
684 assert_eq!(
685 format!("{}", new),
686 format!("{}", old),
687 "{new:?} format mismatch"
688 );
689 }
690
691 let i32_ty = SqlColumnType {
692 nullable: input_nullable,
693 scalar_type: SqlScalarType::Int32,
694 };
695
696 use BinaryFunc as BF;
697
698 check(
702 func::RangeContainsI32,
703 BF::RangeContainsElem {
704 elem_type: SqlScalarType::Int32,
705 rev: false,
706 },
707 &i32_ty,
708 &i32_ty,
709 );
710 check(
711 func::RangeContainsI64,
712 BF::RangeContainsElem {
713 elem_type: SqlScalarType::Int64,
714 rev: false,
715 },
716 &i32_ty,
717 &i32_ty,
718 );
719 check(
720 func::RangeContainsDate,
721 BF::RangeContainsElem {
722 elem_type: SqlScalarType::Date,
723 rev: false,
724 },
725 &i32_ty,
726 &i32_ty,
727 );
728 check(
729 func::RangeContainsNumeric,
730 BF::RangeContainsElem {
731 elem_type: SqlScalarType::Numeric { max_scale: None },
732 rev: false,
733 },
734 &i32_ty,
735 &i32_ty,
736 );
737 check(
738 func::RangeContainsTimestamp,
739 BF::RangeContainsElem {
740 elem_type: SqlScalarType::Timestamp { precision: None },
741 rev: false,
742 },
743 &i32_ty,
744 &i32_ty,
745 );
746 check(
747 func::RangeContainsTimestampTz,
748 BF::RangeContainsElem {
749 elem_type: SqlScalarType::TimestampTz { precision: None },
750 rev: false,
751 },
752 &i32_ty,
753 &i32_ty,
754 );
755 check(
756 func::RangeContainsI32Rev,
757 BF::RangeContainsElem {
758 elem_type: SqlScalarType::Int32,
759 rev: true,
760 },
761 &i32_ty,
762 &i32_ty,
763 );
764 check(
765 func::RangeContainsI64Rev,
766 BF::RangeContainsElem {
767 elem_type: SqlScalarType::Int64,
768 rev: true,
769 },
770 &i32_ty,
771 &i32_ty,
772 );
773 check(
774 func::RangeContainsDateRev,
775 BF::RangeContainsElem {
776 elem_type: SqlScalarType::Date,
777 rev: true,
778 },
779 &i32_ty,
780 &i32_ty,
781 );
782 check(
783 func::RangeContainsNumericRev,
784 BF::RangeContainsElem {
785 elem_type: SqlScalarType::Numeric { max_scale: None },
786 rev: true,
787 },
788 &i32_ty,
789 &i32_ty,
790 );
791 check(
792 func::RangeContainsTimestampRev,
793 BF::RangeContainsElem {
794 elem_type: SqlScalarType::Timestamp { precision: None },
795 rev: true,
796 },
797 &i32_ty,
798 &i32_ty,
799 );
800 check(
801 func::RangeContainsTimestampTzRev,
802 BF::RangeContainsElem {
803 elem_type: SqlScalarType::TimestampTz { precision: None },
804 rev: true,
805 },
806 &i32_ty,
807 &i32_ty,
808 );
809
810 check(
811 func::RangeContainsRange,
812 BF::RangeContainsRange { rev: false },
813 &i32_ty,
814 &i32_ty,
815 );
816 check(
817 func::RangeContainsRangeRev,
818 BF::RangeContainsRange { rev: true },
819 &i32_ty,
820 &i32_ty,
821 );
822
823 check(
828 func::ListContainsList,
829 BF::ListContainsList { rev: false },
830 &i32_ty,
831 &i32_ty,
832 );
833 check(
834 func::ListContainsListRev,
835 BF::ListContainsList { rev: true },
836 &i32_ty,
837 &i32_ty,
838 );
839
840 check(
842 func::ArrayContainsArray,
843 BF::ArrayContainsArray { rev: false },
844 &i32_ty,
845 &i32_ty,
846 );
847 check(
848 func::ArrayContainsArrayRev,
849 BF::ArrayContainsArray { rev: true },
850 &i32_ty,
851 &i32_ty,
852 );
853 }
854}