1use mz_repr::{ColumnType, Datum, DatumType, RowArena};
13
14use crate::{EvalError, MirScalarExpr};
15
16#[allow(unused)]
19pub(crate) trait LazyBinaryFunc {
20 fn eval<'a>(
21 &'a self,
22 datums: &[Datum<'a>],
23 temp_storage: &'a RowArena,
24 a: &'a MirScalarExpr,
25 b: &'a MirScalarExpr,
26 ) -> Result<Datum<'a>, EvalError>;
27
28 fn output_type(&self, input_type_a: ColumnType, input_type_b: ColumnType) -> ColumnType;
30
31 fn propagates_nulls(&self) -> bool;
33
34 fn introduces_nulls(&self) -> bool;
36
37 fn could_error(&self) -> bool {
39 true
41 }
42
43 fn negate(&self) -> Option<crate::BinaryFunc>;
45
46 fn is_monotone(&self) -> (bool, bool);
59
60 fn is_infix_op(&self) -> bool;
62}
63
64#[allow(unused)]
65pub(crate) trait EagerBinaryFunc<'a> {
66 type Input1: DatumType<'a, EvalError>;
67 type Input2: DatumType<'a, EvalError>;
68 type Output: DatumType<'a, EvalError>;
69
70 fn call(&self, a: Self::Input1, b: Self::Input2, temp_storage: &'a RowArena) -> Self::Output;
71
72 fn output_type(&self, input_type_a: ColumnType, input_type_b: ColumnType) -> ColumnType;
74
75 fn propagates_nulls(&self) -> bool {
77 !Self::Input1::nullable() && !Self::Input2::nullable()
79 }
80
81 fn introduces_nulls(&self) -> bool {
83 Self::Output::nullable()
85 }
86
87 fn could_error(&self) -> bool {
89 Self::Output::fallible()
90 }
91
92 fn negate(&self) -> Option<crate::BinaryFunc> {
94 None
95 }
96
97 fn is_monotone(&self) -> (bool, bool) {
98 (false, false)
99 }
100
101 fn is_infix_op(&self) -> bool {
102 false
103 }
104}
105
106impl<T: for<'a> EagerBinaryFunc<'a>> LazyBinaryFunc for T {
107 fn eval<'a>(
108 &'a self,
109 datums: &[Datum<'a>],
110 temp_storage: &'a RowArena,
111 a: &'a MirScalarExpr,
112 b: &'a MirScalarExpr,
113 ) -> Result<Datum<'a>, EvalError> {
114 let a = match T::Input1::try_from_result(a.eval(datums, temp_storage)) {
115 Ok(input) => input,
117 Err(Ok(datum)) if !datum.is_null() => {
119 return Err(EvalError::Internal("invalid input type".into()));
120 }
121 Err(res) => return res,
123 };
124 let b = match T::Input2::try_from_result(b.eval(datums, temp_storage)) {
125 Ok(input) => input,
127 Err(Ok(datum)) if !datum.is_null() => {
129 return Err(EvalError::Internal("invalid input type".into()));
130 }
131 Err(res) => return res,
133 };
134 self.call(a, b, temp_storage).into_result(temp_storage)
135 }
136
137 fn output_type(&self, input_type_a: ColumnType, input_type_b: ColumnType) -> ColumnType {
138 self.output_type(input_type_a, input_type_b)
139 }
140
141 fn propagates_nulls(&self) -> bool {
142 self.propagates_nulls()
143 }
144
145 fn introduces_nulls(&self) -> bool {
146 self.introduces_nulls()
147 }
148
149 fn could_error(&self) -> bool {
150 self.could_error()
151 }
152
153 fn negate(&self) -> Option<crate::BinaryFunc> {
154 self.negate()
155 }
156
157 fn is_monotone(&self) -> (bool, bool) {
158 self.is_monotone()
159 }
160
161 fn is_infix_op(&self) -> bool {
162 self.is_infix_op()
163 }
164}
165
166#[cfg(test)]
167mod test {
168 use mz_expr_derive::sqlfunc;
169 use mz_repr::ColumnType;
170 use mz_repr::ScalarType;
171
172 use crate::scalar::func::binary::LazyBinaryFunc;
173 use crate::{BinaryFunc, EvalError, func};
174
175 #[sqlfunc(sqlname = "INFALLIBLE", is_infix_op = true)]
176 fn infallible1(a: f32, b: f32) -> f32 {
177 a + b
178 }
179
180 #[sqlfunc]
181 fn infallible2(a: Option<f32>, b: Option<f32>) -> f32 {
182 a.unwrap_or_default() + b.unwrap_or_default()
183 }
184
185 #[sqlfunc]
186 fn infallible3(a: f32, b: f32) -> Option<f32> {
187 Some(a + b)
188 }
189
190 #[mz_ore::test]
191 fn elision_rules_infallible() {
192 assert_eq!(format!("{}", Infallible1), "INFALLIBLE");
193 assert!(Infallible1.propagates_nulls());
194 assert!(!Infallible1.introduces_nulls());
195
196 assert!(!Infallible2.propagates_nulls());
197 assert!(!Infallible2.introduces_nulls());
198
199 assert!(Infallible3.propagates_nulls());
200 assert!(Infallible3.introduces_nulls());
201 }
202
203 #[mz_ore::test]
204 fn output_types_infallible() {
205 assert_eq!(
206 Infallible1.output_type(
207 ScalarType::Float32.nullable(true),
208 ScalarType::Float32.nullable(true)
209 ),
210 ScalarType::Float32.nullable(true)
211 );
212 assert_eq!(
213 Infallible1.output_type(
214 ScalarType::Float32.nullable(true),
215 ScalarType::Float32.nullable(false)
216 ),
217 ScalarType::Float32.nullable(true)
218 );
219 assert_eq!(
220 Infallible1.output_type(
221 ScalarType::Float32.nullable(false),
222 ScalarType::Float32.nullable(true)
223 ),
224 ScalarType::Float32.nullable(true)
225 );
226 assert_eq!(
227 Infallible1.output_type(
228 ScalarType::Float32.nullable(false),
229 ScalarType::Float32.nullable(false)
230 ),
231 ScalarType::Float32.nullable(false)
232 );
233
234 assert_eq!(
235 Infallible2.output_type(
236 ScalarType::Float32.nullable(true),
237 ScalarType::Float32.nullable(true)
238 ),
239 ScalarType::Float32.nullable(false)
240 );
241 assert_eq!(
242 Infallible2.output_type(
243 ScalarType::Float32.nullable(true),
244 ScalarType::Float32.nullable(false)
245 ),
246 ScalarType::Float32.nullable(false)
247 );
248 assert_eq!(
249 Infallible2.output_type(
250 ScalarType::Float32.nullable(false),
251 ScalarType::Float32.nullable(true)
252 ),
253 ScalarType::Float32.nullable(false)
254 );
255 assert_eq!(
256 Infallible2.output_type(
257 ScalarType::Float32.nullable(false),
258 ScalarType::Float32.nullable(false)
259 ),
260 ScalarType::Float32.nullable(false)
261 );
262
263 assert_eq!(
264 Infallible3.output_type(
265 ScalarType::Float32.nullable(true),
266 ScalarType::Float32.nullable(true)
267 ),
268 ScalarType::Float32.nullable(true)
269 );
270 assert_eq!(
271 Infallible3.output_type(
272 ScalarType::Float32.nullable(true),
273 ScalarType::Float32.nullable(false)
274 ),
275 ScalarType::Float32.nullable(true)
276 );
277 assert_eq!(
278 Infallible3.output_type(
279 ScalarType::Float32.nullable(false),
280 ScalarType::Float32.nullable(true)
281 ),
282 ScalarType::Float32.nullable(true)
283 );
284 assert_eq!(
285 Infallible3.output_type(
286 ScalarType::Float32.nullable(false),
287 ScalarType::Float32.nullable(false)
288 ),
289 ScalarType::Float32.nullable(true)
290 );
291 }
292
293 #[sqlfunc]
294 fn fallible1(a: f32, b: f32) -> Result<f32, EvalError> {
295 Ok(a + b)
296 }
297
298 #[sqlfunc]
299 fn fallible2(a: Option<f32>, b: Option<f32>) -> Result<f32, EvalError> {
300 Ok(a.unwrap_or_default() + b.unwrap_or_default())
301 }
302
303 #[sqlfunc]
304 fn fallible3(a: f32, b: f32) -> Result<Option<f32>, EvalError> {
305 Ok(Some(a + b))
306 }
307
308 #[mz_ore::test]
309 fn elision_rules_fallible() {
310 assert!(Fallible1.propagates_nulls());
311 assert!(!Fallible1.introduces_nulls());
312
313 assert!(!Fallible2.propagates_nulls());
314 assert!(!Fallible2.introduces_nulls());
315
316 assert!(Fallible3.propagates_nulls());
317 assert!(Fallible3.introduces_nulls());
318 }
319
320 #[mz_ore::test]
321 fn output_types_fallible() {
322 assert_eq!(
323 Fallible1.output_type(
324 ScalarType::Float32.nullable(true),
325 ScalarType::Float32.nullable(true)
326 ),
327 ScalarType::Float32.nullable(true)
328 );
329 assert_eq!(
330 Fallible1.output_type(
331 ScalarType::Float32.nullable(true),
332 ScalarType::Float32.nullable(false)
333 ),
334 ScalarType::Float32.nullable(true)
335 );
336 assert_eq!(
337 Fallible1.output_type(
338 ScalarType::Float32.nullable(false),
339 ScalarType::Float32.nullable(true)
340 ),
341 ScalarType::Float32.nullable(true)
342 );
343 assert_eq!(
344 Fallible1.output_type(
345 ScalarType::Float32.nullable(false),
346 ScalarType::Float32.nullable(false)
347 ),
348 ScalarType::Float32.nullable(false)
349 );
350
351 assert_eq!(
352 Fallible2.output_type(
353 ScalarType::Float32.nullable(true),
354 ScalarType::Float32.nullable(true)
355 ),
356 ScalarType::Float32.nullable(false)
357 );
358 assert_eq!(
359 Fallible2.output_type(
360 ScalarType::Float32.nullable(true),
361 ScalarType::Float32.nullable(false)
362 ),
363 ScalarType::Float32.nullable(false)
364 );
365 assert_eq!(
366 Fallible2.output_type(
367 ScalarType::Float32.nullable(false),
368 ScalarType::Float32.nullable(true)
369 ),
370 ScalarType::Float32.nullable(false)
371 );
372 assert_eq!(
373 Fallible2.output_type(
374 ScalarType::Float32.nullable(false),
375 ScalarType::Float32.nullable(false)
376 ),
377 ScalarType::Float32.nullable(false)
378 );
379
380 assert_eq!(
381 Fallible3.output_type(
382 ScalarType::Float32.nullable(true),
383 ScalarType::Float32.nullable(true)
384 ),
385 ScalarType::Float32.nullable(true)
386 );
387 assert_eq!(
388 Fallible3.output_type(
389 ScalarType::Float32.nullable(true),
390 ScalarType::Float32.nullable(false)
391 ),
392 ScalarType::Float32.nullable(true)
393 );
394 assert_eq!(
395 Fallible3.output_type(
396 ScalarType::Float32.nullable(false),
397 ScalarType::Float32.nullable(true)
398 ),
399 ScalarType::Float32.nullable(true)
400 );
401 assert_eq!(
402 Fallible3.output_type(
403 ScalarType::Float32.nullable(false),
404 ScalarType::Float32.nullable(false)
405 ),
406 ScalarType::Float32.nullable(true)
407 );
408 }
409
410 #[mz_ore::test]
411 fn test_equivalence() {
412 #[track_caller]
413 fn check<T: LazyBinaryFunc + std::fmt::Display>(
414 new: T,
415 old: BinaryFunc,
416 column_ty: ColumnType,
417 ) {
418 assert_eq!(
419 new.propagates_nulls(),
420 old.propagates_nulls(),
421 "propagates_nulls mismatch"
422 );
423 assert_eq!(
424 new.introduces_nulls(),
425 old.introduces_nulls(),
426 "introduces_nulls mismatch"
427 );
428 assert_eq!(new.could_error(), old.could_error(), "could_error mismatch");
429 assert_eq!(new.is_monotone(), old.is_monotone(), "is_monotone mismatch");
430 assert_eq!(new.is_infix_op(), old.is_infix_op(), "is_infix_op mismatch");
431 assert_eq!(
432 new.output_type(column_ty.clone(), column_ty.clone()),
433 old.output_type(column_ty.clone(), column_ty.clone()),
434 "output_type mismatch"
435 );
436 assert_eq!(format!("{}", new), format!("{}", old), "format mismatch");
437 }
438 let i32_ty = ColumnType {
439 nullable: true,
440 scalar_type: ScalarType::Int32,
441 };
442 let ts_tz_ty = ColumnType {
443 nullable: true,
444 scalar_type: ScalarType::TimestampTz { precision: None },
445 };
446
447 use BinaryFunc as BF;
448
449 check(func::AddInt16, BF::AddInt16, i32_ty.clone());
450 check(func::AddInt32, BF::AddInt32, i32_ty.clone());
451 check(func::AddInt64, BF::AddInt64, i32_ty.clone());
452 check(func::AddUint16, BF::AddUInt16, i32_ty.clone());
453 check(func::AddUint32, BF::AddUInt32, i32_ty.clone());
454 check(func::AddUint64, BF::AddUInt64, i32_ty.clone());
455 check(func::AddFloat32, BF::AddFloat32, i32_ty.clone());
456 check(func::AddFloat64, BF::AddFloat64, i32_ty.clone());
457 check(func::AddDateTime, BF::AddDateTime, i32_ty.clone());
458 check(func::AddDateInterval, BF::AddDateInterval, i32_ty.clone());
459 check(func::AddTimeInterval, BF::AddTimeInterval, ts_tz_ty.clone());
460 check(func::RoundNumericBinary, BF::RoundNumeric, i32_ty.clone());
461 }
462}