1use mz_repr::{Datum, DatumType, RowArena, SqlColumnType};
13
14use crate::{EvalError, MirScalarExpr};
15
16#[allow(unused)]
19pub(crate) trait LazyBinaryFunc {
20 fn eval<'a>(
21 &'a self,
22 datums: &[Datum<'a>],
23 temp_storage: &'a RowArena,
24 a: &'a MirScalarExpr,
25 b: &'a MirScalarExpr,
26 ) -> Result<Datum<'a>, EvalError>;
27
28 fn output_type(
30 &self,
31 input_type_a: SqlColumnType,
32 input_type_b: SqlColumnType,
33 ) -> SqlColumnType;
34
35 fn propagates_nulls(&self) -> bool;
37
38 fn introduces_nulls(&self) -> bool;
40
41 fn could_error(&self) -> bool {
43 true
45 }
46
47 fn negate(&self) -> Option<crate::BinaryFunc>;
49
50 fn is_monotone(&self) -> (bool, bool);
63
64 fn is_infix_op(&self) -> bool;
66}
67
68#[allow(unused)]
69pub(crate) trait EagerBinaryFunc<'a> {
70 type Input1: DatumType<'a, EvalError>;
71 type Input2: DatumType<'a, EvalError>;
72 type Output: DatumType<'a, EvalError>;
73
74 fn call(&self, a: Self::Input1, b: Self::Input2, temp_storage: &'a RowArena) -> Self::Output;
75
76 fn output_type(
78 &self,
79 input_type_a: SqlColumnType,
80 input_type_b: SqlColumnType,
81 ) -> SqlColumnType;
82
83 fn propagates_nulls(&self) -> bool {
85 !Self::Input1::nullable() && !Self::Input2::nullable()
87 }
88
89 fn introduces_nulls(&self) -> bool {
91 Self::Output::nullable()
93 }
94
95 fn could_error(&self) -> bool {
97 Self::Output::fallible()
98 }
99
100 fn negate(&self) -> Option<crate::BinaryFunc> {
102 None
103 }
104
105 fn is_monotone(&self) -> (bool, bool) {
106 (false, false)
107 }
108
109 fn is_infix_op(&self) -> bool {
110 false
111 }
112}
113
114impl<T: for<'a> EagerBinaryFunc<'a>> LazyBinaryFunc for T {
115 fn eval<'a>(
116 &'a self,
117 datums: &[Datum<'a>],
118 temp_storage: &'a RowArena,
119 a: &'a MirScalarExpr,
120 b: &'a MirScalarExpr,
121 ) -> Result<Datum<'a>, EvalError> {
122 let a = a.eval(datums, temp_storage)?;
123 let b = b.eval(datums, temp_storage)?;
124 let a = match T::Input1::try_from_result(Ok(a)) {
125 Ok(input) => input,
127 Err(Ok(datum)) if !datum.is_null() => {
129 return Err(EvalError::Internal("invalid input type".into()));
130 }
131 Err(res) => return res,
133 };
134 let b = match T::Input2::try_from_result(Ok(b)) {
135 Ok(input) => input,
137 Err(Ok(datum)) if !datum.is_null() => {
139 return Err(EvalError::Internal("invalid input type".into()));
140 }
141 Err(res) => return res,
143 };
144 self.call(a, b, temp_storage).into_result(temp_storage)
145 }
146
147 fn output_type(
148 &self,
149 input_type_a: SqlColumnType,
150 input_type_b: SqlColumnType,
151 ) -> SqlColumnType {
152 self.output_type(input_type_a, input_type_b)
153 }
154
155 fn propagates_nulls(&self) -> bool {
156 self.propagates_nulls()
157 }
158
159 fn introduces_nulls(&self) -> bool {
160 self.introduces_nulls()
161 }
162
163 fn could_error(&self) -> bool {
164 self.could_error()
165 }
166
167 fn negate(&self) -> Option<crate::BinaryFunc> {
168 self.negate()
169 }
170
171 fn is_monotone(&self) -> (bool, bool) {
172 self.is_monotone()
173 }
174
175 fn is_infix_op(&self) -> bool {
176 self.is_infix_op()
177 }
178}
179
180mod derive {
181 use crate::scalar::func::*;
182
183 derive_binary_from! {
184 AddDateInterval,
185 AddDateTime,
186 AddFloat32,
187 AddFloat64,
188 AddInt16,
189 AddInt32,
190 AddInt64,
191 AddInterval,
192 AddNumeric,
193 AddTimeInterval,
194 AddTimestampInterval,
195 AddTimestampTzInterval,
196 AddUint16,
197 AddUint32,
198 AddUint64,
199 AgeTimestamp,
200 AgeTimestampTz,
201 ArrayArrayConcat,
202 ArrayContains,
203 ArrayContainsArray,
204 ArrayContainsArrayRev,
205 ArrayLength,
206 ArrayLower,
207 ArrayRemove,
208 ArrayUpper,
209 BitAndInt16,
210 BitAndInt32,
211 BitAndInt64,
212 BitAndUint16,
213 BitAndUint32,
214 BitAndUint64,
215 BitOrInt16,
216 BitOrInt32,
217 BitOrInt64,
218 BitOrUint16,
219 BitOrUint32,
220 BitOrUint64,
221 BitShiftLeftInt16,
222 BitShiftLeftInt32,
223 BitShiftLeftInt64,
224 BitShiftLeftUint16,
225 BitShiftLeftUint32,
226 BitShiftLeftUint64,
227 BitShiftRightInt16,
228 BitShiftRightInt32,
229 BitShiftRightInt64,
230 BitShiftRightUint16,
231 BitShiftRightUint32,
232 BitShiftRightUint64,
233 BitXorInt16,
234 BitXorInt32,
235 BitXorInt64,
236 BitXorUint16,
237 BitXorUint32,
238 BitXorUint64,
239 ConstantTimeEqBytes,
240 ConstantTimeEqString,
241 ConvertFrom,
242 DateBinTimestamp,
243 DateBinTimestampTz,
244 DatePartInterval(DatePartIntervalF64),
245 DatePartTime(DatePartTimeF64),
246 DatePartTimestamp(DatePartTimestampTimestampF64),
247 DatePartTimestampTz(DatePartTimestampTimestampTzF64),
248 DateTruncInterval,
249 DateTruncTimestamp(DateTruncUnitsTimestamp),
250 DateTruncTimestampTz(DateTruncUnitsTimestampTz),
251 Decode,
252 DigestBytes,
253 DigestString,
254 DivFloat32,
255 DivFloat64,
256 DivInt16,
257 DivInt32,
258 DivInt64,
259 DivInterval,
260 DivNumeric,
261 DivUint16,
262 DivUint32,
263 DivUint64,
264 ElementListConcat,
265 Encode,
266 EncodedBytesCharLength,
267 Eq,
268 ExtractDate(ExtractDateUnits),
269 ExtractInterval(DatePartIntervalNumeric),
270 ExtractTime(DatePartTimeNumeric),
271 ExtractTimestamp(DatePartTimestampTimestampNumeric),
272 ExtractTimestampTz(DatePartTimestampTimestampTzNumeric),
273 GetBit,
274 GetByte,
275 Gt,
276 Gte,
277 IsLikeMatchCaseInsensitive,
278 IsLikeMatchCaseSensitive,
279 JsonbConcat,
281 JsonbContainsJsonb,
282 JsonbContainsString,
283 JsonbDeleteInt64,
284 JsonbDeleteString,
285 JsonbGetInt64,
286 JsonbGetInt64Stringify,
287 JsonbGetPath,
288 JsonbGetPathStringify,
289 JsonbGetString,
290 JsonbGetStringStringify,
291 Left,
292 LikeEscape,
293 ListContainsList,
294 ListContainsListRev,
295 ListElementConcat,
296 ListLengthMax,
297 ListListConcat,
298 ListRemove,
299 LogNumeric(LogBaseNumeric),
300 Lt,
301 Lte,
302 MapContainsAllKeys,
303 MapContainsAnyKeys,
304 MapContainsKey,
305 MapContainsMap,
306 MapGetValue,
307 ModFloat32,
308 ModFloat64,
309 ModInt16,
310 ModInt32,
311 ModInt64,
312 ModNumeric,
313 ModUint16,
314 ModUint32,
315 ModUint64,
316 MulFloat32,
317 MulFloat64,
318 MulInt16,
319 MulInt32,
320 MulInt64,
321 MulInterval,
322 MulNumeric,
323 MulUint16,
324 MulUint32,
325 MulUint64,
326 MzAclItemContainsPrivilege,
327 MzRenderTypmod,
328 NotEq,
330 ParseIdent,
331 Position,
332 Power,
333 PowerNumeric,
334 PrettySql,
335 RangeAdjacent,
336 RangeAfter,
337 RangeBefore,
338 RangeContainsRange,
340 RangeContainsRangeRev,
341 RangeDifference,
342 RangeIntersection,
343 RangeOverlaps,
344 RangeOverleft,
345 RangeOverright,
346 RangeUnion,
347 Right,
350 RoundNumeric(RoundNumericBinary),
351 StartsWith,
352 SubDate,
353 SubDateInterval,
354 SubFloat32,
355 SubFloat64,
356 SubInt16,
357 SubInt32,
358 SubInt64,
359 SubInterval,
360 SubNumeric,
361 SubTime,
362 SubTimeInterval,
363 SubTimestamp,
364 SubTimestampInterval,
365 SubTimestampTz,
366 SubTimestampTzInterval,
367 SubUint16,
368 SubUint32,
369 SubUint64,
370 TextConcat(TextConcatBinary),
371 TimezoneOffset,
375 ToCharTimestamp(ToCharTimestampFormat),
378 ToCharTimestampTz(ToCharTimestampTzFormat),
379 Trim,
380 TrimLeading,
381 TrimTrailing,
382 UuidGenerateV5,
383 }
384}
385
386#[cfg(test)]
387mod test {
388 use mz_expr_derive::sqlfunc;
389 use mz_repr::SqlColumnType;
390 use mz_repr::SqlScalarType;
391
392 use crate::scalar::func::binary::LazyBinaryFunc;
393 use crate::{BinaryFunc, EvalError, func};
394
395 #[sqlfunc(sqlname = "INFALLIBLE", is_infix_op = true)]
396 #[allow(dead_code)]
397 fn infallible1(a: f32, b: f32) -> f32 {
398 a + b
399 }
400
401 #[sqlfunc]
402 #[allow(dead_code)]
403 fn infallible2(a: Option<f32>, b: Option<f32>) -> f32 {
404 a.unwrap_or_default() + b.unwrap_or_default()
405 }
406
407 #[sqlfunc]
408 #[allow(dead_code)]
409 fn infallible3(a: f32, b: f32) -> Option<f32> {
410 Some(a + b)
411 }
412
413 #[mz_ore::test]
414 fn elision_rules_infallible() {
415 assert_eq!(format!("{}", Infallible1), "INFALLIBLE");
416 assert!(Infallible1.propagates_nulls());
417 assert!(!Infallible1.introduces_nulls());
418
419 assert!(!Infallible2.propagates_nulls());
420 assert!(!Infallible2.introduces_nulls());
421
422 assert!(Infallible3.propagates_nulls());
423 assert!(Infallible3.introduces_nulls());
424 }
425
426 #[mz_ore::test]
427 fn output_types_infallible() {
428 assert_eq!(
429 Infallible1.output_type(
430 SqlScalarType::Float32.nullable(true),
431 SqlScalarType::Float32.nullable(true)
432 ),
433 SqlScalarType::Float32.nullable(true)
434 );
435 assert_eq!(
436 Infallible1.output_type(
437 SqlScalarType::Float32.nullable(true),
438 SqlScalarType::Float32.nullable(false)
439 ),
440 SqlScalarType::Float32.nullable(true)
441 );
442 assert_eq!(
443 Infallible1.output_type(
444 SqlScalarType::Float32.nullable(false),
445 SqlScalarType::Float32.nullable(true)
446 ),
447 SqlScalarType::Float32.nullable(true)
448 );
449 assert_eq!(
450 Infallible1.output_type(
451 SqlScalarType::Float32.nullable(false),
452 SqlScalarType::Float32.nullable(false)
453 ),
454 SqlScalarType::Float32.nullable(false)
455 );
456
457 assert_eq!(
458 Infallible2.output_type(
459 SqlScalarType::Float32.nullable(true),
460 SqlScalarType::Float32.nullable(true)
461 ),
462 SqlScalarType::Float32.nullable(false)
463 );
464 assert_eq!(
465 Infallible2.output_type(
466 SqlScalarType::Float32.nullable(true),
467 SqlScalarType::Float32.nullable(false)
468 ),
469 SqlScalarType::Float32.nullable(false)
470 );
471 assert_eq!(
472 Infallible2.output_type(
473 SqlScalarType::Float32.nullable(false),
474 SqlScalarType::Float32.nullable(true)
475 ),
476 SqlScalarType::Float32.nullable(false)
477 );
478 assert_eq!(
479 Infallible2.output_type(
480 SqlScalarType::Float32.nullable(false),
481 SqlScalarType::Float32.nullable(false)
482 ),
483 SqlScalarType::Float32.nullable(false)
484 );
485
486 assert_eq!(
487 Infallible3.output_type(
488 SqlScalarType::Float32.nullable(true),
489 SqlScalarType::Float32.nullable(true)
490 ),
491 SqlScalarType::Float32.nullable(true)
492 );
493 assert_eq!(
494 Infallible3.output_type(
495 SqlScalarType::Float32.nullable(true),
496 SqlScalarType::Float32.nullable(false)
497 ),
498 SqlScalarType::Float32.nullable(true)
499 );
500 assert_eq!(
501 Infallible3.output_type(
502 SqlScalarType::Float32.nullable(false),
503 SqlScalarType::Float32.nullable(true)
504 ),
505 SqlScalarType::Float32.nullable(true)
506 );
507 assert_eq!(
508 Infallible3.output_type(
509 SqlScalarType::Float32.nullable(false),
510 SqlScalarType::Float32.nullable(false)
511 ),
512 SqlScalarType::Float32.nullable(true)
513 );
514 }
515
516 #[sqlfunc]
517 #[allow(dead_code)]
518 fn fallible1(a: f32, b: f32) -> Result<f32, EvalError> {
519 Ok(a + b)
520 }
521
522 #[sqlfunc]
523 #[allow(dead_code)]
524 fn fallible2(a: Option<f32>, b: Option<f32>) -> Result<f32, EvalError> {
525 Ok(a.unwrap_or_default() + b.unwrap_or_default())
526 }
527
528 #[sqlfunc]
529 #[allow(dead_code)]
530 fn fallible3(a: f32, b: f32) -> Result<Option<f32>, EvalError> {
531 Ok(Some(a + b))
532 }
533
534 #[mz_ore::test]
535 fn elision_rules_fallible() {
536 assert!(Fallible1.propagates_nulls());
537 assert!(!Fallible1.introduces_nulls());
538
539 assert!(!Fallible2.propagates_nulls());
540 assert!(!Fallible2.introduces_nulls());
541
542 assert!(Fallible3.propagates_nulls());
543 assert!(Fallible3.introduces_nulls());
544 }
545
546 #[mz_ore::test]
547 fn output_types_fallible() {
548 assert_eq!(
549 Fallible1.output_type(
550 SqlScalarType::Float32.nullable(true),
551 SqlScalarType::Float32.nullable(true)
552 ),
553 SqlScalarType::Float32.nullable(true)
554 );
555 assert_eq!(
556 Fallible1.output_type(
557 SqlScalarType::Float32.nullable(true),
558 SqlScalarType::Float32.nullable(false)
559 ),
560 SqlScalarType::Float32.nullable(true)
561 );
562 assert_eq!(
563 Fallible1.output_type(
564 SqlScalarType::Float32.nullable(false),
565 SqlScalarType::Float32.nullable(true)
566 ),
567 SqlScalarType::Float32.nullable(true)
568 );
569 assert_eq!(
570 Fallible1.output_type(
571 SqlScalarType::Float32.nullable(false),
572 SqlScalarType::Float32.nullable(false)
573 ),
574 SqlScalarType::Float32.nullable(false)
575 );
576
577 assert_eq!(
578 Fallible2.output_type(
579 SqlScalarType::Float32.nullable(true),
580 SqlScalarType::Float32.nullable(true)
581 ),
582 SqlScalarType::Float32.nullable(false)
583 );
584 assert_eq!(
585 Fallible2.output_type(
586 SqlScalarType::Float32.nullable(true),
587 SqlScalarType::Float32.nullable(false)
588 ),
589 SqlScalarType::Float32.nullable(false)
590 );
591 assert_eq!(
592 Fallible2.output_type(
593 SqlScalarType::Float32.nullable(false),
594 SqlScalarType::Float32.nullable(true)
595 ),
596 SqlScalarType::Float32.nullable(false)
597 );
598 assert_eq!(
599 Fallible2.output_type(
600 SqlScalarType::Float32.nullable(false),
601 SqlScalarType::Float32.nullable(false)
602 ),
603 SqlScalarType::Float32.nullable(false)
604 );
605
606 assert_eq!(
607 Fallible3.output_type(
608 SqlScalarType::Float32.nullable(true),
609 SqlScalarType::Float32.nullable(true)
610 ),
611 SqlScalarType::Float32.nullable(true)
612 );
613 assert_eq!(
614 Fallible3.output_type(
615 SqlScalarType::Float32.nullable(true),
616 SqlScalarType::Float32.nullable(false)
617 ),
618 SqlScalarType::Float32.nullable(true)
619 );
620 assert_eq!(
621 Fallible3.output_type(
622 SqlScalarType::Float32.nullable(false),
623 SqlScalarType::Float32.nullable(true)
624 ),
625 SqlScalarType::Float32.nullable(true)
626 );
627 assert_eq!(
628 Fallible3.output_type(
629 SqlScalarType::Float32.nullable(false),
630 SqlScalarType::Float32.nullable(false)
631 ),
632 SqlScalarType::Float32.nullable(true)
633 );
634 }
635
636 #[mz_ore::test]
637 fn test_equivalence_nullable() {
638 test_equivalence_inner(true);
639 }
640
641 #[mz_ore::test]
642 fn test_equivalence_non_nullable() {
643 test_equivalence_inner(false);
644 }
645
646 fn test_equivalence_inner(input_nullable: bool) {
650 #[track_caller]
651 fn check<T: LazyBinaryFunc + std::fmt::Display + std::fmt::Debug>(
652 new: T,
653 old: BinaryFunc,
654 column_a_ty: &SqlColumnType,
655 column_b_ty: &SqlColumnType,
656 ) {
657 assert_eq!(
658 new.propagates_nulls(),
659 old.propagates_nulls(),
660 "{new:?} propagates_nulls mismatch"
661 );
662 assert_eq!(
663 new.introduces_nulls(),
664 old.introduces_nulls(),
665 "{new:?} introduces_nulls mismatch"
666 );
667 assert_eq!(
668 new.could_error(),
669 old.could_error(),
670 "{new:?} could_error mismatch"
671 );
672 assert_eq!(
673 new.is_monotone(),
674 old.is_monotone(),
675 "{new:?} is_monotone mismatch"
676 );
677 assert_eq!(
678 new.is_infix_op(),
679 old.is_infix_op(),
680 "{new:?} is_infix_op mismatch"
681 );
682 assert_eq!(
683 new.output_type(column_a_ty.clone(), column_b_ty.clone()),
684 old.output_type(column_a_ty.clone(), column_b_ty.clone()),
685 "{new:?} output_type mismatch"
686 );
687 assert_eq!(new.negate(), old.negate(), "{new:?} negate mismatch");
688 assert_eq!(
689 format!("{}", new),
690 format!("{}", old),
691 "{new:?} format mismatch"
692 );
693 }
694
695 let i32_ty = SqlColumnType {
696 nullable: input_nullable,
697 scalar_type: SqlScalarType::Int32,
698 };
699
700 use BinaryFunc as BF;
701
702 check(
706 func::RangeContainsI32,
707 BF::RangeContainsElem {
708 elem_type: SqlScalarType::Int32,
709 rev: false,
710 },
711 &i32_ty,
712 &i32_ty,
713 );
714 check(
715 func::RangeContainsI64,
716 BF::RangeContainsElem {
717 elem_type: SqlScalarType::Int64,
718 rev: false,
719 },
720 &i32_ty,
721 &i32_ty,
722 );
723 check(
724 func::RangeContainsDate,
725 BF::RangeContainsElem {
726 elem_type: SqlScalarType::Date,
727 rev: false,
728 },
729 &i32_ty,
730 &i32_ty,
731 );
732 check(
733 func::RangeContainsNumeric,
734 BF::RangeContainsElem {
735 elem_type: SqlScalarType::Numeric { max_scale: None },
736 rev: false,
737 },
738 &i32_ty,
739 &i32_ty,
740 );
741 check(
742 func::RangeContainsTimestamp,
743 BF::RangeContainsElem {
744 elem_type: SqlScalarType::Timestamp { precision: None },
745 rev: false,
746 },
747 &i32_ty,
748 &i32_ty,
749 );
750 check(
751 func::RangeContainsTimestampTz,
752 BF::RangeContainsElem {
753 elem_type: SqlScalarType::TimestampTz { precision: None },
754 rev: false,
755 },
756 &i32_ty,
757 &i32_ty,
758 );
759 check(
760 func::RangeContainsI32Rev,
761 BF::RangeContainsElem {
762 elem_type: SqlScalarType::Int32,
763 rev: true,
764 },
765 &i32_ty,
766 &i32_ty,
767 );
768 check(
769 func::RangeContainsI64Rev,
770 BF::RangeContainsElem {
771 elem_type: SqlScalarType::Int64,
772 rev: true,
773 },
774 &i32_ty,
775 &i32_ty,
776 );
777 check(
778 func::RangeContainsDateRev,
779 BF::RangeContainsElem {
780 elem_type: SqlScalarType::Date,
781 rev: true,
782 },
783 &i32_ty,
784 &i32_ty,
785 );
786 check(
787 func::RangeContainsNumericRev,
788 BF::RangeContainsElem {
789 elem_type: SqlScalarType::Numeric { max_scale: None },
790 rev: true,
791 },
792 &i32_ty,
793 &i32_ty,
794 );
795 check(
796 func::RangeContainsTimestampRev,
797 BF::RangeContainsElem {
798 elem_type: SqlScalarType::Timestamp { precision: None },
799 rev: true,
800 },
801 &i32_ty,
802 &i32_ty,
803 );
804 check(
805 func::RangeContainsTimestampTzRev,
806 BF::RangeContainsElem {
807 elem_type: SqlScalarType::TimestampTz { precision: None },
808 rev: true,
809 },
810 &i32_ty,
811 &i32_ty,
812 );
813
814 }
816}