mz_transform/non_null_requirements.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Push non-null requirements toward sources.
11//!
12//! This analysis derives NonNull requirements on the arguments to predicates.
13//! These requirements exist because most functions with Null arguments are
14//! themselves Null, and a predicate that evaluates to Null will not pass.
15//!
16//! These requirements are not here introduced as constraints, but rather flow
17//! to sources of data and restrict any constant collections to those rows that
18//! satisfy the constraint. The main consequence is when Null values are added
19//! in support of outer-joins and subqueries, we can occasionally remove that
20//! branch when we observe that Null values would be subjected to predicates.
21//!
22//! This analysis relies on a careful understanding of `ScalarExpr` and the
23//! semantics of various functions, *some of which may be non-Null even with
24//! Null arguments*.
25use std::collections::{BTreeMap, BTreeSet};
26
27use itertools::{Either, Itertools, zip_eq};
28use mz_expr::{Id, JoinInputMapper, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT};
29use mz_ore::assert_none;
30use mz_ore::stack::{CheckedRecursion, RecursionGuard};
31
32use crate::TransformCtx;
33
34/// Push non-null requirements toward sources.
35#[derive(Debug)]
36pub struct NonNullRequirements {
37 recursion_guard: RecursionGuard,
38}
39
40impl Default for NonNullRequirements {
41 fn default() -> NonNullRequirements {
42 NonNullRequirements {
43 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
44 }
45 }
46}
47
48impl CheckedRecursion for NonNullRequirements {
49 fn recursion_guard(&self) -> &RecursionGuard {
50 &self.recursion_guard
51 }
52}
53
54impl crate::Transform for NonNullRequirements {
55 fn name(&self) -> &'static str {
56 "NonNullRequirements"
57 }
58
59 #[mz_ore::instrument(
60 target = "optimizer",
61 level = "debug",
62 fields(path.segment = "non_null_requirements")
63 )]
64 fn actually_perform_transform(
65 &self,
66 relation: &mut MirRelationExpr,
67 _: &mut TransformCtx,
68 ) -> Result<(), crate::TransformError> {
69 let result = self.action(relation, BTreeSet::new(), &mut BTreeMap::new());
70 mz_repr::explain::trace_plan(&*relation);
71 result
72 }
73}
74
75impl NonNullRequirements {
76 /// Push non-null requirements toward sources.
77 ///
78 /// The action computes and pushes `columns` in a top-down manner and
79 /// simplifies the associated tree. The `columns` value denotes a set of
80 /// output columns that entail the associated `relation` will evaluate to
81 /// the constant empty collection if any column is null. This information is
82 /// used to simplify and prune sub-trees in the `Constant` and `Map` cases.
83 pub fn action(
84 &self,
85 relation: &mut MirRelationExpr,
86 mut columns: BTreeSet<usize>,
87 gets: &mut BTreeMap<Id, Vec<BTreeSet<usize>>>,
88 ) -> Result<(), crate::TransformError> {
89 self.checked_recur(|_| {
90 match relation {
91 MirRelationExpr::Constant { rows, .. } => {
92 if let Ok(rows) = rows {
93 let mut datum_vec = mz_repr::DatumVec::new();
94 rows.retain(|(row, _)| {
95 let datums = datum_vec.borrow_with(row);
96 columns.iter().all(|c| datums[*c] != mz_repr::Datum::Null)
97 })
98 }
99 Ok(())
100 }
101 MirRelationExpr::Get { id, .. } => {
102 gets.entry(*id).or_insert_with(Vec::new).push(columns);
103 Ok(())
104 }
105 MirRelationExpr::Let { id, value, body } => {
106 // Let harvests any non-null requirements from its body,
107 // and acts on the intersection of the requirements for
108 // each corresponding Get, pushing them at its value.
109 let id = Id::Local(*id);
110 let prior = gets.insert(id, Vec::new());
111 self.action(body, columns, gets)?;
112 let columns = intersect_all(&gets.remove(&id).unwrap());
113 if let Some(prior) = prior {
114 gets.insert(id, prior);
115 }
116 self.action(value, columns, gets)?;
117 Ok(())
118 }
119 MirRelationExpr::LetRec {
120 ids,
121 values,
122 body,
123 limits: _,
124 } => {
125 // Determine the recursive IDs in this LetRec binding.
126 let rec_ids = MirRelationExpr::recursive_ids(ids, values);
127
128 // Seed the gets map with an empty vector for each ID.
129 for id in ids.iter() {
130 let prior = gets.insert(Id::Local(*id), vec![]);
131 assert_none!(prior);
132 }
133
134 // Descend into the body with the supplied columns.
135 self.action(body, columns, gets)?;
136
137 // Descend into the values in reverse order.
138 for (id, value) in zip_eq(ids.iter().rev(), values.iter_mut().rev()) {
139 // Compute the required non-null columns for this value.
140 let columns = if rec_ids.contains(id) {
141 // For recursive IDs: conservatively don't assume
142 // any non-null column requests. TODO: This can be
143 // improved using a fixpoint-based approximation.
144 BTreeSet::new()
145 } else {
146 // For non-recursive IDs: request the intersection
147 // of all `columns` sets in the gets vector.
148 intersect_all(gets.get(&Id::Local(*id)).unwrap())
149 };
150 self.action(value, columns, gets)?;
151 }
152
153 // Remove the entries for all ids.
154 for id in ids.iter() {
155 gets.remove(&Id::Local(*id));
156 }
157
158 Ok(())
159 }
160 MirRelationExpr::Project { input, outputs } => self.action(
161 input,
162 columns.into_iter().map(|c| outputs[c]).collect(),
163 gets,
164 ),
165 MirRelationExpr::Map { input, scalars } => {
166 let input_arity = input.arity();
167 if columns
168 .iter()
169 .any(|c| *c >= input_arity && scalars[*c - input_arity].is_literal_null())
170 {
171 // A null value was introduced in a marked column;
172 // the entire expression can be zeroed out.
173 relation.take_safely(None);
174 Ok(())
175 } else {
176 // For each column, if it must be non-null, extract the expression's
177 // non-null requirements and include them too. We go in reverse order
178 // to ensure we squeegee down all requirements even for references to
179 // other columns produced in this operator.
180 for column in (input_arity..(input_arity + scalars.len())).rev() {
181 if columns.contains(&column) {
182 scalars[column - input_arity].non_null_requirements(&mut columns);
183 }
184 columns.remove(&column);
185 }
186 self.action(input, columns, gets)
187 }
188 }
189 MirRelationExpr::FlatMap { input, func, exprs } => {
190 // Columns whose number is smaller than arity refer to
191 // columns of `input`. Columns whose number is
192 // greater than or equal to the arity refer to columns created
193 // by the FlatMap. The latter group of columns cannot be
194 // propagated down.
195 let input_arity = input.arity();
196 columns.retain(|c| *c < input_arity);
197
198 if func.empty_on_null_input() {
199 // we can safely disregard rows where any of the exprs
200 // evaluate to null
201 for expr in exprs {
202 expr.non_null_requirements(&mut columns);
203 }
204 }
205
206 // TODO: if `!func.empty_on_null_input()` and there are members
207 // of `columns` that refer to columns created by the FlatMap, we
208 // may be able to propagate some non-null requirements based on
209 // which columns created by the FlatMap cannot be null. However,
210 // we have been too lazy to handle this so far.
211
212 self.action(input, columns, gets)
213 }
214 MirRelationExpr::Filter { input, predicates } => {
215 for predicate in predicates {
216 predicate.non_null_requirements(&mut columns);
217 // TODO: Not(IsNull) should add a constraint!
218 }
219 self.action(input, columns, gets)
220 }
221 MirRelationExpr::Join {
222 inputs,
223 equivalences,
224 ..
225 } => {
226 let input_types = inputs.iter().map(|i| i.typ()).collect::<Vec<_>>();
227
228 let input_mapper = JoinInputMapper::new_from_input_types(&input_types);
229
230 let mut new_columns = input_mapper.split_column_set_by_input(columns.iter());
231
232 // `variable` smears constraints around.
233 // Also, any non-nullable columns impose constraints on their equivalence class.
234 for equivalence in equivalences {
235 let exists_constraint = equivalence.iter().any(|expr| {
236 if let MirScalarExpr::Column(c) = expr {
237 let (col, rel) = input_mapper.map_column_to_local(*c);
238 new_columns[rel].contains(&col)
239 || !input_types[rel].column_types[col].nullable
240 } else {
241 false
242 }
243 });
244
245 if exists_constraint {
246 for expr in equivalence.iter() {
247 if let MirScalarExpr::Column(c) = expr {
248 let (col, rel) = input_mapper.map_column_to_local(*c);
249 new_columns[rel].insert(col);
250 }
251 }
252 }
253 }
254
255 for (input, columns) in inputs.iter_mut().zip(new_columns) {
256 self.action(input, columns, gets)?;
257 }
258 Ok(())
259 }
260 MirRelationExpr::Reduce {
261 input,
262 group_key,
263 aggregates,
264 monotonic: _,
265 expected_group_size: _,
266 } => {
267 let mut new_columns = BTreeSet::new();
268 let (group_key_columns, aggr_columns): (Vec<usize>, Vec<usize>) =
269 columns.iter().partition(|c| **c < group_key.len());
270 for column in group_key_columns {
271 group_key[column].non_null_requirements(&mut new_columns);
272 }
273
274 if !aggr_columns.is_empty() {
275 let (
276 mut inferred_nonnull_constraints,
277 mut ignored_nulls_by_remaining_aggregates,
278 ): (Vec<BTreeSet<usize>>, Vec<BTreeSet<usize>>) =
279 aggregates.iter().enumerate().partition_map(|(pos, aggr)| {
280 let mut ignores_nulls_on_columns = BTreeSet::new();
281 if let mz_repr::Datum::Null = aggr.func.identity_datum() {
282 aggr.expr
283 .non_null_requirements(&mut ignores_nulls_on_columns);
284 }
285 if aggr.func.propagates_nonnull_constraint()
286 && aggr_columns.contains(&(group_key.len() + pos))
287 {
288 Either::Left(ignores_nulls_on_columns)
289 } else {
290 Either::Right(ignores_nulls_on_columns)
291 }
292 });
293
294 // Compute the intersection of all pushable non constraints inferred from
295 // the non-null constraints on aggregate columns and the nulls ignored by
296 // the remaining aggregates. Example:
297 // - SUM(#0 + #2), MAX(#0 + #1), non-null requirements on both aggs => implies !isnull(#0)
298 // We don't want to push down a !isnull(#2) because deleting a row like (1,1, null) would
299 // make the MAX wrong.
300 // - SUM(#0 + #2), MAX(#0 + #1), non-null requirements only on the MAX => implies !isnull(#0).
301 let mut pushable_nonnull_constraints: Option<BTreeSet<usize>> = None;
302 if !inferred_nonnull_constraints.is_empty() {
303 for column_set in inferred_nonnull_constraints
304 .drain(..)
305 .chain(ignored_nulls_by_remaining_aggregates.drain(..))
306 {
307 if let Some(previous) = pushable_nonnull_constraints {
308 pushable_nonnull_constraints =
309 Some(column_set.intersection(&previous).cloned().collect());
310 } else {
311 pushable_nonnull_constraints = Some(column_set);
312 }
313 }
314 }
315
316 if let Some(pushable_nonnull_constraints) = pushable_nonnull_constraints {
317 new_columns.extend(pushable_nonnull_constraints);
318 }
319 }
320
321 self.action(input, new_columns, gets)
322 }
323 MirRelationExpr::TopK {
324 input, group_key, ..
325 } => {
326 // We can only allow rows to be discarded if their key columns are
327 // NULL, as discarding rows based on other columns can change the
328 // result set, based on how NULL is ordered.
329 columns.retain(|c| group_key.contains(c));
330 // TODO(mcsherry): bind NULL ordering and apply the transformation
331 // to all columns if the correct ASC/DESC ordering is observed
332 // (with some care about orderings on multiple columns).
333 self.action(input, columns, gets)
334 }
335 MirRelationExpr::Negate { input } => self.action(input, columns, gets),
336 MirRelationExpr::Threshold { input } => self.action(input, columns, gets),
337 MirRelationExpr::Union { base, inputs } => {
338 self.action(base, columns.clone(), gets)?;
339 for input in inputs {
340 self.action(input, columns.clone(), gets)?;
341 }
342 Ok(())
343 }
344 MirRelationExpr::ArrangeBy { input, .. } => self.action(input, columns, gets),
345 }
346 })
347 }
348}
349
350fn intersect_all(columns_vec: &Vec<BTreeSet<usize>>) -> BTreeSet<usize> {
351 columns_vec.iter().skip(1).fold(
352 columns_vec.first().cloned().unwrap_or_default(),
353 |mut intersection, columns| {
354 intersection.retain(|col| columns.contains(col));
355 intersection
356 },
357 )
358}