mz_transform/
predicate_pushdown.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Pushes predicates down through other operators.
11//!
12//! This action generally improves the quality of the query, in that selective per-record
13//! filters reduce the volume of data before they arrive at more expensive operators.
14//!
15//!
16//! The one time when this action might not improve the quality of a query is
17//! if a filter gets pushed down on an arrangement because that blocks arrangement
18//! reuse. It assumed that actions that need an arrangement are responsible for
19//! lifting filters out of the way.
20//!
21//! Predicate pushdown will not push down literal errors, unless it is certain that
22//! the literal errors will be unconditionally evaluated. For example, the pushdown
23//! will not happen if not all predicates can be pushed down (e.g. reduce and map),
24//! or if we are not certain that the input is non-empty (e.g. join).
25//! Note that this is not addressing the problem in its full generality, because this problem can
26//! occur with any function call that might error (although much more rarely than with literal
27//! errors). See <https://github.com/MaterializeInc/database-issues/issues/4972#issuecomment-1547391011>
28//!
29//! ```rust
30//! use mz_expr::{BinaryFunc, MirRelationExpr, MirScalarExpr, func};
31//! use mz_ore::id_gen::IdGen;
32//! use mz_repr::{SqlColumnType, Datum, SqlRelationType, SqlScalarType};
33//! use mz_repr::optimize::OptimizerFeatures;
34//! use mz_transform::{reprtypecheck, typecheck,Transform, TransformCtx};
35//! use mz_transform::dataflow::DataflowMetainfo;
36//!
37//! use mz_transform::predicate_pushdown::PredicatePushdown;
38//!
39//! let input1 = MirRelationExpr::constant(vec![], SqlRelationType::new(vec![
40//!     SqlScalarType::Bool.nullable(false),
41//! ]));
42//! let input2 = MirRelationExpr::constant(vec![], SqlRelationType::new(vec![
43//!     SqlScalarType::Bool.nullable(false),
44//! ]));
45//! let input3 = MirRelationExpr::constant(vec![], SqlRelationType::new(vec![
46//!     SqlScalarType::Bool.nullable(false),
47//! ]));
48//! let join = MirRelationExpr::join(
49//!     vec![input1.clone(), input2.clone(), input3.clone()],
50//!     vec![vec![(0, 0), (2, 0)].into_iter().collect()],
51//! );
52//!
53//! let predicate0 = MirScalarExpr::column(0);
54//! let predicate1 = MirScalarExpr::column(1);
55//! let predicate01 = MirScalarExpr::column(0).call_binary(MirScalarExpr::column(2), func::AddInt64);
56//! let predicate012 = MirScalarExpr::literal_false();
57//!
58//! let mut expr = join.filter(
59//!    vec![
60//!        predicate0.clone(),
61//!        predicate1.clone(),
62//!        predicate01.clone(),
63//!        predicate012.clone(),
64//!    ]);
65//!
66//! let features = OptimizerFeatures::default();
67//! let typecheck_ctx = typecheck::empty_context();
68//! let repr_typecheck_ctx = reprtypecheck::empty_context();
69//! let mut df_meta = DataflowMetainfo::default();
70//! let mut transform_ctx = TransformCtx::local(&features, &typecheck_ctx, &repr_typecheck_ctx, &mut df_meta, None, None);
71//!
72//! PredicatePushdown::default().transform(&mut expr, &mut transform_ctx);
73//!
74//! let predicate00 = MirScalarExpr::column(0).call_binary(MirScalarExpr::column(0), func::AddInt64);
75//! let expected_expr = MirRelationExpr::join(
76//!     vec![
77//!         input1.clone().filter(vec![predicate0.clone(), predicate00.clone()]),
78//!         input2.clone().filter(vec![predicate0.clone()]),
79//!         input3.clone().filter(vec![predicate0, predicate00])
80//!     ],
81//!     vec![vec![(0, 0), (2, 0)].into_iter().collect()],
82//! ).filter(vec![predicate012]);
83//! assert_eq!(expected_expr, expr)
84//! ```
85
86use std::collections::{BTreeMap, BTreeSet};
87
88use itertools::Itertools;
89use mz_expr::visit::{Visit, VisitChildren};
90use mz_expr::{
91    AggregateFunc, Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT,
92    VariadicFunc, func,
93};
94use mz_ore::soft_assert_eq_no_log;
95use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
96use mz_repr::{Datum, SqlColumnType, SqlScalarType};
97
98use crate::{TransformCtx, TransformError};
99
100/// Pushes predicates down through other operators.
101#[derive(Debug)]
102pub struct PredicatePushdown {
103    recursion_guard: RecursionGuard,
104}
105
106impl Default for PredicatePushdown {
107    fn default() -> PredicatePushdown {
108        PredicatePushdown {
109            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
110        }
111    }
112}
113
114impl CheckedRecursion for PredicatePushdown {
115    fn recursion_guard(&self) -> &RecursionGuard {
116        &self.recursion_guard
117    }
118}
119
120impl crate::Transform for PredicatePushdown {
121    fn name(&self) -> &'static str {
122        "PredicatePushdown"
123    }
124
125    #[mz_ore::instrument(
126        target = "optimizer",
127        level = "debug",
128        fields(path.segment = "predicate_pushdown")
129    )]
130    fn actually_perform_transform(
131        &self,
132        relation: &mut MirRelationExpr,
133        _: &mut TransformCtx,
134    ) -> Result<(), TransformError> {
135        let mut empty = BTreeMap::new();
136        let result = self.action(relation, &mut empty);
137        mz_repr::explain::trace_plan(&*relation);
138        result
139    }
140}
141
142impl PredicatePushdown {
143    /// Predicate pushdown
144    ///
145    /// This method looks for opportunities to push predicates toward
146    /// sources of data. Primarily, this is the `Filter` expression,
147    /// and moving its predicates through the operators it contains.
148    ///
149    /// In addition, the method accumulates the intersection of predicates
150    /// applied to each `Get` expression, so that the predicate can
151    /// then be pushed through to a `Let` binding, or to the external
152    /// source of the data if the `Get` binds to another view.
153    pub fn action(
154        &self,
155        relation: &mut MirRelationExpr,
156        get_predicates: &mut BTreeMap<Id, BTreeSet<MirScalarExpr>>,
157    ) -> Result<(), TransformError> {
158        self.checked_recur(|_| {
159            // In the case of Filter or Get we have specific work to do;
160            // otherwise we should recursively descend.
161            match relation {
162                MirRelationExpr::Filter { input, predicates } => {
163                    // Reduce the predicates to determine as best as possible
164                    // whether they are literal errors before working with them.
165                    let input_type = input.typ();
166                    for predicate in predicates.iter_mut() {
167                        predicate.reduce(&input_type.column_types);
168                    }
169
170                    // It can be helpful to know if there are any non-literal errors,
171                    // as this is justification for not pushing down literal errors.
172                    let all_errors = predicates.iter().all(|p| p.is_literal_err());
173                    // Depending on the type of `input` we have different
174                    // logic to apply to consider pushing `predicates` down.
175                    match &mut **input {
176                        MirRelationExpr::Let { body, .. }
177                        | MirRelationExpr::LetRec { body, .. } => {
178                            // Push all predicates to the body.
179                            **body = body
180                                .take_dangerous()
181                                .filter(std::mem::replace(predicates, Vec::new()));
182
183                            self.action(input, get_predicates)?;
184                        }
185                        MirRelationExpr::Get { id, .. } => {
186                            // We can report the predicates upward in `get_predicates`,
187                            // but we are not yet able to delete them from the
188                            // `Filter`.
189                            get_predicates
190                                .entry(*id)
191                                .or_insert_with(|| predicates.iter().cloned().collect())
192                                .retain(|p| predicates.contains(p));
193                        }
194                        MirRelationExpr::Join {
195                            inputs,
196                            equivalences,
197                            ..
198                        } => {
199                            // We want to scan `predicates` for any that can
200                            // 1) become join variable constraints
201                            // 2) apply to individual elements of `inputs`.
202                            // Figuring out the set of predicates that belong to
203                            //    the latter group requires 1) knowing which predicates
204                            //    are in the former group and 2) that the variable
205                            //    constraints be in canonical form.
206                            // Thus, there is a first scan across `predicates` to
207                            //    populate the join variable constraints
208                            //    and a second scan across the remaining predicates
209                            //    to see which ones can become individual elements of
210                            //    `inputs`.
211
212                            let input_mapper = mz_expr::JoinInputMapper::new(inputs);
213
214                            // Predicates not translated into join variable
215                            // constraints. We will attempt to push them at all
216                            // inputs, and failing to
217                            let mut pred_not_translated = Vec::new();
218
219                            for mut predicate in predicates.drain(..) {
220                                use mz_expr::{BinaryFunc, UnaryFunc};
221                                if let MirScalarExpr::CallBinary {
222                                    func: BinaryFunc::Eq(_),
223                                    expr1,
224                                    expr2,
225                                } = &predicate
226                                {
227                                    // Translate into join variable constraints:
228                                    // 1) `nonliteral1 == nonliteral2` constraints
229                                    // 2) `expr == literal` where `expr` refers to more
230                                    //    than one input.
231                                    let input_count =
232                                        input_mapper.lookup_inputs(&predicate).count();
233                                    if (!expr1.is_literal() && !expr2.is_literal())
234                                        || input_count >= 2
235                                    {
236                                        // `col1 == col2` as a `MirScalarExpr`
237                                        // implies `!isnull(col1)` as well.
238                                        // `col1 == col2` as a join constraint does
239                                        // not have this extra implication.
240                                        // Thus, when translating the
241                                        // `MirScalarExpr` to a join constraint, we
242                                        // need to retain the `!isnull(col1)`
243                                        // information.
244                                        if expr1.typ(&input_type.column_types).nullable {
245                                            pred_not_translated.push(
246                                                expr1
247                                                    .clone()
248                                                    .call_unary(UnaryFunc::IsNull(func::IsNull))
249                                                    .call_unary(UnaryFunc::Not(func::Not)),
250                                            );
251                                        } else if expr2.typ(&input_type.column_types).nullable {
252                                            pred_not_translated.push(
253                                                expr2
254                                                    .clone()
255                                                    .call_unary(UnaryFunc::IsNull(func::IsNull))
256                                                    .call_unary(UnaryFunc::Not(func::Not)),
257                                            );
258                                        }
259                                        equivalences
260                                            .push(vec![(**expr1).clone(), (**expr2).clone()]);
261                                        continue;
262                                    }
263                                } else if let Some((expr1, expr2)) =
264                                    Self::extract_equal_or_both_null(
265                                        &mut predicate,
266                                        &input_type.column_types,
267                                    )
268                                {
269                                    // Also translate into join variable constraints:
270                                    // 3) `((nonliteral1 = nonliteral2) || (nonliteral
271                                    //    is null && nonliteral2 is null))`
272                                    equivalences.push(vec![expr1, expr2]);
273                                    continue;
274                                }
275                                pred_not_translated.push(predicate)
276                            }
277
278                            mz_expr::canonicalize::canonicalize_equivalences(
279                                equivalences,
280                                std::iter::once(&input_type.column_types),
281                            );
282
283                            let (retain, push_downs) = Self::push_filters_through_join(
284                                &input_mapper,
285                                equivalences,
286                                pred_not_translated,
287                            );
288
289                            Self::update_join_inputs_with_push_downs(inputs, push_downs);
290
291                            // Recursively descend on the join
292                            self.action(input, get_predicates)?;
293
294                            // remove all predicates that were pushed down from the current Filter node
295                            *predicates = retain;
296                        }
297                        MirRelationExpr::Reduce {
298                            input: inner,
299                            group_key,
300                            aggregates,
301                            monotonic: _,
302                            expected_group_size: _,
303                        } => {
304                            let mut retain = Vec::new();
305                            let mut push_down = Vec::new();
306                            for predicate in predicates.drain(..) {
307                                // Do not push down literal errors unless it is only errors.
308                                if !predicate.is_literal_err() || all_errors {
309                                    let mut supported = true;
310                                    let mut new_predicate = predicate.clone();
311                                    new_predicate.visit_pre(|e| {
312                                        if let MirScalarExpr::Column(c, _) = e {
313                                            if *c >= group_key.len() {
314                                                supported = false;
315                                            }
316                                        }
317                                    });
318                                    if supported {
319                                        new_predicate.visit_mut_post(&mut |e| {
320                                            if let MirScalarExpr::Column(i, _) = e {
321                                                *e = group_key[*i].clone();
322                                            }
323                                        })?;
324                                        push_down.push(new_predicate);
325                                    } else if let MirScalarExpr::Column(col, _) = &predicate {
326                                        if *col == group_key.len()
327                                            && aggregates.len() == 1
328                                            && aggregates[0].func == AggregateFunc::Any
329                                        {
330                                            push_down.push(aggregates[0].expr.clone());
331                                            aggregates[0].expr = MirScalarExpr::literal_ok(
332                                                Datum::True,
333                                                SqlScalarType::Bool,
334                                            );
335                                        } else {
336                                            retain.push(predicate);
337                                        }
338                                    } else {
339                                        retain.push(predicate);
340                                    }
341                                } else {
342                                    retain.push(predicate);
343                                }
344                            }
345
346                            if !push_down.is_empty() {
347                                *inner = Box::new(inner.take_dangerous().filter(push_down));
348                            }
349                            self.action(inner, get_predicates)?;
350
351                            // remove all predicates that were pushed down from the current Filter node
352                            std::mem::swap(&mut retain, predicates);
353                        }
354                        MirRelationExpr::TopK {
355                            input,
356                            group_key,
357                            order_key: _,
358                            limit,
359                            offset: _,
360                            monotonic: _,
361                            expected_group_size: _,
362                        } => {
363                            let mut retain = Vec::new();
364                            let mut push_down = Vec::new();
365
366                            let mut support = BTreeSet::new();
367                            support.extend(group_key.iter().cloned());
368                            if let Some(limit) = limit {
369                                // Strictly speaking not needed because the
370                                // `limit` support should be a subset of the
371                                // `group_key` support, but we don't want to
372                                // take this for granted here.
373                                limit.support_into(&mut support);
374                            }
375
376                            for predicate in predicates.drain(..) {
377                                // Do not push down literal errors unless it is only errors.
378                                if (!predicate.is_literal_err() || all_errors)
379                                    && predicate.support().is_subset(&support)
380                                {
381                                    push_down.push(predicate);
382                                } else {
383                                    retain.push(predicate);
384                                }
385                            }
386
387                            // remove all predicates that were pushed down from the current Filter node
388                            std::mem::swap(&mut retain, predicates);
389
390                            if !push_down.is_empty() {
391                                *input = Box::new(input.take_dangerous().filter(push_down));
392                            }
393
394                            self.action(input, get_predicates)?;
395                        }
396                        MirRelationExpr::Threshold { input } => {
397                            let predicates = std::mem::take(predicates);
398                            *relation = input.take_dangerous().filter(predicates).threshold();
399                            self.action(relation, get_predicates)?;
400                        }
401                        MirRelationExpr::Project { input, outputs } => {
402                            let predicates = predicates.drain(..).map(|mut predicate| {
403                                predicate.permute(outputs);
404                                predicate
405                            });
406                            *relation = input
407                                .take_dangerous()
408                                .filter(predicates)
409                                .project(outputs.clone());
410
411                            self.action(relation, get_predicates)?;
412                        }
413                        MirRelationExpr::Filter {
414                            input,
415                            predicates: predicates2,
416                        } => {
417                            *relation = input
418                                .take_dangerous()
419                                .filter(predicates.clone().into_iter().chain(predicates2.clone()));
420                            self.action(relation, get_predicates)?;
421                        }
422                        MirRelationExpr::Map { input, scalars } => {
423                            let (retained, pushdown) = Self::push_filters_through_map(
424                                scalars,
425                                predicates,
426                                input.arity(),
427                                all_errors,
428                            )?;
429                            let scalars = std::mem::take(scalars);
430                            let mut result = input.take_dangerous();
431                            if !pushdown.is_empty() {
432                                result = result.filter(pushdown);
433                            }
434                            self.action(&mut result, get_predicates)?;
435                            result = result.map(scalars);
436                            if !retained.is_empty() {
437                                result = result.filter(retained);
438                            }
439                            *relation = result;
440                        }
441                        MirRelationExpr::FlatMap { input, .. } => {
442                            let (mut retained, pushdown) =
443                                Self::push_filters_through_flat_map(predicates, input.arity());
444
445                            // remove all predicates that were pushed down from the current Filter node
446                            std::mem::swap(&mut retained, predicates);
447
448                            if !pushdown.is_empty() {
449                                // put the filter on top of the input
450                                **input = input.take_dangerous().filter(pushdown);
451                            }
452
453                            // ... and keep pushing predicates down
454                            self.action(input, get_predicates)?;
455                        }
456                        MirRelationExpr::Union { base, inputs } => {
457                            let predicates = std::mem::take(predicates);
458                            *base = Box::new(base.take_dangerous().filter(predicates.clone()));
459                            self.action(base, get_predicates)?;
460                            for input in inputs {
461                                *input = input.take_dangerous().filter(predicates.clone());
462                                self.action(input, get_predicates)?;
463                            }
464                        }
465                        MirRelationExpr::Negate { input } => {
466                            // Don't push literal errors past a Negate. The problem is that it's
467                            // hard to appropriately reflect the negation in the error stream:
468                            // - If we don't negate, then errors that should cancel out will not
469                            //   cancel out. For example, see
470                            //   https://github.com/MaterializeInc/database-issues/issues/5691
471                            // - If we negate, then unrelated errors might cancel out. E.g., there
472                            //   might be a division-by-0 in both inputs to an EXCEPT ALL, but
473                            //   on different input data. These shouldn't cancel out.
474                            let (retained, pushdown): (Vec<_>, Vec<_>) = std::mem::take(predicates)
475                                .into_iter()
476                                .partition(|p| p.is_literal_err());
477                            let mut result = input.take_dangerous();
478                            if !pushdown.is_empty() {
479                                result = result.filter(pushdown);
480                            }
481                            self.action(&mut result, get_predicates)?;
482                            result = result.negate();
483                            if !retained.is_empty() {
484                                result = result.filter(retained);
485                            }
486                            *relation = result;
487                        }
488                        x => {
489                            x.try_visit_mut_children(|e| self.action(e, get_predicates))?;
490                        }
491                    }
492
493                    // remove empty filters (junk by-product of the actual transform)
494                    match relation {
495                        MirRelationExpr::Filter { predicates, input } if predicates.is_empty() => {
496                            *relation = input.take_dangerous();
497                        }
498                        _ => {}
499                    }
500
501                    Ok(())
502                }
503                MirRelationExpr::Get { id, .. } => {
504                    // Purge all predicates associated with the id.
505                    get_predicates
506                        .entry(*id)
507                        .or_insert_with(BTreeSet::new)
508                        .clear();
509
510                    Ok(())
511                }
512                MirRelationExpr::Let { id, body, value } => {
513                    // Push predicates and collect intersection at `Get`s.
514                    self.action(body, get_predicates)?;
515
516                    // `get_predicates` should now contain the intersection
517                    // of predicates at each *use* of the binding. If it is
518                    // non-empty, we can move those predicates to the value.
519                    Self::push_into_let_binding(get_predicates, id, value, &mut [body]);
520
521                    // Continue recursively on the value.
522                    self.action(value, get_predicates)
523                }
524                MirRelationExpr::LetRec {
525                    ids,
526                    values,
527                    limits: _,
528                    body,
529                } => {
530                    // Note: This could be extended to be able to do a little more pushdowns, see
531                    // https://github.com/MaterializeInc/database-issues/issues/5336#issuecomment-1477588262
532
533                    // Pre-compute which Ids are used across iterations
534                    let ids_used_across_iterations = MirRelationExpr::recursive_ids(ids, values);
535
536                    // Predicate pushdown within the body
537                    self.action(body, get_predicates)?;
538
539                    // `users` will be the body plus the values of those bindings that we have seen
540                    // so far, while going one-by-one through the list of bindings backwards.
541                    // `users` contains those expressions from which we harvested `get_predicates`,
542                    // and therefore we should attend to all of these expressions when pushing down
543                    // a predicate into a Let binding.
544                    let mut users = vec![&mut **body];
545                    for (id, value) in ids.iter_mut().rev().zip_eq(values.into_iter().rev()) {
546                        // Predicate pushdown from Gets in `users` into the value of a Let binding
547                        //
548                        // For now, we simply always avoid pushing into a Let binding that is
549                        // referenced across iterations to avoid soundness problems and infinite
550                        // pushdowns.
551                        //
552                        // Note that `push_into_let_binding` makes a further check based on
553                        // `get_predicates`: We push a predicate into the value of a binding, only
554                        // if all Gets of this Id have this same predicate on top of them.
555                        if !ids_used_across_iterations.contains(id) {
556                            Self::push_into_let_binding(get_predicates, id, value, &mut users);
557                        }
558
559                        // Predicate pushdown within a binding
560                        self.action(value, get_predicates)?;
561
562                        users.push(value);
563                    }
564
565                    Ok(())
566                }
567                MirRelationExpr::Join {
568                    inputs,
569                    equivalences,
570                    ..
571                } => {
572                    // The goal is to push
573                    //   1) equivalences of the form `expr = <runtime constant>`, where `expr`
574                    //      comes from a single input.
575                    //   2) equivalences of the form `expr1 = expr2`, where both
576                    //      expressions come from the same single input.
577                    let input_types = inputs.iter().map(|i| i.typ()).collect::<Vec<_>>();
578                    mz_expr::canonicalize::canonicalize_equivalences(
579                        equivalences,
580                        input_types.iter().map(|t| &t.column_types),
581                    );
582
583                    let input_mapper = mz_expr::JoinInputMapper::new_from_input_types(&input_types);
584                    // Predicates to push at each input, and to lift out the join.
585                    let mut push_downs = vec![Vec::new(); inputs.len()];
586
587                    for equivalence_pos in 0..equivalences.len() {
588                        // Case 1: there are more than one literal in the
589                        // equivalence class. Because of equivalences have been
590                        // dedupped, this means that everything in the equivalence
591                        // class must be equal to two different literals, so the
592                        // entire relation zeroes out
593                        if equivalences[equivalence_pos]
594                            .iter()
595                            .filter(|expr| expr.is_literal())
596                            .count()
597                            > 1
598                        {
599                            relation.take_safely(Some(relation.typ_with_input_types(&input_types)));
600                            return Ok(());
601                        }
602
603                        let runtime_constants = equivalences[equivalence_pos]
604                            .iter()
605                            .filter(|expr| expr.support().is_empty())
606                            .cloned()
607                            .collect::<Vec<_>>();
608                        if !runtime_constants.is_empty() {
609                            // Case 2: There is at least one runtime constant the equivalence class
610                            let gen_literal_equality_preds = |expr: MirScalarExpr| {
611                                let mut equality_preds = Vec::new();
612                                for constant in runtime_constants.iter() {
613                                    let pred = if constant.is_literal_null() {
614                                        MirScalarExpr::CallUnary {
615                                            func: mz_expr::UnaryFunc::IsNull(func::IsNull),
616                                            expr: Box::new(expr.clone()),
617                                        }
618                                    } else {
619                                        MirScalarExpr::CallBinary {
620                                            func: mz_expr::func::Eq.into(),
621                                            expr1: Box::new(expr.clone()),
622                                            expr2: Box::new(constant.clone()),
623                                        }
624                                    };
625                                    equality_preds.push(pred);
626                                }
627                                equality_preds
628                            };
629
630                            // Find all single input expressions in the equivalence
631                            // class and collect (position within the equivalence class,
632                            // input the expression belongs to, localized version of the
633                            // expression).
634                            let mut single_input_exprs = equivalences[equivalence_pos]
635                                .iter()
636                                .enumerate()
637                                .filter_map(|(pos, e)| {
638                                    let mut inputs = input_mapper.lookup_inputs(e);
639                                    if let Some(input) = inputs.next() {
640                                        if inputs.next().is_none() {
641                                            return Some((
642                                                pos,
643                                                input,
644                                                input_mapper.map_expr_to_local(e.clone()),
645                                            ));
646                                        }
647                                    }
648                                    None
649                                })
650                                .collect::<Vec<_>>();
651
652                            // For every single-input expression `expr`, we can push
653                            // down `expr = <runtime constant>` and remove `expr` from the
654                            // equivalence class.
655                            for (expr_pos, input, expr) in single_input_exprs.drain(..).rev() {
656                                push_downs[input].extend(gen_literal_equality_preds(expr));
657                                equivalences[equivalence_pos].remove(expr_pos);
658                            }
659
660                            // If none of the expressions in the equivalence depend on input
661                            // columns and equality predicates with them are pushed down,
662                            // we can safely remove them from the equivalence.
663                            // TODO: we could probably push equality predicates among the
664                            // remaining constants to all join inputs to prevent any computation
665                            // from happening until the condition is satisfied.
666                            if equivalences[equivalence_pos]
667                                .iter()
668                                .all(|e| e.support().is_empty())
669                                && push_downs.iter().any(|p| !p.is_empty())
670                            {
671                                equivalences[equivalence_pos].clear();
672                            }
673                        } else {
674                            // Case 3: There are no constants in the equivalence
675                            // class. Push a predicate for every pair of expressions
676                            // in the equivalence that either belong to a single
677                            // input or can be localized to a given input through
678                            // the rest of equivalences.
679                            let mut to_remove = Vec::new();
680                            for input in 0..inputs.len() {
681                                // Vector of pairs (position within the equivalence, localized
682                                // expression). The position is None for expressions derived through
683                                // other equivalences.
684                                let localized = equivalences[equivalence_pos]
685                                    .iter()
686                                    .enumerate()
687                                    .filter_map(|(pos, expr)| {
688                                        if let MirScalarExpr::Column(col_pos, _) = &expr {
689                                            let local_col =
690                                                input_mapper.map_column_to_local(*col_pos);
691                                            if input == local_col.1 {
692                                                // TODO(mgree) !!! is it safe to propagate the name here?
693                                                return Some((
694                                                    Some(pos),
695                                                    MirScalarExpr::column(local_col.0),
696                                                ));
697                                            } else {
698                                                return None;
699                                            }
700                                        }
701                                        let mut inputs = input_mapper.lookup_inputs(expr);
702                                        if let Some(single_input) = inputs.next() {
703                                            if input == single_input && inputs.next().is_none() {
704                                                return Some((
705                                                    Some(pos),
706                                                    input_mapper.map_expr_to_local(expr.clone()),
707                                                ));
708                                            }
709                                        }
710                                        // Equivalences not including the current expression
711                                        let mut other_equivalences = equivalences.clone();
712                                        other_equivalences[equivalence_pos].remove(pos);
713                                        let mut localized = expr.clone();
714                                        if input_mapper.try_localize_to_input_with_bound_expr(
715                                            &mut localized,
716                                            input,
717                                            &other_equivalences[..],
718                                        ) {
719                                            Some((None, localized))
720                                        } else {
721                                            None
722                                        }
723                                    })
724                                    .collect::<Vec<_>>();
725
726                                // If there are at least 2 expression in the equivalence that
727                                // can be localized to the same input, push all combinations
728                                // of them to the input.
729                                if localized.len() > 1 {
730                                    for mut pair in
731                                        localized.iter().map(|(_, expr)| expr).combinations(2)
732                                    {
733                                        let expr1 = pair.pop().unwrap();
734                                        let expr2 = pair.pop().unwrap();
735
736                                        push_downs[input].push(
737                                            MirScalarExpr::CallBinary {
738                                                func: func::Eq.into(),
739                                                expr1: Box::new(expr2.clone()),
740                                                expr2: Box::new(expr1.clone()),
741                                            }
742                                            .or(expr2
743                                                .clone()
744                                                .call_is_null()
745                                                .and(expr1.clone().call_is_null())),
746                                        );
747                                    }
748
749                                    if localized.len() == equivalences[equivalence_pos].len() {
750                                        // The equivalence is either a single input one or fully localizable
751                                        // to a single input through other equivalences, so it can be removed
752                                        // completely without introducing any new cross join.
753                                        to_remove.extend(0..equivalences[equivalence_pos].len());
754                                    } else {
755                                        // Leave an expression from this input in the equivalence to avoid
756                                        // cross joins
757                                        to_remove.extend(
758                                            localized.iter().filter_map(|(pos, _)| *pos).skip(1),
759                                        );
760                                    }
761                                }
762                            }
763
764                            // Remove expressions that were pushed down to at least one input
765                            to_remove.sort();
766                            to_remove.dedup();
767                            for pos in to_remove.iter().rev() {
768                                equivalences[equivalence_pos].remove(*pos);
769                            }
770                        };
771                    }
772
773                    mz_expr::canonicalize::canonicalize_equivalences(
774                        equivalences,
775                        input_types.iter().map(|t| &t.column_types),
776                    );
777
778                    Self::update_join_inputs_with_push_downs(inputs, push_downs);
779
780                    // Recursively descend on each of the inputs.
781                    for input in inputs.iter_mut() {
782                        self.action(input, get_predicates)?;
783                    }
784
785                    Ok(())
786                }
787                x => {
788                    // Recursively descend.
789                    x.try_visit_mut_children(|e| self.action(e, get_predicates))
790                }
791            }
792        })
793    }
794
795    fn update_join_inputs_with_push_downs(
796        inputs: &mut Vec<MirRelationExpr>,
797        push_downs: Vec<Vec<MirScalarExpr>>,
798    ) {
799        let new_inputs = inputs
800            .drain(..)
801            .zip_eq(push_downs)
802            .map(|(input, push_down)| {
803                if !push_down.is_empty() {
804                    input.filter(push_down)
805                } else {
806                    input
807                }
808            })
809            .collect();
810        *inputs = new_inputs;
811    }
812
813    // Checks `get_predicates` to see whether we can push a predicate into the Let binding given
814    // by `id` and `value`.
815    // `users` is the list of those expressions from which we will need to remove a predicate that
816    // is being pushed.
817    fn push_into_let_binding(
818        get_predicates: &mut BTreeMap<Id, BTreeSet<MirScalarExpr>>,
819        id: &LocalId,
820        value: &mut MirRelationExpr,
821        users: &mut [&mut MirRelationExpr],
822    ) {
823        if let Some(list) = get_predicates.remove(&Id::Local(*id)) {
824            if !list.is_empty() {
825                // Remove the predicates in `list` from the users.
826                for user in users {
827                    user.visit_pre_mut(|e| {
828                        if let MirRelationExpr::Filter { input, predicates } = e {
829                            if let MirRelationExpr::Get { id: get_id, .. } = **input {
830                                if get_id == Id::Local(*id) {
831                                    predicates.retain(|p| !list.contains(p));
832                                }
833                            }
834                        }
835                    });
836                }
837                // Apply the predicates in `list` to value. Canonicalize
838                // `list` so that plans are always deterministic.
839                let mut list = list.into_iter().collect::<Vec<_>>();
840                mz_expr::canonicalize::canonicalize_predicates(
841                    &mut list,
842                    &value.typ().column_types,
843                );
844                *value = value.take_dangerous().filter(list);
845            }
846        }
847    }
848
849    /// Returns `(<predicates to retain>, <predicates to push at each input>)`.
850    pub fn push_filters_through_join(
851        input_mapper: &JoinInputMapper,
852        equivalences: &Vec<Vec<MirScalarExpr>>,
853        mut predicates: Vec<MirScalarExpr>,
854    ) -> (Vec<MirScalarExpr>, Vec<Vec<MirScalarExpr>>) {
855        let mut push_downs = vec![Vec::new(); input_mapper.total_inputs()];
856        let mut retain = Vec::new();
857
858        for predicate in predicates.drain(..) {
859            // Track if the predicate has been pushed to at least one input.
860            let mut pushed = false;
861            // For each input, try and see if the join
862            // equivalences allow the predicate to be rewritten
863            // in terms of only columns from that input.
864            for (index, push_down) in push_downs.iter_mut().enumerate() {
865                #[allow(deprecated)] // TODO: use `might_error` if possible.
866                if predicate.is_literal_err() || predicate.contains_error_if_null() {
867                    // Do nothing. We don't push down literal errors,
868                    // as we can't know the join will be non-empty.
869                    //
870                    // We also don't want to push anything that involves `error_if_null`. This is
871                    // for the same reason why in theory we shouldn't really push anything that can
872                    // error, assuming that we want to preserve error semantics. (Because we would
873                    // create a spurious error if some other Join input ends up empty.) We can't fix
874                    // this problem in general (as we can't just not push anything that might
875                    // error), but we decided to fix the specific problem instance involving
876                    // `error_if_null`, because it was very painful:
877                    // <https://github.com/MaterializeInc/database-issues/issues/6258>
878                } else {
879                    let mut localized = predicate.clone();
880                    if input_mapper.try_localize_to_input_with_bound_expr(
881                        &mut localized,
882                        index,
883                        equivalences,
884                    ) {
885                        push_down.push(localized);
886                        pushed = true;
887                    } else if let Some(consequence) = input_mapper
888                        // (`consequence_for_input` assumes that
889                        // `try_localize_to_input_with_bound_expr` has already
890                        // been called on `localized`.)
891                        .consequence_for_input(&localized, index)
892                    {
893                        push_down.push(consequence);
894                        // We don't set `pushed` here! We want to retain the
895                        // predicate, because we only pushed a consequence of
896                        // it, but not the full predicate.
897                    }
898                }
899            }
900
901            if !pushed {
902                retain.push(predicate);
903            }
904        }
905
906        (retain, push_downs)
907    }
908
909    /// Computes "safe" predicates to push through a Map.
910    ///
911    /// In the case of a Filter { Map {...} }, we can always push down the Filter
912    /// by inlining expressions from the Map. We don't want to do this in general,
913    /// however, since general inlining can result in exponential blowup in the size
914    /// of expressions, so we only do this in the case where the size after inlining
915    /// is below a certain limit.
916    ///
917    /// Returns the predicates that can be pushed down, followed by ones that cannot.
918    pub fn push_filters_through_map(
919        map_exprs: &Vec<MirScalarExpr>,
920        predicates: &mut Vec<MirScalarExpr>,
921        input_arity: usize,
922        all_errors: bool,
923    ) -> Result<(Vec<MirScalarExpr>, Vec<MirScalarExpr>), TransformError> {
924        let mut pushdown = Vec::new();
925        let mut retained = Vec::new();
926        for predicate in predicates.drain(..) {
927            // We don't push down literal errors, unless all predicates are.
928            if !predicate.is_literal_err() || all_errors {
929                // Consider inlining Map expressions.
930                if let Some(cleaned) =
931                    Self::inline_if_not_too_big(&predicate, input_arity, map_exprs)?
932                {
933                    pushdown.push(cleaned);
934                } else {
935                    retained.push(predicate);
936                }
937            } else {
938                retained.push(predicate);
939            }
940        }
941        Ok((retained, pushdown))
942    }
943
944    /// This fn should be called with a Filter `expr` that is after a Map. `input_arity` is the
945    /// arity of the input of the Map. This fn eliminates such column refs in `expr` that refer not
946    /// to a column in the input of the Map, but to a column that is created by the Map. It does
947    /// this by transitively inlining Map expressions until no such expression remains that points
948    /// to a Map expression. The return value is the cleaned up expression. The fn bails out with a
949    /// None if the resulting expression would be made too big by the inlinings.
950    ///
951    /// OOO: (Optimizer Optimization Opportunity) This function might do work proportional to the
952    /// total size of the Map expressions. We call this function for each predicate above the Map,
953    /// which will be kind of quadratic, i.e., if there are many predicates and a big Map, then this
954    /// will be slow. We could instead pass a vector of Map expressions and call this fn only once.
955    /// The only downside would be that then the inlining limit being hit in the middle part of this
956    /// function would prevent us from inlining any predicates, even ones that wouldn't hit the
957    /// inlining limit if considered on their own.
958    fn inline_if_not_too_big(
959        expr: &MirScalarExpr,
960        input_arity: usize,
961        map_exprs: &Vec<MirScalarExpr>,
962    ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
963        let size_limit = 1000;
964
965        // Transitively determine the support of `expr` produced by `map_exprs`
966        // that needs to be inlined.
967        let cols_to_inline = {
968            let mut support = BTreeSet::new();
969
970            // Seed with `map_exprs` support in `expr`.
971            expr.visit_pre(|e| {
972                if let MirScalarExpr::Column(c, _) = e {
973                    if *c >= input_arity {
974                        support.insert(*c);
975                    }
976                }
977            });
978
979            // Compute transitive closure of supports in `map_exprs`.
980            let mut workset = support.iter().cloned().collect::<Vec<_>>();
981            let mut buffer = vec![];
982            while !workset.is_empty() {
983                // Swap the (empty) `drained` buffer with the `workset`.
984                std::mem::swap(&mut workset, &mut buffer);
985                // Drain the `buffer` and update `support` and `workset`.
986                for c in buffer.drain(..) {
987                    map_exprs[c - input_arity].visit_pre(|e| {
988                        if let MirScalarExpr::Column(c, _) = e {
989                            if *c >= input_arity {
990                                if support.insert(*c) {
991                                    workset.push(*c);
992                                }
993                            }
994                        }
995                    });
996                }
997            }
998            support
999        };
1000
1001        let mut inlined = BTreeMap::<usize, (MirScalarExpr, usize)>::new();
1002        // Populate the memo table in ascending column order (which respects the
1003        // dependency order of `map_exprs` references). Break early if memoization
1004        // fails for one of the columns in `cols_to_inline`.
1005        for c in cols_to_inline.iter() {
1006            let mut new_expr = map_exprs[*c - input_arity].clone();
1007            let mut new_size = 0;
1008            new_expr.visit_mut_post(&mut |expr| {
1009                new_size += 1;
1010                if let MirScalarExpr::Column(c, _) = expr {
1011                    if *c >= input_arity && new_size <= size_limit {
1012                        // (inlined[c] is safe, because we proceed in column order, and we break out
1013                        // of the loop when we stop inserting into memo.)
1014                        let (m_expr, m_size): &(MirScalarExpr, _) = &inlined[c];
1015                        *expr = m_expr.clone();
1016                        new_size += m_size - 1; // Adjust for the +1 above.
1017                    }
1018                }
1019            })?;
1020
1021            if new_size <= size_limit {
1022                inlined.insert(*c, (new_expr, new_size));
1023            } else {
1024                break;
1025            }
1026        }
1027
1028        // Try to resolve expr against the memo table.
1029        if inlined.len() < cols_to_inline.len() {
1030            Ok(None) // We couldn't memoize all map expressions within the given limit.
1031        } else {
1032            let mut new_expr = expr.clone();
1033            let mut new_size = 0;
1034            new_expr.visit_mut_post(&mut |expr| {
1035                new_size += 1;
1036                if let MirScalarExpr::Column(c, _) = expr {
1037                    if *c >= input_arity && new_size <= size_limit {
1038                        // (inlined[c] is safe because of the outer if condition.)
1039                        let (m_expr, m_size): &(MirScalarExpr, _) = &inlined[c];
1040                        *expr = m_expr.clone();
1041                        new_size += m_size - 1; // Adjust for the +1 above.
1042                    }
1043                }
1044            })?;
1045
1046            soft_assert_eq_no_log!(new_size, new_expr.size());
1047            if new_size <= size_limit {
1048                Ok(Some(new_expr)) // We managed to stay within the limit.
1049            } else {
1050                Ok(None) // Limit exceeded.
1051            }
1052        }
1053    }
1054    // fn inline_if_not_too_big(
1055    //     expr: &MirScalarExpr,
1056    //     input_arity: usize,
1057    //     map_exprs: &Vec<MirScalarExpr>,
1058    // ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
1059    //     let size_limit = 1000;
1060    //     // Memoize cleaned up versions of Map expressions. (Not necessarily all the Map expressions
1061    //     // will be involved.)
1062    //     let mut memo: BTreeMap<MirScalarExpr, MirScalarExpr> = BTreeMap::new();
1063    //     fn rec(
1064    //         expr: &MirScalarExpr,
1065    //         input_arity: usize,
1066    //         map_exprs: &Vec<MirScalarExpr>,
1067    //         memo: &mut BTreeMap<MirScalarExpr, MirScalarExpr>,
1068    //         size_limit: usize,
1069    //     ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
1070    //         // (We can't use Entry::or_insert_with, because the closure would need to be fallible.
1071    //         // We also can't manually match on the result of memo.entry, because that holds a
1072    //         // borrow of memo, but we need to pass memo to the recursive call in the middle.)
1073    //         match memo.get(expr) {
1074    //             Some(memoized_result) => Ok(Some(memoized_result.clone())),
1075    //             None => {
1076    //                 let mut expr_size = expr.size()?;
1077    //                 let mut cleaned_expr = expr.clone();
1078    //                 let mut bail = false;
1079    //                 cleaned_expr.try_visit_mut_post(&mut |expr| {
1080    //                     Ok(if !bail {
1081    //                         match expr {
1082    //                             MirScalarExpr::Column(col) => {
1083    //                                 if *col >= input_arity {
1084    //                                     let to_inline = rec(
1085    //                                         &map_exprs[*col - input_arity],
1086    //                                         input_arity,
1087    //                                         map_exprs,
1088    //                                         memo,
1089    //                                         size_limit,
1090    //                                     )?;
1091    //                                     if let Some(to_inline) = to_inline {
1092    //                                         // The `-1` is because the expression that we are
1093    //                                         // replacing has a size of 1.
1094    //                                         expr_size += to_inline.size()? - 1;
1095    //                                         *expr = to_inline;
1096    //                                         if expr_size > size_limit {
1097    //                                             bail = true;
1098    //                                         }
1099    //                                     } else {
1100    //                                         bail = true;
1101    //                                     }
1102    //                                 }
1103    //                             }
1104    //                             _ => (),
1105    //                         }
1106    //                     })
1107    //                 })?;
1108    //                 soft_assert_eq!(cleaned_expr.size()?, expr_size);
1109    //                 if !bail {
1110    //                     memo.insert(expr.clone(), cleaned_expr.clone());
1111    //                     Ok(Some(cleaned_expr))
1112    //                 } else {
1113    //                     Ok(None)
1114    //                 }
1115    //             }
1116    //         }
1117    //     }
1118    //     rec(expr, input_arity, map_exprs, &mut memo, size_limit)
1119    // }
1120
1121    /// Computes "safe" predicate to push through a FlatMap.
1122    ///
1123    /// In the case of a Filter { FlatMap {...} }, we want to push through all predicates
1124    /// that (1) are not literal errors and (2) have support exclusively in the columns
1125    /// provided by the FlatMap input.
1126    ///
1127    /// Returns the predicates that can be pushed down, followed by ones that cannot.
1128    fn push_filters_through_flat_map(
1129        predicates: &mut Vec<MirScalarExpr>,
1130        input_arity: usize,
1131    ) -> (Vec<MirScalarExpr>, Vec<MirScalarExpr>) {
1132        let mut pushdown = Vec::new();
1133        let mut retained = Vec::new();
1134        for predicate in predicates.drain(..) {
1135            // First, check if we can push this predicate down. We can do so if and only if:
1136            // (1) the predicate is not a literal error, and
1137            // (2) each column it references is from the input.
1138            if (!predicate.is_literal_err()) && predicate.support().iter().all(|c| *c < input_arity)
1139            {
1140                pushdown.push(predicate);
1141            } else {
1142                retained.push(predicate);
1143            }
1144        }
1145        (retained, pushdown)
1146    }
1147
1148    /// If `s` is of the form
1149    /// `(isnull(expr1) && isnull(expr2)) || (expr1 = expr2)`, or
1150    /// `(decompose_is_null(expr1) && decompose_is_null(expr2)) || (expr1 = expr2)`,
1151    /// extract `expr1` and `expr2`.
1152    fn extract_equal_or_both_null(
1153        s: &mut MirScalarExpr,
1154        column_types: &[SqlColumnType],
1155    ) -> Option<(MirScalarExpr, MirScalarExpr)> {
1156        if let MirScalarExpr::CallVariadic {
1157            func: VariadicFunc::Or,
1158            exprs,
1159        } = s
1160        {
1161            if let &[ref or_lhs, ref or_rhs] = &**exprs {
1162                // Check both orders of operands of the OR
1163                return Self::extract_equal_or_both_null_inner(or_lhs, or_rhs, column_types)
1164                    .or_else(|| {
1165                        Self::extract_equal_or_both_null_inner(or_rhs, or_lhs, column_types)
1166                    });
1167            }
1168        }
1169        None
1170    }
1171
1172    fn extract_equal_or_both_null_inner(
1173        or_arg1: &MirScalarExpr,
1174        or_arg2: &MirScalarExpr,
1175        column_types: &[SqlColumnType],
1176    ) -> Option<(MirScalarExpr, MirScalarExpr)> {
1177        use mz_expr::BinaryFunc;
1178        if let MirScalarExpr::CallBinary {
1179            func: BinaryFunc::Eq(_),
1180            expr1: eq_lhs,
1181            expr2: eq_rhs,
1182        } = &or_arg2
1183        {
1184            let isnull1 = eq_lhs.clone().call_is_null();
1185            let isnull2 = eq_rhs.clone().call_is_null();
1186            let both_null = MirScalarExpr::CallVariadic {
1187                func: VariadicFunc::And,
1188                exprs: vec![isnull1, isnull2],
1189            };
1190
1191            if Self::extract_reduced_conjunction_terms(both_null, column_types)
1192                == Self::extract_reduced_conjunction_terms(or_arg1.clone(), column_types)
1193            {
1194                return Some(((**eq_lhs).clone(), (**eq_rhs).clone()));
1195            }
1196        }
1197        None
1198    }
1199
1200    /// Reduces the given expression and returns its AND-ed terms.
1201    fn extract_reduced_conjunction_terms(
1202        mut s: MirScalarExpr,
1203        column_types: &[SqlColumnType],
1204    ) -> Vec<MirScalarExpr> {
1205        s.reduce(column_types);
1206
1207        if let MirScalarExpr::CallVariadic {
1208            func: VariadicFunc::And,
1209            exprs,
1210        } = s
1211        {
1212            exprs
1213        } else {
1214            vec![s]
1215        }
1216    }
1217}