1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

//! Push non-null requirements toward sources.
//!
//! This analysis derives NonNull requirements on the arguments to predicates.
//! These requirements exist because most functions with Null arguments are
//! themselves Null, and a predicate that evaluates to Null will not pass.
//!
//! These requirements are not here introduced as constraints, but rather flow
//! to sources of data and restrict any constant collections to those rows that
//! satisfy the constraint. The main consequence is when Null values are added
//! in support of outer-joins and subqueries, we can occasionally remove that
//! branch when we observe that Null values would be subjected to predicates.
//!
//! This analysis relies on a careful understanding of `ScalarExpr` and the
//! semantics of various functions, *some of which may be non-Null even with
//! Null arguments*.
use std::collections::{BTreeMap, BTreeSet};

use itertools::{zip_eq, Either, Itertools};
use mz_expr::{Id, JoinInputMapper, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT};
use mz_ore::assert_none;
use mz_ore::stack::{CheckedRecursion, RecursionGuard};

use crate::TransformCtx;

/// Push non-null requirements toward sources.
#[derive(Debug)]
pub struct NonNullRequirements {
    recursion_guard: RecursionGuard,
}

impl Default for NonNullRequirements {
    fn default() -> NonNullRequirements {
        NonNullRequirements {
            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
        }
    }
}

impl CheckedRecursion for NonNullRequirements {
    fn recursion_guard(&self) -> &RecursionGuard {
        &self.recursion_guard
    }
}

impl crate::Transform for NonNullRequirements {
    fn name(&self) -> &'static str {
        "NonNullRequirements"
    }

    #[mz_ore::instrument(
        target = "optimizer",
        level = "debug",
        fields(path.segment = "non_null_requirements")
    )]
    fn actually_perform_transform(
        &self,
        relation: &mut MirRelationExpr,
        _: &mut TransformCtx,
    ) -> Result<(), crate::TransformError> {
        let result = self.action(relation, BTreeSet::new(), &mut BTreeMap::new());
        mz_repr::explain::trace_plan(&*relation);
        result
    }
}

impl NonNullRequirements {
    /// Push non-null requirements toward sources.
    ///
    /// The action computes and pushes `columns` in a top-down manner and
    /// simplifies the associated tree. The `columns` value denotes a set of
    /// output columns that entail the associated `relation` will evaluate to
    /// the constant empty collection if any column is null. This information is
    /// used to simplify and prune sub-trees in the `Constant` and `Map` cases.
    pub fn action(
        &self,
        relation: &mut MirRelationExpr,
        mut columns: BTreeSet<usize>,
        gets: &mut BTreeMap<Id, Vec<BTreeSet<usize>>>,
    ) -> Result<(), crate::TransformError> {
        self.checked_recur(|_| {
            match relation {
                MirRelationExpr::Constant { rows, .. } => {
                    if let Ok(rows) = rows {
                        let mut datum_vec = mz_repr::DatumVec::new();
                        rows.retain(|(row, _)| {
                            let datums = datum_vec.borrow_with(row);
                            columns.iter().all(|c| datums[*c] != mz_repr::Datum::Null)
                        })
                    }
                    Ok(())
                }
                MirRelationExpr::Get { id, .. } => {
                    gets.entry(*id).or_insert_with(Vec::new).push(columns);
                    Ok(())
                }
                MirRelationExpr::Let { id, value, body } => {
                    // Let harvests any non-null requirements from its body,
                    // and acts on the intersection of the requirements for
                    // each corresponding Get, pushing them at its value.
                    let id = Id::Local(*id);
                    let prior = gets.insert(id, Vec::new());
                    self.action(body, columns, gets)?;
                    let columns = intersect_all(&gets.remove(&id).unwrap());
                    if let Some(prior) = prior {
                        gets.insert(id, prior);
                    }
                    self.action(value, columns, gets)?;
                    Ok(())
                }
                MirRelationExpr::LetRec {
                    ids,
                    values,
                    body,
                    limits: _,
                } => {
                    // Determine the recursive IDs in this LetRec binding.
                    let rec_ids = MirRelationExpr::recursive_ids(ids, values);

                    // Seed the gets map with an empty vector for each ID.
                    for id in ids.iter() {
                        let prior = gets.insert(Id::Local(*id), vec![]);
                        assert_none!(prior);
                    }

                    // Descend into the body with the supplied columns.
                    self.action(body, columns, gets)?;

                    // Descend into the values in reverse order.
                    for (id, value) in zip_eq(ids.iter().rev(), values.iter_mut().rev()) {
                        // Compute the required non-null columns for this value.
                        let columns = if rec_ids.contains(id) {
                            // For recursive IDs: conservatively don't assume
                            // any non-null column requests. TODO: This can be
                            // improved using a fixpoint-based approximation.
                            BTreeSet::new()
                        } else {
                            // For non-recursive IDs: request the intersection
                            // of all `columns` sets in the gets vector.
                            intersect_all(gets.get(&Id::Local(*id)).unwrap())
                        };
                        self.action(value, columns, gets)?;
                    }

                    // Remove the entries for all ids.
                    for id in ids.iter() {
                        gets.remove(&Id::Local(*id));
                    }

                    Ok(())
                }
                MirRelationExpr::Project { input, outputs } => self.action(
                    input,
                    columns.into_iter().map(|c| outputs[c]).collect(),
                    gets,
                ),
                MirRelationExpr::Map { input, scalars } => {
                    let input_arity = input.arity();
                    if columns
                        .iter()
                        .any(|c| *c >= input_arity && scalars[*c - input_arity].is_literal_null())
                    {
                        // A null value was introduced in a marked column;
                        // the entire expression can be zeroed out.
                        relation.take_safely();
                        Ok(())
                    } else {
                        // For each column, if it must be non-null, extract the expression's
                        // non-null requirements and include them too. We go in reverse order
                        // to ensure we squeegee down all requirements even for references to
                        // other columns produced in this operator.
                        for column in (input_arity..(input_arity + scalars.len())).rev() {
                            if columns.contains(&column) {
                                scalars[column - input_arity].non_null_requirements(&mut columns);
                            }
                            columns.remove(&column);
                        }
                        self.action(input, columns, gets)
                    }
                }
                MirRelationExpr::FlatMap { input, func, exprs } => {
                    // Columns whose number is smaller than arity refer to
                    // columns of `input`. Columns whose number is
                    // greater than or equal to the arity refer to columns created
                    // by the FlatMap. The latter group of columns cannot be
                    // propagated down.
                    let input_arity = input.arity();
                    columns.retain(|c| *c < input_arity);

                    if func.empty_on_null_input() {
                        // we can safely disregard rows where any of the exprs
                        // evaluate to null
                        for expr in exprs {
                            expr.non_null_requirements(&mut columns);
                        }
                    }

                    // TODO: if `!func.empty_on_null_input()` and there are members
                    // of `columns` that refer to columns created by the FlatMap, we
                    // may be able to propagate some non-null requirements based on
                    // which columns created by the FlatMap cannot be null. However,
                    // we have been too lazy to handle this so far.

                    self.action(input, columns, gets)
                }
                MirRelationExpr::Filter { input, predicates } => {
                    for predicate in predicates {
                        predicate.non_null_requirements(&mut columns);
                        // TODO: Not(IsNull) should add a constraint!
                    }
                    self.action(input, columns, gets)
                }
                MirRelationExpr::Join {
                    inputs,
                    equivalences,
                    ..
                } => {
                    let input_types = inputs.iter().map(|i| i.typ()).collect::<Vec<_>>();

                    let input_mapper = JoinInputMapper::new_from_input_types(&input_types);

                    let mut new_columns = input_mapper.split_column_set_by_input(columns.iter());

                    // `variable` smears constraints around.
                    // Also, any non-nullable columns impose constraints on their equivalence class.
                    for equivalence in equivalences {
                        let exists_constraint = equivalence.iter().any(|expr| {
                            if let MirScalarExpr::Column(c) = expr {
                                let (col, rel) = input_mapper.map_column_to_local(*c);
                                new_columns[rel].contains(&col)
                                    || !input_types[rel].column_types[col].nullable
                            } else {
                                false
                            }
                        });

                        if exists_constraint {
                            for expr in equivalence.iter() {
                                if let MirScalarExpr::Column(c) = expr {
                                    let (col, rel) = input_mapper.map_column_to_local(*c);
                                    new_columns[rel].insert(col);
                                }
                            }
                        }
                    }

                    for (input, columns) in inputs.iter_mut().zip(new_columns) {
                        self.action(input, columns, gets)?;
                    }
                    Ok(())
                }
                MirRelationExpr::Reduce {
                    input,
                    group_key,
                    aggregates,
                    monotonic: _,
                    expected_group_size: _,
                } => {
                    let mut new_columns = BTreeSet::new();
                    let (group_key_columns, aggr_columns): (Vec<usize>, Vec<usize>) =
                        columns.iter().partition(|c| **c < group_key.len());
                    for column in group_key_columns {
                        group_key[column].non_null_requirements(&mut new_columns);
                    }

                    if !aggr_columns.is_empty() {
                        let (
                            mut inferred_nonnull_constraints,
                            mut ignored_nulls_by_remaining_aggregates,
                        ): (Vec<BTreeSet<usize>>, Vec<BTreeSet<usize>>) =
                            aggregates.iter().enumerate().partition_map(|(pos, aggr)| {
                                let mut ignores_nulls_on_columns = BTreeSet::new();
                                if let mz_repr::Datum::Null = aggr.func.identity_datum() {
                                    aggr.expr
                                        .non_null_requirements(&mut ignores_nulls_on_columns);
                                }
                                if aggr.func.propagates_nonnull_constraint()
                                    && aggr_columns.contains(&(group_key.len() + pos))
                                {
                                    Either::Left(ignores_nulls_on_columns)
                                } else {
                                    Either::Right(ignores_nulls_on_columns)
                                }
                            });

                        // Compute the intersection of all pushable non constraints inferred from
                        // the non-null constraints on aggregate columns and the nulls ignored by
                        // the remaining aggregates. Example:
                        // - SUM(#0 + #2), MAX(#0 + #1), non-null requirements on both aggs => implies !isnull(#0)
                        //  We don't want to push down a !isnull(#2) because deleting a row like (1,1, null) would
                        //  make the MAX wrong.
                        // - SUM(#0 + #2), MAX(#0 + #1), non-null requirements only on the MAX => implies !isnull(#0).
                        let mut pushable_nonnull_constraints: Option<BTreeSet<usize>> = None;
                        if !inferred_nonnull_constraints.is_empty() {
                            for column_set in inferred_nonnull_constraints
                                .drain(..)
                                .chain(ignored_nulls_by_remaining_aggregates.drain(..))
                            {
                                if let Some(previous) = pushable_nonnull_constraints {
                                    pushable_nonnull_constraints =
                                        Some(column_set.intersection(&previous).cloned().collect());
                                } else {
                                    pushable_nonnull_constraints = Some(column_set);
                                }
                            }
                        }

                        if let Some(pushable_nonnull_constraints) = pushable_nonnull_constraints {
                            new_columns.extend(pushable_nonnull_constraints);
                        }
                    }

                    self.action(input, new_columns, gets)
                }
                MirRelationExpr::TopK {
                    input, group_key, ..
                } => {
                    // We can only allow rows to be discarded if their key columns are
                    // NULL, as discarding rows based on other columns can change the
                    // result set, based on how NULL is ordered.
                    columns.retain(|c| group_key.contains(c));
                    // TODO(mcsherry): bind NULL ordering and apply the transformation
                    // to all columns if the correct ASC/DESC ordering is observed
                    // (with some care about orderings on multiple columns).
                    self.action(input, columns, gets)
                }
                MirRelationExpr::Negate { input } => self.action(input, columns, gets),
                MirRelationExpr::Threshold { input } => self.action(input, columns, gets),
                MirRelationExpr::Union { base, inputs } => {
                    self.action(base, columns.clone(), gets)?;
                    for input in inputs {
                        self.action(input, columns.clone(), gets)?;
                    }
                    Ok(())
                }
                MirRelationExpr::ArrangeBy { input, .. } => self.action(input, columns, gets),
            }
        })
    }
}

fn intersect_all(columns_vec: &Vec<BTreeSet<usize>>) -> BTreeSet<usize> {
    columns_vec.iter().skip(1).fold(
        columns_vec.first().cloned().unwrap_or_default(),
        |mut intersection, columns| {
            intersection.retain(|col| columns.contains(col));
            intersection
        },
    )
}