1use mz_ore::assert_none;
13use mz_repr::{Datum, InputDatumType, OutputDatumType, RowArena, SqlColumnType};
14
15use crate::{EvalError, MirScalarExpr};
16
17pub(crate) trait LazyBinaryFunc {
20 fn eval<'a>(
21 &'a self,
22 datums: &[Datum<'a>],
23 temp_storage: &'a RowArena,
24 exprs: &[&'a MirScalarExpr],
25 ) -> Result<Datum<'a>, EvalError>;
26
27 fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType;
29
30 fn propagates_nulls(&self) -> bool;
32
33 fn introduces_nulls(&self) -> bool;
35
36 fn could_error(&self) -> bool {
38 true
40 }
41
42 fn negate(&self) -> Option<crate::BinaryFunc>;
44
45 fn is_monotone(&self) -> (bool, bool);
58
59 fn is_infix_op(&self) -> bool;
61}
62
63pub(crate) trait EagerBinaryFunc {
64 type Input<'a>: InputDatumType<'a, EvalError>;
65 type Output<'a>: OutputDatumType<'a, EvalError>;
66
67 fn call<'a>(&self, input: Self::Input<'a>, temp_storage: &'a RowArena) -> Self::Output<'a>;
68
69 fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType;
71
72 fn propagates_nulls(&self) -> bool {
74 !Self::Input::nullable()
76 }
77
78 fn introduces_nulls(&self) -> bool {
80 Self::Output::nullable()
82 }
83
84 fn could_error(&self) -> bool {
86 Self::Output::fallible()
87 }
88
89 fn negate(&self) -> Option<crate::BinaryFunc> {
91 None
92 }
93
94 fn is_monotone(&self) -> (bool, bool) {
95 (false, false)
96 }
97
98 fn is_infix_op(&self) -> bool {
99 false
100 }
101}
102
103impl<T: EagerBinaryFunc> LazyBinaryFunc for T {
104 fn eval<'a>(
105 &'a self,
106 datums: &[Datum<'a>],
107 temp_storage: &'a RowArena,
108 exprs: &[&'a MirScalarExpr],
109 ) -> Result<Datum<'a>, EvalError> {
110 let mut datums = exprs
111 .into_iter()
112 .map(|expr| expr.eval(datums, temp_storage));
113 let input = match T::Input::try_from_iter(&mut datums) {
114 Ok(input) => input,
116 Err(Ok(Some(datum))) if !datum.is_null() => {
118 return Err(EvalError::Internal("invalid input type".into()));
119 }
120 Err(Ok(None)) => {
121 return Err(EvalError::Internal("unexpectedly missing parameter".into()));
122 }
123 Err(Ok(Some(datum))) => return Ok(datum),
125 Err(Err(res)) => return Err(res),
126 };
127 assert_none!(datums.next(), "No leftover input arguments");
128 self.call(input, temp_storage).into_result(temp_storage)
129 }
130
131 fn output_type(&self, input_types: &[SqlColumnType]) -> SqlColumnType {
132 self.output_type(input_types)
133 }
134
135 fn propagates_nulls(&self) -> bool {
136 self.propagates_nulls()
137 }
138
139 fn introduces_nulls(&self) -> bool {
140 self.introduces_nulls()
141 }
142
143 fn could_error(&self) -> bool {
144 self.could_error()
145 }
146
147 fn negate(&self) -> Option<crate::BinaryFunc> {
148 self.negate()
149 }
150
151 fn is_monotone(&self) -> (bool, bool) {
152 self.is_monotone()
153 }
154
155 fn is_infix_op(&self) -> bool {
156 self.is_infix_op()
157 }
158}
159
160pub use derive::BinaryFunc;
161
162mod derive {
163 use std::fmt;
164
165 use mz_repr::{Datum, RowArena, SqlColumnType};
166
167 use crate::scalar::func::binary::LazyBinaryFunc;
168 use crate::scalar::func::*;
169 use crate::{EvalError, MirScalarExpr};
170
171 derive_binary! {
172 AddInt16(AddInt16),
173 AddInt32(AddInt32),
174 AddInt64(AddInt64),
175 AddUint16(AddUint16),
176 AddUint32(AddUint32),
177 AddUint64(AddUint64),
178 AddFloat32(AddFloat32),
179 AddFloat64(AddFloat64),
180 AddInterval(AddInterval),
181 AddTimestampInterval(AddTimestampInterval),
182 AddTimestampTzInterval(AddTimestampTzInterval),
183 AddDateInterval(AddDateInterval),
184 AddDateTime(AddDateTime),
185 AddTimeInterval(AddTimeInterval),
186 AddNumeric(AddNumeric),
187 AgeTimestamp(AgeTimestamp),
188 AgeTimestampTz(AgeTimestampTz),
189 BitAndInt16(BitAndInt16),
190 BitAndInt32(BitAndInt32),
191 BitAndInt64(BitAndInt64),
192 BitAndUint16(BitAndUint16),
193 BitAndUint32(BitAndUint32),
194 BitAndUint64(BitAndUint64),
195 BitOrInt16(BitOrInt16),
196 BitOrInt32(BitOrInt32),
197 BitOrInt64(BitOrInt64),
198 BitOrUint16(BitOrUint16),
199 BitOrUint32(BitOrUint32),
200 BitOrUint64(BitOrUint64),
201 BitXorInt16(BitXorInt16),
202 BitXorInt32(BitXorInt32),
203 BitXorInt64(BitXorInt64),
204 BitXorUint16(BitXorUint16),
205 BitXorUint32(BitXorUint32),
206 BitXorUint64(BitXorUint64),
207 BitShiftLeftInt16(BitShiftLeftInt16),
208 BitShiftLeftInt32(BitShiftLeftInt32),
209 BitShiftLeftInt64(BitShiftLeftInt64),
210 BitShiftLeftUint16(BitShiftLeftUint16),
211 BitShiftLeftUint32(BitShiftLeftUint32),
212 BitShiftLeftUint64(BitShiftLeftUint64),
213 BitShiftRightInt16(BitShiftRightInt16),
214 BitShiftRightInt32(BitShiftRightInt32),
215 BitShiftRightInt64(BitShiftRightInt64),
216 BitShiftRightUint16(BitShiftRightUint16),
217 BitShiftRightUint32(BitShiftRightUint32),
218 BitShiftRightUint64(BitShiftRightUint64),
219 SubInt16(SubInt16),
220 SubInt32(SubInt32),
221 SubInt64(SubInt64),
222 SubUint16(SubUint16),
223 SubUint32(SubUint32),
224 SubUint64(SubUint64),
225 SubFloat32(SubFloat32),
226 SubFloat64(SubFloat64),
227 SubInterval(SubInterval),
228 SubTimestamp(SubTimestamp),
229 SubTimestampTz(SubTimestampTz),
230 SubTimestampInterval(SubTimestampInterval),
231 SubTimestampTzInterval(SubTimestampTzInterval),
232 SubDate(SubDate),
233 SubDateInterval(SubDateInterval),
234 SubTime(SubTime),
235 SubTimeInterval(SubTimeInterval),
236 SubNumeric(SubNumeric),
237 MulInt16(MulInt16),
238 MulInt32(MulInt32),
239 MulInt64(MulInt64),
240 MulUint16(MulUint16),
241 MulUint32(MulUint32),
242 MulUint64(MulUint64),
243 MulFloat32(MulFloat32),
244 MulFloat64(MulFloat64),
245 MulNumeric(MulNumeric),
246 MulInterval(MulInterval),
247 DivInt16(DivInt16),
248 DivInt32(DivInt32),
249 DivInt64(DivInt64),
250 DivUint16(DivUint16),
251 DivUint32(DivUint32),
252 DivUint64(DivUint64),
253 DivFloat32(DivFloat32),
254 DivFloat64(DivFloat64),
255 DivNumeric(DivNumeric),
256 DivInterval(DivInterval),
257 ModInt16(ModInt16),
258 ModInt32(ModInt32),
259 ModInt64(ModInt64),
260 ModUint16(ModUint16),
261 ModUint32(ModUint32),
262 ModUint64(ModUint64),
263 ModFloat32(ModFloat32),
264 ModFloat64(ModFloat64),
265 ModNumeric(ModNumeric),
266 RoundNumeric(RoundNumericBinary),
267 Eq(Eq),
268 NotEq(NotEq),
269 Lt(Lt),
270 Lte(Lte),
271 Gt(Gt),
272 Gte(Gte),
273 LikeEscape(LikeEscape),
274 IsLikeMatchCaseInsensitive(IsLikeMatchCaseInsensitive),
275 IsLikeMatchCaseSensitive(IsLikeMatchCaseSensitive),
276 IsRegexpMatchCaseSensitive(IsRegexpMatchCaseSensitive),
277 IsRegexpMatchCaseInsensitive(IsRegexpMatchCaseInsensitive),
278 ToCharTimestamp(ToCharTimestampFormat),
279 ToCharTimestampTz(ToCharTimestampTzFormat),
280 DateBinTimestamp(DateBinTimestamp),
281 DateBinTimestampTz(DateBinTimestampTz),
282 ExtractInterval(DatePartIntervalNumeric),
283 ExtractTime(DatePartTimeNumeric),
284 ExtractTimestamp(DatePartTimestampTimestampNumeric),
285 ExtractTimestampTz(DatePartTimestampTimestampTzNumeric),
286 ExtractDate(ExtractDateUnits),
287 DatePartInterval(DatePartIntervalF64),
288 DatePartTime(DatePartTimeF64),
289 DatePartTimestamp(DatePartTimestampTimestampF64),
290 DatePartTimestampTz(DatePartTimestampTimestampTzF64),
291 DateTruncTimestamp(DateTruncUnitsTimestamp),
292 DateTruncTimestampTz(DateTruncUnitsTimestampTz),
293 DateTruncInterval(DateTruncInterval),
294 TimezoneTimestampBinary(TimezoneTimestampBinary),
295 TimezoneTimestampTzBinary(TimezoneTimestampTzBinary),
296 TimezoneIntervalTimestampBinary(TimezoneIntervalTimestampBinary),
297 TimezoneIntervalTimestampTzBinary(TimezoneIntervalTimestampTzBinary),
298 TimezoneIntervalTimeBinary(TimezoneIntervalTimeBinary),
299 TimezoneOffset(TimezoneOffset),
300 TextConcat(TextConcatBinary),
301 JsonbGetInt64(JsonbGetInt64),
302 JsonbGetInt64Stringify(JsonbGetInt64Stringify),
303 JsonbGetString(JsonbGetString),
304 JsonbGetStringStringify(JsonbGetStringStringify),
305 JsonbGetPath(JsonbGetPath),
306 JsonbGetPathStringify(JsonbGetPathStringify),
307 JsonbContainsString(JsonbContainsString),
308 JsonbConcat(JsonbConcat),
309 JsonbContainsJsonb(JsonbContainsJsonb),
310 JsonbDeleteInt64(JsonbDeleteInt64),
311 JsonbDeleteString(JsonbDeleteString),
312 MapContainsKey(MapContainsKey),
313 MapGetValue(MapGetValue),
314 MapContainsAllKeys(MapContainsAllKeys),
315 MapContainsAnyKeys(MapContainsAnyKeys),
316 MapContainsMap(MapContainsMap),
317 ConvertFrom(ConvertFrom),
318 Left(Left),
319 Position(Position),
320 Strpos(Strpos),
321 Right(Right),
322 RepeatString(RepeatString),
323 Normalize(Normalize),
324 Trim(Trim),
325 TrimLeading(TrimLeading),
326 TrimTrailing(TrimTrailing),
327 EncodedBytesCharLength(EncodedBytesCharLength),
328 ListLengthMax(ListLengthMax),
329 ArrayContains(ArrayContains),
330 ArrayContainsArray(ArrayContainsArray),
331 ArrayContainsArrayRev(ArrayContainsArrayRev),
332 ArrayLength(ArrayLength),
333 ArrayLower(ArrayLower),
334 ArrayRemove(ArrayRemove),
335 ArrayUpper(ArrayUpper),
336 ArrayArrayConcat(ArrayArrayConcat),
337 ListListConcat(ListListConcat),
338 ListElementConcat(ListElementConcat),
339 ElementListConcat(ElementListConcat),
340 ListRemove(ListRemove),
341 ListContainsList(ListContainsList),
342 ListContainsListRev(ListContainsListRev),
343 DigestString(DigestString),
344 DigestBytes(DigestBytes),
345 MzRenderTypmod(MzRenderTypmod),
346 Encode(Encode),
347 Decode(Decode),
348 LogNumeric(LogBaseNumeric),
349 Power(Power),
350 PowerNumeric(PowerNumeric),
351 GetBit(GetBit),
352 GetByte(GetByte),
353 ConstantTimeEqBytes(ConstantTimeEqBytes),
354 ConstantTimeEqString(ConstantTimeEqString),
355 RangeContainsDate(RangeContainsDate),
356 RangeContainsDateRev(RangeContainsDateRev),
357 RangeContainsI32(RangeContainsI32),
358 RangeContainsI32Rev(RangeContainsI32Rev),
359 RangeContainsI64(RangeContainsI64),
360 RangeContainsI64Rev(RangeContainsI64Rev),
361 RangeContainsNumeric(RangeContainsNumeric),
362 RangeContainsNumericRev(RangeContainsNumericRev),
363 RangeContainsRange(RangeContainsRange),
364 RangeContainsRangeRev(RangeContainsRangeRev),
365 RangeContainsTimestamp(RangeContainsTimestamp),
366 RangeContainsTimestampRev(RangeContainsTimestampRev),
367 RangeContainsTimestampTz(RangeContainsTimestampTz),
368 RangeContainsTimestampTzRev(RangeContainsTimestampTzRev),
369 RangeOverlaps(RangeOverlaps),
370 RangeAfter(RangeAfter),
371 RangeBefore(RangeBefore),
372 RangeOverleft(RangeOverleft),
373 RangeOverright(RangeOverright),
374 RangeAdjacent(RangeAdjacent),
375 RangeUnion(RangeUnion),
376 RangeIntersection(RangeIntersection),
377 RangeDifference(RangeDifference),
378 UuidGenerateV5(UuidGenerateV5),
379 MzAclItemContainsPrivilege(MzAclItemContainsPrivilege),
380 ParseIdent(ParseIdent),
381 PrettySql(PrettySql),
382 RegexpReplace(RegexpReplace),
383 StartsWith(StartsWith),
384 }
385}
386
387#[cfg(test)]
388mod test {
389 use mz_expr_derive::sqlfunc;
390 use mz_repr::SqlScalarType;
391
392 use crate::EvalError;
393 use crate::scalar::func::binary::LazyBinaryFunc;
394
395 #[sqlfunc(sqlname = "INFALLIBLE", is_infix_op = true, test = true)]
396 #[allow(dead_code)]
397 fn infallible1(a: f32, b: f32) -> f32 {
398 a + b
399 }
400
401 #[sqlfunc(test = true)]
402 #[allow(dead_code)]
403 fn infallible2(a: Option<f32>, b: Option<f32>) -> f32 {
404 a.unwrap_or_default() + b.unwrap_or_default()
405 }
406
407 #[sqlfunc(test = true)]
408 #[allow(dead_code)]
409 fn infallible3(a: f32, b: f32) -> Option<f32> {
410 Some(a + b)
411 }
412
413 #[mz_ore::test]
414 fn elision_rules_infallible() {
415 assert_eq!(format!("{}", Infallible1), "INFALLIBLE");
416 assert!(Infallible1.propagates_nulls());
417 assert!(!Infallible1.introduces_nulls());
418
419 assert!(!Infallible2.propagates_nulls());
420 assert!(!Infallible2.introduces_nulls());
421
422 assert!(Infallible3.propagates_nulls());
423 assert!(Infallible3.introduces_nulls());
424 }
425
426 #[mz_ore::test]
427 fn output_types_infallible() {
428 assert_eq!(
429 Infallible1.output_type(&[
430 SqlScalarType::Float32.nullable(true),
431 SqlScalarType::Float32.nullable(true)
432 ]),
433 SqlScalarType::Float32.nullable(true)
434 );
435 assert_eq!(
436 Infallible1.output_type(&[
437 SqlScalarType::Float32.nullable(true),
438 SqlScalarType::Float32.nullable(false)
439 ]),
440 SqlScalarType::Float32.nullable(true)
441 );
442 assert_eq!(
443 Infallible1.output_type(&[
444 SqlScalarType::Float32.nullable(false),
445 SqlScalarType::Float32.nullable(true)
446 ]),
447 SqlScalarType::Float32.nullable(true)
448 );
449 assert_eq!(
450 Infallible1.output_type(&[
451 SqlScalarType::Float32.nullable(false),
452 SqlScalarType::Float32.nullable(false)
453 ]),
454 SqlScalarType::Float32.nullable(false)
455 );
456
457 assert_eq!(
458 Infallible2.output_type(&[
459 SqlScalarType::Float32.nullable(true),
460 SqlScalarType::Float32.nullable(true)
461 ]),
462 SqlScalarType::Float32.nullable(false)
463 );
464 assert_eq!(
465 Infallible2.output_type(&[
466 SqlScalarType::Float32.nullable(true),
467 SqlScalarType::Float32.nullable(false)
468 ]),
469 SqlScalarType::Float32.nullable(false)
470 );
471 assert_eq!(
472 Infallible2.output_type(&[
473 SqlScalarType::Float32.nullable(false),
474 SqlScalarType::Float32.nullable(true)
475 ]),
476 SqlScalarType::Float32.nullable(false)
477 );
478 assert_eq!(
479 Infallible2.output_type(&[
480 SqlScalarType::Float32.nullable(false),
481 SqlScalarType::Float32.nullable(false)
482 ]),
483 SqlScalarType::Float32.nullable(false)
484 );
485
486 assert_eq!(
487 Infallible3.output_type(&[
488 SqlScalarType::Float32.nullable(true),
489 SqlScalarType::Float32.nullable(true)
490 ]),
491 SqlScalarType::Float32.nullable(true)
492 );
493 assert_eq!(
494 Infallible3.output_type(&[
495 SqlScalarType::Float32.nullable(true),
496 SqlScalarType::Float32.nullable(false)
497 ]),
498 SqlScalarType::Float32.nullable(true)
499 );
500 assert_eq!(
501 Infallible3.output_type(&[
502 SqlScalarType::Float32.nullable(false),
503 SqlScalarType::Float32.nullable(true)
504 ]),
505 SqlScalarType::Float32.nullable(true)
506 );
507 assert_eq!(
508 Infallible3.output_type(&[
509 SqlScalarType::Float32.nullable(false),
510 SqlScalarType::Float32.nullable(false)
511 ]),
512 SqlScalarType::Float32.nullable(true)
513 );
514 }
515
516 #[sqlfunc(test = true)]
517 #[allow(dead_code)]
518 fn fallible1(a: f32, b: f32) -> Result<f32, EvalError> {
519 Ok(a + b)
520 }
521
522 #[sqlfunc(test = true)]
523 #[allow(dead_code)]
524 fn fallible2(a: Option<f32>, b: Option<f32>) -> Result<f32, EvalError> {
525 Ok(a.unwrap_or_default() + b.unwrap_or_default())
526 }
527
528 #[sqlfunc(test = true)]
529 #[allow(dead_code)]
530 fn fallible3(a: f32, b: f32) -> Result<Option<f32>, EvalError> {
531 Ok(Some(a + b))
532 }
533
534 #[mz_ore::test]
535 fn elision_rules_fallible() {
536 assert!(Fallible1.propagates_nulls());
537 assert!(!Fallible1.introduces_nulls());
538
539 assert!(!Fallible2.propagates_nulls());
540 assert!(!Fallible2.introduces_nulls());
541
542 assert!(Fallible3.propagates_nulls());
543 assert!(Fallible3.introduces_nulls());
544 }
545
546 #[mz_ore::test]
547 fn output_types_fallible() {
548 assert_eq!(
549 Fallible1.output_type(&[
550 SqlScalarType::Float32.nullable(true),
551 SqlScalarType::Float32.nullable(true)
552 ]),
553 SqlScalarType::Float32.nullable(true)
554 );
555 assert_eq!(
556 Fallible1.output_type(&[
557 SqlScalarType::Float32.nullable(true),
558 SqlScalarType::Float32.nullable(false)
559 ]),
560 SqlScalarType::Float32.nullable(true)
561 );
562 assert_eq!(
563 Fallible1.output_type(&[
564 SqlScalarType::Float32.nullable(false),
565 SqlScalarType::Float32.nullable(true)
566 ]),
567 SqlScalarType::Float32.nullable(true)
568 );
569 assert_eq!(
570 Fallible1.output_type(&[
571 SqlScalarType::Float32.nullable(false),
572 SqlScalarType::Float32.nullable(false)
573 ]),
574 SqlScalarType::Float32.nullable(false)
575 );
576
577 assert_eq!(
578 Fallible2.output_type(&[
579 SqlScalarType::Float32.nullable(true),
580 SqlScalarType::Float32.nullable(true)
581 ]),
582 SqlScalarType::Float32.nullable(false)
583 );
584 assert_eq!(
585 Fallible2.output_type(&[
586 SqlScalarType::Float32.nullable(true),
587 SqlScalarType::Float32.nullable(false)
588 ]),
589 SqlScalarType::Float32.nullable(false)
590 );
591 assert_eq!(
592 Fallible2.output_type(&[
593 SqlScalarType::Float32.nullable(false),
594 SqlScalarType::Float32.nullable(true)
595 ]),
596 SqlScalarType::Float32.nullable(false)
597 );
598 assert_eq!(
599 Fallible2.output_type(&[
600 SqlScalarType::Float32.nullable(false),
601 SqlScalarType::Float32.nullable(false)
602 ]),
603 SqlScalarType::Float32.nullable(false)
604 );
605
606 assert_eq!(
607 Fallible3.output_type(&[
608 SqlScalarType::Float32.nullable(true),
609 SqlScalarType::Float32.nullable(true)
610 ]),
611 SqlScalarType::Float32.nullable(true)
612 );
613 assert_eq!(
614 Fallible3.output_type(&[
615 SqlScalarType::Float32.nullable(true),
616 SqlScalarType::Float32.nullable(false)
617 ]),
618 SqlScalarType::Float32.nullable(true)
619 );
620 assert_eq!(
621 Fallible3.output_type(&[
622 SqlScalarType::Float32.nullable(false),
623 SqlScalarType::Float32.nullable(true)
624 ]),
625 SqlScalarType::Float32.nullable(true)
626 );
627 assert_eq!(
628 Fallible3.output_type(&[
629 SqlScalarType::Float32.nullable(false),
630 SqlScalarType::Float32.nullable(false)
631 ]),
632 SqlScalarType::Float32.nullable(true)
633 );
634 }
635
636 #[mz_ore::test]
637 fn mz_reflect_binary_func() {
638 use crate::BinaryFunc;
639 use mz_lowertest::{MzReflect, ReflectedTypeInfo};
640
641 let mut rti = ReflectedTypeInfo::default();
642 BinaryFunc::add_to_reflected_type_info(&mut rti);
643
644 let variants = rti
646 .enum_dict
647 .get("BinaryFunc")
648 .expect("BinaryFunc should be in enum_dict");
649 assert!(
650 variants.contains_key("AddInt64"),
651 "AddInt64 variant should exist"
652 );
653 assert!(variants.contains_key("Gte"), "Gte variant should exist");
654
655 assert!(
657 rti.struct_dict.contains_key("AddInt64"),
658 "AddInt64 should be in struct_dict"
659 );
660 assert!(
661 rti.struct_dict.contains_key("Gte"),
662 "Gte should be in struct_dict"
663 );
664
665 let (names, types) = rti.struct_dict.get("AddInt64").unwrap();
667 assert!(names.is_empty(), "AddInt64 should have no field names");
668 assert!(types.is_empty(), "AddInt64 should have no field types");
669 }
670}