mz_expr/scalar/func/impls/case_literal.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A lookup-based evaluation of `CASE expr WHEN lit1 THEN res1 ... ELSE els END`.
11//!
12//! [`CaseLiteral`] replaces chains of `If(Eq(expr, literal), result, If(...))`
13//! with a sorted `Vec` + binary-search lookup, turning O(n) evaluation into O(log n).
14//!
15//! Represented as a `CallVariadic { func: CaseLiteral { lookup, return_type }, exprs }`
16//! where:
17//! * `exprs[0]` = input expression (the `x` in `CASE x WHEN ...`)
18//! * `exprs[1..n]` = case result expressions
19//! * `exprs[last]` = `els` (fallback)
20//! * `lookup: Vec<CaseLiteralEntry>` maps literal values to indices in `exprs` (sorted by `Row`)
21
22use std::fmt;
23
24use mz_lowertest::MzReflect;
25use mz_repr::{Datum, Row, RowArena, SqlColumnType};
26use serde::{Deserialize, Serialize};
27
28use crate::scalar::func::variadic::LazyVariadicFunc;
29use crate::{EvalError, MirScalarExpr};
30
31/// A single entry in a [`CaseLiteral`] lookup table: a literal `Row` value
32/// paired with the index of the corresponding result expression in `exprs`.
33#[derive(
34 Ord,
35 PartialOrd,
36 Clone,
37 Debug,
38 Eq,
39 PartialEq,
40 Serialize,
41 Deserialize,
42 Hash,
43 MzReflect
44)]
45pub struct CaseLiteralEntry {
46 /// The literal value (as a single-datum `Row`).
47 #[mzreflect(ignore)]
48 pub literal: Row,
49 /// Index into the `exprs` vector of the corresponding result expression.
50 pub expr_index: usize,
51}
52
53/// Evaluates a CASE expression by looking up the input datum in a sorted `Vec`.
54///
55/// The input expression (`exprs[0]`) is evaluated once, packed into a temporary
56/// `Row`, and looked up in `lookup` via binary search. If found, the corresponding
57/// result expression (`exprs[idx]`) is evaluated; otherwise the fallback
58/// (`exprs.last()`) is evaluated.
59/// NULL inputs go straight to the fallback (since SQL `NULL = x` is always NULL/falsy).
60#[derive(
61 Ord,
62 PartialOrd,
63 Clone,
64 Debug,
65 Eq,
66 PartialEq,
67 Serialize,
68 Deserialize,
69 Hash,
70 MzReflect
71)]
72pub struct CaseLiteral {
73 /// Sorted vec of literal-to-index entries for binary-search lookup.
74 pub lookup: Vec<CaseLiteralEntry>,
75 /// The output type of this CASE expression.
76 pub return_type: SqlColumnType,
77}
78
79impl LazyVariadicFunc for CaseLiteral {
80 fn eval<'a>(
81 &'a self,
82 datums: &[Datum<'a>],
83 temp_storage: &'a RowArena,
84 exprs: &'a [MirScalarExpr],
85 ) -> Result<Datum<'a>, EvalError> {
86 let input = exprs[0].eval(datums, temp_storage)?;
87 // SQL NULL = x is always NULL/falsy, so go straight to the fallback.
88 if input.is_null() {
89 return exprs.last().unwrap().eval(datums, temp_storage);
90 }
91 let key = Row::pack_slice(&[input]);
92 if let Ok(pos) = self
93 .lookup
94 .binary_search_by(|entry| entry.literal.cmp(&key))
95 {
96 exprs[self.lookup[pos].expr_index].eval(datums, temp_storage)
97 } else {
98 exprs.last().unwrap().eval(datums, temp_storage)
99 }
100 }
101
102 fn output_type(&self, _input_types: &[SqlColumnType]) -> SqlColumnType {
103 self.return_type.clone()
104 }
105
106 fn propagates_nulls(&self) -> bool {
107 // NULL input goes to the fallback, not automatically to NULL output.
108 false
109 }
110
111 fn introduces_nulls(&self) -> bool {
112 // Branch results or the fallback may be NULL.
113 true
114 }
115
116 fn could_error(&self) -> bool {
117 // The function itself does not error; errors in sub-expressions are
118 // checked separately by MirScalarExpr::could_error.
119 false
120 }
121
122 fn is_monotone(&self) -> bool {
123 false
124 }
125
126 fn is_associative(&self) -> bool {
127 false
128 }
129}
130
131// Note: this Display impl is unused at runtime because CaseLiteral has
132// custom printing in src/expr/src/explain/text.rs.
133impl CaseLiteral {
134 /// Look up a key in the sorted lookup vec. Returns the expr index if found.
135 pub fn get(&self, key: &Row) -> Option<usize> {
136 self.lookup
137 .binary_search_by(|entry| entry.literal.cmp(key))
138 .ok()
139 .map(|pos| self.lookup[pos].expr_index)
140 }
141
142 /// Insert an entry, maintaining sorted order.
143 /// If the literal already exists, overwrites the index and returns the old one.
144 pub fn insert(&mut self, literal: Row, expr_index: usize) -> Option<usize> {
145 match self
146 .lookup
147 .binary_search_by(|entry| entry.literal.cmp(&literal))
148 {
149 Ok(pos) => {
150 let old = self.lookup[pos].expr_index;
151 self.lookup[pos].expr_index = expr_index;
152 Some(old)
153 }
154 Err(pos) => {
155 self.lookup.insert(
156 pos,
157 CaseLiteralEntry {
158 literal,
159 expr_index,
160 },
161 );
162 None
163 }
164 }
165 }
166}
167
168impl fmt::Display for CaseLiteral {
169 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
170 write!(f, "case_literal[{} cases]", self.lookup.len())
171 }
172}