Skip to main content

mz_transform/
predicate_pushdown.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Pushes predicates down through other operators.
11//!
12//! This action generally improves the quality of the query, in that selective per-record
13//! filters reduce the volume of data before they arrive at more expensive operators.
14//!
15//!
16//! The one time when this action might not improve the quality of a query is
17//! if a filter gets pushed down on an arrangement because that blocks arrangement
18//! reuse. It assumed that actions that need an arrangement are responsible for
19//! lifting filters out of the way.
20//!
21//! Predicate pushdown will not push down literal errors, unless it is certain that
22//! the literal errors will be unconditionally evaluated. For example, the pushdown
23//! will not happen if not all predicates can be pushed down (e.g. reduce and map),
24//! or if we are not certain that the input is non-empty (e.g. join).
25//! Note that this is not addressing the problem in its full generality, because this problem can
26//! occur with any function call that might error (although much more rarely than with literal
27//! errors). See <https://github.com/MaterializeInc/database-issues/issues/4972#issuecomment-1547391011>
28//!
29//! ```rust
30//! use mz_expr::{BinaryFunc, MirRelationExpr, MirScalarExpr, func};
31//! use mz_ore::id_gen::IdGen;
32//! use mz_repr::{ReprColumnType, ReprRelationType, ReprScalarType};
33//! use mz_repr::optimize::OptimizerFeatures;
34//! use mz_transform::{reprtypecheck, Transform, TransformCtx};
35//! use mz_transform::dataflow::DataflowMetainfo;
36//!
37//! use mz_transform::predicate_pushdown::PredicatePushdown;
38//!
39//! let bool_typ = ReprRelationType::new(vec![
40//!     ReprColumnType { scalar_type: ReprScalarType::Bool, nullable: false },
41//! ]);
42//! let input1 = MirRelationExpr::Constant {
43//!     rows: Ok(vec![]),
44//!     typ: bool_typ.clone(),
45//! };
46//! let input2 = MirRelationExpr::Constant {
47//!     rows: Ok(vec![]),
48//!     typ: bool_typ.clone(),
49//! };
50//! let input3 = MirRelationExpr::Constant {
51//!     rows: Ok(vec![]),
52//!     typ: bool_typ,
53//! };
54//! let join = MirRelationExpr::join(
55//!     vec![input1.clone(), input2.clone(), input3.clone()],
56//!     vec![vec![(0, 0), (2, 0)].into_iter().collect()],
57//! );
58//!
59//! let predicate0 = MirScalarExpr::column(0);
60//! let predicate1 = MirScalarExpr::column(1);
61//! let predicate01 = MirScalarExpr::column(0).call_binary(MirScalarExpr::column(2), func::AddInt64);
62//! let predicate012 = MirScalarExpr::literal_false();
63//!
64//! let mut expr = join.filter(
65//!    vec![
66//!        predicate0.clone(),
67//!        predicate1.clone(),
68//!        predicate01.clone(),
69//!        predicate012.clone(),
70//!    ]);
71//!
72//! let features = OptimizerFeatures::default();
73//! let typecheck_ctx = reprtypecheck::empty_context();
74//! let mut df_meta = DataflowMetainfo::default();
75//! let mut transform_ctx = TransformCtx::local(&features, &typecheck_ctx, &mut df_meta, None, None);
76//!
77//! PredicatePushdown::default().transform(&mut expr, &mut transform_ctx);
78//!
79//! let predicate00 = MirScalarExpr::column(0).call_binary(MirScalarExpr::column(0), func::AddInt64);
80//! let expected_expr = MirRelationExpr::join(
81//!     vec![
82//!         input1.clone().filter(vec![predicate0.clone(), predicate00.clone()]),
83//!         input2.clone().filter(vec![predicate0.clone()]),
84//!         input3.clone().filter(vec![predicate0, predicate00])
85//!     ],
86//!     vec![vec![(0, 0), (2, 0)].into_iter().collect()],
87//! ).filter(vec![predicate012]);
88//! assert_eq!(expected_expr, expr)
89//! ```
90
91use std::collections::{BTreeMap, BTreeSet};
92
93use itertools::Itertools;
94use mz_expr::func::variadic::And;
95use mz_expr::visit::{Visit, VisitChildren};
96use mz_expr::{
97    AggregateFunc, Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT,
98    VariadicFunc, func,
99};
100use mz_ore::soft_assert_eq_no_log;
101use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError};
102use mz_repr::{Datum, ReprColumnType, ReprScalarType};
103
104use crate::{TransformCtx, TransformError};
105
106/// Pushes predicates down through other operators.
107#[derive(Debug)]
108pub struct PredicatePushdown {
109    recursion_guard: RecursionGuard,
110}
111
112impl Default for PredicatePushdown {
113    fn default() -> PredicatePushdown {
114        PredicatePushdown {
115            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
116        }
117    }
118}
119
120impl CheckedRecursion for PredicatePushdown {
121    fn recursion_guard(&self) -> &RecursionGuard {
122        &self.recursion_guard
123    }
124}
125
126impl crate::Transform for PredicatePushdown {
127    fn name(&self) -> &'static str {
128        "PredicatePushdown"
129    }
130
131    #[mz_ore::instrument(
132        target = "optimizer",
133        level = "debug",
134        fields(path.segment = "predicate_pushdown")
135    )]
136    fn actually_perform_transform(
137        &self,
138        relation: &mut MirRelationExpr,
139        _: &mut TransformCtx,
140    ) -> Result<(), TransformError> {
141        let mut empty = BTreeMap::new();
142        let result = self.action(relation, &mut empty);
143        mz_repr::explain::trace_plan(&*relation);
144        result
145    }
146}
147
148impl PredicatePushdown {
149    /// Predicate pushdown
150    ///
151    /// This method looks for opportunities to push predicates toward
152    /// sources of data. Primarily, this is the `Filter` expression,
153    /// and moving its predicates through the operators it contains.
154    ///
155    /// In addition, the method accumulates the intersection of predicates
156    /// applied to each `Get` expression, so that the predicate can
157    /// then be pushed through to a `Let` binding, or to the external
158    /// source of the data if the `Get` binds to another view.
159    pub fn action(
160        &self,
161        relation: &mut MirRelationExpr,
162        get_predicates: &mut BTreeMap<Id, BTreeSet<MirScalarExpr>>,
163    ) -> Result<(), TransformError> {
164        self.checked_recur(|_| {
165            // In the case of Filter or Get we have specific work to do;
166            // otherwise we should recursively descend.
167            match relation {
168                MirRelationExpr::Filter { input, predicates } => {
169                    // Reduce the predicates to determine as best as possible
170                    // whether they are literal errors before working with them.
171                    let input_type = input.typ();
172                    for predicate in predicates.iter_mut() {
173                        predicate.reduce(&input_type.column_types);
174                    }
175
176                    // It can be helpful to know if there are any non-literal errors,
177                    // as this is justification for not pushing down literal errors.
178                    let all_errors = predicates.iter().all(|p| p.is_literal_err());
179                    // Depending on the type of `input` we have different
180                    // logic to apply to consider pushing `predicates` down.
181                    match &mut **input {
182                        MirRelationExpr::Let { body, .. }
183                        | MirRelationExpr::LetRec { body, .. } => {
184                            // Push all predicates to the body.
185                            **body = body
186                                .take_dangerous()
187                                .filter(std::mem::replace(predicates, Vec::new()));
188
189                            self.action(input, get_predicates)?;
190                        }
191                        MirRelationExpr::Get { id, .. } => {
192                            // We can report the predicates upward in `get_predicates`,
193                            // but we are not yet able to delete them from the
194                            // `Filter`.
195                            get_predicates
196                                .entry(*id)
197                                .or_insert_with(|| predicates.iter().cloned().collect())
198                                .retain(|p| predicates.contains(p));
199                        }
200                        MirRelationExpr::Join {
201                            inputs,
202                            equivalences,
203                            ..
204                        } => {
205                            // We want to scan `predicates` for any that can
206                            // 1) become join variable constraints
207                            // 2) apply to individual elements of `inputs`.
208                            // Figuring out the set of predicates that belong to
209                            //    the latter group requires 1) knowing which predicates
210                            //    are in the former group and 2) that the variable
211                            //    constraints be in canonical form.
212                            // Thus, there is a first scan across `predicates` to
213                            //    populate the join variable constraints
214                            //    and a second scan across the remaining predicates
215                            //    to see which ones can become individual elements of
216                            //    `inputs`.
217
218                            let input_mapper = mz_expr::JoinInputMapper::new(inputs);
219
220                            // Predicates not translated into join variable
221                            // constraints. We will attempt to push them at all
222                            // inputs, and failing to
223                            let mut pred_not_translated = Vec::new();
224
225                            for mut predicate in predicates.drain(..) {
226                                use mz_expr::{BinaryFunc, UnaryFunc};
227                                if let MirScalarExpr::CallBinary {
228                                    func: BinaryFunc::Eq(_),
229                                    expr1,
230                                    expr2,
231                                } = &predicate
232                                {
233                                    // Translate into join variable constraints:
234                                    // 1) `nonliteral1 == nonliteral2` constraints
235                                    // 2) `expr == literal` where `expr` refers to more
236                                    //    than one input.
237                                    let input_count =
238                                        input_mapper.lookup_inputs(&predicate).count();
239                                    if (!expr1.is_literal() && !expr2.is_literal())
240                                        || input_count >= 2
241                                    {
242                                        // `col1 == col2` as a `MirScalarExpr`
243                                        // implies `!isnull(col1)` as well.
244                                        // `col1 == col2` as a join constraint does
245                                        // not have this extra implication.
246                                        // Thus, when translating the
247                                        // `MirScalarExpr` to a join constraint, we
248                                        // need to retain the `!isnull(col1)`
249                                        // information.
250                                        if expr1.typ(&input_type.column_types).nullable {
251                                            pred_not_translated.push(
252                                                expr1
253                                                    .clone()
254                                                    .call_unary(UnaryFunc::IsNull(func::IsNull))
255                                                    .call_unary(UnaryFunc::Not(func::Not)),
256                                            );
257                                        } else if expr2.typ(&input_type.column_types).nullable {
258                                            pred_not_translated.push(
259                                                expr2
260                                                    .clone()
261                                                    .call_unary(UnaryFunc::IsNull(func::IsNull))
262                                                    .call_unary(UnaryFunc::Not(func::Not)),
263                                            );
264                                        }
265                                        equivalences
266                                            .push(vec![(**expr1).clone(), (**expr2).clone()]);
267                                        continue;
268                                    }
269                                } else if let Some((expr1, expr2)) =
270                                    Self::extract_equal_or_both_null(
271                                        &mut predicate,
272                                        &input_type.column_types,
273                                    )
274                                {
275                                    // Also translate into join variable constraints:
276                                    // 3) `((nonliteral1 = nonliteral2) || (nonliteral
277                                    //    is null && nonliteral2 is null))`
278                                    equivalences.push(vec![expr1, expr2]);
279                                    continue;
280                                }
281                                pred_not_translated.push(predicate)
282                            }
283
284                            mz_expr::canonicalize::canonicalize_equivalences(
285                                equivalences,
286                                std::iter::once(&input_type.column_types),
287                            );
288
289                            let (retain, push_downs) = Self::push_filters_through_join(
290                                &input_mapper,
291                                equivalences,
292                                pred_not_translated,
293                            );
294
295                            Self::update_join_inputs_with_push_downs(inputs, push_downs);
296
297                            // Recursively descend on the join
298                            self.action(input, get_predicates)?;
299
300                            // remove all predicates that were pushed down from the current Filter node
301                            *predicates = retain;
302                        }
303                        MirRelationExpr::Reduce {
304                            input: inner,
305                            group_key,
306                            aggregates,
307                            monotonic: _,
308                            expected_group_size: _,
309                        } => {
310                            let mut retain = Vec::new();
311                            let mut push_down = Vec::new();
312                            for predicate in predicates.drain(..) {
313                                // Do not push down literal errors unless it is only errors.
314                                if !predicate.is_literal_err() || all_errors {
315                                    let mut supported = true;
316                                    let mut new_predicate = predicate.clone();
317                                    new_predicate.visit_pre(|e| {
318                                        if let MirScalarExpr::Column(c, _) = e {
319                                            if *c >= group_key.len() {
320                                                supported = false;
321                                            }
322                                        }
323                                    });
324                                    if supported {
325                                        new_predicate.visit_mut_post(&mut |e| {
326                                            if let MirScalarExpr::Column(i, _) = e {
327                                                *e = group_key[*i].clone();
328                                            }
329                                        })?;
330                                        push_down.push(new_predicate);
331                                    } else if let MirScalarExpr::Column(col, _) = &predicate {
332                                        if *col == group_key.len()
333                                            && aggregates.len() == 1
334                                            && aggregates[0].func == AggregateFunc::Any
335                                        {
336                                            push_down.push(aggregates[0].expr.clone());
337                                            aggregates[0].expr = MirScalarExpr::literal_ok(
338                                                Datum::True,
339                                                ReprScalarType::Bool,
340                                            );
341                                        } else {
342                                            retain.push(predicate);
343                                        }
344                                    } else {
345                                        retain.push(predicate);
346                                    }
347                                } else {
348                                    retain.push(predicate);
349                                }
350                            }
351
352                            if !push_down.is_empty() {
353                                **inner = inner.take_dangerous().filter(push_down);
354                            }
355                            self.action(inner, get_predicates)?;
356
357                            // remove all predicates that were pushed down from the current Filter node
358                            std::mem::swap(&mut retain, predicates);
359                        }
360                        MirRelationExpr::TopK {
361                            input,
362                            group_key,
363                            order_key: _,
364                            limit,
365                            offset: _,
366                            monotonic: _,
367                            expected_group_size: _,
368                        } => {
369                            let mut retain = Vec::new();
370                            let mut push_down = Vec::new();
371
372                            let mut support = BTreeSet::new();
373                            support.extend(group_key.iter().cloned());
374                            if let Some(limit) = limit {
375                                // Strictly speaking not needed because the
376                                // `limit` support should be a subset of the
377                                // `group_key` support, but we don't want to
378                                // take this for granted here.
379                                limit.support_into(&mut support);
380                            }
381
382                            for predicate in predicates.drain(..) {
383                                // Do not push down literal errors unless it is only errors.
384                                if (!predicate.is_literal_err() || all_errors)
385                                    && predicate.support().is_subset(&support)
386                                {
387                                    push_down.push(predicate);
388                                } else {
389                                    retain.push(predicate);
390                                }
391                            }
392
393                            // remove all predicates that were pushed down from the current Filter node
394                            std::mem::swap(&mut retain, predicates);
395
396                            if !push_down.is_empty() {
397                                **input = input.take_dangerous().filter(push_down);
398                            }
399
400                            self.action(input, get_predicates)?;
401                        }
402                        MirRelationExpr::Threshold { input } => {
403                            let predicates = std::mem::take(predicates);
404                            *relation = input.take_dangerous().filter(predicates).threshold();
405                            self.action(relation, get_predicates)?;
406                        }
407                        MirRelationExpr::Project { input, outputs } => {
408                            let predicates = predicates.drain(..).map(|mut predicate| {
409                                predicate.permute(outputs);
410                                predicate
411                            });
412                            *relation = input
413                                .take_dangerous()
414                                .filter(predicates)
415                                .project(outputs.clone());
416
417                            self.action(relation, get_predicates)?;
418                        }
419                        MirRelationExpr::Filter {
420                            input,
421                            predicates: predicates2,
422                        } => {
423                            *relation = input
424                                .take_dangerous()
425                                .filter(predicates.clone().into_iter().chain(predicates2.clone()));
426                            self.action(relation, get_predicates)?;
427                        }
428                        MirRelationExpr::Map { input, scalars } => {
429                            let (retained, pushdown) = Self::push_filters_through_map(
430                                scalars,
431                                predicates,
432                                input.arity(),
433                                all_errors,
434                            )?;
435                            let scalars = std::mem::take(scalars);
436                            let mut result = input.take_dangerous();
437                            if !pushdown.is_empty() {
438                                result = result.filter(pushdown);
439                            }
440                            self.action(&mut result, get_predicates)?;
441                            result = result.map(scalars);
442                            if !retained.is_empty() {
443                                result = result.filter(retained);
444                            }
445                            *relation = result;
446                        }
447                        MirRelationExpr::FlatMap { input, .. } => {
448                            let (mut retained, pushdown) =
449                                Self::push_filters_through_flat_map(predicates, input.arity());
450
451                            // remove all predicates that were pushed down from the current Filter node
452                            std::mem::swap(&mut retained, predicates);
453
454                            if !pushdown.is_empty() {
455                                // put the filter on top of the input
456                                **input = input.take_dangerous().filter(pushdown);
457                            }
458
459                            // ... and keep pushing predicates down
460                            self.action(input, get_predicates)?;
461                        }
462                        MirRelationExpr::Union { base, inputs } => {
463                            let predicates = std::mem::take(predicates);
464                            **base = base.take_dangerous().filter(predicates.clone());
465                            self.action(base, get_predicates)?;
466                            for input in inputs {
467                                *input = input.take_dangerous().filter(predicates.clone());
468                                self.action(input, get_predicates)?;
469                            }
470                        }
471                        MirRelationExpr::Negate { input } => {
472                            // Don't push literal errors past a Negate. The problem is that it's
473                            // hard to appropriately reflect the negation in the error stream:
474                            // - If we don't negate, then errors that should cancel out will not
475                            //   cancel out. For example, see
476                            //   https://github.com/MaterializeInc/database-issues/issues/5691
477                            // - If we negate, then unrelated errors might cancel out. E.g., there
478                            //   might be a division-by-0 in both inputs to an EXCEPT ALL, but
479                            //   on different input data. These shouldn't cancel out.
480                            let (retained, pushdown): (Vec<_>, Vec<_>) = std::mem::take(predicates)
481                                .into_iter()
482                                .partition(|p| p.is_literal_err());
483                            let mut result = input.take_dangerous();
484                            if !pushdown.is_empty() {
485                                result = result.filter(pushdown);
486                            }
487                            self.action(&mut result, get_predicates)?;
488                            result = result.negate();
489                            if !retained.is_empty() {
490                                result = result.filter(retained);
491                            }
492                            *relation = result;
493                        }
494                        x => {
495                            x.try_visit_mut_children(|e| self.action(e, get_predicates))?;
496                        }
497                    }
498
499                    // remove empty filters (junk by-product of the actual transform)
500                    match relation {
501                        MirRelationExpr::Filter { predicates, input } if predicates.is_empty() => {
502                            *relation = input.take_dangerous();
503                        }
504                        _ => {}
505                    }
506
507                    Ok(())
508                }
509                MirRelationExpr::Get { id, .. } => {
510                    // Purge all predicates associated with the id.
511                    get_predicates
512                        .entry(*id)
513                        .or_insert_with(BTreeSet::new)
514                        .clear();
515
516                    Ok(())
517                }
518                MirRelationExpr::Let { id, body, value } => {
519                    // Push predicates and collect intersection at `Get`s.
520                    self.action(body, get_predicates)?;
521
522                    // `get_predicates` should now contain the intersection
523                    // of predicates at each *use* of the binding. If it is
524                    // non-empty, we can move those predicates to the value.
525                    Self::push_into_let_binding(get_predicates, id, value, &mut [body]);
526
527                    // Continue recursively on the value.
528                    self.action(value, get_predicates)
529                }
530                MirRelationExpr::LetRec {
531                    ids,
532                    values,
533                    limits: _,
534                    body,
535                } => {
536                    // Note: This could be extended to be able to do a little more pushdowns, see
537                    // https://github.com/MaterializeInc/database-issues/issues/5336#issuecomment-1477588262
538
539                    // Pre-compute which Ids are used across iterations
540                    let ids_used_across_iterations = MirRelationExpr::recursive_ids(ids, values);
541
542                    // Predicate pushdown within the body
543                    self.action(body, get_predicates)?;
544
545                    // `users` will be the body plus the values of those bindings that we have seen
546                    // so far, while going one-by-one through the list of bindings backwards.
547                    // `users` contains those expressions from which we harvested `get_predicates`,
548                    // and therefore we should attend to all of these expressions when pushing down
549                    // a predicate into a Let binding.
550                    let mut users = vec![&mut **body];
551                    for (id, value) in ids.iter_mut().rev().zip_eq(values.into_iter().rev()) {
552                        // Predicate pushdown from Gets in `users` into the value of a Let binding
553                        //
554                        // For now, we simply always avoid pushing into a Let binding that is
555                        // referenced across iterations to avoid soundness problems and infinite
556                        // pushdowns.
557                        //
558                        // Note that `push_into_let_binding` makes a further check based on
559                        // `get_predicates`: We push a predicate into the value of a binding, only
560                        // if all Gets of this Id have this same predicate on top of them.
561                        if !ids_used_across_iterations.contains(id) {
562                            Self::push_into_let_binding(get_predicates, id, value, &mut users);
563                        }
564
565                        // Predicate pushdown within a binding
566                        self.action(value, get_predicates)?;
567
568                        users.push(value);
569                    }
570
571                    Ok(())
572                }
573                MirRelationExpr::Join {
574                    inputs,
575                    equivalences,
576                    ..
577                } => {
578                    // The goal is to push
579                    //   1) equivalences of the form `expr = <runtime constant>`, where `expr`
580                    //      comes from a single input.
581                    //   2) equivalences of the form `expr1 = expr2`, where both
582                    //      expressions come from the same single input.
583                    let input_types = inputs.iter().map(|i| i.typ()).collect::<Vec<_>>();
584                    mz_expr::canonicalize::canonicalize_equivalences(
585                        equivalences,
586                        input_types.iter().map(|t| &t.column_types),
587                    );
588
589                    let input_mapper = mz_expr::JoinInputMapper::new_from_input_types(&input_types);
590                    // Predicates to push at each input, and to lift out the join.
591                    let mut push_downs = vec![Vec::new(); inputs.len()];
592
593                    for equivalence_pos in 0..equivalences.len() {
594                        // Case 1: there are more than one literal in the
595                        // equivalence class. Because of equivalences have been
596                        // dedupped, this means that everything in the equivalence
597                        // class must be equal to two different literals, so the
598                        // entire relation zeroes out
599                        if equivalences[equivalence_pos]
600                            .iter()
601                            .filter(|expr| expr.is_literal())
602                            .count()
603                            > 1
604                        {
605                            relation.take_safely(Some(relation.typ_with_input_types(&input_types)));
606                            return Ok(());
607                        }
608
609                        let runtime_constants = equivalences[equivalence_pos]
610                            .iter()
611                            .filter(|expr| expr.support().is_empty())
612                            .cloned()
613                            .collect::<Vec<_>>();
614                        if !runtime_constants.is_empty() {
615                            // Case 2: There is at least one runtime constant the equivalence class
616                            let gen_literal_equality_preds = |expr: MirScalarExpr| {
617                                let mut equality_preds = Vec::new();
618                                for constant in runtime_constants.iter() {
619                                    let pred = if constant.is_literal_null() {
620                                        MirScalarExpr::CallUnary {
621                                            func: mz_expr::UnaryFunc::IsNull(func::IsNull),
622                                            expr: Box::new(expr.clone()),
623                                        }
624                                    } else {
625                                        MirScalarExpr::CallBinary {
626                                            func: mz_expr::func::Eq.into(),
627                                            expr1: Box::new(expr.clone()),
628                                            expr2: Box::new(constant.clone()),
629                                        }
630                                    };
631                                    equality_preds.push(pred);
632                                }
633                                equality_preds
634                            };
635
636                            // Find all single input expressions in the equivalence
637                            // class and collect (position within the equivalence class,
638                            // input the expression belongs to, localized version of the
639                            // expression).
640                            let mut single_input_exprs = equivalences[equivalence_pos]
641                                .iter()
642                                .enumerate()
643                                .filter_map(|(pos, e)| {
644                                    let mut inputs = input_mapper.lookup_inputs(e);
645                                    if let Some(input) = inputs.next() {
646                                        if inputs.next().is_none() {
647                                            return Some((
648                                                pos,
649                                                input,
650                                                input_mapper.map_expr_to_local(e.clone()),
651                                            ));
652                                        }
653                                    }
654                                    None
655                                })
656                                .collect::<Vec<_>>();
657
658                            // For every single-input expression `expr`, we can push
659                            // down `expr = <runtime constant>` and remove `expr` from the
660                            // equivalence class.
661                            for (expr_pos, input, expr) in single_input_exprs.drain(..).rev() {
662                                push_downs[input].extend(gen_literal_equality_preds(expr));
663                                equivalences[equivalence_pos].remove(expr_pos);
664                            }
665
666                            // If none of the expressions in the equivalence depend on input
667                            // columns and equality predicates with them are pushed down,
668                            // we can safely remove them from the equivalence.
669                            // TODO: we could probably push equality predicates among the
670                            // remaining constants to all join inputs to prevent any computation
671                            // from happening until the condition is satisfied.
672                            if equivalences[equivalence_pos]
673                                .iter()
674                                .all(|e| e.support().is_empty())
675                                && push_downs.iter().any(|p| !p.is_empty())
676                            {
677                                equivalences[equivalence_pos].clear();
678                            }
679                        } else {
680                            // Case 3: There are no constants in the equivalence
681                            // class. Push a predicate for every pair of expressions
682                            // in the equivalence that either belong to a single
683                            // input or can be localized to a given input through
684                            // the rest of equivalences.
685                            let mut to_remove = Vec::new();
686                            for input in 0..inputs.len() {
687                                // Vector of pairs (position within the equivalence, localized
688                                // expression). The position is None for expressions derived through
689                                // other equivalences.
690                                let localized = equivalences[equivalence_pos]
691                                    .iter()
692                                    .enumerate()
693                                    .filter_map(|(pos, expr)| {
694                                        if let MirScalarExpr::Column(col_pos, _) = &expr {
695                                            let local_col =
696                                                input_mapper.map_column_to_local(*col_pos);
697                                            if input == local_col.1 {
698                                                // TODO(mgree) !!! is it safe to propagate the name here?
699                                                return Some((
700                                                    Some(pos),
701                                                    MirScalarExpr::column(local_col.0),
702                                                ));
703                                            } else {
704                                                return None;
705                                            }
706                                        }
707                                        let mut inputs = input_mapper.lookup_inputs(expr);
708                                        if let Some(single_input) = inputs.next() {
709                                            if input == single_input && inputs.next().is_none() {
710                                                return Some((
711                                                    Some(pos),
712                                                    input_mapper.map_expr_to_local(expr.clone()),
713                                                ));
714                                            }
715                                        }
716                                        // Equivalences not including the current expression
717                                        let mut other_equivalences = equivalences.clone();
718                                        other_equivalences[equivalence_pos].remove(pos);
719                                        let mut localized = expr.clone();
720                                        if input_mapper.try_localize_to_input_with_bound_expr(
721                                            &mut localized,
722                                            input,
723                                            &other_equivalences[..],
724                                        ) {
725                                            Some((None, localized))
726                                        } else {
727                                            None
728                                        }
729                                    })
730                                    .collect::<Vec<_>>();
731
732                                // If there are at least 2 expression in the equivalence that
733                                // can be localized to the same input, push all combinations
734                                // of them to the input.
735                                if localized.len() > 1 {
736                                    for mut pair in
737                                        localized.iter().map(|(_, expr)| expr).combinations(2)
738                                    {
739                                        let expr1 = pair.pop().unwrap();
740                                        let expr2 = pair.pop().unwrap();
741
742                                        push_downs[input].push(
743                                            MirScalarExpr::CallBinary {
744                                                func: func::Eq.into(),
745                                                expr1: Box::new(expr2.clone()),
746                                                expr2: Box::new(expr1.clone()),
747                                            }
748                                            .or(expr2
749                                                .clone()
750                                                .call_is_null()
751                                                .and(expr1.clone().call_is_null())),
752                                        );
753                                    }
754
755                                    if localized.len() == equivalences[equivalence_pos].len() {
756                                        // The equivalence is either a single input one or fully localizable
757                                        // to a single input through other equivalences, so it can be removed
758                                        // completely without introducing any new cross join.
759                                        to_remove.extend(0..equivalences[equivalence_pos].len());
760                                    } else {
761                                        // Leave an expression from this input in the equivalence to avoid
762                                        // cross joins
763                                        to_remove.extend(
764                                            localized.iter().filter_map(|(pos, _)| *pos).skip(1),
765                                        );
766                                    }
767                                }
768                            }
769
770                            // Remove expressions that were pushed down to at least one input
771                            to_remove.sort();
772                            to_remove.dedup();
773                            for pos in to_remove.iter().rev() {
774                                equivalences[equivalence_pos].remove(*pos);
775                            }
776                        };
777                    }
778
779                    mz_expr::canonicalize::canonicalize_equivalences(
780                        equivalences,
781                        input_types.iter().map(|t| &t.column_types),
782                    );
783
784                    Self::update_join_inputs_with_push_downs(inputs, push_downs);
785
786                    // Recursively descend on each of the inputs.
787                    for input in inputs.iter_mut() {
788                        self.action(input, get_predicates)?;
789                    }
790
791                    Ok(())
792                }
793                x => {
794                    // Recursively descend.
795                    x.try_visit_mut_children(|e| self.action(e, get_predicates))
796                }
797            }
798        })
799    }
800
801    fn update_join_inputs_with_push_downs(
802        inputs: &mut Vec<MirRelationExpr>,
803        push_downs: Vec<Vec<MirScalarExpr>>,
804    ) {
805        let new_inputs = inputs
806            .drain(..)
807            .zip_eq(push_downs)
808            .map(|(input, push_down)| {
809                if !push_down.is_empty() {
810                    input.filter(push_down)
811                } else {
812                    input
813                }
814            })
815            .collect();
816        *inputs = new_inputs;
817    }
818
819    // Checks `get_predicates` to see whether we can push a predicate into the Let binding given
820    // by `id` and `value`.
821    // `users` is the list of those expressions from which we will need to remove a predicate that
822    // is being pushed.
823    fn push_into_let_binding(
824        get_predicates: &mut BTreeMap<Id, BTreeSet<MirScalarExpr>>,
825        id: &LocalId,
826        value: &mut MirRelationExpr,
827        users: &mut [&mut MirRelationExpr],
828    ) {
829        if let Some(list) = get_predicates.remove(&Id::Local(*id)) {
830            if !list.is_empty() {
831                // Remove the predicates in `list` from the users.
832                for user in users {
833                    user.visit_pre_mut(|e| {
834                        if let MirRelationExpr::Filter { input, predicates } = e {
835                            if let MirRelationExpr::Get { id: get_id, .. } = **input {
836                                if get_id == Id::Local(*id) {
837                                    predicates.retain(|p| !list.contains(p));
838                                }
839                            }
840                        }
841                    });
842                }
843                // Apply the predicates in `list` to value. Canonicalize
844                // `list` so that plans are always deterministic.
845                let mut list = list.into_iter().collect::<Vec<_>>();
846                mz_expr::canonicalize::canonicalize_predicates(
847                    &mut list,
848                    &value.typ().column_types,
849                );
850                *value = value.take_dangerous().filter(list);
851            }
852        }
853    }
854
855    /// Returns `(<predicates to retain>, <predicates to push at each input>)`.
856    pub fn push_filters_through_join(
857        input_mapper: &JoinInputMapper,
858        equivalences: &Vec<Vec<MirScalarExpr>>,
859        mut predicates: Vec<MirScalarExpr>,
860    ) -> (Vec<MirScalarExpr>, Vec<Vec<MirScalarExpr>>) {
861        let mut push_downs = vec![Vec::new(); input_mapper.total_inputs()];
862        let mut retain = Vec::new();
863
864        for predicate in predicates.drain(..) {
865            // Track if the predicate has been pushed to at least one input.
866            let mut pushed = false;
867            // For each input, try and see if the join
868            // equivalences allow the predicate to be rewritten
869            // in terms of only columns from that input.
870            for (index, push_down) in push_downs.iter_mut().enumerate() {
871                #[allow(deprecated)] // TODO: use `might_error` if possible.
872                if predicate.is_literal_err() || predicate.contains_error_if_null() {
873                    // Do nothing. We don't push down literal errors,
874                    // as we can't know the join will be non-empty.
875                    //
876                    // We also don't want to push anything that involves `error_if_null`. This is
877                    // for the same reason why in theory we shouldn't really push anything that can
878                    // error, assuming that we want to preserve error semantics. (Because we would
879                    // create a spurious error if some other Join input ends up empty.) We can't fix
880                    // this problem in general (as we can't just not push anything that might
881                    // error), but we decided to fix the specific problem instance involving
882                    // `error_if_null`, because it was very painful:
883                    // <https://github.com/MaterializeInc/database-issues/issues/6258>
884                } else {
885                    let mut localized = predicate.clone();
886                    if input_mapper.try_localize_to_input_with_bound_expr(
887                        &mut localized,
888                        index,
889                        equivalences,
890                    ) {
891                        push_down.push(localized);
892                        pushed = true;
893                    } else if let Some(consequence) = input_mapper
894                        // (`consequence_for_input` assumes that
895                        // `try_localize_to_input_with_bound_expr` has already
896                        // been called on `localized`.)
897                        .consequence_for_input(&localized, index)
898                    {
899                        push_down.push(consequence);
900                        // We don't set `pushed` here! We want to retain the
901                        // predicate, because we only pushed a consequence of
902                        // it, but not the full predicate.
903                    }
904                }
905            }
906
907            if !pushed {
908                retain.push(predicate);
909            }
910        }
911
912        (retain, push_downs)
913    }
914
915    /// Computes "safe" predicates to push through a Map.
916    ///
917    /// In the case of a Filter { Map {...} }, we can always push down the Filter
918    /// by inlining expressions from the Map. We don't want to do this in general,
919    /// however, since general inlining can result in exponential blowup in the size
920    /// of expressions, so we only do this in the case where the size after inlining
921    /// is below a certain limit.
922    ///
923    /// Returns the predicates that can be pushed down, followed by ones that cannot.
924    pub fn push_filters_through_map(
925        map_exprs: &Vec<MirScalarExpr>,
926        predicates: &mut Vec<MirScalarExpr>,
927        input_arity: usize,
928        all_errors: bool,
929    ) -> Result<(Vec<MirScalarExpr>, Vec<MirScalarExpr>), TransformError> {
930        let mut pushdown = Vec::new();
931        let mut retained = Vec::new();
932        for predicate in predicates.drain(..) {
933            // We don't push down literal errors, unless all predicates are.
934            if !predicate.is_literal_err() || all_errors {
935                // Consider inlining Map expressions.
936                if let Some(cleaned) =
937                    Self::inline_if_not_too_big(&predicate, input_arity, map_exprs)?
938                {
939                    pushdown.push(cleaned);
940                } else {
941                    retained.push(predicate);
942                }
943            } else {
944                retained.push(predicate);
945            }
946        }
947        Ok((retained, pushdown))
948    }
949
950    /// This fn should be called with a Filter `expr` that is after a Map. `input_arity` is the
951    /// arity of the input of the Map. This fn eliminates such column refs in `expr` that refer not
952    /// to a column in the input of the Map, but to a column that is created by the Map. It does
953    /// this by transitively inlining Map expressions until no such expression remains that points
954    /// to a Map expression. The return value is the cleaned up expression. The fn bails out with a
955    /// None if the resulting expression would be made too big by the inlinings.
956    ///
957    /// OOO: (Optimizer Optimization Opportunity) This function might do work proportional to the
958    /// total size of the Map expressions. We call this function for each predicate above the Map,
959    /// which will be kind of quadratic, i.e., if there are many predicates and a big Map, then this
960    /// will be slow. We could instead pass a vector of Map expressions and call this fn only once.
961    /// The only downside would be that then the inlining limit being hit in the middle part of this
962    /// function would prevent us from inlining any predicates, even ones that wouldn't hit the
963    /// inlining limit if considered on their own.
964    fn inline_if_not_too_big(
965        expr: &MirScalarExpr,
966        input_arity: usize,
967        map_exprs: &Vec<MirScalarExpr>,
968    ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
969        let size_limit = 1000;
970
971        // Transitively determine the support of `expr` produced by `map_exprs`
972        // that needs to be inlined.
973        let cols_to_inline = {
974            let mut support = BTreeSet::new();
975
976            // Seed with `map_exprs` support in `expr`.
977            expr.visit_pre(|e| {
978                if let MirScalarExpr::Column(c, _) = e {
979                    if *c >= input_arity {
980                        support.insert(*c);
981                    }
982                }
983            });
984
985            // Compute transitive closure of supports in `map_exprs`.
986            let mut workset = support.iter().cloned().collect::<Vec<_>>();
987            let mut buffer = vec![];
988            while !workset.is_empty() {
989                // Swap the (empty) `drained` buffer with the `workset`.
990                std::mem::swap(&mut workset, &mut buffer);
991                // Drain the `buffer` and update `support` and `workset`.
992                for c in buffer.drain(..) {
993                    map_exprs[c - input_arity].visit_pre(|e| {
994                        if let MirScalarExpr::Column(c, _) = e {
995                            if *c >= input_arity {
996                                if support.insert(*c) {
997                                    workset.push(*c);
998                                }
999                            }
1000                        }
1001                    });
1002                }
1003            }
1004            support
1005        };
1006
1007        let mut inlined = BTreeMap::<usize, (MirScalarExpr, usize)>::new();
1008        // Populate the memo table in ascending column order (which respects the
1009        // dependency order of `map_exprs` references). Break early if memoization
1010        // fails for one of the columns in `cols_to_inline`.
1011        for c in cols_to_inline.iter() {
1012            let mut new_expr = map_exprs[*c - input_arity].clone();
1013            let mut new_size = 0;
1014            new_expr.visit_mut_post(&mut |expr| {
1015                new_size += 1;
1016                if let MirScalarExpr::Column(c, _) = expr {
1017                    if *c >= input_arity && new_size <= size_limit {
1018                        // (inlined[c] is safe, because we proceed in column order, and we break out
1019                        // of the loop when we stop inserting into memo.)
1020                        let (m_expr, m_size): &(MirScalarExpr, _) = &inlined[c];
1021                        *expr = m_expr.clone();
1022                        new_size += m_size - 1; // Adjust for the +1 above.
1023                    }
1024                }
1025            })?;
1026
1027            if new_size <= size_limit {
1028                inlined.insert(*c, (new_expr, new_size));
1029            } else {
1030                break;
1031            }
1032        }
1033
1034        // Try to resolve expr against the memo table.
1035        if inlined.len() < cols_to_inline.len() {
1036            Ok(None) // We couldn't memoize all map expressions within the given limit.
1037        } else {
1038            let mut new_expr = expr.clone();
1039            let mut new_size = 0;
1040            new_expr.visit_mut_post(&mut |expr| {
1041                new_size += 1;
1042                if let MirScalarExpr::Column(c, _) = expr {
1043                    if *c >= input_arity && new_size <= size_limit {
1044                        // (inlined[c] is safe because of the outer if condition.)
1045                        let (m_expr, m_size): &(MirScalarExpr, _) = &inlined[c];
1046                        *expr = m_expr.clone();
1047                        new_size += m_size - 1; // Adjust for the +1 above.
1048                    }
1049                }
1050            })?;
1051
1052            soft_assert_eq_no_log!(new_size, new_expr.size());
1053            if new_size <= size_limit {
1054                Ok(Some(new_expr)) // We managed to stay within the limit.
1055            } else {
1056                Ok(None) // Limit exceeded.
1057            }
1058        }
1059    }
1060    // fn inline_if_not_too_big(
1061    //     expr: &MirScalarExpr,
1062    //     input_arity: usize,
1063    //     map_exprs: &Vec<MirScalarExpr>,
1064    // ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
1065    //     let size_limit = 1000;
1066    //     // Memoize cleaned up versions of Map expressions. (Not necessarily all the Map expressions
1067    //     // will be involved.)
1068    //     let mut memo: BTreeMap<MirScalarExpr, MirScalarExpr> = BTreeMap::new();
1069    //     fn rec(
1070    //         expr: &MirScalarExpr,
1071    //         input_arity: usize,
1072    //         map_exprs: &Vec<MirScalarExpr>,
1073    //         memo: &mut BTreeMap<MirScalarExpr, MirScalarExpr>,
1074    //         size_limit: usize,
1075    //     ) -> Result<Option<MirScalarExpr>, RecursionLimitError> {
1076    //         // (We can't use Entry::or_insert_with, because the closure would need to be fallible.
1077    //         // We also can't manually match on the result of memo.entry, because that holds a
1078    //         // borrow of memo, but we need to pass memo to the recursive call in the middle.)
1079    //         match memo.get(expr) {
1080    //             Some(memoized_result) => Ok(Some(memoized_result.clone())),
1081    //             None => {
1082    //                 let mut expr_size = expr.size()?;
1083    //                 let mut cleaned_expr = expr.clone();
1084    //                 let mut bail = false;
1085    //                 cleaned_expr.try_visit_mut_post(&mut |expr| {
1086    //                     Ok(if !bail {
1087    //                         match expr {
1088    //                             MirScalarExpr::Column(col) => {
1089    //                                 if *col >= input_arity {
1090    //                                     let to_inline = rec(
1091    //                                         &map_exprs[*col - input_arity],
1092    //                                         input_arity,
1093    //                                         map_exprs,
1094    //                                         memo,
1095    //                                         size_limit,
1096    //                                     )?;
1097    //                                     if let Some(to_inline) = to_inline {
1098    //                                         // The `-1` is because the expression that we are
1099    //                                         // replacing has a size of 1.
1100    //                                         expr_size += to_inline.size()? - 1;
1101    //                                         *expr = to_inline;
1102    //                                         if expr_size > size_limit {
1103    //                                             bail = true;
1104    //                                         }
1105    //                                     } else {
1106    //                                         bail = true;
1107    //                                     }
1108    //                                 }
1109    //                             }
1110    //                             _ => (),
1111    //                         }
1112    //                     })
1113    //                 })?;
1114    //                 soft_assert_eq!(cleaned_expr.size()?, expr_size);
1115    //                 if !bail {
1116    //                     memo.insert(expr.clone(), cleaned_expr.clone());
1117    //                     Ok(Some(cleaned_expr))
1118    //                 } else {
1119    //                     Ok(None)
1120    //                 }
1121    //             }
1122    //         }
1123    //     }
1124    //     rec(expr, input_arity, map_exprs, &mut memo, size_limit)
1125    // }
1126
1127    /// Computes "safe" predicate to push through a FlatMap.
1128    ///
1129    /// In the case of a Filter { FlatMap {...} }, we want to push through all predicates
1130    /// that (1) are not literal errors and (2) have support exclusively in the columns
1131    /// provided by the FlatMap input.
1132    ///
1133    /// Returns the predicates that can be pushed down, followed by ones that cannot.
1134    fn push_filters_through_flat_map(
1135        predicates: &mut Vec<MirScalarExpr>,
1136        input_arity: usize,
1137    ) -> (Vec<MirScalarExpr>, Vec<MirScalarExpr>) {
1138        let mut pushdown = Vec::new();
1139        let mut retained = Vec::new();
1140        for predicate in predicates.drain(..) {
1141            // First, check if we can push this predicate down. We can do so if and only if:
1142            // (1) the predicate is not a literal error, and
1143            // (2) each column it references is from the input.
1144            if (!predicate.is_literal_err()) && predicate.support().iter().all(|c| *c < input_arity)
1145            {
1146                pushdown.push(predicate);
1147            } else {
1148                retained.push(predicate);
1149            }
1150        }
1151        (retained, pushdown)
1152    }
1153
1154    /// If `s` is of the form
1155    /// `(isnull(expr1) && isnull(expr2)) || (expr1 = expr2)`, or
1156    /// `(decompose_is_null(expr1) && decompose_is_null(expr2)) || (expr1 = expr2)`,
1157    /// extract `expr1` and `expr2`.
1158    fn extract_equal_or_both_null(
1159        s: &mut MirScalarExpr,
1160        column_types: &[ReprColumnType],
1161    ) -> Option<(MirScalarExpr, MirScalarExpr)> {
1162        if let MirScalarExpr::CallVariadic {
1163            func: VariadicFunc::Or(_),
1164            exprs,
1165        } = s
1166        {
1167            if let &[ref or_lhs, ref or_rhs] = &**exprs {
1168                // Check both orders of operands of the OR
1169                return Self::extract_equal_or_both_null_inner(or_lhs, or_rhs, column_types)
1170                    .or_else(|| {
1171                        Self::extract_equal_or_both_null_inner(or_rhs, or_lhs, column_types)
1172                    });
1173            }
1174        }
1175        None
1176    }
1177
1178    fn extract_equal_or_both_null_inner(
1179        or_arg1: &MirScalarExpr,
1180        or_arg2: &MirScalarExpr,
1181        column_types: &[ReprColumnType],
1182    ) -> Option<(MirScalarExpr, MirScalarExpr)> {
1183        use mz_expr::BinaryFunc;
1184        if let MirScalarExpr::CallBinary {
1185            func: BinaryFunc::Eq(_),
1186            expr1: eq_lhs,
1187            expr2: eq_rhs,
1188        } = &or_arg2
1189        {
1190            let isnull1 = eq_lhs.clone().call_is_null();
1191            let isnull2 = eq_rhs.clone().call_is_null();
1192            let both_null = MirScalarExpr::call_variadic(And, vec![isnull1, isnull2]);
1193
1194            if Self::extract_reduced_conjunction_terms(both_null, column_types)
1195                == Self::extract_reduced_conjunction_terms(or_arg1.clone(), column_types)
1196            {
1197                return Some(((**eq_lhs).clone(), (**eq_rhs).clone()));
1198            }
1199        }
1200        None
1201    }
1202
1203    /// Reduces the given expression and returns its AND-ed terms.
1204    fn extract_reduced_conjunction_terms(
1205        mut s: MirScalarExpr,
1206        column_types: &[ReprColumnType],
1207    ) -> Vec<MirScalarExpr> {
1208        s.reduce(column_types);
1209
1210        if let MirScalarExpr::CallVariadic {
1211            func: VariadicFunc::And(_),
1212            exprs,
1213        } = s
1214        {
1215            exprs
1216        } else {
1217            vec![s]
1218        }
1219    }
1220}