mz_transform/
non_null_requirements.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Push non-null requirements toward sources.
11//!
12//! This analysis derives NonNull requirements on the arguments to predicates.
13//! These requirements exist because most functions with Null arguments are
14//! themselves Null, and a predicate that evaluates to Null will not pass.
15//!
16//! These requirements are not here introduced as constraints, but rather flow
17//! to sources of data and restrict any constant collections to those rows that
18//! satisfy the constraint. The main consequence is when Null values are added
19//! in support of outer-joins and subqueries, we can occasionally remove that
20//! branch when we observe that Null values would be subjected to predicates.
21//!
22//! This analysis relies on a careful understanding of `ScalarExpr` and the
23//! semantics of various functions, *some of which may be non-Null even with
24//! Null arguments*.
25use std::collections::{BTreeMap, BTreeSet};
26
27use itertools::{Either, Itertools, zip_eq};
28use mz_expr::{Id, JoinInputMapper, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT};
29use mz_ore::assert_none;
30use mz_ore::stack::{CheckedRecursion, RecursionGuard};
31
32use crate::TransformCtx;
33
34/// Push non-null requirements toward sources.
35#[derive(Debug)]
36pub struct NonNullRequirements {
37    recursion_guard: RecursionGuard,
38}
39
40impl Default for NonNullRequirements {
41    fn default() -> NonNullRequirements {
42        NonNullRequirements {
43            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
44        }
45    }
46}
47
48impl CheckedRecursion for NonNullRequirements {
49    fn recursion_guard(&self) -> &RecursionGuard {
50        &self.recursion_guard
51    }
52}
53
54impl crate::Transform for NonNullRequirements {
55    fn name(&self) -> &'static str {
56        "NonNullRequirements"
57    }
58
59    #[mz_ore::instrument(
60        target = "optimizer",
61        level = "debug",
62        fields(path.segment = "non_null_requirements")
63    )]
64    fn actually_perform_transform(
65        &self,
66        relation: &mut MirRelationExpr,
67        _: &mut TransformCtx,
68    ) -> Result<(), crate::TransformError> {
69        let result = self.action(relation, BTreeSet::new(), &mut BTreeMap::new());
70        mz_repr::explain::trace_plan(&*relation);
71        result
72    }
73}
74
75impl NonNullRequirements {
76    /// Push non-null requirements toward sources.
77    ///
78    /// The action computes and pushes `columns` in a top-down manner and
79    /// simplifies the associated tree. The `columns` value denotes a set of
80    /// output columns that entail the associated `relation` will evaluate to
81    /// the constant empty collection if any column is null. This information is
82    /// used to simplify and prune sub-trees in the `Constant` and `Map` cases.
83    pub fn action(
84        &self,
85        relation: &mut MirRelationExpr,
86        mut columns: BTreeSet<usize>,
87        gets: &mut BTreeMap<Id, Vec<BTreeSet<usize>>>,
88    ) -> Result<(), crate::TransformError> {
89        self.checked_recur(|_| {
90            match relation {
91                MirRelationExpr::Constant { rows, .. } => {
92                    if let Ok(rows) = rows {
93                        let mut datum_vec = mz_repr::DatumVec::new();
94                        rows.retain(|(row, _)| {
95                            let datums = datum_vec.borrow_with(row);
96                            columns.iter().all(|c| datums[*c] != mz_repr::Datum::Null)
97                        })
98                    }
99                    Ok(())
100                }
101                MirRelationExpr::Get { id, .. } => {
102                    gets.entry(*id).or_insert_with(Vec::new).push(columns);
103                    Ok(())
104                }
105                MirRelationExpr::Let { id, value, body } => {
106                    // Let harvests any non-null requirements from its body,
107                    // and acts on the intersection of the requirements for
108                    // each corresponding Get, pushing them at its value.
109                    let id = Id::Local(*id);
110                    let prior = gets.insert(id, Vec::new());
111                    self.action(body, columns, gets)?;
112                    let columns = intersect_all(&gets.remove(&id).unwrap());
113                    if let Some(prior) = prior {
114                        gets.insert(id, prior);
115                    }
116                    self.action(value, columns, gets)?;
117                    Ok(())
118                }
119                MirRelationExpr::LetRec {
120                    ids,
121                    values,
122                    body,
123                    limits: _,
124                } => {
125                    // Determine the recursive IDs in this LetRec binding.
126                    let rec_ids = MirRelationExpr::recursive_ids(ids, values);
127
128                    // Seed the gets map with an empty vector for each ID.
129                    for id in ids.iter() {
130                        let prior = gets.insert(Id::Local(*id), vec![]);
131                        assert_none!(prior);
132                    }
133
134                    // Descend into the body with the supplied columns.
135                    self.action(body, columns, gets)?;
136
137                    // Descend into the values in reverse order.
138                    for (id, value) in zip_eq(ids.iter().rev(), values.iter_mut().rev()) {
139                        // Compute the required non-null columns for this value.
140                        let columns = if rec_ids.contains(id) {
141                            // For recursive IDs: conservatively don't assume
142                            // any non-null column requests. TODO: This can be
143                            // improved using a fixpoint-based approximation.
144                            BTreeSet::new()
145                        } else {
146                            // For non-recursive IDs: request the intersection
147                            // of all `columns` sets in the gets vector.
148                            intersect_all(gets.get(&Id::Local(*id)).unwrap())
149                        };
150                        self.action(value, columns, gets)?;
151                    }
152
153                    // Remove the entries for all ids.
154                    for id in ids.iter() {
155                        gets.remove(&Id::Local(*id));
156                    }
157
158                    Ok(())
159                }
160                MirRelationExpr::Project { input, outputs } => self.action(
161                    input,
162                    columns.into_iter().map(|c| outputs[c]).collect(),
163                    gets,
164                ),
165                MirRelationExpr::Map { input, scalars } => {
166                    let input_arity = input.arity();
167                    if columns
168                        .iter()
169                        .any(|c| *c >= input_arity && scalars[*c - input_arity].is_literal_null())
170                    {
171                        // A null value was introduced in a marked column;
172                        // the entire expression can be zeroed out.
173                        relation.take_safely(None);
174                        Ok(())
175                    } else {
176                        // For each column, if it must be non-null, extract the expression's
177                        // non-null requirements and include them too. We go in reverse order
178                        // to ensure we squeegee down all requirements even for references to
179                        // other columns produced in this operator.
180                        for column in (input_arity..(input_arity + scalars.len())).rev() {
181                            if columns.contains(&column) {
182                                scalars[column - input_arity].non_null_requirements(&mut columns);
183                            }
184                            columns.remove(&column);
185                        }
186                        self.action(input, columns, gets)
187                    }
188                }
189                MirRelationExpr::FlatMap { input, func, exprs } => {
190                    // Columns whose number is smaller than arity refer to
191                    // columns of `input`. Columns whose number is
192                    // greater than or equal to the arity refer to columns created
193                    // by the FlatMap. The latter group of columns cannot be
194                    // propagated down.
195                    let input_arity = input.arity();
196                    columns.retain(|c| *c < input_arity);
197
198                    if func.empty_on_null_input() {
199                        // we can safely disregard rows where any of the exprs
200                        // evaluate to null
201                        for expr in exprs {
202                            expr.non_null_requirements(&mut columns);
203                        }
204                    }
205
206                    // TODO: if `!func.empty_on_null_input()` and there are members
207                    // of `columns` that refer to columns created by the FlatMap, we
208                    // may be able to propagate some non-null requirements based on
209                    // which columns created by the FlatMap cannot be null. However,
210                    // we have been too lazy to handle this so far.
211
212                    self.action(input, columns, gets)
213                }
214                MirRelationExpr::Filter { input, predicates } => {
215                    for predicate in predicates {
216                        predicate.non_null_requirements(&mut columns);
217                        // TODO: Not(IsNull) should add a constraint!
218                    }
219                    self.action(input, columns, gets)
220                }
221                MirRelationExpr::Join {
222                    inputs,
223                    equivalences,
224                    ..
225                } => {
226                    let input_types = inputs.iter().map(|i| i.typ()).collect::<Vec<_>>();
227
228                    let input_mapper = JoinInputMapper::new_from_input_types(&input_types);
229
230                    let mut new_columns = input_mapper.split_column_set_by_input(columns.iter());
231
232                    // `variable` smears constraints around.
233                    // Also, any non-nullable columns impose constraints on their equivalence class.
234                    for equivalence in equivalences {
235                        let exists_constraint = equivalence.iter().any(|expr| {
236                            if let MirScalarExpr::Column(c) = expr {
237                                let (col, rel) = input_mapper.map_column_to_local(*c);
238                                new_columns[rel].contains(&col)
239                                    || !input_types[rel].column_types[col].nullable
240                            } else {
241                                false
242                            }
243                        });
244
245                        if exists_constraint {
246                            for expr in equivalence.iter() {
247                                if let MirScalarExpr::Column(c) = expr {
248                                    let (col, rel) = input_mapper.map_column_to_local(*c);
249                                    new_columns[rel].insert(col);
250                                }
251                            }
252                        }
253                    }
254
255                    for (input, columns) in inputs.iter_mut().zip(new_columns) {
256                        self.action(input, columns, gets)?;
257                    }
258                    Ok(())
259                }
260                MirRelationExpr::Reduce {
261                    input,
262                    group_key,
263                    aggregates,
264                    monotonic: _,
265                    expected_group_size: _,
266                } => {
267                    let mut new_columns = BTreeSet::new();
268                    let (group_key_columns, aggr_columns): (Vec<usize>, Vec<usize>) =
269                        columns.iter().partition(|c| **c < group_key.len());
270                    for column in group_key_columns {
271                        group_key[column].non_null_requirements(&mut new_columns);
272                    }
273
274                    if !aggr_columns.is_empty() {
275                        let (
276                            mut inferred_nonnull_constraints,
277                            mut ignored_nulls_by_remaining_aggregates,
278                        ): (Vec<BTreeSet<usize>>, Vec<BTreeSet<usize>>) =
279                            aggregates.iter().enumerate().partition_map(|(pos, aggr)| {
280                                let mut ignores_nulls_on_columns = BTreeSet::new();
281                                if let mz_repr::Datum::Null = aggr.func.identity_datum() {
282                                    aggr.expr
283                                        .non_null_requirements(&mut ignores_nulls_on_columns);
284                                }
285                                if aggr.func.propagates_nonnull_constraint()
286                                    && aggr_columns.contains(&(group_key.len() + pos))
287                                {
288                                    Either::Left(ignores_nulls_on_columns)
289                                } else {
290                                    Either::Right(ignores_nulls_on_columns)
291                                }
292                            });
293
294                        // Compute the intersection of all pushable non constraints inferred from
295                        // the non-null constraints on aggregate columns and the nulls ignored by
296                        // the remaining aggregates. Example:
297                        // - SUM(#0 + #2), MAX(#0 + #1), non-null requirements on both aggs => implies !isnull(#0)
298                        //  We don't want to push down a !isnull(#2) because deleting a row like (1,1, null) would
299                        //  make the MAX wrong.
300                        // - SUM(#0 + #2), MAX(#0 + #1), non-null requirements only on the MAX => implies !isnull(#0).
301                        let mut pushable_nonnull_constraints: Option<BTreeSet<usize>> = None;
302                        if !inferred_nonnull_constraints.is_empty() {
303                            for column_set in inferred_nonnull_constraints
304                                .drain(..)
305                                .chain(ignored_nulls_by_remaining_aggregates.drain(..))
306                            {
307                                if let Some(previous) = pushable_nonnull_constraints {
308                                    pushable_nonnull_constraints =
309                                        Some(column_set.intersection(&previous).cloned().collect());
310                                } else {
311                                    pushable_nonnull_constraints = Some(column_set);
312                                }
313                            }
314                        }
315
316                        if let Some(pushable_nonnull_constraints) = pushable_nonnull_constraints {
317                            new_columns.extend(pushable_nonnull_constraints);
318                        }
319                    }
320
321                    self.action(input, new_columns, gets)
322                }
323                MirRelationExpr::TopK {
324                    input, group_key, ..
325                } => {
326                    // We can only allow rows to be discarded if their key columns are
327                    // NULL, as discarding rows based on other columns can change the
328                    // result set, based on how NULL is ordered.
329                    columns.retain(|c| group_key.contains(c));
330                    // TODO(mcsherry): bind NULL ordering and apply the transformation
331                    // to all columns if the correct ASC/DESC ordering is observed
332                    // (with some care about orderings on multiple columns).
333                    self.action(input, columns, gets)
334                }
335                MirRelationExpr::Negate { input } => self.action(input, columns, gets),
336                MirRelationExpr::Threshold { input } => self.action(input, columns, gets),
337                MirRelationExpr::Union { base, inputs } => {
338                    self.action(base, columns.clone(), gets)?;
339                    for input in inputs {
340                        self.action(input, columns.clone(), gets)?;
341                    }
342                    Ok(())
343                }
344                MirRelationExpr::ArrangeBy { input, .. } => self.action(input, columns, gets),
345            }
346        })
347    }
348}
349
350fn intersect_all(columns_vec: &Vec<BTreeSet<usize>>) -> BTreeSet<usize> {
351    columns_vec.iter().skip(1).fold(
352        columns_vec.first().cloned().unwrap_or_default(),
353        |mut intersection, columns| {
354            intersection.retain(|col| columns.contains(col));
355            intersection
356        },
357    )
358}