Skip to main content

mz_expr/scalar/
optimizable.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A trait for scalar expressions that can be optimized inside a `MapFilterProject`.
11//!
12//! This trait is implemented by both `MirScalarExpr` and `LirScalarExpr`,
13//! allowing `MapFilterProject` to be parameterized over either.
14
15use std::fmt::Debug;
16use std::hash::Hash;
17
18use mz_ore::stack::RecursionLimitError;
19use serde::Serialize;
20
21use crate::scalar::columns::Columns;
22use crate::scalar::func::{BinaryFunc, UnaryFunc, VariadicFunc};
23use crate::visit::{Visit, VisitChildren};
24use crate::{MirScalarExpr, func};
25
26/// A scalar expression type that can be optimized inside a `MapFilterProject`.
27///
28/// Implemented by `MirScalarExpr` and `LirScalarExpr`.
29pub trait OptimizableExpr:
30    Columns + VisitChildren<Self> + Clone + Eq + Ord + Hash + Debug + Sized + Serialize
31{
32    /// True if this expression is a literal.
33    fn is_literal(&self) -> bool;
34
35    /// True if this expression is a literal error.
36    fn is_literal_err(&self) -> bool;
37
38    /// True if this expression contains a temporal reference (`mz_now()`).
39    fn contains_temporal(&self) -> bool;
40
41    /// Count of AST nodes in the expression tree.
42    fn size(&self) -> usize;
43
44    /// For memoization: which children should be eagerly memoized?
45    ///
46    /// Returns `None` to visit all children (the common case).
47    /// Returns `Some(children)` for selective descent — e.g., for `If`, only the
48    /// condition should be eagerly memoized (branches may not be taken).
49    fn eager_children(&mut self) -> Option<Vec<&mut Self>>;
50
51    /// If `predicate` is `col = expr` (or `expr = col`) where `col` is a column
52    /// with index < `threshold`, return a clone of that column expression.
53    ///
54    /// Used by `optimize()` to detect equality-derived column aliases.
55    fn equality_column_alias(predicate: &Self, expr: &Self, threshold: usize) -> Option<Self>;
56
57    /// Extract temporal bounds from a list of temporal predicates.
58    ///
59    /// Returns `(lower_bounds, upper_bounds)` for use in `MfpPlan`.
60    fn extract_temporal_bounds(temporal: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), String>;
61
62    /// Visit in a pre-traversal. Defaults to the `Visit` implementation, but overridable.
63    fn visit_pre<F>(&self, f: &mut F) -> Result<(), RecursionLimitError>
64    where
65        F: FnMut(&Self),
66    {
67        Visit::visit_pre(self, f)
68    }
69}
70
71impl OptimizableExpr for MirScalarExpr {
72    fn is_literal(&self) -> bool {
73        self.is_literal()
74    }
75
76    fn is_literal_err(&self) -> bool {
77        self.is_literal_err()
78    }
79
80    fn contains_temporal(&self) -> bool {
81        self.contains_temporal()
82    }
83
84    fn size(&self) -> usize {
85        self.size()
86    }
87
88    fn eager_children(&mut self) -> Option<Vec<&mut Self>> {
89        // Do not eagerly memoize `if` branches that might not be taken.
90        if let MirScalarExpr::If { cond, .. } = self {
91            return Some(vec![cond]);
92        }
93
94        // Do not eagerly memoize `COALESCE` expressions after the first,
95        // as they are only meant to be evaluated if the preceding expressions
96        // evaluate to NULL.
97        if let MirScalarExpr::CallVariadic {
98            func: VariadicFunc::Coalesce(_),
99            exprs,
100        } = self
101        {
102            return Some(exprs.iter_mut().take(1).collect());
103        }
104
105        // Do not deconstruct temporal filters, because `MfpPlan::create_from` expects
106        // those to be in a specific form. However, attend to the expression on the
107        // opposite side of mz_now().
108        if let Ok((_func, other_side)) = self.as_mut_temporal_filter() {
109            return Some(vec![other_side]);
110        }
111
112        None
113    }
114
115    fn equality_column_alias(predicate: &Self, expr: &Self, threshold: usize) -> Option<Self> {
116        if let MirScalarExpr::CallBinary {
117            func: BinaryFunc::Eq(_),
118            expr1,
119            expr2,
120        } = predicate
121        {
122            if let MirScalarExpr::Column(c, name) = &**expr1 {
123                if *c < threshold && &**expr2 == expr {
124                    return Some(MirScalarExpr::Column(*c, name.clone()));
125                }
126            }
127            if let MirScalarExpr::Column(c, name) = &**expr2 {
128                if *c < threshold && &**expr1 == expr {
129                    return Some(MirScalarExpr::Column(*c, name.clone()));
130                }
131            }
132        }
133        None
134    }
135
136    fn extract_temporal_bounds(temporal: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), String> {
137        let mut lower_bounds = Vec::new();
138        let mut upper_bounds = Vec::new();
139
140        for mut predicate in temporal.into_iter() {
141            let (func, expr2) = predicate.as_mut_temporal_filter()?;
142            let expr2 = expr2.clone();
143
144            match func {
145                BinaryFunc::Eq(_) => {
146                    lower_bounds.push(expr2.clone());
147                    upper_bounds
148                        .push(expr2.call_unary(UnaryFunc::StepMzTimestamp(func::StepMzTimestamp)));
149                }
150                BinaryFunc::Lt(_) => {
151                    upper_bounds.push(expr2.clone());
152                }
153                BinaryFunc::Lte(_) => {
154                    upper_bounds
155                        .push(expr2.call_unary(UnaryFunc::StepMzTimestamp(func::StepMzTimestamp)));
156                }
157                BinaryFunc::Gt(_) => {
158                    lower_bounds
159                        .push(expr2.call_unary(UnaryFunc::StepMzTimestamp(func::StepMzTimestamp)));
160                }
161                BinaryFunc::Gte(_) => {
162                    lower_bounds.push(expr2.clone());
163                }
164                _ => {
165                    return Err(format!("Unsupported binary temporal operation: {:?}", func));
166                }
167            }
168        }
169
170        Ok((lower_bounds, upper_bounds))
171    }
172
173    fn visit_pre<F>(&self, f: &mut F) -> Result<(), RecursionLimitError>
174    where
175        F: FnMut(&Self),
176    {
177        self.visit_pre(f);
178        Ok(())
179    }
180}