Skip to main content

mz_expr/scalar/
optimizable.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A trait for scalar expressions that can be optimized inside a `MapFilterProject`.
11//!
12//! This trait is implemented by both `MirScalarExpr` and `LirScalarExpr`,
13//! allowing `MapFilterProject` to be parameterized over either.
14
15use std::fmt::Debug;
16use std::hash::Hash;
17
18use serde::Serialize;
19
20use crate::scalar::columns::Columns;
21use crate::scalar::func::{BinaryFunc, UnaryFunc, VariadicFunc};
22use crate::visit::VisitChildren;
23use crate::{MirScalarExpr, func};
24
25/// A scalar expression type that can be optimized inside a `MapFilterProject`.
26///
27/// Implemented by `MirScalarExpr` and `LirScalarExpr`.
28pub trait OptimizableExpr:
29    Columns + VisitChildren<Self> + Clone + Eq + Ord + Hash + Debug + Sized + Serialize
30{
31    /// True if this expression is a literal.
32    fn is_literal(&self) -> bool;
33
34    /// True if this expression is a literal error.
35    fn is_literal_err(&self) -> bool;
36
37    /// True if this expression contains a temporal reference (`mz_now()`).
38    fn contains_temporal(&self) -> bool;
39
40    /// Count of AST nodes in the expression tree.
41    fn size(&self) -> usize;
42
43    /// For memoization: which children should be eagerly memoized?
44    ///
45    /// Returns `None` to visit all children (the common case).
46    /// Returns `Some(children)` for selective descent — e.g., for `If`, only the
47    /// condition should be eagerly memoized (branches may not be taken).
48    fn eager_children(&mut self) -> Option<Vec<&mut Self>>;
49
50    /// If `predicate` is `col = expr` (or `expr = col`) where `col` is a column
51    /// with index < `threshold`, return a clone of that column expression.
52    ///
53    /// Used by `optimize()` to detect equality-derived column aliases.
54    fn equality_column_alias(predicate: &Self, expr: &Self, threshold: usize) -> Option<Self>;
55
56    /// Extract temporal bounds from a list of temporal predicates.
57    ///
58    /// Returns `(lower_bounds, upper_bounds)` for use in `MfpPlan`.
59    fn extract_temporal_bounds(temporal: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), String>;
60}
61
62impl OptimizableExpr for MirScalarExpr {
63    fn is_literal(&self) -> bool {
64        self.is_literal()
65    }
66
67    fn is_literal_err(&self) -> bool {
68        self.is_literal_err()
69    }
70
71    fn contains_temporal(&self) -> bool {
72        self.contains_temporal()
73    }
74
75    fn size(&self) -> usize {
76        self.size()
77    }
78
79    fn eager_children(&mut self) -> Option<Vec<&mut Self>> {
80        // Do not eagerly memoize `if` branches that might not be taken.
81        if let MirScalarExpr::If { cond, .. } = self {
82            return Some(vec![cond]);
83        }
84
85        // Do not eagerly memoize `COALESCE` expressions after the first,
86        // as they are only meant to be evaluated if the preceding expressions
87        // evaluate to NULL.
88        if let MirScalarExpr::CallVariadic {
89            func: VariadicFunc::Coalesce(_),
90            exprs,
91        } = self
92        {
93            return Some(exprs.iter_mut().take(1).collect());
94        }
95
96        // Do not deconstruct temporal filters, because `MfpPlan::create_from` expects
97        // those to be in a specific form. However, attend to the expression on the
98        // opposite side of mz_now().
99        if let Ok((_func, other_side)) = self.as_mut_temporal_filter() {
100            return Some(vec![other_side]);
101        }
102
103        None
104    }
105
106    fn equality_column_alias(predicate: &Self, expr: &Self, threshold: usize) -> Option<Self> {
107        if let MirScalarExpr::CallBinary {
108            func: BinaryFunc::Eq(_),
109            expr1,
110            expr2,
111        } = predicate
112        {
113            if let MirScalarExpr::Column(c, name) = &**expr1 {
114                if *c < threshold && &**expr2 == expr {
115                    return Some(MirScalarExpr::Column(*c, name.clone()));
116                }
117            }
118            if let MirScalarExpr::Column(c, name) = &**expr2 {
119                if *c < threshold && &**expr1 == expr {
120                    return Some(MirScalarExpr::Column(*c, name.clone()));
121                }
122            }
123        }
124        None
125    }
126
127    fn extract_temporal_bounds(temporal: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), String> {
128        let mut lower_bounds = Vec::new();
129        let mut upper_bounds = Vec::new();
130
131        for mut predicate in temporal.into_iter() {
132            let (func, expr2) = predicate.as_mut_temporal_filter()?;
133            let expr2 = expr2.clone();
134
135            match func {
136                BinaryFunc::Eq(_) => {
137                    lower_bounds.push(expr2.clone());
138                    upper_bounds
139                        .push(expr2.call_unary(UnaryFunc::StepMzTimestamp(func::StepMzTimestamp)));
140                }
141                BinaryFunc::Lt(_) => {
142                    upper_bounds.push(expr2.clone());
143                }
144                BinaryFunc::Lte(_) => {
145                    upper_bounds
146                        .push(expr2.call_unary(UnaryFunc::StepMzTimestamp(func::StepMzTimestamp)));
147                }
148                BinaryFunc::Gt(_) => {
149                    lower_bounds
150                        .push(expr2.call_unary(UnaryFunc::StepMzTimestamp(func::StepMzTimestamp)));
151                }
152                BinaryFunc::Gte(_) => {
153                    lower_bounds.push(expr2.clone());
154                }
155                _ => {
156                    return Err(format!("Unsupported binary temporal operation: {:?}", func));
157                }
158            }
159        }
160
161        Ok((lower_bounds, upper_bounds))
162    }
163}