1use mz_repr::{Datum, DatumType, RowArena, SqlColumnType};
13
14use crate::{EvalError, MirScalarExpr};
15
16#[allow(unused)]
19pub(crate) trait LazyBinaryFunc {
20 fn eval<'a>(
21 &'a self,
22 datums: &[Datum<'a>],
23 temp_storage: &'a RowArena,
24 a: &'a MirScalarExpr,
25 b: &'a MirScalarExpr,
26 ) -> Result<Datum<'a>, EvalError>;
27
28 fn output_type(
30 &self,
31 input_type_a: SqlColumnType,
32 input_type_b: SqlColumnType,
33 ) -> SqlColumnType;
34
35 fn propagates_nulls(&self) -> bool;
37
38 fn introduces_nulls(&self) -> bool;
40
41 fn could_error(&self) -> bool {
43 true
45 }
46
47 fn negate(&self) -> Option<crate::BinaryFunc>;
49
50 fn is_monotone(&self) -> (bool, bool);
63
64 fn is_infix_op(&self) -> bool;
66}
67
68#[allow(unused)]
69pub(crate) trait EagerBinaryFunc<'a> {
70 type Input1: DatumType<'a, EvalError>;
71 type Input2: DatumType<'a, EvalError>;
72 type Output: DatumType<'a, EvalError>;
73
74 fn call(&self, a: Self::Input1, b: Self::Input2, temp_storage: &'a RowArena) -> Self::Output;
75
76 fn output_type(
78 &self,
79 input_type_a: SqlColumnType,
80 input_type_b: SqlColumnType,
81 ) -> SqlColumnType;
82
83 fn propagates_nulls(&self) -> bool {
85 !Self::Input1::nullable() && !Self::Input2::nullable()
87 }
88
89 fn introduces_nulls(&self) -> bool {
91 Self::Output::nullable()
93 }
94
95 fn could_error(&self) -> bool {
97 Self::Output::fallible()
98 }
99
100 fn negate(&self) -> Option<crate::BinaryFunc> {
102 None
103 }
104
105 fn is_monotone(&self) -> (bool, bool) {
106 (false, false)
107 }
108
109 fn is_infix_op(&self) -> bool {
110 false
111 }
112}
113
114impl<T: for<'a> EagerBinaryFunc<'a>> LazyBinaryFunc for T {
115 fn eval<'a>(
116 &'a self,
117 datums: &[Datum<'a>],
118 temp_storage: &'a RowArena,
119 a: &'a MirScalarExpr,
120 b: &'a MirScalarExpr,
121 ) -> Result<Datum<'a>, EvalError> {
122 let a = a.eval(datums, temp_storage)?;
123 let b = b.eval(datums, temp_storage)?;
124 let a = match T::Input1::try_from_result(Ok(a)) {
125 Ok(input) => input,
127 Err(Ok(datum)) if !datum.is_null() => {
129 return Err(EvalError::Internal("invalid input type".into()));
130 }
131 Err(res) => return res,
133 };
134 let b = match T::Input2::try_from_result(Ok(b)) {
135 Ok(input) => input,
137 Err(Ok(datum)) if !datum.is_null() => {
139 return Err(EvalError::Internal("invalid input type".into()));
140 }
141 Err(res) => return res,
143 };
144 self.call(a, b, temp_storage).into_result(temp_storage)
145 }
146
147 fn output_type(
148 &self,
149 input_type_a: SqlColumnType,
150 input_type_b: SqlColumnType,
151 ) -> SqlColumnType {
152 self.output_type(input_type_a, input_type_b)
153 }
154
155 fn propagates_nulls(&self) -> bool {
156 self.propagates_nulls()
157 }
158
159 fn introduces_nulls(&self) -> bool {
160 self.introduces_nulls()
161 }
162
163 fn could_error(&self) -> bool {
164 self.could_error()
165 }
166
167 fn negate(&self) -> Option<crate::BinaryFunc> {
168 self.negate()
169 }
170
171 fn is_monotone(&self) -> (bool, bool) {
172 self.is_monotone()
173 }
174
175 fn is_infix_op(&self) -> bool {
176 self.is_infix_op()
177 }
178}
179
180mod derive {
181 use crate::scalar::func::*;
182
183 derive_binary_from! {
184 AddDateInterval,
185 AddDateTime,
186 AddFloat32,
187 AddFloat64,
188 AddInt16,
189 AddInt32,
190 AddInt64,
191 AddInterval,
192 AddNumeric,
193 AddTimeInterval,
194 AddTimestampInterval,
195 AddTimestampTzInterval,
196 AddUint16,
197 AddUint32,
198 AddUint64,
199 AgeTimestamp,
200 AgeTimestampTz,
201 ArrayArrayConcat,
202 ArrayContains,
203 ArrayLength,
205 ArrayLower,
206 ArrayRemove,
207 ArrayUpper,
208 BitAndInt16,
209 BitAndInt32,
210 BitAndInt64,
211 BitAndUint16,
212 BitAndUint32,
213 BitAndUint64,
214 BitOrInt16,
215 BitOrInt32,
216 BitOrInt64,
217 BitOrUint16,
218 BitOrUint32,
219 BitOrUint64,
220 BitShiftLeftInt16,
221 BitShiftLeftInt32,
222 BitShiftLeftInt64,
223 BitShiftLeftUint16,
224 BitShiftLeftUint32,
225 BitShiftLeftUint64,
226 BitShiftRightInt16,
227 BitShiftRightInt32,
228 BitShiftRightInt64,
229 BitShiftRightUint16,
230 BitShiftRightUint32,
231 BitShiftRightUint64,
232 BitXorInt16,
233 BitXorInt32,
234 BitXorInt64,
235 BitXorUint16,
236 BitXorUint32,
237 BitXorUint64,
238 ConstantTimeEqBytes,
239 ConstantTimeEqString,
240 ConvertFrom,
241 DateBinTimestamp,
242 DateBinTimestampTz,
243 DatePartInterval(DatePartIntervalF64),
244 DatePartTime(DatePartTimeF64),
245 DatePartTimestamp(DatePartTimestampTimestampF64),
246 DatePartTimestampTz(DatePartTimestampTimestampTzF64),
247 DateTruncInterval,
248 DateTruncTimestamp(DateTruncUnitsTimestamp),
249 DateTruncTimestampTz(DateTruncUnitsTimestampTz),
250 Decode,
251 DigestBytes,
252 DigestString,
253 DivFloat32,
254 DivFloat64,
255 DivInt16,
256 DivInt32,
257 DivInt64,
258 DivInterval,
259 DivNumeric,
260 DivUint16,
261 DivUint32,
262 DivUint64,
263 ElementListConcat,
264 Encode,
265 EncodedBytesCharLength,
266 Eq,
267 ExtractDate(ExtractDateUnits),
268 ExtractInterval(DatePartIntervalNumeric),
269 ExtractTime(DatePartTimeNumeric),
270 ExtractTimestamp(DatePartTimestampTimestampNumeric),
271 ExtractTimestampTz(DatePartTimestampTimestampTzNumeric),
272 GetBit,
273 GetByte,
274 Gt,
275 Gte,
276 IsLikeMatchCaseInsensitive,
277 IsLikeMatchCaseSensitive,
278 JsonbConcat,
280 JsonbContainsJsonb,
281 JsonbContainsString,
282 JsonbDeleteInt64,
283 JsonbDeleteString,
284 Left,
291 LikeEscape,
292 ListElementConcat,
294 ListListConcat,
296 ListRemove,
297 LogNumeric(LogBaseNumeric),
298 Lt,
299 Lte,
300 MapContainsAllKeys,
301 MapContainsAnyKeys,
302 MapContainsKey,
303 MapContainsMap,
304 MapGetValue,
305 ModFloat32,
306 ModFloat64,
307 ModInt16,
308 ModInt32,
309 ModInt64,
310 ModNumeric,
311 ModUint16,
312 ModUint32,
313 ModUint64,
314 MulFloat32,
315 MulFloat64,
316 MulInt16,
317 MulInt32,
318 MulInt64,
319 MulInterval,
320 MulNumeric,
321 MulUint16,
322 MulUint32,
323 MulUint64,
324 MzAclItemContainsPrivilege,
325 MzRenderTypmod,
326 NotEq,
328 ParseIdent,
329 Position,
330 Power,
331 PowerNumeric,
332 PrettySql,
333 RangeAdjacent,
334 RangeAfter,
335 RangeBefore,
336 RangeDifference,
339 RangeIntersection,
340 RangeOverlaps,
341 RangeOverleft,
342 RangeOverright,
343 RangeUnion,
344 Right,
347 RoundNumeric(RoundNumericBinary),
348 StartsWith,
349 SubDate,
350 SubDateInterval,
351 SubFloat32,
352 SubFloat64,
353 SubInt16,
354 SubInt32,
355 SubInt64,
356 SubInterval,
357 SubNumeric,
358 SubTime,
359 SubTimeInterval,
360 SubTimestamp,
361 SubTimestampInterval,
362 SubTimestampTz,
363 SubTimestampTzInterval,
364 SubUint16,
365 SubUint32,
366 SubUint64,
367 TextConcat(TextConcatBinary),
368 TimezoneOffset,
372 ToCharTimestamp(ToCharTimestampFormat),
375 ToCharTimestampTz(ToCharTimestampTzFormat),
376 Trim,
377 TrimLeading,
378 TrimTrailing,
379 UuidGenerateV5,
380 }
381}
382
383#[cfg(test)]
384mod test {
385 use mz_expr_derive::sqlfunc;
386 use mz_repr::SqlColumnType;
387 use mz_repr::SqlScalarType;
388
389 use crate::scalar::func::binary::LazyBinaryFunc;
390 use crate::{BinaryFunc, EvalError, func};
391
392 #[sqlfunc(sqlname = "INFALLIBLE", is_infix_op = true)]
393 #[allow(dead_code)]
394 fn infallible1(a: f32, b: f32) -> f32 {
395 a + b
396 }
397
398 #[sqlfunc]
399 #[allow(dead_code)]
400 fn infallible2(a: Option<f32>, b: Option<f32>) -> f32 {
401 a.unwrap_or_default() + b.unwrap_or_default()
402 }
403
404 #[sqlfunc]
405 #[allow(dead_code)]
406 fn infallible3(a: f32, b: f32) -> Option<f32> {
407 Some(a + b)
408 }
409
410 #[mz_ore::test]
411 fn elision_rules_infallible() {
412 assert_eq!(format!("{}", Infallible1), "INFALLIBLE");
413 assert!(Infallible1.propagates_nulls());
414 assert!(!Infallible1.introduces_nulls());
415
416 assert!(!Infallible2.propagates_nulls());
417 assert!(!Infallible2.introduces_nulls());
418
419 assert!(Infallible3.propagates_nulls());
420 assert!(Infallible3.introduces_nulls());
421 }
422
423 #[mz_ore::test]
424 fn output_types_infallible() {
425 assert_eq!(
426 Infallible1.output_type(
427 SqlScalarType::Float32.nullable(true),
428 SqlScalarType::Float32.nullable(true)
429 ),
430 SqlScalarType::Float32.nullable(true)
431 );
432 assert_eq!(
433 Infallible1.output_type(
434 SqlScalarType::Float32.nullable(true),
435 SqlScalarType::Float32.nullable(false)
436 ),
437 SqlScalarType::Float32.nullable(true)
438 );
439 assert_eq!(
440 Infallible1.output_type(
441 SqlScalarType::Float32.nullable(false),
442 SqlScalarType::Float32.nullable(true)
443 ),
444 SqlScalarType::Float32.nullable(true)
445 );
446 assert_eq!(
447 Infallible1.output_type(
448 SqlScalarType::Float32.nullable(false),
449 SqlScalarType::Float32.nullable(false)
450 ),
451 SqlScalarType::Float32.nullable(false)
452 );
453
454 assert_eq!(
455 Infallible2.output_type(
456 SqlScalarType::Float32.nullable(true),
457 SqlScalarType::Float32.nullable(true)
458 ),
459 SqlScalarType::Float32.nullable(false)
460 );
461 assert_eq!(
462 Infallible2.output_type(
463 SqlScalarType::Float32.nullable(true),
464 SqlScalarType::Float32.nullable(false)
465 ),
466 SqlScalarType::Float32.nullable(false)
467 );
468 assert_eq!(
469 Infallible2.output_type(
470 SqlScalarType::Float32.nullable(false),
471 SqlScalarType::Float32.nullable(true)
472 ),
473 SqlScalarType::Float32.nullable(false)
474 );
475 assert_eq!(
476 Infallible2.output_type(
477 SqlScalarType::Float32.nullable(false),
478 SqlScalarType::Float32.nullable(false)
479 ),
480 SqlScalarType::Float32.nullable(false)
481 );
482
483 assert_eq!(
484 Infallible3.output_type(
485 SqlScalarType::Float32.nullable(true),
486 SqlScalarType::Float32.nullable(true)
487 ),
488 SqlScalarType::Float32.nullable(true)
489 );
490 assert_eq!(
491 Infallible3.output_type(
492 SqlScalarType::Float32.nullable(true),
493 SqlScalarType::Float32.nullable(false)
494 ),
495 SqlScalarType::Float32.nullable(true)
496 );
497 assert_eq!(
498 Infallible3.output_type(
499 SqlScalarType::Float32.nullable(false),
500 SqlScalarType::Float32.nullable(true)
501 ),
502 SqlScalarType::Float32.nullable(true)
503 );
504 assert_eq!(
505 Infallible3.output_type(
506 SqlScalarType::Float32.nullable(false),
507 SqlScalarType::Float32.nullable(false)
508 ),
509 SqlScalarType::Float32.nullable(true)
510 );
511 }
512
513 #[sqlfunc]
514 #[allow(dead_code)]
515 fn fallible1(a: f32, b: f32) -> Result<f32, EvalError> {
516 Ok(a + b)
517 }
518
519 #[sqlfunc]
520 #[allow(dead_code)]
521 fn fallible2(a: Option<f32>, b: Option<f32>) -> Result<f32, EvalError> {
522 Ok(a.unwrap_or_default() + b.unwrap_or_default())
523 }
524
525 #[sqlfunc]
526 #[allow(dead_code)]
527 fn fallible3(a: f32, b: f32) -> Result<Option<f32>, EvalError> {
528 Ok(Some(a + b))
529 }
530
531 #[mz_ore::test]
532 fn elision_rules_fallible() {
533 assert!(Fallible1.propagates_nulls());
534 assert!(!Fallible1.introduces_nulls());
535
536 assert!(!Fallible2.propagates_nulls());
537 assert!(!Fallible2.introduces_nulls());
538
539 assert!(Fallible3.propagates_nulls());
540 assert!(Fallible3.introduces_nulls());
541 }
542
543 #[mz_ore::test]
544 fn output_types_fallible() {
545 assert_eq!(
546 Fallible1.output_type(
547 SqlScalarType::Float32.nullable(true),
548 SqlScalarType::Float32.nullable(true)
549 ),
550 SqlScalarType::Float32.nullable(true)
551 );
552 assert_eq!(
553 Fallible1.output_type(
554 SqlScalarType::Float32.nullable(true),
555 SqlScalarType::Float32.nullable(false)
556 ),
557 SqlScalarType::Float32.nullable(true)
558 );
559 assert_eq!(
560 Fallible1.output_type(
561 SqlScalarType::Float32.nullable(false),
562 SqlScalarType::Float32.nullable(true)
563 ),
564 SqlScalarType::Float32.nullable(true)
565 );
566 assert_eq!(
567 Fallible1.output_type(
568 SqlScalarType::Float32.nullable(false),
569 SqlScalarType::Float32.nullable(false)
570 ),
571 SqlScalarType::Float32.nullable(false)
572 );
573
574 assert_eq!(
575 Fallible2.output_type(
576 SqlScalarType::Float32.nullable(true),
577 SqlScalarType::Float32.nullable(true)
578 ),
579 SqlScalarType::Float32.nullable(false)
580 );
581 assert_eq!(
582 Fallible2.output_type(
583 SqlScalarType::Float32.nullable(true),
584 SqlScalarType::Float32.nullable(false)
585 ),
586 SqlScalarType::Float32.nullable(false)
587 );
588 assert_eq!(
589 Fallible2.output_type(
590 SqlScalarType::Float32.nullable(false),
591 SqlScalarType::Float32.nullable(true)
592 ),
593 SqlScalarType::Float32.nullable(false)
594 );
595 assert_eq!(
596 Fallible2.output_type(
597 SqlScalarType::Float32.nullable(false),
598 SqlScalarType::Float32.nullable(false)
599 ),
600 SqlScalarType::Float32.nullable(false)
601 );
602
603 assert_eq!(
604 Fallible3.output_type(
605 SqlScalarType::Float32.nullable(true),
606 SqlScalarType::Float32.nullable(true)
607 ),
608 SqlScalarType::Float32.nullable(true)
609 );
610 assert_eq!(
611 Fallible3.output_type(
612 SqlScalarType::Float32.nullable(true),
613 SqlScalarType::Float32.nullable(false)
614 ),
615 SqlScalarType::Float32.nullable(true)
616 );
617 assert_eq!(
618 Fallible3.output_type(
619 SqlScalarType::Float32.nullable(false),
620 SqlScalarType::Float32.nullable(true)
621 ),
622 SqlScalarType::Float32.nullable(true)
623 );
624 assert_eq!(
625 Fallible3.output_type(
626 SqlScalarType::Float32.nullable(false),
627 SqlScalarType::Float32.nullable(false)
628 ),
629 SqlScalarType::Float32.nullable(true)
630 );
631 }
632
633 #[mz_ore::test]
634 fn test_equivalence_nullable() {
635 test_equivalence_inner(true);
636 }
637
638 #[mz_ore::test]
639 fn test_equivalence_non_nullable() {
640 test_equivalence_inner(false);
641 }
642
643 fn test_equivalence_inner(input_nullable: bool) {
647 #[track_caller]
648 fn check<T: LazyBinaryFunc + std::fmt::Display + std::fmt::Debug>(
649 new: T,
650 old: BinaryFunc,
651 column_a_ty: &SqlColumnType,
652 column_b_ty: &SqlColumnType,
653 ) {
654 assert_eq!(
655 new.propagates_nulls(),
656 old.propagates_nulls(),
657 "{new:?} propagates_nulls mismatch"
658 );
659 assert_eq!(
660 new.introduces_nulls(),
661 old.introduces_nulls(),
662 "{new:?} introduces_nulls mismatch"
663 );
664 assert_eq!(
665 new.could_error(),
666 old.could_error(),
667 "{new:?} could_error mismatch"
668 );
669 assert_eq!(
670 new.is_monotone(),
671 old.is_monotone(),
672 "{new:?} is_monotone mismatch"
673 );
674 assert_eq!(
675 new.is_infix_op(),
676 old.is_infix_op(),
677 "{new:?} is_infix_op mismatch"
678 );
679 assert_eq!(
680 new.output_type(column_a_ty.clone(), column_b_ty.clone()),
681 old.output_type(column_a_ty.clone(), column_b_ty.clone()),
682 "{new:?} output_type mismatch"
683 );
684 assert_eq!(new.negate(), old.negate(), "{new:?} negate mismatch");
685 assert_eq!(
686 format!("{}", new),
687 format!("{}", old),
688 "{new:?} format mismatch"
689 );
690 }
691
692 let i32_ty = SqlColumnType {
693 nullable: input_nullable,
694 scalar_type: SqlScalarType::Int32,
695 };
696
697 use BinaryFunc as BF;
698
699 check(
703 func::RangeContainsI32,
704 BF::RangeContainsElem {
705 elem_type: SqlScalarType::Int32,
706 rev: false,
707 },
708 &i32_ty,
709 &i32_ty,
710 );
711 check(
712 func::RangeContainsI64,
713 BF::RangeContainsElem {
714 elem_type: SqlScalarType::Int64,
715 rev: false,
716 },
717 &i32_ty,
718 &i32_ty,
719 );
720 check(
721 func::RangeContainsDate,
722 BF::RangeContainsElem {
723 elem_type: SqlScalarType::Date,
724 rev: false,
725 },
726 &i32_ty,
727 &i32_ty,
728 );
729 check(
730 func::RangeContainsNumeric,
731 BF::RangeContainsElem {
732 elem_type: SqlScalarType::Numeric { max_scale: None },
733 rev: false,
734 },
735 &i32_ty,
736 &i32_ty,
737 );
738 check(
739 func::RangeContainsTimestamp,
740 BF::RangeContainsElem {
741 elem_type: SqlScalarType::Timestamp { precision: None },
742 rev: false,
743 },
744 &i32_ty,
745 &i32_ty,
746 );
747 check(
748 func::RangeContainsTimestampTz,
749 BF::RangeContainsElem {
750 elem_type: SqlScalarType::TimestampTz { precision: None },
751 rev: false,
752 },
753 &i32_ty,
754 &i32_ty,
755 );
756 check(
757 func::RangeContainsI32Rev,
758 BF::RangeContainsElem {
759 elem_type: SqlScalarType::Int32,
760 rev: true,
761 },
762 &i32_ty,
763 &i32_ty,
764 );
765 check(
766 func::RangeContainsI64Rev,
767 BF::RangeContainsElem {
768 elem_type: SqlScalarType::Int64,
769 rev: true,
770 },
771 &i32_ty,
772 &i32_ty,
773 );
774 check(
775 func::RangeContainsDateRev,
776 BF::RangeContainsElem {
777 elem_type: SqlScalarType::Date,
778 rev: true,
779 },
780 &i32_ty,
781 &i32_ty,
782 );
783 check(
784 func::RangeContainsNumericRev,
785 BF::RangeContainsElem {
786 elem_type: SqlScalarType::Numeric { max_scale: None },
787 rev: true,
788 },
789 &i32_ty,
790 &i32_ty,
791 );
792 check(
793 func::RangeContainsTimestampRev,
794 BF::RangeContainsElem {
795 elem_type: SqlScalarType::Timestamp { precision: None },
796 rev: true,
797 },
798 &i32_ty,
799 &i32_ty,
800 );
801 check(
802 func::RangeContainsTimestampTzRev,
803 BF::RangeContainsElem {
804 elem_type: SqlScalarType::TimestampTz { precision: None },
805 rev: true,
806 },
807 &i32_ty,
808 &i32_ty,
809 );
810
811 check(
812 func::RangeContainsRange,
813 BF::RangeContainsRange { rev: false },
814 &i32_ty,
815 &i32_ty,
816 );
817 check(
818 func::RangeContainsRangeRev,
819 BF::RangeContainsRange { rev: true },
820 &i32_ty,
821 &i32_ty,
822 );
823
824 check(
829 func::ListContainsList,
830 BF::ListContainsList { rev: false },
831 &i32_ty,
832 &i32_ty,
833 );
834 check(
835 func::ListContainsListRev,
836 BF::ListContainsList { rev: true },
837 &i32_ty,
838 &i32_ty,
839 );
840
841 check(
843 func::ArrayContainsArray,
844 BF::ArrayContainsArray { rev: false },
845 &i32_ty,
846 &i32_ty,
847 );
848 check(
849 func::ArrayContainsArrayRev,
850 BF::ArrayContainsArray { rev: true },
851 &i32_ty,
852 &i32_ty,
853 );
854 }
855}