mz_expr/scalar/func/impls/case_literal.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A lookup-based evaluation of `CASE expr WHEN lit1 THEN res1 ... ELSE els END`.
11//!
12//! [`CaseLiteral`] replaces chains of `If(Eq(expr, literal), result, If(...))`
13//! with a single `BTreeMap` lookup, turning O(n) evaluation into O(log n).
14//!
15//! Represented as a `CallVariadic { func: CaseLiteral { lookup, return_type }, exprs }`
16//! where:
17//! * `exprs[0]` = input expression (the `x` in `CASE x WHEN ...`)
18//! * `exprs[1..n]` = case result expressions
19//! * `exprs[last]` = `els` (fallback)
20//! * `lookup: BTreeMap<Row, usize>` maps literal values to indices in `exprs`
21
22use std::collections::BTreeMap;
23use std::fmt;
24
25use mz_lowertest::MzReflect;
26use mz_repr::{Datum, Row, RowArena, SqlColumnType};
27use serde::{Deserialize, Serialize};
28
29use crate::scalar::func::variadic::LazyVariadicFunc;
30use crate::{EvalError, MirScalarExpr};
31
32/// Evaluates a CASE expression by looking up the input datum in a `BTreeMap`.
33///
34/// The input expression (`exprs[0]`) is evaluated once, packed into a temporary
35/// `Row`, and looked up in `lookup`. If found, the corresponding result expression
36/// (`exprs[idx]`) is evaluated; otherwise the fallback (`exprs.last()`) is evaluated.
37/// NULL inputs go straight to the fallback (since SQL `NULL = x` is always NULL/falsy).
38#[derive(
39 Ord,
40 PartialOrd,
41 Clone,
42 Debug,
43 Eq,
44 PartialEq,
45 Serialize,
46 Deserialize,
47 Hash,
48 MzReflect
49)]
50pub struct CaseLiteral {
51 /// Map from literal values (as single-datum `Row`s) to indices in the `exprs` vector.
52 pub lookup: BTreeMap<Row, usize>,
53 /// The output type of this CASE expression.
54 pub return_type: SqlColumnType,
55}
56
57impl LazyVariadicFunc for CaseLiteral {
58 fn eval<'a>(
59 &'a self,
60 datums: &[Datum<'a>],
61 temp_storage: &'a RowArena,
62 exprs: &'a [MirScalarExpr],
63 ) -> Result<Datum<'a>, EvalError> {
64 let input = exprs[0].eval(datums, temp_storage)?;
65 // SQL NULL = x is always NULL/falsy, so go straight to the fallback.
66 if input.is_null() {
67 return exprs.last().unwrap().eval(datums, temp_storage);
68 }
69 let key = Row::pack_slice(&[input]);
70 if let Some(&idx) = self.lookup.get(&key) {
71 exprs[idx].eval(datums, temp_storage)
72 } else {
73 exprs.last().unwrap().eval(datums, temp_storage)
74 }
75 }
76
77 fn output_type(&self, _input_types: &[SqlColumnType]) -> SqlColumnType {
78 self.return_type.clone()
79 }
80
81 fn propagates_nulls(&self) -> bool {
82 // NULL input goes to the fallback, not automatically to NULL output.
83 false
84 }
85
86 fn introduces_nulls(&self) -> bool {
87 // Branch results or the fallback may be NULL.
88 true
89 }
90
91 fn could_error(&self) -> bool {
92 // The function itself does not error; errors in sub-expressions are
93 // checked separately by MirScalarExpr::could_error.
94 false
95 }
96
97 fn is_monotone(&self) -> bool {
98 false
99 }
100
101 fn is_associative(&self) -> bool {
102 false
103 }
104}
105
106// Note: this Display impl is unused at runtime because CaseLiteral has
107// custom printing in src/expr/src/explain/text.rs.
108impl fmt::Display for CaseLiteral {
109 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
110 write!(f, "case_literal[{} cases]", self.lookup.len())
111 }
112}