1use mz_ore::assert_none;
13use mz_repr::{Datum, InputDatumType, OutputDatumType, ReprColumnType, RowArena, SqlColumnType};
14
15use crate::{EvalError, MirScalarExpr};
16
17pub(crate) trait LazyBinaryFunc {
20 fn eval<'a>(
21 &'a self,
22 datums: &[Datum<'a>],
23 temp_storage: &'a RowArena,
24 exprs: &[&'a MirScalarExpr],
25 ) -> Result<Datum<'a>, EvalError>;
26
27 fn output_sql_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType;
29
30 fn output_type(&self, input_types: &[ReprColumnType]) -> ReprColumnType {
32 ReprColumnType::from(
33 &self.output_sql_type(
34 &input_types
35 .iter()
36 .map(SqlColumnType::from_repr)
37 .collect::<Vec<_>>(),
38 ),
39 )
40 }
41
42 fn propagates_nulls(&self) -> bool;
44
45 fn introduces_nulls(&self) -> bool;
47
48 fn could_error(&self) -> bool {
50 true
52 }
53
54 fn negate(&self) -> Option<crate::BinaryFunc>;
56
57 fn is_monotone(&self) -> (bool, bool);
70
71 fn is_infix_op(&self) -> bool;
73}
74
75pub(crate) trait EagerBinaryFunc {
76 type Input<'a>: InputDatumType<'a, EvalError>;
77 type Output<'a>: OutputDatumType<'a, EvalError>;
78
79 fn call<'a>(&self, input: Self::Input<'a>, temp_storage: &'a RowArena) -> Self::Output<'a>;
80
81 fn output_sql_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType;
83
84 #[allow(dead_code)]
86 fn output_type(&self, input_types: &[ReprColumnType]) -> ReprColumnType {
87 ReprColumnType::from(
88 &self.output_sql_type(
89 &input_types
90 .iter()
91 .map(SqlColumnType::from_repr)
92 .collect::<Vec<_>>(),
93 ),
94 )
95 }
96
97 fn propagates_nulls(&self) -> bool {
99 !Self::Input::nullable()
101 }
102
103 fn introduces_nulls(&self) -> bool {
105 Self::Output::nullable()
107 }
108
109 fn could_error(&self) -> bool {
111 Self::Output::fallible()
112 }
113
114 fn negate(&self) -> Option<crate::BinaryFunc> {
116 None
117 }
118
119 fn is_monotone(&self) -> (bool, bool) {
120 (false, false)
121 }
122
123 fn is_infix_op(&self) -> bool {
124 false
125 }
126}
127
128impl<T: EagerBinaryFunc> LazyBinaryFunc for T {
129 fn eval<'a>(
130 &'a self,
131 datums: &[Datum<'a>],
132 temp_storage: &'a RowArena,
133 exprs: &[&'a MirScalarExpr],
134 ) -> Result<Datum<'a>, EvalError> {
135 let mut datums = exprs
136 .into_iter()
137 .map(|expr| expr.eval(datums, temp_storage));
138 let input = match T::Input::try_from_iter(&mut datums) {
139 Ok(input) => input,
141 Err(Ok(Some(datum))) if !datum.is_null() => {
143 return Err(EvalError::Internal("invalid input type".into()));
144 }
145 Err(Ok(None)) => {
146 return Err(EvalError::Internal("unexpectedly missing parameter".into()));
147 }
148 Err(Ok(Some(datum))) => return Ok(datum),
150 Err(Err(res)) => return Err(res),
151 };
152 assert_none!(datums.next(), "No leftover input arguments");
153 self.call(input, temp_storage).into_result(temp_storage)
154 }
155
156 fn output_sql_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
157 self.output_sql_type(input_types)
158 }
159
160 fn propagates_nulls(&self) -> bool {
161 self.propagates_nulls()
162 }
163
164 fn introduces_nulls(&self) -> bool {
165 self.introduces_nulls()
166 }
167
168 fn could_error(&self) -> bool {
169 self.could_error()
170 }
171
172 fn negate(&self) -> Option<crate::BinaryFunc> {
173 self.negate()
174 }
175
176 fn is_monotone(&self) -> (bool, bool) {
177 self.is_monotone()
178 }
179
180 fn is_infix_op(&self) -> bool {
181 self.is_infix_op()
182 }
183}
184
185pub use derive::BinaryFunc;
186
187mod derive {
188 use std::fmt;
189
190 use mz_repr::{Datum, ReprColumnType, RowArena, SqlColumnType};
191
192 use crate::scalar::func::binary::LazyBinaryFunc;
193 use crate::scalar::func::*;
194 use crate::{EvalError, MirScalarExpr};
195
196 derive_binary! {
197 AddInt16(AddInt16),
198 AddInt32(AddInt32),
199 AddInt64(AddInt64),
200 AddUint16(AddUint16),
201 AddUint32(AddUint32),
202 AddUint64(AddUint64),
203 AddFloat32(AddFloat32),
204 AddFloat64(AddFloat64),
205 AddInterval(AddInterval),
206 AddTimestampInterval(AddTimestampInterval),
207 AddTimestampTzInterval(AddTimestampTzInterval),
208 AddDateInterval(AddDateInterval),
209 AddDateTime(AddDateTime),
210 AddTimeInterval(AddTimeInterval),
211 AddNumeric(AddNumeric),
212 AgeTimestamp(AgeTimestamp),
213 AgeTimestampTz(AgeTimestampTz),
214 BitAndInt16(BitAndInt16),
215 BitAndInt32(BitAndInt32),
216 BitAndInt64(BitAndInt64),
217 BitAndUint16(BitAndUint16),
218 BitAndUint32(BitAndUint32),
219 BitAndUint64(BitAndUint64),
220 BitOrInt16(BitOrInt16),
221 BitOrInt32(BitOrInt32),
222 BitOrInt64(BitOrInt64),
223 BitOrUint16(BitOrUint16),
224 BitOrUint32(BitOrUint32),
225 BitOrUint64(BitOrUint64),
226 BitXorInt16(BitXorInt16),
227 BitXorInt32(BitXorInt32),
228 BitXorInt64(BitXorInt64),
229 BitXorUint16(BitXorUint16),
230 BitXorUint32(BitXorUint32),
231 BitXorUint64(BitXorUint64),
232 BitShiftLeftInt16(BitShiftLeftInt16),
233 BitShiftLeftInt32(BitShiftLeftInt32),
234 BitShiftLeftInt64(BitShiftLeftInt64),
235 BitShiftLeftUint16(BitShiftLeftUint16),
236 BitShiftLeftUint32(BitShiftLeftUint32),
237 BitShiftLeftUint64(BitShiftLeftUint64),
238 BitShiftRightInt16(BitShiftRightInt16),
239 BitShiftRightInt32(BitShiftRightInt32),
240 BitShiftRightInt64(BitShiftRightInt64),
241 BitShiftRightUint16(BitShiftRightUint16),
242 BitShiftRightUint32(BitShiftRightUint32),
243 BitShiftRightUint64(BitShiftRightUint64),
244 SubInt16(SubInt16),
245 SubInt32(SubInt32),
246 SubInt64(SubInt64),
247 SubUint16(SubUint16),
248 SubUint32(SubUint32),
249 SubUint64(SubUint64),
250 SubFloat32(SubFloat32),
251 SubFloat64(SubFloat64),
252 SubInterval(SubInterval),
253 SubTimestamp(SubTimestamp),
254 SubTimestampTz(SubTimestampTz),
255 SubTimestampInterval(SubTimestampInterval),
256 SubTimestampTzInterval(SubTimestampTzInterval),
257 SubDate(SubDate),
258 SubDateInterval(SubDateInterval),
259 SubTime(SubTime),
260 SubTimeInterval(SubTimeInterval),
261 SubNumeric(SubNumeric),
262 MulInt16(MulInt16),
263 MulInt32(MulInt32),
264 MulInt64(MulInt64),
265 MulUint16(MulUint16),
266 MulUint32(MulUint32),
267 MulUint64(MulUint64),
268 MulFloat32(MulFloat32),
269 MulFloat64(MulFloat64),
270 MulNumeric(MulNumeric),
271 MulInterval(MulInterval),
272 DivInt16(DivInt16),
273 DivInt32(DivInt32),
274 DivInt64(DivInt64),
275 DivUint16(DivUint16),
276 DivUint32(DivUint32),
277 DivUint64(DivUint64),
278 DivFloat32(DivFloat32),
279 DivFloat64(DivFloat64),
280 DivNumeric(DivNumeric),
281 DivInterval(DivInterval),
282 ModInt16(ModInt16),
283 ModInt32(ModInt32),
284 ModInt64(ModInt64),
285 ModUint16(ModUint16),
286 ModUint32(ModUint32),
287 ModUint64(ModUint64),
288 ModFloat32(ModFloat32),
289 ModFloat64(ModFloat64),
290 ModNumeric(ModNumeric),
291 RoundNumeric(RoundNumericBinary),
292 Eq(Eq),
293 NotEq(NotEq),
294 Lt(Lt),
295 Lte(Lte),
296 Gt(Gt),
297 Gte(Gte),
298 LikeEscape(LikeEscape),
299 IsLikeMatchCaseInsensitive(IsLikeMatchCaseInsensitive),
300 IsLikeMatchCaseSensitive(IsLikeMatchCaseSensitive),
301 IsRegexpMatchCaseSensitive(IsRegexpMatchCaseSensitive),
302 IsRegexpMatchCaseInsensitive(IsRegexpMatchCaseInsensitive),
303 ToCharTimestamp(ToCharTimestampFormat),
304 ToCharTimestampTz(ToCharTimestampTzFormat),
305 DateBinTimestamp(DateBinTimestamp),
306 DateBinTimestampTz(DateBinTimestampTz),
307 ExtractInterval(DatePartIntervalNumeric),
308 ExtractTime(DatePartTimeNumeric),
309 ExtractTimestamp(DatePartTimestampTimestampNumeric),
310 ExtractTimestampTz(DatePartTimestampTimestampTzNumeric),
311 ExtractDate(ExtractDateUnits),
312 DatePartInterval(DatePartIntervalF64),
313 DatePartTime(DatePartTimeF64),
314 DatePartTimestamp(DatePartTimestampTimestampF64),
315 DatePartTimestampTz(DatePartTimestampTimestampTzF64),
316 DateTruncTimestamp(DateTruncUnitsTimestamp),
317 DateTruncTimestampTz(DateTruncUnitsTimestampTz),
318 DateTruncInterval(DateTruncInterval),
319 TimezoneTimestampBinary(TimezoneTimestampBinary),
320 TimezoneTimestampTzBinary(TimezoneTimestampTzBinary),
321 TimezoneIntervalTimestampBinary(TimezoneIntervalTimestampBinary),
322 TimezoneIntervalTimestampTzBinary(TimezoneIntervalTimestampTzBinary),
323 TimezoneIntervalTimeBinary(TimezoneIntervalTimeBinary),
324 TimezoneOffset(TimezoneOffset),
325 TextConcat(TextConcatBinary),
326 JsonbGetInt64(JsonbGetInt64),
327 JsonbGetInt64Stringify(JsonbGetInt64Stringify),
328 JsonbGetString(JsonbGetString),
329 JsonbGetStringStringify(JsonbGetStringStringify),
330 JsonbGetPath(JsonbGetPath),
331 JsonbGetPathStringify(JsonbGetPathStringify),
332 JsonbContainsString(JsonbContainsString),
333 JsonbConcat(JsonbConcat),
334 JsonbContainsJsonb(JsonbContainsJsonb),
335 JsonbDeleteInt64(JsonbDeleteInt64),
336 JsonbDeleteString(JsonbDeleteString),
337 MapContainsKey(MapContainsKey),
338 MapGetValue(MapGetValue),
339 MapContainsAllKeys(MapContainsAllKeys),
340 MapContainsAnyKeys(MapContainsAnyKeys),
341 MapContainsMap(MapContainsMap),
342 ConvertFrom(ConvertFrom),
343 Left(Left),
344 Position(Position),
345 Strpos(Strpos),
346 Right(Right),
347 RepeatString(RepeatString),
348 Normalize(Normalize),
349 Trim(Trim),
350 TrimLeading(TrimLeading),
351 TrimTrailing(TrimTrailing),
352 EncodedBytesCharLength(EncodedBytesCharLength),
353 ListLengthMax(ListLengthMax),
354 ArrayContains(ArrayContains),
355 ArrayContainsArray(ArrayContainsArray),
356 ArrayContainsArrayRev(ArrayContainsArrayRev),
357 ArrayLength(ArrayLength),
358 ArrayLower(ArrayLower),
359 ArrayRemove(ArrayRemove),
360 ArrayUpper(ArrayUpper),
361 ArrayArrayConcat(ArrayArrayConcat),
362 ListListConcat(ListListConcat),
363 ListElementConcat(ListElementConcat),
364 ElementListConcat(ElementListConcat),
365 ListRemove(ListRemove),
366 ListContainsList(ListContainsList),
367 ListContainsListRev(ListContainsListRev),
368 DigestString(DigestString),
369 DigestBytes(DigestBytes),
370 MzRenderTypmod(MzRenderTypmod),
371 Encode(Encode),
372 Decode(Decode),
373 LogNumeric(LogBaseNumeric),
374 Power(Power),
375 PowerNumeric(PowerNumeric),
376 GetBit(GetBit),
377 GetByte(GetByte),
378 ConstantTimeEqBytes(ConstantTimeEqBytes),
379 ConstantTimeEqString(ConstantTimeEqString),
380 RangeContainsDate(RangeContainsDate),
381 RangeContainsDateRev(RangeContainsDateRev),
382 RangeContainsI32(RangeContainsI32),
383 RangeContainsI32Rev(RangeContainsI32Rev),
384 RangeContainsI64(RangeContainsI64),
385 RangeContainsI64Rev(RangeContainsI64Rev),
386 RangeContainsNumeric(RangeContainsNumeric),
387 RangeContainsNumericRev(RangeContainsNumericRev),
388 RangeContainsRange(RangeContainsRange),
389 RangeContainsRangeRev(RangeContainsRangeRev),
390 RangeContainsTimestamp(RangeContainsTimestamp),
391 RangeContainsTimestampRev(RangeContainsTimestampRev),
392 RangeContainsTimestampTz(RangeContainsTimestampTz),
393 RangeContainsTimestampTzRev(RangeContainsTimestampTzRev),
394 RangeOverlaps(RangeOverlaps),
395 RangeAfter(RangeAfter),
396 RangeBefore(RangeBefore),
397 RangeOverleft(RangeOverleft),
398 RangeOverright(RangeOverright),
399 RangeAdjacent(RangeAdjacent),
400 RangeUnion(RangeUnion),
401 RangeIntersection(RangeIntersection),
402 RangeDifference(RangeDifference),
403 UuidGenerateV5(UuidGenerateV5),
404 MzAclItemContainsPrivilege(MzAclItemContainsPrivilege),
405 ParseIdent(ParseIdent),
406 PrettySql(PrettySql),
407 RegexpReplace(RegexpReplace),
408 StartsWith(StartsWith),
409 }
410}
411
412#[cfg(test)]
413mod test {
414 use mz_expr_derive::sqlfunc;
415 use mz_repr::SqlScalarType;
416
417 use crate::EvalError;
418 use crate::scalar::func::binary::LazyBinaryFunc;
419
420 #[sqlfunc(sqlname = "INFALLIBLE", is_infix_op = true, test = true)]
421 #[allow(dead_code)]
422 fn infallible1(a: f32, b: f32) -> f32 {
423 a + b
424 }
425
426 #[sqlfunc(test = true)]
427 #[allow(dead_code)]
428 fn infallible2(a: Option<f32>, b: Option<f32>) -> f32 {
429 a.unwrap_or_default() + b.unwrap_or_default()
430 }
431
432 #[sqlfunc(test = true)]
433 #[allow(dead_code)]
434 fn infallible3(a: f32, b: f32) -> Option<f32> {
435 Some(a + b)
436 }
437
438 #[mz_ore::test]
439 fn elision_rules_infallible() {
440 assert_eq!(format!("{}", Infallible1), "INFALLIBLE");
441 assert!(Infallible1.propagates_nulls());
442 assert!(!Infallible1.introduces_nulls());
443
444 assert!(!Infallible2.propagates_nulls());
445 assert!(!Infallible2.introduces_nulls());
446
447 assert!(Infallible3.propagates_nulls());
448 assert!(Infallible3.introduces_nulls());
449 }
450
451 #[mz_ore::test]
452 fn output_types_infallible() {
453 assert_eq!(
454 Infallible1.output_sql_type(&[
455 SqlScalarType::Float32.nullable(true),
456 SqlScalarType::Float32.nullable(true)
457 ]),
458 SqlScalarType::Float32.nullable(true)
459 );
460 assert_eq!(
461 Infallible1.output_sql_type(&[
462 SqlScalarType::Float32.nullable(true),
463 SqlScalarType::Float32.nullable(false)
464 ]),
465 SqlScalarType::Float32.nullable(true)
466 );
467 assert_eq!(
468 Infallible1.output_sql_type(&[
469 SqlScalarType::Float32.nullable(false),
470 SqlScalarType::Float32.nullable(true)
471 ]),
472 SqlScalarType::Float32.nullable(true)
473 );
474 assert_eq!(
475 Infallible1.output_sql_type(&[
476 SqlScalarType::Float32.nullable(false),
477 SqlScalarType::Float32.nullable(false)
478 ]),
479 SqlScalarType::Float32.nullable(false)
480 );
481
482 assert_eq!(
483 Infallible2.output_sql_type(&[
484 SqlScalarType::Float32.nullable(true),
485 SqlScalarType::Float32.nullable(true)
486 ]),
487 SqlScalarType::Float32.nullable(false)
488 );
489 assert_eq!(
490 Infallible2.output_sql_type(&[
491 SqlScalarType::Float32.nullable(true),
492 SqlScalarType::Float32.nullable(false)
493 ]),
494 SqlScalarType::Float32.nullable(false)
495 );
496 assert_eq!(
497 Infallible2.output_sql_type(&[
498 SqlScalarType::Float32.nullable(false),
499 SqlScalarType::Float32.nullable(true)
500 ]),
501 SqlScalarType::Float32.nullable(false)
502 );
503 assert_eq!(
504 Infallible2.output_sql_type(&[
505 SqlScalarType::Float32.nullable(false),
506 SqlScalarType::Float32.nullable(false)
507 ]),
508 SqlScalarType::Float32.nullable(false)
509 );
510
511 assert_eq!(
512 Infallible3.output_sql_type(&[
513 SqlScalarType::Float32.nullable(true),
514 SqlScalarType::Float32.nullable(true)
515 ]),
516 SqlScalarType::Float32.nullable(true)
517 );
518 assert_eq!(
519 Infallible3.output_sql_type(&[
520 SqlScalarType::Float32.nullable(true),
521 SqlScalarType::Float32.nullable(false)
522 ]),
523 SqlScalarType::Float32.nullable(true)
524 );
525 assert_eq!(
526 Infallible3.output_sql_type(&[
527 SqlScalarType::Float32.nullable(false),
528 SqlScalarType::Float32.nullable(true)
529 ]),
530 SqlScalarType::Float32.nullable(true)
531 );
532 assert_eq!(
533 Infallible3.output_sql_type(&[
534 SqlScalarType::Float32.nullable(false),
535 SqlScalarType::Float32.nullable(false)
536 ]),
537 SqlScalarType::Float32.nullable(true)
538 );
539 }
540
541 #[sqlfunc(test = true)]
542 #[allow(dead_code)]
543 fn fallible1(a: f32, b: f32) -> Result<f32, EvalError> {
544 Ok(a + b)
545 }
546
547 #[sqlfunc(test = true)]
548 #[allow(dead_code)]
549 fn fallible2(a: Option<f32>, b: Option<f32>) -> Result<f32, EvalError> {
550 Ok(a.unwrap_or_default() + b.unwrap_or_default())
551 }
552
553 #[sqlfunc(test = true)]
554 #[allow(dead_code)]
555 fn fallible3(a: f32, b: f32) -> Result<Option<f32>, EvalError> {
556 Ok(Some(a + b))
557 }
558
559 #[mz_ore::test]
560 fn elision_rules_fallible() {
561 assert!(Fallible1.propagates_nulls());
562 assert!(!Fallible1.introduces_nulls());
563
564 assert!(!Fallible2.propagates_nulls());
565 assert!(!Fallible2.introduces_nulls());
566
567 assert!(Fallible3.propagates_nulls());
568 assert!(Fallible3.introduces_nulls());
569 }
570
571 #[mz_ore::test]
572 fn output_types_fallible() {
573 assert_eq!(
574 Fallible1.output_sql_type(&[
575 SqlScalarType::Float32.nullable(true),
576 SqlScalarType::Float32.nullable(true)
577 ]),
578 SqlScalarType::Float32.nullable(true)
579 );
580 assert_eq!(
581 Fallible1.output_sql_type(&[
582 SqlScalarType::Float32.nullable(true),
583 SqlScalarType::Float32.nullable(false)
584 ]),
585 SqlScalarType::Float32.nullable(true)
586 );
587 assert_eq!(
588 Fallible1.output_sql_type(&[
589 SqlScalarType::Float32.nullable(false),
590 SqlScalarType::Float32.nullable(true)
591 ]),
592 SqlScalarType::Float32.nullable(true)
593 );
594 assert_eq!(
595 Fallible1.output_sql_type(&[
596 SqlScalarType::Float32.nullable(false),
597 SqlScalarType::Float32.nullable(false)
598 ]),
599 SqlScalarType::Float32.nullable(false)
600 );
601
602 assert_eq!(
603 Fallible2.output_sql_type(&[
604 SqlScalarType::Float32.nullable(true),
605 SqlScalarType::Float32.nullable(true)
606 ]),
607 SqlScalarType::Float32.nullable(false)
608 );
609 assert_eq!(
610 Fallible2.output_sql_type(&[
611 SqlScalarType::Float32.nullable(true),
612 SqlScalarType::Float32.nullable(false)
613 ]),
614 SqlScalarType::Float32.nullable(false)
615 );
616 assert_eq!(
617 Fallible2.output_sql_type(&[
618 SqlScalarType::Float32.nullable(false),
619 SqlScalarType::Float32.nullable(true)
620 ]),
621 SqlScalarType::Float32.nullable(false)
622 );
623 assert_eq!(
624 Fallible2.output_sql_type(&[
625 SqlScalarType::Float32.nullable(false),
626 SqlScalarType::Float32.nullable(false)
627 ]),
628 SqlScalarType::Float32.nullable(false)
629 );
630
631 assert_eq!(
632 Fallible3.output_sql_type(&[
633 SqlScalarType::Float32.nullable(true),
634 SqlScalarType::Float32.nullable(true)
635 ]),
636 SqlScalarType::Float32.nullable(true)
637 );
638 assert_eq!(
639 Fallible3.output_sql_type(&[
640 SqlScalarType::Float32.nullable(true),
641 SqlScalarType::Float32.nullable(false)
642 ]),
643 SqlScalarType::Float32.nullable(true)
644 );
645 assert_eq!(
646 Fallible3.output_sql_type(&[
647 SqlScalarType::Float32.nullable(false),
648 SqlScalarType::Float32.nullable(true)
649 ]),
650 SqlScalarType::Float32.nullable(true)
651 );
652 assert_eq!(
653 Fallible3.output_sql_type(&[
654 SqlScalarType::Float32.nullable(false),
655 SqlScalarType::Float32.nullable(false)
656 ]),
657 SqlScalarType::Float32.nullable(true)
658 );
659 }
660
661 #[mz_ore::test]
662 fn mz_reflect_binary_func() {
663 use crate::BinaryFunc;
664 use mz_lowertest::{MzReflect, ReflectedTypeInfo};
665
666 let mut rti = ReflectedTypeInfo::default();
667 BinaryFunc::add_to_reflected_type_info(&mut rti);
668
669 let variants = rti
671 .enum_dict
672 .get("BinaryFunc")
673 .expect("BinaryFunc should be in enum_dict");
674 assert!(
675 variants.contains_key("AddInt64"),
676 "AddInt64 variant should exist"
677 );
678 assert!(variants.contains_key("Gte"), "Gte variant should exist");
679
680 assert!(
682 rti.struct_dict.contains_key("AddInt64"),
683 "AddInt64 should be in struct_dict"
684 );
685 assert!(
686 rti.struct_dict.contains_key("Gte"),
687 "Gte should be in struct_dict"
688 );
689
690 let (names, types) = rti.struct_dict.get("AddInt64").unwrap();
692 assert!(names.is_empty(), "AddInt64 should have no field names");
693 assert!(types.is_empty(), "AddInt64 should have no field types");
694 }
695}