Skip to main content

mz_expr/scalar/func/impls/
case_literal.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A lookup-based evaluation of `CASE expr WHEN lit1 THEN res1 ... ELSE els END`.
11//!
12//! [`CaseLiteral`] replaces chains of `If(Eq(expr, literal), result, If(...))`
13//! with a sorted `Vec` + binary-search lookup, turning O(n) evaluation into O(log n).
14//!
15//! Represented as a `CallVariadic { func: CaseLiteral { lookup, return_type }, exprs }`
16//! where:
17//! * `exprs[0]` = input expression (the `x` in `CASE x WHEN ...`)
18//! * `exprs[1..n]` = case result expressions
19//! * `exprs[last]` = `els` (fallback)
20//! * `lookup: Vec<CaseLiteralEntry>` maps literal values to indices in `exprs` (sorted by `Row`)
21
22use std::fmt;
23
24use mz_lowertest::MzReflect;
25use mz_repr::{Datum, Row, RowArena, SqlColumnType};
26use serde::{Deserialize, Serialize};
27
28use crate::scalar::func::variadic::LazyVariadicFunc;
29use crate::{EvalError, MirScalarExpr};
30
31/// A single entry in a [`CaseLiteral`] lookup table: a literal `Row` value
32/// paired with the index of the corresponding result expression in `exprs`.
33#[derive(
34    Ord,
35    PartialOrd,
36    Clone,
37    Debug,
38    Eq,
39    PartialEq,
40    Serialize,
41    Deserialize,
42    Hash,
43    MzReflect
44)]
45pub struct CaseLiteralEntry {
46    /// The literal value (as a single-datum `Row`).
47    #[mzreflect(ignore)]
48    pub literal: Row,
49    /// Index into the `exprs` vector of the corresponding result expression.
50    pub expr_index: usize,
51}
52
53/// Evaluates a CASE expression by looking up the input datum in a sorted `Vec`.
54///
55/// The input expression (`exprs[0]`) is evaluated once, packed into a temporary
56/// `Row`, and looked up in `lookup` via binary search. If found, the corresponding
57/// result expression (`exprs[idx]`) is evaluated; otherwise the fallback
58/// (`exprs.last()`) is evaluated.
59/// NULL inputs go straight to the fallback (since SQL `NULL = x` is always NULL/falsy).
60#[derive(
61    Ord,
62    PartialOrd,
63    Clone,
64    Debug,
65    Eq,
66    PartialEq,
67    Serialize,
68    Deserialize,
69    Hash,
70    MzReflect
71)]
72pub struct CaseLiteral {
73    /// Sorted vec of literal-to-index entries for binary-search lookup.
74    pub lookup: Vec<CaseLiteralEntry>,
75    /// The output type of this CASE expression.
76    pub return_type: SqlColumnType,
77}
78
79impl LazyVariadicFunc for CaseLiteral {
80    fn eval<'a>(
81        &'a self,
82        datums: &[Datum<'a>],
83        temp_storage: &'a RowArena,
84        exprs: &'a [MirScalarExpr],
85    ) -> Result<Datum<'a>, EvalError> {
86        let input = exprs[0].eval(datums, temp_storage)?;
87        // SQL NULL = x is always NULL/falsy, so go straight to the fallback.
88        if input.is_null() {
89            return exprs.last().unwrap().eval(datums, temp_storage);
90        }
91        let key = Row::pack_slice(&[input]);
92        if let Ok(pos) = self
93            .lookup
94            .binary_search_by(|entry| entry.literal.cmp(&key))
95        {
96            exprs[self.lookup[pos].expr_index].eval(datums, temp_storage)
97        } else {
98            exprs.last().unwrap().eval(datums, temp_storage)
99        }
100    }
101
102    fn output_type(&self, _input_types: &[SqlColumnType]) -> SqlColumnType {
103        self.return_type.clone()
104    }
105
106    fn propagates_nulls(&self) -> bool {
107        // NULL input goes to the fallback, not automatically to NULL output.
108        false
109    }
110
111    fn introduces_nulls(&self) -> bool {
112        // Branch results or the fallback may be NULL.
113        true
114    }
115
116    fn could_error(&self) -> bool {
117        // The function itself does not error; errors in sub-expressions are
118        // checked separately by MirScalarExpr::could_error.
119        false
120    }
121
122    fn is_monotone(&self) -> bool {
123        false
124    }
125
126    fn is_associative(&self) -> bool {
127        false
128    }
129}
130
131// Note: this Display impl is unused at runtime because CaseLiteral has
132// custom printing in src/expr/src/explain/text.rs.
133impl CaseLiteral {
134    /// Look up a key in the sorted lookup vec. Returns the expr index if found.
135    pub fn get(&self, key: &Row) -> Option<usize> {
136        self.lookup
137            .binary_search_by(|entry| entry.literal.cmp(key))
138            .ok()
139            .map(|pos| self.lookup[pos].expr_index)
140    }
141
142    /// Insert an entry, maintaining sorted order.
143    /// If the literal already exists, overwrites the index and returns the old one.
144    pub fn insert(&mut self, literal: Row, expr_index: usize) -> Option<usize> {
145        match self
146            .lookup
147            .binary_search_by(|entry| entry.literal.cmp(&literal))
148        {
149            Ok(pos) => {
150                let old = self.lookup[pos].expr_index;
151                self.lookup[pos].expr_index = expr_index;
152                Some(old)
153            }
154            Err(pos) => {
155                self.lookup.insert(
156                    pos,
157                    CaseLiteralEntry {
158                        literal,
159                        expr_index,
160                    },
161                );
162                None
163            }
164        }
165    }
166}
167
168impl fmt::Display for CaseLiteral {
169    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
170        write!(f, "case_literal[{} cases]", self.lookup.len())
171    }
172}