Skip to main content

mz_transform/
case_literal.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Rewrites chains of `If(Eq(expr, literal), result, If(...))` into
11//! `CallVariadic { func: CaseLiteral { lookup, return_type }, exprs }` for
12//! O(log n) evaluation via `BTreeMap` lookup.
13//!
14//! Uses the `ReprRelationType` analysis to obtain column types in O(n),
15//! avoiding repeated `input.typ()` calls. Each scalar is then visited
16//! bottom-up so inner CaseLiterals are created first, then outer If nodes
17//! fold into them.
18
19use std::collections::{BTreeMap, BTreeSet};
20
21use itertools::Itertools;
22use mz_expr::visit::Visit;
23use mz_expr::{BinaryFunc, MirRelationExpr, MirScalarExpr, VariadicFunc};
24use mz_repr::{ReprColumnType, Row, SqlColumnType};
25
26use crate::analysis::{DerivedBuilder, ReprRelationType};
27use crate::{Transform, TransformCtx, TransformError};
28
29/// Rewrites If-chains matching a single expression against literals
30/// into a `CaseLiteral` variadic function with `BTreeMap` lookup.
31#[derive(Debug)]
32pub struct CaseLiteralTransform;
33
34impl Transform for CaseLiteralTransform {
35    fn name(&self) -> &'static str {
36        "CaseLiteralTransform"
37    }
38
39    #[mz_ore::instrument(
40        target = "optimizer",
41        level = "debug",
42        fields(path.segment = "case_literal")
43    )]
44    fn actually_perform_transform(
45        &self,
46        relation: &mut MirRelationExpr,
47        ctx: &mut TransformCtx,
48    ) -> Result<(), TransformError> {
49        // Pre-compute column types for all nodes in a single pass.
50        let mut builder = DerivedBuilder::new(ctx.features);
51        builder.require(ReprRelationType);
52        let derived = builder.visit(&*relation);
53
54        let mut todo = vec![(&mut *relation, derived.as_view())];
55        while let Some((expr, view)) = todo.pop() {
56            match expr {
57                MirRelationExpr::Map { scalars, .. } => {
58                    // Use the output type (includes scalars' types).
59                    let output_type: &Vec<ReprColumnType> = view
60                        .value::<ReprRelationType>()
61                        .expect("ReprRelationType required")
62                        .as_ref()
63                        .unwrap();
64                    let input_arity = output_type.len() - scalars.len();
65                    for (index, scalar) in scalars.iter_mut().enumerate() {
66                        Self::rewrite_scalar(scalar, &output_type[..input_arity + index])?;
67                    }
68                }
69                MirRelationExpr::Filter { predicates, .. } => {
70                    let input_type: &Vec<ReprColumnType> = view
71                        .last_child()
72                        .value::<ReprRelationType>()
73                        .expect("ReprRelationType required")
74                        .as_ref()
75                        .unwrap();
76                    for predicate in predicates.iter_mut() {
77                        Self::rewrite_scalar(predicate, input_type)?;
78                    }
79                }
80                MirRelationExpr::Reduce { aggregates, .. } => {
81                    let input_type: &Vec<ReprColumnType> = view
82                        .last_child()
83                        .value::<ReprRelationType>()
84                        .expect("ReprRelationType required")
85                        .as_ref()
86                        .unwrap();
87                    for agg in aggregates.iter_mut() {
88                        Self::rewrite_scalar(&mut agg.expr, input_type)?;
89                    }
90                }
91                MirRelationExpr::FlatMap { exprs, .. } => {
92                    let input_type: &Vec<ReprColumnType> = view
93                        .last_child()
94                        .value::<ReprRelationType>()
95                        .expect("ReprRelationType required")
96                        .as_ref()
97                        .unwrap();
98                    for e in exprs.iter_mut() {
99                        Self::rewrite_scalar(e, input_type)?;
100                    }
101                }
102                MirRelationExpr::Join { equivalences, .. } => {
103                    let mut children: Vec<_> = view.children_rev().collect::<Vec<_>>();
104                    children.reverse();
105                    let input_types: Vec<ReprColumnType> = children
106                        .iter()
107                        .flat_map(|c| {
108                            c.value::<ReprRelationType>()
109                                .expect("ReprRelationType required")
110                                .as_ref()
111                                .unwrap()
112                                .iter()
113                                .cloned()
114                        })
115                        .collect();
116                    for class in equivalences.iter_mut() {
117                        for expr in class.iter_mut() {
118                            Self::rewrite_scalar(expr, &input_types)?;
119                        }
120                    }
121                }
122                MirRelationExpr::TopK { limit, .. } => {
123                    let input_type: &Vec<ReprColumnType> = view
124                        .last_child()
125                        .value::<ReprRelationType>()
126                        .expect("ReprRelationType required")
127                        .as_ref()
128                        .unwrap();
129                    if let Some(limit) = limit {
130                        Self::rewrite_scalar(limit, input_type)?;
131                    }
132                }
133                _ => {}
134            }
135            todo.extend(expr.children_mut().rev().zip_eq(view.children_rev()));
136        }
137
138        mz_repr::explain::trace_plan(&*relation);
139        Ok(())
140    }
141}
142
143impl CaseLiteralTransform {
144    /// Rewrites a scalar expression tree bottom-up, replacing If-chains of
145    /// `If(Eq(common_candidate, literal), result, ...)` with `CaseLiteral`.
146    fn rewrite_scalar(
147        expr: &mut MirScalarExpr,
148        column_types: &[ReprColumnType],
149    ) -> Result<(), TransformError> {
150        expr.try_visit_mut_post(&mut |node: &mut MirScalarExpr| {
151            try_fold_into_case_literal(node);
152            try_create_case_literal(node, column_types);
153            Ok(())
154        })
155    }
156}
157
158/// Fold rule: if node is `If(Eq(x, lit), res, CallVariadic(CaseLiteral{..}, [x, ...]))`
159/// where the CaseLiteral's input (`exprs[0]`) structurally equals `x`, insert (or
160/// overwrite) `res` into the existing CaseLiteral.
161///
162/// Because we traverse bottom-up, the current If is an *earlier* arm than anything
163/// already in the CaseLiteral. For duplicates, the outer/earlier arm wins per SQL
164/// CASE semantics, so we overwrite the existing entry.
165fn try_fold_into_case_literal(expr: &mut MirScalarExpr) {
166    let MirScalarExpr::If { cond, then, els } = expr else {
167        return;
168    };
169    let Some((common_candidate, literal_row)) = peek_eq_literal(cond) else {
170        return;
171    };
172    let MirScalarExpr::CallVariadic {
173        func: VariadicFunc::CaseLiteral(cl),
174        exprs,
175    } = els.as_mut()
176    else {
177        return;
178    };
179
180    // Check that the CaseLiteral's input matches the If's common expression.
181    if exprs[0] != *common_candidate {
182        return;
183    }
184
185    if let Some(&existing_idx) = cl.lookup.get(literal_row) {
186        // Duplicate literal: overwrite with the earlier arm's result (this If).
187        exprs[existing_idx] = then.take();
188    } else {
189        // New literal: insert before the fallback (last position).
190        let new_idx = exprs.len() - 1;
191        exprs.insert(new_idx, then.take());
192        cl.lookup.insert(literal_row.clone(), new_idx);
193    }
194
195    // Replace the If with the CaseLiteral.
196    *expr = els.take();
197}
198
199/// Chain-walk rule: if node is an If-chain with >= 2 consecutive arms matching
200/// `Eq(same_expr, literal)`, create a new CaseLiteral.
201fn try_create_case_literal(expr: &mut MirScalarExpr, column_types: &[ReprColumnType]) {
202    if !has_at_least_two_arms(expr) {
203        return;
204    }
205
206    // Take the expression and dismantle it.
207    let chain = expr.take();
208    let (collected_cases, common, els) = collect_if_chain_arms(chain);
209
210    let common = common.expect("common expr must be set when arm_count >= 2");
211
212    // Compute the return type as the union of all branch types and els type.
213    let mut return_type: Option<ReprColumnType> = None;
214    for (_, result) in &collected_cases {
215        let t = result.typ(column_types);
216        return_type = Some(match return_type {
217            None => t,
218            Some(prev) => prev.union(&t).expect("incompatible branch types"),
219        });
220    }
221    let els_type = els.typ(column_types);
222    let return_type = match return_type {
223        Some(prev) => prev.union(&els_type).expect("incompatible else type"),
224        None => els_type,
225    };
226    let sql_return_type = SqlColumnType::from_repr(&return_type);
227
228    // Build the exprs vector: [input, result1, result2, ..., els]
229    let mut exprs = Vec::with_capacity(collected_cases.len() + 2);
230    exprs.push(common);
231    let mut lookup = BTreeMap::new();
232    for (row, result_expr) in collected_cases {
233        let idx = exprs.len();
234        lookup.insert(row, idx);
235        exprs.push(result_expr);
236    }
237    exprs.push(els);
238
239    *expr = MirScalarExpr::CallVariadic {
240        func: VariadicFunc::CaseLiteral(mz_expr::func::CaseLiteral {
241            lookup,
242            return_type: sql_return_type,
243        }),
244        exprs,
245    };
246}
247
248/// Returns `true` if the If-chain has at least 2 arms matching `Eq(same_expr, literal)`.
249/// Bails early once 2 are found to avoid unnecessary traversal.
250fn has_at_least_two_arms(expr: &MirScalarExpr) -> bool {
251    let mut count = 0;
252    let mut common_candidate: Option<&MirScalarExpr> = None;
253    let mut current = expr;
254
255    loop {
256        match current {
257            MirScalarExpr::If { cond, then: _, els } => {
258                if let Some((expr_side, _literal_row)) = peek_eq_literal(cond) {
259                    match common_candidate {
260                        None => {
261                            common_candidate = Some(expr_side);
262                        }
263                        Some(existing) => {
264                            if existing != expr_side {
265                                break;
266                            }
267                        }
268                    }
269                    count += 1;
270                    if count >= 2 {
271                        return true;
272                    }
273                    current = els;
274                } else {
275                    break;
276                }
277            }
278            _ => break,
279        }
280    }
281
282    false
283}
284
285/// Inspects an `Eq(expr, literal)` condition and returns references to the
286/// non-literal expression and the literal `Row`.
287/// Returns `(non_literal_expr_ref, literal_row_ref)`.
288fn peek_eq_literal(cond: &MirScalarExpr) -> Option<(&MirScalarExpr, &Row)> {
289    let MirScalarExpr::CallBinary {
290        func: BinaryFunc::Eq(_),
291        expr1,
292        expr2,
293    } = cond
294    else {
295        return None;
296    };
297
298    if let Some(row) = expr1.as_literal_non_null_row() {
299        if !expr2.is_literal() {
300            return Some((expr2.as_ref(), row));
301        }
302    }
303    if let Some(row) = expr2.as_literal_non_null_row() {
304        if !expr1.is_literal() {
305            return Some((expr1.as_ref(), row));
306        }
307    }
308    None
309}
310
311/// Walks an If-chain and collects `(literal_row, result_expr)` pairs.
312///
313/// The input `chain` is consumed and dismantled.
314/// Returns `(cases, common_candidate, els)`.
315fn collect_if_chain_arms(
316    chain: MirScalarExpr,
317) -> (
318    Vec<(Row, MirScalarExpr)>,
319    Option<MirScalarExpr>,
320    MirScalarExpr,
321) {
322    let mut cases = Vec::new();
323    let mut seen = BTreeSet::new();
324    let mut common_candidate: Option<MirScalarExpr> = None;
325    let mut remaining = chain;
326
327    loop {
328        match remaining {
329            MirScalarExpr::If { cond, then, els } => {
330                if let Some((expr_side, literal_row)) = peek_eq_literal(&cond) {
331                    match &common_candidate {
332                        None => {
333                            common_candidate = Some(expr_side.clone());
334                        }
335                        Some(existing) => {
336                            if existing != expr_side {
337                                remaining = MirScalarExpr::If { cond, then, els };
338                                break;
339                            }
340                        }
341                    }
342
343                    // First occurrence of each literal wins (SQL CASE semantics).
344                    if seen.insert(literal_row.clone()) {
345                        cases.push((literal_row.clone(), *then));
346                    }
347
348                    remaining = *els;
349                } else {
350                    remaining = MirScalarExpr::If { cond, then, els };
351                    break;
352                }
353            }
354            _ => break,
355        }
356    }
357
358    (cases, common_candidate, remaining)
359}
360
361#[cfg(test)]
362mod tests {
363    use mz_expr::func::Eq;
364    use mz_expr::{MirRelationExpr, MirScalarExpr, VariadicFunc};
365    use mz_repr::{Datum, ReprColumnType, ReprRelationType, ReprScalarType};
366
367    use super::*;
368
369    /// Helper: build an i64 literal.
370    fn lit_i64(v: i64) -> MirScalarExpr {
371        MirScalarExpr::literal_ok(Datum::Int64(v), ReprScalarType::Int64)
372    }
373
374    /// Wrap a scalar expression in a `Map` over a constant to allow applying the transform.
375    fn wrap_in_map(scalar: MirScalarExpr) -> MirRelationExpr {
376        MirRelationExpr::Map {
377            input: Box::new(MirRelationExpr::constant(
378                vec![vec![Datum::Int64(0)]],
379                ReprRelationType::new(vec![ReprColumnType {
380                    scalar_type: ReprScalarType::Int64,
381                    nullable: false,
382                }]),
383            )),
384            scalars: vec![scalar],
385        }
386    }
387
388    /// Apply the CaseLiteralTransform to a relation and return the first scalar from the Map.
389    fn apply_transform(scalar: MirScalarExpr) -> MirScalarExpr {
390        let mut relation = wrap_in_map(scalar);
391        let mut features = mz_repr::optimize::OptimizerFeatures::default();
392        features.enable_case_literal_transform = true;
393        let typecheck_ctx = crate::typecheck::empty_typechecking_context();
394        let mut df_meta = crate::dataflow::DataflowMetainfo::default();
395        let mut transform_ctx =
396            crate::TransformCtx::local(&features, &typecheck_ctx, &mut df_meta, None, None);
397        crate::Transform::transform(&CaseLiteralTransform, &mut relation, &mut transform_ctx)
398            .unwrap();
399        match relation {
400            MirRelationExpr::Map { scalars, .. } => scalars.into_iter().next().unwrap(),
401            other => panic!("expected Map, got {other:?}"),
402        }
403    }
404
405    /// Verify that the result is a CaseLiteral with the expected number of cases.
406    fn assert_case_literal(expr: &MirScalarExpr, expected_cases: usize) {
407        match expr {
408            MirScalarExpr::CallVariadic {
409                func: VariadicFunc::CaseLiteral(cl),
410                ..
411            } => {
412                assert_eq!(
413                    cl.lookup.len(),
414                    expected_cases,
415                    "expected {expected_cases} cases, got {}",
416                    cl.lookup.len()
417                );
418            }
419            other => panic!("expected CaseLiteral, got {other:?}"),
420        }
421    }
422
423    // Build a CASE-like If-chain: CASE #0 WHEN 1 THEN 10 WHEN 2 THEN 20 ELSE 0 END
424    fn build_2_arm_chain() -> MirScalarExpr {
425        MirScalarExpr::column(0)
426            .call_binary(lit_i64(1), Eq)
427            .if_then_else(
428                lit_i64(10),
429                MirScalarExpr::column(0)
430                    .call_binary(lit_i64(2), Eq)
431                    .if_then_else(lit_i64(20), lit_i64(0)),
432            )
433    }
434
435    #[mz_ore::test]
436    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
437    fn test_null_literal_skipped() {
438        // If(Eq(#0, NULL::int64), 10, If(Eq(#0, 2), 20, If(Eq(#0, 3), 30, 0)))
439        // The NULL arm breaks the chain at the top level because
440        // peek_eq_literal skips NULL literals, so arm1 doesn't match.
441        let null_lit = MirScalarExpr::literal(Ok(Datum::Null), ReprScalarType::Int64);
442        let expr = MirScalarExpr::column(0)
443            .call_binary(null_lit, Eq)
444            .if_then_else(
445                lit_i64(10),
446                MirScalarExpr::column(0)
447                    .call_binary(lit_i64(2), Eq)
448                    .if_then_else(
449                        lit_i64(20),
450                        MirScalarExpr::column(0)
451                            .call_binary(lit_i64(3), Eq)
452                            .if_then_else(lit_i64(30), lit_i64(0)),
453                    ),
454            );
455        let result = apply_transform(expr);
456        // The NULL arm breaks the chain at the top level. The inner 2 arms
457        // (comparing #0 to 2 and 3) should still be converted.
458        // With bottom-up, the inner chain becomes a CaseLiteral first.
459        // Then the outer If(Eq(#0, NULL), 10, CaseLiteral) has a CaseLiteral
460        // in els, but the cond is Eq(#0, NULL) which is not a valid literal
461        // (NULL is skipped), so the fold rule doesn't fire.
462        // Result: If(Eq(#0, NULL), 10, CaseLiteral(...))
463        match &result {
464            MirScalarExpr::If { els, .. } => {
465                assert_case_literal(els, 2);
466            }
467            other => panic!("expected If with CaseLiteral in els, got {other:?}"),
468        }
469    }
470
471    #[mz_ore::test]
472    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
473    fn test_64_arm_chain() {
474        // Build a 64-arm If-chain and verify it converts to a single CaseLiteral.
475        let n: usize = 64;
476        let mut expr = lit_i64(-1);
477        for i in (0..n).rev() {
478            let i = i64::try_from(i).expect("arm index fits in i64");
479            expr = MirScalarExpr::column(0)
480                .call_binary(lit_i64(i), Eq)
481                .if_then_else(lit_i64(100 * i), expr);
482        }
483        let result = apply_transform(expr);
484        assert_case_literal(&result, n);
485
486        // Spot-check evaluation.
487        let arena = mz_repr::RowArena::new();
488        assert_eq!(
489            result.eval(&[Datum::Int64(0)], &arena).unwrap(),
490            Datum::Int64(0)
491        );
492        assert_eq!(
493            result.eval(&[Datum::Int64(32)], &arena).unwrap(),
494            Datum::Int64(3200)
495        );
496        assert_eq!(
497            result.eval(&[Datum::Int64(63)], &arena).unwrap(),
498            Datum::Int64(6300)
499        );
500        assert_eq!(
501            result.eval(&[Datum::Int64(999)], &arena).unwrap(),
502            Datum::Int64(-1)
503        );
504    }
505
506    #[mz_ore::test]
507    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
508    fn test_eval_basic() {
509        // Verify that evaluating the CaseLiteral produces correct results.
510        let expr = build_2_arm_chain();
511        let result = apply_transform(expr);
512
513        let arena = mz_repr::RowArena::new();
514
515        // Input = 1 → output should be 10
516        let out = result.eval(&[Datum::Int64(1)], &arena).unwrap();
517        assert_eq!(out, Datum::Int64(10));
518
519        // Input = 2 → output should be 20
520        let out = result.eval(&[Datum::Int64(2)], &arena).unwrap();
521        assert_eq!(out, Datum::Int64(20));
522
523        // Input = 99 → output should be 0 (els)
524        let out = result.eval(&[Datum::Int64(99)], &arena).unwrap();
525        assert_eq!(out, Datum::Int64(0));
526
527        // Input = NULL → output should be 0 (els, since NULL = x is falsy)
528        let out = result.eval(&[Datum::Null], &arena).unwrap();
529        assert_eq!(out, Datum::Int64(0));
530    }
531
532    #[mz_ore::test]
533    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
534    fn test_string_literals() {
535        // CASE #0 WHEN 'a' THEN 1 WHEN 'b' THEN 2 ELSE 0 END
536        // Verify that non-i64 literal types also work.
537        fn lit_str(s: &str) -> MirScalarExpr {
538            MirScalarExpr::literal_ok(Datum::String(s), ReprScalarType::String)
539        }
540        fn wrap_in_string_map(scalar: MirScalarExpr) -> MirRelationExpr {
541            MirRelationExpr::Map {
542                input: Box::new(MirRelationExpr::constant(
543                    vec![vec![Datum::String("x")]],
544                    ReprRelationType::new(vec![ReprColumnType {
545                        scalar_type: ReprScalarType::String,
546                        nullable: false,
547                    }]),
548                )),
549                scalars: vec![scalar],
550            }
551        }
552        let expr = MirScalarExpr::column(0)
553            .call_binary(lit_str("a"), Eq)
554            .if_then_else(
555                lit_i64(1),
556                MirScalarExpr::column(0)
557                    .call_binary(lit_str("b"), Eq)
558                    .if_then_else(lit_i64(2), lit_i64(0)),
559            );
560        let mut relation = wrap_in_string_map(expr);
561        let mut features = mz_repr::optimize::OptimizerFeatures::default();
562        features.enable_case_literal_transform = true;
563        let typecheck_ctx = crate::typecheck::empty_typechecking_context();
564        let mut df_meta = crate::dataflow::DataflowMetainfo::default();
565        let mut transform_ctx =
566            crate::TransformCtx::local(&features, &typecheck_ctx, &mut df_meta, None, None);
567        crate::Transform::transform(&CaseLiteralTransform, &mut relation, &mut transform_ctx)
568            .unwrap();
569        let result = match relation {
570            MirRelationExpr::Map { scalars, .. } => scalars.into_iter().next().unwrap(),
571            other => panic!("expected Map, got {other:?}"),
572        };
573        assert_case_literal(&result, 2);
574
575        let arena = mz_repr::RowArena::new();
576        assert_eq!(
577            result.eval(&[Datum::String("a")], &arena).unwrap(),
578            Datum::Int64(1)
579        );
580        assert_eq!(
581            result.eval(&[Datum::String("b")], &arena).unwrap(),
582            Datum::Int64(2)
583        );
584        assert_eq!(
585            result.eval(&[Datum::String("z")], &arena).unwrap(),
586            Datum::Int64(0)
587        );
588    }
589}