Skip to main content

mz_expr/scalar/func/impls/
case_literal.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A lookup-based evaluation of `CASE expr WHEN lit1 THEN res1 ... ELSE els END`.
11//!
12//! [`CaseLiteral`] replaces chains of `If(Eq(expr, literal), result, If(...))`
13//! with a single `BTreeMap` lookup, turning O(n) evaluation into O(log n).
14//!
15//! Represented as a `CallVariadic { func: CaseLiteral { lookup, return_type }, exprs }`
16//! where:
17//! * `exprs[0]` = input expression (the `x` in `CASE x WHEN ...`)
18//! * `exprs[1..n]` = case result expressions
19//! * `exprs[last]` = `els` (fallback)
20//! * `lookup: BTreeMap<Row, usize>` maps literal values to indices in `exprs`
21
22use std::collections::BTreeMap;
23use std::fmt;
24
25use mz_lowertest::MzReflect;
26use mz_repr::{Datum, Row, RowArena, SqlColumnType};
27use serde::{Deserialize, Serialize};
28
29use crate::scalar::func::variadic::LazyVariadicFunc;
30use crate::{EvalError, MirScalarExpr};
31
32/// Evaluates a CASE expression by looking up the input datum in a `BTreeMap`.
33///
34/// The input expression (`exprs[0]`) is evaluated once, packed into a temporary
35/// `Row`, and looked up in `lookup`. If found, the corresponding result expression
36/// (`exprs[idx]`) is evaluated; otherwise the fallback (`exprs.last()`) is evaluated.
37/// NULL inputs go straight to the fallback (since SQL `NULL = x` is always NULL/falsy).
38#[derive(
39    Ord,
40    PartialOrd,
41    Clone,
42    Debug,
43    Eq,
44    PartialEq,
45    Serialize,
46    Deserialize,
47    Hash,
48    MzReflect
49)]
50pub struct CaseLiteral {
51    /// Map from literal values (as single-datum `Row`s) to indices in the `exprs` vector.
52    pub lookup: BTreeMap<Row, usize>,
53    /// The output type of this CASE expression.
54    pub return_type: SqlColumnType,
55}
56
57impl LazyVariadicFunc for CaseLiteral {
58    fn eval<'a>(
59        &'a self,
60        datums: &[Datum<'a>],
61        temp_storage: &'a RowArena,
62        exprs: &'a [MirScalarExpr],
63    ) -> Result<Datum<'a>, EvalError> {
64        let input = exprs[0].eval(datums, temp_storage)?;
65        // SQL NULL = x is always NULL/falsy, so go straight to the fallback.
66        if input.is_null() {
67            return exprs.last().unwrap().eval(datums, temp_storage);
68        }
69        let key = Row::pack_slice(&[input]);
70        if let Some(&idx) = self.lookup.get(&key) {
71            exprs[idx].eval(datums, temp_storage)
72        } else {
73            exprs.last().unwrap().eval(datums, temp_storage)
74        }
75    }
76
77    fn output_type(&self, _input_types: &[SqlColumnType]) -> SqlColumnType {
78        self.return_type.clone()
79    }
80
81    fn propagates_nulls(&self) -> bool {
82        // NULL input goes to the fallback, not automatically to NULL output.
83        false
84    }
85
86    fn introduces_nulls(&self) -> bool {
87        // Branch results or the fallback may be NULL.
88        true
89    }
90
91    fn could_error(&self) -> bool {
92        // The function itself does not error; errors in sub-expressions are
93        // checked separately by MirScalarExpr::could_error.
94        false
95    }
96
97    fn is_monotone(&self) -> bool {
98        false
99    }
100
101    fn is_associative(&self) -> bool {
102        false
103    }
104}
105
106// Note: this Display impl is unused at runtime because CaseLiteral has
107// custom printing in src/expr/src/explain/text.rs.
108impl fmt::Display for CaseLiteral {
109    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
110        write!(f, "case_literal[{} cases]", self.lookup.len())
111    }
112}